auth
auth/oauth2.go
Test authentication helper, fix bugs with authentication. Authentication needs to be hex encoded to be stored, so it only makes sense to decode the hex string stored to get the bytes we'll be comparing. Check for ErrLoginNotFound in addition to ErrProfileNotFound. ErrLoginNotFound is more likely to occur, anyways. Add unit tests for our authentication helper to verify it functions as expected.
| paddy@51 | 1 package auth |
| paddy@51 | 2 |
| paddy@51 | 3 import ( |
| paddy@69 | 4 "encoding/base64" |
| paddy@79 | 5 "encoding/hex" |
| paddy@69 | 6 "encoding/json" |
| paddy@69 | 7 "errors" |
| paddy@61 | 8 "html/template" |
| paddy@77 | 9 "log" |
| paddy@51 | 10 "net/http" |
| paddy@60 | 11 "net/url" |
| paddy@69 | 12 "strings" |
| paddy@60 | 13 "time" |
| paddy@77 | 14 "github.com/gorilla/mux" |
| paddy@56 | 15 |
| paddy@69 | 16 "crypto/sha256" |
| paddy@69 | 17 "code.secondbit.org/pass" |
| paddy@56 | 18 "code.secondbit.org/uuid" |
| paddy@51 | 19 ) |
| paddy@51 | 20 |
| paddy@60 | 21 const ( |
| paddy@69 | 22 authCookieName = "auth" |
| paddy@69 | 23 defaultGrantExpiration = 600 // default to ten minute grant expirations |
| paddy@60 | 24 getGrantTemplateName = "get_grant" |
| paddy@60 | 25 ) |
| paddy@51 | 26 |
| paddy@69 | 27 var ( |
| paddy@69 | 28 // ErrNoAuth is returned when an Authorization header is not present or is empty. |
| paddy@69 | 29 ErrNoAuth = errors.New("no authorization header supplied") |
| paddy@69 | 30 // ErrInvalidAuthFormat is returned when an Authorization header is present but not the correct format. |
| paddy@69 | 31 ErrInvalidAuthFormat = errors.New("authorization header is not in a valid format") |
| paddy@69 | 32 // ErrIncorrectAuth is returned when a user authentication attempt does not match the stored values. |
| paddy@69 | 33 ErrIncorrectAuth = errors.New("invalid authentication") |
| paddy@69 | 34 // ErrInvalidPassphraseScheme is returned when an undefined passphrase scheme is used. |
| paddy@69 | 35 ErrInvalidPassphraseScheme = errors.New("invalid passphrase scheme") |
| paddy@69 | 36 // ErrNoSession is returned when no session ID is passed with a request. |
| paddy@69 | 37 ErrNoSession = errors.New("no session ID found") |
| paddy@69 | 38 ) |
| paddy@69 | 39 |
| paddy@69 | 40 type tokenResponse struct { |
| paddy@69 | 41 AccessToken string `json:"access_token"` |
| paddy@69 | 42 TokenType string `json:"token_type,omitempty"` |
| paddy@69 | 43 ExpiresIn int32 `json:"expires_in,omitempty"` |
| paddy@69 | 44 RefreshToken string `json:"refresh_token,omitempty"` |
| paddy@69 | 45 } |
| paddy@69 | 46 |
| paddy@69 | 47 func getBasicAuth(r *http.Request) (un, pass string, err error) { |
| paddy@69 | 48 auth := r.Header.Get("Authorization") |
| paddy@69 | 49 if auth == "" { |
| paddy@69 | 50 return "", "", ErrNoAuth |
| paddy@69 | 51 } |
| paddy@69 | 52 pieces := strings.SplitN(auth, " ", 2) |
| paddy@69 | 53 if pieces[0] != "Basic" { |
| paddy@69 | 54 return "", "", ErrInvalidAuthFormat |
| paddy@69 | 55 } |
| paddy@69 | 56 decoded, err := base64.StdEncoding.DecodeString(pieces[1]) |
| paddy@69 | 57 if err != nil { |
| paddy@69 | 58 return "", "", ErrInvalidAuthFormat |
| paddy@69 | 59 } |
| paddy@69 | 60 info := strings.SplitN(string(decoded), ":", 2) |
| paddy@77 | 61 if len(info) < 2 { |
| paddy@77 | 62 return "", "", ErrInvalidAuthFormat |
| paddy@77 | 63 } |
| paddy@69 | 64 return info[0], info[1], nil |
| paddy@69 | 65 } |
| paddy@69 | 66 |
| paddy@69 | 67 func checkCookie(r *http.Request, context Context) (Session, error) { |
| paddy@69 | 68 cookie, err := r.Cookie(authCookieName) |
| paddy@77 | 69 if err == http.ErrNoCookie { |
| paddy@77 | 70 return Session{}, ErrNoSession |
| paddy@77 | 71 } else if err != nil { |
| paddy@77 | 72 log.Println(err) |
| paddy@69 | 73 return Session{}, err |
| paddy@69 | 74 } |
| paddy@69 | 75 sess, err := context.GetSession(cookie.Value) |
| paddy@69 | 76 if err == ErrSessionNotFound { |
| paddy@69 | 77 return Session{}, ErrInvalidSession |
| paddy@69 | 78 } else if err != nil { |
| paddy@69 | 79 return Session{}, err |
| paddy@69 | 80 } |
| paddy@69 | 81 if !sess.Active { |
| paddy@69 | 82 return Session{}, ErrInvalidSession |
| paddy@69 | 83 } |
| paddy@69 | 84 return sess, nil |
| paddy@69 | 85 } |
| paddy@69 | 86 |
| paddy@77 | 87 func buildLoginRedirect(r *http.Request, context Context) string { |
| paddy@77 | 88 if context.loginURI == nil { |
| paddy@77 | 89 return "" |
| paddy@77 | 90 } |
| paddy@77 | 91 uri := *context.loginURI |
| paddy@77 | 92 q := uri.Query() |
| paddy@78 | 93 q.Set("from", r.URL.String()) |
| paddy@77 | 94 uri.RawQuery = q.Encode() |
| paddy@77 | 95 return uri.String() |
| paddy@77 | 96 } |
| paddy@77 | 97 |
| paddy@69 | 98 func authenticate(user, passphrase string, context Context) (Profile, error) { |
| paddy@69 | 99 profile, err := context.GetProfileByLogin(user) |
| paddy@69 | 100 if err != nil { |
| paddy@79 | 101 if err == ErrProfileNotFound || err == ErrLoginNotFound { |
| paddy@69 | 102 return Profile{}, ErrIncorrectAuth |
| paddy@69 | 103 } |
| paddy@69 | 104 return Profile{}, err |
| paddy@69 | 105 } |
| paddy@69 | 106 switch profile.PassphraseScheme { |
| paddy@69 | 107 case 1: |
| paddy@79 | 108 realPass, err := hex.DecodeString(profile.Passphrase) |
| paddy@79 | 109 if err != nil { |
| paddy@79 | 110 return Profile{}, err |
| paddy@79 | 111 } |
| paddy@69 | 112 candidate := pass.Check(sha256.New, profile.Iterations, []byte(passphrase), []byte(profile.Salt)) |
| paddy@79 | 113 if !pass.Compare(candidate, realPass) { |
| paddy@69 | 114 return Profile{}, ErrIncorrectAuth |
| paddy@69 | 115 } |
| paddy@69 | 116 default: |
| paddy@69 | 117 return Profile{}, ErrInvalidPassphraseScheme |
| paddy@69 | 118 } |
| paddy@69 | 119 return profile, nil |
| paddy@69 | 120 } |
| paddy@69 | 121 |
| paddy@77 | 122 func wrap(context Context, f func(w http.ResponseWriter, r *http.Request, context Context)) http.Handler { |
| paddy@77 | 123 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| paddy@77 | 124 f(w, r, context) |
| paddy@77 | 125 }) |
| paddy@77 | 126 } |
| paddy@77 | 127 |
| paddy@77 | 128 // RegisterOAuth2 adds handlers to the passed router to handle the OAuth2 endpoints. |
| paddy@77 | 129 func RegisterOAuth2(r *mux.Router, context Context) { |
| paddy@77 | 130 r.Handle("/authorize", wrap(context, GetGrantHandler)) |
| paddy@77 | 131 r.Handle("/token", wrap(context, GetTokenHandler)) |
| paddy@77 | 132 } |
| paddy@77 | 133 |
| paddy@57 | 134 // GetGrantHandler presents and processes the page for asking a user to grant access |
| paddy@57 | 135 // to their data. See RFC 6749, Section 4.1. |
| paddy@51 | 136 func GetGrantHandler(w http.ResponseWriter, r *http.Request, context Context) { |
| paddy@69 | 137 session, err := checkCookie(r, context) |
| paddy@69 | 138 if err != nil { |
| paddy@76 | 139 if err == ErrNoSession || err == ErrInvalidSession { |
| paddy@77 | 140 redir := buildLoginRedirect(r, context) |
| paddy@77 | 141 if redir == "" { |
| paddy@77 | 142 log.Println("No login URL configured.") |
| paddy@77 | 143 w.WriteHeader(http.StatusInternalServerError) |
| paddy@77 | 144 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@77 | 145 "internal_error": template.HTML("Missing login URL."), |
| paddy@77 | 146 }) |
| paddy@77 | 147 return |
| paddy@77 | 148 } |
| paddy@77 | 149 http.Redirect(w, r, redir, http.StatusFound) |
| paddy@77 | 150 return |
| paddy@69 | 151 } |
| paddy@77 | 152 log.Println(err.Error()) |
| paddy@77 | 153 w.WriteHeader(http.StatusInternalServerError) |
| paddy@77 | 154 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@77 | 155 "internal_error": template.HTML(err.Error()), |
| paddy@77 | 156 }) |
| paddy@77 | 157 return |
| paddy@69 | 158 } |
| paddy@56 | 159 if r.URL.Query().Get("client_id") == "" { |
| paddy@56 | 160 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 161 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 162 "error": template.HTML("Client ID must be specified in the request."), |
| paddy@56 | 163 }) |
| paddy@56 | 164 return |
| paddy@56 | 165 } |
| paddy@56 | 166 clientID, err := uuid.Parse(r.URL.Query().Get("client_id")) |
| paddy@56 | 167 if err != nil { |
| paddy@56 | 168 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 169 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 170 "error": template.HTML("client_id is not a valid Client ID."), |
| paddy@56 | 171 }) |
| paddy@56 | 172 return |
| paddy@56 | 173 } |
| paddy@64 | 174 redirectURI := r.URL.Query().Get("redirect_uri") |
| paddy@64 | 175 redirectURL, err := url.Parse(redirectURI) |
| paddy@64 | 176 if err != nil { |
| paddy@64 | 177 w.WriteHeader(http.StatusBadRequest) |
| paddy@64 | 178 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@64 | 179 "error": template.HTML("The redirect_uri specified is not valid."), |
| paddy@64 | 180 }) |
| paddy@64 | 181 return |
| paddy@64 | 182 } |
| paddy@56 | 183 client, err := context.GetClient(clientID) |
| paddy@56 | 184 if err != nil { |
| paddy@59 | 185 if err == ErrClientNotFound { |
| paddy@59 | 186 w.WriteHeader(http.StatusBadRequest) |
| paddy@59 | 187 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 188 "error": template.HTML("The specified Client couldn’t be found."), |
| paddy@59 | 189 }) |
| paddy@59 | 190 } else { |
| paddy@77 | 191 log.Println(err.Error()) |
| paddy@59 | 192 w.WriteHeader(http.StatusInternalServerError) |
| paddy@59 | 193 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 194 "internal_error": template.HTML(err.Error()), |
| paddy@59 | 195 }) |
| paddy@59 | 196 } |
| paddy@56 | 197 return |
| paddy@56 | 198 } |
| paddy@56 | 199 // whether a redirect URI is valid or not depends on the number of endpoints |
| paddy@56 | 200 // the client has registered |
| paddy@56 | 201 numEndpoints, err := context.CountEndpoints(clientID) |
| paddy@56 | 202 if err != nil { |
| paddy@77 | 203 log.Println(err.Error()) |
| paddy@56 | 204 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 205 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 206 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 207 }) |
| paddy@56 | 208 return |
| paddy@56 | 209 } |
| paddy@56 | 210 var validURI bool |
| paddy@58 | 211 if redirectURI != "" { |
| paddy@58 | 212 // BUG(paddy): We really should normalize URIs before trying to compare them. |
| paddy@58 | 213 validURI, err = context.CheckEndpoint(clientID, redirectURI) |
| paddy@56 | 214 if err != nil { |
| paddy@77 | 215 log.Println(err.Error()) |
| paddy@56 | 216 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 217 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 218 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 219 }) |
| paddy@56 | 220 return |
| paddy@56 | 221 } |
| paddy@56 | 222 } else if redirectURI == "" && numEndpoints == 1 { |
| paddy@56 | 223 // if we don't specify the endpoint and there's only one endpoint, the |
| paddy@56 | 224 // request is valid, and we're redirecting to that one endpoint |
| paddy@56 | 225 validURI = true |
| paddy@56 | 226 endpoints, err := context.ListEndpoints(clientID, 1, 0) |
| paddy@56 | 227 if err != nil { |
| paddy@77 | 228 log.Println(err.Error()) |
| paddy@56 | 229 w.WriteHeader(http.StatusInternalServerError) |
| paddy@56 | 230 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 231 "internal_error": template.HTML(err.Error()), |
| paddy@56 | 232 }) |
| paddy@56 | 233 return |
| paddy@56 | 234 } |
| paddy@56 | 235 if len(endpoints) != 1 { |
| paddy@56 | 236 validURI = false |
| paddy@56 | 237 } else { |
| paddy@66 | 238 u := endpoints[0].URI // Copy it here to avoid grabbing a pointer to the memstore |
| paddy@66 | 239 redirectURI = u.String() |
| paddy@66 | 240 redirectURL = &u |
| paddy@56 | 241 } |
| paddy@56 | 242 } else { |
| paddy@56 | 243 validURI = false |
| paddy@56 | 244 } |
| paddy@56 | 245 if !validURI { |
| paddy@56 | 246 w.WriteHeader(http.StatusBadRequest) |
| paddy@56 | 247 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@61 | 248 "error": template.HTML("The redirect_uri specified is not valid."), |
| paddy@56 | 249 }) |
| paddy@56 | 250 return |
| paddy@56 | 251 } |
| paddy@60 | 252 scope := r.URL.Query().Get("scope") |
| paddy@60 | 253 state := r.URL.Query().Get("state") |
| paddy@56 | 254 if r.URL.Query().Get("response_type") != "code" { |
| paddy@65 | 255 q := redirectURL.Query() |
| paddy@65 | 256 q.Add("error", "invalid_request") |
| paddy@65 | 257 q.Add("state", state) |
| paddy@65 | 258 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 259 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 260 return |
| paddy@56 | 261 } |
| paddy@56 | 262 if r.Method == "POST" { |
| paddy@63 | 263 // BUG(paddy): We need to implement CSRF protection when obtaining a grant code. |
| paddy@56 | 264 if r.PostFormValue("grant") == "approved" { |
| paddy@60 | 265 code := uuid.NewID().String() |
| paddy@60 | 266 grant := Grant{ |
| paddy@60 | 267 Code: code, |
| paddy@60 | 268 Created: time.Now(), |
| paddy@60 | 269 ExpiresIn: defaultGrantExpiration, |
| paddy@60 | 270 ClientID: clientID, |
| paddy@60 | 271 Scope: scope, |
| paddy@69 | 272 RedirectURI: r.URL.Query().Get("redirect_uri"), |
| paddy@60 | 273 State: state, |
| paddy@69 | 274 ProfileID: session.ProfileID, |
| paddy@60 | 275 } |
| paddy@60 | 276 err := context.SaveGrant(grant) |
| paddy@60 | 277 if err != nil { |
| paddy@66 | 278 q := redirectURL.Query() |
| paddy@66 | 279 q.Add("error", "server_error") |
| paddy@66 | 280 q.Add("state", state) |
| paddy@66 | 281 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 282 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 283 return |
| paddy@60 | 284 } |
| paddy@66 | 285 q := redirectURL.Query() |
| paddy@66 | 286 q.Add("code", code) |
| paddy@66 | 287 q.Add("state", state) |
| paddy@66 | 288 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 289 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 290 return |
| paddy@56 | 291 } |
| paddy@66 | 292 q := redirectURL.Query() |
| paddy@66 | 293 q.Add("error", "access_denied") |
| paddy@66 | 294 q.Add("state", state) |
| paddy@66 | 295 redirectURL.RawQuery = q.Encode() |
| paddy@60 | 296 http.Redirect(w, r, redirectURL.String(), http.StatusFound) |
| paddy@60 | 297 return |
| paddy@56 | 298 } |
| paddy@51 | 299 w.WriteHeader(http.StatusOK) |
| paddy@56 | 300 context.Render(w, getGrantTemplateName, map[string]interface{}{ |
| paddy@56 | 301 "client": client, |
| paddy@56 | 302 }) |
| paddy@51 | 303 } |
| paddy@68 | 304 |
| paddy@69 | 305 // GetTokenHandler allows a client to exchange an authorization grant for an |
| paddy@69 | 306 // access token. See RFC 6749 Section 4.1.3. |
| paddy@69 | 307 func GetTokenHandler(w http.ResponseWriter, r *http.Request, context Context) { |
| paddy@69 | 308 enc := json.NewEncoder(w) |
| paddy@69 | 309 grantType := r.PostFormValue("grant_type") |
| paddy@69 | 310 if grantType != "authorization_code" { |
| paddy@69 | 311 // TODO(paddy): render invalid request JSON |
| paddy@69 | 312 return |
| paddy@69 | 313 } |
| paddy@69 | 314 code := r.PostFormValue("code") |
| paddy@69 | 315 if code == "" { |
| paddy@69 | 316 // TODO(paddy): render invalid request JSON |
| paddy@69 | 317 return |
| paddy@69 | 318 } |
| paddy@69 | 319 redirectURI := r.PostFormValue("redirect_uri") |
| paddy@69 | 320 clientIDStr, clientSecret, err := getBasicAuth(r) |
| paddy@69 | 321 if err != nil { |
| paddy@69 | 322 // TODO(paddy): render access denied |
| paddy@69 | 323 return |
| paddy@69 | 324 } |
| paddy@69 | 325 if clientIDStr == "" && err == nil { |
| paddy@69 | 326 clientIDStr = r.PostFormValue("client_id") |
| paddy@69 | 327 } |
| paddy@69 | 328 clientID, err := uuid.Parse(clientIDStr) |
| paddy@69 | 329 if err != nil { |
| paddy@69 | 330 // TODO(paddy): render invalid request JSON |
| paddy@69 | 331 return |
| paddy@69 | 332 } |
| paddy@69 | 333 client, err := context.GetClient(clientID) |
| paddy@69 | 334 if err != nil { |
| paddy@69 | 335 if err == ErrClientNotFound { |
| paddy@69 | 336 // TODO(paddy): render invalid request JSON |
| paddy@69 | 337 } else { |
| paddy@69 | 338 // TODO(paddy): render internal server error JSON |
| paddy@69 | 339 } |
| paddy@69 | 340 return |
| paddy@69 | 341 } |
| paddy@69 | 342 if client.Secret != clientSecret { |
| paddy@69 | 343 // TODO(paddy): render invalid request JSON |
| paddy@69 | 344 return |
| paddy@69 | 345 } |
| paddy@69 | 346 grant, err := context.GetGrant(code) |
| paddy@69 | 347 if err != nil { |
| paddy@69 | 348 if err == ErrGrantNotFound { |
| paddy@69 | 349 // TODO(paddy): return error |
| paddy@69 | 350 return |
| paddy@69 | 351 } |
| paddy@69 | 352 // TODO(paddy): return error |
| paddy@69 | 353 } |
| paddy@69 | 354 if grant.RedirectURI != redirectURI { |
| paddy@69 | 355 // TODO(paddy): return error |
| paddy@69 | 356 } |
| paddy@69 | 357 if !grant.ClientID.Equal(clientID) { |
| paddy@69 | 358 // TODO(paddy): return error |
| paddy@69 | 359 } |
| paddy@69 | 360 token := Token{ |
| paddy@69 | 361 AccessToken: uuid.NewID().String(), |
| paddy@69 | 362 RefreshToken: uuid.NewID().String(), |
| paddy@69 | 363 Created: time.Now(), |
| paddy@69 | 364 ExpiresIn: defaultTokenExpiration, |
| paddy@69 | 365 TokenType: "", // TODO(paddy): fill in token type |
| paddy@69 | 366 Scope: grant.Scope, |
| paddy@69 | 367 ProfileID: grant.ProfileID, |
| paddy@69 | 368 } |
| paddy@69 | 369 err = context.SaveToken(token) |
| paddy@69 | 370 if err != nil { |
| paddy@69 | 371 // TODO(paddy): return error |
| paddy@69 | 372 } |
| paddy@69 | 373 resp := tokenResponse{ |
| paddy@69 | 374 AccessToken: token.AccessToken, |
| paddy@69 | 375 RefreshToken: token.RefreshToken, |
| paddy@69 | 376 ExpiresIn: token.ExpiresIn, |
| paddy@69 | 377 TokenType: token.TokenType, |
| paddy@69 | 378 } |
| paddy@69 | 379 err = enc.Encode(resp) |
| paddy@69 | 380 if err != nil { |
| paddy@69 | 381 // TODO(paddy): log this or something |
| paddy@69 | 382 return |
| paddy@69 | 383 } |
| paddy@69 | 384 } |
| paddy@69 | 385 |
| paddy@68 | 386 // TODO(paddy): exchange user credentials for access token |
| paddy@68 | 387 // TODO(paddy): exchange client credentials for access token |
| paddy@68 | 388 // TODO(paddy): implicit grant for access token |
| paddy@68 | 389 // TODO(paddy): exchange refresh token for access token |