auth
auth/oauth2.go
Change normalization flags to a constant. Let's use a constant so we can ensure we're using the same flags everywhere. Otherwise, we can get weird data corruption because we use the wrong flags.
1 package auth
3 import (
4 "encoding/json"
5 "errors"
6 "html/template"
7 "log"
8 "net/http"
9 "net/url"
10 "strconv"
11 "sync"
12 "time"
14 "code.secondbit.org/uuid.hg"
15 "github.com/gorilla/mux"
16 )
18 const (
19 authCookieName = "auth"
20 defaultAuthorizationCodeExpiration = 600 // default to ten minute grant expirations
21 getAuthorizationCodeTemplateName = "get_grant"
22 )
24 var (
25 // ErrNoAuth is returned when an Authorization header is not present or is empty.
26 ErrNoAuth = errors.New("no authorization header supplied")
27 // ErrInvalidAuthFormat is returned when an Authorization header is present but not the correct format.
28 ErrInvalidAuthFormat = errors.New("authorization header is not in a valid format")
29 // ErrIncorrectAuth is returned when a user authentication attempt does not match the stored values.
30 ErrIncorrectAuth = errors.New("invalid authentication")
31 // ErrInvalidPassphraseScheme is returned when an undefined passphrase scheme is used.
32 ErrInvalidPassphraseScheme = errors.New("invalid passphrase scheme")
33 // ErrNoSession is returned when no session ID is passed with a request.
34 ErrNoSession = errors.New("no session ID found")
36 grantTypesMap = grantTypes{types: map[string]GrantType{}}
37 )
39 type grantTypes struct {
40 types map[string]GrantType
41 sync.RWMutex
42 }
44 // GrantType defines a set of functions and metadata around a specific authorization grant strategy.
45 //
46 // The Validate function will be called when requests are made that match the GrantType, and should write any
47 // errors to the ResponseWriter. It is responsible for determining if the grant is valid and a token should be issued.
48 // It must return the scope the grant was for and the ID of the Profile that issued the grant, as well as if the grant
49 // is valid or not. It must not be nil.
50 //
51 // The Invalidate function will be called when the grant has successfully generated a token and the token has successfully
52 // been conveyed to the user. The Invalidate function is always called asynchronously, outside the request. It should take
53 // care of marking the grant as used, if the GrantType requires grants to be one-time only grants. The Invalidate function
54 // can be nil.
55 //
56 // IssuesRefresh determines whether the GrantType should yield a refresh token as well as an access token. If true, the client
57 // will be issued a refresh token.
58 //
59 // AllowsPublic determines whether the GrantType should allow public clients to use that grant. If true, clients without
60 // credentials will be able to use the grant to obtain a token.
61 //
62 // AuditString should return the string that will be saved in the resulting Token's CreatedFrom field, as an audit log of how
63 // the Token was authorized.
64 //
65 // The ReturnToken will be called when a token is created and needs to be returned to the client. If it returns true, the token
66 // was successfully returned and the Invalidate function will be called asynchronously.
67 type GrantType struct {
68 Validate func(w http.ResponseWriter, r *http.Request, context Context) (scope string, profileID uuid.ID, valid bool)
69 Invalidate func(r *http.Request, context Context) error
70 ReturnToken func(w http.ResponseWriter, r *http.Request, token Token, context Context) bool
71 AuditString func(r *http.Request) string
72 IssuesRefresh bool
73 AllowsPublic bool
74 }
76 type tokenResponse struct {
77 AccessToken string `json:"access_token"`
78 TokenType string `json:"token_type,omitempty"`
79 ExpiresIn int32 `json:"expires_in,omitempty"`
80 RefreshToken string `json:"refresh_token,omitempty"`
81 }
83 type errorResponse struct {
84 Error string `json:"error"`
85 Description string `json:"error_description,omitempty"`
86 URI string `json:"error_uri,omitempty"`
87 }
89 // RegisterGrantType associates a string with a GrantType. When the string is used as the value for "grant_type" when obtaining
90 // an access token, the associated GrantType's properties will be used.
91 //
92 // RegisterGrantType should be called in the `init()` function of packages, much like database/sql registers drivers. It will panic
93 // if a GrantType tries to register under a string that already has a GrantType registered for it.
94 func RegisterGrantType(name string, g GrantType) {
95 grantTypesMap.Lock()
96 defer grantTypesMap.Unlock()
97 if _, ok := grantTypesMap.types[name]; ok {
98 panic("Duplicate registration of grant_type " + name)
99 }
100 grantTypesMap.types[name] = g
101 }
103 func findGrantType(name string) (GrantType, bool) {
104 grantTypesMap.RLock()
105 defer grantTypesMap.RUnlock()
106 t, ok := grantTypesMap.types[name]
107 return t, ok
108 }
110 func renderJSONError(enc *json.Encoder, errorType string) {
111 err := enc.Encode(errorResponse{
112 Error: errorType,
113 })
114 if err != nil {
115 log.Println(err)
116 }
117 }
119 // RenderJSONToken is an implementation of the ReturnToken function for GrantTypes. It returns the token using JSON
120 // according to the spec. See RFC 6479, Section 4.1.4.
121 func RenderJSONToken(w http.ResponseWriter, r *http.Request, token Token, context Context) bool {
122 enc := json.NewEncoder(w)
123 resp := tokenResponse{
124 AccessToken: token.AccessToken,
125 RefreshToken: token.RefreshToken,
126 ExpiresIn: token.ExpiresIn,
127 TokenType: token.TokenType,
128 }
129 w.Header().Set("Content-Type", "application/json")
130 err := enc.Encode(resp)
131 if err != nil {
132 log.Println(err)
133 return false
134 }
135 return true
136 }
138 // RegisterOAuth2 adds handlers to the passed router to handle the OAuth2 endpoints.
139 func RegisterOAuth2(r *mux.Router, context Context) {
140 r.Handle("/authorize", wrap(context, GetAuthorizationCodeHandler))
141 r.Handle("/token", wrap(context, GetTokenHandler))
142 }
144 // GetAuthorizationCodeHandler presents and processes the page for asking a user to grant access
145 // to their data. See RFC 6749, Section 4.1.
146 func GetAuthorizationCodeHandler(w http.ResponseWriter, r *http.Request, context Context) {
147 session, err := checkCookie(r, context)
148 if err != nil {
149 if err == ErrNoSession || err == ErrInvalidSession {
150 redir := buildLoginRedirect(r, context)
151 if redir == "" {
152 log.Println("No login URL configured.")
153 w.WriteHeader(http.StatusInternalServerError)
154 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
155 "internal_error": template.HTML("Missing login URL."),
156 })
157 return
158 }
159 http.Redirect(w, r, redir, http.StatusFound)
160 return
161 }
162 log.Println(err.Error())
163 w.WriteHeader(http.StatusInternalServerError)
164 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
165 "internal_error": template.HTML(err.Error()),
166 })
167 return
168 }
169 if r.URL.Query().Get("client_id") == "" {
170 w.WriteHeader(http.StatusBadRequest)
171 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
172 "error": template.HTML("Client ID must be specified in the request."),
173 })
174 return
175 }
176 clientID, err := uuid.Parse(r.URL.Query().Get("client_id"))
177 if err != nil {
178 w.WriteHeader(http.StatusBadRequest)
179 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
180 "error": template.HTML("client_id is not a valid Client ID."),
181 })
182 return
183 }
184 redirectURI := r.URL.Query().Get("redirect_uri")
185 client, err := context.GetClient(clientID)
186 if err != nil {
187 if err == ErrClientNotFound {
188 w.WriteHeader(http.StatusBadRequest)
189 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
190 "error": template.HTML("The specified Client couldn’t be found."),
191 })
192 } else {
193 log.Println(err.Error())
194 w.WriteHeader(http.StatusInternalServerError)
195 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
196 "internal_error": template.HTML(err.Error()),
197 })
198 }
199 return
200 }
201 // BUG(paddy): Checking if the redirect URI is valid should be a helper function.
203 // whether a redirect URI is valid or not depends on the number of endpoints
204 // the client has registered
205 numEndpoints, err := context.CountEndpoints(clientID)
206 if err != nil {
207 log.Println(err.Error())
208 w.WriteHeader(http.StatusInternalServerError)
209 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
210 "internal_error": template.HTML(err.Error()),
211 })
212 return
213 }
214 var validURI bool
215 if redirectURI != "" {
216 validURI, err = context.CheckEndpoint(clientID, redirectURI)
217 if err != nil {
218 if err == ErrEndpointURINotURL {
219 w.WriteHeader(http.StatusBadRequest)
220 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
221 "error": template.HTML("The redirect_uri specified is not valid."),
222 })
223 return
224 }
225 log.Println(err.Error())
226 w.WriteHeader(http.StatusInternalServerError)
227 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
228 "internal_error": template.HTML(err.Error()),
229 })
230 return
231 }
232 } else if redirectURI == "" && numEndpoints == 1 {
233 // if we don't specify the endpoint and there's only one endpoint, the
234 // request is valid, and we're redirecting to that one endpoint
235 validURI = true
236 endpoints, err := context.ListEndpoints(clientID, 1, 0)
237 if err != nil {
238 log.Println(err.Error())
239 w.WriteHeader(http.StatusInternalServerError)
240 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
241 "internal_error": template.HTML(err.Error()),
242 })
243 return
244 }
245 if len(endpoints) != 1 {
246 validURI = false
247 } else {
248 redirectURI = endpoints[0].URI
249 }
250 } else {
251 validURI = false
252 }
253 if !validURI {
254 w.WriteHeader(http.StatusBadRequest)
255 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
256 "error": template.HTML("The redirect_uri specified is not valid."),
257 })
258 return
259 }
260 redirectURL, err := url.Parse(redirectURI)
261 if err != nil {
262 w.WriteHeader(http.StatusBadRequest)
263 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
264 "error": template.HTML("The redirect_uri specified is not valid."),
265 })
266 return
267 }
268 scope := r.URL.Query().Get("scope")
269 state := r.URL.Query().Get("state")
270 responseType := r.URL.Query().Get("response_type")
271 q := redirectURL.Query()
272 q.Add("state", state)
273 if responseType != "code" && responseType != "token" {
274 q.Add("error", "invalid_request")
275 redirectURL.RawQuery = q.Encode()
276 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
277 return
278 }
279 if r.Method == "POST" {
280 // BUG(paddy): We need to implement CSRF protection when obtaining a grant code.
281 if r.PostFormValue("grant") == "approved" {
282 var fragment bool
283 switch responseType {
284 case "code":
285 code := uuid.NewID().String()
286 authCode := AuthorizationCode{
287 Code: code,
288 Created: time.Now(),
289 ExpiresIn: defaultAuthorizationCodeExpiration,
290 ClientID: clientID,
291 Scope: scope,
292 RedirectURI: r.URL.Query().Get("redirect_uri"),
293 State: state,
294 ProfileID: session.ProfileID,
295 }
296 err := context.SaveAuthorizationCode(authCode)
297 if err != nil {
298 log.Println("Error saving authorization code:", err)
299 q.Add("error", "server_error")
300 break
301 }
302 q.Add("code", code)
303 case "token":
304 token := Token{
305 AccessToken: uuid.NewID().String(),
306 Created: time.Now(),
307 CreatedFrom: "implicit",
308 ExpiresIn: defaultTokenExpiration,
309 TokenType: "bearer",
310 Scope: scope,
311 ProfileID: session.ProfileID,
312 ClientID: clientID,
313 }
314 err := context.SaveToken(token)
315 if err != nil {
316 log.Println("Error saving token:", err)
317 q.Add("error", "server_error")
318 break
319 }
320 q = url.Values{} // we're not altering the querystring, so don't clone it
321 q.Add("access_token", token.AccessToken)
322 q.Add("token_type", token.TokenType)
323 q.Add("expires_in", strconv.FormatInt(int64(token.ExpiresIn), 10))
324 q.Add("scope", token.Scope)
325 q.Add("state", state) // we wiped out the old values, so we need to set the state again
326 fragment = true
327 }
328 if fragment {
329 redirectURL.Fragment = q.Encode()
330 } else {
331 redirectURL.RawQuery = q.Encode()
332 }
333 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
334 return
335 }
336 q.Add("error", "access_denied")
337 redirectURL.RawQuery = q.Encode()
338 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
339 return
340 }
341 profile, err := context.GetProfileByID(session.ProfileID)
342 if err != nil {
343 log.Println("Error getting profile from session:", err)
344 q.Add("error", "server_error")
345 redirectURL.RawQuery = q.Encode()
346 http.Redirect(w, r, redirectURL.String(), http.StatusFound)
347 return
348 }
349 w.WriteHeader(http.StatusOK)
350 context.Render(w, getAuthorizationCodeTemplateName, map[string]interface{}{
351 "client": client,
352 "redirectURL": redirectURL,
353 "scope": scope,
354 "profile": profile,
355 })
356 }
358 // GetTokenHandler allows a client to exchange an authorization grant for an
359 // access token. See RFC 6749 Section 4.1.3.
360 func GetTokenHandler(w http.ResponseWriter, r *http.Request, context Context) {
361 enc := json.NewEncoder(w)
362 grantType := r.PostFormValue("grant_type")
363 gt, ok := findGrantType(grantType)
364 if !ok {
365 w.WriteHeader(http.StatusBadRequest)
366 renderJSONError(enc, "invalid_request")
367 return
368 }
369 clientID, success := verifyClient(w, r, gt.AllowsPublic, context)
370 if !success {
371 return
372 }
373 scope, profileID, valid := gt.Validate(w, r, context)
374 if !valid {
375 return
376 }
377 refresh := ""
378 if gt.IssuesRefresh {
379 refresh = uuid.NewID().String()
380 }
381 token := Token{
382 AccessToken: uuid.NewID().String(),
383 RefreshToken: refresh,
384 Created: time.Now(),
385 CreatedFrom: gt.AuditString(r),
386 ExpiresIn: defaultTokenExpiration,
387 TokenType: "bearer",
388 Scope: scope,
389 ProfileID: profileID,
390 ClientID: clientID,
391 }
392 err := context.SaveToken(token)
393 if err != nil {
394 w.WriteHeader(http.StatusInternalServerError)
395 renderJSONError(enc, "server_error")
396 return
397 }
398 if gt.ReturnToken(w, r, token, context) && gt.Invalidate != nil {
399 go gt.Invalidate(r, context)
400 }
401 }