package auth

import (
	"os"
	"testing"
	"time"

	"code.secondbit.org/uuid.hg"
)

func init() {
	if os.Getenv("PG_TEST_DB") != "" {
		p, err := NewPostgres(os.Getenv("PG_TEST_DB"))
		if err != nil {
			panic(err)
		}
		sessionStores = append(sessionStores, &p)
	}
}

var sessionStores = []sessionStore{NewMemstore()}

func compareSessions(session1, session2 Session) (success bool, field string, val1, val2 interface{}) {
	if session1.ID != session2.ID {
		return false, "ID", session1.ID, session2.ID
	}
	if session1.IP != session2.IP {
		return false, "IP", session1.IP, session2.IP
	}
	if session1.UserAgent != session2.UserAgent {
		return false, "UserAgent", session1.UserAgent, session2.UserAgent
	}
	if !session1.ProfileID.Equal(session2.ProfileID) {
		return false, "ProfileID", session1.ProfileID, session2.ProfileID
	}
	if !session1.Created.Equal(session2.Created) {
		return false, "Created", session1.Created, session2.Created
	}
	if !session1.Expires.Equal(session2.Expires) {
		return false, "Expires", session1.Expires, session2.Expires
	}
	if session1.Login != session2.Login {
		return false, "Login", session1.Login, session2.Login
	}
	if session1.Active != session2.Active {
		return false, "Active", session1.Active, session2.Active
	}
	if session1.CSRFToken != session2.CSRFToken {
		return false, "CSRFToken", session1.CSRFToken, session2.CSRFToken
	}
	return true, "", nil, nil
}

func TestSessionStoreSuccess(t *testing.T) {
	t.Parallel()
	session := Session{
		ID:        uuid.NewID().String() + uuid.NewID().String(),
		IP:        "127.0.0.1",
		UserAgent: "TestRunner",
		ProfileID: uuid.NewID(),
		Created:   time.Now().Round(time.Millisecond),
		Login:     "test@example.com",
		Active:    true,
	}
	for _, store := range sessionStores {
		context := Context{sessions: store}
		err := context.CreateSession(session)
		if err != nil {
			t.Errorf("Error saving session to %T: %s", store, err)
		}
		err = context.CreateSession(session)
		if err != ErrSessionAlreadyExists {
			t.Errorf("Expected ErrSessionAlreadyExists from %T, got %s", store, err)
		}
		retrieved, err := context.GetSession(session.ID)
		if err != nil {
			t.Errorf("Error retrieving session from %T: %s", store, err)
		}
		success, field, expectation, result := compareSessions(session, retrieved)
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		retrievedList, err := context.ListSessions(session.ProfileID, time.Time{}, 10)
		if err != nil {
			t.Errorf("Error retrieving sessions by profile from %T: %s", store, err)
		}
		if len(retrievedList) != 1 {
			t.Errorf("Expected 1 session retrieved by profile from %T, got %d", store, len(retrievedList))
		}
		success, field, expectation, result = compareSessions(session, retrievedList[0])
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		err = context.TerminateSession(session.ID)
		if err != nil {
			t.Errorf("Error terminating session in %T: %s", store, err)
		}
		retrieved, err = context.GetSession(session.ID)
		if err != nil {
			t.Errorf("Error retrieving session from %T: %s", store, err)
		}
		expected := session
		expected.Active = false
		success, field, expectation, result = compareSessions(expected, retrieved)
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		retrievedList, err = context.ListSessions(session.ProfileID, time.Time{}, 10)
		if err != nil {
			t.Errorf("Error retrieving sessions by profile from %T: %s", store, err)
		}
		if len(retrievedList) != 1 {
			t.Errorf("Expected 1 session retrieved by profile from %T, got %d", store, len(retrievedList))
		}
		success, field, expectation, result = compareSessions(expected, retrievedList[0])
		if !success {
			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
		}
		err = context.RemoveSession(session.ID)
		if err != nil {
			t.Errorf("Error removing session from %T: %s", store, err)
		}
		retrieved, err = context.GetSession(session.ID)
		if err != ErrSessionNotFound {
			t.Errorf("Expected ErrSessionNotFound from %T, got %s", store, err)
		}
		retrievedList, err = context.ListSessions(session.ProfileID, time.Time{}, 10)
		if err != nil {
			t.Errorf("Error retrieving sessions by profile from %T: %s", store, err)
		}
		if len(retrievedList) != 0 {
			t.Errorf("Expected 0 sessions retrieved by profile from %T, got %d", store, len(retrievedList))
		}
		err = context.RemoveSession(session.ID)
		if err != ErrSessionNotFound {
			t.Errorf("Expected ErrSessionNotFound from %T, got %s", store, err)
		}
		err = context.TerminateSession(session.ID)
		if err != ErrSessionNotFound {
			t.Errorf("Expected ERrSessionNotFound from %T, got %s", store, err)
		}
	}
}

// BUG(paddy): We need to test the CreateSessionHandler.
// BUG(paddy): We need to test the credentialsValidate function.
