auth
auth/oauth2.go
Rename http.go. We're going to have a lot of HTTP handlers, and I'd rather make it clear that this is taking care of our OAuth2 HTTP logic. So rename the file, and we'll put the API handlers in their files, or something.
1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000 1.2 +++ b/oauth2.go Tue Nov 11 21:22:57 2014 -0500 1.3 @@ -0,0 +1,347 @@ 1.4 +package auth 1.5 + 1.6 +import ( 1.7 + "encoding/base64" 1.8 + "encoding/json" 1.9 + "errors" 1.10 + "html/template" 1.11 + "net/http" 1.12 + "net/url" 1.13 + "strings" 1.14 + "time" 1.15 + 1.16 + "crypto/sha256" 1.17 + "code.secondbit.org/pass" 1.18 + "code.secondbit.org/uuid" 1.19 +) 1.20 + 1.21 +const ( 1.22 + authCookieName = "auth" 1.23 + defaultGrantExpiration = 600 // default to ten minute grant expirations 1.24 + getGrantTemplateName = "get_grant" 1.25 +) 1.26 + 1.27 +var ( 1.28 + // ErrNoAuth is returned when an Authorization header is not present or is empty. 1.29 + ErrNoAuth = errors.New("no authorization header supplied") 1.30 + // ErrInvalidAuthFormat is returned when an Authorization header is present but not the correct format. 1.31 + ErrInvalidAuthFormat = errors.New("authorization header is not in a valid format") 1.32 + // ErrIncorrectAuth is returned when a user authentication attempt does not match the stored values. 1.33 + ErrIncorrectAuth = errors.New("invalid authentication") 1.34 + // ErrInvalidPassphraseScheme is returned when an undefined passphrase scheme is used. 1.35 + ErrInvalidPassphraseScheme = errors.New("invalid passphrase scheme") 1.36 + // ErrNoSession is returned when no session ID is passed with a request. 1.37 + ErrNoSession = errors.New("no session ID found") 1.38 +) 1.39 + 1.40 +type tokenResponse struct { 1.41 + AccessToken string `json:"access_token"` 1.42 + TokenType string `json:"token_type,omitempty"` 1.43 + ExpiresIn int32 `json:"expires_in,omitempty"` 1.44 + RefreshToken string `json:"refresh_token,omitempty"` 1.45 +} 1.46 + 1.47 +func getBasicAuth(r *http.Request) (un, pass string, err error) { 1.48 + auth := r.Header.Get("Authorization") 1.49 + if auth == "" { 1.50 + return "", "", ErrNoAuth 1.51 + } 1.52 + pieces := strings.SplitN(auth, " ", 2) 1.53 + if pieces[0] != "Basic" { 1.54 + return "", "", ErrInvalidAuthFormat 1.55 + } 1.56 + decoded, err := base64.StdEncoding.DecodeString(pieces[1]) 1.57 + if err != nil { 1.58 + return "", "", ErrInvalidAuthFormat 1.59 + } 1.60 + info := strings.SplitN(string(decoded), ":", 2) 1.61 + return info[0], info[1], nil 1.62 +} 1.63 + 1.64 +func checkCookie(r *http.Request, context Context) (Session, error) { 1.65 + cookie, err := r.Cookie(authCookieName) 1.66 + if err != nil { 1.67 + if err == http.ErrNoCookie { 1.68 + return Session{}, ErrNoSession 1.69 + } 1.70 + return Session{}, err 1.71 + } 1.72 + if cookie.Name != authCookieName || !cookie.Expires.After(time.Now()) || 1.73 + !cookie.Secure || !cookie.HttpOnly { 1.74 + return Session{}, ErrInvalidSession 1.75 + } 1.76 + sess, err := context.GetSession(cookie.Value) 1.77 + if err == ErrSessionNotFound { 1.78 + return Session{}, ErrInvalidSession 1.79 + } else if err != nil { 1.80 + return Session{}, err 1.81 + } 1.82 + if !sess.Active { 1.83 + return Session{}, ErrInvalidSession 1.84 + } 1.85 + return sess, nil 1.86 +} 1.87 + 1.88 +func authenticate(user, passphrase string, context Context) (Profile, error) { 1.89 + profile, err := context.GetProfileByLogin(user) 1.90 + if err != nil { 1.91 + if err == ErrProfileNotFound { 1.92 + return Profile{}, ErrIncorrectAuth 1.93 + } 1.94 + return Profile{}, err 1.95 + } 1.96 + switch profile.PassphraseScheme { 1.97 + case 1: 1.98 + candidate := pass.Check(sha256.New, profile.Iterations, []byte(passphrase), []byte(profile.Salt)) 1.99 + if !pass.Compare(candidate, []byte(profile.Passphrase)) { 1.100 + return Profile{}, ErrIncorrectAuth 1.101 + } 1.102 + default: 1.103 + return Profile{}, ErrInvalidPassphraseScheme 1.104 + } 1.105 + return profile, nil 1.106 +} 1.107 + 1.108 +// GetGrantHandler presents and processes the page for asking a user to grant access 1.109 +// to their data. See RFC 6749, Section 4.1. 1.110 +func GetGrantHandler(w http.ResponseWriter, r *http.Request, context Context) { 1.111 + session, err := checkCookie(r, context) 1.112 + if err != nil { 1.113 + if err == ErrNoSession { 1.114 + // TODO(paddy): redirect to login screen 1.115 + //return 1.116 + } 1.117 + if err == ErrInvalidSession { 1.118 + // TODO(paddy): return an access denied error 1.119 + //return 1.120 + } 1.121 + // TODO(paddy): return a server error 1.122 + //return 1.123 + } 1.124 + if r.URL.Query().Get("client_id") == "" { 1.125 + w.WriteHeader(http.StatusBadRequest) 1.126 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.127 + "error": template.HTML("Client ID must be specified in the request."), 1.128 + }) 1.129 + return 1.130 + } 1.131 + clientID, err := uuid.Parse(r.URL.Query().Get("client_id")) 1.132 + if err != nil { 1.133 + w.WriteHeader(http.StatusBadRequest) 1.134 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.135 + "error": template.HTML("client_id is not a valid Client ID."), 1.136 + }) 1.137 + return 1.138 + } 1.139 + redirectURI := r.URL.Query().Get("redirect_uri") 1.140 + redirectURL, err := url.Parse(redirectURI) 1.141 + if err != nil { 1.142 + w.WriteHeader(http.StatusBadRequest) 1.143 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.144 + "error": template.HTML("The redirect_uri specified is not valid."), 1.145 + }) 1.146 + return 1.147 + } 1.148 + client, err := context.GetClient(clientID) 1.149 + if err != nil { 1.150 + if err == ErrClientNotFound { 1.151 + w.WriteHeader(http.StatusBadRequest) 1.152 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.153 + "error": template.HTML("The specified Client couldn’t be found."), 1.154 + }) 1.155 + } else { 1.156 + w.WriteHeader(http.StatusInternalServerError) 1.157 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.158 + "internal_error": template.HTML(err.Error()), 1.159 + }) 1.160 + } 1.161 + return 1.162 + } 1.163 + // whether a redirect URI is valid or not depends on the number of endpoints 1.164 + // the client has registered 1.165 + numEndpoints, err := context.CountEndpoints(clientID) 1.166 + if err != nil { 1.167 + w.WriteHeader(http.StatusInternalServerError) 1.168 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.169 + "internal_error": template.HTML(err.Error()), 1.170 + }) 1.171 + return 1.172 + } 1.173 + var validURI bool 1.174 + if redirectURI != "" { 1.175 + // BUG(paddy): We really should normalize URIs before trying to compare them. 1.176 + validURI, err = context.CheckEndpoint(clientID, redirectURI) 1.177 + if err != nil { 1.178 + w.WriteHeader(http.StatusInternalServerError) 1.179 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.180 + "internal_error": template.HTML(err.Error()), 1.181 + }) 1.182 + return 1.183 + } 1.184 + } else if redirectURI == "" && numEndpoints == 1 { 1.185 + // if we don't specify the endpoint and there's only one endpoint, the 1.186 + // request is valid, and we're redirecting to that one endpoint 1.187 + validURI = true 1.188 + endpoints, err := context.ListEndpoints(clientID, 1, 0) 1.189 + if err != nil { 1.190 + w.WriteHeader(http.StatusInternalServerError) 1.191 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.192 + "internal_error": template.HTML(err.Error()), 1.193 + }) 1.194 + return 1.195 + } 1.196 + if len(endpoints) != 1 { 1.197 + validURI = false 1.198 + } else { 1.199 + u := endpoints[0].URI // Copy it here to avoid grabbing a pointer to the memstore 1.200 + redirectURI = u.String() 1.201 + redirectURL = &u 1.202 + } 1.203 + } else { 1.204 + validURI = false 1.205 + } 1.206 + if !validURI { 1.207 + w.WriteHeader(http.StatusBadRequest) 1.208 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.209 + "error": template.HTML("The redirect_uri specified is not valid."), 1.210 + }) 1.211 + return 1.212 + } 1.213 + scope := r.URL.Query().Get("scope") 1.214 + state := r.URL.Query().Get("state") 1.215 + if r.URL.Query().Get("response_type") != "code" { 1.216 + q := redirectURL.Query() 1.217 + q.Add("error", "invalid_request") 1.218 + q.Add("state", state) 1.219 + redirectURL.RawQuery = q.Encode() 1.220 + http.Redirect(w, r, redirectURL.String(), http.StatusFound) 1.221 + return 1.222 + } 1.223 + if r.Method == "POST" { 1.224 + // BUG(paddy): We need to implement CSRF protection when obtaining a grant code. 1.225 + if r.PostFormValue("grant") == "approved" { 1.226 + code := uuid.NewID().String() 1.227 + grant := Grant{ 1.228 + Code: code, 1.229 + Created: time.Now(), 1.230 + ExpiresIn: defaultGrantExpiration, 1.231 + ClientID: clientID, 1.232 + Scope: scope, 1.233 + RedirectURI: r.URL.Query().Get("redirect_uri"), 1.234 + State: state, 1.235 + ProfileID: session.ProfileID, 1.236 + } 1.237 + err := context.SaveGrant(grant) 1.238 + if err != nil { 1.239 + q := redirectURL.Query() 1.240 + q.Add("error", "server_error") 1.241 + q.Add("state", state) 1.242 + redirectURL.RawQuery = q.Encode() 1.243 + http.Redirect(w, r, redirectURL.String(), http.StatusFound) 1.244 + return 1.245 + } 1.246 + q := redirectURL.Query() 1.247 + q.Add("code", code) 1.248 + q.Add("state", state) 1.249 + redirectURL.RawQuery = q.Encode() 1.250 + http.Redirect(w, r, redirectURL.String(), http.StatusFound) 1.251 + return 1.252 + } 1.253 + q := redirectURL.Query() 1.254 + q.Add("error", "access_denied") 1.255 + q.Add("state", state) 1.256 + redirectURL.RawQuery = q.Encode() 1.257 + http.Redirect(w, r, redirectURL.String(), http.StatusFound) 1.258 + return 1.259 + } 1.260 + w.WriteHeader(http.StatusOK) 1.261 + context.Render(w, getGrantTemplateName, map[string]interface{}{ 1.262 + "client": client, 1.263 + }) 1.264 +} 1.265 + 1.266 +// GetTokenHandler allows a client to exchange an authorization grant for an 1.267 +// access token. See RFC 6749 Section 4.1.3. 1.268 +func GetTokenHandler(w http.ResponseWriter, r *http.Request, context Context) { 1.269 + enc := json.NewEncoder(w) 1.270 + grantType := r.PostFormValue("grant_type") 1.271 + if grantType != "authorization_code" { 1.272 + // TODO(paddy): render invalid request JSON 1.273 + return 1.274 + } 1.275 + code := r.PostFormValue("code") 1.276 + if code == "" { 1.277 + // TODO(paddy): render invalid request JSON 1.278 + return 1.279 + } 1.280 + redirectURI := r.PostFormValue("redirect_uri") 1.281 + clientIDStr, clientSecret, err := getBasicAuth(r) 1.282 + if err != nil { 1.283 + // TODO(paddy): render access denied 1.284 + return 1.285 + } 1.286 + if clientIDStr == "" && err == nil { 1.287 + clientIDStr = r.PostFormValue("client_id") 1.288 + } 1.289 + clientID, err := uuid.Parse(clientIDStr) 1.290 + if err != nil { 1.291 + // TODO(paddy): render invalid request JSON 1.292 + return 1.293 + } 1.294 + client, err := context.GetClient(clientID) 1.295 + if err != nil { 1.296 + if err == ErrClientNotFound { 1.297 + // TODO(paddy): render invalid request JSON 1.298 + } else { 1.299 + // TODO(paddy): render internal server error JSON 1.300 + } 1.301 + return 1.302 + } 1.303 + if client.Secret != clientSecret { 1.304 + // TODO(paddy): render invalid request JSON 1.305 + return 1.306 + } 1.307 + grant, err := context.GetGrant(code) 1.308 + if err != nil { 1.309 + if err == ErrGrantNotFound { 1.310 + // TODO(paddy): return error 1.311 + return 1.312 + } 1.313 + // TODO(paddy): return error 1.314 + } 1.315 + if grant.RedirectURI != redirectURI { 1.316 + // TODO(paddy): return error 1.317 + } 1.318 + if !grant.ClientID.Equal(clientID) { 1.319 + // TODO(paddy): return error 1.320 + } 1.321 + token := Token{ 1.322 + AccessToken: uuid.NewID().String(), 1.323 + RefreshToken: uuid.NewID().String(), 1.324 + Created: time.Now(), 1.325 + ExpiresIn: defaultTokenExpiration, 1.326 + TokenType: "", // TODO(paddy): fill in token type 1.327 + Scope: grant.Scope, 1.328 + ProfileID: grant.ProfileID, 1.329 + } 1.330 + err = context.SaveToken(token) 1.331 + if err != nil { 1.332 + // TODO(paddy): return error 1.333 + } 1.334 + resp := tokenResponse{ 1.335 + AccessToken: token.AccessToken, 1.336 + RefreshToken: token.RefreshToken, 1.337 + ExpiresIn: token.ExpiresIn, 1.338 + TokenType: token.TokenType, 1.339 + } 1.340 + err = enc.Encode(resp) 1.341 + if err != nil { 1.342 + // TODO(paddy): log this or something 1.343 + return 1.344 + } 1.345 +} 1.346 + 1.347 +// TODO(paddy): exchange user credentials for access token 1.348 +// TODO(paddy): exchange client credentials for access token 1.349 +// TODO(paddy): implicit grant for access token 1.350 +// TODO(paddy): exchange refresh token for access token