login_sources.go 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "fmt"
  7. "strconv"
  8. "time"
  9. "github.com/jinzhu/gorm"
  10. jsoniter "github.com/json-iterator/go"
  11. "github.com/pkg/errors"
  12. "gogs.io/gogs/internal/auth/ldap"
  13. "gogs.io/gogs/internal/errutil"
  14. )
  15. // LoginSourcesStore is the persistent interface for login sources.
  16. //
  17. // NOTE: All methods are sorted in alphabetical order.
  18. type LoginSourcesStore interface {
  19. // Create creates a new login source and persist to database.
  20. // It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
  21. Create(opts CreateLoginSourceOpts) (*LoginSource, error)
  22. // Count returns the total number of login sources.
  23. Count() int64
  24. // DeleteByID deletes a login source by given ID.
  25. // It returns ErrLoginSourceInUse if at least one user is associated with the login source.
  26. DeleteByID(id int64) error
  27. // GetByID returns the login source with given ID.
  28. // It returns ErrLoginSourceNotExist when not found.
  29. GetByID(id int64) (*LoginSource, error)
  30. // List returns a list of login sources filtered by options.
  31. List(opts ListLoginSourceOpts) ([]*LoginSource, error)
  32. // ResetNonDefault clears default flag for all the other login sources.
  33. ResetNonDefault(source *LoginSource) error
  34. // Save persists all values of given login source to database or local file.
  35. // The Updated field is set to current time automatically.
  36. Save(t *LoginSource) error
  37. }
  38. var LoginSources LoginSourcesStore
  39. // LoginSource represents an external way for authorizing users.
  40. type LoginSource struct {
  41. ID int64
  42. Type LoginType
  43. Name string `xorm:"UNIQUE" gorm:"UNIQUE"`
  44. IsActived bool `xorm:"NOT NULL DEFAULT false" gorm:"NOT NULL"`
  45. IsDefault bool `xorm:"DEFAULT false"`
  46. Config interface{} `xorm:"-" gorm:"-"`
  47. RawConfig string `xorm:"TEXT cfg" gorm:"COLUMN:cfg"`
  48. Created time.Time `xorm:"-" gorm:"-" json:"-"`
  49. CreatedUnix int64
  50. Updated time.Time `xorm:"-" gorm:"-" json:"-"`
  51. UpdatedUnix int64
  52. File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
  53. }
  54. // NOTE: This is a GORM save hook.
  55. func (s *LoginSource) BeforeSave() (err error) {
  56. if s.Config == nil {
  57. return nil
  58. }
  59. s.RawConfig, err = jsoniter.MarshalToString(s.Config)
  60. return err
  61. }
  62. // NOTE: This is a GORM create hook.
  63. func (s *LoginSource) BeforeCreate() {
  64. if s.CreatedUnix > 0 {
  65. return
  66. }
  67. s.CreatedUnix = gorm.NowFunc().Unix()
  68. s.UpdatedUnix = s.CreatedUnix
  69. }
  70. // NOTE: This is a GORM update hook.
  71. func (s *LoginSource) BeforeUpdate() {
  72. s.UpdatedUnix = gorm.NowFunc().Unix()
  73. }
  74. // NOTE: This is a GORM query hook.
  75. func (s *LoginSource) AfterFind() error {
  76. s.Created = time.Unix(s.CreatedUnix, 0).Local()
  77. s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
  78. switch s.Type {
  79. case LoginLDAP, LoginDLDAP:
  80. s.Config = new(LDAPConfig)
  81. case LoginSMTP:
  82. s.Config = new(SMTPConfig)
  83. case LoginPAM:
  84. s.Config = new(PAMConfig)
  85. case LoginGitHub:
  86. s.Config = new(GitHubConfig)
  87. default:
  88. return fmt.Errorf("unrecognized login source type: %v", s.Type)
  89. }
  90. return jsoniter.UnmarshalFromString(s.RawConfig, s.Config)
  91. }
  92. func (s *LoginSource) TypeName() string {
  93. return LoginNames[s.Type]
  94. }
  95. func (s *LoginSource) IsLDAP() bool {
  96. return s.Type == LoginLDAP
  97. }
  98. func (s *LoginSource) IsDLDAP() bool {
  99. return s.Type == LoginDLDAP
  100. }
  101. func (s *LoginSource) IsSMTP() bool {
  102. return s.Type == LoginSMTP
  103. }
  104. func (s *LoginSource) IsPAM() bool {
  105. return s.Type == LoginPAM
  106. }
  107. func (s *LoginSource) IsGitHub() bool {
  108. return s.Type == LoginGitHub
  109. }
  110. func (s *LoginSource) HasTLS() bool {
  111. return ((s.IsLDAP() || s.IsDLDAP()) &&
  112. s.LDAP().SecurityProtocol > ldap.SecurityProtocolUnencrypted) ||
  113. s.IsSMTP()
  114. }
  115. func (s *LoginSource) UseTLS() bool {
  116. switch s.Type {
  117. case LoginLDAP, LoginDLDAP:
  118. return s.LDAP().SecurityProtocol != ldap.SecurityProtocolUnencrypted
  119. case LoginSMTP:
  120. return s.SMTP().TLS
  121. }
  122. return false
  123. }
  124. func (s *LoginSource) SkipVerify() bool {
  125. switch s.Type {
  126. case LoginLDAP, LoginDLDAP:
  127. return s.LDAP().SkipVerify
  128. case LoginSMTP:
  129. return s.SMTP().SkipVerify
  130. }
  131. return false
  132. }
  133. func (s *LoginSource) LDAP() *LDAPConfig {
  134. return s.Config.(*LDAPConfig)
  135. }
  136. func (s *LoginSource) SMTP() *SMTPConfig {
  137. return s.Config.(*SMTPConfig)
  138. }
  139. func (s *LoginSource) PAM() *PAMConfig {
  140. return s.Config.(*PAMConfig)
  141. }
  142. func (s *LoginSource) GitHub() *GitHubConfig {
  143. return s.Config.(*GitHubConfig)
  144. }
  145. var _ LoginSourcesStore = (*loginSources)(nil)
  146. type loginSources struct {
  147. *gorm.DB
  148. files loginSourceFilesStore
  149. }
  150. type CreateLoginSourceOpts struct {
  151. Type LoginType
  152. Name string
  153. Activated bool
  154. Default bool
  155. Config interface{}
  156. }
  157. type ErrLoginSourceAlreadyExist struct {
  158. args errutil.Args
  159. }
  160. func IsErrLoginSourceAlreadyExist(err error) bool {
  161. _, ok := err.(ErrLoginSourceAlreadyExist)
  162. return ok
  163. }
  164. func (err ErrLoginSourceAlreadyExist) Error() string {
  165. return fmt.Sprintf("login source already exists: %v", err.args)
  166. }
  167. func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
  168. err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
  169. if err == nil {
  170. return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
  171. } else if !gorm.IsRecordNotFoundError(err) {
  172. return nil, err
  173. }
  174. source := &LoginSource{
  175. Type: opts.Type,
  176. Name: opts.Name,
  177. IsActived: opts.Activated,
  178. IsDefault: opts.Default,
  179. Config: opts.Config,
  180. }
  181. return source, db.DB.Create(source).Error
  182. }
  183. func (db *loginSources) Count() int64 {
  184. var count int64
  185. db.Model(new(LoginSource)).Count(&count)
  186. return count + int64(db.files.Len())
  187. }
  188. type ErrLoginSourceInUse struct {
  189. args errutil.Args
  190. }
  191. func IsErrLoginSourceInUse(err error) bool {
  192. _, ok := err.(ErrLoginSourceInUse)
  193. return ok
  194. }
  195. func (err ErrLoginSourceInUse) Error() string {
  196. return fmt.Sprintf("login source is still used by some users: %v", err.args)
  197. }
  198. func (db *loginSources) DeleteByID(id int64) error {
  199. var count int64
  200. err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
  201. if err != nil {
  202. return err
  203. } else if count > 0 {
  204. return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
  205. }
  206. return db.Where("id = ?", id).Delete(new(LoginSource)).Error
  207. }
  208. func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
  209. source := new(LoginSource)
  210. err := db.Where("id = ?", id).First(source).Error
  211. if err != nil {
  212. if gorm.IsRecordNotFoundError(err) {
  213. return db.files.GetByID(id)
  214. }
  215. return nil, err
  216. }
  217. return source, nil
  218. }
  219. type ListLoginSourceOpts struct {
  220. // Whether to only include activated login sources.
  221. OnlyActivated bool
  222. }
  223. func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
  224. var sources []*LoginSource
  225. query := db.Order("id ASC")
  226. if opts.OnlyActivated {
  227. query = query.Where("is_actived = ?", true)
  228. }
  229. err := query.Find(&sources).Error
  230. if err != nil {
  231. return nil, err
  232. }
  233. return append(sources, db.files.List(opts)...), nil
  234. }
  235. func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
  236. err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
  237. if err != nil {
  238. return err
  239. }
  240. for _, source := range db.files.List(ListLoginSourceOpts{}) {
  241. if source.File != nil && source.ID != dflt.ID {
  242. source.File.SetGeneral("is_default", "false")
  243. if err = source.File.Save(); err != nil {
  244. return errors.Wrap(err, "save file")
  245. }
  246. }
  247. }
  248. db.files.Update(dflt)
  249. return nil
  250. }
  251. func (db *loginSources) Save(source *LoginSource) error {
  252. if source.File == nil {
  253. return db.DB.Save(source).Error
  254. }
  255. source.File.SetGeneral("name", source.Name)
  256. source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
  257. source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
  258. if err := source.File.SetConfig(source.Config); err != nil {
  259. return errors.Wrap(err, "set config")
  260. } else if err = source.File.Save(); err != nil {
  261. return errors.Wrap(err, "save file")
  262. }
  263. return nil
  264. }