package auth

import (
	"testing"
	"time"

	"code.secondbit.org/uuid"
)

var tokenStores = []tokenStore{NewMemstore()}

func compareTokens(token1, token2 Token) (success bool, field string, val1, val2 interface{}) {
	if token1.AccessToken != token2.AccessToken {
		return false, "access token", token1.AccessToken, token2.AccessToken
	}
	if token1.RefreshToken != token2.RefreshToken {
		return false, "refresh token", token1.RefreshToken, token2.RefreshToken
	}
	if !token1.Created.Equal(token2.Created) {
		return false, "created", token1.Created, token2.Created
	}
	if token1.CreatedFrom != token2.CreatedFrom {
		return false, "created from", token1.CreatedFrom, token2.CreatedFrom
	}
	if token1.ExpiresIn != token2.ExpiresIn {
		return false, "expires in", token1.ExpiresIn, token2.ExpiresIn
	}
	if token1.RefreshExpiresIn != token2.RefreshExpiresIn {
		return false, "refresh expires in", token1.RefreshExpiresIn, token2.RefreshExpiresIn
	}
	if token1.TokenType != token2.TokenType {
		return false, "token type", token1.TokenType, token2.TokenType
	}
	if token1.Scope != token2.Scope {
		return false, "scope", token1.Scope, token2.Scope
	}
	if !token1.ProfileID.Equal(token2.ProfileID) {
		return false, "profile ID", token1.ProfileID, token2.ProfileID
	}
	if token1.Revoked != token2.Revoked {
		return false, "revoked", token1.Revoked, token2.Revoked
	}
	return true, "", nil, nil
}

func TestTokenStoreSuccess(t *testing.T) {
	t.Parallel()
	token := Token{
		AccessToken:  "access",
		RefreshToken: "refresh",
		Created:      time.Now(),
		ExpiresIn:    3600,
		TokenType:    "bearer",
		Scope:        "scope",
		ProfileID:    uuid.NewID(),
	}
	for _, store := range tokenStores {
		err := store.saveToken(token)
		if err != nil {
			t.Errorf("Error saving token to %T: %s", store, err)
		}
		err = store.saveToken(token)
		if err != ErrTokenAlreadyExists {
			t.Errorf("Expected ErrTokenAlreadyExists from %T, got %s", store, err)
		}
		retrievedAccess, err := store.getToken(token.AccessToken, false)
		if err != nil {
			t.Errorf("Error retrieving token from %T: %s", store, err)
		}
		success, field, expectation, result := compareTokens(token, retrievedAccess)
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		retrievedRefresh, err := store.getToken(token.RefreshToken, true)
		if err != nil {
			t.Errorf("Error retrieving refresh token from %T: %s", store, err)
		}
		success, field, expectation, result = compareTokens(token, retrievedRefresh)
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		retrievedProfile, err := store.getTokensByProfileID(token.ProfileID, 25, 0)
		if err != nil {
			t.Errorf("Error retrieving token by profile from %T: %s", store, err)
		}
		if len(retrievedProfile) != 1 {
			t.Errorf("Expected 1 token retrieved by profile ID from %T, got %+v", store, retrievedProfile)
		}
		success, field, expectation, result = compareTokens(token, retrievedProfile[0])
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		err = store.revokeToken(token.AccessToken, false)
		if err != nil {
			t.Errorf("Error revoking token in %T: %s", store, err)
		}
		retrievedRevoked, err := store.getToken(token.AccessToken, false)
		if err != nil {
			t.Errorf("Error retrieving token from %T: %s", store, err)
		}
		token.Revoked = true
		success, field, expectation, result = compareTokens(token, retrievedRevoked)
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		// TODO(paddy): test revoking by refresh token.
		err = store.removeToken(token.AccessToken)
		if err != nil {
			t.Errorf("Error removing token from %T: %s", store, err)
		}
		_, err = store.getToken(token.AccessToken, false)
		if err != ErrTokenNotFound {
			t.Errorf("Expected ErrTokenNotFound from %T, got %s", store, err)
		}
		_, err = store.getToken(token.RefreshToken, true)
		if err != ErrTokenNotFound {
			t.Errorf("Expected ErrTokenNotFound from %T, got %s", store, err)
		}
		retrievedProfile, err = store.getTokensByProfileID(token.ProfileID, 25, 0)
		if err != nil {
			t.Errorf("Error retrieving token by profile from %T: %s", store, err)
		}
		if len(retrievedProfile) != 0 {
			t.Errorf("Expected list of 0 tokens from %T, got %+v", store, retrievedProfile)
		}
		err = store.removeToken(token.AccessToken)
		if err != ErrTokenNotFound {
			t.Errorf("Expected ErrTokenNotFound from %T, got %s", store, err)
		}
		err = store.revokeToken(token.AccessToken, false)
		if err != ErrTokenNotFound {
			t.Errorf("Expected ErrTokenNotFound from %T, got %s", store, err)
		}
		err = store.revokeToken(token.RefreshToken, true)
		if err != ErrTokenNotFound {
			t.Errorf("Expected ErrTokenNotFound from %T, got %s", store, err)
		}
	}
}
