login_sources.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "fmt"
  7. "strconv"
  8. "time"
  9. "github.com/jinzhu/gorm"
  10. jsoniter "github.com/json-iterator/go"
  11. "github.com/pkg/errors"
  12. "gogs.io/gogs/internal/auth/ldap"
  13. "gogs.io/gogs/internal/errutil"
  14. )
  15. // LoginSourcesStore is the persistent interface for login sources.
  16. //
  17. // NOTE: All methods are sorted in alphabetical order.
  18. type LoginSourcesStore interface {
  19. // Create creates a new login source and persist to database.
  20. // It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
  21. Create(opts CreateLoginSourceOpts) (*LoginSource, error)
  22. // Count returns the total number of login sources.
  23. Count() int64
  24. // DeleteByID deletes a login source by given ID.
  25. // It returns ErrLoginSourceInUse if at least one user is associated with the login source.
  26. DeleteByID(id int64) error
  27. // GetByID returns the login source with given ID.
  28. // It returns ErrLoginSourceNotExist when not found.
  29. GetByID(id int64) (*LoginSource, error)
  30. // List returns a list of login sources filtered by options.
  31. List(opts ListLoginSourceOpts) ([]*LoginSource, error)
  32. // ResetNonDefault clears default flag for all the other login sources.
  33. ResetNonDefault(source *LoginSource) error
  34. // Save persists all values of given login source to database or local file.
  35. // The Updated field is set to current time automatically.
  36. Save(t *LoginSource) error
  37. }
  38. var LoginSources LoginSourcesStore
  39. // LoginSource represents an external way for authorizing users.
  40. type LoginSource struct {
  41. ID int64
  42. Type LoginType
  43. Name string `xorm:"UNIQUE"`
  44. IsActived bool `xorm:"NOT NULL DEFAULT false"`
  45. IsDefault bool `xorm:"DEFAULT false"`
  46. Config interface{} `xorm:"-" gorm:"-"`
  47. RawConfig string `xorm:"TEXT cfg" gorm:"COLUMN:cfg"`
  48. Created time.Time `xorm:"-" gorm:"-" json:"-"`
  49. CreatedUnix int64
  50. Updated time.Time `xorm:"-" gorm:"-" json:"-"`
  51. UpdatedUnix int64
  52. File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
  53. }
  54. // NOTE: This is a GORM save hook.
  55. func (s *LoginSource) BeforeSave() (err error) {
  56. s.RawConfig, err = jsoniter.MarshalToString(s.Config)
  57. return err
  58. }
  59. // NOTE: This is a GORM create hook.
  60. func (s *LoginSource) BeforeCreate() {
  61. s.CreatedUnix = gorm.NowFunc().Unix()
  62. s.UpdatedUnix = s.CreatedUnix
  63. }
  64. // NOTE: This is a GORM update hook.
  65. func (s *LoginSource) BeforeUpdate() {
  66. s.UpdatedUnix = gorm.NowFunc().Unix()
  67. }
  68. // NOTE: This is a GORM query hook.
  69. func (s *LoginSource) AfterFind() error {
  70. s.Created = time.Unix(s.CreatedUnix, 0).Local()
  71. s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
  72. switch s.Type {
  73. case LoginLDAP, LoginDLDAP:
  74. s.Config = new(LDAPConfig)
  75. case LoginSMTP:
  76. s.Config = new(SMTPConfig)
  77. case LoginPAM:
  78. s.Config = new(PAMConfig)
  79. case LoginGitHub:
  80. s.Config = new(GitHubConfig)
  81. default:
  82. return fmt.Errorf("unrecognized login source type: %v", s.Type)
  83. }
  84. return jsoniter.UnmarshalFromString(s.RawConfig, s.Config)
  85. }
  86. func (s *LoginSource) TypeName() string {
  87. return LoginNames[s.Type]
  88. }
  89. func (s *LoginSource) IsLDAP() bool {
  90. return s.Type == LoginLDAP
  91. }
  92. func (s *LoginSource) IsDLDAP() bool {
  93. return s.Type == LoginDLDAP
  94. }
  95. func (s *LoginSource) IsSMTP() bool {
  96. return s.Type == LoginSMTP
  97. }
  98. func (s *LoginSource) IsPAM() bool {
  99. return s.Type == LoginPAM
  100. }
  101. func (s *LoginSource) IsGitHub() bool {
  102. return s.Type == LoginGitHub
  103. }
  104. func (s *LoginSource) HasTLS() bool {
  105. return ((s.IsLDAP() || s.IsDLDAP()) &&
  106. s.LDAP().SecurityProtocol > ldap.SecurityProtocolUnencrypted) ||
  107. s.IsSMTP()
  108. }
  109. func (s *LoginSource) UseTLS() bool {
  110. switch s.Type {
  111. case LoginLDAP, LoginDLDAP:
  112. return s.LDAP().SecurityProtocol != ldap.SecurityProtocolUnencrypted
  113. case LoginSMTP:
  114. return s.SMTP().TLS
  115. }
  116. return false
  117. }
  118. func (s *LoginSource) SkipVerify() bool {
  119. switch s.Type {
  120. case LoginLDAP, LoginDLDAP:
  121. return s.LDAP().SkipVerify
  122. case LoginSMTP:
  123. return s.SMTP().SkipVerify
  124. }
  125. return false
  126. }
  127. func (s *LoginSource) LDAP() *LDAPConfig {
  128. return s.Config.(*LDAPConfig)
  129. }
  130. func (s *LoginSource) SMTP() *SMTPConfig {
  131. return s.Config.(*SMTPConfig)
  132. }
  133. func (s *LoginSource) PAM() *PAMConfig {
  134. return s.Config.(*PAMConfig)
  135. }
  136. func (s *LoginSource) GitHub() *GitHubConfig {
  137. return s.Config.(*GitHubConfig)
  138. }
  139. var _ LoginSourcesStore = (*loginSources)(nil)
  140. type loginSources struct {
  141. *gorm.DB
  142. files loginSourceFilesStore
  143. }
  144. type CreateLoginSourceOpts struct {
  145. Type LoginType
  146. Name string
  147. Activated bool
  148. Default bool
  149. Config interface{}
  150. }
  151. type ErrLoginSourceAlreadyExist struct {
  152. args errutil.Args
  153. }
  154. func IsErrLoginSourceAlreadyExist(err error) bool {
  155. _, ok := err.(ErrLoginSourceAlreadyExist)
  156. return ok
  157. }
  158. func (err ErrLoginSourceAlreadyExist) Error() string {
  159. return fmt.Sprintf("login source already exists: %v", err.args)
  160. }
  161. func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
  162. err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
  163. if err == nil {
  164. return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
  165. } else if !gorm.IsRecordNotFoundError(err) {
  166. return nil, err
  167. }
  168. source := &LoginSource{
  169. Type: opts.Type,
  170. Name: opts.Name,
  171. IsActived: opts.Activated,
  172. IsDefault: opts.Default,
  173. Config: opts.Config,
  174. }
  175. return source, db.DB.Create(source).Error
  176. }
  177. func (db *loginSources) Count() int64 {
  178. var count int64
  179. db.Model(new(LoginSource)).Count(&count)
  180. return count + int64(db.files.Len())
  181. }
  182. type ErrLoginSourceInUse struct {
  183. args errutil.Args
  184. }
  185. func IsErrLoginSourceInUse(err error) bool {
  186. _, ok := err.(ErrLoginSourceInUse)
  187. return ok
  188. }
  189. func (err ErrLoginSourceInUse) Error() string {
  190. return fmt.Sprintf("login source is still used by some users: %v", err.args)
  191. }
  192. func (db *loginSources) DeleteByID(id int64) error {
  193. var count int64
  194. err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
  195. if err != nil {
  196. return err
  197. } else if count > 0 {
  198. return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
  199. }
  200. return db.Where("id = ?", id).Delete(new(LoginSource)).Error
  201. }
  202. func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
  203. source := new(LoginSource)
  204. err := db.Where("id = ?", id).First(source).Error
  205. if err != nil {
  206. if gorm.IsRecordNotFoundError(err) {
  207. return db.files.GetByID(id)
  208. }
  209. return nil, err
  210. }
  211. return source, nil
  212. }
  213. type ListLoginSourceOpts struct {
  214. // Whether to only include activated login sources.
  215. OnlyActivated bool
  216. }
  217. func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
  218. var sources []*LoginSource
  219. query := db.Order("id ASC")
  220. if opts.OnlyActivated {
  221. query = query.Where("is_actived = ?", true)
  222. }
  223. err := query.Find(&sources).Error
  224. if err != nil {
  225. return nil, err
  226. }
  227. return append(sources, db.files.List(opts)...), nil
  228. }
  229. func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
  230. err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
  231. if err != nil {
  232. return err
  233. }
  234. for _, source := range db.files.List(ListLoginSourceOpts{}) {
  235. if source.File != nil && source.ID != dflt.ID {
  236. source.File.SetGeneral("is_default", "false")
  237. if err = source.File.Save(); err != nil {
  238. return errors.Wrap(err, "save file")
  239. }
  240. }
  241. }
  242. db.files.Update(dflt)
  243. return nil
  244. }
  245. func (db *loginSources) Save(source *LoginSource) error {
  246. if source.File == nil {
  247. return db.DB.Save(source).Error
  248. }
  249. source.File.SetGeneral("name", source.Name)
  250. source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
  251. source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
  252. if err := source.File.SetConfig(source.Config); err != nil {
  253. return errors.Wrap(err, "set config")
  254. } else if err = source.File.Save(); err != nil {
  255. return errors.Wrap(err, "save file")
  256. }
  257. return nil
  258. }