auth

Paddy 2015-01-14 Parent:5bd46746b809 Child:e000b1c24fc0

115:fa8ee6a4507c Go to Latest

auth/client.go

Turn AddEndpoint into AddEndpoints. Because one is a special case of many, it makes sense to be able to add multiple endpoints in a single call to the database. So we've converted the AddEndpoint method into an AddEndpoints method and updated our tests appropriately. We also filled in the errors when creating a client through the API, and moved things around to optimize for the maximum number of errors returned in a single call.

History
1 package auth
3 import (
4 "crypto/rand"
5 "encoding/hex"
6 "encoding/json"
7 "errors"
8 "github.com/gorilla/mux"
9 "net/http"
10 "net/url"
11 "strconv"
12 "time"
14 "code.secondbit.org/uuid.hg"
15 )
17 var (
18 // ErrNoClientStore is returned when a Context tries to act on a clientStore without setting one first.
19 ErrNoClientStore = errors.New("no clientStore was specified for the Context")
20 // ErrClientNotFound is returned when a Client is requested but not found in a clientStore.
21 ErrClientNotFound = errors.New("client not found in clientStore")
22 // ErrClientAlreadyExists is returned when a Client is added to a clientStore, but another Client with
23 // the same ID already exists in the clientStore.
24 ErrClientAlreadyExists = errors.New("client already exists in clientStore")
26 // ErrEmptyChange is returned when a Change has all its properties set to nil.
27 ErrEmptyChange = errors.New("change must have at least one property set")
28 // ErrClientNameTooShort is returned when a Client's Name property is too short.
29 ErrClientNameTooShort = errors.New("client name must be at least 2 characters")
30 // ErrClientNameTooLong is returned when a Client's Name property is too long.
31 ErrClientNameTooLong = errors.New("client name must be at most 32 characters")
32 // ErrClientLogoTooLong is returned when a Client's Logo property is too long.
33 ErrClientLogoTooLong = errors.New("client logo must be at most 1024 characters")
34 // ErrClientLogoNotURL is returned when a Client's Logo property is not a valid absolute URL.
35 ErrClientLogoNotURL = errors.New("client logo must be a valid absolute URL")
36 // ErrClientWebsiteTooLong is returned when a Client's Website property is too long.
37 ErrClientWebsiteTooLong = errors.New("client website must be at most 1024 characters")
38 // ErrClientWebsiteNotURL is returned when a Client's Website property is not a valid absolute URL.
39 ErrClientWebsiteNotURL = errors.New("client website must be a valid absolute URL")
40 )
42 const (
43 clientTypePublic = "public"
44 clientTypeConfidential = "confidential"
45 )
47 // Client represents a client that grants access
48 // to the auth server, exchanging grants for tokens,
49 // and tokens for access.
50 type Client struct {
51 ID uuid.ID
52 Secret string
53 OwnerID uuid.ID
54 Name string
55 Logo string
56 Website string
57 Type string
58 }
60 // ApplyChange applies the properties of the passed
61 // ClientChange to the Client object it is called on.
62 func (c *Client) ApplyChange(change ClientChange) {
63 if change.Secret != nil {
64 c.Secret = *change.Secret
65 }
66 if change.OwnerID != nil {
67 c.OwnerID = change.OwnerID
68 }
69 if change.Name != nil {
70 c.Name = *change.Name
71 }
72 if change.Logo != nil {
73 c.Logo = *change.Logo
74 }
75 if change.Website != nil {
76 c.Website = *change.Website
77 }
78 }
80 // ClientChange represents a bundle of options for
81 // updating a Client's mutable data.
82 type ClientChange struct {
83 Secret *string
84 OwnerID uuid.ID
85 Name *string
86 Logo *string
87 Website *string
88 }
90 // Validate checks the ClientChange it is called on
91 // and asserts its internal validity, or lack thereof.
92 func (c ClientChange) Validate() error {
93 if c.Secret == nil && c.OwnerID == nil && c.Name == nil && c.Logo == nil && c.Website == nil {
94 return ErrEmptyChange
95 }
96 if c.Name != nil && len(*c.Name) < 2 {
97 return ErrClientNameTooShort
98 }
99 if c.Name != nil && len(*c.Name) > 32 {
100 return ErrClientNameTooLong
101 }
102 if c.Logo != nil && *c.Logo != "" {
103 if len(*c.Logo) > 1024 {
104 return ErrClientLogoTooLong
105 }
106 u, err := url.Parse(*c.Logo)
107 if err != nil || !u.IsAbs() {
108 return ErrClientLogoNotURL
109 }
110 }
111 if c.Website != nil && *c.Website != "" {
112 if len(*c.Website) > 140 {
113 return ErrClientWebsiteTooLong
114 }
115 u, err := url.Parse(*c.Website)
116 if err != nil || !u.IsAbs() {
117 return ErrClientWebsiteNotURL
118 }
119 }
120 return nil
121 }
123 func verifyClient(w http.ResponseWriter, r *http.Request, allowPublic bool, context Context) (uuid.ID, bool) {
124 enc := json.NewEncoder(w)
125 clientIDStr, clientSecret, fromAuthHeader := r.BasicAuth()
126 if !fromAuthHeader {
127 if !allowPublic {
128 w.WriteHeader(http.StatusBadRequest)
129 renderJSONError(enc, "unauthorized_client")
130 return nil, false
131 }
132 clientIDStr = r.PostFormValue("client_id")
133 }
134 clientID, err := uuid.Parse(clientIDStr)
135 if err != nil {
136 w.WriteHeader(http.StatusUnauthorized)
137 if fromAuthHeader {
138 w.Header().Set("WWW-Authenticate", "Basic")
139 }
140 renderJSONError(enc, "invalid_client")
141 return nil, false
142 }
143 client, err := context.GetClient(clientID)
144 if err == ErrClientNotFound {
145 w.WriteHeader(http.StatusUnauthorized)
146 if fromAuthHeader {
147 w.Header().Set("WWW-Authenticate", "Basic")
148 }
149 renderJSONError(enc, "invalid_client")
150 return nil, false
151 } else if err != nil {
152 w.WriteHeader(http.StatusInternalServerError)
153 renderJSONError(enc, "server_error")
154 return nil, false
155 }
156 if client.Secret != clientSecret { // it's important that any client deemed "public" is not issued a client secret.
157 w.WriteHeader(http.StatusUnauthorized)
158 if fromAuthHeader {
159 w.Header().Set("WWW-Authenticate", "Basic")
160 }
161 renderJSONError(enc, "invalid_client")
162 return nil, false
163 }
164 return clientID, true
165 }
167 // Endpoint represents a single URI that a Client
168 // controls. Users will be redirected to these URIs
169 // following successful authorization grants and
170 // exchanges for access tokens.
171 type Endpoint struct {
172 ID uuid.ID
173 ClientID uuid.ID
174 URI url.URL
175 Added time.Time
176 }
178 type sortedEndpoints []Endpoint
180 func (s sortedEndpoints) Len() int {
181 return len(s)
182 }
184 func (s sortedEndpoints) Less(i, j int) bool {
185 return s[i].Added.Before(s[j].Added)
186 }
188 func (s sortedEndpoints) Swap(i, j int) {
189 s[i], s[j] = s[j], s[i]
190 }
192 type clientStore interface {
193 getClient(id uuid.ID) (Client, error)
194 saveClient(client Client) error
195 updateClient(id uuid.ID, change ClientChange) error
196 deleteClient(id uuid.ID) error
197 listClientsByOwner(ownerID uuid.ID, num, offset int) ([]Client, error)
199 addEndpoints(client uuid.ID, endpoint []Endpoint) error
200 removeEndpoint(client, endpoint uuid.ID) error
201 checkEndpoint(client uuid.ID, endpoint string) (bool, error)
202 listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error)
203 countEndpoints(client uuid.ID) (int64, error)
204 }
206 func (m *memstore) getClient(id uuid.ID) (Client, error) {
207 m.clientLock.RLock()
208 defer m.clientLock.RUnlock()
209 c, ok := m.clients[id.String()]
210 if !ok {
211 return Client{}, ErrClientNotFound
212 }
213 return c, nil
214 }
216 func (m *memstore) saveClient(client Client) error {
217 m.clientLock.Lock()
218 defer m.clientLock.Unlock()
219 if _, ok := m.clients[client.ID.String()]; ok {
220 return ErrClientAlreadyExists
221 }
222 m.clients[client.ID.String()] = client
223 m.profileClientLookup[client.OwnerID.String()] = append(m.profileClientLookup[client.OwnerID.String()], client.ID)
224 return nil
225 }
227 func (m *memstore) updateClient(id uuid.ID, change ClientChange) error {
228 m.clientLock.Lock()
229 defer m.clientLock.Unlock()
230 c, ok := m.clients[id.String()]
231 if !ok {
232 return ErrClientNotFound
233 }
234 c.ApplyChange(change)
235 m.clients[id.String()] = c
236 return nil
237 }
239 func (m *memstore) deleteClient(id uuid.ID) error {
240 client, err := m.getClient(id)
241 if err != nil {
242 return err
243 }
244 m.clientLock.Lock()
245 defer m.clientLock.Unlock()
246 delete(m.clients, id.String())
247 pos := -1
248 for p, item := range m.profileClientLookup[client.OwnerID.String()] {
249 if item.Equal(id) {
250 pos = p
251 break
252 }
253 }
254 if pos >= 0 {
255 m.profileClientLookup[client.OwnerID.String()] = append(m.profileClientLookup[client.OwnerID.String()][:pos], m.profileClientLookup[client.OwnerID.String()][pos+1:]...)
256 }
257 return nil
258 }
260 func (m *memstore) listClientsByOwner(ownerID uuid.ID, num, offset int) ([]Client, error) {
261 ids := m.lookupClientsByProfileID(ownerID.String())
262 if len(ids) > num+offset {
263 ids = ids[offset : num+offset]
264 } else if len(ids) > offset {
265 ids = ids[offset:]
266 } else {
267 return []Client{}, nil
268 }
269 clients := []Client{}
270 for _, id := range ids {
271 client, err := m.getClient(id)
272 if err != nil {
273 return []Client{}, err
274 }
275 clients = append(clients, client)
276 }
277 return clients, nil
278 }
280 func (m *memstore) addEndpoints(client uuid.ID, endpoints []Endpoint) error {
281 m.endpointLock.Lock()
282 defer m.endpointLock.Unlock()
283 m.endpoints[client.String()] = append(m.endpoints[client.String()], endpoints...)
284 return nil
285 }
287 func (m *memstore) removeEndpoint(client, endpoint uuid.ID) error {
288 m.endpointLock.Lock()
289 defer m.endpointLock.Unlock()
290 pos := -1
291 for p, item := range m.endpoints[client.String()] {
292 if item.ID.Equal(endpoint) {
293 pos = p
294 break
295 }
296 }
297 if pos >= 0 {
298 m.endpoints[client.String()] = append(m.endpoints[client.String()][:pos], m.endpoints[client.String()][pos+1:]...)
299 }
300 return nil
301 }
303 func (m *memstore) checkEndpoint(client uuid.ID, endpoint string) (bool, error) {
304 m.endpointLock.RLock()
305 defer m.endpointLock.RUnlock()
306 for _, candidate := range m.endpoints[client.String()] {
307 if endpoint == candidate.URI.String() {
308 return true, nil
309 }
310 }
311 return false, nil
312 }
314 func (m *memstore) listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error) {
315 m.endpointLock.RLock()
316 defer m.endpointLock.RUnlock()
317 return m.endpoints[client.String()], nil
318 }
320 func (m *memstore) countEndpoints(client uuid.ID) (int64, error) {
321 m.endpointLock.RLock()
322 defer m.endpointLock.RUnlock()
323 return int64(len(m.endpoints[client.String()])), nil
324 }
326 type newClientReq struct {
327 Name string `json:"name"`
328 Logo string `json:"logo"`
329 Website string `json:"website"`
330 Type string `json:"type"`
331 Endpoints []string `json:"endpoints"`
332 }
334 func RegisterClientHandlers(r *mux.Router, context Context) {
335 r.Handle("/clients", wrap(context, CreateClientHandler)).Methods("POST")
336 }
338 func CreateClientHandler(w http.ResponseWriter, r *http.Request, c Context) {
339 errors := []requestError{}
340 username, password, ok := r.BasicAuth()
341 if !ok {
342 errors = append(errors, requestError{Slug: requestErrAccessDenied})
343 encode(w, r, http.StatusUnauthorized, response{Errors: errors})
344 return
345 }
346 profile, err := authenticate(username, password, c)
347 if err != nil {
348 errors = append(errors, requestError{Slug: requestErrAccessDenied})
349 encode(w, r, http.StatusUnauthorized, response{Errors: errors})
350 return
351 }
352 var req newClientReq
353 decoder := json.NewDecoder(r.Body)
354 err = decoder.Decode(&req)
355 if err != nil {
356 encode(w, r, http.StatusBadRequest, invalidFormatResponse)
357 return
358 }
359 if req.Type != clientTypePublic && req.Type != clientTypeConfidential {
360 errors = append(errors, requestError{Slug: requestErrInvalidValue, Field: "/type"})
361 encode(w, r, http.StatusBadRequest, response{Errors: errors})
362 return
363 }
364 client := Client{
365 ID: uuid.NewID(),
366 OwnerID: profile.ID,
367 Name: req.Name,
368 Logo: req.Logo,
369 Website: req.Website,
370 Type: req.Type,
371 }
372 if client.Type == clientTypePublic {
373 secret := make([]byte, 32)
374 _, err = rand.Read(secret)
375 if err != nil {
376 encode(w, r, http.StatusInternalServerError, actOfGodResponse)
377 return
378 }
379 client.Secret = hex.EncodeToString(secret)
380 }
381 err = c.SaveClient(client)
382 if err != nil {
383 if err == ErrClientAlreadyExists {
384 errors = append(errors, requestError{Slug: requestErrConflict, Field: "/id"})
385 encode(w, r, http.StatusBadRequest, response{Errors: errors})
386 return
387 }
388 encode(w, r, http.StatusInternalServerError, actOfGodResponse)
389 return
390 }
391 endpoints := []Endpoint{}
392 for pos, u := range req.Endpoints {
393 uri, err := url.Parse(u)
394 if err != nil {
395 errors = append(errors, requestError{Slug: requestErrInvalidFormat, Field: "/endpoints/" + strconv.Itoa(pos)})
396 continue
397 }
398 endpoint := Endpoint{
399 ID: uuid.NewID(),
400 ClientID: client.ID,
401 URI: *uri,
402 Added: time.Now(),
403 }
404 endpoints = append(endpoints, endpoint)
405 }
406 err = c.AddEndpoints(client.ID, endpoints)
407 if err != nil {
408 errors = append(errors, requestError{Slug: requestErrActOfGod})
409 encode(w, r, http.StatusInternalServerError, response{Errors: errors, Clients: []Client{client}})
410 return
411 }
412 resp := response{
413 Clients: []Client{client},
414 Endpoints: endpoints,
415 }
416 encode(w, r, http.StatusCreated, resp)
417 }