package auth

import (
	"testing"
	"time"

	"code.secondbit.org/uuid.hg"
)

var sessionStores = []sessionStore{NewMemstore()}

func compareSessions(session1, session2 Session) (success bool, field string, val1, val2 interface{}) {
	if session1.ID != session2.ID {
		return false, "ID", session1.ID, session2.ID
	}
	if session1.IP != session2.IP {
		return false, "IP", session1.IP, session2.IP
	}
	if session1.UserAgent != session2.UserAgent {
		return false, "UserAgent", session1.UserAgent, session2.UserAgent
	}
	if !session1.ProfileID.Equal(session2.ProfileID) {
		return false, "ProfileID", session1.ProfileID, session2.ProfileID
	}
	if !session1.Created.Equal(session2.Created) {
		return false, "Created", session1.Created, session2.Created
	}
	if session1.Login != session2.Login {
		return false, "Login", session1.Login, session2.Login
	}
	if session1.Active != session2.Active {
		return false, "Active", session1.Active, session2.Active
	}
	return true, "", nil, nil
}

func TestSessionStoreSuccess(t *testing.T) {
	t.Parallel()
	session := Session{
		ID:        uuid.NewID().String() + uuid.NewID().String(),
		IP:        "127.0.0.1",
		UserAgent: "TestRunner",
		ProfileID: uuid.NewID(),
		Created:   time.Now(),
		Login:     "test@example.com",
		Active:    true,
	}
	for _, store := range sessionStores {
		err := store.createSession(session)
		if err != nil {
			t.Errorf("Error saving session to %T: %s", store, err)
		}
		err = store.createSession(session)
		if err != ErrSessionAlreadyExists {
			t.Errorf("Expected ErrSessionAlreadyExists from %T, got %s", store, err)
		}
		retrieved, err := store.getSession(session.ID)
		if err != nil {
			t.Errorf("Error retrieving session from %T: %s", store, err)
		}
		success, field, expectation, result := compareSessions(session, retrieved)
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		retrievedList, err := store.listSessions(session.ProfileID, time.Time{}, 10)
		if err != nil {
			t.Errorf("Error retrieving sessions by profile from %T: %s", store, err)
		}
		if len(retrievedList) != 1 {
			t.Errorf("Expected 1 session retrieved by profile from %T, got %d", store, len(retrievedList))
		}
		success, field, expectation, result = compareSessions(session, retrievedList[0])
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		err = store.removeSession(session.ID)
		if err != nil {
			t.Errorf("Error removing session from %T: %s", store, err)
		}
		retrieved, err = store.getSession(session.ID)
		if err != ErrSessionNotFound {
			t.Errorf("Expected ErrSessionNotFound from %T, got %s", store, err)
		}
		retrievedList, err = store.listSessions(session.ProfileID, time.Time{}, 10)
		if err != nil {
			t.Errorf("Error retrieving sessions by profile from %T: %s", store, err)
		}
		if len(retrievedList) != 0 {
			t.Errorf("Expected 0 sessions retrieved by profile from %T, got %d", store, len(retrievedList))
		}
		err = store.removeSession(session.ID)
		if err != ErrSessionNotFound {
			t.Errorf("Expected ErrSessionNotFound from %T, got %s", store, err)
		}
	}
}
