auth
auth/oauth2.go
Add tests for redirecting to the login page. Make sure that we're redirecting to the configured login page (or returning an error) as expected when trying to obtain a grant code.
| paddy@51 | 1 package auth |
| paddy@51 | 2 |
| paddy@51 | 3 import ( |
| paddy@69 | 4 "encoding/base64" |
| paddy@79 | 5 "encoding/hex" |
| paddy@69 | 6 "encoding/json" |
| paddy@69 | 7 "errors" |
| paddy@61 | 8 "html/template" |
| paddy@77 | 9 "log" |
| paddy@51 | 10 "net/http" |
| paddy@60 | 11 "net/url" |
| paddy@69 | 12 "strings" |
| paddy@60 | 13 "time" |
| paddy@77 | 14 "github.com/gorilla/mux" |
| paddy@56 | 15 |
| paddy@69 | 16 "crypto/sha256" |
| paddy@69 | 17 "code.secondbit.org/pass" |
| paddy@56 | 18 "code.secondbit.org/uuid" |
| paddy@51 | 19 ) |
| paddy@51 | 20 |
| paddy@60 | 21 const ( |
| paddy@69 | 22 authCookieName = "auth" |
| paddy@69 | 23 defaultGrantExpiration = 600 // default to ten minute grant expirations |
| paddy@60 | 24 getGrantTemplateName = "get_grant" |
| paddy@60 | 25 ) |
| paddy@51 | 26 |
| paddy@69 | 27 var ( |
| paddy@69 | 28 // ErrNoAuth is returned when an Authorization header is not present or is empty. |
| paddy@69 | 29 ErrNoAuth = errors.New("no authorization header supplied") |
| paddy@69 | 30 // ErrInvalidAuthFormat is returned when an Authorization header is present but not the correct format. |
| paddy@69 | 31 ErrInvalidAuthFormat = errors.New("authorization header is not in a valid format") |
| paddy@69 | 32 // ErrIncorrectAuth is returned when a user authentication attempt does not match the stored values. |
| paddy@69 | 33 ErrIncorrectAuth = errors.New("invalid authentication") |
| paddy@69 | 34 // ErrInvalidPassphraseScheme is returned when an undefined passphrase scheme is used. |
| paddy@69 | 35 ErrInvalidPassphraseScheme = errors.New("invalid passphrase scheme") |
| paddy@69 | 36 // ErrNoSession is returned when no session ID is passed with a request. |
| paddy@69 | 37 ErrNoSession = errors.New("no session ID found") |
| paddy@69 | 38 ) |
| paddy@69 | 39 |
| paddy@69 | 40 type tokenResponse struct { |
| paddy@69 | 41 AccessToken string `json:"access_token"` |
| paddy@69 | 42 TokenType string `json:"token_type,omitempty"` |
| paddy@69 | 43 ExpiresIn int32 `json:"expires_in,omitempty"` |
| paddy@69 | 44 RefreshToken string `json:"refresh_token,omitempty"` |
| paddy@69 | 45 } |
| paddy@69 | 46 |
| paddy@69 | 47 func getBasicAuth(r *http.Request) (un, pass string, err error) { |
| paddy@69 | 48 auth := r.Header.Get("Authorization") |
| paddy@69 | 49 if auth == "" { |
| paddy@69 | 50 return "", "", ErrNoAuth |
| paddy@69 | 51 } |
| paddy@69 | 52 pieces := strings.SplitN(auth, " ", 2) |
| paddy@69 | 53 if pieces[0] != "Basic" { |
| paddy@69 | 54 return "", "", ErrInvalidAuthFormat |
| paddy@69 | 55 } |
| paddy@69 | 56 decoded, err := base64.StdEncoding.DecodeString(pieces[1]) |
| paddy@69 | 57 if err != nil { |
| paddy@69 | 58 return "", "", ErrInvalidAuthFormat |
| paddy@69 | 59 } |
| paddy@69 | 60 info := strings.SplitN(string(decoded), ":", 2) |
| paddy@77 | 61 if len(info) < 2 { |
| paddy@77 | 62 return "", "", ErrInvalidAuthFormat |
| paddy@77 | 63 } |
| paddy@69 | 64 return info[0], info[1], nil |
| paddy@69 | 65 } |
| paddy@69 | 66 |
| paddy@69 | 67 func checkCookie(r *http.Request, context Context) (Session, error) { |
| paddy@69 | 68 cookie, err := r.Cookie(authCookieName) |
| paddy@77 | 69 if err == http.ErrNoCookie { |
| paddy@77 | 70 return Session{}, ErrNoSession |
| paddy@77 | 71 } else if err != nil { |
| paddy@77 | 72 log.Println(err) |
| paddy@69 | 73 return Session{}, err |
| paddy@69 | 74 } |
| paddy@69 | 75 sess, err := context.GetSession(cookie.Value) |
| paddy@69 | 76 if err == ErrSessionNotFound { |
| paddy@69 | 77 return Session{}, ErrInvalidSession |
| paddy@69 | 78 } else if err != nil { |
| paddy@69 | 79 return Session{}, err |
| paddy@69 | 80 } |
| paddy@69 | 81 if !sess.Active { |
| paddy@69 | 82 return Session{}, ErrInvalidSession |
| paddy@69 | 83 } |
| paddy@69 | 84 return sess, nil |
| paddy@69 | 85 } |
| paddy@69 | 86 |
| paddy@77 | 87 func buildLoginRedirect(r *http.Request, context Context) string { |
| paddy@77 | 88 if context.loginURI == nil { |
| paddy@77 | 89 return "" |
| paddy@77 | 90 } |
| paddy@77 | 91 uri := *context.loginURI |
| paddy@77 | 92 q := uri.Query() |
| paddy@78 | 93 q.Set("from", r.URL.String()) |
| paddy@77 | 94 uri.RawQuery = q.Encode() |
| paddy@77 | 95 return uri.String() |
| paddy@77 | 96 } |
| paddy@77 | 97 |
| paddy@69 | 98 func authenticate(user, passphrase string, context Context) (Profile, error) { |
| paddy@69 | 99 profile, err := context.GetProfileByLogin(user) |
| paddy@69 | 100 if err != nil { |
| paddy@79 | 101 if err == ErrProfileNotFound || err == ErrLoginNotFound { |
| paddy@69 | 102 return Profile{}, ErrIncorrectAuth |
| paddy@69 | 103 } |
| paddy@69 | 104 return Profile{}, err |
| paddy@69 | 105 } |
| paddy@69 | 106 switch profile.PassphraseScheme { |
| paddy@69 | 107 case 1: |
| paddy@79 | 108 realPass, err := hex.DecodeString(profile.Passphrase) |
| paddy@79 | 109 if err != nil { |
| paddy@79 | 110 return Profile{}, err |
| paddy@79 | 111 } |
| paddy@69 | 112 candidate := pass.Check(sha256.New, profile.Iterations, []byte(passphrase), []byte(profile.Salt)) |
| paddy@79 | 113 if !pass.Compare(candidate, realPass) { |
| paddy@69 | 114 return Profile{}, ErrIncorrectAuth |
| paddy@69 | 115 } |
| paddy@69 | 116 default: |
| paddy@69 | 117 return Profile{}, ErrInvalidPassphraseScheme |
| paddy@69 | 118 } |
| paddy@69 | 119 return profile, nil |
| paddy@69 | 120 } |
| paddy@69 | 121 |
| paddy@77 | 122 func wrap(context Context, f func(w http.ResponseWriter, r *http.Request, context Context)) http.Handler { |
| paddy@77 | 123 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| paddy@77 | 124 f(w, r, context) |
| paddy@77 | 125 }) |
| paddy@77 | 126 } |
| paddy@77 | 127 |
| paddy@77 | 128 // RegisterOAuth2 adds handlers to the passed router to handle the OAuth2 endpoints. |
| paddy@77 | 129 func RegisterOAuth2(r *mux.Router, context Context) { |
| paddy@77 | 130 r.Handle("/authorize", wrap(context, GetGrantHandler)) |
| paddy@77 | 131 r.Handle("/token", wrap(context, GetTokenHandler)) |
| paddy@77 | 132 } |
| paddy@77 | 133 |
| paddy@57 | 134 // GetGrantHandler presents and processes the page for asking a user to grant access |
| paddy@57 | 135 // to their data. See RFC 6749, Section 4.1. |
| paddy@51 | 136 func GetGrantHandler(w http.ResponseWriter, r *http.Request, context Context) { |
| paddy@69 | 137 session, err := checkCookie(r, context) |
| paddy@69 | 138 if err != nil { |
| paddy@76 | 139 if err == ErrNoSession || err == ErrInvalidSession { |
| paddy@77 | 140 redir := buildLoginRedirect(r, context) |
| paddy@77 | 141 if redir == "" { |
| paddy@77 | 142 log.Println("No login URL configured.") |
| paddy@77 | 143 w.WriteHeader(http.StatusInternalServerError) |
| paddy@77 | 144 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@77 | 145 "internal_error": template.HTML("Missing login URL."), |
| paddy@77 | 146 }) |
| paddy@77 | 147 return |
| paddy@77 | 148 } |
| paddy@77 | 149 http.Redirect(w, r, redir, http.StatusFound) |
| paddy@77 | 150 return |
| paddy@69 | 151 } |
| paddy@77 | 152 log.Println(err.Error()) |
| paddy@77 | 153 w.WriteHeader(http.StatusInternalServerError) |
| paddy@77 | 154 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@77 | 155 "internal_error": template.HTML(err.Error()), |
| paddy@77 | 156 }) |
| paddy@77 | 157 return |
| paddy@69 | 158 } |
| paddy@56 | 159 if r.URL.Query().Get("client_id") == "" { |
| paddy@56 | 160 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 161 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 162 "error": template.HTML("Client ID must be specified in the request."), |
| paddy@56 | 163 }) |
| paddy@56 | 164 return |
| paddy@56 | 165 } |
| paddy@56 | 166 clientID, err := uuid.Parse(r.URL.Query().Get("client_id")) |
| paddy@56 | 167 if err != nil { |
| paddy@56 | 168 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 169 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 170 "error": template.HTML("client_id is not a valid Client ID."), |
| paddy@56 | 171 }) |
| paddy@56 | 172 return |
| paddy@56 | 173 } |
| paddy@64 | 174 redirectURI := r.URL.Query().Get("redirect_uri") |
| paddy@64 | 175 redirectURL, err := url.Parse(redirectURI) |
| paddy@64 | 176 if err != nil { |
| paddy@64 | 177 w.WriteHeader(http.StatusBadRequest) |
| paddy@64 | 178 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@64 | 179 "error": template.HTML("The redirect_uri specified is not valid."), |
| paddy@64 | 180 }) |
| paddy@64 | 181 return |
| paddy@64 | 182 } |
| paddy@56 | 183 client, err := context.GetClient(clientID) |
| paddy@56 | 184 if err != nil { |
| paddy@59 | 185 if err == ErrClientNotFound { |
| paddy@59 | 186 w.WriteHeader(http.StatusBadRequest) |
| paddy@59 | 187 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 188 "error": template.HTML("The specified Client couldn’t be found."), |
| paddy@59 | 189 }) |
| paddy@59 | 190 } else { |
| paddy@77 | 191 log.Println(err.Error()) |
| paddy@59 | 192 w.WriteHeader(http.StatusInternalServerError) |
| paddy@59 | 193 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 194 "internal_error": template.HTML(err.Error()), |
| paddy@59 | 195 }) |
| paddy@59 | 196 } |
| paddy@56 | 197 return |
| paddy@56 | 198 } |
| paddy@56 | 199 // whether a redirect URI is valid or not depends on the number of endpoints |
| paddy@56 | 200 // the client has registered |
| paddy@56 | 201 numEndpoints, err := context.CountEndpoints(clientID) |
| paddy@56 | 202 if err != nil { |
| paddy@77 | 203 log.Println(err.Error()) |
| paddy@56 | 204 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 205 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 206 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 207 }) |
| paddy@56 | 208 return |
| paddy@56 | 209 } |
| paddy@56 | 210 var validURI bool |
| paddy@58 | 211 if redirectURI != "" { |
| paddy@58 | 212 // BUG(paddy): We really should normalize URIs before trying to compare them. |
| paddy@58 | 213 validURI, err = context.CheckEndpoint(clientID, redirectURI) |
| paddy@56 | 214 if err != nil { |
| paddy@77 | 215 log.Println(err.Error()) |
| paddy@56 | 216 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 217 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 218 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 219 }) |
| paddy@56 | 220 return |
| paddy@56 | 221 } |
| paddy@56 | 222 } else if redirectURI == "" && numEndpoints == 1 { |
| paddy@56 | 223 // if we don't specify the endpoint and there's only one endpoint, the |
| paddy@56 | 224 // request is valid, and we're redirecting to that one endpoint |
| paddy@56 | 225 validURI = true |
| paddy@56 | 226 endpoints, err := context.ListEndpoints(clientID, 1, 0) |
| paddy@56 | 227 if err != nil { |
| paddy@77 | 228 log.Println(err.Error()) |
| paddy@56 | 229 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 230 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 231 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 232 }) |
| paddy@56 | 233 return |
| paddy@56 | 234 } |
| paddy@56 | 235 if len(endpoints) != 1 { |
| paddy@56 | 236 validURI = false |
| paddy@56 | 237 } else { |
| paddy@66 | 238 u := endpoints[0].URI // Copy it here to avoid grabbing a pointer to the memstore |
| paddy@66 | 239 redirectURI = u.String() |
| paddy@66 | 240 redirectURL = &u |
| paddy@56 | 241 } |
| paddy@56 | 242 } else { |
| paddy@56 | 243 validURI = false |
| paddy@56 | 244 } |
| paddy@56 | 245 if !validURI { |
| paddy@56 | 246 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 247 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 248 "error": template.HTML("The redirect_uri specified is not valid."), |
| paddy@56 | 249 }) |
| paddy@56 | 250 return |
| paddy@56 | 251 } |
| paddy@60 | 252 scope := r.URL.Query().Get("scope") |
| paddy@60 | 253 state := r.URL.Query().Get("state") |
| paddy@56 | 254 if r.URL.Query().Get("response_type") != "code" { |
| paddy@65 | 255 q := redirectURL.Query() |
| paddy@65 | 256 q.Add("error", "invalid_request") |
| paddy@65 | 257 q.Add("state", state) |
| paddy@65 | 258 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 259 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 260 return |
| paddy@56 | 261 } |
| paddy@56 | 262 if r.Method == "POST" { |
| paddy@63 | 263 // BUG(paddy): We need to implement CSRF protection when obtaining a grant code. |
| paddy@56 | 264 if r.PostFormValue("grant") == "approved" { |
| paddy@60 | 265 code := uuid.NewID().String() |
| paddy@60 | 266 grant := Grant{ |
| paddy@60 | 267 Code: code, |
| paddy@60 | 268 Created: time.Now(), |
| paddy@60 | 269 ExpiresIn: defaultGrantExpiration, |
| paddy@60 | 270 ClientID: clientID, |
| paddy@60 | 271 Scope: scope, |
| paddy@69 | 272 RedirectURI: r.URL.Query().Get("redirect_uri"), |
| paddy@60 | 273 State: state, |
| paddy@69 | 274 ProfileID: session.ProfileID, |
| paddy@60 | 275 } |
| paddy@60 | 276 err := context.SaveGrant(grant) |
| paddy@60 | 277 if err != nil { |
| paddy@66 | 278 q := redirectURL.Query() |
| paddy@66 | 279 q.Add("error", "server_error") |
| paddy@66 | 280 q.Add("state", state) |
| paddy@66 | 281 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 282 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 283 return |
| paddy@60 | 284 } |
| paddy@66 | 285 q := redirectURL.Query() |
| paddy@66 | 286 q.Add("code", code) |
| paddy@66 | 287 q.Add("state", state) |
| paddy@66 | 288 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 289 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 290 return |
| paddy@56 | 291 } |
| paddy@66 | 292 q := redirectURL.Query() |
| paddy@66 | 293 q.Add("error", "access_denied") |
| paddy@66 | 294 q.Add("state", state) |
| paddy@66 | 295 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 296 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 297 return |
| paddy@56 | 298 } |
| paddy@51 | 299 w.WriteHeader(http.StatusOK) |
| paddy@56 | 300 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@56 | 301 "client": client, |
| paddy@56 | 302 }) |
| paddy@51 | 303 } |
| paddy@68 | 304 |
| paddy@69 | 305 // GetTokenHandler allows a client to exchange an authorization grant for an |
| paddy@69 | 306 // access token. See RFC 6749 Section 4.1.3. |
| paddy@69 | 307 func GetTokenHandler(w http.ResponseWriter, r *http.Request, context Context) { |
| paddy@69 | 308 enc := json.NewEncoder(w) |
| paddy@69 | 309 grantType := r.PostFormValue("grant_type") |
| paddy@69 | 310 if grantType != "authorization_code" { |
| paddy@69 | 311 // TODO(paddy): render invalid request JSON |
| paddy@69 | 312 return |
| paddy@69 | 313 } |
| paddy@69 | 314 code := r.PostFormValue("code") |
| paddy@69 | 315 if code == "" { |
| paddy@69 | 316 // TODO(paddy): render invalid request JSON |
| paddy@69 | 317 return |
| paddy@69 | 318 } |
| paddy@69 | 319 redirectURI := r.PostFormValue("redirect_uri") |
| paddy@69 | 320 clientIDStr, clientSecret, err := getBasicAuth(r) |
| paddy@69 | 321 if err != nil { |
| paddy@69 | 322 // TODO(paddy): render access denied |
| paddy@69 | 323 return |
| paddy@69 | 324 } |
| paddy@69 | 325 if clientIDStr == "" && err == nil { |
| paddy@69 | 326 clientIDStr = r.PostFormValue("client_id") |
| paddy@69 | 327 } |
| paddy@69 | 328 clientID, err := uuid.Parse(clientIDStr) |
| paddy@69 | 329 if err != nil { |
| paddy@69 | 330 // TODO(paddy): render invalid request JSON |
| paddy@69 | 331 return |
| paddy@69 | 332 } |
| paddy@69 | 333 client, err := context.GetClient(clientID) |
| paddy@69 | 334 if err != nil { |
| paddy@69 | 335 if err == ErrClientNotFound { |
| paddy@69 | 336 // TODO(paddy): render invalid request JSON |
| paddy@69 | 337 } else { |
| paddy@69 | 338 // TODO(paddy): render internal server error JSON |
| paddy@69 | 339 } |
| paddy@69 | 340 return |
| paddy@69 | 341 } |
| paddy@69 | 342 if client.Secret != clientSecret { |
| paddy@69 | 343 // TODO(paddy): render invalid request JSON |
| paddy@69 | 344 return |
| paddy@69 | 345 } |
| paddy@69 | 346 grant, err := context.GetGrant(code) |
| paddy@69 | 347 if err != nil { |
| paddy@69 | 348 if err == ErrGrantNotFound { |
| paddy@69 | 349 // TODO(paddy): return error |
| paddy@69 | 350 return |
| paddy@69 | 351 } |
| paddy@69 | 352 // TODO(paddy): return error |
| paddy@69 | 353 } |
| paddy@69 | 354 if grant.RedirectURI != redirectURI { |
| paddy@69 | 355 // TODO(paddy): return error |
| paddy@69 | 356 } |
| paddy@69 | 357 if !grant.ClientID.Equal(clientID) { |
| paddy@69 | 358 // TODO(paddy): return error |
| paddy@69 | 359 } |
| paddy@69 | 360 token := Token{ |
| paddy@69 | 361 AccessToken: uuid.NewID().String(), |
| paddy@69 | 362 RefreshToken: uuid.NewID().String(), |
| paddy@69 | 363 Created: time.Now(), |
| paddy@69 | 364 ExpiresIn: defaultTokenExpiration, |
| paddy@69 | 365 TokenType: "", // TODO(paddy): fill in token type |
| paddy@69 | 366 Scope: grant.Scope, |
| paddy@69 | 367 ProfileID: grant.ProfileID, |
| paddy@69 | 368 } |
| paddy@69 | 369 err = context.SaveToken(token) |
| paddy@69 | 370 if err != nil { |
| paddy@69 | 371 // TODO(paddy): return error |
| paddy@69 | 372 } |
| paddy@69 | 373 resp := tokenResponse{ |
| paddy@69 | 374 AccessToken: token.AccessToken, |
| paddy@69 | 375 RefreshToken: token.RefreshToken, |
| paddy@69 | 376 ExpiresIn: token.ExpiresIn, |
| paddy@69 | 377 TokenType: token.TokenType, |
| paddy@69 | 378 } |
| paddy@69 | 379 err = enc.Encode(resp) |
| paddy@69 | 380 if err != nil { |
| paddy@69 | 381 // TODO(paddy): log this or something |
| paddy@69 | 382 return |
| paddy@69 | 383 } |
| paddy@69 | 384 } |
| paddy@69 | 385 |
| paddy@68 | 386 // TODO(paddy): exchange user credentials for access token |
| paddy@68 | 387 // TODO(paddy): exchange client credentials for access token |
| paddy@68 | 388 // TODO(paddy): implicit grant for access token |
| paddy@68 | 389 // TODO(paddy): exchange refresh token for access token |