auth
auth/oauth2.go
Write Session tests. Add loginURI as a property to our Context, to keep track of where users should be redirected to log in. Implement the sessionStore in the memstore to let us test with Sessions. Catch when the HTTP Basic Auth header doesn't include two parts, rather than panicking. Return an ErrInvalidAuthFormat. Clean up the error handling for checkCookie to be cleaner. Log unexpected errors from request.Cookie. Stop checking for cookie expiration times--those aren't sent to the server, so we'll never get a valid session if we look for them. Add a helper to build a login redirect URI--a URI the user can be redirected to that has a URL-encoded URL to redirect the user back to after a successful login. Add a wrapper to wrap our Context into HTTP handlers. Create a RegisterOAuth2 helper that adds the OAuth2 endpoints to a Gorilla/mux router. Redirect users to the login page when they have no session set or an invalid session. Return a server error when we can't check our cookie for whatever reason. Log errors. Add sessions to our OAuth2 tests so the tests stop failing--the session check was interfering with them. Add a test for our getBasicAuth helper to ensure that we're parsing basic auth correctly. Add an ErrSessionAlreadyExists error to be returned when a session has an ID conflict. Test that our memstore implementation of the sessionStore works as intended..
| paddy@51 | 1 package auth |
| paddy@51 | 2 |
| paddy@51 | 3 import ( |
| paddy@69 | 4 "encoding/base64" |
| paddy@69 | 5 "encoding/json" |
| paddy@69 | 6 "errors" |
| paddy@61 | 7 "html/template" |
| paddy@77 | 8 "log" |
| paddy@51 | 9 "net/http" |
| paddy@60 | 10 "net/url" |
| paddy@69 | 11 "strings" |
| paddy@60 | 12 "time" |
| paddy@77 | 13 "github.com/gorilla/mux" |
| paddy@56 | 14 |
| paddy@69 | 15 "crypto/sha256" |
| paddy@69 | 16 "code.secondbit.org/pass" |
| paddy@56 | 17 "code.secondbit.org/uuid" |
| paddy@51 | 18 ) |
| paddy@51 | 19 |
| paddy@60 | 20 const ( |
| paddy@69 | 21 authCookieName = "auth" |
| paddy@69 | 22 defaultGrantExpiration = 600 // default to ten minute grant expirations |
| paddy@60 | 23 getGrantTemplateName = "get_grant" |
| paddy@60 | 24 ) |
| paddy@51 | 25 |
| paddy@69 | 26 var ( |
| paddy@69 | 27 // ErrNoAuth is returned when an Authorization header is not present or is empty. |
| paddy@69 | 28 ErrNoAuth = errors.New("no authorization header supplied") |
| paddy@69 | 29 // ErrInvalidAuthFormat is returned when an Authorization header is present but not the correct format. |
| paddy@69 | 30 ErrInvalidAuthFormat = errors.New("authorization header is not in a valid format") |
| paddy@69 | 31 // ErrIncorrectAuth is returned when a user authentication attempt does not match the stored values. |
| paddy@69 | 32 ErrIncorrectAuth = errors.New("invalid authentication") |
| paddy@69 | 33 // ErrInvalidPassphraseScheme is returned when an undefined passphrase scheme is used. |
| paddy@69 | 34 ErrInvalidPassphraseScheme = errors.New("invalid passphrase scheme") |
| paddy@69 | 35 // ErrNoSession is returned when no session ID is passed with a request. |
| paddy@69 | 36 ErrNoSession = errors.New("no session ID found") |
| paddy@69 | 37 ) |
| paddy@69 | 38 |
| paddy@69 | 39 type tokenResponse struct { |
| paddy@69 | 40 AccessToken string `json:"access_token"` |
| paddy@69 | 41 TokenType string `json:"token_type,omitempty"` |
| paddy@69 | 42 ExpiresIn int32 `json:"expires_in,omitempty"` |
| paddy@69 | 43 RefreshToken string `json:"refresh_token,omitempty"` |
| paddy@69 | 44 } |
| paddy@69 | 45 |
| paddy@69 | 46 func getBasicAuth(r *http.Request) (un, pass string, err error) { |
| paddy@69 | 47 auth := r.Header.Get("Authorization") |
| paddy@69 | 48 if auth == "" { |
| paddy@69 | 49 return "", "", ErrNoAuth |
| paddy@69 | 50 } |
| paddy@69 | 51 pieces := strings.SplitN(auth, " ", 2) |
| paddy@69 | 52 if pieces[0] != "Basic" { |
| paddy@69 | 53 return "", "", ErrInvalidAuthFormat |
| paddy@69 | 54 } |
| paddy@69 | 55 decoded, err := base64.StdEncoding.DecodeString(pieces[1]) |
| paddy@69 | 56 if err != nil { |
| paddy@69 | 57 return "", "", ErrInvalidAuthFormat |
| paddy@69 | 58 } |
| paddy@69 | 59 info := strings.SplitN(string(decoded), ":", 2) |
| paddy@77 | 60 if len(info) < 2 { |
| paddy@77 | 61 return "", "", ErrInvalidAuthFormat |
| paddy@77 | 62 } |
| paddy@69 | 63 return info[0], info[1], nil |
| paddy@69 | 64 } |
| paddy@69 | 65 |
| paddy@69 | 66 func checkCookie(r *http.Request, context Context) (Session, error) { |
| paddy@69 | 67 cookie, err := r.Cookie(authCookieName) |
| paddy@77 | 68 if err == http.ErrNoCookie { |
| paddy@77 | 69 return Session{}, ErrNoSession |
| paddy@77 | 70 } else if err != nil { |
| paddy@77 | 71 log.Println(err) |
| paddy@69 | 72 return Session{}, err |
| paddy@69 | 73 } |
| paddy@69 | 74 sess, err := context.GetSession(cookie.Value) |
| paddy@69 | 75 if err == ErrSessionNotFound { |
| paddy@69 | 76 return Session{}, ErrInvalidSession |
| paddy@69 | 77 } else if err != nil { |
| paddy@69 | 78 return Session{}, err |
| paddy@69 | 79 } |
| paddy@69 | 80 if !sess.Active { |
| paddy@69 | 81 return Session{}, ErrInvalidSession |
| paddy@69 | 82 } |
| paddy@69 | 83 return sess, nil |
| paddy@69 | 84 } |
| paddy@69 | 85 |
| paddy@77 | 86 func buildLoginRedirect(r *http.Request, context Context) string { |
| paddy@77 | 87 if context.loginURI == nil { |
| paddy@77 | 88 return "" |
| paddy@77 | 89 } |
| paddy@77 | 90 uri := *context.loginURI |
| paddy@77 | 91 q := uri.Query() |
| paddy@77 | 92 q.Set("from", url.QueryEscape(r.URL.String())) |
| paddy@77 | 93 uri.RawQuery = q.Encode() |
| paddy@77 | 94 return uri.String() |
| paddy@77 | 95 } |
| paddy@77 | 96 |
| paddy@69 | 97 func authenticate(user, passphrase string, context Context) (Profile, error) { |
| paddy@69 | 98 profile, err := context.GetProfileByLogin(user) |
| paddy@69 | 99 if err != nil { |
| paddy@69 | 100 if err == ErrProfileNotFound { |
| paddy@69 | 101 return Profile{}, ErrIncorrectAuth |
| paddy@69 | 102 } |
| paddy@69 | 103 return Profile{}, err |
| paddy@69 | 104 } |
| paddy@69 | 105 switch profile.PassphraseScheme { |
| paddy@69 | 106 case 1: |
| paddy@69 | 107 candidate := pass.Check(sha256.New, profile.Iterations, []byte(passphrase), []byte(profile.Salt)) |
| paddy@69 | 108 if !pass.Compare(candidate, []byte(profile.Passphrase)) { |
| paddy@69 | 109 return Profile{}, ErrIncorrectAuth |
| paddy@69 | 110 } |
| paddy@69 | 111 default: |
| paddy@69 | 112 return Profile{}, ErrInvalidPassphraseScheme |
| paddy@69 | 113 } |
| paddy@69 | 114 return profile, nil |
| paddy@69 | 115 } |
| paddy@69 | 116 |
| paddy@77 | 117 func wrap(context Context, f func(w http.ResponseWriter, r *http.Request, context Context)) http.Handler { |
| paddy@77 | 118 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| paddy@77 | 119 f(w, r, context) |
| paddy@77 | 120 }) |
| paddy@77 | 121 } |
| paddy@77 | 122 |
| paddy@77 | 123 // RegisterOAuth2 adds handlers to the passed router to handle the OAuth2 endpoints. |
| paddy@77 | 124 func RegisterOAuth2(r *mux.Router, context Context) { |
| paddy@77 | 125 r.Handle("/authorize", wrap(context, GetGrantHandler)) |
| paddy@77 | 126 r.Handle("/token", wrap(context, GetTokenHandler)) |
| paddy@77 | 127 } |
| paddy@77 | 128 |
| paddy@57 | 129 // GetGrantHandler presents and processes the page for asking a user to grant access |
| paddy@57 | 130 // to their data. See RFC 6749, Section 4.1. |
| paddy@51 | 131 func GetGrantHandler(w http.ResponseWriter, r *http.Request, context Context) { |
| paddy@69 | 132 session, err := checkCookie(r, context) |
| paddy@69 | 133 if err != nil { |
| paddy@76 | 134 if err == ErrNoSession || err == ErrInvalidSession { |
| paddy@77 | 135 redir := buildLoginRedirect(r, context) |
| paddy@77 | 136 if redir == "" { |
| paddy@77 | 137 log.Println("No login URL configured.") |
| paddy@77 | 138 w.WriteHeader(http.StatusInternalServerError) |
| paddy@77 | 139 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@77 | 140 "internal_error": template.HTML("Missing login URL."), |
| paddy@77 | 141 }) |
| paddy@77 | 142 return |
| paddy@77 | 143 } |
| paddy@77 | 144 http.Redirect(w, r, redir, http.StatusFound) |
| paddy@77 | 145 return |
| paddy@69 | 146 } |
| paddy@77 | 147 log.Println(err.Error()) |
| paddy@77 | 148 w.WriteHeader(http.StatusInternalServerError) |
| paddy@77 | 149 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@77 | 150 "internal_error": template.HTML(err.Error()), |
| paddy@77 | 151 }) |
| paddy@77 | 152 return |
| paddy@69 | 153 } |
| paddy@56 | 154 if r.URL.Query().Get("client_id") == "" { |
| paddy@56 | 155 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 156 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 157 "error": template.HTML("Client ID must be specified in the request."), |
| paddy@56 | 158 }) |
| paddy@56 | 159 return |
| paddy@56 | 160 } |
| paddy@56 | 161 clientID, err := uuid.Parse(r.URL.Query().Get("client_id")) |
| paddy@56 | 162 if err != nil { |
| paddy@56 | 163 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 164 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 165 "error": template.HTML("client_id is not a valid Client ID."), |
| paddy@56 | 166 }) |
| paddy@56 | 167 return |
| paddy@56 | 168 } |
| paddy@64 | 169 redirectURI := r.URL.Query().Get("redirect_uri") |
| paddy@64 | 170 redirectURL, err := url.Parse(redirectURI) |
| paddy@64 | 171 if err != nil { |
| paddy@64 | 172 w.WriteHeader(http.StatusBadRequest) |
| paddy@64 | 173 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@64 | 174 "error": template.HTML("The redirect_uri specified is not valid."), |
| paddy@64 | 175 }) |
| paddy@64 | 176 return |
| paddy@64 | 177 } |
| paddy@56 | 178 client, err := context.GetClient(clientID) |
| paddy@56 | 179 if err != nil { |
| paddy@59 | 180 if err == ErrClientNotFound { |
| paddy@59 | 181 w.WriteHeader(http.StatusBadRequest) |
| paddy@59 | 182 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 183 "error": template.HTML("The specified Client couldn’t be found."), |
| paddy@59 | 184 }) |
| paddy@59 | 185 } else { |
| paddy@77 | 186 log.Println(err.Error()) |
| paddy@59 | 187 w.WriteHeader(http.StatusInternalServerError) |
| paddy@59 | 188 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 189 "internal_error": template.HTML(err.Error()), |
| paddy@59 | 190 }) |
| paddy@59 | 191 } |
| paddy@56 | 192 return |
| paddy@56 | 193 } |
| paddy@56 | 194 // whether a redirect URI is valid or not depends on the number of endpoints |
| paddy@56 | 195 // the client has registered |
| paddy@56 | 196 numEndpoints, err := context.CountEndpoints(clientID) |
| paddy@56 | 197 if err != nil { |
| paddy@77 | 198 log.Println(err.Error()) |
| paddy@56 | 199 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 200 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 201 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 202 }) |
| paddy@56 | 203 return |
| paddy@56 | 204 } |
| paddy@56 | 205 var validURI bool |
| paddy@58 | 206 if redirectURI != "" { |
| paddy@58 | 207 // BUG(paddy): We really should normalize URIs before trying to compare them. |
| paddy@58 | 208 validURI, err = context.CheckEndpoint(clientID, redirectURI) |
| paddy@56 | 209 if err != nil { |
| paddy@77 | 210 log.Println(err.Error()) |
| paddy@56 | 211 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 212 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 213 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 214 }) |
| paddy@56 | 215 return |
| paddy@56 | 216 } |
| paddy@56 | 217 } else if redirectURI == "" && numEndpoints == 1 { |
| paddy@56 | 218 // if we don't specify the endpoint and there's only one endpoint, the |
| paddy@56 | 219 // request is valid, and we're redirecting to that one endpoint |
| paddy@56 | 220 validURI = true |
| paddy@56 | 221 endpoints, err := context.ListEndpoints(clientID, 1, 0) |
| paddy@56 | 222 if err != nil { |
| paddy@77 | 223 log.Println(err.Error()) |
| paddy@56 | 224 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 225 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 226 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 227 }) |
| paddy@56 | 228 return |
| paddy@56 | 229 } |
| paddy@56 | 230 if len(endpoints) != 1 { |
| paddy@56 | 231 validURI = false |
| paddy@56 | 232 } else { |
| paddy@66 | 233 u := endpoints[0].URI // Copy it here to avoid grabbing a pointer to the memstore |
| paddy@66 | 234 redirectURI = u.String() |
| paddy@66 | 235 redirectURL = &u |
| paddy@56 | 236 } |
| paddy@56 | 237 } else { |
| paddy@56 | 238 validURI = false |
| paddy@56 | 239 } |
| paddy@56 | 240 if !validURI { |
| paddy@56 | 241 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 242 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 243 "error": template.HTML("The redirect_uri specified is not valid."), |
| paddy@56 | 244 }) |
| paddy@56 | 245 return |
| paddy@56 | 246 } |
| paddy@60 | 247 scope := r.URL.Query().Get("scope") |
| paddy@60 | 248 state := r.URL.Query().Get("state") |
| paddy@56 | 249 if r.URL.Query().Get("response_type") != "code" { |
| paddy@65 | 250 q := redirectURL.Query() |
| paddy@65 | 251 q.Add("error", "invalid_request") |
| paddy@65 | 252 q.Add("state", state) |
| paddy@65 | 253 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 254 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 255 return |
| paddy@56 | 256 } |
| paddy@56 | 257 if r.Method == "POST" { |
| paddy@63 | 258 // BUG(paddy): We need to implement CSRF protection when obtaining a grant code. |
| paddy@56 | 259 if r.PostFormValue("grant") == "approved" { |
| paddy@60 | 260 code := uuid.NewID().String() |
| paddy@60 | 261 grant := Grant{ |
| paddy@60 | 262 Code: code, |
| paddy@60 | 263 Created: time.Now(), |
| paddy@60 | 264 ExpiresIn: defaultGrantExpiration, |
| paddy@60 | 265 ClientID: clientID, |
| paddy@60 | 266 Scope: scope, |
| paddy@69 | 267 RedirectURI: r.URL.Query().Get("redirect_uri"), |
| paddy@60 | 268 State: state, |
| paddy@69 | 269 ProfileID: session.ProfileID, |
| paddy@60 | 270 } |
| paddy@60 | 271 err := context.SaveGrant(grant) |
| paddy@60 | 272 if err != nil { |
| paddy@66 | 273 q := redirectURL.Query() |
| paddy@66 | 274 q.Add("error", "server_error") |
| paddy@66 | 275 q.Add("state", state) |
| paddy@66 | 276 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 277 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 278 return |
| paddy@60 | 279 } |
| paddy@66 | 280 q := redirectURL.Query() |
| paddy@66 | 281 q.Add("code", code) |
| paddy@66 | 282 q.Add("state", state) |
| paddy@66 | 283 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 284 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 285 return |
| paddy@56 | 286 } |
| paddy@66 | 287 q := redirectURL.Query() |
| paddy@66 | 288 q.Add("error", "access_denied") |
| paddy@66 | 289 q.Add("state", state) |
| paddy@66 | 290 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 291 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 292 return |
| paddy@56 | 293 } |
| paddy@51 | 294 w.WriteHeader(http.StatusOK) |
| paddy@56 | 295 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@56 | 296 "client": client, |
| paddy@56 | 297 }) |
| paddy@51 | 298 } |
| paddy@68 | 299 |
| paddy@69 | 300 // GetTokenHandler allows a client to exchange an authorization grant for an |
| paddy@69 | 301 // access token. See RFC 6749 Section 4.1.3. |
| paddy@69 | 302 func GetTokenHandler(w http.ResponseWriter, r *http.Request, context Context) { |
| paddy@69 | 303 enc := json.NewEncoder(w) |
| paddy@69 | 304 grantType := r.PostFormValue("grant_type") |
| paddy@69 | 305 if grantType != "authorization_code" { |
| paddy@69 | 306 // TODO(paddy): render invalid request JSON |
| paddy@69 | 307 return |
| paddy@69 | 308 } |
| paddy@69 | 309 code := r.PostFormValue("code") |
| paddy@69 | 310 if code == "" { |
| paddy@69 | 311 // TODO(paddy): render invalid request JSON |
| paddy@69 | 312 return |
| paddy@69 | 313 } |
| paddy@69 | 314 redirectURI := r.PostFormValue("redirect_uri") |
| paddy@69 | 315 clientIDStr, clientSecret, err := getBasicAuth(r) |
| paddy@69 | 316 if err != nil { |
| paddy@69 | 317 // TODO(paddy): render access denied |
| paddy@69 | 318 return |
| paddy@69 | 319 } |
| paddy@69 | 320 if clientIDStr == "" && err == nil { |
| paddy@69 | 321 clientIDStr = r.PostFormValue("client_id") |
| paddy@69 | 322 } |
| paddy@69 | 323 clientID, err := uuid.Parse(clientIDStr) |
| paddy@69 | 324 if err != nil { |
| paddy@69 | 325 // TODO(paddy): render invalid request JSON |
| paddy@69 | 326 return |
| paddy@69 | 327 } |
| paddy@69 | 328 client, err := context.GetClient(clientID) |
| paddy@69 | 329 if err != nil { |
| paddy@69 | 330 if err == ErrClientNotFound { |
| paddy@69 | 331 // TODO(paddy): render invalid request JSON |
| paddy@69 | 332 } else { |
| paddy@69 | 333 // TODO(paddy): render internal server error JSON |
| paddy@69 | 334 } |
| paddy@69 | 335 return |
| paddy@69 | 336 } |
| paddy@69 | 337 if client.Secret != clientSecret { |
| paddy@69 | 338 // TODO(paddy): render invalid request JSON |
| paddy@69 | 339 return |
| paddy@69 | 340 } |
| paddy@69 | 341 grant, err := context.GetGrant(code) |
| paddy@69 | 342 if err != nil { |
| paddy@69 | 343 if err == ErrGrantNotFound { |
| paddy@69 | 344 // TODO(paddy): return error |
| paddy@69 | 345 return |
| paddy@69 | 346 } |
| paddy@69 | 347 // TODO(paddy): return error |
| paddy@69 | 348 } |
| paddy@69 | 349 if grant.RedirectURI != redirectURI { |
| paddy@69 | 350 // TODO(paddy): return error |
| paddy@69 | 351 } |
| paddy@69 | 352 if !grant.ClientID.Equal(clientID) { |
| paddy@69 | 353 // TODO(paddy): return error |
| paddy@69 | 354 } |
| paddy@69 | 355 token := Token{ |
| paddy@69 | 356 AccessToken: uuid.NewID().String(), |
| paddy@69 | 357 RefreshToken: uuid.NewID().String(), |
| paddy@69 | 358 Created: time.Now(), |
| paddy@69 | 359 ExpiresIn: defaultTokenExpiration, |
| paddy@69 | 360 TokenType: "", // TODO(paddy): fill in token type |
| paddy@69 | 361 Scope: grant.Scope, |
| paddy@69 | 362 ProfileID: grant.ProfileID, |
| paddy@69 | 363 } |
| paddy@69 | 364 err = context.SaveToken(token) |
| paddy@69 | 365 if err != nil { |
| paddy@69 | 366 // TODO(paddy): return error |
| paddy@69 | 367 } |
| paddy@69 | 368 resp := tokenResponse{ |
| paddy@69 | 369 AccessToken: token.AccessToken, |
| paddy@69 | 370 RefreshToken: token.RefreshToken, |
| paddy@69 | 371 ExpiresIn: token.ExpiresIn, |
| paddy@69 | 372 TokenType: token.TokenType, |
| paddy@69 | 373 } |
| paddy@69 | 374 err = enc.Encode(resp) |
| paddy@69 | 375 if err != nil { |
| paddy@69 | 376 // TODO(paddy): log this or something |
| paddy@69 | 377 return |
| paddy@69 | 378 } |
| paddy@69 | 379 } |
| paddy@69 | 380 |
| paddy@68 | 381 // TODO(paddy): exchange user credentials for access token |
| paddy@68 | 382 // TODO(paddy): exchange client credentials for access token |
| paddy@68 | 383 // TODO(paddy): implicit grant for access token |
| paddy@68 | 384 // TODO(paddy): exchange refresh token for access token |