Write Session tests.
Add loginURI as a property to our Context, to keep track of where users should
be redirected to log in.
Implement the sessionStore in the memstore to let us test with Sessions.
Catch when the HTTP Basic Auth header doesn't include two parts, rather than
panicking. Return an ErrInvalidAuthFormat.
Clean up the error handling for checkCookie to be cleaner. Log unexpected errors
from request.Cookie.
Stop checking for cookie expiration times--those aren't sent to the server, so
we'll never get a valid session if we look for them.
Add a helper to build a login redirect URI--a URI the user can be redirected to
that has a URL-encoded URL to redirect the user back to after a successful
login.
Add a wrapper to wrap our Context into HTTP handlers.
Create a RegisterOAuth2 helper that adds the OAuth2 endpoints to a Gorilla/mux
router.
Redirect users to the login page when they have no session set or an invalid
session.
Return a server error when we can't check our cookie for whatever reason.
Log errors.
Add sessions to our OAuth2 tests so the tests stop failing--the session check
was interfering with them.
Add a test for our getBasicAuth helper to ensure that we're parsing basic auth
correctly.
Add an ErrSessionAlreadyExists error to be returned when a session has an ID
conflict.
Test that our memstore implementation of the sessionStore works as intended..
13 "github.com/gorilla/mux"
16 "code.secondbit.org/pass"
17 "code.secondbit.org/uuid"
21 authCookieName = "auth"
22 defaultGrantExpiration = 600 // default to ten minute grant expirations
23 getGrantTemplateName = "get_grant"
27 // ErrNoAuth is returned when an Authorization header is not present or is empty.
28 ErrNoAuth = errors.New("no authorization header supplied")
29 // ErrInvalidAuthFormat is returned when an Authorization header is present but not the correct format.
30 ErrInvalidAuthFormat = errors.New("authorization header is not in a valid format")
31 // ErrIncorrectAuth is returned when a user authentication attempt does not match the stored values.
32 ErrIncorrectAuth = errors.New("invalid authentication")
33 // ErrInvalidPassphraseScheme is returned when an undefined passphrase scheme is used.
34 ErrInvalidPassphraseScheme = errors.New("invalid passphrase scheme")
35 // ErrNoSession is returned when no session ID is passed with a request.
36 ErrNoSession = errors.New("no session ID found")
39 type tokenResponse struct {
40 AccessToken string `json:"access_token"`
41 TokenType string `json:"token_type,omitempty"`
42 ExpiresIn int32 `json:"expires_in,omitempty"`
43 RefreshToken string `json:"refresh_token,omitempty"`
46 func getBasicAuth(r *http.Request) (un, pass string, err error) {
47 auth := r.Header.Get("Authorization")
49 return "", "", ErrNoAuth
51 pieces := strings.SplitN(auth, " ", 2)
52 if pieces[0] != "Basic" {
53 return "", "", ErrInvalidAuthFormat
55 decoded, err := base64.StdEncoding.DecodeString(pieces[1])
57 return "", "", ErrInvalidAuthFormat
59 info := strings.SplitN(string(decoded), ":", 2)
61 return "", "", ErrInvalidAuthFormat
63 return info[0], info[1], nil
66 func checkCookie(r *http.Request, context Context) (Session, error) {
67 cookie, err := r.Cookie(authCookieName)
68 if err == http.ErrNoCookie {
69 return Session{}, ErrNoSession
70 } else if err != nil {
74 sess, err := context.GetSession(cookie.Value)
75 if err == ErrSessionNotFound {
76 return Session{}, ErrInvalidSession
77 } else if err != nil {
81 return Session{}, ErrInvalidSession
86 func buildLoginRedirect(r *http.Request, context Context) string {
87 if context.loginURI == nil {
90 uri := *context.loginURI
92 q.Set("from", url.QueryEscape(r.URL.String()))
93 uri.RawQuery = q.Encode()
97 func authenticate(user, passphrase string, context Context) (Profile, error) {
98 profile, err := context.GetProfileByLogin(user)
100 if err == ErrProfileNotFound {
101 return Profile{}, ErrIncorrectAuth
103 return Profile{}, err
105 switch profile.PassphraseScheme {
107 candidate := pass.Check(sha256.New, profile.Iterations, []byte(passphrase), []byte(profile.Salt))
108 if !pass.Compare(candidate, []byte(profile.Passphrase)) {
109 return Profile{}, ErrIncorrectAuth
112 return Profile{}, ErrInvalidPassphraseScheme
117 func wrap(context Context, f func(w http.ResponseWriter, r *http.Request, context Context)) http.Handler {
118 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
123 // RegisterOAuth2 adds handlers to the passed router to handle the OAuth2 endpoints.
124 func RegisterOAuth2(r *mux.Router, context Context) {
125 r.Handle("/authorize", wrap(context, GetGrantHandler))
126 r.Handle("/token", wrap(context, GetTokenHandler))
129 // GetGrantHandler presents and processes the page for asking a user to grant access
130 // to their data. See RFC 6749, Section 4.1.
131 func GetGrantHandler(w http.ResponseWriter, r *http.Request, context Context) {
132 session, err := checkCookie(r, context)
134 if err == ErrNoSession || err == ErrInvalidSession {
135 redir := buildLoginRedirect(r, context)
137 log.Println("No login URL configured.")
138 w.WriteHeader(http.StatusInternalServerError)
139 context.Render(w, getGrantTemplateName, map[string]interface{}{
140 "internal_error": template.HTML("Missing login URL."),
144 http.Redirect(w, r, redir, http.StatusFound)
147 log.Println(err.Error())
148 w.WriteHeader(http.StatusInternalServerError)
149 context.Render(w, getGrantTemplateName, map[string]interface{}{
150 "internal_error": template.HTML(err.Error()),
154 if r.URL.Query().Get("client_id") == "" {
155 w.WriteHeader(http.StatusBadRequest)
156 context.Render(w, getGrantTemplateName, map[string]interface{}{
157 "error": template.HTML("Client ID must be specified in the request."),
161 clientID, err := uuid.Parse(r.URL.Query().Get("client_id"))
163 w.WriteHeader(http.StatusBadRequest)
164 context.Render(w, getGrantTemplateName, map[string]interface{}{
165 "error": template.HTML("client_id is not a valid Client ID."),
169 redirectURI := r.URL.Query().Get("redirect_uri")
170 redirectURL, err := url.Parse(redirectURI)
172 w.WriteHeader(http.StatusBadRequest)
173 context.Render(w, getGrantTemplateName, map[string]interface{}{
174 "error": template.HTML("The redirect_uri specified is not valid."),
178 client, err := context.GetClient(clientID)
180 if err == ErrClientNotFound {
181 w.WriteHeader(http.StatusBadRequest)
182 context.Render(w, getGrantTemplateName, map[string]interface{}{
183 "error": template.HTML("The specified Client couldn’t be found."),
186 log.Println(err.Error())
187 w.WriteHeader(http.StatusInternalServerError)
188 context.Render(w, getGrantTemplateName, map[string]interface{}{
189 "internal_error": template.HTML(err.Error()),
194 // whether a redirect URI is valid or not depends on the number of endpoints
195 // the client has registered
196 numEndpoints, err := context.CountEndpoints(clientID)
198 log.Println(err.Error())
199 w.WriteHeader(http.StatusInternalServerError)
200 context.Render(w, getGrantTemplateName, map[string]interface{}{
201 "internal_error": template.HTML(err.Error()),
206 if redirectURI != "" {
207 // BUG(paddy): We really should normalize URIs before trying to compare them.
208 validURI, err = context.CheckEndpoint(clientID, redirectURI)
210 log.Println(err.Error())
211 w.WriteHeader(http.StatusInternalServerError)
212 context.Render(w, getGrantTemplateName, map[string]interface{}{
213 "internal_error": template.HTML(err.Error()),
217 } else if redirectURI == "" && numEndpoints == 1 {
218 // if we don't specify the endpoint and there's only one endpoint, the
219 // request is valid, and we're redirecting to that one endpoint
221 endpoints, err := context.ListEndpoints(clientID, 1, 0)
223 log.Println(err.Error())
224 w.WriteHeader(http.StatusInternalServerError)
225 context.Render(w, getGrantTemplateName, map[string]interface{}{
226 "internal_error": template.HTML(err.Error()),
230 if len(endpoints) != 1 {
233 u := endpoints[0].URI // Copy it here to avoid grabbing a pointer to the memstore
234 redirectURI = u.String()
241 w.WriteHeader(http.StatusBadRequest)
242 context.Render(w, getGrantTemplateName, map[string]interface{}{
243 "error": template.HTML("The redirect_uri specified is not valid."),
247 scope := r.URL.Query().Get("scope")
248 state := r.URL.Query().Get("state")
249 if r.URL.Query().Get("response_type") != "code" {
250 q := redirectURL.Query()
251 q.Add("error", "invalid_request")
252 q.Add("state", state)
253 redirectURL.RawQuery = q.Encode()
254 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
257 if r.Method == "POST" {
258 // BUG(paddy): We need to implement CSRF protection when obtaining a grant code.
259 if r.PostFormValue("grant") == "approved" {
260 code := uuid.NewID().String()
264 ExpiresIn: defaultGrantExpiration,
267 RedirectURI: r.URL.Query().Get("redirect_uri"),
269 ProfileID: session.ProfileID,
271 err := context.SaveGrant(grant)
273 q := redirectURL.Query()
274 q.Add("error", "server_error")
275 q.Add("state", state)
276 redirectURL.RawQuery = q.Encode()
277 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
280 q := redirectURL.Query()
282 q.Add("state", state)
283 redirectURL.RawQuery = q.Encode()
284 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
287 q := redirectURL.Query()
288 q.Add("error", "access_denied")
289 q.Add("state", state)
290 redirectURL.RawQuery = q.Encode()
291 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
294 w.WriteHeader(http.StatusOK)
295 context.Render(w, getGrantTemplateName, map[string]interface{}{
300 // GetTokenHandler allows a client to exchange an authorization grant for an
301 // access token. See RFC 6749 Section 4.1.3.
302 func GetTokenHandler(w http.ResponseWriter, r *http.Request, context Context) {
303 enc := json.NewEncoder(w)
304 grantType := r.PostFormValue("grant_type")
305 if grantType != "authorization_code" {
306 // TODO(paddy): render invalid request JSON
309 code := r.PostFormValue("code")
311 // TODO(paddy): render invalid request JSON
314 redirectURI := r.PostFormValue("redirect_uri")
315 clientIDStr, clientSecret, err := getBasicAuth(r)
317 // TODO(paddy): render access denied
320 if clientIDStr == "" && err == nil {
321 clientIDStr = r.PostFormValue("client_id")
323 clientID, err := uuid.Parse(clientIDStr)
325 // TODO(paddy): render invalid request JSON
328 client, err := context.GetClient(clientID)
330 if err == ErrClientNotFound {
331 // TODO(paddy): render invalid request JSON
333 // TODO(paddy): render internal server error JSON
337 if client.Secret != clientSecret {
338 // TODO(paddy): render invalid request JSON
341 grant, err := context.GetGrant(code)
343 if err == ErrGrantNotFound {
344 // TODO(paddy): return error
347 // TODO(paddy): return error
349 if grant.RedirectURI != redirectURI {
350 // TODO(paddy): return error
352 if !grant.ClientID.Equal(clientID) {
353 // TODO(paddy): return error
356 AccessToken: uuid.NewID().String(),
357 RefreshToken: uuid.NewID().String(),
359 ExpiresIn: defaultTokenExpiration,
360 TokenType: "", // TODO(paddy): fill in token type
362 ProfileID: grant.ProfileID,
364 err = context.SaveToken(token)
366 // TODO(paddy): return error
368 resp := tokenResponse{
369 AccessToken: token.AccessToken,
370 RefreshToken: token.RefreshToken,
371 ExpiresIn: token.ExpiresIn,
372 TokenType: token.TokenType,
374 err = enc.Encode(resp)
376 // TODO(paddy): log this or something
381 // TODO(paddy): exchange user credentials for access token
382 // TODO(paddy): exchange client credentials for access token
383 // TODO(paddy): implicit grant for access token
384 // TODO(paddy): exchange refresh token for access token