Browse Source

lfs: add unit test for middleware (#6070)

* Add unit test for `authenticate` middleware

* Add more cases

* Add tests for verifyOID and internalServerError

* Add tests for verifyHeader

* Add tests for authroize
ᴜɴᴋɴᴡᴏɴ 4 years ago
parent
commit
ee0ea2c5fc

+ 121 - 0
internal/db/mocks.go

@@ -0,0 +1,121 @@
+// Copyright 2020 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package db
+
+import (
+	"testing"
+)
+
+// NOTE: Mocks are sorted in alphabetical order.
+
+var _ AccessTokensStore = (*MockAccessTokensStore)(nil)
+
+type MockAccessTokensStore struct {
+	MockGetBySHA func(sha string) (*AccessToken, error)
+	MockSave     func(t *AccessToken) error
+}
+
+func (m *MockAccessTokensStore) GetBySHA(sha string) (*AccessToken, error) {
+	return m.MockGetBySHA(sha)
+}
+
+func (m *MockAccessTokensStore) Save(t *AccessToken) error {
+	return m.MockSave(t)
+}
+
+func SetMockAccessTokensStore(t *testing.T, mock AccessTokensStore) {
+	before := AccessTokens
+	AccessTokens = mock
+	t.Cleanup(func() {
+		AccessTokens = before
+	})
+}
+
+var _ PermsStore = (*MockPermsStore)(nil)
+
+type MockPermsStore struct {
+	MockAccessMode func(userID int64, repo *Repository) AccessMode
+	MockAuthorize  func(userID int64, repo *Repository, desired AccessMode) bool
+}
+
+func (m *MockPermsStore) AccessMode(userID int64, repo *Repository) AccessMode {
+	return m.MockAccessMode(userID, repo)
+}
+
+func (m *MockPermsStore) Authorize(userID int64, repo *Repository, desired AccessMode) bool {
+	return m.MockAuthorize(userID, repo, desired)
+}
+
+func SetMockPermsStore(t *testing.T, mock PermsStore) {
+	before := Perms
+	Perms = mock
+	t.Cleanup(func() {
+		Perms = before
+	})
+}
+
+var _ ReposStore = (*MockReposStore)(nil)
+
+type MockReposStore struct {
+	MockGetByName func(ownerID int64, name string) (*Repository, error)
+}
+
+func (m *MockReposStore) GetByName(ownerID int64, name string) (*Repository, error) {
+	return m.MockGetByName(ownerID, name)
+}
+
+func SetMockReposStore(t *testing.T, mock ReposStore) {
+	before := Repos
+	Repos = mock
+	t.Cleanup(func() {
+		Repos = before
+	})
+}
+
+var _ TwoFactorsStore = (*MockTwoFactorsStore)(nil)
+
+type MockTwoFactorsStore struct {
+	MockIsUserEnabled func(userID int64) bool
+}
+
+func (m *MockTwoFactorsStore) IsUserEnabled(userID int64) bool {
+	return m.MockIsUserEnabled(userID)
+}
+
+func SetMockTwoFactorsStore(t *testing.T, mock TwoFactorsStore) {
+	before := TwoFactors
+	TwoFactors = mock
+	t.Cleanup(func() {
+		TwoFactors = before
+	})
+}
+
+var _ UsersStore = (*MockUsersStore)(nil)
+
+type MockUsersStore struct {
+	MockAuthenticate  func(username, password string, loginSourceID int64) (*User, error)
+	MockGetByID       func(id int64) (*User, error)
+	MockGetByUsername func(username string) (*User, error)
+}
+
+func (m *MockUsersStore) Authenticate(username, password string, loginSourceID int64) (*User, error) {
+	return m.MockAuthenticate(username, password, loginSourceID)
+}
+
+func (m *MockUsersStore) GetByID(id int64) (*User, error) {
+	return m.MockGetByID(id)
+}
+
+func (m *MockUsersStore) GetByUsername(username string) (*User, error) {
+	return m.MockGetByUsername(username)
+}
+
+func SetMockUsersStore(t *testing.T, mock UsersStore) {
+	before := Users
+	Users = mock
+	t.Cleanup(func() {
+		Users = before
+	})
+}

+ 0 - 1
internal/gitutil/mock.go → internal/gitutil/mocks.go

@@ -12,7 +12,6 @@ import (
 
 var _ ModuleStore = (*MockModuleStore)(nil)
 
-// MockModuleStore is a mock implementation of ModuleStore interface.
 type MockModuleStore struct {
 	repoAddRemote    func(repoPath, name, url string, opts ...git.AddRemoteOptions) error
 	repoDiffNameOnly func(repoPath, base, head string, opts ...git.DiffNameOnlyOptions) ([]string, error)

+ 5 - 14
internal/mock/locale.go → internal/mocks/locale.go

@@ -2,7 +2,7 @@
 // Use of this source code is governed by a MIT-style
 // license that can be found in the LICENSE file.
 
-package mock
+package mocks
 
 import (
 	"gopkg.in/macaron.v1"
@@ -10,24 +10,15 @@ import (
 
 var _ macaron.Locale = (*Locale)(nil)
 
-// Locale is a mock that implements macaron.Locale.
 type Locale struct {
-	lang string
-	tr   func(string, ...interface{}) string
-}
-
-// NewLocale creates a new mock for macaron.Locale.
-func NewLocale(lang string, tr func(string, ...interface{}) string) *Locale {
-	return &Locale{
-		lang: lang,
-		tr:   tr,
-	}
+	MockLang string
+	MockTr   func(string, ...interface{}) string
 }
 
 func (l *Locale) Language() string {
-	return l.lang
+	return l.MockLang
 }
 
 func (l *Locale) Tr(format string, args ...interface{}) string {
-	return l.tr(format, args...)
+	return l.MockTr(format, args...)
 }

+ 0 - 6
internal/route/lfs/batch.go

@@ -174,9 +174,3 @@ func responseJSON(w http.ResponseWriter, status int, v interface{}) {
 		return
 	}
 }
-
-func internalServerError(w http.ResponseWriter) {
-	responseJSON(w, http.StatusInternalServerError, responseError{
-		Message: "Internal server error",
-	})
-}

+ 21 - 12
internal/route/lfs/route.go

@@ -34,10 +34,11 @@ func RegisterRoutes(r *macaron.Router) {
 	}, authenticate())
 }
 
-// authenticate tries to authenticate user via HTTP Basic Auth.
+// authenticate tries to authenticate user via HTTP Basic Auth. It first tries to authenticate
+// as plain username and password, then use username as access token if previous step failed.
 func authenticate() macaron.Handler {
 	askCredentials := func(w http.ResponseWriter) {
-		w.Header().Set("LFS-Authenticate", `Basic realm="Git LFS"`)
+		w.Header().Set("Lfs-Authenticate", `Basic realm="Git LFS"`)
 		responseJSON(w, http.StatusUnauthorized, responseError{
 			Message: "Credentials needed",
 		})
@@ -52,13 +53,13 @@ func authenticate() macaron.Handler {
 
 		user, err := db.Users.Authenticate(username, password, -1)
 		if err != nil && !db.IsErrUserNotExist(err) {
-			c.Status(http.StatusInternalServerError)
+			internalServerError(c.Resp)
 			log.Error("Failed to authenticate user [name: %s]: %v", username, err)
 			return
 		}
 
 		if err == nil && user.IsEnabledTwoFactor() {
-			c.Error(http.StatusBadRequest, `Users with 2FA enabled are not allowed to authenticate via username and password.`)
+			c.Error(http.StatusBadRequest, "Users with 2FA enabled are not allowed to authenticate via username and password.")
 			return
 		}
 
@@ -69,7 +70,7 @@ func authenticate() macaron.Handler {
 				if db.IsErrAccessTokenNotExist(err) {
 					askCredentials(c.Resp)
 				} else {
-					c.Status(http.StatusInternalServerError)
+					internalServerError(c.Resp)
 					log.Error("Failed to get access token [sha: %s]: %v", username, err)
 				}
 				return
@@ -83,7 +84,7 @@ func authenticate() macaron.Handler {
 			if err != nil {
 				// Once we found the token, we're supposed to find its related user,
 				// thus any error is unexpected.
-				c.Status(http.StatusInternalServerError)
+				internalServerError(c.Resp)
 				log.Error("Failed to get user: %v", err)
 				return
 			}
@@ -97,7 +98,7 @@ func authenticate() macaron.Handler {
 
 // authorize tries to authorize the user to the context repository with given access mode.
 func authorize(mode db.AccessMode) macaron.Handler {
-	return func(c *macaron.Context, user *db.User) {
+	return func(c *macaron.Context, actor *db.User) {
 		username := c.Params(":username")
 		reponame := strings.TrimSuffix(c.Params(":reponame"), ".git")
 
@@ -106,7 +107,7 @@ func authorize(mode db.AccessMode) macaron.Handler {
 			if db.IsErrUserNotExist(err) {
 				c.Status(http.StatusNotFound)
 			} else {
-				c.Status(http.StatusInternalServerError)
+				internalServerError(c.Resp)
 				log.Error("Failed to get user [name: %s]: %v", username, err)
 			}
 			return
@@ -117,18 +118,18 @@ func authorize(mode db.AccessMode) macaron.Handler {
 			if db.IsErrRepoNotExist(err) {
 				c.Status(http.StatusNotFound)
 			} else {
-				c.Status(http.StatusInternalServerError)
+				internalServerError(c.Resp)
 				log.Error("Failed to get repository [owner_id: %d, name: %s]: %v", owner.ID, reponame, err)
 			}
 			return
 		}
 
-		if !db.Perms.Authorize(user.ID, repo, mode) {
+		if !db.Perms.Authorize(actor.ID, repo, mode) {
 			c.Status(http.StatusNotFound)
 			return
 		}
 
-		c.Map(owner)
+		c.Map(owner) // NOTE: Override actor
 		c.Map(repo)
 	}
 }
@@ -149,10 +150,18 @@ func verifyOID() macaron.Handler {
 	return func(c *macaron.Context) {
 		oid := lfsutil.OID(c.Params(":oid"))
 		if !lfsutil.ValidOID(oid) {
-			c.Error(http.StatusBadRequest, "Invalid oid")
+			responseJSON(c.Resp, http.StatusBadRequest, responseError{
+				Message: "Invalid oid",
+			})
 			return
 		}
 
 		c.Map(oid)
 	}
 }
+
+func internalServerError(w http.ResponseWriter) {
+	responseJSON(w, http.StatusInternalServerError, responseError{
+		Message: "Internal server error",
+	})
+}

+ 382 - 0
internal/route/lfs/route_test.go

@@ -0,0 +1,382 @@
+// Copyright 2020 The Gogs Authors. All rights reserved.
+// Use of this source code is governed by a MIT-style
+// license that can be found in the LICENSE file.
+
+package lfs
+
+import (
+	"fmt"
+	"io/ioutil"
+	"net/http"
+	"net/http/httptest"
+	"testing"
+
+	"github.com/pkg/errors"
+	"github.com/stretchr/testify/assert"
+	"gopkg.in/macaron.v1"
+
+	"gogs.io/gogs/internal/db"
+	"gogs.io/gogs/internal/lfsutil"
+)
+
+func Test_authenticate(t *testing.T) {
+	m := macaron.New()
+	m.Use(macaron.Renderer())
+	m.Get("/", authenticate(), func(w http.ResponseWriter, user *db.User) {
+		fmt.Fprintf(w, "ID: %d, Name: %s", user.ID, user.Name)
+	})
+
+	tests := []struct {
+		name                  string
+		header                http.Header
+		mockUsersStore        *db.MockUsersStore
+		mockTwoFactorsStore   *db.MockTwoFactorsStore
+		mockAccessTokensStore *db.MockAccessTokensStore
+		expStatusCode         int
+		expHeader             http.Header
+		expBody               string
+	}{
+		{
+			name:          "no authorization",
+			expStatusCode: http.StatusUnauthorized,
+			expHeader: http.Header{
+				"Lfs-Authenticate": []string{`Basic realm="Git LFS"`},
+				"Content-Type":     []string{"application/vnd.git-lfs+json"},
+			},
+			expBody: `{"message":"Credentials needed"}` + "\n",
+		},
+		{
+			name: "user has 2FA enabled",
+			header: http.Header{
+				"Authorization": []string{"Basic dXNlcm5hbWU6cGFzc3dvcmQ="},
+			},
+			mockUsersStore: &db.MockUsersStore{
+				MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) {
+					return &db.User{}, nil
+				},
+			},
+			mockTwoFactorsStore: &db.MockTwoFactorsStore{
+				MockIsUserEnabled: func(userID int64) bool {
+					return true
+				},
+			},
+			expStatusCode: http.StatusBadRequest,
+			expHeader:     http.Header{},
+			expBody:       "Users with 2FA enabled are not allowed to authenticate via username and password.",
+		},
+		{
+			name: "both user and access token do not exist",
+			header: http.Header{
+				"Authorization": []string{"Basic dXNlcm5hbWU="},
+			},
+			mockUsersStore: &db.MockUsersStore{
+				MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) {
+					return nil, db.ErrUserNotExist{}
+				},
+			},
+			mockAccessTokensStore: &db.MockAccessTokensStore{
+				MockGetBySHA: func(sha string) (*db.AccessToken, error) {
+					return nil, db.ErrAccessTokenNotExist{}
+				},
+			},
+			expStatusCode: http.StatusUnauthorized,
+			expHeader: http.Header{
+				"Lfs-Authenticate": []string{`Basic realm="Git LFS"`},
+				"Content-Type":     []string{"application/vnd.git-lfs+json"},
+			},
+			expBody: `{"message":"Credentials needed"}` + "\n",
+		},
+
+		{
+			name: "authenticated by username and password",
+			header: http.Header{
+				"Authorization": []string{"Basic dXNlcm5hbWU6cGFzc3dvcmQ="},
+			},
+			mockUsersStore: &db.MockUsersStore{
+				MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) {
+					return &db.User{ID: 1, Name: "unknwon"}, nil
+				},
+			},
+			mockTwoFactorsStore: &db.MockTwoFactorsStore{
+				MockIsUserEnabled: func(userID int64) bool {
+					return false
+				},
+			},
+			expStatusCode: http.StatusOK,
+			expHeader:     http.Header{},
+			expBody:       "ID: 1, Name: unknwon",
+		},
+		{
+			name: "authenticate by access token",
+			header: http.Header{
+				"Authorization": []string{"Basic dXNlcm5hbWU="},
+			},
+			mockUsersStore: &db.MockUsersStore{
+				MockAuthenticate: func(username, password string, loginSourceID int64) (*db.User, error) {
+					return nil, db.ErrUserNotExist{}
+				},
+				MockGetByID: func(id int64) (*db.User, error) {
+					return &db.User{ID: 1, Name: "unknwon"}, nil
+				},
+			},
+			mockAccessTokensStore: &db.MockAccessTokensStore{
+				MockGetBySHA: func(sha string) (*db.AccessToken, error) {
+					return &db.AccessToken{}, nil
+				},
+				MockSave: func(t *db.AccessToken) error {
+					if t.Updated.IsZero() {
+						return errors.New("Updated is zero")
+					}
+					return nil
+				},
+			},
+			expStatusCode: http.StatusOK,
+			expHeader:     http.Header{},
+			expBody:       "ID: 1, Name: unknwon",
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			db.SetMockUsersStore(t, test.mockUsersStore)
+			db.SetMockTwoFactorsStore(t, test.mockTwoFactorsStore)
+			db.SetMockAccessTokensStore(t, test.mockAccessTokensStore)
+
+			r, err := http.NewRequest("GET", "/", nil)
+			if err != nil {
+				t.Fatal(err)
+			}
+			r.Header = test.header
+
+			rr := httptest.NewRecorder()
+			m.ServeHTTP(rr, r)
+
+			resp := rr.Result()
+			assert.Equal(t, test.expStatusCode, resp.StatusCode)
+			assert.Equal(t, test.expHeader, resp.Header)
+
+			body, err := ioutil.ReadAll(resp.Body)
+			if err != nil {
+				t.Fatal(err)
+			}
+			assert.Equal(t, test.expBody, string(body))
+		})
+	}
+}
+
+func Test_authorize(t *testing.T) {
+	tests := []struct {
+		name           string
+		authroize      macaron.Handler
+		mockUsersStore *db.MockUsersStore
+		mockReposStore *db.MockReposStore
+		mockPermsStore *db.MockPermsStore
+		expStatusCode  int
+		expBody        string
+	}{
+		{
+			name:      "user does not exist",
+			authroize: authorize(db.AccessModeNone),
+			mockUsersStore: &db.MockUsersStore{
+				MockGetByUsername: func(username string) (*db.User, error) {
+					return nil, db.ErrUserNotExist{}
+				},
+			},
+			expStatusCode: http.StatusNotFound,
+		},
+		{
+			name:      "repository does not exist",
+			authroize: authorize(db.AccessModeNone),
+			mockUsersStore: &db.MockUsersStore{
+				MockGetByUsername: func(username string) (*db.User, error) {
+					return &db.User{Name: username}, nil
+				},
+			},
+			mockReposStore: &db.MockReposStore{
+				MockGetByName: func(ownerID int64, name string) (*db.Repository, error) {
+					return nil, db.ErrRepoNotExist{}
+				},
+			},
+			expStatusCode: http.StatusNotFound,
+		},
+		{
+			name:      "actor is not authorized",
+			authroize: authorize(db.AccessModeWrite),
+			mockUsersStore: &db.MockUsersStore{
+				MockGetByUsername: func(username string) (*db.User, error) {
+					return &db.User{Name: username}, nil
+				},
+			},
+			mockReposStore: &db.MockReposStore{
+				MockGetByName: func(ownerID int64, name string) (*db.Repository, error) {
+					return &db.Repository{Name: name}, nil
+				},
+			},
+			mockPermsStore: &db.MockPermsStore{
+				MockAuthorize: func(userID int64, repo *db.Repository, desired db.AccessMode) bool {
+					return desired <= db.AccessModeRead
+				},
+			},
+			expStatusCode: http.StatusNotFound,
+		},
+
+		{
+			name:      "actor is authorized",
+			authroize: authorize(db.AccessModeRead),
+			mockUsersStore: &db.MockUsersStore{
+				MockGetByUsername: func(username string) (*db.User, error) {
+					return &db.User{Name: username}, nil
+				},
+			},
+			mockReposStore: &db.MockReposStore{
+				MockGetByName: func(ownerID int64, name string) (*db.Repository, error) {
+					return &db.Repository{Name: name}, nil
+				},
+			},
+			mockPermsStore: &db.MockPermsStore{
+				MockAuthorize: func(userID int64, repo *db.Repository, desired db.AccessMode) bool {
+					return desired <= db.AccessModeRead
+				},
+			},
+			expStatusCode: http.StatusOK,
+			expBody:       "owner.Name: owner, repo.Name: repo",
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			db.SetMockUsersStore(t, test.mockUsersStore)
+			db.SetMockReposStore(t, test.mockReposStore)
+			db.SetMockPermsStore(t, test.mockPermsStore)
+
+			m := macaron.New()
+			m.Use(macaron.Renderer())
+			m.Use(func(c *macaron.Context) {
+				c.Map(&db.User{})
+			})
+			m.Get("/:username/:reponame", test.authroize, func(w http.ResponseWriter, owner *db.User, repo *db.Repository) {
+				fmt.Fprintf(w, "owner.Name: %s, repo.Name: %s", owner.Name, repo.Name)
+			})
+
+			r, err := http.NewRequest("GET", "/owner/repo", nil)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			rr := httptest.NewRecorder()
+			m.ServeHTTP(rr, r)
+
+			resp := rr.Result()
+			assert.Equal(t, test.expStatusCode, resp.StatusCode)
+
+			body, err := ioutil.ReadAll(resp.Body)
+			if err != nil {
+				t.Fatal(err)
+			}
+			assert.Equal(t, test.expBody, string(body))
+		})
+	}
+}
+
+func Test_verifyHeader(t *testing.T) {
+	tests := []struct {
+		name          string
+		verifyHeader  macaron.Handler
+		header        http.Header
+		expStatusCode int
+	}{
+		{
+			name:          "header not found",
+			verifyHeader:  verifyHeader("Accept", contentType, http.StatusNotAcceptable),
+			expStatusCode: http.StatusNotAcceptable,
+		},
+
+		{
+			name:         "header found",
+			verifyHeader: verifyHeader("Accept", "application/vnd.git-lfs+json", http.StatusNotAcceptable),
+			header: http.Header{
+				"Accept": []string{"application/vnd.git-lfs+json; charset=utf-8"},
+			},
+			expStatusCode: http.StatusOK,
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			m := macaron.New()
+			m.Use(macaron.Renderer())
+			m.Get("/", test.verifyHeader)
+
+			r, err := http.NewRequest("GET", "/", nil)
+			if err != nil {
+				t.Fatal(err)
+			}
+			r.Header = test.header
+
+			rr := httptest.NewRecorder()
+			m.ServeHTTP(rr, r)
+
+			resp := rr.Result()
+			assert.Equal(t, test.expStatusCode, resp.StatusCode)
+		})
+	}
+}
+
+func Test_verifyOID(t *testing.T) {
+	m := macaron.New()
+	m.Get("/:oid", verifyOID(), func(w http.ResponseWriter, oid lfsutil.OID) {
+		fmt.Fprintf(w, "oid: %s", oid)
+	})
+
+	tests := []struct {
+		name          string
+		url           string
+		expStatusCode int
+		expBody       string
+	}{
+		{
+			name:          "bad oid",
+			url:           "/bad_oid",
+			expStatusCode: http.StatusBadRequest,
+			expBody:       `{"message":"Invalid oid"}` + "\n",
+		},
+
+		{
+			name:          "good oid",
+			url:           "/ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f",
+			expStatusCode: http.StatusOK,
+			expBody:       "oid: ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f",
+		},
+	}
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			r, err := http.NewRequest("GET", test.url, nil)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			rr := httptest.NewRecorder()
+			m.ServeHTTP(rr, r)
+
+			resp := rr.Result()
+			assert.Equal(t, test.expStatusCode, resp.StatusCode)
+
+			body, err := ioutil.ReadAll(resp.Body)
+			if err != nil {
+				t.Fatal(err)
+			}
+			assert.Equal(t, test.expBody, string(body))
+		})
+	}
+}
+
+func Test_internalServerError(t *testing.T) {
+	rr := httptest.NewRecorder()
+	internalServerError(rr)
+
+	resp := rr.Result()
+	assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
+
+	body, err := ioutil.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+	assert.Equal(t, `{"message":"Internal server error"}`+"\n", string(body))
+}

+ 7 - 4
internal/route/repo/webhook_test.go

@@ -10,7 +10,7 @@ import (
 	"github.com/stretchr/testify/assert"
 
 	"gogs.io/gogs/internal/db"
-	"gogs.io/gogs/internal/mock"
+	"gogs.io/gogs/internal/mocks"
 )
 
 func Test_isLocalHostname(t *testing.T) {
@@ -33,9 +33,12 @@ func Test_isLocalHostname(t *testing.T) {
 }
 
 func Test_validateWebhook(t *testing.T) {
-	l := mock.NewLocale("en", func(s string, _ ...interface{}) string {
-		return s
-	})
+	l := &mocks.Locale{
+		MockLang: "en",
+		MockTr: func(s string, _ ...interface{}) string {
+			return s
+		},
+	}
 
 	tests := []struct {
 		name     string