auth

Paddy 2015-01-28 Parent:f474ce964dcf Child:bc842183181d

132:163ce22fa4c9 Go to Latest

auth/client_test.go

Enable CSRF protection, add expiration to sessions. Sessions gain a CSRF token, which is passed as a parameter to the login page. The login page now checks for that CSRF token, and logs a CSRF attempt if the token does not match. I also added an expiration to sessions, so they don't last forever. Sessions should be pretty short--we just need to stay logged in for long enough to approve the OAuth request. Everything after that should be cookie based. Finally, I added a configuration parameter to control whether the session cookie should be set to Secure, requiring the use of HTTPS. For production use, this flag is a requirement, but it makes testing extremely difficult, so we need a way to disable it.

History
1 package auth
3 import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "io/ioutil"
8 "net/http"
9 "net/http/httptest"
10 "net/url"
11 "sort"
12 "strings"
13 "testing"
14 "time"
16 "code.secondbit.org/uuid.hg"
17 )
19 const (
20 clientChangeSecret = 1 << iota
21 clientChangeOwnerID
22 clientChangeName
23 clientChangeLogo
24 clientChangeWebsite
25 )
27 var clientStores = []clientStore{NewMemstore()}
29 func compareClients(client1, client2 Client) (success bool, field string, val1, val2 interface{}) {
30 if !client1.ID.Equal(client2.ID) {
31 return false, "ID", client1.ID, client2.ID
32 }
33 if client1.Secret != client2.Secret {
34 return false, "secret", client1.Secret, client2.Secret
35 }
36 if !client1.OwnerID.Equal(client2.OwnerID) {
37 return false, "owner ID", client1.OwnerID, client2.OwnerID
38 }
39 if client1.Name != client2.Name {
40 return false, "name", client1.Name, client2.Name
41 }
42 if client1.Logo != client2.Logo {
43 return false, "logo", client1.Logo, client2.Logo
44 }
45 if client1.Website != client2.Website {
46 return false, "website", client1.Website, client2.Website
47 }
48 if client1.Type != client2.Type {
49 return false, "type", client1.Type, client2.Type
50 }
51 return true, "", nil, nil
52 }
54 func compareEndpoints(endpoint1, endpoint2 Endpoint) (success bool, field string, val1, val2 interface{}) {
55 if !endpoint1.ID.Equal(endpoint2.ID) {
56 return false, "ID", endpoint1.ID, endpoint2.ID
57 }
58 if !endpoint1.ClientID.Equal(endpoint2.ClientID) {
59 return false, "OwnerID", endpoint1.ClientID, endpoint2.ClientID
60 }
61 if !endpoint1.Added.Equal(endpoint2.Added) {
62 return false, "Added", endpoint1.Added, endpoint2.Added
63 }
64 if endpoint1.URI != endpoint2.URI {
65 return false, "URI", endpoint1.URI, endpoint2.URI
66 }
67 return true, "", nil, nil
68 }
70 func TestClientStoreSuccess(t *testing.T) {
71 t.Parallel()
72 client := Client{
73 ID: uuid.NewID(),
74 Secret: "secret",
75 OwnerID: uuid.NewID(),
76 Name: "name",
77 Logo: "logo",
78 Website: "website",
79 }
80 for _, store := range clientStores {
81 context := Context{clients: store}
82 err := context.SaveClient(client)
83 if err != nil {
84 t.Fatalf("Error saving client to %T: %s", store, err)
85 }
86 err = context.SaveClient(client)
87 if err != ErrClientAlreadyExists {
88 t.Fatalf("Expected ErrClientAlreadyExists, got %v from %T", err, store)
89 }
90 retrieved, err := context.GetClient(client.ID)
91 if err != nil {
92 t.Fatalf("Error retrieving client from %T: %s", store, err)
93 }
94 success, field, expectation, result := compareClients(client, retrieved)
95 if !success {
96 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
97 }
98 clients, err := context.ListClientsByOwner(client.OwnerID, 25, 0)
99 if err != nil {
100 t.Fatalf("Error retrieving clients by owner from %T: %s", store, err)
101 }
102 if len(clients) != 1 {
103 t.Fatalf("Expected 1 client in response from %T, got %+v", store, clients)
104 }
105 success, field, expectation, result = compareClients(client, clients[0])
106 if !success {
107 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
108 }
109 err = context.DeleteClient(client.ID)
110 if err != nil {
111 t.Fatalf("Error deleting client from %T: %s", store, err)
112 }
113 err = context.DeleteClient(client.ID)
114 if err != ErrClientNotFound {
115 t.Fatalf("Expected ErrClientNotFound, got %s from %T", err, store)
116 }
117 retrieved, err = context.GetClient(client.ID)
118 if err != ErrClientNotFound {
119 t.Fatalf("Expected ErrClientNotFound from %T, got %+v and %s", store, retrieved, err)
120 }
121 clients, err = context.ListClientsByOwner(client.OwnerID, 25, 0)
122 if err != nil {
123 t.Fatalf("Error listing clients by owner from %T: %s", store, err)
124 }
125 if len(clients) != 0 {
126 t.Fatalf("Expected 0 clients in response from %T, got %+v", store, clients)
127 }
128 }
129 }
131 func TestEndpointStoreSuccess(t *testing.T) {
132 t.Parallel()
133 client := Client{
134 ID: uuid.NewID(),
135 Secret: "secret",
136 OwnerID: uuid.NewID(),
137 Name: "name",
138 Logo: "logo",
139 Website: "website",
140 }
141 endpoint1 := Endpoint{
142 ID: uuid.NewID(),
143 ClientID: client.ID,
144 Added: time.Now(),
145 URI: "https://www.example.com/",
146 }
147 endpoint2 := Endpoint{
148 ID: uuid.NewID(),
149 ClientID: client.ID,
150 Added: time.Now(),
151 URI: "https://www.example.com/my/full/path",
152 }
153 for _, store := range clientStores {
154 context := Context{clients: store}
155 err := context.SaveClient(client)
156 if err != nil {
157 t.Fatalf("Error saving client to %T: %s", store, err)
158 }
159 err = context.AddEndpoints(client.ID, []Endpoint{endpoint1})
160 if err != nil {
161 t.Fatalf("Error adding endpoint to client in %T: %s", store, err)
162 }
163 endpoints, err := context.ListEndpoints(client.ID, 10, 0)
164 if err != nil {
165 t.Fatalf("Error retrieving endpoints from %T: %s", store, err)
166 }
167 if len(endpoints) != 1 {
168 t.Fatalf("Expected %d endpoints, got %+v from %T", 1, endpoints, store)
169 }
170 success, field, expectation, result := compareEndpoints(endpoint1, endpoints[0])
171 if !success {
172 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
173 }
174 err = context.AddEndpoints(client.ID, []Endpoint{endpoint2})
175 if err != nil {
176 t.Fatalf("Error adding endpoint to client in %T: %s", store, err)
177 }
178 endpoints, err = context.ListEndpoints(client.ID, 10, 0)
179 if err != nil {
180 t.Fatalf("Error retrieving endpoints from %T: %s", store, err)
181 }
182 if len(endpoints) != 2 {
183 t.Fatalf("Expected %d endpoints, got %+v from %T", 2, endpoints, store)
184 }
185 sortedEnd := sortedEndpoints(endpoints)
186 sort.Sort(sortedEnd)
187 endpoints = []Endpoint(sortedEnd)
188 success, field, expectation, result = compareEndpoints(endpoint1, endpoints[0])
189 if !success {
190 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
191 }
192 success, field, expectation, result = compareEndpoints(endpoint2, endpoints[1])
193 if !success {
194 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
195 }
196 err = context.RemoveEndpoint(client.ID, endpoint1.ID)
197 if err != nil {
198 t.Fatalf("Error removing endpoint from client in %T: %s", store, err)
199 }
200 endpoints, err = context.ListEndpoints(client.ID, 10, 0)
201 if err != nil {
202 t.Fatalf("Error listing endpoints in %T: %s", store, err)
203 }
204 if len(endpoints) != 1 {
205 t.Fatalf("Expected %d endpoints, got %+v from %T", 1, endpoints, store)
206 }
207 success, field, expectation, result = compareEndpoints(endpoint2, endpoints[0])
208 if !success {
209 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
210 }
211 err = context.RemoveEndpoint(client.ID, endpoint2.ID)
212 if err != nil {
213 t.Fatalf("Error removing endpoint from client in %T: %s", store, err)
214 }
215 endpoints, err = context.ListEndpoints(client.ID, 10, 0)
216 if err != nil {
217 t.Fatalf("Error listing endpoints in %T: %s", store, err)
218 }
219 if len(endpoints) != 0 {
220 t.Fatalf("Expected %d endpoints, got %+v from %T", 0, endpoints, store)
221 }
222 }
223 }
225 func TestClientUpdates(t *testing.T) {
226 t.Parallel()
227 variations := 1 << 5
228 client := Client{
229 ID: uuid.NewID(),
230 Secret: "secret",
231 OwnerID: uuid.NewID(),
232 Name: "name",
233 Logo: "logo",
234 Website: "website",
235 }
236 for i := 0; i < variations; i++ {
237 var secret, name, logo, website string
238 change := ClientChange{}
239 expectation := client
240 result := client
241 if i&clientChangeSecret != 0 {
242 secret = fmt.Sprintf("secret-%d", i)
243 change.Secret = &secret
244 expectation.Secret = secret
245 }
246 if i&clientChangeOwnerID != 0 {
247 change.OwnerID = uuid.NewID()
248 expectation.OwnerID = change.OwnerID
249 }
250 if i&clientChangeName != 0 {
251 name = fmt.Sprintf("name-%d", i)
252 change.Name = &name
253 expectation.Name = name
254 }
255 if i&clientChangeLogo != 0 {
256 logo = fmt.Sprintf("logo-%d", i)
257 change.Logo = &logo
258 expectation.Logo = logo
259 }
260 if i&clientChangeWebsite != 0 {
261 website = fmt.Sprintf("website-%d", i)
262 change.Website = &website
263 expectation.Website = website
264 }
265 result.ApplyChange(change)
266 match, field, expected, got := compareClients(expectation, result)
267 if !match {
268 t.Fatalf("Expected field `%s` to be `%v`, got `%v`", field, expected, got)
269 }
270 for _, store := range clientStores {
271 context := Context{clients: store}
272 err := context.SaveClient(client)
273 if err != nil {
274 t.Fatalf("Error saving client in %T: %s", store, err)
275 }
276 err = context.UpdateClient(client.ID, change)
277 if err != nil {
278 t.Fatalf("Error updating client in %T: %s", store, err)
279 }
280 retrieved, err := context.GetClient(client.ID)
281 if err != nil {
282 t.Fatalf("Error getting client from %T: %s", store, err)
283 }
284 match, field, expected, got = compareClients(expectation, retrieved)
285 if !match {
286 t.Fatalf("Expected field `%s` to be `%v`, got `%v` from %T", field, expected, got, store)
287 }
288 err = context.DeleteClient(client.ID)
289 if err != nil {
290 t.Fatalf("Error deleting client from %T: %s", store, err)
291 }
292 err = context.UpdateClient(client.ID, change)
293 if err != ErrClientNotFound {
294 t.Fatalf("Expected ErrClientNotFound, got %v from %T", err, store)
295 }
296 }
297 }
298 }
300 func TestClientEndpointChecks(t *testing.T) {
301 t.Parallel()
302 client := Client{
303 ID: uuid.NewID(),
304 Secret: "secret",
305 OwnerID: uuid.NewID(),
306 Name: "name",
307 Logo: "logo",
308 Website: "website",
309 }
310 endpoint1 := Endpoint{
311 ID: uuid.NewID(),
312 ClientID: client.ID,
313 Added: time.Now(),
314 URI: "https://www.example.com/first",
315 }
316 endpoint2 := Endpoint{
317 ID: uuid.NewID(),
318 ClientID: client.ID,
319 Added: time.Now(),
320 URI: "https://www.example.com/my/full/path",
321 }
322 candidates := map[string]bool{
323 "https://www.example.com/": false,
324 "https://www.example.com/first": true,
325 "https://www.example.com/first/extra/path": false,
326 "https://www.example.com/my": false,
327 "https://www.example.com/my/full/path": true,
328 }
329 for _, store := range clientStores {
330 context := Context{clients: store}
331 err := context.SaveClient(client)
332 if err != nil {
333 t.Fatalf("Error saving client in %T: %s", store, err)
334 }
335 err = context.AddEndpoints(client.ID, []Endpoint{endpoint1})
336 if err != nil {
337 t.Fatalf("Error saving endpoint in %T: %s", store, err)
338 }
339 err = context.AddEndpoints(client.ID, []Endpoint{endpoint2})
340 if err != nil {
341 t.Fatalf("Error saving endpoint in %T: %s", store, err)
342 }
343 for candidate, expectation := range candidates {
344 result, err := context.CheckEndpoint(client.ID, candidate)
345 if err != nil {
346 t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
347 }
348 if result != expectation {
349 expectStr := "no"
350 resultStr := "a"
351 if expectation {
352 expectStr = "a"
353 resultStr = "no"
354 }
355 t.Errorf("Expected %s match for %s in %T, got %s match", expectStr, candidate, store, resultStr)
356 }
357 }
358 }
359 }
361 func TestClientEndpointChecksStrict(t *testing.T) {
362 t.Parallel()
363 client := Client{
364 ID: uuid.NewID(),
365 Secret: "secret",
366 OwnerID: uuid.NewID(),
367 Name: "name",
368 Logo: "logo",
369 Website: "website",
370 }
371 endpoint1 := Endpoint{
372 ID: uuid.NewID(),
373 ClientID: client.ID,
374 Added: time.Now(),
375 URI: "https://www.example.com/first",
376 }
377 endpoint2 := Endpoint{
378 ID: uuid.NewID(),
379 ClientID: client.ID,
380 Added: time.Now(),
381 URI: "https://www.example.com/my/full/path",
382 }
383 candidates := map[string]bool{
384 "https://www.example.com/": false,
385 "https://www.example.com/first": true,
386 "https://www.example.com/first/extra/path": false,
387 "https://www.example.com/my": false,
388 "https://www.example.com/my/full/path": true,
389 }
390 for _, store := range clientStores {
391 context := Context{clients: store}
392 err := context.SaveClient(client)
393 if err != nil {
394 t.Fatalf("Error saving client in %T: %s", store, err)
395 }
396 err = context.AddEndpoints(client.ID, []Endpoint{endpoint1})
397 if err != nil {
398 t.Fatalf("Error saving endpoint in %T: %s", store, err)
399 }
400 err = context.AddEndpoints(client.ID, []Endpoint{endpoint2})
401 if err != nil {
402 t.Fatalf("Error saving endpoint in %T: %s", store, err)
403 }
404 for candidate, expectation := range candidates {
405 result, err := context.CheckEndpoint(client.ID, candidate)
406 if err != nil {
407 t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
408 }
409 if result != expectation {
410 expectStr := "no"
411 resultStr := "a"
412 if expectation {
413 expectStr = "a"
414 resultStr = "no"
415 }
416 t.Errorf("Expected %s match for %s in %T, got %s match", expectStr, candidate, store, resultStr)
417 }
418 }
419 }
420 }
422 func TestClientChangeValidation(t *testing.T) {
423 t.Parallel()
424 change := ClientChange{}
425 if err := change.Validate(); err != ErrEmptyChange {
426 t.Errorf("Expected %s to give an error of %s, gave %s", "empty change", ErrEmptyChange, err)
427 }
428 names := map[string]error{
429 "a": ErrClientNameTooShort,
430 "ab": nil,
431 "abc": nil,
432 "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopq": ErrClientNameTooLong,
433 }
434 for name, expectation := range names {
435 change = ClientChange{Name: &name}
436 if err := change.Validate(); err != expectation {
437 t.Errorf("Expected %s to give an error of %s, gave %s", name, expectation, err)
438 }
439 }
440 longPath := ""
441 for i := 0; i < 1025; i++ {
442 longPath = fmt.Sprintf("%s%d", longPath, i)
443 }
444 logos := map[string]error{
445 "https://www.example.com/" + longPath: ErrClientLogoTooLong,
446 "https://www.example.com/ab": nil,
447 "www.example.com/ab": ErrClientLogoNotURL,
448 "test": ErrClientLogoNotURL,
449 "": nil,
450 }
451 for logo, expectation := range logos {
452 change = ClientChange{Logo: &logo}
453 if err := change.Validate(); err != expectation {
454 t.Errorf("Expected %s to give an error of %s, gave %s", logo, expectation, err)
455 }
456 }
457 websites := map[string]error{
458 "https://www.example.com/" + longPath: ErrClientWebsiteTooLong,
459 "https://www.example.com/ab": nil,
460 "www.example.com/ab": ErrClientWebsiteNotURL,
461 "test": ErrClientWebsiteNotURL,
462 "": nil,
463 }
464 for website, expectation := range websites {
465 change = ClientChange{Website: &website}
466 if err := change.Validate(); err != expectation {
467 t.Errorf("Expected %s to give an error of %s, gave %s", website, expectation, err)
468 }
469 }
470 }
472 func TestGetClientAuth(t *testing.T) {
473 t.Parallel()
474 type clientAuthRequest struct {
475 username string
476 pass string
477 clientID string
478 allowPublic bool
479 expectedClientID uuid.ID
480 expectedClientSecret string
481 expectedValid bool
482 expectedCode int
483 expectedBody string
484 expectAuthenticateHeader bool
485 }
486 id := uuid.NewID()
487 tests := []clientAuthRequest{
488 {"", "", "", false, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
489 {"", "", "", true, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
490 {"", "no clientID set", "", false, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
491 {"", "no clientID set", "", true, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
492 {"not an actual id", "invalid client ID set", "", false, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
493 {"not an actual id", "invalid client ID set", "", true, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
494 {"", "", "not an actual id", true, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
495 {id.String(), "secret", "", true, id, "secret", true, http.StatusOK, "", false},
496 {id.String(), "secret", "", false, id, "secret", true, http.StatusOK, "", false},
497 {"", "", id.String(), true, id, "", true, http.StatusOK, "", false},
498 {"", "", id.String(), false, nil, "", false, http.StatusBadRequest, `{"error":"unauthorized_client"}`, false},
499 }
500 for pos, test := range tests {
501 t.Logf("Running test #%d, with request %+v", pos, test)
502 w := httptest.NewRecorder()
503 r, err := http.NewRequest("POST", "https://test.auth.secondbit.org/oauth2/grant", nil)
504 if err != nil {
505 t.Fatal("Can't build request:", err)
506 }
507 if test.username != "" || test.pass != "" {
508 r.SetBasicAuth(test.username, test.pass)
509 }
510 r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
511 params := url.Values{}
512 params.Set("client_id", test.clientID)
513 body := bytes.NewBufferString(params.Encode())
514 r.Body = ioutil.NopCloser(body)
515 respID, respSecret, success := getClientAuth(w, r, test.allowPublic)
516 if (respID == nil && test.expectedClientID != nil) || (respID != nil && test.expectedClientID == nil) || !respID.Equal(test.expectedClientID) {
517 t.Errorf("Expected response ID to be %v, got %v", test.expectedClientID, respID)
518 }
519 if test.expectedClientSecret != respSecret {
520 t.Errorf("Expected response secret to be '%s', got '%s'", test.expectedClientSecret, respSecret)
521 }
522 if test.expectedValid != success {
523 t.Errorf("Expected success result to be %v, got %v", test.expectedValid, success)
524 }
525 if test.expectedCode != w.Code {
526 t.Errorf("Expected response code to be %d, got %d", test.expectedCode, w.Code)
527 }
528 if test.expectedBody != strings.TrimSpace(w.Body.String()) {
529 t.Errorf("Expected body to be '%s', got '%s'", test.expectedBody, strings.TrimSpace(w.Body.String()))
530 }
531 if test.expectAuthenticateHeader && w.Header().Get("WWW-Authenticate") != "Basic" {
532 t.Errorf(`Expected header WWW-Authenticate to be set to "Basic", got "%s"`, w.Header().Get("WWW-Authenticate"))
533 }
534 }
535 }
537 func TestVerifyClient(t *testing.T) {
538 t.Parallel()
539 type verifyClientRequest struct {
540 username string
541 pass string
542 clientID string
543 allowPublic bool
544 expectedClientID uuid.ID
545 expectedValid bool
546 expectedCode int
547 expectedBody string
548 expectAuthenticateHeader bool
549 }
550 memstore := NewMemstore()
551 context := Context{
552 clients: memstore,
553 }
554 client := Client{
555 ID: uuid.NewID(),
556 Secret: "super secret!",
557 OwnerID: uuid.NewID(),
558 Name: "My test client",
559 Logo: "https://secondbit.org/logo.png",
560 Website: "https://secondbit.org/",
561 Type: "confidential",
562 }
563 err := context.SaveClient(client)
564 if err != nil {
565 t.Fatal("Could not save client:", err)
566 }
567 publicClient := Client{
568 ID: uuid.NewID(),
569 Secret: "",
570 OwnerID: uuid.NewID(),
571 Name: "A public client",
572 Logo: "https://secondbit.org/logo.png",
573 Website: "https://secondbit.org/",
574 Type: "public",
575 }
576 err = context.SaveClient(publicClient)
577 if err != nil {
578 t.Fatal("Could not save client:", err)
579 }
580 id := uuid.NewID()
581 tests := []verifyClientRequest{
582 {"", "", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
583 {"", "", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
584 {"", "no clientID set", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
585 {"", "no clientID set", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
586 {"not an actual id", "invalid client ID set", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
587 {"not an actual id", "invalid client ID set", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
588 {id.String(), "unsaved client ID set", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
589 {id.String(), "unsaved client ID set", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
590 {client.ID.String(), "wrong secret", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
591 {client.ID.String(), "wrong secret", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
592 {"", "", "not an actual id", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
593 {"", "", id.String(), true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
594 {client.ID.String(), client.Secret, "", true, client.ID, true, http.StatusOK, "", false},
595 {client.ID.String(), client.Secret, "", false, client.ID, true, http.StatusOK, "", false},
596 {"", "", publicClient.ID.String(), true, publicClient.ID, true, http.StatusOK, "", false},
597 {"", "", publicClient.ID.String(), false, nil, false, http.StatusBadRequest, `{"error":"unauthorized_client"}`, false},
598 }
600 for pos, test := range tests {
601 t.Logf("Running test #%d, with request %+v", pos, test)
602 w := httptest.NewRecorder()
603 r, err := http.NewRequest("POST", "https://test.auth.secondbit.org/oauth2/grant", nil)
604 if err != nil {
605 t.Fatal("Can't build request:", err)
606 }
607 if test.username != "" || test.pass != "" {
608 r.SetBasicAuth(test.username, test.pass)
609 }
610 r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
611 params := url.Values{}
612 params.Set("client_id", test.clientID)
613 body := bytes.NewBufferString(params.Encode())
614 r.Body = ioutil.NopCloser(body)
615 respID, success := verifyClient(w, r, test.allowPublic, context)
616 if (respID == nil && test.expectedClientID != nil) || (respID != nil && test.expectedClientID == nil) || !respID.Equal(test.expectedClientID) {
617 t.Errorf("Expected response ID to be %v, got %v", test.expectedClientID, respID)
618 }
619 if test.expectedValid != success {
620 t.Errorf("Expected success result to be %v, got %v", test.expectedValid, success)
621 }
622 if test.expectedCode != w.Code {
623 t.Errorf("Expected response code to be %d, got %d", test.expectedCode, w.Code)
624 }
625 if test.expectedBody != strings.TrimSpace(w.Body.String()) {
626 t.Errorf("Expected body to be '%s', got '%s'", test.expectedBody, strings.TrimSpace(w.Body.String()))
627 }
628 if test.expectAuthenticateHeader && w.Header().Get("WWW-Authenticate") != "Basic" {
629 t.Errorf(`Expected header WWW-Authenticate to be set to "Basic", got "%s"`, w.Header().Get("WWW-Authenticate"))
630 }
631 }
632 }
634 func TestCreateClientHandler(t *testing.T) {
635 t.Parallel()
636 memstore := NewMemstore()
637 c := Context{
638 clients: memstore,
639 profiles: memstore,
640 }
641 w := httptest.NewRecorder()
642 r, err := http.NewRequest("POST", "https://test.auth.secondbit.org/clients", nil)
643 if err != nil {
644 t.Fatal("Can't build request:", err)
645 }
646 r.Header.Set("Content-Type", "application/json")
647 CreateClientHandler(w, r, c)
648 if w.Code != http.StatusUnauthorized {
649 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
650 }
651 expected := `{"errors":[{"error":"access_denied"}]}`
652 result := strings.TrimSpace(w.Body.String())
653 if result != expected {
654 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
655 }
656 w = httptest.NewRecorder()
657 r.Header.Set("Authorization", "Not basic at all...")
658 CreateClientHandler(w, r, c)
659 if w.Code != http.StatusUnauthorized {
660 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
661 }
662 expected = `{"errors":[{"error":"access_denied"}]}`
663 result = strings.TrimSpace(w.Body.String())
664 if result != expected {
665 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
666 }
667 w = httptest.NewRecorder()
668 r.Header.Set("Authorization", "Basic TotallyNotBase64Encoded")
669 CreateClientHandler(w, r, c)
670 if w.Code != http.StatusUnauthorized {
671 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
672 }
673 expected = `{"errors":[{"error":"access_denied"}]}`
674 result = strings.TrimSpace(w.Body.String())
675 if result != expected {
676 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
677 }
678 w = httptest.NewRecorder()
679 r.Header.Set("Authorization", "Basic dGhpc2hhc25vY29sb24=")
680 CreateClientHandler(w, r, c)
681 if w.Code != http.StatusUnauthorized {
682 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
683 }
684 expected = `{"errors":[{"error":"access_denied"}]}`
685 result = strings.TrimSpace(w.Body.String())
686 if result != expected {
687 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
688 }
689 profile := Profile{
690 ID: uuid.NewID(),
691 Name: "Test User",
692 Passphrase: "f3a4ac4f1d657b2e6e776d24213e39406d50a87a52691a2a78891425af1271d0",
693 Iterations: 1,
694 Salt: "d82d92cfa8bfb5a08270ebbf39a3710d24b352b937fcc8959ebcb40384cc616b",
695 PassphraseScheme: 1,
696 Compromised: false,
697 LockedUntil: time.Time{},
698 PassphraseReset: "",
699 PassphraseResetCreated: time.Time{},
700 Created: time.Now(),
701 LastSeen: time.Time{},
702 }
703 login := Login{
704 Type: "email",
705 Value: "test@example.com",
706 ProfileID: profile.ID,
707 Created: time.Now(),
708 LastUsed: time.Time{},
709 }
710 w = httptest.NewRecorder()
711 r.SetBasicAuth("test@example.com", "mysecurepassphrase")
712 CreateClientHandler(w, r, c)
713 if w.Code != http.StatusUnauthorized {
714 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
715 }
716 expected = `{"errors":[{"error":"access_denied"}]}`
717 result = strings.TrimSpace(w.Body.String())
718 if result != expected {
719 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
720 }
721 err = c.SaveProfile(profile)
722 if err != nil {
723 t.Error("Error saving profile:", err)
724 }
725 err = c.AddLogin(login)
726 if err != nil {
727 t.Error("Error adding login:", err)
728 }
729 r.SetBasicAuth("test@example.com", "mysecurepassphrase")
730 type testStruct struct {
731 request string
732 code int
733 resp response
734 }
735 tests := []testStruct{
736 {``, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidFormat, Field: "/"}}}},
737 {`{}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrMissing, Field: "/type"}, {Slug: requestErrMissing, Field: "/name"}}}},
738 {`{"type":"notarealtype"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidValue, Field: "/type"}, {Slug: requestErrMissing, Field: "/name"}}}},
739 {`{"type":"notarealtype","name":"myreallylongnameislongerthatthemaximumnamelength"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidValue, Field: "/type"}, {Slug: requestErrOverflow, Field: "/name"}}}},
740 {`{"type":"notarealtype","name":"a"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidValue, Field: "/type"}, {Slug: requestErrInsufficient, Field: "/name"}}}},
741 {`{"type":"public"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrMissing, Field: "/name"}}}},
742 {`{"type":"public","name":"myreallylongnameislongerthatthemaximumnamelength"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrOverflow, Field: "/name"}}}},
743 {`{"type":"public","name":"a"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInsufficient, Field: "/name"}}}},
744 {`{"name":"My Client"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrMissing, Field: "/type"}}}},
745 {`{"type":"notarealtype","name":"My Client"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidValue, Field: "/type"}}}},
746 {`{"type":"public","name":"My Client"}`, http.StatusCreated, response{Clients: []Client{{Name: "My Client", OwnerID: profile.ID, Type: "public"}}}},
747 {`{"type":"public","name":"My Client", "endpoints": ["https://test.secondbit.org/", "https://paddy.io"]}`, http.StatusCreated, response{Clients: []Client{{Name: "My Client", OwnerID: profile.ID, Type: "public"}}, Endpoints: []Endpoint{{URI: "https://test.secondbit.org/"}, {URI: "https://paddy.io"}}}},
748 {`{"type":"public","name":"My Client", "endpoints": [":/not a url", "https://paddy.io"]}`, http.StatusCreated, response{Clients: []Client{{Name: "My Client", OwnerID: profile.ID, Type: "public"}}, Endpoints: []Endpoint{{URI: "https://paddy.io"}}, Errors: []requestError{{Slug: requestErrInvalidFormat, Field: "/endpoints/0"}}}},
749 {`{"type":"public","name":"My Client", "endpoints": [":/not a url", "/relative/uri", "https://paddy.io"]}`, http.StatusCreated, response{Clients: []Client{{Name: "My Client", OwnerID: profile.ID, Type: "public"}}, Endpoints: []Endpoint{{URI: "https://paddy.io"}}, Errors: []requestError{{Slug: requestErrInvalidFormat, Field: "/endpoints/0"}, {Slug: requestErrInvalidValue, Field: "/endpoints/1"}}}},
750 {`{"type":"confidential","name":"Secret Client", "endpoints": ["https://secondbit.org"]}`, http.StatusCreated, response{Clients: []Client{{Name: "Secret Client", OwnerID: profile.ID, Type: "confidential"}}, Endpoints: []Endpoint{{URI: "https://secondbit.org"}}}},
751 }
752 for pos, test := range tests {
753 t.Logf("Test #%d: `%s`", pos, test.request)
754 w = httptest.NewRecorder()
755 body := bytes.NewBufferString(test.request)
756 r.Body = ioutil.NopCloser(body)
757 CreateClientHandler(w, r, c)
758 if w.Code != test.code {
759 t.Errorf("Expected response code to be %d, got %d", test.code, w.Code)
760 }
761 t.Logf("Response: %s", w.Body.String())
762 var res response
763 err = json.Unmarshal(w.Body.Bytes(), &res)
764 if err != nil {
765 t.Error("Unexpected error unmarshalling response:", err)
766 }
767 if len(res.Clients) > 0 {
768 if res.Clients[0].Type == "confidential" && res.Clients[0].Secret == "" {
769 t.Log("Client:", res.Clients[0])
770 t.Error("Expected confidential client to have a secret, but does not.")
771 } else if res.Clients[0].Type == "public" && res.Clients[0].Secret != "" {
772 t.Log("Client:", res.Clients[0])
773 t.Error("Expected public client to not have a secret, but it does.")
774 }
775 }
776 fillInServerGenerated(test.resp, res)
777 success, field, expectation, result := compareResponses(test.resp, res)
778 if !success {
779 t.Errorf("Unexpected result for %s in response: expected %v, got %v", field, expectation, result)
780 }
781 }
782 }
784 // BUG(paddy): We need to test the clientCredentialsValidate function.
785 // BUG(paddy): We need to test the GetClientHandler.
786 // BUG(paddy): We need to test the ListClientsHandler.