auth

Paddy 2014-09-18 Parent:690561c6619a Child:2a6d722a2b4b

41:113ccb15b919 Go to Latest

auth/client_test.go

Added validation for clients, split endpoints out. Split endpoints out into their own type and added associated methods to the ClientStores, so now each client can have more than one redirect endpoint. Added unit testing for endpoint methods. Added validation code to validate client changes.

History
     1.1 --- a/client_test.go	Thu Sep 18 19:34:18 2014 -0400
     1.2 +++ b/client_test.go	Thu Sep 18 22:13:22 2014 -0400
     1.3 @@ -2,14 +2,16 @@
     1.4  
     1.5  import (
     1.6  	"fmt"
     1.7 +	"net/url"
     1.8  	"testing"
     1.9 +	"time"
    1.10  
    1.11 +	"sort"
    1.12  	"secondbit.org/uuid"
    1.13  )
    1.14  
    1.15  const (
    1.16  	clientChangeSecret = 1 << iota
    1.17 -	clientChangeRedirectURI
    1.18  	clientChangeOwnerID
    1.19  	clientChangeName
    1.20  	clientChangeLogo
    1.21 @@ -25,9 +27,6 @@
    1.22  	if client1.Secret != client2.Secret {
    1.23  		return false, "secret", client1.Secret, client2.Secret
    1.24  	}
    1.25 -	if client1.RedirectURI != client2.RedirectURI {
    1.26 -		return false, "redirect URI", client1.RedirectURI, client2.RedirectURI
    1.27 -	}
    1.28  	if !client1.OwnerID.Equal(client2.OwnerID) {
    1.29  		return false, "owner ID", client1.OwnerID, client2.OwnerID
    1.30  	}
    1.31 @@ -40,84 +39,196 @@
    1.32  	if client1.Website != client2.Website {
    1.33  		return false, "website", client1.Website, client2.Website
    1.34  	}
    1.35 +	if client1.Type != client2.Type {
    1.36 +		return false, "type", client1.Type, client2.Type
    1.37 +	}
    1.38 +	return true, "", nil, nil
    1.39 +}
    1.40 +
    1.41 +func compareEndpoints(endpoint1, endpoint2 Endpoint) (success bool, field string, val1, val2 interface{}) {
    1.42 +	if !endpoint1.ID.Equal(endpoint2.ID) {
    1.43 +		return false, "ID", endpoint1.ID, endpoint2.ID
    1.44 +	}
    1.45 +	if !endpoint1.ClientID.Equal(endpoint2.ClientID) {
    1.46 +		return false, "OwnerID", endpoint1.ClientID, endpoint2.ClientID
    1.47 +	}
    1.48 +	if !endpoint1.Added.Equal(endpoint2.Added) {
    1.49 +		return false, "Added", endpoint1.Added, endpoint2.Added
    1.50 +	}
    1.51 +	if endpoint1.URI.String() != endpoint2.URI.String() {
    1.52 +		return false, "URI", endpoint1.URI, endpoint2.URI
    1.53 +	}
    1.54  	return true, "", nil, nil
    1.55  }
    1.56  
    1.57  func TestClientStoreSuccess(t *testing.T) {
    1.58  	t.Parallel()
    1.59  	client := Client{
    1.60 -		ID:          uuid.NewID(),
    1.61 -		Secret:      "secret",
    1.62 -		RedirectURI: "redirectURI",
    1.63 -		OwnerID:     uuid.NewID(),
    1.64 -		Name:        "name",
    1.65 -		Logo:        "logo",
    1.66 -		Website:     "website",
    1.67 +		ID:      uuid.NewID(),
    1.68 +		Secret:  "secret",
    1.69 +		OwnerID: uuid.NewID(),
    1.70 +		Name:    "name",
    1.71 +		Logo:    "logo",
    1.72 +		Website: "website",
    1.73  	}
    1.74  	for _, store := range clientStores {
    1.75  		err := store.SaveClient(client)
    1.76  		if err != nil {
    1.77 -			t.Errorf("Error saving client to %T: %s", store, err)
    1.78 +			t.Fatalf("Error saving client to %T: %s", store, err)
    1.79  		}
    1.80  		err = store.SaveClient(client)
    1.81  		if err != ErrClientAlreadyExists {
    1.82 -			t.Errorf("Expected ErrClientAlreadyExists, got %v from %T", err, store)
    1.83 +			t.Fatalf("Expected ErrClientAlreadyExists, got %v from %T", err, store)
    1.84  		}
    1.85  		retrieved, err := store.GetClient(client.ID)
    1.86  		if err != nil {
    1.87 -			t.Errorf("Error retrieving client from %T: %s", store, err)
    1.88 +			t.Fatalf("Error retrieving client from %T: %s", store, err)
    1.89  		}
    1.90  		success, field, expectation, result := compareClients(client, retrieved)
    1.91  		if !success {
    1.92 -			t.Errorf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
    1.93 +			t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
    1.94  		}
    1.95  		clients, err := store.ListClientsByOwner(client.OwnerID, 25, 0)
    1.96  		if err != nil {
    1.97 -			t.Errorf("Error retrieving clients by owner from %T: %s", store, err)
    1.98 +			t.Fatalf("Error retrieving clients by owner from %T: %s", store, err)
    1.99  		}
   1.100  		if len(clients) != 1 {
   1.101 -			t.Errorf("Expected 1 client in response from %T, got %+v", store, clients)
   1.102 +			t.Fatalf("Expected 1 client in response from %T, got %+v", store, clients)
   1.103  		}
   1.104  		success, field, expectation, result = compareClients(client, clients[0])
   1.105  		if !success {
   1.106 -			t.Errorf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
   1.107 +			t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
   1.108  		}
   1.109  		err = store.DeleteClient(client.ID)
   1.110  		if err != nil {
   1.111 -			t.Errorf("Error deleting client from %T: %s", store, err)
   1.112 +			t.Fatalf("Error deleting client from %T: %s", store, err)
   1.113  		}
   1.114  		err = store.DeleteClient(client.ID)
   1.115  		if err != ErrClientNotFound {
   1.116 -			t.Errorf("Expected ErrClientNotFound, got %s from %T", err, store)
   1.117 +			t.Fatalf("Expected ErrClientNotFound, got %s from %T", err, store)
   1.118  		}
   1.119  		retrieved, err = store.GetClient(client.ID)
   1.120  		if err != ErrClientNotFound {
   1.121 -			t.Errorf("Expected ErrClientNotFound from %T, got %+v and %s", store, retrieved, err)
   1.122 +			t.Fatalf("Expected ErrClientNotFound from %T, got %+v and %s", store, retrieved, err)
   1.123  		}
   1.124  		clients, err = store.ListClientsByOwner(client.OwnerID, 25, 0)
   1.125  		if err != nil {
   1.126 -			t.Errorf("Error listing clients by owner from %T: %s", store, err)
   1.127 +			t.Fatalf("Error listing clients by owner from %T: %s", store, err)
   1.128  		}
   1.129  		if len(clients) != 0 {
   1.130 -			t.Errorf("Expected 0 clients in response from %T, got %+v", store, clients)
   1.131 +			t.Fatalf("Expected 0 clients in response from %T, got %+v", store, clients)
   1.132 +		}
   1.133 +	}
   1.134 +}
   1.135 +
   1.136 +func TestEndpointStoreSuccess(t *testing.T) {
   1.137 +	t.Parallel()
   1.138 +	client := Client{
   1.139 +		ID:      uuid.NewID(),
   1.140 +		Secret:  "secret",
   1.141 +		OwnerID: uuid.NewID(),
   1.142 +		Name:    "name",
   1.143 +		Logo:    "logo",
   1.144 +		Website: "website",
   1.145 +	}
   1.146 +	uri1, _ := url.Parse("https://www.example.com/")
   1.147 +	uri2, _ := url.Parse("https://www.example.com/my/full/path")
   1.148 +	endpoint1 := Endpoint{
   1.149 +		ID:       uuid.NewID(),
   1.150 +		ClientID: client.ID,
   1.151 +		Added:    time.Now(),
   1.152 +		URI:      *uri1,
   1.153 +	}
   1.154 +	endpoint2 := Endpoint{
   1.155 +		ID:       uuid.NewID(),
   1.156 +		ClientID: client.ID,
   1.157 +		Added:    time.Now(),
   1.158 +		URI:      *uri2,
   1.159 +	}
   1.160 +	for _, store := range clientStores {
   1.161 +		err := store.SaveClient(client)
   1.162 +		if err != nil {
   1.163 +			t.Fatalf("Error saving client to %T: %s", store, err)
   1.164 +		}
   1.165 +		err = store.AddEndpoint(client.ID, endpoint1)
   1.166 +		if err != nil {
   1.167 +			t.Fatalf("Error adding endpoint to client in %T: %s", store, err)
   1.168 +		}
   1.169 +		endpoints, err := store.ListEndpoints(client.ID, 10, 0)
   1.170 +		if err != nil {
   1.171 +			t.Fatalf("Error retrieving endpoints from %T: %s", store, err)
   1.172 +		}
   1.173 +		if len(endpoints) != 1 {
   1.174 +			t.Fatalf("Expected %d endpoints, got %+v from %T", 1, endpoints, store)
   1.175 +		}
   1.176 +		success, field, expectation, result := compareEndpoints(endpoint1, endpoints[0])
   1.177 +		if !success {
   1.178 +			t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
   1.179 +		}
   1.180 +		err = store.AddEndpoint(client.ID, endpoint2)
   1.181 +		if err != nil {
   1.182 +			t.Fatalf("Error adding endpoint to client in %T: %s", store, err)
   1.183 +		}
   1.184 +		endpoints, err = store.ListEndpoints(client.ID, 10, 0)
   1.185 +		if err != nil {
   1.186 +			t.Fatalf("Error retrieving endpoints from %T: %s", store, err)
   1.187 +		}
   1.188 +		if len(endpoints) != 2 {
   1.189 +			t.Fatalf("Expected %d endpoints, got %+v from %T", 2, endpoints, store)
   1.190 +		}
   1.191 +		sortedEnd := sortedEndpoints(endpoints)
   1.192 +		sort.Sort(sortedEnd)
   1.193 +		endpoints = []Endpoint(sortedEnd)
   1.194 +		success, field, expectation, result = compareEndpoints(endpoint1, endpoints[0])
   1.195 +		if !success {
   1.196 +			t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
   1.197 +		}
   1.198 +		success, field, expectation, result = compareEndpoints(endpoint2, endpoints[1])
   1.199 +		if !success {
   1.200 +			t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
   1.201 +		}
   1.202 +		err = store.RemoveEndpoint(client.ID, endpoint1.ID)
   1.203 +		if err != nil {
   1.204 +			t.Fatalf("Error removing endpoint from client in %T: %s", store, err)
   1.205 +		}
   1.206 +		endpoints, err = store.ListEndpoints(client.ID, 10, 0)
   1.207 +		if err != nil {
   1.208 +			t.Fatalf("Error listing endpoints in %T: %s", store, err)
   1.209 +		}
   1.210 +		if len(endpoints) != 1 {
   1.211 +			t.Fatalf("Expected %d endpoints, got %+v from %T", 1, endpoints, store)
   1.212 +		}
   1.213 +		success, field, expectation, result = compareEndpoints(endpoint2, endpoints[0])
   1.214 +		if !success {
   1.215 +			t.Fatalf("Expected field %s to be %v, but %T returned %v", field, expectation, store, result)
   1.216 +		}
   1.217 +		err = store.RemoveEndpoint(client.ID, endpoint2.ID)
   1.218 +		if err != nil {
   1.219 +			t.Fatalf("Error removing endpoint from client in %T: %s", store, err)
   1.220 +		}
   1.221 +		endpoints, err = store.ListEndpoints(client.ID, 10, 0)
   1.222 +		if err != nil {
   1.223 +			t.Fatalf("Error listing endpoints in %T: %s", store, err)
   1.224 +		}
   1.225 +		if len(endpoints) != 0 {
   1.226 +			t.Fatalf("Expected %d endpoints, got %+v from %T", 0, endpoints, store)
   1.227  		}
   1.228  	}
   1.229  }
   1.230  
   1.231  func TestClientUpdates(t *testing.T) {
   1.232  	t.Parallel()
   1.233 -	variations := 1 << 10
   1.234 +	variations := 1 << 5
   1.235  	client := Client{
   1.236 -		ID:          uuid.NewID(),
   1.237 -		Secret:      "secret",
   1.238 -		RedirectURI: "redirectURI",
   1.239 -		OwnerID:     uuid.NewID(),
   1.240 -		Name:        "name",
   1.241 -		Logo:        "logo",
   1.242 -		Website:     "website",
   1.243 +		ID:      uuid.NewID(),
   1.244 +		Secret:  "secret",
   1.245 +		OwnerID: uuid.NewID(),
   1.246 +		Name:    "name",
   1.247 +		Logo:    "logo",
   1.248 +		Website: "website",
   1.249  	}
   1.250  	for i := 0; i < variations; i++ {
   1.251 -		var secret, redirectURI, name, logo, website string
   1.252 +		var secret, name, logo, website string
   1.253  		change := ClientChange{}
   1.254  		expectation := client
   1.255  		result := client
   1.256 @@ -126,11 +237,6 @@
   1.257  			change.Secret = &secret
   1.258  			expectation.Secret = secret
   1.259  		}
   1.260 -		if i&clientChangeRedirectURI != 0 {
   1.261 -			redirectURI = fmt.Sprintf("redirect-uri-%d", i)
   1.262 -			change.RedirectURI = &redirectURI
   1.263 -			expectation.RedirectURI = redirectURI
   1.264 -		}
   1.265  		if i&clientChangeOwnerID != 0 {
   1.266  			change.OwnerID = uuid.NewID()
   1.267  			expectation.OwnerID = change.OwnerID
   1.268 @@ -153,33 +259,95 @@
   1.269  		result.ApplyChange(change)
   1.270  		match, field, expected, got := compareClients(expectation, result)
   1.271  		if !match {
   1.272 -			t.Errorf("Expected field `%s` to be `%v`, got `%v`", field, expected, got)
   1.273 +			t.Fatalf("Expected field `%s` to be `%v`, got `%v`", field, expected, got)
   1.274  		}
   1.275  		for _, store := range clientStores {
   1.276  			err := store.SaveClient(client)
   1.277  			if err != nil {
   1.278 -				t.Errorf("Error saving client in %T: %s", store, err)
   1.279 +				t.Fatalf("Error saving client in %T: %s", store, err)
   1.280  			}
   1.281  			err = store.UpdateClient(client.ID, change)
   1.282  			if err != nil {
   1.283 -				t.Errorf("Error updating client in %T: %s", store, err)
   1.284 +				t.Fatalf("Error updating client in %T: %s", store, err)
   1.285  			}
   1.286  			retrieved, err := store.GetClient(client.ID)
   1.287  			if err != nil {
   1.288 -				t.Errorf("Error getting profile from %T: %s", store, err)
   1.289 +				t.Fatalf("Error getting profile from %T: %s", store, err)
   1.290  			}
   1.291  			match, field, expected, got = compareClients(expectation, retrieved)
   1.292  			if !match {
   1.293 -				t.Errorf("Expected field `%s` to be `%v`, got `%v` from %T", field, expected, got, store)
   1.294 +				t.Fatalf("Expected field `%s` to be `%v`, got `%v` from %T", field, expected, got, store)
   1.295  			}
   1.296  			err = store.DeleteClient(client.ID)
   1.297  			if err != nil {
   1.298 -				t.Errorf("Error deleting client from %T: %s", store, err)
   1.299 +				t.Fatalf("Error deleting client from %T: %s", store, err)
   1.300  			}
   1.301  			err = store.UpdateClient(client.ID, change)
   1.302  			if err != ErrClientNotFound {
   1.303 -				t.Errorf("Expected ErrClientNotFound, got %v from %T", err, store)
   1.304 +				t.Fatalf("Expected ErrClientNotFound, got %v from %T", err, store)
   1.305  			}
   1.306  		}
   1.307  	}
   1.308  }
   1.309 +
   1.310 +func TestClientEndpointChecks(t *testing.T) {
   1.311 +	t.Parallel()
   1.312 +	client := Client{
   1.313 +		ID:      uuid.NewID(),
   1.314 +		Secret:  "secret",
   1.315 +		OwnerID: uuid.NewID(),
   1.316 +		Name:    "name",
   1.317 +		Logo:    "logo",
   1.318 +		Website: "website",
   1.319 +	}
   1.320 +	uri1, _ := url.Parse("https://www.example.com/first")
   1.321 +	uri2, _ := url.Parse("https://www.example.com/my/full/path")
   1.322 +	endpoint1 := Endpoint{
   1.323 +		ID:       uuid.NewID(),
   1.324 +		ClientID: client.ID,
   1.325 +		Added:    time.Now(),
   1.326 +		URI:      *uri1,
   1.327 +	}
   1.328 +	endpoint2 := Endpoint{
   1.329 +		ID:       uuid.NewID(),
   1.330 +		ClientID: client.ID,
   1.331 +		Added:    time.Now(),
   1.332 +		URI:      *uri2,
   1.333 +	}
   1.334 +	candidates := map[string]bool{
   1.335 +		"https://www.example.com/":                 false,
   1.336 +		"https://www.example.com/first":            true,
   1.337 +		"https://www.example.com/first/extra/path": true,
   1.338 +		"https://www.example.com/my":               false,
   1.339 +		"https://www.example.com/my/full/path":     true,
   1.340 +	}
   1.341 +	for _, store := range clientStores {
   1.342 +		err := store.SaveClient(client)
   1.343 +		if err != nil {
   1.344 +			t.Fatalf("Error saving client in %T: %s", store, err)
   1.345 +		}
   1.346 +		err = store.AddEndpoint(client.ID, endpoint1)
   1.347 +		if err != nil {
   1.348 +			t.Fatalf("Error saving endpoint in %T: %s", store, err)
   1.349 +		}
   1.350 +		err = store.AddEndpoint(client.ID, endpoint2)
   1.351 +		if err != nil {
   1.352 +			t.Fatalf("Error saving endpoint in %T: %s", store, err)
   1.353 +		}
   1.354 +		for candidate, expectation := range candidates {
   1.355 +			result, err := store.CheckEndpoint(client.ID, candidate)
   1.356 +			if err != nil {
   1.357 +				t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
   1.358 +			}
   1.359 +			if result != expectation {
   1.360 +				expectStr := "no"
   1.361 +				resultStr := "a"
   1.362 +				if expectation {
   1.363 +					expectStr = "a"
   1.364 +					resultStr = "no"
   1.365 +				}
   1.366 +				t.Errorf("Expected %s match for %s in %T, got %s match", expectStr, candidate, store, resultStr)
   1.367 +			}
   1.368 +		}
   1.369 +	}
   1.370 +}