auth
auth/oauth2.go
Test authentication helper, fix bugs with authentication. Authentication needs to be hex encoded to be stored, so it only makes sense to decode the hex string stored to get the bytes we'll be comparing. Check for ErrLoginNotFound in addition to ErrProfileNotFound. ErrLoginNotFound is more likely to occur, anyways. Add unit tests for our authentication helper to verify it functions as expected.
1 package auth
3 import (
4 "encoding/base64"
5 "encoding/hex"
6 "encoding/json"
7 "errors"
8 "html/template"
9 "log"
10 "net/http"
11 "net/url"
12 "strings"
13 "time"
14 "github.com/gorilla/mux"
16 "crypto/sha256"
17 "code.secondbit.org/pass"
18 "code.secondbit.org/uuid"
19 )
21 const (
22 authCookieName = "auth"
23 defaultGrantExpiration = 600 // default to ten minute grant expirations
24 getGrantTemplateName = "get_grant"
25 )
27 var (
28 // ErrNoAuth is returned when an Authorization header is not present or is empty.
29 ErrNoAuth = errors.New("no authorization header supplied")
30 // ErrInvalidAuthFormat is returned when an Authorization header is present but not the correct format.
31 ErrInvalidAuthFormat = errors.New("authorization header is not in a valid format")
32 // ErrIncorrectAuth is returned when a user authentication attempt does not match the stored values.
33 ErrIncorrectAuth = errors.New("invalid authentication")
34 // ErrInvalidPassphraseScheme is returned when an undefined passphrase scheme is used.
35 ErrInvalidPassphraseScheme = errors.New("invalid passphrase scheme")
36 // ErrNoSession is returned when no session ID is passed with a request.
37 ErrNoSession = errors.New("no session ID found")
38 )
40 type tokenResponse struct {
41 AccessToken string `json:"access_token"`
42 TokenType string `json:"token_type,omitempty"`
43 ExpiresIn int32 `json:"expires_in,omitempty"`
44 RefreshToken string `json:"refresh_token,omitempty"`
45 }
47 func getBasicAuth(r *http.Request) (un, pass string, err error) {
48 auth := r.Header.Get("Authorization")
49 if auth == "" {
50 return "", "", ErrNoAuth
51 }
52 pieces := strings.SplitN(auth, " ", 2)
53 if pieces[0] != "Basic" {
54 return "", "", ErrInvalidAuthFormat
55 }
56 decoded, err := base64.StdEncoding.DecodeString(pieces[1])
57 if err != nil {
58 return "", "", ErrInvalidAuthFormat
59 }
60 info := strings.SplitN(string(decoded), ":", 2)
61 if len(info) < 2 {
62 return "", "", ErrInvalidAuthFormat
63 }
64 return info[0], info[1], nil
65 }
67 func checkCookie(r *http.Request, context Context) (Session, error) {
68 cookie, err := r.Cookie(authCookieName)
69 if err == http.ErrNoCookie {
70 return Session{}, ErrNoSession
71 } else if err != nil {
72 log.Println(err)
73 return Session{}, err
74 }
75 sess, err := context.GetSession(cookie.Value)
76 if err == ErrSessionNotFound {
77 return Session{}, ErrInvalidSession
78 } else if err != nil {
79 return Session{}, err
80 }
81 if !sess.Active {
82 return Session{}, ErrInvalidSession
83 }
84 return sess, nil
85 }
87 func buildLoginRedirect(r *http.Request, context Context) string {
88 if context.loginURI == nil {
89 return ""
90 }
91 uri := *context.loginURI
92 q := uri.Query()
93 q.Set("from", r.URL.String())
94 uri.RawQuery = q.Encode()
95 return uri.String()
96 }
98 func authenticate(user, passphrase string, context Context) (Profile, error) {
99 profile, err := context.GetProfileByLogin(user)
100 if err != nil {
101 if err == ErrProfileNotFound || err == ErrLoginNotFound {
102 return Profile{}, ErrIncorrectAuth
103 }
104 return Profile{}, err
105 }
106 switch profile.PassphraseScheme {
107 case 1:
108 realPass, err := hex.DecodeString(profile.Passphrase)
109 if err != nil {
110 return Profile{}, err
111 }
112 candidate := pass.Check(sha256.New, profile.Iterations, []byte(passphrase), []byte(profile.Salt))
113 if !pass.Compare(candidate, realPass) {
114 return Profile{}, ErrIncorrectAuth
115 }
116 default:
117 return Profile{}, ErrInvalidPassphraseScheme
118 }
119 return profile, nil
120 }
122 func wrap(context Context, f func(w http.ResponseWriter, r *http.Request, context Context)) http.Handler {
123 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124 f(w, r, context)
125 })
126 }
128 // RegisterOAuth2 adds handlers to the passed router to handle the OAuth2 endpoints.
129 func RegisterOAuth2(r *mux.Router, context Context) {
130 r.Handle("/authorize", wrap(context, GetGrantHandler))
131 r.Handle("/token", wrap(context, GetTokenHandler))
132 }
134 // GetGrantHandler presents and processes the page for asking a user to grant access
135 // to their data. See RFC 6749, Section 4.1.
136 func GetGrantHandler(w http.ResponseWriter, r *http.Request, context Context) {
137 session, err := checkCookie(r, context)
138 if err != nil {
139 if err == ErrNoSession || err == ErrInvalidSession {
140 redir := buildLoginRedirect(r, context)
141 if redir == "" {
142 log.Println("No login URL configured.")
143 w.WriteHeader(http.StatusInternalServerError)
144 context.Render(w, getGrantTemplateName, map[string]interface{}{
145 "internal_error": template.HTML("Missing login URL."),
146 })
147 return
148 }
149 http.Redirect(w, r, redir, http.StatusFound)
150 return
151 }
152 log.Println(err.Error())
153 w.WriteHeader(http.StatusInternalServerError)
154 context.Render(w, getGrantTemplateName, map[string]interface{}{
155 "internal_error": template.HTML(err.Error()),
156 })
157 return
158 }
159 if r.URL.Query().Get("client_id") == "" {
160 w.WriteHeader(http.StatusBadRequest)
161 context.Render(w, getGrantTemplateName, map[string]interface{}{
162 "error": template.HTML("Client ID must be specified in the request."),
163 })
164 return
165 }
166 clientID, err := uuid.Parse(r.URL.Query().Get("client_id"))
167 if err != nil {
168 w.WriteHeader(http.StatusBadRequest)
169 context.Render(w, getGrantTemplateName, map[string]interface{}{
170 "error": template.HTML("client_id is not a valid Client ID."),
171 })
172 return
173 }
174 redirectURI := r.URL.Query().Get("redirect_uri")
175 redirectURL, err := url.Parse(redirectURI)
176 if err != nil {
177 w.WriteHeader(http.StatusBadRequest)
178 context.Render(w, getGrantTemplateName, map[string]interface{}{
179 "error": template.HTML("The redirect_uri specified is not valid."),
180 })
181 return
182 }
183 client, err := context.GetClient(clientID)
184 if err != nil {
185 if err == ErrClientNotFound {
186 w.WriteHeader(http.StatusBadRequest)
187 context.Render(w, getGrantTemplateName, map[string]interface{}{
188 "error": template.HTML("The specified Client couldn’t be found."),
189 })
190 } else {
191 log.Println(err.Error())
192 w.WriteHeader(http.StatusInternalServerError)
193 context.Render(w, getGrantTemplateName, map[string]interface{}{
194 "internal_error": template.HTML(err.Error()),
195 })
196 }
197 return
198 }
199 // whether a redirect URI is valid or not depends on the number of endpoints
200 // the client has registered
201 numEndpoints, err := context.CountEndpoints(clientID)
202 if err != nil {
203 log.Println(err.Error())
204 w.WriteHeader(http.StatusInternalServerError)
205 context.Render(w, getGrantTemplateName, map[string]interface{}{
206 "internal_error": template.HTML(err.Error()),
207 })
208 return
209 }
210 var validURI bool
211 if redirectURI != "" {
212 // BUG(paddy): We really should normalize URIs before trying to compare them.
213 validURI, err = context.CheckEndpoint(clientID, redirectURI)
214 if err != nil {
215 log.Println(err.Error())
216 w.WriteHeader(http.StatusInternalServerError)
217 context.Render(w, getGrantTemplateName, map[string]interface{}{
218 "internal_error": template.HTML(err.Error()),
219 })
220 return
221 }
222 } else if redirectURI == "" && numEndpoints == 1 {
223 // if we don't specify the endpoint and there's only one endpoint, the
224 // request is valid, and we're redirecting to that one endpoint
225 validURI = true
226 endpoints, err := context.ListEndpoints(clientID, 1, 0)
227 if err != nil {
228 log.Println(err.Error())
229 w.WriteHeader(http.StatusInternalServerError)
230 context.Render(w, getGrantTemplateName, map[string]interface{}{
231 "internal_error": template.HTML(err.Error()),
232 })
233 return
234 }
235 if len(endpoints) != 1 {
236 validURI = false
237 } else {
238 u := endpoints[0].URI // Copy it here to avoid grabbing a pointer to the memstore
239 redirectURI = u.String()
240 redirectURL = &u
241 }
242 } else {
243 validURI = false
244 }
245 if !validURI {
246 w.WriteHeader(http.StatusBadRequest)
247 context.Render(w, getGrantTemplateName, map[string]interface{}{
248 "error": template.HTML("The redirect_uri specified is not valid."),
249 })
250 return
251 }
252 scope := r.URL.Query().Get("scope")
253 state := r.URL.Query().Get("state")
254 if r.URL.Query().Get("response_type") != "code" {
255 q := redirectURL.Query()
256 q.Add("error", "invalid_request")
257 q.Add("state", state)
258 redirectURL.RawQuery = q.Encode()
259 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
260 return
261 }
262 if r.Method == "POST" {
263 // BUG(paddy): We need to implement CSRF protection when obtaining a grant code.
264 if r.PostFormValue("grant") == "approved" {
265 code := uuid.NewID().String()
266 grant := Grant{
267 Code: code,
268 Created: time.Now(),
269 ExpiresIn: defaultGrantExpiration,
270 ClientID: clientID,
271 Scope: scope,
272 RedirectURI: r.URL.Query().Get("redirect_uri"),
273 State: state,
274 ProfileID: session.ProfileID,
275 }
276 err := context.SaveGrant(grant)
277 if err != nil {
278 q := redirectURL.Query()
279 q.Add("error", "server_error")
280 q.Add("state", state)
281 redirectURL.RawQuery = q.Encode()
282 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
283 return
284 }
285 q := redirectURL.Query()
286 q.Add("code", code)
287 q.Add("state", state)
288 redirectURL.RawQuery = q.Encode()
289 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
290 return
291 }
292 q := redirectURL.Query()
293 q.Add("error", "access_denied")
294 q.Add("state", state)
295 redirectURL.RawQuery = q.Encode()
296 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
297 return
298 }
299 w.WriteHeader(http.StatusOK)
300 context.Render(w, getGrantTemplateName, map[string]interface{}{
301 "client": client,
302 })
303 }
305 // GetTokenHandler allows a client to exchange an authorization grant for an
306 // access token. See RFC 6749 Section 4.1.3.
307 func GetTokenHandler(w http.ResponseWriter, r *http.Request, context Context) {
308 enc := json.NewEncoder(w)
309 grantType := r.PostFormValue("grant_type")
310 if grantType != "authorization_code" {
311 // TODO(paddy): render invalid request JSON
312 return
313 }
314 code := r.PostFormValue("code")
315 if code == "" {
316 // TODO(paddy): render invalid request JSON
317 return
318 }
319 redirectURI := r.PostFormValue("redirect_uri")
320 clientIDStr, clientSecret, err := getBasicAuth(r)
321 if err != nil {
322 // TODO(paddy): render access denied
323 return
324 }
325 if clientIDStr == "" && err == nil {
326 clientIDStr = r.PostFormValue("client_id")
327 }
328 clientID, err := uuid.Parse(clientIDStr)
329 if err != nil {
330 // TODO(paddy): render invalid request JSON
331 return
332 }
333 client, err := context.GetClient(clientID)
334 if err != nil {
335 if err == ErrClientNotFound {
336 // TODO(paddy): render invalid request JSON
337 } else {
338 // TODO(paddy): render internal server error JSON
339 }
340 return
341 }
342 if client.Secret != clientSecret {
343 // TODO(paddy): render invalid request JSON
344 return
345 }
346 grant, err := context.GetGrant(code)
347 if err != nil {
348 if err == ErrGrantNotFound {
349 // TODO(paddy): return error
350 return
351 }
352 // TODO(paddy): return error
353 }
354 if grant.RedirectURI != redirectURI {
355 // TODO(paddy): return error
356 }
357 if !grant.ClientID.Equal(clientID) {
358 // TODO(paddy): return error
359 }
360 token := Token{
361 AccessToken: uuid.NewID().String(),
362 RefreshToken: uuid.NewID().String(),
363 Created: time.Now(),
364 ExpiresIn: defaultTokenExpiration,
365 TokenType: "", // TODO(paddy): fill in token type
366 Scope: grant.Scope,
367 ProfileID: grant.ProfileID,
368 }
369 err = context.SaveToken(token)
370 if err != nil {
371 // TODO(paddy): return error
372 }
373 resp := tokenResponse{
374 AccessToken: token.AccessToken,
375 RefreshToken: token.RefreshToken,
376 ExpiresIn: token.ExpiresIn,
377 TokenType: token.TokenType,
378 }
379 err = enc.Encode(resp)
380 if err != nil {
381 // TODO(paddy): log this or something
382 return
383 }
384 }
386 // TODO(paddy): exchange user credentials for access token
387 // TODO(paddy): exchange client credentials for access token
388 // TODO(paddy): implicit grant for access token
389 // TODO(paddy): exchange refresh token for access token