package subscriptions

import (
	"os"
	"strconv"
	"testing"
	"time"

	"code.secondbit.org/uuid.hg"
)

const (
	subscriptionChangeStripeSubscription = 1 << iota
	subscriptionChangePlan
	subscriptionChangeStatus
	subscriptionChangeCanceling
	subscriptionChangeTrialStart
	subscriptionChangeTrialEnd
	subscriptionChangePeriodStart
	subscriptionChangePeriodEnd
	subscriptionChangeCanceledAt
	subscriptionChangeFailedChargeAttempts
	subscriptionChangeLastFailedCharge
	subscriptionChangeLastNotified
)

func init() {
	if os.Getenv("PG_TEST_DB") != "" {
		p, err := NewPostgres(os.Getenv("PG_TEST_DB"))
		if err != nil {
			panic(err)
		}
		testSubscriptionStores = append(testSubscriptionStores, p)
	}
}

var testSubscriptionStores = []SubscriptionStore{
	NewMemstore(),
}

func compareSubscriptions(sub1, sub2 Subscription) (bool, string, interface{}, interface{}) {
	if !sub1.UserID.Equal(sub2.UserID) {
		return false, "UserID", sub1.UserID, sub2.UserID
	}
	if sub1.StripeSubscription != sub2.StripeSubscription {
		return false, "StripeSubscription", sub1.StripeSubscription, sub2.StripeSubscription
	}
	if sub1.Plan != sub2.Plan {
		return false, "Plan", sub1.Plan, sub2.Plan
	}
	if sub1.Status != sub2.Status {
		return false, "Status", sub1.Status, sub2.Status
	}
	if sub1.Canceling != sub2.Canceling {
		return false, "Canceling", sub1.Canceling, sub2.Canceling
	}
	if !sub1.Created.Equal(sub2.Created) {
		return false, "Created", sub1.Created, sub2.Created
	}
	if !sub1.TrialStart.Equal(sub2.TrialStart) {
		return false, "TrialStart", sub1.TrialStart, sub2.TrialStart
	}
	if !sub1.TrialEnd.Equal(sub2.TrialEnd) {
		return false, "TrialEnd", sub1.TrialEnd, sub2.TrialEnd
	}
	if !sub1.PeriodStart.Equal(sub2.PeriodStart) {
		return false, "PeriodStart", sub1.PeriodStart, sub2.PeriodStart
	}
	if !sub1.PeriodEnd.Equal(sub2.PeriodEnd) {
		return false, "PeriodEnd", sub1.PeriodEnd, sub2.PeriodEnd
	}
	if !sub1.CanceledAt.Equal(sub2.CanceledAt) {
		return false, "CanceledAt", sub1.CanceledAt, sub2.CanceledAt
	}
	if sub1.FailedChargeAttempts != sub2.FailedChargeAttempts {
		return false, "FailedChargeAttempts", sub1.FailedChargeAttempts, sub2.FailedChargeAttempts
	}
	if !sub1.LastFailedCharge.Equal(sub2.LastFailedCharge) {
		return false, "LastFailedCharge", sub1.LastFailedCharge, sub2.LastFailedCharge
	}
	if !sub1.LastNotified.Equal(sub2.LastNotified) {
		return false, "LastNotified", sub1.LastNotified, sub2.LastNotified
	}
	return true, "", nil, nil
}

func subscriptionMapContains(subscriptionMap map[string]Subscription, subscriptions ...Subscription) (bool, []Subscription) {
	var missing []Subscription
	for _, sub := range subscriptions {
		if _, ok := subscriptionMap[sub.UserID.String()]; !ok {
			missing = append(missing, sub)
		}
	}
	if len(missing) > 0 {
		return false, missing
	}
	return true, missing
}

func compareSubscriptionStats(stat1, stat2 SubscriptionStats) (bool, string, interface{}, interface{}) {
	if stat1.Number != stat2.Number {
		return false, "Number", stat1.Number, stat2.Number
	}
	if stat1.Canceling != stat2.Canceling {
		return false, "Canceling", stat1.Canceling, stat2.Canceling
	}
	if stat1.Failing != stat2.Failing {
		return false, "Failing", stat1.Failing, stat2.Failing
	}
	if len(stat1.Plans) != len(stat2.Plans) {
		return false, "Plans", stat1.Plans, stat2.Plans
	}
	for key, count := range stat1.Plans {
		count2, ok := stat2.Plans[key]
		if !ok {
			return false, "Plans", stat1.Plans, stat2.Plans
		}
		if count != count2 {
			return false, "Plans", stat1.Plans, stat2.Plans
		}
	}
	return true, "", nil, nil
}

func TestCreateSubscription(t *testing.T) {
	for _, store := range testSubscriptionStores {
		err := store.Reset()
		if err != nil {
			t.Fatalf("Error resetting %T: %+v\n", store, err)
		}
		customerID := uuid.NewID()
		sub := Subscription{
			UserID:             customerID,
			StripeSubscription: "stripeSubscription1",
			Created:            time.Now().Round(time.Millisecond),
			TrialStart:         time.Now().Round(time.Millisecond),
			TrialEnd:           time.Now().Round(time.Millisecond).Add(time.Hour * 24 * 31),
		}
		err = store.CreateSubscription(sub)
		if err != nil {
			t.Errorf("Error creating subscription in %T: %+v\n", store, err)
		}
		retrieved, err := store.GetSubscriptions([]uuid.ID{sub.UserID})
		if err != nil {
			t.Errorf("Error retrieving subscription from %T: %+v\n", store, err)
		}
		if _, returned := retrieved[sub.UserID.String()]; !returned {
			t.Errorf("Error retrieving subscription from %T: %s wasn't in the results.", store, sub.UserID)
		}
		ok, field, expected, result := compareSubscriptions(sub, retrieved[sub.UserID.String()])
		if !ok {
			t.Errorf("Expected %s to be %v, got %v from %T\n", field, expected, result, store)
		}
		err = store.CreateSubscription(sub)
		if err != ErrSubscriptionAlreadyExists {
			t.Errorf("Unexpected error creating subscription in %T (wanted %+v): %+v\n", store, ErrSubscriptionAlreadyExists, err)
		}
		sub.UserID = uuid.NewID()
		err = store.CreateSubscription(sub)
		if err != ErrStripeSubscriptionAlreadyExists {
			t.Errorf("Unexpected error creating subscription in %T (wanted %+v): %#+v\n", store, ErrStripeSubscriptionAlreadyExists, err)
		}
		sub.StripeSubscription = "stripeSubscription2"
		err = store.CreateSubscription(sub)
		if err != nil {
			t.Errorf("Error creating subscription in %T: %+v\n", store, err)
		}
	}
}

func TestUpdateSubscription(t *testing.T) {
	variations := 1 << 12
	sub := Subscription{
		UserID:             uuid.NewID(),
		StripeSubscription: "default",
		Created:            time.Now().Round(time.Millisecond).Add(time.Hour * -24 * -32),
		TrialStart:         time.Now().Round(time.Millisecond).Add(time.Hour * -24 * -32),
		TrialEnd:           time.Now().Round(time.Millisecond).Add(time.Hour * -24),
		LastNotified:       time.Now().Round(time.Millisecond).Add(time.Hour * -24),
	}
	sub2 := Subscription{
		UserID:             uuid.NewID(),
		StripeSubscription: "stripeSubscription2",
		Created:            time.Now().Round(time.Millisecond),
		TrialStart:         time.Now().Round(time.Millisecond),
		TrialEnd:           time.Now().Round(time.Millisecond),
		LastNotified:       time.Now().Round(time.Millisecond),
	}

	for _, store := range testSubscriptionStores {
		err := store.Reset()
		if err != nil {
			t.Fatalf("Error resetting %T: %+v\n", store, err)
		}
		err = store.CreateSubscription(sub)
		if err != nil {
			t.Fatalf("Error saving subscription in %T: %s\n", store, err)
		}
		for i := 1; i < variations; i++ {
			var stripeSubscription, plan, status string
			var canceling bool
			var failedChargeAttempts int
			var trialStart, trialEnd, periodStart, periodEnd, canceledAt, lastFailedCharge, lastNotified time.Time

			change := SubscriptionChange{}
			empty := change.IsEmpty()
			if !empty {
				t.Errorf("Expected empty to be %t, was %t\n", true, empty)
			}
			result := sub
			strI := strconv.Itoa(i)

			if i&subscriptionChangeStripeSubscription != 0 {
				stripeSubscription = "stripeSubscription-" + strI
				change.StripeSubscription = &stripeSubscription
				sub.StripeSubscription = stripeSubscription
			}

			if i&subscriptionChangePlan != 0 {
				plan = "plan-" + strI
				change.Plan = &plan
				sub.Plan = plan
			}

			if i&subscriptionChangeStatus != 0 {
				status = "status-" + strI
				change.Status = &status
				sub.Status = status
			}

			if i&subscriptionChangeCanceling != 0 {
				canceling = i%2 == 0
				change.Canceling = &canceling
				sub.Canceling = canceling
			}

			if i&subscriptionChangeTrialStart != 0 {
				trialStart = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
				change.TrialStart = &trialStart
				sub.TrialStart = trialStart
			}

			if i&subscriptionChangeTrialEnd != 0 {
				trialEnd = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
				change.TrialEnd = &trialEnd
				sub.TrialEnd = trialEnd
			}

			if i&subscriptionChangePeriodStart != 0 {
				periodStart = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
				change.PeriodStart = &periodStart
				sub.PeriodStart = periodStart
			}

			if i&subscriptionChangePeriodEnd != 0 {
				periodEnd = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
				change.PeriodEnd = &periodEnd
				sub.PeriodEnd = periodEnd
			}

			if i&subscriptionChangeCanceledAt != 0 {
				canceledAt = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
				change.CanceledAt = &canceledAt
				sub.CanceledAt = canceledAt
			}

			if i&subscriptionChangeFailedChargeAttempts != 0 {
				failedChargeAttempts = i
				change.FailedChargeAttempts = &failedChargeAttempts
				sub.FailedChargeAttempts = failedChargeAttempts
			}

			if i&subscriptionChangeLastFailedCharge != 0 {
				lastFailedCharge = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
				change.LastFailedCharge = &lastFailedCharge
				sub.LastFailedCharge = lastFailedCharge
			}

			if i&subscriptionChangeLastNotified != 0 {
				lastNotified = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
				change.LastNotified = &lastNotified
				sub.LastNotified = lastNotified
			}

			empty = change.IsEmpty()
			if empty {
				t.Errorf("Expected empty to be %t, was %t\n", false, empty)
			}

			result.ApplyChange(change)
			match, field, expected, got := compareSubscriptions(sub, result)
			if !match {
				t.Errorf("Expected field `%s` to be `%v`, got `%v`\n", field, expected, got)
			}
			err = store.UpdateSubscription(sub.UserID, change)
			if err != nil {
				t.Logf("Change %d: %+v\n", i, change)
				if p, ok := store.(Postgres); ok {
					query := p.updateSubscriptionSQL(sub.UserID, change)
					t.Log(query.String())
					t.Log(query.Args...)
				}
				t.Errorf("Error updating subscription in %T: %s\n", store, err)
			}
			retrieved, err := store.GetSubscriptions([]uuid.ID{sub.UserID})
			if err != nil {
				t.Errorf("Error getting subscription from %T: %s\n", store, err)
			}
			ok, missing := subscriptionMapContains(retrieved, sub)
			if !ok {
				t.Errorf("Expected to retrieve %s from %T, but missing was %+v\n", sub.UserID.String(), store, missing)
			}
			match, field, expected, got = compareSubscriptions(sub, retrieved[sub.UserID.String()])
			if !match {
				t.Errorf("Expected field `%s` to be `%v`, got `%v` from %T\n", field, expected, got, store)
			}
			sub = result
		}

		err = store.CreateSubscription(sub2)
		if err != nil {
			t.Fatalf("Error saving subscription in %T: %+v\n", store, err)
		}
		change := SubscriptionChange{}
		err = store.UpdateSubscription(sub.UserID, change)
		if err != ErrSubscriptionChangeEmpty {
			t.Errorf("Expected err to be %+v, but got %+v from %T\n", ErrSubscriptionChangeEmpty, err, store)
		}
		stripeSubscription := sub2.StripeSubscription
		change.StripeSubscription = &stripeSubscription
		err = store.UpdateSubscription(uuid.NewID(), change)
		if err != ErrSubscriptionNotFound {
			t.Errorf("Expected err to be %+v, but got %+v from %T\n", ErrSubscriptionNotFound, err, store)
		}
		err = store.UpdateSubscription(sub.UserID, change)
		if err != ErrStripeSubscriptionAlreadyExists {
			t.Errorf("Expected err to be %+v, but got %+v from %T\n", ErrStripeSubscriptionAlreadyExists, err, store)
		}
	}
}

func TestDeleteSubscription(t *testing.T) {
	for _, store := range testSubscriptionStores {
		err := store.Reset()
		if err != nil {
			t.Fatalf("Error resetting %T: %+v\n", store, err)
		}
		sub1 := Subscription{
			UserID:             uuid.NewID(),
			StripeSubscription: "stripeSubscription1",
		}
		sub2 := Subscription{
			UserID:             uuid.NewID(),
			StripeSubscription: "stripeSubscription2",
		}
		err = store.CreateSubscription(sub1)
		if err != nil {
			t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
		}
		err = store.CreateSubscription(sub2)
		if err != nil {
			t.Fatalf("Error creating %+v in %T: %+v\n", sub2, store, err)
		}
		err = store.DeleteSubscription(sub1.UserID)
		if err != nil {
			t.Fatalf("Error deleting %+v in %T: %+v\n", sub1, store, err)
		}
		retrieved, err := store.GetSubscriptions([]uuid.ID{sub1.UserID, sub2.UserID})
		if err != nil {
			t.Errorf("Error retrieving subscriptions from %T: %+v\n", store, err)
		}
		ok, missing := subscriptionMapContains(retrieved, sub1)
		if ok {
			t.Errorf("Expected not to retrieve %s from %T, but missing was %+v\n", sub1.UserID.String(), store, missing)
		}
		ok, missing = subscriptionMapContains(retrieved, sub2)
		if !ok {
			t.Errorf("Expected to retrieve %s from %T, but missing was %+v\n", sub2.UserID.String(), store, missing)
		}
		err = store.DeleteSubscription(sub1.UserID)
		if err != ErrSubscriptionNotFound {
			t.Errorf("Expected err to be %+v, but got %+v from %T\n", ErrSubscriptionNotFound, err, store)
		}
	}
}

func TestGetSubscriptions(t *testing.T) {
	for _, store := range testSubscriptionStores {
		err := store.Reset()
		if err != nil {
			t.Fatalf("Error resetting %T: %+v\n", store, err)
		}
		sub1 := Subscription{
			UserID:             uuid.NewID(),
			StripeSubscription: "stripeSubscription1",
			Plan:               "plan1",
			Created:            time.Now().Round(time.Millisecond),
			TrialStart:         time.Now().Round(time.Millisecond),
			TrialEnd:           time.Now().Round(time.Millisecond).Add(time.Hour * 24 * 32),
		}
		sub2 := Subscription{
			UserID:             uuid.NewID(),
			StripeSubscription: "stripeSubscription2",
			Plan:               "plan2",
			Created:            time.Now().Round(time.Millisecond).Add(time.Hour * -720),
			TrialStart:         time.Now().Round(time.Millisecond).Add(time.Hour * -720),
			TrialEnd:           time.Now().Round(time.Millisecond),
		}
		sub3 := Subscription{
			UserID:             uuid.NewID(),
			StripeSubscription: "stripeSubscription3",
			Plan:               "plan3",
			Created:            time.Now().Round(time.Millisecond).Add(time.Hour * -1440),
			TrialStart:         time.Now().Round(time.Millisecond).Add(time.Hour * -1440),
			TrialEnd:           time.Now().Round(time.Millisecond).Add(time.Hour * -720),
			PeriodStart:        time.Now().Round(time.Millisecond).Add(time.Hour * -720),
			PeriodEnd:          time.Now().Round(time.Millisecond),
			Status:             "unpaid",
		}
		err = store.CreateSubscription(sub1)
		if err != nil {
			t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
		}
		err = store.CreateSubscription(sub2)
		if err != nil {
			t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
		}
		err = store.CreateSubscription(sub3)
		if err != nil {
			t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
		}
		retrieved, err := store.GetSubscriptions([]uuid.ID{})
		if err != ErrNoSubscriptionID {
			t.Errorf("Error retrieving no subscriptions from %T. Expected %+v, got %+v\n", store, ErrNoSubscriptionID, err)
		}
		retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID})
		if err != nil {
			t.Errorf("Error retrieving %s from %T: %+v\n", sub1.UserID, store, err)
		}
		ok, missing := subscriptionMapContains(retrieved, sub1)
		if !ok {
			t.Logf("Results: %+v\n", retrieved)
			t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
		}
		retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID, sub2.UserID})
		if err != nil {
			t.Errorf("Error retrieving %s and %s from %T: %+v\n", sub1.UserID, sub2.UserID, store, err)
		}
		ok, missing = subscriptionMapContains(retrieved, sub1, sub2)
		if !ok {
			t.Logf("Results: %+v\n", retrieved)
			t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
		}
		retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID, sub3.UserID})
		if err != nil {
			t.Errorf("Error retrieving %s and %s from %T: %+v\n", sub1.UserID, sub3.UserID, store, err)
		}
		ok, missing = subscriptionMapContains(retrieved, sub1, sub3)
		if !ok {
			t.Logf("Results: %+v\n", retrieved)
			t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
		}
		retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID, sub2.UserID, sub3.UserID})
		if err != nil {
			t.Errorf("Error retrieving %s, %s, and %s from %T: %+v\n", sub1.UserID, sub2.UserID, sub3.UserID, store, err)
		}
		ok, missing = subscriptionMapContains(retrieved, sub1, sub2, sub3)
		if !ok {
			t.Logf("Results: %+v\n", retrieved)
			t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
		}
		retrieved, err = store.GetSubscriptions([]uuid.ID{sub2.UserID})
		if err != nil {
			t.Errorf("Error retrieving %s from %T: %+v\n", sub2.UserID, store, err)
		}
		ok, missing = subscriptionMapContains(retrieved, sub2)
		if !ok {
			t.Logf("Results: %+v\n", retrieved)
			t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
		}
		retrieved, err = store.GetSubscriptions([]uuid.ID{sub2.UserID, sub3.UserID})
		if err != nil {
			t.Errorf("Error retrieving %s and %s from %T: %+v\n", sub2.UserID, sub3.UserID, store, err)
		}
		ok, missing = subscriptionMapContains(retrieved, sub2, sub3)
		if !ok {
			t.Logf("Results: %+v\n", retrieved)
			t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
		}
		retrieved, err = store.GetSubscriptions([]uuid.ID{sub3.UserID})
		if err != nil {
			t.Errorf("Error retrieving %s from %T: %+v\n", sub3.UserID, store, err)
		}
		ok, missing = subscriptionMapContains(retrieved, sub3)
		if !ok {
			t.Logf("Results: %+v\n", retrieved)
			t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
		}
		retrieved, err = store.GetSubscriptions([]uuid.ID{uuid.NewID()})
		if err != nil {
			t.Errorf("Error retrieving non-existent ID from %T: %+v\n", store, err)
		}
		if len(retrieved) != 0 {
			t.Errorf("Expected no results, %T returned %+v\n", store, retrieved)
		}
		retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID, sub2.UserID, uuid.NewID(), sub3.UserID})
		if err != nil {
			t.Errorf("Error retrieving non-existent ID from %T: %+v\n", store, err)
		}
		if len(retrieved) != 3 {
			t.Errorf("Expected 3 results, %T returned %+v\n", store, retrieved)
		}
		ok, missing = subscriptionMapContains(retrieved, sub1, sub2, sub3)
		if !ok {
			t.Logf("Results: %+v\n", retrieved)
			t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
		}
	}
}

func TestGetSubscriptionStats(t *testing.T) {
	for _, store := range testSubscriptionStores {
		err := store.Reset()
		if err != nil {
			t.Fatalf("Error resetting %T: %+v\n", store, err)
		}
		sub1 := Subscription{
			UserID:             uuid.NewID(),
			StripeSubscription: "stripeSubscription1",
			Plan:               "plan1",
			Canceling:          true,
		}
		sub2 := Subscription{
			UserID:             uuid.NewID(),
			StripeSubscription: "stripeSubscription2",
			Plan:               "plan2",
			Status:             "past_due",
		}
		err = store.CreateSubscription(sub1)
		if err != nil {
			t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
		}
		stats, err := store.GetSubscriptionStats()
		if err != nil {
			t.Errorf("Error getting stats from %T: %+v\n", store, err)
		}
		ok, field, expected, results := compareSubscriptionStats(SubscriptionStats{
			Number:    1,
			Canceling: 1,
			Failing:   0,
			Plans: map[string]int64{
				"plan1": 1,
			},
		}, stats)
		if !ok {
			t.Errorf("Expected %s to be %+v, got %+v from %T\n", field, expected, results, store)
		}
		err = store.CreateSubscription(sub2)
		if err != nil {
			t.Fatalf("Error creating %+v in %T: %+v\n", sub2, store, err)
		}
		stats, err = store.GetSubscriptionStats()
		if err != nil {
			t.Errorf("Error getting status from %T: %+v\n", store, err)
		}
		ok, field, expected, results = compareSubscriptionStats(SubscriptionStats{
			Number:    2,
			Canceling: 1,
			Failing:   1,
			Plans: map[string]int64{
				"plan1": 1,
				"plan2": 1,
			},
		}, stats)
		if !ok {
			t.Errorf("Expected %s to be %+v, got %+v from %T\n", field, expected, results, store)
		}
		err = store.DeleteSubscription(sub1.UserID)
		if err != nil {
			t.Errorf("Error deleting subscription from %T: %+v\n", store, err)
		}
		stats, err = store.GetSubscriptionStats()
		if err != nil {
			t.Errorf("Error getting status from %T: %+v\n", store, err)
		}
		ok, field, expected, results = compareSubscriptionStats(SubscriptionStats{
			Number:    1,
			Canceling: 0,
			Failing:   1,
			Plans: map[string]int64{
				"plan2": 1,
			},
		}, stats)
		if !ok {
			t.Errorf("Expected %s to be %+v, got %+v from %T\n", field, expected, results, store)
		}
	}
}
