oauth2.go 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. // Copyright 2014 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package oauth2 contains Martini handlers to provide
  15. // user login via an OAuth 2.0 backend.
  16. package oauth2
  17. import (
  18. "encoding/json"
  19. "fmt"
  20. "net/http"
  21. "net/url"
  22. "strings"
  23. "time"
  24. "code.google.com/p/goauth2/oauth"
  25. "github.com/go-martini/martini"
  26. "github.com/martini-contrib/sessions"
  27. )
  28. const (
  29. codeRedirect = 302
  30. keyToken = "oauth2_token"
  31. keyNextPage = "next"
  32. )
  33. var (
  34. // Path to handle OAuth 2.0 logins.
  35. PathLogin = "/login"
  36. // Path to handle OAuth 2.0 logouts.
  37. PathLogout = "/logout"
  38. // Path to handle callback from OAuth 2.0 backend
  39. // to exchange credentials.
  40. PathCallback = "/oauth2callback"
  41. // Path to handle error cases.
  42. PathError = "/oauth2error"
  43. )
  44. // Represents OAuth2 backend options.
  45. type Options struct {
  46. ClientId string
  47. ClientSecret string
  48. RedirectURL string
  49. Scopes []string
  50. AuthUrl string
  51. TokenUrl string
  52. }
  53. // Represents a container that contains
  54. // user's OAuth 2.0 access and refresh tokens.
  55. type Tokens interface {
  56. Access() string
  57. Refresh() string
  58. IsExpired() bool
  59. ExpiryTime() time.Time
  60. ExtraData() map[string]string
  61. }
  62. type token struct {
  63. oauth.Token
  64. }
  65. func (t *token) ExtraData() map[string]string {
  66. return t.Extra
  67. }
  68. // Returns the access token.
  69. func (t *token) Access() string {
  70. return t.AccessToken
  71. }
  72. // Returns the refresh token.
  73. func (t *token) Refresh() string {
  74. return t.RefreshToken
  75. }
  76. // Returns whether the access token is
  77. // expired or not.
  78. func (t *token) IsExpired() bool {
  79. if t == nil {
  80. return true
  81. }
  82. return t.Expired()
  83. }
  84. // Returns the expiry time of the user's
  85. // access token.
  86. func (t *token) ExpiryTime() time.Time {
  87. return t.Expiry
  88. }
  89. // Formats tokens into string.
  90. func (t *token) String() string {
  91. return fmt.Sprintf("tokens: %v", t)
  92. }
  93. // Returns a new Google OAuth 2.0 backend endpoint.
  94. func Google(opts *Options) martini.Handler {
  95. opts.AuthUrl = "https://accounts.google.com/o/oauth2/auth"
  96. opts.TokenUrl = "https://accounts.google.com/o/oauth2/token"
  97. return NewOAuth2Provider(opts)
  98. }
  99. // Returns a new Github OAuth 2.0 backend endpoint.
  100. func Github(opts *Options) martini.Handler {
  101. opts.AuthUrl = "https://github.com/login/oauth/authorize"
  102. opts.TokenUrl = "https://github.com/login/oauth/access_token"
  103. return NewOAuth2Provider(opts)
  104. }
  105. func Facebook(opts *Options) martini.Handler {
  106. opts.AuthUrl = "https://www.facebook.com/dialog/oauth"
  107. opts.TokenUrl = "https://graph.facebook.com/oauth/access_token"
  108. return NewOAuth2Provider(opts)
  109. }
  110. // Returns a generic OAuth 2.0 backend endpoint.
  111. func NewOAuth2Provider(opts *Options) martini.Handler {
  112. config := &oauth.Config{
  113. ClientId: opts.ClientId,
  114. ClientSecret: opts.ClientSecret,
  115. RedirectURL: opts.RedirectURL,
  116. Scope: strings.Join(opts.Scopes, " "),
  117. AuthURL: opts.AuthUrl,
  118. TokenURL: opts.TokenUrl,
  119. }
  120. transport := &oauth.Transport{
  121. Config: config,
  122. Transport: http.DefaultTransport,
  123. }
  124. return func(s sessions.Session, c martini.Context, w http.ResponseWriter, r *http.Request) {
  125. if r.Method == "GET" {
  126. switch r.URL.Path {
  127. case PathLogin:
  128. login(transport, s, w, r)
  129. case PathLogout:
  130. logout(transport, s, w, r)
  131. case PathCallback:
  132. handleOAuth2Callback(transport, s, w, r)
  133. }
  134. }
  135. tk := unmarshallToken(s)
  136. if tk != nil {
  137. // check if the access token is expired
  138. if tk.IsExpired() && tk.Refresh() == "" {
  139. s.Delete(keyToken)
  140. tk = nil
  141. }
  142. }
  143. // Inject tokens.
  144. c.MapTo(tk, (*Tokens)(nil))
  145. }
  146. }
  147. // Handler that redirects user to the login page
  148. // if user is not logged in.
  149. // Sample usage:
  150. // m.Get("/login-required", oauth2.LoginRequired, func() ... {})
  151. var LoginRequired martini.Handler = func() martini.Handler {
  152. return func(s sessions.Session, c martini.Context, w http.ResponseWriter, r *http.Request) {
  153. token := unmarshallToken(s)
  154. if token == nil || token.IsExpired() {
  155. next := url.QueryEscape(r.URL.RequestURI())
  156. http.Redirect(w, r, PathLogin+"?next="+next, codeRedirect)
  157. }
  158. }
  159. }()
  160. func login(t *oauth.Transport, s sessions.Session, w http.ResponseWriter, r *http.Request) {
  161. next := extractPath(r.URL.Query().Get(keyNextPage))
  162. if s.Get(keyToken) == nil {
  163. // User is not logged in.
  164. http.Redirect(w, r, t.Config.AuthCodeURL(next), codeRedirect)
  165. return
  166. }
  167. // No need to login, redirect to the next page.
  168. http.Redirect(w, r, next, codeRedirect)
  169. }
  170. func logout(t *oauth.Transport, s sessions.Session, w http.ResponseWriter, r *http.Request) {
  171. next := extractPath(r.URL.Query().Get(keyNextPage))
  172. s.Delete(keyToken)
  173. http.Redirect(w, r, next, codeRedirect)
  174. }
  175. func handleOAuth2Callback(t *oauth.Transport, s sessions.Session, w http.ResponseWriter, r *http.Request) {
  176. next := extractPath(r.URL.Query().Get("state"))
  177. code := r.URL.Query().Get("code")
  178. tk, err := t.Exchange(code)
  179. if err != nil {
  180. // Pass the error message, or allow dev to provide its own
  181. // error handler.
  182. http.Redirect(w, r, PathError, codeRedirect)
  183. return
  184. }
  185. // Store the credentials in the session.
  186. val, _ := json.Marshal(tk)
  187. s.Set(keyToken, val)
  188. http.Redirect(w, r, next, codeRedirect)
  189. }
  190. func unmarshallToken(s sessions.Session) (t *token) {
  191. if s.Get(keyToken) == nil {
  192. return
  193. }
  194. data := s.Get(keyToken).([]byte)
  195. var tk oauth.Token
  196. json.Unmarshal(data, &tk)
  197. return &token{tk}
  198. }
  199. func extractPath(next string) string {
  200. n, err := url.Parse(next)
  201. if err != nil {
  202. return "/"
  203. }
  204. return n.Path
  205. }