auth
auth/oauth2.go
Start supporting our pluggable grant_type. Define GrantType as a way to bundle information that can be used to validate requests based on their grant_type parameter. Move our validation of the authorization_code grant_type out of GetTokenHandler and into its own function. Define RegisterGrantType as a way to register new grant_type bundles and associate them with the string passed to grant_type. This enables other packages to define RegisterGrantType in their init() functions and plug in new grant types without forking this code. Implement RegisterGrantType for our authorization_code grant type.
| paddy@51 | 1 package auth |
| paddy@51 | 2 |
| paddy@51 | 3 import ( |
| paddy@82 | 4 "crypto/sha256" |
| paddy@79 | 5 "encoding/hex" |
| paddy@69 | 6 "encoding/json" |
| paddy@69 | 7 "errors" |
| paddy@61 | 8 "html/template" |
| paddy@77 | 9 "log" |
| paddy@51 | 10 "net/http" |
| paddy@60 | 11 "net/url" |
| paddy@84 | 12 "sync" |
| paddy@60 | 13 "time" |
| paddy@56 | 14 |
| paddy@69 | 15 "code.secondbit.org/pass" |
| paddy@56 | 16 "code.secondbit.org/uuid" |
| paddy@82 | 17 |
| paddy@82 | 18 "github.com/gorilla/mux" |
| paddy@51 | 19 ) |
| paddy@51 | 20 |
| paddy@60 | 21 const ( |
| paddy@69 | 22 authCookieName = "auth" |
| paddy@69 | 23 defaultGrantExpiration = 600 // default to ten minute grant expirations |
| paddy@60 | 24 getGrantTemplateName = "get_grant" |
| paddy@60 | 25 ) |
| paddy@51 | 26 |
| paddy@69 | 27 var ( |
| paddy@69 | 28 // ErrNoAuth is returned when an Authorization header is not present or is empty. |
| paddy@69 | 29 ErrNoAuth = errors.New("no authorization header supplied") |
| paddy@69 | 30 // ErrInvalidAuthFormat is returned when an Authorization header is present but not the correct format. |
| paddy@69 | 31 ErrInvalidAuthFormat = errors.New("authorization header is not in a valid format") |
| paddy@69 | 32 // ErrIncorrectAuth is returned when a user authentication attempt does not match the stored values. |
| paddy@69 | 33 ErrIncorrectAuth = errors.New("invalid authentication") |
| paddy@69 | 34 // ErrInvalidPassphraseScheme is returned when an undefined passphrase scheme is used. |
| paddy@69 | 35 ErrInvalidPassphraseScheme = errors.New("invalid passphrase scheme") |
| paddy@69 | 36 // ErrNoSession is returned when no session ID is passed with a request. |
| paddy@69 | 37 ErrNoSession = errors.New("no session ID found") |
| paddy@84 | 38 |
| paddy@84 | 39 grantTypesMap = grantTypes{types: map[string]GrantType{}} |
| paddy@69 | 40 ) |
| paddy@69 | 41 |
| paddy@84 | 42 type grantTypes struct { |
| paddy@84 | 43 types map[string]GrantType |
| paddy@84 | 44 sync.RWMutex |
| paddy@84 | 45 } |
| paddy@84 | 46 |
| paddy@84 | 47 // GrantType defines a set of functions and metadata around a specific authorization grant strategy. |
| paddy@84 | 48 // |
| paddy@84 | 49 // The Validate function will be called when requests are made that match the GrantType, and should write any |
| paddy@84 | 50 // errors to the ResponseWriter. It is responsible for determining if the grant is valid and a token should be issued. |
| paddy@84 | 51 // It must return the scope the grant was for and the ID of the Profile that issued the grant, as well as if the grant |
| paddy@84 | 52 // is valid or not. It must not be nil. |
| paddy@84 | 53 // |
| paddy@84 | 54 // The Invalidate function will be called when the grant has successfully generated a token and the token has successfully |
| paddy@84 | 55 // been conveyed to the user. The Invalidate function is always called asynchronously, outside the request. It should take |
| paddy@84 | 56 // care of marking the grant as used, if the GrantType requires grants to be one-time only grants. The Invalidate function |
| paddy@84 | 57 // can be nil. |
| paddy@84 | 58 // |
| paddy@84 | 59 // IssuesRefresh determines whether the GrantType should yield a refresh token as well as an access token. If true, the client |
| paddy@84 | 60 // will be issued a refresh token. |
| paddy@84 | 61 type GrantType struct { |
| paddy@84 | 62 Validate func(w http.ResponseWriter, r *http.Request, context Context) (scope string, profileID uuid.ID, valid bool) |
| paddy@84 | 63 Invalidate func(r *http.Request, context Context) bool |
| paddy@84 | 64 IssuesRefresh bool |
| paddy@84 | 65 } |
| paddy@84 | 66 |
| paddy@69 | 67 type tokenResponse struct { |
| paddy@69 | 68 AccessToken string `json:"access_token"` |
| paddy@69 | 69 TokenType string `json:"token_type,omitempty"` |
| paddy@69 | 70 ExpiresIn int32 `json:"expires_in,omitempty"` |
| paddy@69 | 71 RefreshToken string `json:"refresh_token,omitempty"` |
| paddy@69 | 72 } |
| paddy@69 | 73 |
| paddy@82 | 74 type errorResponse struct { |
| paddy@82 | 75 Error string `json:"error"` |
| paddy@82 | 76 Description string `json:"error_description,omitempty"` |
| paddy@82 | 77 URI string `json:"error_uri,omitempty"` |
| paddy@82 | 78 } |
| paddy@82 | 79 |
| paddy@84 | 80 // RegisterGrantType associates a string with a GrantType. When the string is used as the value for "grant_type" when obtaining |
| paddy@84 | 81 // an access token, the associated GrantType's properties will be used. |
| paddy@84 | 82 // |
| paddy@84 | 83 // RegisterGrantType should be called in the `init()` function of packages, much like database/sql registers drivers. It will panic |
| paddy@84 | 84 // if a GrantType tries to register under a string that already has a GrantType registered for it. |
| paddy@84 | 85 func RegisterGrantType(name string, g GrantType) { |
| paddy@84 | 86 grantTypesMap.Lock() |
| paddy@84 | 87 defer grantTypesMap.Unlock() |
| paddy@84 | 88 if _, ok := grantTypesMap.types[name]; ok { |
| paddy@84 | 89 panic("Duplicate registration of grant_type " + name) |
| paddy@84 | 90 } |
| paddy@84 | 91 grantTypesMap.types[name] = g |
| paddy@84 | 92 } |
| paddy@84 | 93 |
| paddy@84 | 94 func findGrantType(name string) (GrantType, bool) { |
| paddy@84 | 95 grantTypesMap.RLock() |
| paddy@84 | 96 defer grantTypesMap.RUnlock() |
| paddy@84 | 97 t, ok := grantTypesMap.types[name] |
| paddy@84 | 98 return t, ok |
| paddy@84 | 99 } |
| paddy@84 | 100 |
| paddy@82 | 101 func renderJSONError(enc *json.Encoder, errorType string) { |
| paddy@82 | 102 err := enc.Encode(errorResponse{ |
| paddy@82 | 103 Error: errorType, |
| paddy@82 | 104 }) |
| paddy@82 | 105 if err != nil { |
| paddy@82 | 106 // TODO(paddy): log this or something |
| paddy@69 | 107 } |
| paddy@69 | 108 } |
| paddy@69 | 109 |
| paddy@69 | 110 func checkCookie(r *http.Request, context Context) (Session, error) { |
| paddy@69 | 111 cookie, err := r.Cookie(authCookieName) |
| paddy@77 | 112 if err == http.ErrNoCookie { |
| paddy@77 | 113 return Session{}, ErrNoSession |
| paddy@77 | 114 } else if err != nil { |
| paddy@77 | 115 log.Println(err) |
| paddy@69 | 116 return Session{}, err |
| paddy@69 | 117 } |
| paddy@69 | 118 sess, err := context.GetSession(cookie.Value) |
| paddy@69 | 119 if err == ErrSessionNotFound { |
| paddy@69 | 120 return Session{}, ErrInvalidSession |
| paddy@69 | 121 } else if err != nil { |
| paddy@69 | 122 return Session{}, err |
| paddy@69 | 123 } |
| paddy@69 | 124 if !sess.Active { |
| paddy@69 | 125 return Session{}, ErrInvalidSession |
| paddy@69 | 126 } |
| paddy@69 | 127 return sess, nil |
| paddy@69 | 128 } |
| paddy@69 | 129 |
| paddy@77 | 130 func buildLoginRedirect(r *http.Request, context Context) string { |
| paddy@77 | 131 if context.loginURI == nil { |
| paddy@77 | 132 return "" |
| paddy@77 | 133 } |
| paddy@77 | 134 uri := *context.loginURI |
| paddy@77 | 135 q := uri.Query() |
| paddy@78 | 136 q.Set("from", r.URL.String()) |
| paddy@77 | 137 uri.RawQuery = q.Encode() |
| paddy@77 | 138 return uri.String() |
| paddy@77 | 139 } |
| paddy@77 | 140 |
| paddy@69 | 141 func authenticate(user, passphrase string, context Context) (Profile, error) { |
| paddy@69 | 142 profile, err := context.GetProfileByLogin(user) |
| paddy@69 | 143 if err != nil { |
| paddy@79 | 144 if err == ErrProfileNotFound || err == ErrLoginNotFound { |
| paddy@69 | 145 return Profile{}, ErrIncorrectAuth |
| paddy@69 | 146 } |
| paddy@69 | 147 return Profile{}, err |
| paddy@69 | 148 } |
| paddy@69 | 149 switch profile.PassphraseScheme { |
| paddy@69 | 150 case 1: |
| paddy@79 | 151 realPass, err := hex.DecodeString(profile.Passphrase) |
| paddy@79 | 152 if err != nil { |
| paddy@79 | 153 return Profile{}, err |
| paddy@79 | 154 } |
| paddy@69 | 155 candidate := pass.Check(sha256.New, profile.Iterations, []byte(passphrase), []byte(profile.Salt)) |
| paddy@79 | 156 if !pass.Compare(candidate, realPass) { |
| paddy@69 | 157 return Profile{}, ErrIncorrectAuth |
| paddy@69 | 158 } |
| paddy@69 | 159 default: |
| paddy@69 | 160 return Profile{}, ErrInvalidPassphraseScheme |
| paddy@69 | 161 } |
| paddy@69 | 162 return profile, nil |
| paddy@69 | 163 } |
| paddy@69 | 164 |
| paddy@77 | 165 func wrap(context Context, f func(w http.ResponseWriter, r *http.Request, context Context)) http.Handler { |
| paddy@77 | 166 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| paddy@77 | 167 f(w, r, context) |
| paddy@77 | 168 }) |
| paddy@77 | 169 } |
| paddy@77 | 170 |
| paddy@77 | 171 // RegisterOAuth2 adds handlers to the passed router to handle the OAuth2 endpoints. |
| paddy@77 | 172 func RegisterOAuth2(r *mux.Router, context Context) { |
| paddy@77 | 173 r.Handle("/authorize", wrap(context, GetGrantHandler)) |
| paddy@77 | 174 r.Handle("/token", wrap(context, GetTokenHandler)) |
| paddy@77 | 175 } |
| paddy@77 | 176 |
| paddy@57 | 177 // GetGrantHandler presents and processes the page for asking a user to grant access |
| paddy@57 | 178 // to their data. See RFC 6749, Section 4.1. |
| paddy@51 | 179 func GetGrantHandler(w http.ResponseWriter, r *http.Request, context Context) { |
| paddy@69 | 180 session, err := checkCookie(r, context) |
| paddy@69 | 181 if err != nil { |
| paddy@76 | 182 if err == ErrNoSession || err == ErrInvalidSession { |
| paddy@77 | 183 redir := buildLoginRedirect(r, context) |
| paddy@77 | 184 if redir == "" { |
| paddy@77 | 185 log.Println("No login URL configured.") |
| paddy@77 | 186 w.WriteHeader(http.StatusInternalServerError) |
| paddy@77 | 187 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@77 | 188 "internal_error": template.HTML("Missing login URL."), |
| paddy@77 | 189 }) |
| paddy@77 | 190 return |
| paddy@77 | 191 } |
| paddy@77 | 192 http.Redirect(w, r, redir, http.StatusFound) |
| paddy@77 | 193 return |
| paddy@69 | 194 } |
| paddy@77 | 195 log.Println(err.Error()) |
| paddy@77 | 196 w.WriteHeader(http.StatusInternalServerError) |
| paddy@77 | 197 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@77 | 198 "internal_error": template.HTML(err.Error()), |
| paddy@77 | 199 }) |
| paddy@77 | 200 return |
| paddy@69 | 201 } |
| paddy@56 | 202 if r.URL.Query().Get("client_id") == "" { |
| paddy@56 | 203 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 204 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 205 "error": template.HTML("Client ID must be specified in the request."), |
| paddy@56 | 206 }) |
| paddy@56 | 207 return |
| paddy@56 | 208 } |
| paddy@56 | 209 clientID, err := uuid.Parse(r.URL.Query().Get("client_id")) |
| paddy@56 | 210 if err != nil { |
| paddy@56 | 211 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 212 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 213 "error": template.HTML("client_id is not a valid Client ID."), |
| paddy@56 | 214 }) |
| paddy@56 | 215 return |
| paddy@56 | 216 } |
| paddy@64 | 217 redirectURI := r.URL.Query().Get("redirect_uri") |
| paddy@64 | 218 redirectURL, err := url.Parse(redirectURI) |
| paddy@64 | 219 if err != nil { |
| paddy@64 | 220 w.WriteHeader(http.StatusBadRequest) |
| paddy@64 | 221 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@64 | 222 "error": template.HTML("The redirect_uri specified is not valid."), |
| paddy@64 | 223 }) |
| paddy@64 | 224 return |
| paddy@64 | 225 } |
| paddy@56 | 226 client, err := context.GetClient(clientID) |
| paddy@56 | 227 if err != nil { |
| paddy@59 | 228 if err == ErrClientNotFound { |
| paddy@59 | 229 w.WriteHeader(http.StatusBadRequest) |
| paddy@59 | 230 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 231 "error": template.HTML("The specified Client couldn’t be found."), |
| paddy@59 | 232 }) |
| paddy@59 | 233 } else { |
| paddy@77 | 234 log.Println(err.Error()) |
| paddy@59 | 235 w.WriteHeader(http.StatusInternalServerError) |
| paddy@59 | 236 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 237 "internal_error": template.HTML(err.Error()), |
| paddy@59 | 238 }) |
| paddy@59 | 239 } |
| paddy@56 | 240 return |
| paddy@56 | 241 } |
| paddy@56 | 242 // whether a redirect URI is valid or not depends on the number of endpoints |
| paddy@56 | 243 // the client has registered |
| paddy@56 | 244 numEndpoints, err := context.CountEndpoints(clientID) |
| paddy@56 | 245 if err != nil { |
| paddy@77 | 246 log.Println(err.Error()) |
| paddy@56 | 247 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 248 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 249 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 250 }) |
| paddy@56 | 251 return |
| paddy@56 | 252 } |
| paddy@56 | 253 var validURI bool |
| paddy@58 | 254 if redirectURI != "" { |
| paddy@58 | 255 // BUG(paddy): We really should normalize URIs before trying to compare them. |
| paddy@58 | 256 validURI, err = context.CheckEndpoint(clientID, redirectURI) |
| paddy@56 | 257 if err != nil { |
| paddy@77 | 258 log.Println(err.Error()) |
| paddy@56 | 259 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 260 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 261 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 262 }) |
| paddy@56 | 263 return |
| paddy@56 | 264 } |
| paddy@56 | 265 } else if redirectURI == "" && numEndpoints == 1 { |
| paddy@56 | 266 // if we don't specify the endpoint and there's only one endpoint, the |
| paddy@56 | 267 // request is valid, and we're redirecting to that one endpoint |
| paddy@56 | 268 validURI = true |
| paddy@56 | 269 endpoints, err := context.ListEndpoints(clientID, 1, 0) |
| paddy@56 | 270 if err != nil { |
| paddy@77 | 271 log.Println(err.Error()) |
| paddy@56 | 272 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 273 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 274 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 275 }) |
| paddy@56 | 276 return |
| paddy@56 | 277 } |
| paddy@56 | 278 if len(endpoints) != 1 { |
| paddy@56 | 279 validURI = false |
| paddy@56 | 280 } else { |
| paddy@66 | 281 u := endpoints[0].URI // Copy it here to avoid grabbing a pointer to the memstore |
| paddy@66 | 282 redirectURI = u.String() |
| paddy@66 | 283 redirectURL = &u |
| paddy@56 | 284 } |
| paddy@56 | 285 } else { |
| paddy@56 | 286 validURI = false |
| paddy@56 | 287 } |
| paddy@56 | 288 if !validURI { |
| paddy@56 | 289 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 290 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 291 "error": template.HTML("The redirect_uri specified is not valid."), |
| paddy@56 | 292 }) |
| paddy@56 | 293 return |
| paddy@56 | 294 } |
| paddy@60 | 295 scope := r.URL.Query().Get("scope") |
| paddy@60 | 296 state := r.URL.Query().Get("state") |
| paddy@56 | 297 if r.URL.Query().Get("response_type") != "code" { |
| paddy@65 | 298 q := redirectURL.Query() |
| paddy@65 | 299 q.Add("error", "invalid_request") |
| paddy@65 | 300 q.Add("state", state) |
| paddy@65 | 301 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 302 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 303 return |
| paddy@56 | 304 } |
| paddy@56 | 305 if r.Method == "POST" { |
| paddy@63 | 306 // BUG(paddy): We need to implement CSRF protection when obtaining a grant code. |
| paddy@56 | 307 if r.PostFormValue("grant") == "approved" { |
| paddy@60 | 308 code := uuid.NewID().String() |
| paddy@60 | 309 grant := Grant{ |
| paddy@60 | 310 Code: code, |
| paddy@60 | 311 Created: time.Now(), |
| paddy@60 | 312 ExpiresIn: defaultGrantExpiration, |
| paddy@60 | 313 ClientID: clientID, |
| paddy@60 | 314 Scope: scope, |
| paddy@69 | 315 RedirectURI: r.URL.Query().Get("redirect_uri"), |
| paddy@60 | 316 State: state, |
| paddy@69 | 317 ProfileID: session.ProfileID, |
| paddy@60 | 318 } |
| paddy@60 | 319 err := context.SaveGrant(grant) |
| paddy@60 | 320 if err != nil { |
| paddy@66 | 321 q := redirectURL.Query() |
| paddy@66 | 322 q.Add("error", "server_error") |
| paddy@66 | 323 q.Add("state", state) |
| paddy@66 | 324 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 325 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 326 return |
| paddy@60 | 327 } |
| paddy@66 | 328 q := redirectURL.Query() |
| paddy@66 | 329 q.Add("code", code) |
| paddy@66 | 330 q.Add("state", state) |
| paddy@66 | 331 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 332 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 333 return |
| paddy@56 | 334 } |
| paddy@66 | 335 q := redirectURL.Query() |
| paddy@66 | 336 q.Add("error", "access_denied") |
| paddy@66 | 337 q.Add("state", state) |
| paddy@66 | 338 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 339 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 340 return |
| paddy@56 | 341 } |
| paddy@51 | 342 w.WriteHeader(http.StatusOK) |
| paddy@56 | 343 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@56 | 344 "client": client, |
| paddy@56 | 345 }) |
| paddy@51 | 346 } |
| paddy@68 | 347 |
| paddy@69 | 348 // GetTokenHandler allows a client to exchange an authorization grant for an |
| paddy@69 | 349 // access token. See RFC 6749 Section 4.1.3. |
| paddy@69 | 350 func GetTokenHandler(w http.ResponseWriter, r *http.Request, context Context) { |
| paddy@69 | 351 enc := json.NewEncoder(w) |
| paddy@69 | 352 grantType := r.PostFormValue("grant_type") |
| paddy@84 | 353 gt, ok := findGrantType(grantType) |
| paddy@84 | 354 if !ok { |
| paddy@82 | 355 w.WriteHeader(http.StatusBadRequest) |
| paddy@82 | 356 renderJSONError(enc, "invalid_request") |
| paddy@69 | 357 return |
| paddy@69 | 358 } |
| paddy@84 | 359 scope, profileID, valid := gt.Validate(w, r, context) |
| paddy@84 | 360 if !valid { |
| paddy@69 | 361 return |
| paddy@69 | 362 } |
| paddy@84 | 363 refresh := "" |
| paddy@84 | 364 if gt.IssuesRefresh { |
| paddy@84 | 365 refresh = uuid.NewID().String() |
| paddy@69 | 366 } |
| paddy@69 | 367 token := Token{ |
| paddy@69 | 368 AccessToken: uuid.NewID().String(), |
| paddy@84 | 369 RefreshToken: refresh, |
| paddy@69 | 370 Created: time.Now(), |
| paddy@69 | 371 ExpiresIn: defaultTokenExpiration, |
| paddy@81 | 372 TokenType: "bearer", |
| paddy@84 | 373 Scope: scope, |
| paddy@84 | 374 ProfileID: profileID, |
| paddy@69 | 375 } |
| paddy@84 | 376 err := context.SaveToken(token) |
| paddy@69 | 377 if err != nil { |
| paddy@82 | 378 w.WriteHeader(http.StatusInternalServerError) |
| paddy@82 | 379 renderJSONError(enc, "server_error") |
| paddy@81 | 380 return |
| paddy@69 | 381 } |
| paddy@69 | 382 resp := tokenResponse{ |
| paddy@69 | 383 AccessToken: token.AccessToken, |
| paddy@69 | 384 RefreshToken: token.RefreshToken, |
| paddy@69 | 385 ExpiresIn: token.ExpiresIn, |
| paddy@69 | 386 TokenType: token.TokenType, |
| paddy@69 | 387 } |
| paddy@69 | 388 err = enc.Encode(resp) |
| paddy@69 | 389 if err != nil { |
| paddy@69 | 390 // TODO(paddy): log this or something |
| paddy@69 | 391 return |
| paddy@69 | 392 } |
| paddy@81 | 393 // BUG(paddy): we need to invalidate the grant for future requests |
| paddy@69 | 394 } |
| paddy@69 | 395 |
| paddy@68 | 396 // TODO(paddy): exchange user credentials for access token |
| paddy@68 | 397 // TODO(paddy): exchange client credentials for access token |
| paddy@68 | 398 // TODO(paddy): implicit grant for access token |
| paddy@68 | 399 // TODO(paddy): exchange refresh token for access token |