auth

Paddy 2015-01-24 Parent:4f5d13d2f7c7 Child:f474ce964dcf

130:6c755b23ec80 Go to Latest

auth/client.go

Change normalization flags to a constant. Let's use a constant so we can ensure we're using the same flags everywhere. Otherwise, we can get weird data corruption because we use the wrong flags.

History
1 package auth
3 import (
4 "crypto/rand"
5 "encoding/hex"
6 "encoding/json"
7 "errors"
8 "log"
9 "net/http"
10 "net/url"
11 "strconv"
12 "time"
14 "github.com/PuerkitoBio/purell"
15 "github.com/gorilla/mux"
17 "code.secondbit.org/uuid.hg"
18 )
20 func init() {
21 RegisterGrantType("client_credentials", GrantType{
22 Validate: clientCredentialsValidate,
23 Invalidate: nil,
24 IssuesRefresh: true,
25 ReturnToken: RenderJSONToken,
26 AllowsPublic: false,
27 AuditString: clientCredentialsAuditString,
28 })
29 }
31 var (
32 // ErrNoClientStore is returned when a Context tries to act on a clientStore without setting one first.
33 ErrNoClientStore = errors.New("no clientStore was specified for the Context")
34 // ErrClientNotFound is returned when a Client is requested but not found in a clientStore.
35 ErrClientNotFound = errors.New("client not found in clientStore")
36 // ErrClientAlreadyExists is returned when a Client is added to a clientStore, but another Client with
37 // the same ID already exists in the clientStore.
38 ErrClientAlreadyExists = errors.New("client already exists in clientStore")
40 // ErrEmptyChange is returned when a Change has all its properties set to nil.
41 ErrEmptyChange = errors.New("change must have at least one property set")
42 // ErrClientNameTooShort is returned when a Client's Name property is too short.
43 ErrClientNameTooShort = errors.New("client name must be at least 2 characters")
44 // ErrClientNameTooLong is returned when a Client's Name property is too long.
45 ErrClientNameTooLong = errors.New("client name must be at most 32 characters")
46 // ErrClientLogoTooLong is returned when a Client's Logo property is too long.
47 ErrClientLogoTooLong = errors.New("client logo must be at most 1024 characters")
48 // ErrClientLogoNotURL is returned when a Client's Logo property is not a valid absolute URL.
49 ErrClientLogoNotURL = errors.New("client logo must be a valid absolute URL")
50 // ErrClientWebsiteTooLong is returned when a Client's Website property is too long.
51 ErrClientWebsiteTooLong = errors.New("client website must be at most 1024 characters")
52 // ErrClientWebsiteNotURL is returned when a Client's Website property is not a valid absolute URL.
53 ErrClientWebsiteNotURL = errors.New("client website must be a valid absolute URL")
54 // ErrEndpointURINotURL is returned when an Endpoint's URI property is not a valid absolute URL.
55 ErrEndpointURINotURL = errors.New("endpoint URI must be a valid absolute URL")
56 )
58 const (
59 clientTypePublic = "public"
60 clientTypeConfidential = "confidential"
61 minClientNameLen = 2
62 maxClientNameLen = 24
64 normalizeFlags = purell.FlagsUsuallySafeNonGreedy | purell.FlagSortQuery
65 )
67 // Client represents a client that grants access
68 // to the auth server, exchanging grants for tokens,
69 // and tokens for access.
70 type Client struct {
71 ID uuid.ID `json:"id,omitempty"`
72 Secret string `json:"secret,omitempty"`
73 OwnerID uuid.ID `json:"owner_id,omitempty"`
74 Name string `json:"name,omitempty"`
75 Logo string `json:"logo,omitempty"`
76 Website string `json:"website,omitempty"`
77 Type string `json:"type,omitempty"`
78 }
80 // ApplyChange applies the properties of the passed
81 // ClientChange to the Client object it is called on.
82 func (c *Client) ApplyChange(change ClientChange) {
83 if change.Secret != nil {
84 c.Secret = *change.Secret
85 }
86 if change.OwnerID != nil {
87 c.OwnerID = change.OwnerID
88 }
89 if change.Name != nil {
90 c.Name = *change.Name
91 }
92 if change.Logo != nil {
93 c.Logo = *change.Logo
94 }
95 if change.Website != nil {
96 c.Website = *change.Website
97 }
98 }
100 // ClientChange represents a bundle of options for
101 // updating a Client's mutable data.
102 type ClientChange struct {
103 Secret *string
104 OwnerID uuid.ID
105 Name *string
106 Logo *string
107 Website *string
108 }
110 // Validate checks the ClientChange it is called on
111 // and asserts its internal validity, or lack thereof.
112 func (c ClientChange) Validate() error {
113 if c.Secret == nil && c.OwnerID == nil && c.Name == nil && c.Logo == nil && c.Website == nil {
114 return ErrEmptyChange
115 }
116 if c.Name != nil && len(*c.Name) < 2 {
117 return ErrClientNameTooShort
118 }
119 if c.Name != nil && len(*c.Name) > 32 {
120 return ErrClientNameTooLong
121 }
122 if c.Logo != nil && *c.Logo != "" {
123 if len(*c.Logo) > 1024 {
124 return ErrClientLogoTooLong
125 }
126 u, err := url.Parse(*c.Logo)
127 if err != nil || !u.IsAbs() {
128 return ErrClientLogoNotURL
129 }
130 }
131 if c.Website != nil && *c.Website != "" {
132 if len(*c.Website) > 140 {
133 return ErrClientWebsiteTooLong
134 }
135 u, err := url.Parse(*c.Website)
136 if err != nil || !u.IsAbs() {
137 return ErrClientWebsiteNotURL
138 }
139 }
140 return nil
141 }
143 func getClientAuth(w http.ResponseWriter, r *http.Request, allowPublic bool) (uuid.ID, string, bool) {
144 enc := json.NewEncoder(w)
145 clientIDStr, clientSecret, fromAuthHeader := r.BasicAuth()
146 if !fromAuthHeader {
147 clientIDStr = r.PostFormValue("client_id")
148 }
149 if clientIDStr == "" {
150 w.WriteHeader(http.StatusUnauthorized)
151 if fromAuthHeader {
152 w.Header().Set("WWW-Authenticate", "Basic")
153 }
154 renderJSONError(enc, "invalid_client")
155 return nil, "", false
156 }
157 if !allowPublic && !fromAuthHeader {
158 w.WriteHeader(http.StatusBadRequest)
159 renderJSONError(enc, "unauthorized_client")
160 return nil, "", false
161 }
162 clientID, err := uuid.Parse(clientIDStr)
163 if err != nil {
164 log.Println("Error decoding client ID:", err)
165 w.WriteHeader(http.StatusUnauthorized)
166 if fromAuthHeader {
167 w.Header().Set("WWW-Authenticate", "Basic")
168 }
169 renderJSONError(enc, "invalid_client")
170 return nil, "", false
171 }
172 return clientID, clientSecret, true
173 }
175 func verifyClient(w http.ResponseWriter, r *http.Request, allowPublic bool, context Context) (uuid.ID, bool) {
176 enc := json.NewEncoder(w)
177 clientID, clientSecret, ok := getClientAuth(w, r, allowPublic)
178 if !ok {
179 return nil, false
180 }
181 _, _, fromAuthHeader := r.BasicAuth()
182 client, err := context.GetClient(clientID)
183 if err == ErrClientNotFound {
184 w.WriteHeader(http.StatusUnauthorized)
185 if fromAuthHeader {
186 w.Header().Set("WWW-Authenticate", "Basic")
187 }
188 renderJSONError(enc, "invalid_client")
189 return nil, false
190 } else if err != nil {
191 w.WriteHeader(http.StatusInternalServerError)
192 renderJSONError(enc, "server_error")
193 return nil, false
194 }
195 if client.Secret != clientSecret { // it's important that any client deemed "public" is not issued a client secret.
196 w.WriteHeader(http.StatusUnauthorized)
197 if fromAuthHeader {
198 w.Header().Set("WWW-Authenticate", "Basic")
199 }
200 renderJSONError(enc, "invalid_client")
201 return nil, false
202 }
203 return clientID, true
204 }
206 // Endpoint represents a single URI that a Client
207 // controls. Users will be redirected to these URIs
208 // following successful authorization grants and
209 // exchanges for access tokens.
210 type Endpoint struct {
211 ID uuid.ID `json:"id,omitempty"`
212 ClientID uuid.ID `json:"client_id,omitempty"`
213 URI string `json:"uri,omitempty"`
214 NormalizedURI string `json:"-"`
215 Added time.Time `json:"added,omitempty"`
216 }
218 func normalizeURIString(in string) (string, error) {
219 n, err := purell.NormalizeURLString(in, normalizeFlags)
220 if err != nil {
221 log.Println(err)
222 return in, ErrEndpointURINotURL
223 }
224 return n, nil
225 }
227 func normalizeURI(in *url.URL) string {
228 return purell.NormalizeURL(in, normalizeFlags)
229 }
231 type sortedEndpoints []Endpoint
233 func (s sortedEndpoints) Len() int {
234 return len(s)
235 }
237 func (s sortedEndpoints) Less(i, j int) bool {
238 return s[i].Added.Before(s[j].Added)
239 }
241 func (s sortedEndpoints) Swap(i, j int) {
242 s[i], s[j] = s[j], s[i]
243 }
245 type clientStore interface {
246 getClient(id uuid.ID) (Client, error)
247 saveClient(client Client) error
248 updateClient(id uuid.ID, change ClientChange) error
249 deleteClient(id uuid.ID) error
250 listClientsByOwner(ownerID uuid.ID, num, offset int) ([]Client, error)
252 addEndpoints(client uuid.ID, endpoint []Endpoint) error
253 removeEndpoint(client, endpoint uuid.ID) error
254 checkEndpoint(client uuid.ID, endpoint string) (bool, error)
255 listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error)
256 countEndpoints(client uuid.ID) (int64, error)
257 }
259 func (m *memstore) getClient(id uuid.ID) (Client, error) {
260 m.clientLock.RLock()
261 defer m.clientLock.RUnlock()
262 c, ok := m.clients[id.String()]
263 if !ok {
264 return Client{}, ErrClientNotFound
265 }
266 return c, nil
267 }
269 func (m *memstore) saveClient(client Client) error {
270 m.clientLock.Lock()
271 defer m.clientLock.Unlock()
272 if _, ok := m.clients[client.ID.String()]; ok {
273 return ErrClientAlreadyExists
274 }
275 m.clients[client.ID.String()] = client
276 m.profileClientLookup[client.OwnerID.String()] = append(m.profileClientLookup[client.OwnerID.String()], client.ID)
277 return nil
278 }
280 func (m *memstore) updateClient(id uuid.ID, change ClientChange) error {
281 m.clientLock.Lock()
282 defer m.clientLock.Unlock()
283 c, ok := m.clients[id.String()]
284 if !ok {
285 return ErrClientNotFound
286 }
287 c.ApplyChange(change)
288 m.clients[id.String()] = c
289 return nil
290 }
292 func (m *memstore) deleteClient(id uuid.ID) error {
293 client, err := m.getClient(id)
294 if err != nil {
295 return err
296 }
297 m.clientLock.Lock()
298 defer m.clientLock.Unlock()
299 delete(m.clients, id.String())
300 pos := -1
301 for p, item := range m.profileClientLookup[client.OwnerID.String()] {
302 if item.Equal(id) {
303 pos = p
304 break
305 }
306 }
307 if pos >= 0 {
308 m.profileClientLookup[client.OwnerID.String()] = append(m.profileClientLookup[client.OwnerID.String()][:pos], m.profileClientLookup[client.OwnerID.String()][pos+1:]...)
309 }
310 return nil
311 }
313 func (m *memstore) listClientsByOwner(ownerID uuid.ID, num, offset int) ([]Client, error) {
314 ids := m.lookupClientsByProfileID(ownerID.String())
315 if len(ids) > num+offset {
316 ids = ids[offset : num+offset]
317 } else if len(ids) > offset {
318 ids = ids[offset:]
319 } else {
320 return []Client{}, nil
321 }
322 clients := []Client{}
323 for _, id := range ids {
324 client, err := m.getClient(id)
325 if err != nil {
326 return []Client{}, err
327 }
328 clients = append(clients, client)
329 }
330 return clients, nil
331 }
333 func (m *memstore) addEndpoints(client uuid.ID, endpoints []Endpoint) error {
334 m.endpointLock.Lock()
335 defer m.endpointLock.Unlock()
336 m.endpoints[client.String()] = append(m.endpoints[client.String()], endpoints...)
337 return nil
338 }
340 func (m *memstore) removeEndpoint(client, endpoint uuid.ID) error {
341 m.endpointLock.Lock()
342 defer m.endpointLock.Unlock()
343 pos := -1
344 for p, item := range m.endpoints[client.String()] {
345 if item.ID.Equal(endpoint) {
346 pos = p
347 break
348 }
349 }
350 if pos >= 0 {
351 m.endpoints[client.String()] = append(m.endpoints[client.String()][:pos], m.endpoints[client.String()][pos+1:]...)
352 }
353 return nil
354 }
356 func (m *memstore) checkEndpoint(client uuid.ID, endpoint string) (bool, error) {
357 m.endpointLock.RLock()
358 defer m.endpointLock.RUnlock()
359 for _, candidate := range m.endpoints[client.String()] {
360 if endpoint == candidate.NormalizedURI {
361 return true, nil
362 }
363 }
364 return false, nil
365 }
367 func (m *memstore) listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error) {
368 m.endpointLock.RLock()
369 defer m.endpointLock.RUnlock()
370 return m.endpoints[client.String()], nil
371 }
373 func (m *memstore) countEndpoints(client uuid.ID) (int64, error) {
374 m.endpointLock.RLock()
375 defer m.endpointLock.RUnlock()
376 return int64(len(m.endpoints[client.String()])), nil
377 }
379 type newClientReq struct {
380 Name string `json:"name"`
381 Logo string `json:"logo"`
382 Website string `json:"website"`
383 Type string `json:"type"`
384 Endpoints []string `json:"endpoints"`
385 }
387 func RegisterClientHandlers(r *mux.Router, context Context) {
388 r.Handle("/clients", wrap(context, CreateClientHandler)).Methods("POST")
389 // BUG(paddy): We need to implement a handler to retrieve info on a client.
390 // BUG(paddy): We need to implement a handler to list clients.
391 // BUG(paddy): We need to implement a handler to update a client.
392 // BUG(paddy): We need to implement a handler to delete a client. Also, what should that do with the grants and tokens belonging to that client?
393 // BUG(paddy): We need to implement a handler to add an endpoint to a client.
394 // BUG(paddy): We need to implement a handler to remove an endpoint from a client.
395 // BUG(paddy): We need to implement a handler to list endpoints.
396 }
398 func CreateClientHandler(w http.ResponseWriter, r *http.Request, c Context) {
399 errors := []requestError{}
400 username, password, ok := r.BasicAuth()
401 if !ok {
402 errors = append(errors, requestError{Slug: requestErrAccessDenied})
403 encode(w, r, http.StatusUnauthorized, response{Errors: errors})
404 return
405 }
406 profile, err := authenticate(username, password, c)
407 if err != nil {
408 errors = append(errors, requestError{Slug: requestErrAccessDenied})
409 encode(w, r, http.StatusUnauthorized, response{Errors: errors})
410 return
411 }
412 var req newClientReq
413 decoder := json.NewDecoder(r.Body)
414 err = decoder.Decode(&req)
415 if err != nil {
416 encode(w, r, http.StatusBadRequest, invalidFormatResponse)
417 return
418 }
419 if req.Type == "" {
420 errors = append(errors, requestError{Slug: requestErrMissing, Field: "/type"})
421 } else if req.Type != clientTypePublic && req.Type != clientTypeConfidential {
422 errors = append(errors, requestError{Slug: requestErrInvalidValue, Field: "/type"})
423 }
424 if req.Name == "" {
425 errors = append(errors, requestError{Slug: requestErrMissing, Field: "/name"})
426 } else if len(req.Name) < minClientNameLen {
427 errors = append(errors, requestError{Slug: requestErrInsufficient, Field: "/name"})
428 } else if len(req.Name) > maxClientNameLen {
429 errors = append(errors, requestError{Slug: requestErrOverflow, Field: "/name"})
430 }
431 if len(errors) > 0 {
432 encode(w, r, http.StatusBadRequest, response{Errors: errors})
433 return
434 }
435 client := Client{
436 ID: uuid.NewID(),
437 OwnerID: profile.ID,
438 Name: req.Name,
439 Logo: req.Logo,
440 Website: req.Website,
441 Type: req.Type,
442 }
443 if client.Type == clientTypeConfidential {
444 secret := make([]byte, 32)
445 _, err = rand.Read(secret)
446 if err != nil {
447 encode(w, r, http.StatusInternalServerError, actOfGodResponse)
448 return
449 }
450 client.Secret = hex.EncodeToString(secret)
451 }
452 err = c.SaveClient(client)
453 if err != nil {
454 if err == ErrClientAlreadyExists {
455 errors = append(errors, requestError{Slug: requestErrConflict, Field: "/id"})
456 encode(w, r, http.StatusBadRequest, response{Errors: errors})
457 return
458 }
459 encode(w, r, http.StatusInternalServerError, actOfGodResponse)
460 return
461 }
462 endpoints := []Endpoint{}
463 for pos, u := range req.Endpoints {
464 uri, err := url.Parse(u)
465 if err != nil {
466 errors = append(errors, requestError{Slug: requestErrInvalidFormat, Field: "/endpoints/" + strconv.Itoa(pos)})
467 continue
468 }
469 if !uri.IsAbs() {
470 errors = append(errors, requestError{Slug: requestErrInvalidValue, Field: "/endpoints/" + strconv.Itoa(pos)})
471 continue
472 }
473 endpoint := Endpoint{
474 ID: uuid.NewID(),
475 ClientID: client.ID,
476 URI: uri.String(),
477 Added: time.Now(),
478 }
479 endpoints = append(endpoints, endpoint)
480 }
481 err = c.AddEndpoints(client.ID, endpoints)
482 if err != nil {
483 errors = append(errors, requestError{Slug: requestErrActOfGod})
484 encode(w, r, http.StatusInternalServerError, response{Errors: errors, Clients: []Client{client}})
485 return
486 }
487 resp := response{
488 Clients: []Client{client},
489 Endpoints: endpoints,
490 Errors: errors,
491 }
492 encode(w, r, http.StatusCreated, resp)
493 }
495 func clientCredentialsValidate(w http.ResponseWriter, r *http.Request, context Context) (scope string, profileID uuid.ID, valid bool) {
496 scope = r.PostFormValue("scope")
497 valid = true
498 return
499 }
501 func clientCredentialsAuditString(r *http.Request) string {
502 return "client_credentials"
503 }