login_sources_test.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. "gorm.io/gorm"
  10. "gogs.io/gogs/internal/errutil"
  11. )
  12. func TestLoginSource_BeforeSave(t *testing.T) {
  13. now := time.Now()
  14. db := &gorm.DB{
  15. Config: &gorm.Config{
  16. NowFunc: func() time.Time {
  17. return now
  18. },
  19. },
  20. }
  21. t.Run("Config has not been set", func(t *testing.T) {
  22. s := &LoginSource{}
  23. err := s.BeforeSave(db)
  24. if err != nil {
  25. t.Fatal(err)
  26. }
  27. assert.Empty(t, s.RawConfig)
  28. })
  29. t.Run("Config has been set", func(t *testing.T) {
  30. s := &LoginSource{
  31. Config: &PAMConfig{ServiceName: "pam_service"},
  32. }
  33. err := s.BeforeSave(db)
  34. if err != nil {
  35. t.Fatal(err)
  36. }
  37. assert.Equal(t, `{"ServiceName":"pam_service"}`, s.RawConfig)
  38. })
  39. }
  40. func TestLoginSource_BeforeCreate(t *testing.T) {
  41. now := time.Now()
  42. db := &gorm.DB{
  43. Config: &gorm.Config{
  44. NowFunc: func() time.Time {
  45. return now
  46. },
  47. },
  48. }
  49. t.Run("CreatedUnix has been set", func(t *testing.T) {
  50. s := &LoginSource{CreatedUnix: 1}
  51. _ = s.BeforeCreate(db)
  52. assert.Equal(t, int64(1), s.CreatedUnix)
  53. assert.Equal(t, int64(0), s.UpdatedUnix)
  54. })
  55. t.Run("CreatedUnix has not been set", func(t *testing.T) {
  56. s := &LoginSource{}
  57. _ = s.BeforeCreate(db)
  58. assert.Equal(t, db.NowFunc().Unix(), s.CreatedUnix)
  59. assert.Equal(t, db.NowFunc().Unix(), s.UpdatedUnix)
  60. })
  61. }
  62. func Test_loginSources(t *testing.T) {
  63. if testing.Short() {
  64. t.Skip()
  65. }
  66. t.Parallel()
  67. tables := []interface{}{new(LoginSource), new(User)}
  68. db := &loginSources{
  69. DB: initTestDB(t, "loginSources", tables...),
  70. }
  71. for _, tc := range []struct {
  72. name string
  73. test func(*testing.T, *loginSources)
  74. }{
  75. {"Create", test_loginSources_Create},
  76. {"Count", test_loginSources_Count},
  77. {"DeleteByID", test_loginSources_DeleteByID},
  78. {"GetByID", test_loginSources_GetByID},
  79. {"List", test_loginSources_List},
  80. {"ResetNonDefault", test_loginSources_ResetNonDefault},
  81. {"Save", test_loginSources_Save},
  82. } {
  83. t.Run(tc.name, func(t *testing.T) {
  84. t.Cleanup(func() {
  85. err := clearTables(t, db.DB, tables...)
  86. if err != nil {
  87. t.Fatal(err)
  88. }
  89. })
  90. tc.test(t, db)
  91. })
  92. }
  93. }
  94. func test_loginSources_Create(t *testing.T, db *loginSources) {
  95. // Create first login source with name "GitHub"
  96. source, err := db.Create(CreateLoginSourceOpts{
  97. Type: LoginGitHub,
  98. Name: "GitHub",
  99. Activated: true,
  100. Default: false,
  101. Config: &GitHubConfig{
  102. APIEndpoint: "https://api.github.com",
  103. },
  104. })
  105. if err != nil {
  106. t.Fatal(err)
  107. }
  108. // Get it back and check the Created field
  109. source, err = db.GetByID(source.ID)
  110. if err != nil {
  111. t.Fatal(err)
  112. }
  113. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Created.UTC().Format(time.RFC3339))
  114. assert.Equal(t, db.NowFunc().Format(time.RFC3339), source.Updated.UTC().Format(time.RFC3339))
  115. // Try create second login source with same name should fail
  116. _, err = db.Create(CreateLoginSourceOpts{Name: source.Name})
  117. expErr := ErrLoginSourceAlreadyExist{args: errutil.Args{"name": source.Name}}
  118. assert.Equal(t, expErr, err)
  119. }
  120. func test_loginSources_Count(t *testing.T, db *loginSources) {
  121. // Create two login sources, one in database and one as source file.
  122. _, err := db.Create(CreateLoginSourceOpts{
  123. Type: LoginGitHub,
  124. Name: "GitHub",
  125. Activated: true,
  126. Default: false,
  127. Config: &GitHubConfig{
  128. APIEndpoint: "https://api.github.com",
  129. },
  130. })
  131. if err != nil {
  132. t.Fatal(err)
  133. }
  134. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  135. MockLen: func() int {
  136. return 2
  137. },
  138. })
  139. assert.Equal(t, int64(3), db.Count())
  140. }
  141. func test_loginSources_DeleteByID(t *testing.T, db *loginSources) {
  142. t.Run("delete but in used", func(t *testing.T) {
  143. source, err := db.Create(CreateLoginSourceOpts{
  144. Type: LoginGitHub,
  145. Name: "GitHub",
  146. Activated: true,
  147. Default: false,
  148. Config: &GitHubConfig{
  149. APIEndpoint: "https://api.github.com",
  150. },
  151. })
  152. if err != nil {
  153. t.Fatal(err)
  154. }
  155. // Create a user that uses this login source
  156. _, err = (&users{DB: db.DB}).Create(CreateUserOpts{
  157. Name: "alice",
  158. LoginSource: source.ID,
  159. })
  160. if err != nil {
  161. t.Fatal(err)
  162. }
  163. // Delete the login source will result in error
  164. err = db.DeleteByID(source.ID)
  165. expErr := ErrLoginSourceInUse{args: errutil.Args{"id": source.ID}}
  166. assert.Equal(t, expErr, err)
  167. })
  168. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  169. MockGetByID: func(id int64) (*LoginSource, error) {
  170. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  171. },
  172. })
  173. // Create a login source with name "GitHub2"
  174. source, err := db.Create(CreateLoginSourceOpts{
  175. Type: LoginGitHub,
  176. Name: "GitHub2",
  177. Activated: true,
  178. Default: false,
  179. Config: &GitHubConfig{
  180. APIEndpoint: "https://api.github.com",
  181. },
  182. })
  183. if err != nil {
  184. t.Fatal(err)
  185. }
  186. // Delete a non-existent ID is noop
  187. err = db.DeleteByID(9999)
  188. if err != nil {
  189. t.Fatal(err)
  190. }
  191. // We should be able to get it back
  192. _, err = db.GetByID(source.ID)
  193. if err != nil {
  194. t.Fatal(err)
  195. }
  196. // Now delete this login source with ID
  197. err = db.DeleteByID(source.ID)
  198. if err != nil {
  199. t.Fatal(err)
  200. }
  201. // We should get token not found error
  202. _, err = db.GetByID(source.ID)
  203. expErr := ErrLoginSourceNotExist{args: errutil.Args{"id": source.ID}}
  204. assert.Equal(t, expErr, err)
  205. }
  206. func test_loginSources_GetByID(t *testing.T, db *loginSources) {
  207. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  208. MockGetByID: func(id int64) (*LoginSource, error) {
  209. if id != 101 {
  210. return nil, ErrLoginSourceNotExist{args: errutil.Args{"id": id}}
  211. }
  212. return &LoginSource{ID: id}, nil
  213. },
  214. })
  215. expConfig := &GitHubConfig{
  216. APIEndpoint: "https://api.github.com",
  217. }
  218. // Create a login source with name "GitHub"
  219. source, err := db.Create(CreateLoginSourceOpts{
  220. Type: LoginGitHub,
  221. Name: "GitHub",
  222. Activated: true,
  223. Default: false,
  224. Config: expConfig,
  225. })
  226. if err != nil {
  227. t.Fatal(err)
  228. }
  229. // Get the one in the database and test the read/write hooks
  230. source, err = db.GetByID(source.ID)
  231. if err != nil {
  232. t.Fatal(err)
  233. }
  234. assert.Equal(t, expConfig, source.Config)
  235. // Get the one in source file store
  236. _, err = db.GetByID(101)
  237. if err != nil {
  238. t.Fatal(err)
  239. }
  240. }
  241. func test_loginSources_List(t *testing.T, db *loginSources) {
  242. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  243. MockList: func(opts ListLoginSourceOpts) []*LoginSource {
  244. if opts.OnlyActivated {
  245. return []*LoginSource{
  246. {ID: 1},
  247. }
  248. }
  249. return []*LoginSource{
  250. {ID: 1},
  251. {ID: 2},
  252. }
  253. },
  254. })
  255. // Create two login sources in database, one activated and the other one not
  256. _, err := db.Create(CreateLoginSourceOpts{
  257. Type: LoginPAM,
  258. Name: "PAM",
  259. Config: &PAMConfig{
  260. ServiceName: "PAM",
  261. },
  262. })
  263. if err != nil {
  264. t.Fatal(err)
  265. }
  266. _, err = db.Create(CreateLoginSourceOpts{
  267. Type: LoginGitHub,
  268. Name: "GitHub",
  269. Activated: true,
  270. Config: &GitHubConfig{
  271. APIEndpoint: "https://api.github.com",
  272. },
  273. })
  274. if err != nil {
  275. t.Fatal(err)
  276. }
  277. // List all login sources
  278. sources, err := db.List(ListLoginSourceOpts{})
  279. if err != nil {
  280. t.Fatal(err)
  281. }
  282. assert.Equal(t, 4, len(sources), "number of sources")
  283. // Only list activated login sources
  284. sources, err = db.List(ListLoginSourceOpts{OnlyActivated: true})
  285. if err != nil {
  286. t.Fatal(err)
  287. }
  288. assert.Equal(t, 2, len(sources), "number of sources")
  289. }
  290. func test_loginSources_ResetNonDefault(t *testing.T, db *loginSources) {
  291. setMockLoginSourceFilesStore(t, db, &mockLoginSourceFilesStore{
  292. MockList: func(opts ListLoginSourceOpts) []*LoginSource {
  293. return []*LoginSource{
  294. {
  295. File: &mockLoginSourceFileStore{
  296. MockSetGeneral: func(name, value string) {
  297. assert.Equal(t, "is_default", name)
  298. assert.Equal(t, "false", value)
  299. },
  300. MockSave: func() error {
  301. return nil
  302. },
  303. },
  304. },
  305. }
  306. },
  307. MockUpdate: func(source *LoginSource) {},
  308. })
  309. // Create two login sources both have default on
  310. source1, err := db.Create(CreateLoginSourceOpts{
  311. Type: LoginPAM,
  312. Name: "PAM",
  313. Default: true,
  314. Config: &PAMConfig{
  315. ServiceName: "PAM",
  316. },
  317. })
  318. if err != nil {
  319. t.Fatal(err)
  320. }
  321. source2, err := db.Create(CreateLoginSourceOpts{
  322. Type: LoginGitHub,
  323. Name: "GitHub",
  324. Activated: true,
  325. Default: true,
  326. Config: &GitHubConfig{
  327. APIEndpoint: "https://api.github.com",
  328. },
  329. })
  330. if err != nil {
  331. t.Fatal(err)
  332. }
  333. // Set source 1 as default
  334. err = db.ResetNonDefault(source1)
  335. if err != nil {
  336. t.Fatal(err)
  337. }
  338. // Verify the default state
  339. source1, err = db.GetByID(source1.ID)
  340. if err != nil {
  341. t.Fatal(err)
  342. }
  343. assert.True(t, source1.IsDefault)
  344. source2, err = db.GetByID(source2.ID)
  345. if err != nil {
  346. t.Fatal(err)
  347. }
  348. assert.False(t, source2.IsDefault)
  349. }
  350. func test_loginSources_Save(t *testing.T, db *loginSources) {
  351. t.Run("save to database", func(t *testing.T) {
  352. // Create a login source with name "GitHub"
  353. source, err := db.Create(CreateLoginSourceOpts{
  354. Type: LoginGitHub,
  355. Name: "GitHub",
  356. Activated: true,
  357. Default: false,
  358. Config: &GitHubConfig{
  359. APIEndpoint: "https://api.github.com",
  360. },
  361. })
  362. if err != nil {
  363. t.Fatal(err)
  364. }
  365. source.IsActived = false
  366. source.Config = &GitHubConfig{
  367. APIEndpoint: "https://api2.github.com",
  368. }
  369. err = db.Save(source)
  370. if err != nil {
  371. t.Fatal(err)
  372. }
  373. source, err = db.GetByID(source.ID)
  374. if err != nil {
  375. t.Fatal(err)
  376. }
  377. assert.False(t, source.IsActived)
  378. assert.Equal(t, "https://api2.github.com", source.GitHub().APIEndpoint)
  379. })
  380. t.Run("save to file", func(t *testing.T) {
  381. calledSave := false
  382. source := &LoginSource{
  383. File: &mockLoginSourceFileStore{
  384. MockSetGeneral: func(name, value string) {},
  385. MockSetConfig: func(cfg interface{}) error { return nil },
  386. MockSave: func() error {
  387. calledSave = true
  388. return nil
  389. },
  390. },
  391. }
  392. err := db.Save(source)
  393. if err != nil {
  394. t.Fatal(err)
  395. }
  396. assert.True(t, calledSave)
  397. })
  398. }