auth
auth/client_test.go
Added validation for clients, split endpoints out. Split endpoints out into their own type and added associated methods to the ClientStores, so now each client can have more than one redirect endpoint. Added unit testing for endpoint methods. Added validation code to validate client changes.
1 package auth
3 import (
4 "fmt"
5 "net/url"
6 "testing"
7 "time"
9 "sort"
10 "secondbit.org/uuid"
11 )
13 const (
14 clientChangeSecret = 1 << iota
15 clientChangeOwnerID
16 clientChangeName
17 clientChangeLogo
18 clientChangeWebsite
19 )
21 var clientStores = []ClientStore{NewMemstore()}
23 func compareClients(client1, client2 Client) (success bool, field string, val1, val2 interface{}) {
24 if !client1.ID.Equal(client2.ID) {
25 return false, "ID", client1.ID, client2.ID
26 }
27 if client1.Secret != client2.Secret {
28 return false, "secret", client1.Secret, client2.Secret
29 }
30 if !client1.OwnerID.Equal(client2.OwnerID) {
31 return false, "owner ID", client1.OwnerID, client2.OwnerID
32 }
33 if client1.Name != client2.Name {
34 return false, "name", client1.Name, client2.Name
35 }
36 if client1.Logo != client2.Logo {
37 return false, "logo", client1.Logo, client2.Logo
38 }
39 if client1.Website != client2.Website {
40 return false, "website", client1.Website, client2.Website
41 }
42 if client1.Type != client2.Type {
43 return false, "type", client1.Type, client2.Type
44 }
45 return true, "", nil, nil
46 }
48 func compareEndpoints(endpoint1, endpoint2 Endpoint) (success bool, field string, val1, val2 interface{}) {
49 if !endpoint1.ID.Equal(endpoint2.ID) {
50 return false, "ID", endpoint1.ID, endpoint2.ID
51 }
52 if !endpoint1.ClientID.Equal(endpoint2.ClientID) {
53 return false, "OwnerID", endpoint1.ClientID, endpoint2.ClientID
54 }
55 if !endpoint1.Added.Equal(endpoint2.Added) {
56 return false, "Added", endpoint1.Added, endpoint2.Added
57 }
58 if endpoint1.URI.String() != endpoint2.URI.String() {
59 return false, "URI", endpoint1.URI, endpoint2.URI
60 }
61 return true, "", nil, nil
62 }
64 func TestClientStoreSuccess(t *testing.T) {
65 t.Parallel()
66 client := Client{
67 ID: uuid.NewID(),
68 Secret: "secret",
69 OwnerID: uuid.NewID(),
70 Name: "name",
71 Logo: "logo",
72 Website: "website",
73 }
74 for _, store := range clientStores {
75 err := store.SaveClient(client)
76 if err != nil {
77 t.Fatalf("Error saving client to %T: %s", store, err)
78 }
79 err = store.SaveClient(client)
80 if err != ErrClientAlreadyExists {
81 t.Fatalf("Expected ErrClientAlreadyExists, got %v from %T", err, store)
82 }
83 retrieved, err := store.GetClient(client.ID)
84 if err != nil {
85 t.Fatalf("Error retrieving client from %T: %s", store, err)
86 }
87 success, field, expectation, result := compareClients(client, retrieved)
88 if !success {
89 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
90 }
91 clients, err := store.ListClientsByOwner(client.OwnerID, 25, 0)
92 if err != nil {
93 t.Fatalf("Error retrieving clients by owner from %T: %s", store, err)
94 }
95 if len(clients) != 1 {
96 t.Fatalf("Expected 1 client in response from %T, got %+v", store, clients)
97 }
98 success, field, expectation, result = compareClients(client, clients[0])
99 if !success {
100 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
101 }
102 err = store.DeleteClient(client.ID)
103 if err != nil {
104 t.Fatalf("Error deleting client from %T: %s", store, err)
105 }
106 err = store.DeleteClient(client.ID)
107 if err != ErrClientNotFound {
108 t.Fatalf("Expected ErrClientNotFound, got %s from %T", err, store)
109 }
110 retrieved, err = store.GetClient(client.ID)
111 if err != ErrClientNotFound {
112 t.Fatalf("Expected ErrClientNotFound from %T, got %+v and %s", store, retrieved, err)
113 }
114 clients, err = store.ListClientsByOwner(client.OwnerID, 25, 0)
115 if err != nil {
116 t.Fatalf("Error listing clients by owner from %T: %s", store, err)
117 }
118 if len(clients) != 0 {
119 t.Fatalf("Expected 0 clients in response from %T, got %+v", store, clients)
120 }
121 }
122 }
124 func TestEndpointStoreSuccess(t *testing.T) {
125 t.Parallel()
126 client := Client{
127 ID: uuid.NewID(),
128 Secret: "secret",
129 OwnerID: uuid.NewID(),
130 Name: "name",
131 Logo: "logo",
132 Website: "website",
133 }
134 uri1, _ := url.Parse("https://www.example.com/")
135 uri2, _ := url.Parse("https://www.example.com/my/full/path")
136 endpoint1 := Endpoint{
137 ID: uuid.NewID(),
138 ClientID: client.ID,
139 Added: time.Now(),
140 URI: *uri1,
141 }
142 endpoint2 := Endpoint{
143 ID: uuid.NewID(),
144 ClientID: client.ID,
145 Added: time.Now(),
146 URI: *uri2,
147 }
148 for _, store := range clientStores {
149 err := store.SaveClient(client)
150 if err != nil {
151 t.Fatalf("Error saving client to %T: %s", store, err)
152 }
153 err = store.AddEndpoint(client.ID, endpoint1)
154 if err != nil {
155 t.Fatalf("Error adding endpoint to client in %T: %s", store, err)
156 }
157 endpoints, err := store.ListEndpoints(client.ID, 10, 0)
158 if err != nil {
159 t.Fatalf("Error retrieving endpoints from %T: %s", store, err)
160 }
161 if len(endpoints) != 1 {
162 t.Fatalf("Expected %d endpoints, got %+v from %T", 1, endpoints, store)
163 }
164 success, field, expectation, result := compareEndpoints(endpoint1, endpoints[0])
165 if !success {
166 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
167 }
168 err = store.AddEndpoint(client.ID, endpoint2)
169 if err != nil {
170 t.Fatalf("Error adding endpoint to client in %T: %s", store, err)
171 }
172 endpoints, err = store.ListEndpoints(client.ID, 10, 0)
173 if err != nil {
174 t.Fatalf("Error retrieving endpoints from %T: %s", store, err)
175 }
176 if len(endpoints) != 2 {
177 t.Fatalf("Expected %d endpoints, got %+v from %T", 2, endpoints, store)
178 }
179 sortedEnd := sortedEndpoints(endpoints)
180 sort.Sort(sortedEnd)
181 endpoints = []Endpoint(sortedEnd)
182 success, field, expectation, result = compareEndpoints(endpoint1, endpoints[0])
183 if !success {
184 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
185 }
186 success, field, expectation, result = compareEndpoints(endpoint2, endpoints[1])
187 if !success {
188 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
189 }
190 err = store.RemoveEndpoint(client.ID, endpoint1.ID)
191 if err != nil {
192 t.Fatalf("Error removing endpoint from client in %T: %s", store, err)
193 }
194 endpoints, err = store.ListEndpoints(client.ID, 10, 0)
195 if err != nil {
196 t.Fatalf("Error listing endpoints in %T: %s", store, err)
197 }
198 if len(endpoints) != 1 {
199 t.Fatalf("Expected %d endpoints, got %+v from %T", 1, endpoints, store)
200 }
201 success, field, expectation, result = compareEndpoints(endpoint2, endpoints[0])
202 if !success {
203 t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
204 }
205 err = store.RemoveEndpoint(client.ID, endpoint2.ID)
206 if err != nil {
207 t.Fatalf("Error removing endpoint from client in %T: %s", store, err)
208 }
209 endpoints, err = store.ListEndpoints(client.ID, 10, 0)
210 if err != nil {
211 t.Fatalf("Error listing endpoints in %T: %s", store, err)
212 }
213 if len(endpoints) != 0 {
214 t.Fatalf("Expected %d endpoints, got %+v from %T", 0, endpoints, store)
215 }
216 }
217 }
219 func TestClientUpdates(t *testing.T) {
220 t.Parallel()
221 variations := 1 << 5
222 client := Client{
223 ID: uuid.NewID(),
224 Secret: "secret",
225 OwnerID: uuid.NewID(),
226 Name: "name",
227 Logo: "logo",
228 Website: "website",
229 }
230 for i := 0; i < variations; i++ {
231 var secret, name, logo, website string
232 change := ClientChange{}
233 expectation := client
234 result := client
235 if i&clientChangeSecret != 0 {
236 secret = fmt.Sprintf("secret-%d", i)
237 change.Secret = &secret
238 expectation.Secret = secret
239 }
240 if i&clientChangeOwnerID != 0 {
241 change.OwnerID = uuid.NewID()
242 expectation.OwnerID = change.OwnerID
243 }
244 if i&clientChangeName != 0 {
245 name = fmt.Sprintf("name-%d", i)
246 change.Name = &name
247 expectation.Name = name
248 }
249 if i&clientChangeLogo != 0 {
250 logo = fmt.Sprintf("logo-%d", i)
251 change.Logo = &logo
252 expectation.Logo = logo
253 }
254 if i&clientChangeWebsite != 0 {
255 website = fmt.Sprintf("website-%d", i)
256 change.Website = &website
257 expectation.Website = website
258 }
259 result.ApplyChange(change)
260 match, field, expected, got := compareClients(expectation, result)
261 if !match {
262 t.Fatalf("Expected field `%s` to be `%v`, got `%v`", field, expected, got)
263 }
264 for _, store := range clientStores {
265 err := store.SaveClient(client)
266 if err != nil {
267 t.Fatalf("Error saving client in %T: %s", store, err)
268 }
269 err = store.UpdateClient(client.ID, change)
270 if err != nil {
271 t.Fatalf("Error updating client in %T: %s", store, err)
272 }
273 retrieved, err := store.GetClient(client.ID)
274 if err != nil {
275 t.Fatalf("Error getting profile from %T: %s", store, err)
276 }
277 match, field, expected, got = compareClients(expectation, retrieved)
278 if !match {
279 t.Fatalf("Expected field `%s` to be `%v`, got `%v` from %T", field, expected, got, store)
280 }
281 err = store.DeleteClient(client.ID)
282 if err != nil {
283 t.Fatalf("Error deleting client from %T: %s", store, err)
284 }
285 err = store.UpdateClient(client.ID, change)
286 if err != ErrClientNotFound {
287 t.Fatalf("Expected ErrClientNotFound, got %v from %T", err, store)
288 }
289 }
290 }
291 }
293 func TestClientEndpointChecks(t *testing.T) {
294 t.Parallel()
295 client := Client{
296 ID: uuid.NewID(),
297 Secret: "secret",
298 OwnerID: uuid.NewID(),
299 Name: "name",
300 Logo: "logo",
301 Website: "website",
302 }
303 uri1, _ := url.Parse("https://www.example.com/first")
304 uri2, _ := url.Parse("https://www.example.com/my/full/path")
305 endpoint1 := Endpoint{
306 ID: uuid.NewID(),
307 ClientID: client.ID,
308 Added: time.Now(),
309 URI: *uri1,
310 }
311 endpoint2 := Endpoint{
312 ID: uuid.NewID(),
313 ClientID: client.ID,
314 Added: time.Now(),
315 URI: *uri2,
316 }
317 candidates := map[string]bool{
318 "https://www.example.com/": false,
319 "https://www.example.com/first": true,
320 "https://www.example.com/first/extra/path": true,
321 "https://www.example.com/my": false,
322 "https://www.example.com/my/full/path": true,
323 }
324 for _, store := range clientStores {
325 err := store.SaveClient(client)
326 if err != nil {
327 t.Fatalf("Error saving client in %T: %s", store, err)
328 }
329 err = store.AddEndpoint(client.ID, endpoint1)
330 if err != nil {
331 t.Fatalf("Error saving endpoint in %T: %s", store, err)
332 }
333 err = store.AddEndpoint(client.ID, endpoint2)
334 if err != nil {
335 t.Fatalf("Error saving endpoint in %T: %s", store, err)
336 }
337 for candidate, expectation := range candidates {
338 result, err := store.CheckEndpoint(client.ID, candidate)
339 if err != nil {
340 t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
341 }
342 if result != expectation {
343 expectStr := "no"
344 resultStr := "a"
345 if expectation {
346 expectStr = "a"
347 resultStr = "no"
348 }
349 t.Errorf("Expected %s match for %s in %T, got %s match", expectStr, candidate, store, resultStr)
350 }
351 }
352 }
353 }