auth
auth/client_test.go
Fix go vet error. We accidentally had a $ instead of a % in our test output, which would have caused an error in printing that output. This fixed it.
1 package auth
3 import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "io/ioutil"
8 "net/http"
9 "net/http/httptest"
10 "net/url"
11 "sort"
12 "strings"
13 "testing"
14 "time"
16 "code.secondbit.org/uuid.hg"
17 )
19 const (
20 clientChangeSecret = 1 << iota
21 clientChangeOwnerID
22 clientChangeName
23 clientChangeLogo
24 clientChangeWebsite
25 )
27 var clientStores = []clientStore{NewMemstore()}
29 func compareClients(client1, client2 Client) (success bool, field string, val1, val2 interface{}) {
30 if !client1.ID.Equal(client2.ID) {
31 return false, "ID", client1.ID, client2.ID
32 }
33 if client1.Secret != client2.Secret {
34 return false, "secret", client1.Secret, client2.Secret
35 }
36 if !client1.OwnerID.Equal(client2.OwnerID) {
37 return false, "owner ID", client1.OwnerID, client2.OwnerID
38 }
39 if client1.Name != client2.Name {
40 return false, "name", client1.Name, client2.Name
41 }
42 if client1.Logo != client2.Logo {
43 return false, "logo", client1.Logo, client2.Logo
44 }
45 if client1.Website != client2.Website {
46 return false, "website", client1.Website, client2.Website
47 }
48 if client1.Type != client2.Type {
49 return false, "type", client1.Type, client2.Type
50 }
51 return true, "", nil, nil
52 }
54 func compareEndpoints(endpoint1, endpoint2 Endpoint) (success bool, field string, val1, val2 interface{}) {
55 if !endpoint1.ID.Equal(endpoint2.ID) {
56 return false, "ID", endpoint1.ID, endpoint2.ID
57 }
58 if !endpoint1.ClientID.Equal(endpoint2.ClientID) {
59 return false, "OwnerID", endpoint1.ClientID, endpoint2.ClientID
60 }
61 if !endpoint1.Added.Equal(endpoint2.Added) {
62 return false, "Added", endpoint1.Added, endpoint2.Added
63 }
64 if endpoint1.URI != endpoint2.URI {
65 return false, "URI", endpoint1.URI, endpoint2.URI
66 }
67 return true, "", nil, nil
68 }
70 func TestClientStoreSuccess(t *testing.T) {
71 t.Parallel()
72 client := Client{
73 ID: uuid.NewID(),
74 Secret: "secret",
75 OwnerID: uuid.NewID(),
76 Name: "name",
77 Logo: "logo",
78 Website: "website",
79 }
80 for _, store := range clientStores {
81 context := Context{clients: store}
82 err := context.SaveClient(client)
83 if err != nil {
84 t.Fatalf("Error saving client to %T: %s", store, err)
85 }
86 err = context.SaveClient(client)
87 if err != ErrClientAlreadyExists {
88 t.Fatalf("Expected ErrClientAlreadyExists, got %v from %T", err, store)
89 }
90 retrieved, err := context.GetClient(client.ID)
91 if err != nil {
92 t.Fatalf("Error retrieving client from %T: %s", store, err)
93 }
94 success, field, expectation, result := compareClients(client, retrieved)
95 if !success {
96 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
97 }
98 clients, err := context.ListClientsByOwner(client.OwnerID, 25, 0)
99 if err != nil {
100 t.Fatalf("Error retrieving clients by owner from %T: %s", store, err)
101 }
102 if len(clients) != 1 {
103 t.Fatalf("Expected 1 client in response from %T, got %+v", store, clients)
104 }
105 success, field, expectation, result = compareClients(client, clients[0])
106 if !success {
107 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
108 }
109 err = context.DeleteClient(client.ID)
110 if err != nil {
111 t.Fatalf("Error deleting client from %T: %s", store, err)
112 }
113 err = context.DeleteClient(client.ID)
114 if err != ErrClientNotFound {
115 t.Fatalf("Expected ErrClientNotFound, got %s from %T", err, store)
116 }
117 retrieved, err = context.GetClient(client.ID)
118 if err != ErrClientNotFound {
119 t.Fatalf("Expected ErrClientNotFound from %T, got %+v and %s", store, retrieved, err)
120 }
121 clients, err = context.ListClientsByOwner(client.OwnerID, 25, 0)
122 if err != nil {
123 t.Fatalf("Error listing clients by owner from %T: %s", store, err)
124 }
125 if len(clients) != 0 {
126 t.Fatalf("Expected 0 clients in response from %T, got %+v", store, clients)
127 }
128 }
129 }
131 func TestEndpointStoreSuccess(t *testing.T) {
132 t.Parallel()
133 client := Client{
134 ID: uuid.NewID(),
135 Secret: "secret",
136 OwnerID: uuid.NewID(),
137 Name: "name",
138 Logo: "logo",
139 Website: "website",
140 }
141 endpoint1 := Endpoint{
142 ID: uuid.NewID(),
143 ClientID: client.ID,
144 Added: time.Now(),
145 URI: "https://www.example.com/",
146 }
147 endpoint2 := Endpoint{
148 ID: uuid.NewID(),
149 ClientID: client.ID,
150 Added: time.Now(),
151 URI: "https://www.example.com/my/full/path",
152 }
153 for _, store := range clientStores {
154 context := Context{clients: store}
155 err := context.SaveClient(client)
156 if err != nil {
157 t.Fatalf("Error saving client to %T: %s", store, err)
158 }
159 err = context.AddEndpoints(client.ID, []Endpoint{endpoint1})
160 if err != nil {
161 t.Fatalf("Error adding endpoint to client in %T: %s", store, err)
162 }
163 endpoints, err := context.ListEndpoints(client.ID, 10, 0)
164 if err != nil {
165 t.Fatalf("Error retrieving endpoints from %T: %s", store, err)
166 }
167 if len(endpoints) != 1 {
168 t.Fatalf("Expected %d endpoints, got %+v from %T", 1, endpoints, store)
169 }
170 success, field, expectation, result := compareEndpoints(endpoint1, endpoints[0])
171 if !success {
172 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
173 }
174 err = context.AddEndpoints(client.ID, []Endpoint{endpoint2})
175 if err != nil {
176 t.Fatalf("Error adding endpoint to client in %T: %s", store, err)
177 }
178 endpoints, err = context.ListEndpoints(client.ID, 10, 0)
179 if err != nil {
180 t.Fatalf("Error retrieving endpoints from %T: %s", store, err)
181 }
182 if len(endpoints) != 2 {
183 t.Fatalf("Expected %d endpoints, got %+v from %T", 2, endpoints, store)
184 }
185 sortedEnd := sortedEndpoints(endpoints)
186 sort.Sort(sortedEnd)
187 endpoints = []Endpoint(sortedEnd)
188 success, field, expectation, result = compareEndpoints(endpoint1, endpoints[0])
189 if !success {
190 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
191 }
192 success, field, expectation, result = compareEndpoints(endpoint2, endpoints[1])
193 if !success {
194 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
195 }
196 err = context.RemoveEndpoint(client.ID, endpoint1.ID)
197 if err != nil {
198 t.Fatalf("Error removing endpoint from client in %T: %s", store, err)
199 }
200 endpoints, err = context.ListEndpoints(client.ID, 10, 0)
201 if err != nil {
202 t.Fatalf("Error listing endpoints in %T: %s", store, err)
203 }
204 if len(endpoints) != 1 {
205 t.Fatalf("Expected %d endpoints, got %+v from %T", 1, endpoints, store)
206 }
207 success, field, expectation, result = compareEndpoints(endpoint2, endpoints[0])
208 if !success {
209 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
210 }
211 err = context.RemoveEndpoint(client.ID, endpoint2.ID)
212 if err != nil {
213 t.Fatalf("Error removing endpoint from client in %T: %s", store, err)
214 }
215 endpoints, err = context.ListEndpoints(client.ID, 10, 0)
216 if err != nil {
217 t.Fatalf("Error listing endpoints in %T: %s", store, err)
218 }
219 if len(endpoints) != 0 {
220 t.Fatalf("Expected %d endpoints, got %+v from %T", 0, endpoints, store)
221 }
222 }
223 }
225 func TestClientUpdates(t *testing.T) {
226 t.Parallel()
227 variations := 1 << 5
228 client := Client{
229 ID: uuid.NewID(),
230 Secret: "secret",
231 OwnerID: uuid.NewID(),
232 Name: "name",
233 Logo: "logo",
234 Website: "website",
235 }
236 for i := 0; i < variations; i++ {
237 var secret, name, logo, website string
238 change := ClientChange{}
239 expectation := client
240 result := client
241 if i&clientChangeSecret != 0 {
242 secret = fmt.Sprintf("secret-%d", i)
243 change.Secret = &secret
244 expectation.Secret = secret
245 }
246 if i&clientChangeOwnerID != 0 {
247 change.OwnerID = uuid.NewID()
248 expectation.OwnerID = change.OwnerID
249 }
250 if i&clientChangeName != 0 {
251 name = fmt.Sprintf("name-%d", i)
252 change.Name = &name
253 expectation.Name = name
254 }
255 if i&clientChangeLogo != 0 {
256 logo = fmt.Sprintf("logo-%d", i)
257 change.Logo = &logo
258 expectation.Logo = logo
259 }
260 if i&clientChangeWebsite != 0 {
261 website = fmt.Sprintf("website-%d", i)
262 change.Website = &website
263 expectation.Website = website
264 }
265 result.ApplyChange(change)
266 match, field, expected, got := compareClients(expectation, result)
267 if !match {
268 t.Fatalf("Expected field `%s` to be `%v`, got `%v`", field, expected, got)
269 }
270 for _, store := range clientStores {
271 context := Context{clients: store}
272 err := context.SaveClient(client)
273 if err != nil {
274 t.Fatalf("Error saving client in %T: %s", store, err)
275 }
276 err = context.UpdateClient(client.ID, change)
277 if err != nil {
278 t.Fatalf("Error updating client in %T: %s", store, err)
279 }
280 retrieved, err := context.GetClient(client.ID)
281 if err != nil {
282 t.Fatalf("Error getting client from %T: %s", store, err)
283 }
284 match, field, expected, got = compareClients(expectation, retrieved)
285 if !match {
286 t.Fatalf("Expected field `%s` to be `%v`, got `%v` from %T", field, expected, got, store)
287 }
288 err = context.DeleteClient(client.ID)
289 if err != nil {
290 t.Fatalf("Error deleting client from %T: %s", store, err)
291 }
292 err = context.UpdateClient(client.ID, change)
293 if err != ErrClientNotFound {
294 t.Fatalf("Expected ErrClientNotFound, got %v from %T", err, store)
295 }
296 }
297 }
298 }
300 func TestClientEndpointChecks(t *testing.T) {
301 t.Parallel()
302 client := Client{
303 ID: uuid.NewID(),
304 Secret: "secret",
305 OwnerID: uuid.NewID(),
306 Name: "name",
307 Logo: "logo",
308 Website: "website",
309 }
310 endpoint1 := Endpoint{
311 ID: uuid.NewID(),
312 ClientID: client.ID,
313 Added: time.Now(),
314 URI: "https://www.example.com/first",
315 }
316 endpoint2 := Endpoint{
317 ID: uuid.NewID(),
318 ClientID: client.ID,
319 Added: time.Now(),
320 URI: "https://www.example.com/my/full/path",
321 }
322 candidates := map[string]bool{
323 "https://www.example.com/": false,
324 "https://www.example.com/first": true,
325 "https://www.example.com/first/extra/path": false,
326 "https://www.example.com/my": false,
327 "https://www.example.com/my/full/path": true,
328 }
329 for _, store := range clientStores {
330 context := Context{clients: store}
331 err := context.SaveClient(client)
332 if err != nil {
333 t.Fatalf("Error saving client in %T: %s", store, err)
334 }
335 err = context.AddEndpoints(client.ID, []Endpoint{endpoint1})
336 if err != nil {
337 t.Fatalf("Error saving endpoint in %T: %s", store, err)
338 }
339 err = context.AddEndpoints(client.ID, []Endpoint{endpoint2})
340 if err != nil {
341 t.Fatalf("Error saving endpoint in %T: %s", store, err)
342 }
343 for candidate, expectation := range candidates {
344 result, err := context.CheckEndpoint(client.ID, candidate)
345 if err != nil {
346 t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
347 }
348 if result != expectation {
349 expectStr := "no"
350 resultStr := "a"
351 if expectation {
352 expectStr = "a"
353 resultStr = "no"
354 }
355 t.Errorf("Expected %s match for %s in %T, got %s match", expectStr, candidate, store, resultStr)
356 }
357 }
358 }
359 }
361 func TestClientEndpointChecksStrict(t *testing.T) {
362 t.Parallel()
363 client := Client{
364 ID: uuid.NewID(),
365 Secret: "secret",
366 OwnerID: uuid.NewID(),
367 Name: "name",
368 Logo: "logo",
369 Website: "website",
370 }
371 endpoint1 := Endpoint{
372 ID: uuid.NewID(),
373 ClientID: client.ID,
374 Added: time.Now(),
375 URI: "https://www.example.com/first",
376 }
377 endpoint2 := Endpoint{
378 ID: uuid.NewID(),
379 ClientID: client.ID,
380 Added: time.Now(),
381 URI: "https://www.example.com/my/full/path",
382 }
383 candidates := map[string]bool{
384 "https://www.example.com/": false,
385 "https://www.example.com/first": true,
386 "https://www.example.com/first/extra/path": false,
387 "https://www.example.com/my": false,
388 "https://www.example.com/my/full/path": true,
389 }
390 for _, store := range clientStores {
391 context := Context{clients: store}
392 err := context.SaveClient(client)
393 if err != nil {
394 t.Fatalf("Error saving client in %T: %s", store, err)
395 }
396 err = context.AddEndpoints(client.ID, []Endpoint{endpoint1})
397 if err != nil {
398 t.Fatalf("Error saving endpoint in %T: %s", store, err)
399 }
400 err = context.AddEndpoints(client.ID, []Endpoint{endpoint2})
401 if err != nil {
402 t.Fatalf("Error saving endpoint in %T: %s", store, err)
403 }
404 for candidate, expectation := range candidates {
405 result, err := context.CheckEndpoint(client.ID, candidate)
406 if err != nil {
407 t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
408 }
409 if result != expectation {
410 expectStr := "no"
411 resultStr := "a"
412 if expectation {
413 expectStr = "a"
414 resultStr = "no"
415 }
416 t.Errorf("Expected %s match for %s in %T, got %s match", expectStr, candidate, store, resultStr)
417 }
418 }
419 }
420 }
422 func TestClientChangeValidation(t *testing.T) {
423 t.Parallel()
424 change := ClientChange{}
425 if err := change.Validate(); err[0] != ErrEmptyChange {
426 t.Errorf("Expected %s to give an error of %s, gave %s", "empty change", ErrEmptyChange, err)
427 }
428 names := map[string][]error{
429 "a": []error{ErrClientNameTooShort},
430 "ab": []error{},
431 "abc": []error{},
432 "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopq": []error{ErrClientNameTooLong},
433 }
434 for name, expectation := range names {
435 change = ClientChange{Name: &name}
436 errs := change.Validate()
437 if len(errs) != len(expectation) {
438 t.Errorf("Expected %s to give %d errors, gave %d", name, len(expectation), len(errs))
439 t.Logf("%+v", errs)
440 }
441 for pos, err := range errs {
442 if err != expectation[pos] {
443 t.Errorf("Expected %s to give an error of %s in position %d, gave %s", name, expectation[pos], pos, err)
444 }
445 }
446 }
447 longPath := ""
448 for i := 0; i < 1025; i++ {
449 longPath = fmt.Sprintf("%s%d", longPath, i)
450 }
451 logos := map[string][]error{
452 "https://www.example.com/" + longPath: []error{ErrClientLogoTooLong},
453 "https://www.example.com/ab": []error{},
454 "www.example.com/ab": []error{ErrClientLogoNotURL},
455 "test": []error{ErrClientLogoNotURL},
456 "": []error{},
457 }
458 for logo, expectation := range logos {
459 change = ClientChange{Logo: &logo}
460 errs := change.Validate()
461 if len(errs) != len(expectation) {
462 t.Errorf("Expected %s to give %d errors, gave %d", logo, len(expectation), len(errs))
463 }
464 for pos, err := range errs {
465 if err != expectation[pos] {
466 t.Errorf("Expected %s to give an error of %s in positiong %d, gave %s", logo, expectation[pos], pos, err)
467 }
468 }
469 }
470 websites := map[string][]error{
471 "https://www.example.com/" + longPath: []error{ErrClientWebsiteTooLong},
472 "https://www.example.com/ab": []error{},
473 "www.example.com/ab": []error{ErrClientWebsiteNotURL},
474 "test": []error{ErrClientWebsiteNotURL},
475 "": []error{},
476 }
477 for website, expectation := range websites {
478 change = ClientChange{Website: &website}
479 errs := change.Validate()
480 if len(errs) != len(expectation) {
481 t.Errorf("Expected %s to give %d errors, gave %d", website, len(expectation), len(errs))
482 }
483 for pos, err := range errs {
484 if err != expectation[pos] {
485 t.Errorf("Expected %s to give an error of %s in position %d, gave %s", website, expectation[pos], pos, err)
486 }
487 }
488 }
489 }
491 func TestGetClientAuth(t *testing.T) {
492 t.Parallel()
493 type clientAuthRequest struct {
494 username string
495 pass string
496 clientID string
497 allowPublic bool
498 expectedClientID uuid.ID
499 expectedClientSecret string
500 expectedValid bool
501 expectedCode int
502 expectedBody string
503 expectAuthenticateHeader bool
504 }
505 id := uuid.NewID()
506 tests := []clientAuthRequest{
507 {"", "", "", false, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
508 {"", "", "", true, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
509 {"", "no clientID set", "", false, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
510 {"", "no clientID set", "", true, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
511 {"not an actual id", "invalid client ID set", "", false, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
512 {"not an actual id", "invalid client ID set", "", true, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
513 {"", "", "not an actual id", true, nil, "", false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
514 {id.String(), "secret", "", true, id, "secret", true, http.StatusOK, "", false},
515 {id.String(), "secret", "", false, id, "secret", true, http.StatusOK, "", false},
516 {"", "", id.String(), true, id, "", true, http.StatusOK, "", false},
517 {"", "", id.String(), false, nil, "", false, http.StatusBadRequest, `{"error":"unauthorized_client"}`, false},
518 }
519 for pos, test := range tests {
520 t.Logf("Running test #%d, with request %+v", pos, test)
521 w := httptest.NewRecorder()
522 r, err := http.NewRequest("POST", "https://test.auth.secondbit.org/oauth2/grant", nil)
523 if err != nil {
524 t.Fatal("Can't build request:", err)
525 }
526 if test.username != "" || test.pass != "" {
527 r.SetBasicAuth(test.username, test.pass)
528 }
529 r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
530 params := url.Values{}
531 params.Set("client_id", test.clientID)
532 body := bytes.NewBufferString(params.Encode())
533 r.Body = ioutil.NopCloser(body)
534 respID, respSecret, success := getClientAuth(w, r, test.allowPublic)
535 if (respID == nil && test.expectedClientID != nil) || (respID != nil && test.expectedClientID == nil) || !respID.Equal(test.expectedClientID) {
536 t.Errorf("Expected response ID to be %v, got %v", test.expectedClientID, respID)
537 }
538 if test.expectedClientSecret != respSecret {
539 t.Errorf("Expected response secret to be '%s', got '%s'", test.expectedClientSecret, respSecret)
540 }
541 if test.expectedValid != success {
542 t.Errorf("Expected success result to be %v, got %v", test.expectedValid, success)
543 }
544 if test.expectedCode != w.Code {
545 t.Errorf("Expected response code to be %d, got %d", test.expectedCode, w.Code)
546 }
547 if test.expectedBody != strings.TrimSpace(w.Body.String()) {
548 t.Errorf("Expected body to be '%s', got '%s'", test.expectedBody, strings.TrimSpace(w.Body.String()))
549 }
550 if test.expectAuthenticateHeader && w.Header().Get("WWW-Authenticate") != "Basic" {
551 t.Errorf(`Expected header WWW-Authenticate to be set to "Basic", got "%s"`, w.Header().Get("WWW-Authenticate"))
552 }
553 }
554 }
556 func TestVerifyClient(t *testing.T) {
557 t.Parallel()
558 type verifyClientRequest struct {
559 username string
560 pass string
561 clientID string
562 allowPublic bool
563 expectedClientID uuid.ID
564 expectedValid bool
565 expectedCode int
566 expectedBody string
567 expectAuthenticateHeader bool
568 }
569 memstore := NewMemstore()
570 context := Context{
571 clients: memstore,
572 }
573 client := Client{
574 ID: uuid.NewID(),
575 Secret: "super secret!",
576 OwnerID: uuid.NewID(),
577 Name: "My test client",
578 Logo: "https://secondbit.org/logo.png",
579 Website: "https://secondbit.org/",
580 Type: "confidential",
581 }
582 err := context.SaveClient(client)
583 if err != nil {
584 t.Fatal("Could not save client:", err)
585 }
586 publicClient := Client{
587 ID: uuid.NewID(),
588 Secret: "",
589 OwnerID: uuid.NewID(),
590 Name: "A public client",
591 Logo: "https://secondbit.org/logo.png",
592 Website: "https://secondbit.org/",
593 Type: "public",
594 }
595 err = context.SaveClient(publicClient)
596 if err != nil {
597 t.Fatal("Could not save client:", err)
598 }
599 id := uuid.NewID()
600 tests := []verifyClientRequest{
601 {"", "", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
602 {"", "", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
603 {"", "no clientID set", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
604 {"", "no clientID set", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
605 {"not an actual id", "invalid client ID set", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
606 {"not an actual id", "invalid client ID set", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
607 {id.String(), "unsaved client ID set", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
608 {id.String(), "unsaved client ID set", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
609 {client.ID.String(), "wrong secret", "", false, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
610 {client.ID.String(), "wrong secret", "", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, true},
611 {"", "", "not an actual id", true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
612 {"", "", id.String(), true, nil, false, http.StatusUnauthorized, `{"error":"invalid_client"}`, false},
613 {client.ID.String(), client.Secret, "", true, client.ID, true, http.StatusOK, "", false},
614 {client.ID.String(), client.Secret, "", false, client.ID, true, http.StatusOK, "", false},
615 {"", "", publicClient.ID.String(), true, publicClient.ID, true, http.StatusOK, "", false},
616 {"", "", publicClient.ID.String(), false, nil, false, http.StatusBadRequest, `{"error":"unauthorized_client"}`, false},
617 }
619 for pos, test := range tests {
620 t.Logf("Running test #%d, with request %+v", pos, test)
621 w := httptest.NewRecorder()
622 r, err := http.NewRequest("POST", "https://test.auth.secondbit.org/oauth2/grant", nil)
623 if err != nil {
624 t.Fatal("Can't build request:", err)
625 }
626 if test.username != "" || test.pass != "" {
627 r.SetBasicAuth(test.username, test.pass)
628 }
629 r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
630 params := url.Values{}
631 params.Set("client_id", test.clientID)
632 body := bytes.NewBufferString(params.Encode())
633 r.Body = ioutil.NopCloser(body)
634 respID, success := verifyClient(w, r, test.allowPublic, context)
635 if (respID == nil && test.expectedClientID != nil) || (respID != nil && test.expectedClientID == nil) || !respID.Equal(test.expectedClientID) {
636 t.Errorf("Expected response ID to be %v, got %v", test.expectedClientID, respID)
637 }
638 if test.expectedValid != success {
639 t.Errorf("Expected success result to be %v, got %v", test.expectedValid, success)
640 }
641 if test.expectedCode != w.Code {
642 t.Errorf("Expected response code to be %d, got %d", test.expectedCode, w.Code)
643 }
644 if test.expectedBody != strings.TrimSpace(w.Body.String()) {
645 t.Errorf("Expected body to be '%s', got '%s'", test.expectedBody, strings.TrimSpace(w.Body.String()))
646 }
647 if test.expectAuthenticateHeader && w.Header().Get("WWW-Authenticate") != "Basic" {
648 t.Errorf(`Expected header WWW-Authenticate to be set to "Basic", got "%s"`, w.Header().Get("WWW-Authenticate"))
649 }
650 }
651 }
653 func TestCreateClientHandler(t *testing.T) {
654 t.Parallel()
655 memstore := NewMemstore()
656 c := Context{
657 clients: memstore,
658 profiles: memstore,
659 }
660 w := httptest.NewRecorder()
661 r, err := http.NewRequest("POST", "https://test.auth.secondbit.org/clients", nil)
662 if err != nil {
663 t.Fatal("Can't build request:", err)
664 }
665 r.Header.Set("Content-Type", "application/json")
666 CreateClientHandler(w, r, c)
667 if w.Code != http.StatusUnauthorized {
668 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
669 }
670 expected := `{"errors":[{"error":"access_denied"}]}`
671 result := strings.TrimSpace(w.Body.String())
672 if result != expected {
673 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
674 }
675 w = httptest.NewRecorder()
676 r.Header.Set("Authorization", "Not basic at all...")
677 CreateClientHandler(w, r, c)
678 if w.Code != http.StatusUnauthorized {
679 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
680 }
681 expected = `{"errors":[{"error":"access_denied"}]}`
682 result = strings.TrimSpace(w.Body.String())
683 if result != expected {
684 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
685 }
686 w = httptest.NewRecorder()
687 r.Header.Set("Authorization", "Basic TotallyNotBase64Encoded")
688 CreateClientHandler(w, r, c)
689 if w.Code != http.StatusUnauthorized {
690 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
691 }
692 expected = `{"errors":[{"error":"access_denied"}]}`
693 result = strings.TrimSpace(w.Body.String())
694 if result != expected {
695 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
696 }
697 w = httptest.NewRecorder()
698 r.Header.Set("Authorization", "Basic dGhpc2hhc25vY29sb24=")
699 CreateClientHandler(w, r, c)
700 if w.Code != http.StatusUnauthorized {
701 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
702 }
703 expected = `{"errors":[{"error":"access_denied"}]}`
704 result = strings.TrimSpace(w.Body.String())
705 if result != expected {
706 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
707 }
708 profile := Profile{
709 ID: uuid.NewID(),
710 Name: "Test User",
711 Passphrase: "f3a4ac4f1d657b2e6e776d24213e39406d50a87a52691a2a78891425af1271d0",
712 Iterations: 1,
713 Salt: "d82d92cfa8bfb5a08270ebbf39a3710d24b352b937fcc8959ebcb40384cc616b",
714 PassphraseScheme: 1,
715 Compromised: false,
716 LockedUntil: time.Time{},
717 PassphraseReset: "",
718 PassphraseResetCreated: time.Time{},
719 Created: time.Now(),
720 LastSeen: time.Time{},
721 }
722 login := Login{
723 Type: "email",
724 Value: "test@example.com",
725 ProfileID: profile.ID,
726 Created: time.Now(),
727 LastUsed: time.Time{},
728 }
729 w = httptest.NewRecorder()
730 r.SetBasicAuth("test@example.com", "mysecurepassphrase")
731 CreateClientHandler(w, r, c)
732 if w.Code != http.StatusUnauthorized {
733 t.Errorf("Expected status of %d, got status %d", http.StatusUnauthorized, w.Code)
734 }
735 expected = `{"errors":[{"error":"access_denied"}]}`
736 result = strings.TrimSpace(w.Body.String())
737 if result != expected {
738 t.Errorf("Expected response to be `%s`, got `%s`", expected, result)
739 }
740 err = c.SaveProfile(profile)
741 if err != nil {
742 t.Error("Error saving profile:", err)
743 }
744 err = c.AddLogin(login)
745 if err != nil {
746 t.Error("Error adding login:", err)
747 }
748 r.SetBasicAuth("test@example.com", "mysecurepassphrase")
749 type testStruct struct {
750 request string
751 code int
752 resp response
753 }
754 tests := []testStruct{
755 {``, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidFormat, Field: "/"}}}},
756 {`{}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrMissing, Field: "/type"}, {Slug: requestErrMissing, Field: "/name"}}}},
757 {`{"type":"notarealtype"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidValue, Field: "/type"}, {Slug: requestErrMissing, Field: "/name"}}}},
758 {`{"type":"notarealtype","name":"myreallylongnameislongerthatthemaximumnamelength"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidValue, Field: "/type"}, {Slug: requestErrOverflow, Field: "/name"}}}},
759 {`{"type":"notarealtype","name":"a"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidValue, Field: "/type"}, {Slug: requestErrInsufficient, Field: "/name"}}}},
760 {`{"type":"public"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrMissing, Field: "/name"}}}},
761 {`{"type":"public","name":"myreallylongnameislongerthatthemaximumnamelength"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrOverflow, Field: "/name"}}}},
762 {`{"type":"public","name":"a"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInsufficient, Field: "/name"}}}},
763 {`{"name":"My Client"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrMissing, Field: "/type"}}}},
764 {`{"type":"notarealtype","name":"My Client"}`, http.StatusBadRequest, response{Errors: []requestError{{Slug: requestErrInvalidValue, Field: "/type"}}}},
765 {`{"type":"public","name":"My Client"}`, http.StatusCreated, response{Clients: []Client{{Name: "My Client", OwnerID: profile.ID, Type: "public"}}}},
766 {`{"type":"public","name":"My Client", "endpoints": ["https://test.secondbit.org/", "https://paddy.io"]}`, http.StatusCreated, response{Clients: []Client{{Name: "My Client", OwnerID: profile.ID, Type: "public"}}, Endpoints: []Endpoint{{URI: "https://test.secondbit.org/"}, {URI: "https://paddy.io"}}}},
767 {`{"type":"public","name":"My Client", "endpoints": [":/not a url", "https://paddy.io"]}`, http.StatusCreated, response{Clients: []Client{{Name: "My Client", OwnerID: profile.ID, Type: "public"}}, Endpoints: []Endpoint{{URI: "https://paddy.io"}}, Errors: []requestError{{Slug: requestErrInvalidFormat, Field: "/endpoints/0"}}}},
768 {`{"type":"public","name":"My Client", "endpoints": [":/not a url", "/relative/uri", "https://paddy.io"]}`, http.StatusCreated, response{Clients: []Client{{Name: "My Client", OwnerID: profile.ID, Type: "public"}}, Endpoints: []Endpoint{{URI: "https://paddy.io"}}, Errors: []requestError{{Slug: requestErrInvalidFormat, Field: "/endpoints/0"}, {Slug: requestErrInvalidValue, Field: "/endpoints/1"}}}},
769 {`{"type":"confidential","name":"Secret Client", "endpoints": ["https://secondbit.org"]}`, http.StatusCreated, response{Clients: []Client{{Name: "Secret Client", OwnerID: profile.ID, Type: "confidential"}}, Endpoints: []Endpoint{{URI: "https://secondbit.org"}}}},
770 }
771 for pos, test := range tests {
772 t.Logf("Test #%d: `%s`", pos, test.request)
773 w = httptest.NewRecorder()
774 body := bytes.NewBufferString(test.request)
775 r.Body = ioutil.NopCloser(body)
776 CreateClientHandler(w, r, c)
777 if w.Code != test.code {
778 t.Errorf("Expected response code to be %d, got %d", test.code, w.Code)
779 }
780 t.Logf("Response: %s", w.Body.String())
781 var res response
782 err = json.Unmarshal(w.Body.Bytes(), &res)
783 if err != nil {
784 t.Error("Unexpected error unmarshalling response:", err)
785 }
786 if len(res.Clients) > 0 {
787 if res.Clients[0].Type == "confidential" && res.Clients[0].Secret == "" {
788 t.Log("Client:", res.Clients[0])
789 t.Error("Expected confidential client to have a secret, but does not.")
790 } else if res.Clients[0].Type == "public" && res.Clients[0].Secret != "" {
791 t.Log("Client:", res.Clients[0])
792 t.Error("Expected public client to not have a secret, but it does.")
793 }
794 }
795 fillInServerGenerated(test.resp, res)
796 success, field, expectation, result := compareResponses(test.resp, res)
797 if !success {
798 t.Errorf("Unexpected result for %s in response: expected %v, got %v", field, expectation, result)
799 }
800 }
801 }
803 // BUG(paddy): We need to test the clientCredentialsValidate function.
804 // BUG(paddy): We need to test the GetClientHandler.
805 // BUG(paddy): We need to test the ListClientsHandler.