login_sources_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. "gorm.io/gorm"
  10. "gogs.io/gogs/internal/auth"
  11. "gogs.io/gogs/internal/auth/github"
  12. "gogs.io/gogs/internal/auth/pam"
  13. "gogs.io/gogs/internal/errutil"
  14. )
  15. func TestLoginSource_BeforeSave(t *testing.T) {
  16. now := time.Now()
  17. db := &gorm.DB{
  18. Config: &gorm.Config{
  19. NowFunc: func() time.Time {
  20. return now
  21. },
  22. },
  23. }
  24. t.Run("Config has not been set", func(t *testing.T) {
  25. s := &LoginSource{}
  26. err := s.BeforeSave(db)
  27. if err != nil {
  28. t.Fatal(err)
  29. }
  30. assert.Empty(t, s.Config)
  31. })
  32. t.Run("Config has been set", func(t *testing.T) {
  33. s := &LoginSource{
  34. Provider: pam.NewProvider(&pam.Config{
  35. ServiceName: "pam_service",
  36. }),
  37. }
  38. err := s.BeforeSave(db)
  39. if err != nil {
  40. t.Fatal(err)
  41. }
  42. assert.Equal(t, `{"ServiceName":"pam_service"}`, s.Config)
  43. })
  44. }
  45. func TestLoginSource_BeforeCreate(t *testing.T) {
  46. now := time.Now()
  47. db := &gorm.DB{
  48. Config: &gorm.Config{
  49. NowFunc: func() time.Time {
  50. return now
  51. },
  52. },
  53. }
  54. t.Run("CreatedUnix has been set", func(t *testing.T) {
  55. s := &LoginSource{CreatedUnix: 1}
  56. _ = s.BeforeCreate(db)
  57. assert.Equal(t, int64(1), s.CreatedUnix)
  58. assert.Equal(t, int64(0), s.UpdatedUnix)
  59. })
  60. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  61. s := &LoginSource{}
  62. _ = s.BeforeCreate(db)
  63. assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix)
  64. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  65. })
  66. }
  67. func Test_loginSources(t *testing.T) {
  68. if testing.Short() {
  69. t.Skip()
  70. }
  71. t.Parallel()
  72. tables := []interface{}{new(LoginSource), new(User)}
  73. db := &loginSources{
  74. DB: initTestDB(t, "loginSources", tables...),
  75. }
  76. for _, tc := range []struct {
  77. name string
  78. test func(*testing.T, *loginSources)
  79. }{
  80. {"Create", test_loginSources_Create},
  81. {"Count", test_loginSources_Count},
  82. {"DeleteByID", test_loginSources_DeleteByID},
  83. {"GetByID", test_loginSources_GetByID},
  84. {"List", test_loginSources_List},
  85. {"ResetNonDefault", test_loginSources_ResetNonDefault},
  86. {"Save", test_loginSources_Save},
  87. } {
  88. t.Run(tc.name, func(t *testing.T) {
  89. t.Cleanup(func() {
  90. err := clearTables(t, db.DB, tables...)
  91. if err != nil {
  92. t.Fatal(err)
  93. }
  94. })
  95. tc.test(t, db)
  96. })
  97. }
  98. }
  99. func test_loginSources_Create(t *testing.T, db *loginSources) {
  100. // Create first login source with name "GitHub"
  101. source, err := db.Create(CreateLoginSourceOpts{
  102. Type: auth.GitHub,
  103. Name: "GitHub",
  104. Activated: true,
  105. Default: false,
  106. Config: &github.Config{
  107. APIEndpoint: "https://api.github.com",
  108. },
  109. })
  110. if err != nil {
  111. t.Fatal(err)
  112. }
  113. // Get it back and check the Created field
  114. source, err = db.GetByID(source.ID)
  115. if err != nil {
  116. t.Fatal(err)
  117. }
  118. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
  119. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
  120. // Try create second login source with same name should fail
  121. _, err = db.Create(CreateLoginSourceOpts{Name: source.Name})
  122. expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
  123. assert.Equal(t, expErr, err)
  124. }
  125. func test_loginSources_Count(t *testing.T, db *loginSources) {
  126. // Create two login sources, one in database and one as source file.
  127. _, err := db.Create(CreateLoginSourceOpts{
  128. Type: auth.GitHub,
  129. Name: "GitHub",
  130. Activated: true,
  131. Default: false,
  132. Config: &github.Config{
  133. APIEndpoint: "https://api.github.com",
  134. },
  135. })
  136. if err != nil {
  137. t.Fatal(err)
  138. }
  139. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  140. MockLen: func() int {
  141. return 2
  142. },
  143. })
  144. assert.Equal(t, int64(3), db.Count())
  145. }
  146. func test_loginSources_DeleteByID(t *testing.T, db *loginSources) {
  147. t.Run("delete but in used", func(t *testing.T) {
  148. source, err := db.Create(CreateLoginSourceOpts{
  149. Type: auth.GitHub,
  150. Name: "GitHub",
  151. Activated: true,
  152. Default: false,
  153. Config: &github.Config{
  154. APIEndpoint: "https://api.github.com",
  155. },
  156. })
  157. if err != nil {
  158. t.Fatal(err)
  159. }
  160. // Create a user that uses this login source
  161. _, err = (&users{DB: db.DB}).Create("alice", "", CreateUserOpts{
  162. LoginSource: source.ID,
  163. })
  164. if err != nil {
  165. t.Fatal(err)
  166. }
  167. // Delete the login source will result in error
  168. err = db.DeleteByID(source.ID)
  169. expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
  170. assert.Equal(t, expErr, err)
  171. })
  172. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  173. MockGetByID: func(id int64) (*LoginSource, error) {
  174. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  175. },
  176. })
  177. // Create a login source with name "GitHub2"
  178. source, err := db.Create(CreateLoginSourceOpts{
  179. Type: auth.GitHub,
  180. Name: "GitHub2",
  181. Activated: true,
  182. Default: false,
  183. Config: &github.Config{
  184. APIEndpoint: "https://api.github.com",
  185. },
  186. })
  187. if err != nil {
  188. t.Fatal(err)
  189. }
  190. // Delete a non-existent ID is noop
  191. err = db.DeleteByID(9999)
  192. if err != nil {
  193. t.Fatal(err)
  194. }
  195. // We should be able to get it back
  196. _, err = db.GetByID(source.ID)
  197. if err != nil {
  198. t.Fatal(err)
  199. }
  200. // Now delete this login source with ID
  201. err = db.DeleteByID(source.ID)
  202. if err != nil {
  203. t.Fatal(err)
  204. }
  205. // We should get token not found error
  206. _, err = db.GetByID(source.ID)
  207. expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
  208. assert.Equal(t, expErr, err)
  209. }
  210. func test_loginSources_GetByID(t *testing.T, db *loginSources) {
  211. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  212. MockGetByID: func(id int64) (*LoginSource, error) {
  213. if id != 101 {
  214. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  215. }
  216. return &LoginSource{ID: id}, nil
  217. },
  218. })
  219. expConfig := &github.Config{
  220. APIEndpoint: "https://api.github.com",
  221. }
  222. // Create a login source with name "GitHub"
  223. source, err := db.Create(CreateLoginSourceOpts{
  224. Type: auth.GitHub,
  225. Name: "GitHub",
  226. Activated: true,
  227. Default: false,
  228. Config: expConfig,
  229. })
  230. if err != nil {
  231. t.Fatal(err)
  232. }
  233. // Get the one in the database and test the read/write hooks
  234. source, err = db.GetByID(source.ID)
  235. if err != nil {
  236. t.Fatal(err)
  237. }
  238. assert.Equal(t, expConfig, source.Provider.Config())
  239. // Get the one in source file store
  240. _, err = db.GetByID(101)
  241. if err != nil {
  242. t.Fatal(err)
  243. }
  244. }
  245. func test_loginSources_List(t *testing.T, db *loginSources) {
  246. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  247. MockList: func(opts ListLoginSourceOpts) []*LoginSource {
  248. if opts.OnlyActivated {
  249. return []*LoginSource{
  250. {ID: 1},
  251. }
  252. }
  253. return []*LoginSource{
  254. {ID: 1},
  255. {ID: 2},
  256. }
  257. },
  258. })
  259. // Create two login sources in database, one activated and the other one not
  260. _, err := db.Create(CreateLoginSourceOpts{
  261. Type: auth.PAM,
  262. Name: "PAM",
  263. Config: &pam.Config{
  264. ServiceName: "PAM",
  265. },
  266. })
  267. if err != nil {
  268. t.Fatal(err)
  269. }
  270. _, err = db.Create(CreateLoginSourceOpts{
  271. Type: auth.GitHub,
  272. Name: "GitHub",
  273. Activated: true,
  274. Config: &github.Config{
  275. APIEndpoint: "https://api.github.com",
  276. },
  277. })
  278. if err != nil {
  279. t.Fatal(err)
  280. }
  281. // List all login sources
  282. sources, err := db.List(ListLoginSourceOpts{})
  283. if err != nil {
  284. t.Fatal(err)
  285. }
  286. assert.Equal(t, 4, len(sources), "number of sources")
  287. // Only list activated login sources
  288. sources, err = db.List(ListLoginSourceOpts{OnlyActivated: true})
  289. if err != nil {
  290. t.Fatal(err)
  291. }
  292. assert.Equal(t, 2, len(sources), "number of sources")
  293. }
  294. func test_loginSources_ResetNonDefault(t *testing.T, db *loginSources) {
  295. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  296. MockList: func(opts ListLoginSourceOpts) []*LoginSource {
  297. return []*LoginSource{
  298. {
  299. File: &mockLoginSourceFileStore{
  300. MockSetGeneral: func(name, value string) {
  301. assert.Equal(t, "is_default", name)
  302. assert.Equal(t, "false", value)
  303. },
  304. MockSave: func() error {
  305. return nil
  306. },
  307. },
  308. },
  309. }
  310. },
  311. MockUpdate: func(source *LoginSource) {},
  312. })
  313. // Create two login sources both have default on
  314. source1, err := db.Create(CreateLoginSourceOpts{
  315. Type: auth.PAM,
  316. Name: "PAM",
  317. Default: true,
  318. Config: &pam.Config{
  319. ServiceName: "PAM",
  320. },
  321. })
  322. if err != nil {
  323. t.Fatal(err)
  324. }
  325. source2, err := db.Create(CreateLoginSourceOpts{
  326. Type: auth.GitHub,
  327. Name: "GitHub",
  328. Activated: true,
  329. Default: true,
  330. Config: &github.Config{
  331. APIEndpoint: "https://api.github.com",
  332. },
  333. })
  334. if err != nil {
  335. t.Fatal(err)
  336. }
  337. // Set source 1 as default
  338. err = db.ResetNonDefault(source1)
  339. if err != nil {
  340. t.Fatal(err)
  341. }
  342. // Verify the default state
  343. source1, err = db.GetByID(source1.ID)
  344. if err != nil {
  345. t.Fatal(err)
  346. }
  347. assert.True(t, source1.IsDefault)
  348. source2, err = db.GetByID(source2.ID)
  349. if err != nil {
  350. t.Fatal(err)
  351. }
  352. assert.False(t, source2.IsDefault)
  353. }
  354. func test_loginSources_Save(t *testing.T, db *loginSources) {
  355. t.Run("save to database", func(t *testing.T) {
  356. // Create a login source with name "GitHub"
  357. source, err := db.Create(CreateLoginSourceOpts{
  358. Type: auth.GitHub,
  359. Name: "GitHub",
  360. Activated: true,
  361. Default: false,
  362. Config: &github.Config{
  363. APIEndpoint: "https://api.github.com",
  364. },
  365. })
  366. if err != nil {
  367. t.Fatal(err)
  368. }
  369. source.IsActived = false
  370. source.Provider = github.NewProvider(&github.Config{
  371. APIEndpoint: "https://api2.github.com",
  372. })
  373. err = db.Save(source)
  374. if err != nil {
  375. t.Fatal(err)
  376. }
  377. source, err = db.GetByID(source.ID)
  378. if err != nil {
  379. t.Fatal(err)
  380. }
  381. assert.False(t, source.IsActived)
  382. assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
  383. })
  384. t.Run("save to file", func(t *testing.T) {
  385. calledSave := false
  386. source := &LoginSource{
  387. Provider: github.NewProvider(&github.Config{
  388. APIEndpoint: "https://api.github.com",
  389. }),
  390. File: &mockLoginSourceFileStore{
  391. MockSetGeneral: func(name, value string) {},
  392. MockSetConfig: func(cfg interface{}) error { return nil },
  393. MockSave: func() error {
  394. calledSave = true
  395. return nil
  396. },
  397. },
  398. }
  399. err := db.Save(source)
  400. if err != nil {
  401. t.Fatal(err)
  402. }
  403. assert.True(t, calledSave)
  404. })
  405. }