login_sources.go 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. // Copyright 2020 The Gogs Authors. All rights reserved.
  2. // Use of this source code is governed by a MIT-style
  3. // license that can be found in the LICENSE file.
  4. package db
  5. import (
  6. "fmt"
  7. "strconv"
  8. "time"
  9. jsoniter "github.com/json-iterator/go"
  10. "github.com/pkg/errors"
  11. "gorm.io/gorm"
  12. "gogs.io/gogs/internal/auth/ldap"
  13. "gogs.io/gogs/internal/errutil"
  14. )
  15. // LoginSourcesStore is the persistent interface for login sources.
  16. //
  17. // NOTE: All methods are sorted in alphabetical order.
  18. type LoginSourcesStore interface {
  19. // Create creates a new login source and persist to database.
  20. // It returns ErrLoginSourceAlreadyExist when a login source with same name already exists.
  21. Create(opts CreateLoginSourceOpts) (*LoginSource, error)
  22. // Count returns the total number of login sources.
  23. Count() int64
  24. // DeleteByID deletes a login source by given ID.
  25. // It returns ErrLoginSourceInUse if at least one user is associated with the login source.
  26. DeleteByID(id int64) error
  27. // GetByID returns the login source with given ID.
  28. // It returns ErrLoginSourceNotExist when not found.
  29. GetByID(id int64) (*LoginSource, error)
  30. // List returns a list of login sources filtered by options.
  31. List(opts ListLoginSourceOpts) ([]*LoginSource, error)
  32. // ResetNonDefault clears default flag for all the other login sources.
  33. ResetNonDefault(source *LoginSource) error
  34. // Save persists all values of given login source to database or local file.
  35. // The Updated field is set to current time automatically.
  36. Save(t *LoginSource) error
  37. }
  38. var LoginSources LoginSourcesStore
  39. // LoginSource represents an external way for authorizing users.
  40. type LoginSource struct {
  41. ID int64
  42. Type LoginType
  43. Name string `xorm:"UNIQUE" gorm:"UNIQUE"`
  44. IsActived bool `xorm:"NOT NULL DEFAULT false" gorm:"NOT NULL"`
  45. IsDefault bool `xorm:"DEFAULT false"`
  46. Config interface{} `xorm:"-" gorm:"-"`
  47. RawConfig string `xorm:"TEXT cfg" gorm:"COLUMN:cfg;TYPE:TEXT"`
  48. Created time.Time `xorm:"-" gorm:"-" json:"-"`
  49. CreatedUnix int64
  50. Updated time.Time `xorm:"-" gorm:"-" json:"-"`
  51. UpdatedUnix int64
  52. File loginSourceFileStore `xorm:"-" gorm:"-" json:"-"`
  53. }
  54. // NOTE: This is a GORM save hook.
  55. func (s *LoginSource) BeforeSave(tx *gorm.DB) (err error) {
  56. if s.Config == nil {
  57. return nil
  58. }
  59. s.RawConfig, err = jsoniter.MarshalToString(s.Config)
  60. return err
  61. }
  62. // NOTE: This is a GORM create hook.
  63. func (s *LoginSource) BeforeCreate(tx *gorm.DB) error {
  64. if s.CreatedUnix == 0 {
  65. s.CreatedUnix = tx.NowFunc().Unix()
  66. s.UpdatedUnix = s.CreatedUnix
  67. }
  68. return nil
  69. }
  70. // NOTE: This is a GORM update hook.
  71. func (s *LoginSource) BeforeUpdate(tx *gorm.DB) error {
  72. s.UpdatedUnix = tx.NowFunc().Unix()
  73. return nil
  74. }
  75. // NOTE: This is a GORM query hook.
  76. func (s *LoginSource) AfterFind(tx *gorm.DB) error {
  77. s.Created = time.Unix(s.CreatedUnix, 0).Local()
  78. s.Updated = time.Unix(s.UpdatedUnix, 0).Local()
  79. switch s.Type {
  80. case LoginLDAP, LoginDLDAP:
  81. s.Config = new(LDAPConfig)
  82. case LoginSMTP:
  83. s.Config = new(SMTPConfig)
  84. case LoginPAM:
  85. s.Config = new(PAMConfig)
  86. case LoginGitHub:
  87. s.Config = new(GitHubConfig)
  88. default:
  89. return fmt.Errorf("unrecognized login source type: %v", s.Type)
  90. }
  91. return jsoniter.UnmarshalFromString(s.RawConfig, s.Config)
  92. }
  93. func (s *LoginSource) TypeName() string {
  94. return LoginNames[s.Type]
  95. }
  96. func (s *LoginSource) IsLDAP() bool {
  97. return s.Type == LoginLDAP
  98. }
  99. func (s *LoginSource) IsDLDAP() bool {
  100. return s.Type == LoginDLDAP
  101. }
  102. func (s *LoginSource) IsSMTP() bool {
  103. return s.Type == LoginSMTP
  104. }
  105. func (s *LoginSource) IsPAM() bool {
  106. return s.Type == LoginPAM
  107. }
  108. func (s *LoginSource) IsGitHub() bool {
  109. return s.Type == LoginGitHub
  110. }
  111. func (s *LoginSource) HasTLS() bool {
  112. return ((s.IsLDAP() || s.IsDLDAP()) &&
  113. s.LDAP().SecurityProtocol > ldap.SecurityProtocolUnencrypted) ||
  114. s.IsSMTP()
  115. }
  116. func (s *LoginSource) UseTLS() bool {
  117. switch s.Type {
  118. case LoginLDAP, LoginDLDAP:
  119. return s.LDAP().SecurityProtocol != ldap.SecurityProtocolUnencrypted
  120. case LoginSMTP:
  121. return s.SMTP().TLS
  122. }
  123. return false
  124. }
  125. func (s *LoginSource) SkipVerify() bool {
  126. switch s.Type {
  127. case LoginLDAP, LoginDLDAP:
  128. return s.LDAP().SkipVerify
  129. case LoginSMTP:
  130. return s.SMTP().SkipVerify
  131. }
  132. return false
  133. }
  134. func (s *LoginSource) LDAP() *LDAPConfig {
  135. return s.Config.(*LDAPConfig)
  136. }
  137. func (s *LoginSource) SMTP() *SMTPConfig {
  138. return s.Config.(*SMTPConfig)
  139. }
  140. func (s *LoginSource) PAM() *PAMConfig {
  141. return s.Config.(*PAMConfig)
  142. }
  143. func (s *LoginSource) GitHub() *GitHubConfig {
  144. return s.Config.(*GitHubConfig)
  145. }
  146. var _ LoginSourcesStore = (*loginSources)(nil)
  147. type loginSources struct {
  148. *gorm.DB
  149. files loginSourceFilesStore
  150. }
  151. type CreateLoginSourceOpts struct {
  152. Type LoginType
  153. Name string
  154. Activated bool
  155. Default bool
  156. Config interface{}
  157. }
  158. type ErrLoginSourceAlreadyExist struct {
  159. args errutil.Args
  160. }
  161. func IsErrLoginSourceAlreadyExist(err error) bool {
  162. _, ok := err.(ErrLoginSourceAlreadyExist)
  163. return ok
  164. }
  165. func (err ErrLoginSourceAlreadyExist) Error() string {
  166. return fmt.Sprintf("login source already exists: %v", err.args)
  167. }
  168. func (db *loginSources) Create(opts CreateLoginSourceOpts) (*LoginSource, error) {
  169. err := db.Where("name = ?", opts.Name).First(new(LoginSource)).Error
  170. if err == nil {
  171. return nil, ErrLoginSourceAlreadyExist{args: errutil.Args{"name": opts.Name}}
  172. } else if err != gorm.ErrRecordNotFound {
  173. return nil, err
  174. }
  175. source := &LoginSource{
  176. Type: opts.Type,
  177. Name: opts.Name,
  178. IsActived: opts.Activated,
  179. IsDefault: opts.Default,
  180. Config: opts.Config,
  181. }
  182. return source, db.DB.Create(source).Error
  183. }
  184. func (db *loginSources) Count() int64 {
  185. var count int64
  186. db.Model(new(LoginSource)).Count(&count)
  187. return count + int64(db.files.Len())
  188. }
  189. type ErrLoginSourceInUse struct {
  190. args errutil.Args
  191. }
  192. func IsErrLoginSourceInUse(err error) bool {
  193. _, ok := err.(ErrLoginSourceInUse)
  194. return ok
  195. }
  196. func (err ErrLoginSourceInUse) Error() string {
  197. return fmt.Sprintf("login source is still used by some users: %v", err.args)
  198. }
  199. func (db *loginSources) DeleteByID(id int64) error {
  200. var count int64
  201. err := db.Model(new(User)).Where("login_source = ?", id).Count(&count).Error
  202. if err != nil {
  203. return err
  204. } else if count > 0 {
  205. return ErrLoginSourceInUse{args: errutil.Args{"id": id}}
  206. }
  207. return db.Where("id = ?", id).Delete(new(LoginSource)).Error
  208. }
  209. func (db *loginSources) GetByID(id int64) (*LoginSource, error) {
  210. source := new(LoginSource)
  211. err := db.Where("id = ?", id).First(source).Error
  212. if err != nil {
  213. if err == gorm.ErrRecordNotFound {
  214. return db.files.GetByID(id)
  215. }
  216. return nil, err
  217. }
  218. return source, nil
  219. }
  220. type ListLoginSourceOpts struct {
  221. // Whether to only include activated login sources.
  222. OnlyActivated bool
  223. }
  224. func (db *loginSources) List(opts ListLoginSourceOpts) ([]*LoginSource, error) {
  225. var sources []*LoginSource
  226. query := db.Order("id ASC")
  227. if opts.OnlyActivated {
  228. query = query.Where("is_actived = ?", true)
  229. }
  230. err := query.Find(&sources).Error
  231. if err != nil {
  232. return nil, err
  233. }
  234. return append(sources, db.files.List(opts)...), nil
  235. }
  236. func (db *loginSources) ResetNonDefault(dflt *LoginSource) error {
  237. err := db.Model(new(LoginSource)).Where("id != ?", dflt.ID).Updates(map[string]interface{}{"is_default": false}).Error
  238. if err != nil {
  239. return err
  240. }
  241. for _, source := range db.files.List(ListLoginSourceOpts{}) {
  242. if source.File != nil && source.ID != dflt.ID {
  243. source.File.SetGeneral("is_default", "false")
  244. if err = source.File.Save(); err != nil {
  245. return errors.Wrap(err, "save file")
  246. }
  247. }
  248. }
  249. db.files.Update(dflt)
  250. return nil
  251. }
  252. func (db *loginSources) Save(source *LoginSource) error {
  253. if source.File == nil {
  254. return db.DB.Save(source).Error
  255. }
  256. source.File.SetGeneral("name", source.Name)
  257. source.File.SetGeneral("is_activated", strconv.FormatBool(source.IsActived))
  258. source.File.SetGeneral("is_default", strconv.FormatBool(source.IsDefault))
  259. if err := source.File.SetConfig(source.Config); err != nil {
  260. return errors.Wrap(err, "set config")
  261. } else if err = source.File.Save(); err != nil {
  262. return errors.Wrap(err, "save file")
  263. }
  264. return nil
  265. }