auth
auth/token_test.go
Clean up sessions and tokens after Profile is deleted. Add a terminateSessionsByProfile method to our sessionStore to mark Sessions associated with a Profile as inactive. Implement memstore and postgres implementations of the terminateSessionsByProfile method. Add a TerminateSessionsByProfile wrapper method to Context. Add a revokeTokensByProfileID method to our tokenStore to mark Tokens associated with a Profile as revoked. Implement memstore and postgres implementation of the revokeTokensByProfileID method. Add a RevokeTokensByProfileID wrapper method to Context. Call our RevokeTokensByProfileID and TerminateSessionsByProfile methods after a Profile is deleted, to clean up the Tokens and Sessions associated with it.
| paddy@28 | 1 package auth |
| paddy@28 | 2 |
| paddy@28 | 3 import ( |
| paddy@155 | 4 "os" |
| paddy@28 | 5 "testing" |
| paddy@28 | 6 "time" |
| paddy@28 | 7 |
| paddy@107 | 8 "code.secondbit.org/uuid.hg" |
| paddy@28 | 9 ) |
| paddy@28 | 10 |
| paddy@155 | 11 func init() { |
| paddy@155 | 12 if os.Getenv("PG_TEST_DB") != "" { |
| paddy@155 | 13 p, err := NewPostgres(os.Getenv("PG_TEST_DB")) |
| paddy@155 | 14 if err != nil { |
| paddy@155 | 15 panic(err) |
| paddy@155 | 16 } |
| paddy@155 | 17 tokenStores = append(tokenStores, &p) |
| paddy@155 | 18 } |
| paddy@155 | 19 } |
| paddy@155 | 20 |
| paddy@57 | 21 var tokenStores = []tokenStore{NewMemstore()} |
| paddy@28 | 22 |
| paddy@35 | 23 func compareTokens(token1, token2 Token) (success bool, field string, val1, val2 interface{}) { |
| paddy@35 | 24 if token1.AccessToken != token2.AccessToken { |
| paddy@35 | 25 return false, "access token", token1.AccessToken, token2.AccessToken |
| paddy@35 | 26 } |
| paddy@35 | 27 if token1.RefreshToken != token2.RefreshToken { |
| paddy@35 | 28 return false, "refresh token", token1.RefreshToken, token2.RefreshToken |
| paddy@35 | 29 } |
| paddy@35 | 30 if !token1.Created.Equal(token2.Created) { |
| paddy@35 | 31 return false, "created", token1.Created, token2.Created |
| paddy@35 | 32 } |
| paddy@97 | 33 if token1.CreatedFrom != token2.CreatedFrom { |
| paddy@97 | 34 return false, "created from", token1.CreatedFrom, token2.CreatedFrom |
| paddy@97 | 35 } |
| paddy@35 | 36 if token1.ExpiresIn != token2.ExpiresIn { |
| paddy@35 | 37 return false, "expires in", token1.ExpiresIn, token2.ExpiresIn |
| paddy@35 | 38 } |
| paddy@35 | 39 if token1.TokenType != token2.TokenType { |
| paddy@35 | 40 return false, "token type", token1.TokenType, token2.TokenType |
| paddy@35 | 41 } |
| paddy@135 | 42 if len(token1.Scopes) != len(token2.Scopes) { |
| paddy@135 | 43 return false, "scopes", token1.Scopes, token2.Scopes |
| paddy@135 | 44 } |
| paddy@135 | 45 for pos, scope := range token1.Scopes { |
| paddy@135 | 46 if scope != token2.Scopes[pos] { |
| paddy@135 | 47 return false, "scopes", token1.Scopes, token2.Scopes |
| paddy@135 | 48 } |
| paddy@35 | 49 } |
| paddy@35 | 50 if !token1.ProfileID.Equal(token2.ProfileID) { |
| paddy@35 | 51 return false, "profile ID", token1.ProfileID, token2.ProfileID |
| paddy@35 | 52 } |
| paddy@97 | 53 if token1.Revoked != token2.Revoked { |
| paddy@97 | 54 return false, "revoked", token1.Revoked, token2.Revoked |
| paddy@97 | 55 } |
| paddy@35 | 56 return true, "", nil, nil |
| paddy@35 | 57 } |
| paddy@35 | 58 |
| paddy@28 | 59 func TestTokenStoreSuccess(t *testing.T) { |
| paddy@37 | 60 t.Parallel() |
| paddy@28 | 61 token := Token{ |
| paddy@28 | 62 AccessToken: "access", |
| paddy@28 | 63 RefreshToken: "refresh", |
| paddy@149 | 64 Created: time.Now().Round(time.Millisecond), |
| paddy@28 | 65 ExpiresIn: 3600, |
| paddy@28 | 66 TokenType: "bearer", |
| paddy@135 | 67 Scopes: []string{"scope"}, |
| paddy@28 | 68 ProfileID: uuid.NewID(), |
| paddy@28 | 69 } |
| paddy@35 | 70 for _, store := range tokenStores { |
| paddy@116 | 71 context := Context{tokens: store} |
| paddy@127 | 72 retrievedAccess, err := context.GetToken(token.AccessToken, false) |
| paddy@127 | 73 if err == nil { |
| paddy@127 | 74 t.Errorf("Expected ErrTokenNotFound from %T, got %+v", store, retrievedAccess) |
| paddy@127 | 75 } else if err != ErrTokenNotFound { |
| paddy@127 | 76 t.Errorf("Expected ErrTokenNotFound from %T, got %s", store, err) |
| paddy@127 | 77 } |
| paddy@127 | 78 retrievedRefresh, err := context.GetToken(token.RefreshToken, true) |
| paddy@127 | 79 if err == nil { |
| paddy@127 | 80 t.Errorf("Expected ErrTokenNotFound from %T, got %+v", store, retrievedRefresh) |
| paddy@127 | 81 } else if err != ErrTokenNotFound { |
| paddy@127 | 82 t.Errorf("Expected ErrTokenNotFound from %T, got %s", store, err) |
| paddy@127 | 83 } |
| paddy@127 | 84 err = context.RevokeToken(token.AccessToken, false) |
| paddy@127 | 85 if err != ErrTokenNotFound { |
| paddy@127 | 86 t.Errorf("Expected ErrTokenNotFound from %T, got %s", store, err) |
| paddy@127 | 87 } |
| paddy@127 | 88 err = context.RevokeToken(token.RefreshToken, true) |
| paddy@127 | 89 if err != ErrTokenNotFound { |
| paddy@127 | 90 t.Errorf("Expected ErrTokenNotFound from %T, got %s", store, err) |
| paddy@127 | 91 } |
| paddy@127 | 92 err = context.SaveToken(token) |
| paddy@28 | 93 if err != nil { |
| paddy@37 | 94 t.Errorf("Error saving token to %T: %s", store, err) |
| paddy@37 | 95 } |
| paddy@116 | 96 err = context.SaveToken(token) |
| paddy@37 | 97 if err != ErrTokenAlreadyExists { |
| paddy@37 | 98 t.Errorf("Expected ErrTokenAlreadyExists from %T, got %s", store, err) |
| paddy@28 | 99 } |
| paddy@127 | 100 retrievedAccess, err = context.GetToken(token.AccessToken, false) |
| paddy@28 | 101 if err != nil { |
| paddy@35 | 102 t.Errorf("Error retrieving token from %T: %s", store, err) |
| paddy@28 | 103 } |
| paddy@35 | 104 success, field, expectation, result := compareTokens(token, retrievedAccess) |
| paddy@35 | 105 if !success { |
| paddy@35 | 106 t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store) |
| paddy@35 | 107 } |
| paddy@127 | 108 retrievedRefresh, err = context.GetToken(token.RefreshToken, true) |
| paddy@28 | 109 if err != nil { |
| paddy@35 | 110 t.Errorf("Error retrieving refresh token from %T: %s", store, err) |
| paddy@28 | 111 } |
| paddy@35 | 112 success, field, expectation, result = compareTokens(token, retrievedRefresh) |
| paddy@35 | 113 if !success { |
| paddy@35 | 114 t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store) |
| paddy@35 | 115 } |
| paddy@116 | 116 retrievedProfile, err := context.GetTokensByProfileID(token.ProfileID, 25, 0) |
| paddy@28 | 117 if err != nil { |
| paddy@35 | 118 t.Errorf("Error retrieving token by profile from %T: %s", store, err) |
| paddy@28 | 119 } |
| paddy@28 | 120 if len(retrievedProfile) != 1 { |
| paddy@35 | 121 t.Errorf("Expected 1 token retrieved by profile ID from %T, got %+v", store, retrievedProfile) |
| paddy@28 | 122 } |
| paddy@35 | 123 success, field, expectation, result = compareTokens(token, retrievedProfile[0]) |
| paddy@35 | 124 if !success { |
| paddy@35 | 125 t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store) |
| paddy@35 | 126 } |
| paddy@116 | 127 err = context.RevokeToken(token.AccessToken, false) |
| paddy@97 | 128 if err != nil { |
| paddy@97 | 129 t.Errorf("Error revoking token in %T: %s", store, err) |
| paddy@97 | 130 } |
| paddy@116 | 131 retrievedRevoked, err := context.GetToken(token.AccessToken, false) |
| paddy@97 | 132 if err != nil { |
| paddy@97 | 133 t.Errorf("Error retrieving token from %T: %s", store, err) |
| paddy@97 | 134 } |
| paddy@97 | 135 token.Revoked = true |
| paddy@97 | 136 success, field, expectation, result = compareTokens(token, retrievedRevoked) |
| paddy@97 | 137 if !success { |
| paddy@97 | 138 t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store) |
| paddy@97 | 139 } |
| paddy@127 | 140 err = context.RevokeToken(token.RefreshToken, true) |
| paddy@28 | 141 if err != nil { |
| paddy@127 | 142 t.Errorf("Error revoking token in %T: %s", store, err) |
| paddy@28 | 143 } |
| paddy@127 | 144 retrievedRevoked, err = context.GetToken(token.RefreshToken, true) |
| paddy@127 | 145 if err != nil { |
| paddy@127 | 146 t.Errorf("Error retrieving token from %T: %s", store, err) |
| paddy@28 | 147 } |
| paddy@127 | 148 token.RefreshRevoked = true |
| paddy@127 | 149 success, field, expectation, result = compareTokens(token, retrievedRevoked) |
| paddy@127 | 150 if !success { |
| paddy@127 | 151 t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store) |
| paddy@97 | 152 } |
| paddy@28 | 153 } |
| paddy@28 | 154 } |
| paddy@128 | 155 |
| paddy@128 | 156 // BUG(paddy): We need to test the refreshTokenValidate function. |
| paddy@128 | 157 // BUG(paddy): We need to test the refreshTokenInvalidate function. |