auth
auth/oauth2.go
Add tests for redirecting to the login page. Make sure that we're redirecting to the configured login page (or returning an error) as expected when trying to obtain a grant code.
1 package auth
3 import (
4 "encoding/base64"
5 "encoding/hex"
6 "encoding/json"
7 "errors"
8 "html/template"
9 "log"
10 "net/http"
11 "net/url"
12 "strings"
13 "time"
14 "github.com/gorilla/mux"
16 "crypto/sha256"
17 "code.secondbit.org/pass"
18 "code.secondbit.org/uuid"
19 )
21 const (
22 authCookieName = "auth"
23 defaultGrantExpiration = 600 // default to ten minute grant expirations
24 getGrantTemplateName = "get_grant"
25 )
27 var (
28 // ErrNoAuth is returned when an Authorization header is not present or is empty.
29 ErrNoAuth = errors.New("no authorization header supplied")
30 // ErrInvalidAuthFormat is returned when an Authorization header is present but not the correct format.
31 ErrInvalidAuthFormat = errors.New("authorization header is not in a valid format")
32 // ErrIncorrectAuth is returned when a user authentication attempt does not match the stored values.
33 ErrIncorrectAuth = errors.New("invalid authentication")
34 // ErrInvalidPassphraseScheme is returned when an undefined passphrase scheme is used.
35 ErrInvalidPassphraseScheme = errors.New("invalid passphrase scheme")
36 // ErrNoSession is returned when no session ID is passed with a request.
37 ErrNoSession = errors.New("no session ID found")
38 )
40 type tokenResponse struct {
41 AccessToken string `json:"access_token"`
42 TokenType string `json:"token_type,omitempty"`
43 ExpiresIn int32 `json:"expires_in,omitempty"`
44 RefreshToken string `json:"refresh_token,omitempty"`
45 }
47 func getBasicAuth(r *http.Request) (un, pass string, err error) {
48 auth := r.Header.Get("Authorization")
49 if auth == "" {
50 return "", "", ErrNoAuth
51 }
52 pieces := strings.SplitN(auth, " ", 2)
53 if pieces[0] != "Basic" {
54 return "", "", ErrInvalidAuthFormat
55 }
56 decoded, err := base64.StdEncoding.DecodeString(pieces[1])
57 if err != nil {
58 return "", "", ErrInvalidAuthFormat
59 }
60 info := strings.SplitN(string(decoded), ":", 2)
61 if len(info) < 2 {
62 return "", "", ErrInvalidAuthFormat
63 }
64 return info[0], info[1], nil
65 }
67 func checkCookie(r *http.Request, context Context) (Session, error) {
68 cookie, err := r.Cookie(authCookieName)
69 if err == http.ErrNoCookie {
70 return Session{}, ErrNoSession
71 } else if err != nil {
72 log.Println(err)
73 return Session{}, err
74 }
75 sess, err := context.GetSession(cookie.Value)
76 if err == ErrSessionNotFound {
77 return Session{}, ErrInvalidSession
78 } else if err != nil {
79 return Session{}, err
80 }
81 if !sess.Active {
82 return Session{}, ErrInvalidSession
83 }
84 return sess, nil
85 }
87 func buildLoginRedirect(r *http.Request, context Context) string {
88 if context.loginURI == nil {
89 return ""
90 }
91 uri := *context.loginURI
92 q := uri.Query()
93 q.Set("from", r.URL.String())
94 uri.RawQuery = q.Encode()
95 return uri.String()
96 }
98 func authenticate(user, passphrase string, context Context) (Profile, error) {
99 profile, err := context.GetProfileByLogin(user)
100 if err != nil {
101 if err == ErrProfileNotFound || err == ErrLoginNotFound {
102 return Profile{}, ErrIncorrectAuth
103 }
104 return Profile{}, err
105 }
106 switch profile.PassphraseScheme {
107 case 1:
108 realPass, err := hex.DecodeString(profile.Passphrase)
109 if err != nil {
110 return Profile{}, err
111 }
112 candidate := pass.Check(sha256.New, profile.Iterations, []byte(passphrase), []byte(profile.Salt))
113 if !pass.Compare(candidate, realPass) {
114 return Profile{}, ErrIncorrectAuth
115 }
116 default:
117 return Profile{}, ErrInvalidPassphraseScheme
118 }
119 return profile, nil
120 }
122 func wrap(context Context, f func(w http.ResponseWriter, r *http.Request, context Context)) http.Handler {
123 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124 f(w, r, context)
125 })
126 }
128 // RegisterOAuth2 adds handlers to the passed router to handle the OAuth2 endpoints.
129 func RegisterOAuth2(r *mux.Router, context Context) {
130 r.Handle("/authorize", wrap(context, GetGrantHandler))
131 r.Handle("/token", wrap(context, GetTokenHandler))
132 }
134 // GetGrantHandler presents and processes the page for asking a user to grant access
135 // to their data. See RFC 6749, Section 4.1.
136 func GetGrantHandler(w http.ResponseWriter, r *http.Request, context Context) {
137 session, err := checkCookie(r, context)
138 if err != nil {
139 if err == ErrNoSession || err == ErrInvalidSession {
140 redir := buildLoginRedirect(r, context)
141 if redir == "" {
142 log.Println("No login URL configured.")
143 w.WriteHeader(http.StatusInternalServerError)
144 context.Render(w, getGrantTemplateName, map[string]interface{}{
145 "internal_error": template.HTML("Missing login URL."),
146 })
147 return
148 }
149 http.Redirect(w, r, redir, http.StatusFound)
150 return
151 }
152 log.Println(err.Error())
153 w.WriteHeader(http.StatusInternalServerError)
154 context.Render(w, getGrantTemplateName, map[string]interface{}{
155 "internal_error": template.HTML(err.Error()),
156 })
157 return
158 }
159 if r.URL.Query().Get("client_id") == "" {
160 w.WriteHeader(http.StatusBadRequest)
161 context.Render(w, getGrantTemplateName, map[string]interface{}{
162 "error": template.HTML("Client ID must be specified in the request."),
163 })
164 return
165 }
166 clientID, err := uuid.Parse(r.URL.Query().Get("client_id"))
167 if err != nil {
168 w.WriteHeader(http.StatusBadRequest)
169 context.Render(w, getGrantTemplateName, map[string]interface{}{
170 "error": template.HTML("client_id is not a valid Client ID."),
171 })
172 return
173 }
174 redirectURI := r.URL.Query().Get("redirect_uri")
175 redirectURL, err := url.Parse(redirectURI)
176 if err != nil {
177 w.WriteHeader(http.StatusBadRequest)
178 context.Render(w, getGrantTemplateName, map[string]interface{}{
179 "error": template.HTML("The redirect_uri specified is not valid."),
180 })
181 return
182 }
183 client, err := context.GetClient(clientID)
184 if err != nil {
185 if err == ErrClientNotFound {
186 w.WriteHeader(http.StatusBadRequest)
187 context.Render(w, getGrantTemplateName, map[string]interface{}{
188 "error": template.HTML("The specified Client couldn’t be found."),
189 })
190 } else {
191 log.Println(err.Error())
192 w.WriteHeader(http.StatusInternalServerError)
193 context.Render(w, getGrantTemplateName, map[string]interface{}{
194 "internal_error": template.HTML(err.Error()),
195 })
196 }
197 return
198 }
199 // whether a redirect URI is valid or not depends on the number of endpoints
200 // the client has registered
201 numEndpoints, err := context.CountEndpoints(clientID)
202 if err != nil {
203 log.Println(err.Error())
204 w.WriteHeader(http.StatusInternalServerError)
205 context.Render(w, getGrantTemplateName, map[string]interface{}{
206 "internal_error": template.HTML(err.Error()),
207 })
208 return
209 }
210 var validURI bool
211 if redirectURI != "" {
212 // BUG(paddy): We really should normalize URIs before trying to compare them.
213 validURI, err = context.CheckEndpoint(clientID, redirectURI)
214 if err != nil {
215 log.Println(err.Error())
216 w.WriteHeader(http.StatusInternalServerError)
217 context.Render(w, getGrantTemplateName, map[string]interface{}{
218 "internal_error": template.HTML(err.Error()),
219 })
220 return
221 }
222 } else if redirectURI == "" && numEndpoints == 1 {
223 // if we don't specify the endpoint and there's only one endpoint, the
224 // request is valid, and we're redirecting to that one endpoint
225 validURI = true
226 endpoints, err := context.ListEndpoints(clientID, 1, 0)
227 if err != nil {
228 log.Println(err.Error())
229 w.WriteHeader(http.StatusInternalServerError)
230 context.Render(w, getGrantTemplateName, map[string]interface{}{
231 "internal_error": template.HTML(err.Error()),
232 })
233 return
234 }
235 if len(endpoints) != 1 {
236 validURI = false
237 } else {
238 u := endpoints[0].URI // Copy it here to avoid grabbing a pointer to the memstore
239 redirectURI = u.String()
240 redirectURL = &u
241 }
242 } else {
243 validURI = false
244 }
245 if !validURI {
246 w.WriteHeader(http.StatusBadRequest)
247 context.Render(w, getGrantTemplateName, map[string]interface{}{
248 "error": template.HTML("The redirect_uri specified is not valid."),
249 })
250 return
251 }
252 scope := r.URL.Query().Get("scope")
253 state := r.URL.Query().Get("state")
254 if r.URL.Query().Get("response_type") != "code" {
255 q := redirectURL.Query()
256 q.Add("error", "invalid_request")
257 q.Add("state", state)
258 redirectURL.RawQuery = q.Encode()
259 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
260 return
261 }
262 if r.Method == "POST" {
263 // BUG(paddy): We need to implement CSRF protection when obtaining a grant code.
264 if r.PostFormValue("grant") == "approved" {
265 code := uuid.NewID().String()
266 grant := Grant{
267 Code: code,
268 Created: time.Now(),
269 ExpiresIn: defaultGrantExpiration,
270 ClientID: clientID,
271 Scope: scope,
272 RedirectURI: r.URL.Query().Get("redirect_uri"),
273 State: state,
274 ProfileID: session.ProfileID,
275 }
276 err := context.SaveGrant(grant)
277 if err != nil {
278 q := redirectURL.Query()
279 q.Add("error", "server_error")
280 q.Add("state", state)
281 redirectURL.RawQuery = q.Encode()
282 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
283 return
284 }
285 q := redirectURL.Query()
286 q.Add("code", code)
287 q.Add("state", state)
288 redirectURL.RawQuery = q.Encode()
289 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
290 return
291 }
292 q := redirectURL.Query()
293 q.Add("error", "access_denied")
294 q.Add("state", state)
295 redirectURL.RawQuery = q.Encode()
296 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
297 return
298 }
299 w.WriteHeader(http.StatusOK)
300 context.Render(w, getGrantTemplateName, map[string]interface{}{
301 "client": client,
302 })
303 }
305 // GetTokenHandler allows a client to exchange an authorization grant for an
306 // access token. See RFC 6749 Section 4.1.3.
307 func GetTokenHandler(w http.ResponseWriter, r *http.Request, context Context) {
308 enc := json.NewEncoder(w)
309 grantType := r.PostFormValue("grant_type")
310 if grantType != "authorization_code" {
311 // TODO(paddy): render invalid request JSON
312 return
313 }
314 code := r.PostFormValue("code")
315 if code == "" {
316 // TODO(paddy): render invalid request JSON
317 return
318 }
319 redirectURI := r.PostFormValue("redirect_uri")
320 clientIDStr, clientSecret, err := getBasicAuth(r)
321 if err != nil {
322 // TODO(paddy): render access denied
323 return
324 }
325 if clientIDStr == "" && err == nil {
326 clientIDStr = r.PostFormValue("client_id")
327 }
328 clientID, err := uuid.Parse(clientIDStr)
329 if err != nil {
330 // TODO(paddy): render invalid request JSON
331 return
332 }
333 client, err := context.GetClient(clientID)
334 if err != nil {
335 if err == ErrClientNotFound {
336 // TODO(paddy): render invalid request JSON
337 } else {
338 // TODO(paddy): render internal server error JSON
339 }
340 return
341 }
342 if client.Secret != clientSecret {
343 // TODO(paddy): render invalid request JSON
344 return
345 }
346 grant, err := context.GetGrant(code)
347 if err != nil {
348 if err == ErrGrantNotFound {
349 // TODO(paddy): return error
350 return
351 }
352 // TODO(paddy): return error
353 }
354 if grant.RedirectURI != redirectURI {
355 // TODO(paddy): return error
356 }
357 if !grant.ClientID.Equal(clientID) {
358 // TODO(paddy): return error
359 }
360 token := Token{
361 AccessToken: uuid.NewID().String(),
362 RefreshToken: uuid.NewID().String(),
363 Created: time.Now(),
364 ExpiresIn: defaultTokenExpiration,
365 TokenType: "", // TODO(paddy): fill in token type
366 Scope: grant.Scope,
367 ProfileID: grant.ProfileID,
368 }
369 err = context.SaveToken(token)
370 if err != nil {
371 // TODO(paddy): return error
372 }
373 resp := tokenResponse{
374 AccessToken: token.AccessToken,
375 RefreshToken: token.RefreshToken,
376 ExpiresIn: token.ExpiresIn,
377 TokenType: token.TokenType,
378 }
379 err = enc.Encode(resp)
380 if err != nil {
381 // TODO(paddy): log this or something
382 return
383 }
384 }
386 // TODO(paddy): exchange user credentials for access token
387 // TODO(paddy): exchange client credentials for access token
388 // TODO(paddy): implicit grant for access token
389 // TODO(paddy): exchange refresh token for access token