auth

Paddy 2014-10-26 Parent:e45bfa2abc00 Child:03c9890f99c5

58:b3cd7765a7c8 Browse Files

Require full URLs for Endpoints. The spec says that we SHOULD require full URLs for redirection, but we _can_ offer the ability to set a URL as a "partial URL" if we really must. I see no particular reason to do this, so I've simplified the code by pulling that option out. This means that URLs (as long as they're normalized, which I've filed a bug in the codebase to do) can be checked using simple string comparison, which makes the likelihood of bugs across clientStorage implementations a lot lower.

client.go client_test.go context.go http.go http_test.go

     1.1 --- a/client.go	Sun Oct 26 00:53:36 2014 -0400
     1.2 +++ b/client.go	Sun Oct 26 03:22:41 2014 -0400
     1.3 @@ -3,7 +3,6 @@
     1.4  import (
     1.5  	"errors"
     1.6  	"net/url"
     1.7 -	"strings"
     1.8  	"time"
     1.9  
    1.10  	"code.secondbit.org/uuid"
    1.11 @@ -144,7 +143,7 @@
    1.12  
    1.13  	addEndpoint(client uuid.ID, endpoint Endpoint) error
    1.14  	removeEndpoint(client, endpoint uuid.ID) error
    1.15 -	checkEndpoint(client uuid.ID, endpoint string, strict bool) (bool, error)
    1.16 +	checkEndpoint(client uuid.ID, endpoint string) (bool, error)
    1.17  	listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error)
    1.18  	countEndpoints(client uuid.ID) (int64, error)
    1.19  }
    1.20 @@ -246,13 +245,11 @@
    1.21  	return nil
    1.22  }
    1.23  
    1.24 -func (m *memstore) checkEndpoint(client uuid.ID, endpoint string, strict bool) (bool, error) {
    1.25 +func (m *memstore) checkEndpoint(client uuid.ID, endpoint string) (bool, error) {
    1.26  	m.endpointLock.RLock()
    1.27  	defer m.endpointLock.RUnlock()
    1.28  	for _, candidate := range m.endpoints[client.String()] {
    1.29 -		if !strict && strings.HasPrefix(endpoint, candidate.URI.String()) {
    1.30 -			return true, nil
    1.31 -		} else if strict && endpoint == candidate.URI.String() {
    1.32 +		if endpoint == candidate.URI.String() {
    1.33  			return true, nil
    1.34  		}
    1.35  	}
     2.1 --- a/client_test.go	Sun Oct 26 00:53:36 2014 -0400
     2.2 +++ b/client_test.go	Sun Oct 26 03:22:41 2014 -0400
     2.3 @@ -317,7 +317,7 @@
     2.4  	candidates := map[string]bool{
     2.5  		"https://www.example.com/":                 false,
     2.6  		"https://www.example.com/first":            true,
     2.7 -		"https://www.example.com/first/extra/path": true,
     2.8 +		"https://www.example.com/first/extra/path": false,
     2.9  		"https://www.example.com/my":               false,
    2.10  		"https://www.example.com/my/full/path":     true,
    2.11  	}
    2.12 @@ -335,7 +335,7 @@
    2.13  			t.Fatalf("Error saving endpoint in %T: %s", store, err)
    2.14  		}
    2.15  		for candidate, expectation := range candidates {
    2.16 -			result, err := store.checkEndpoint(client.ID, candidate, false)
    2.17 +			result, err := store.checkEndpoint(client.ID, candidate)
    2.18  			if err != nil {
    2.19  				t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
    2.20  			}
    2.21 @@ -397,7 +397,7 @@
    2.22  			t.Fatalf("Error saving endpoint in %T: %s", store, err)
    2.23  		}
    2.24  		for candidate, expectation := range candidates {
    2.25 -			result, err := store.checkEndpoint(client.ID, candidate, true)
    2.26 +			result, err := store.checkEndpoint(client.ID, candidate)
    2.27  			if err != nil {
    2.28  				t.Fatalf("Error checking endpoint %s in %T: %s", candidate, store, err)
    2.29  			}
     3.1 --- a/context.go	Sun Oct 26 00:53:36 2014 -0400
     3.2 +++ b/context.go	Sun Oct 26 03:22:41 2014 -0400
     3.3 @@ -98,14 +98,13 @@
     3.4  }
     3.5  
     3.6  // CheckEndpoint finds Endpoints in the clientStore associated with the Context that belong
     3.7 -// to the Client specified by the passed ID and match the URI passed. If strict is true, only
     3.8 -// exact matches for the URI will be returned. If it is false, matches are performed according
     3.9 -// to RFC 3986 Section 6.
    3.10 -func (c Context) CheckEndpoint(client uuid.ID, URI string, strict bool) (bool, error) {
    3.11 +// to the Client specified by the passed ID and match the URI passed. URI matches must be
    3.12 +// performed according to RFC 3986 Section 6.
    3.13 +func (c Context) CheckEndpoint(client uuid.ID, URI string) (bool, error) {
    3.14  	if c.clients == nil {
    3.15  		return false, ErrNoClientStore
    3.16  	}
    3.17 -	return c.clients.checkEndpoint(client, URI, strict)
    3.18 +	return c.clients.checkEndpoint(client, URI)
    3.19  }
    3.20  
    3.21  // ListEndpoints finds Endpoints in the clientStore associated with the Context that belong
     4.1 --- a/http.go	Sun Oct 26 00:53:36 2014 -0400
     4.2 +++ b/http.go	Sun Oct 26 03:22:41 2014 -0400
     4.3 @@ -46,21 +46,9 @@
     4.4  	}
     4.5  	redirectURI := r.URL.Query().Get("redirect_uri")
     4.6  	var validURI bool
     4.7 -	if redirectURI != "" && numEndpoints > 1 {
     4.8 -		// if there's more than one registered endpoint, we need to match the
     4.9 -		// entire thing, character for character. So use strict checking.
    4.10 -		validURI, err = context.CheckEndpoint(clientID, redirectURI, true)
    4.11 -		if err != nil {
    4.12 -			w.WriteHeader(http.StatusInternalServerError)
    4.13 -			context.Render(w, getGrantTemplateName, map[string]interface{}{
    4.14 -				"internal_error": err,
    4.15 -			})
    4.16 -			return
    4.17 -		}
    4.18 -	} else if redirectURI != "" && numEndpoints == 1 {
    4.19 -		// if there's exactly one endpoint, we can match only the prefix of it,
    4.20 -		// so don't use strict checking.
    4.21 -		validURI, err = context.CheckEndpoint(clientID, redirectURI, false)
    4.22 +	if redirectURI != "" {
    4.23 +		// BUG(paddy): We really should normalize URIs before trying to compare them.
    4.24 +		validURI, err = context.CheckEndpoint(clientID, redirectURI)
    4.25  		if err != nil {
    4.26  			w.WriteHeader(http.StatusInternalServerError)
    4.27  			context.Render(w, getGrantTemplateName, map[string]interface{}{
     5.1 --- a/http_test.go	Sun Oct 26 00:53:36 2014 -0400
     5.2 +++ b/http_test.go	Sun Oct 26 03:22:41 2014 -0400
     5.3 @@ -15,7 +15,6 @@
     5.4  	scopeSet = 1 << iota
     5.5  	stateSet
     5.6  	uriSet
     5.7 -	uriExact
     5.8  )
     5.9  
    5.10  func TestGetGrantCodeSuccess(t *testing.T) {
    5.11 @@ -59,18 +58,14 @@
    5.12  	if err != nil {
    5.13  		t.Fatal("Can't build request:", err)
    5.14  	}
    5.15 -	for i := 0; i < 1<<4; i++ {
    5.16 +	for i := 0; i < 1<<3; i++ {
    5.17  		w := httptest.NewRecorder()
    5.18  		params := url.Values{}
    5.19  		// see OAuth 2.0 spec, section 4.1.1
    5.20  		params.Set("response_type", "code")
    5.21  		params.Set("client_id", client.ID.String())
    5.22  		if i&uriSet != 0 {
    5.23 -			if i&uriExact != 0 {
    5.24 -				params.Set("redirect_uri", endpoint.URI.String())
    5.25 -			} else {
    5.26 -				params.Set("redirect_uri", endpoint.URI.String()+"/inexact")
    5.27 -			}
    5.28 +			params.Set("redirect_uri", endpoint.URI.String())
    5.29  		}
    5.30  		if i&scopeSet != 0 {
    5.31  			params.Set("scope", "testscope")