auth

Paddy 2014-10-22 Parent:28d48fdb0dd1 Child:cfab12566289

54:0f80a3e391b8 Browse Files

Update CheckEndpoints for strict checking, add CountEndpoints. Create a "strict" mode for CheckEndpoints that will only return true on an exact match, and update the memstore implementation accordingly. Add tests to make sure that the strict mode is adhered to. We need this mode because in certain situations (e.g., the client has more than one endpoint registered), the spec demands a full-string comparison. Add a CountEndpoints method to the ClientStore that will return the number of endpoints registered for a specific client. As we just mentioned, the rules for how a redirect URI is validated depend upon the number of endpoints a client has registered, so we need to be able to get at that number.

client.go client_test.go

     1.1 --- a/client.go	Thu Oct 16 00:18:14 2014 -0400
     1.2 +++ b/client.go	Wed Oct 22 00:26:39 2014 -0400
     1.3 @@ -125,8 +125,9 @@
     1.4  
     1.5  	AddEndpoint(client uuid.ID, endpoint Endpoint) error
     1.6  	RemoveEndpoint(client, endpoint uuid.ID) error
     1.7 -	CheckEndpoint(client uuid.ID, endpoint string) (bool, error)
     1.8 +	CheckEndpoint(client uuid.ID, endpoint string, strict bool) (bool, error)
     1.9  	ListEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error)
    1.10 +	CountEndpoints(client uuid.ID) (int64, error)
    1.11  }
    1.12  
    1.13  func (m *Memstore) GetClient(id uuid.ID) (Client, error) {
    1.14 @@ -226,11 +227,13 @@
    1.15  	return nil
    1.16  }
    1.17  
    1.18 -func (m *Memstore) CheckEndpoint(client uuid.ID, endpoint string) (bool, error) {
    1.19 +func (m *Memstore) CheckEndpoint(client uuid.ID, endpoint string, strict bool) (bool, error) {
    1.20  	m.endpointLock.RLock()
    1.21  	defer m.endpointLock.RUnlock()
    1.22  	for _, candidate := range m.endpoints[client.String()] {
    1.23 -		if strings.HasPrefix(endpoint, candidate.URI.String()) {
    1.24 +		if !strict && strings.HasPrefix(endpoint, candidate.URI.String()) {
    1.25 +			return true, nil
    1.26 +		} else if strict && endpoint == candidate.URI.String() {
    1.27  			return true, nil
    1.28  		}
    1.29  	}
    1.30 @@ -242,3 +245,9 @@
    1.31  	defer m.endpointLock.RUnlock()
    1.32  	return m.endpoints[client.String()], nil
    1.33  }
    1.34 +
    1.35 +func (m *Memstore) CountEndpoints(client uuid.ID) (int64, error) {
    1.36 +	m.endpointLock.RLock()
    1.37 +	defer m.endpointLock.RUnlock()
    1.38 +	return int64(len(m.endpoints[client.String()])), nil
    1.39 +}
     2.1 --- a/client_test.go	Thu Oct 16 00:18:14 2014 -0400
     2.2 +++ b/client_test.go	Wed Oct 22 00:26:39 2014 -0400
     2.3 @@ -6,8 +6,8 @@
     2.4  	"testing"
     2.5  	"time"
     2.6  
     2.7 +	"sort"
     2.8  	"code.secondbit.org/uuid"
     2.9 -	"sort"
    2.10  )
    2.11  
    2.12  const (
    2.13 @@ -335,7 +335,69 @@
    2.14  			t.Fatalf("Error saving endpoint in %T: %s", store, err)
    2.15  		}
    2.16  		for candidate, expectation := range candidates {
    2.17 -			result, err := store.CheckEndpoint(client.ID, candidate)
    2.18 +			result, err := store.CheckEndpoint(client.ID, candidate, false)
    2.19 +			if err != nil {
    2.20 +				t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
    2.21 +			}
    2.22 +			if result != expectation {
    2.23 +				expectStr := "no"
    2.24 +				resultStr := "a"
    2.25 +				if expectation {
    2.26 +					expectStr = "a"
    2.27 +					resultStr = "no"
    2.28 +				}
    2.29 +				t.Errorf("Expected %s match for %s in %T, got %s match", expectStr, candidate, store, resultStr)
    2.30 +			}
    2.31 +		}
    2.32 +	}
    2.33 +}
    2.34 +
    2.35 +func TestClientEndpointChecksStrict(t *testing.T) {
    2.36 +	t.Parallel()
    2.37 +	client := Client{
    2.38 +		ID:      uuid.NewID(),
    2.39 +		Secret:  "secret",
    2.40 +		OwnerID: uuid.NewID(),
    2.41 +		Name:    "name",
    2.42 +		Logo:    "logo",
    2.43 +		Website: "website",
    2.44 +	}
    2.45 +	uri1, _ := url.Parse("https://www.example.com/first")
    2.46 +	uri2, _ := url.Parse("https://www.example.com/my/full/path")
    2.47 +	endpoint1 := Endpoint{
    2.48 +		ID:       uuid.NewID(),
    2.49 +		ClientID: client.ID,
    2.50 +		Added:    time.Now(),
    2.51 +		URI:      *uri1,
    2.52 +	}
    2.53 +	endpoint2 := Endpoint{
    2.54 +		ID:       uuid.NewID(),
    2.55 +		ClientID: client.ID,
    2.56 +		Added:    time.Now(),
    2.57 +		URI:      *uri2,
    2.58 +	}
    2.59 +	candidates := map[string]bool{
    2.60 +		"https://www.example.com/":                 false,
    2.61 +		"https://www.example.com/first":            true,
    2.62 +		"https://www.example.com/first/extra/path": false,
    2.63 +		"https://www.example.com/my":               false,
    2.64 +		"https://www.example.com/my/full/path":     true,
    2.65 +	}
    2.66 +	for _, store := range clientStores {
    2.67 +		err := store.SaveClient(client)
    2.68 +		if err != nil {
    2.69 +			t.Fatalf("Error saving client in %T: %s", store, err)
    2.70 +		}
    2.71 +		err = store.AddEndpoint(client.ID, endpoint1)
    2.72 +		if err != nil {
    2.73 +			t.Fatalf("Error saving endpoint in %T: %s", store, err)
    2.74 +		}
    2.75 +		err = store.AddEndpoint(client.ID, endpoint2)
    2.76 +		if err != nil {
    2.77 +			t.Fatalf("Error saving endpoint in %T: %s", store, err)
    2.78 +		}
    2.79 +		for candidate, expectation := range candidates {
    2.80 +			result, err := store.CheckEndpoint(client.ID, candidate, true)
    2.81  			if err != nil {
    2.82  				t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
    2.83  			}