package subscriptions

import (
	"database/sql"
	"time"

	"code.secondbit.org/uuid.hg"
	"github.com/lib/pq"
	"github.com/secondbit/pan"
)

// GetSQLTableName fulfills the pan.SQLTableNamer interface, allowing
// us to manipulate Subscriptions with pan.
func (s Subscription) GetSQLTableName() string {
	return "subscriptions"
}

func (p Postgres) resetSQL() *pan.Query {
	var sub Subscription
	query := pan.New(pan.POSTGRES, "TRUNCATE "+pan.GetTableName(sub))
	return query.FlushExpressions(" ")
}

func (p Postgres) reset() error {
	query := p.resetSQL()
	_, err := p.Exec(query.String(), query.Args...)
	if err != nil {
		return err
	}
	return nil
}

func (p Postgres) createSubscriptionSQL(sub Subscription) *pan.Query {
	fields, values := pan.GetFields(sub)
	query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(sub))
	query.Include("(" + pan.QueryList(fields) + ")")
	query.Include("VALUES")
	query.Include("("+pan.VariableList(len(values))+")", values...)
	return query.FlushExpressions(" ")
}

func (p Postgres) createSubscription(sub Subscription) error {
	query := p.createSubscriptionSQL(sub)
	_, err := p.Exec(query.String(), query.Args...)
	if e, ok := err.(*pq.Error); ok && e.Constraint == "subscriptions_pkey" {
		err = ErrSubscriptionAlreadyExists
	} else if e, ok := err.(*pq.Error); ok && e.Constraint == "subscriptions_stripe_customer_key" {
		err = ErrStripeCustomerAlreadyExists
	}
	return err
}

func (p Postgres) updateSubscriptionSQL(id uuid.ID, change SubscriptionChange) *pan.Query {
	var sub Subscription
	query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(sub)+" SET")
	query.IncludeIfNotNil(pan.GetUnquotedColumn(sub, "StripeCustomer")+" = ?", change.StripeCustomer)
	query.IncludeIfNotNil(pan.GetUnquotedColumn(sub, "Amount")+" = ?", change.Amount)
	query.IncludeIfNotNil(pan.GetUnquotedColumn(sub, "Period")+" = ?", change.Period)
	query.IncludeIfNotNil(pan.GetUnquotedColumn(sub, "BeginCharging")+" = ?", change.BeginCharging)
	query.IncludeIfNotNil(pan.GetUnquotedColumn(sub, "LastCharged")+" = ?", change.LastCharged)
	query.IncludeIfNotNil(pan.GetUnquotedColumn(sub, "LastNotified")+" = ?", change.LastNotified)
	query.IncludeIfNotNil(pan.GetUnquotedColumn(sub, "InLockout")+" = ?", change.InLockout)
	query.FlushExpressions(", ")
	query.IncludeWhere()
	query.Include(pan.GetUnquotedColumn(sub, "UserID")+" = ?", id)
	return query.FlushExpressions(" ")
}

func (p Postgres) updateSubscription(id uuid.ID, change SubscriptionChange) error {
	if change.IsEmpty() {
		return ErrSubscriptionChangeEmpty
	}

	query := p.updateSubscriptionSQL(id, change)
	res, err := p.Exec(query.String(), query.Args...)
	if e, ok := err.(*pq.Error); ok && e.Constraint == "subscriptions_stripe_customer_key" {
		return ErrStripeCustomerAlreadyExists
	} else if err != nil {
		return err
	}
	rows, err := res.RowsAffected()
	if err != nil {
		return err
	}
	if rows < 1 {
		return ErrSubscriptionNotFound
	}
	return nil
}

func (p Postgres) deleteSubscriptionSQL(id uuid.ID) *pan.Query {
	var sub Subscription
	query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(sub))
	query.IncludeWhere()
	query.Include(pan.GetUnquotedColumn(sub, "UserID")+" = ?", id)
	return query.FlushExpressions(" ")
}

func (p Postgres) deleteSubscription(id uuid.ID) error {
	query := p.deleteSubscriptionSQL(id)
	res, err := p.Exec(query.String(), query.Args...)
	if err != nil {
		return err
	}
	rows, err := res.RowsAffected()
	if err != nil {
		return err
	}
	if rows < 1 {
		return ErrSubscriptionNotFound
	}
	return nil
}

func (p Postgres) listSubscriptionsLastChargedBeforeSQL(cutoff time.Time) *pan.Query {
	var sub Subscription
	fields, _ := pan.GetFields(sub)
	query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(sub))
	query.IncludeWhere()
	query.Include(pan.GetUnquotedColumn(sub, "LastCharged")+" < ?", cutoff)
	query.IncludeOrder(pan.GetUnquotedColumn(sub, "LastCharged") + " ASC")
	return query.FlushExpressions(" ")
}

func (p Postgres) listSubscriptionsLastChargedBefore(cutoff time.Time) ([]Subscription, error) {
	var results []Subscription
	query := p.listSubscriptionsLastChargedBeforeSQL(cutoff)
	rows, err := p.Query(query.String(), query.Args...)
	if err != nil {
		return results, err
	}
	for rows.Next() {
		var sub Subscription
		err := pan.Unmarshal(rows, &sub)
		if err != nil {
			return results, err
		}
		results = append(results, sub)
	}
	if err := rows.Err(); err != nil {
		return results, err
	}
	return results, nil
}

func (p Postgres) getSubscriptionsSQL(ids []uuid.ID) *pan.Query {
	var sub Subscription
	fields, _ := pan.GetFields(sub)
	intIDs := make([]interface{}, len(ids))
	for pos, id := range ids {
		intIDs[pos] = id
	}
	query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(sub))
	query.IncludeWhere()
	query.Include(pan.GetUnquotedColumn(sub, "UserID") + " IN")
	query.Include("("+pan.VariableList(len(intIDs))+")", intIDs...)
	return query.FlushExpressions(" ")
}

func (p Postgres) getSubscriptions(ids []uuid.ID) (map[string]Subscription, error) {
	results := map[string]Subscription{}
	if len(ids) < 1 {
		return results, ErrNoSubscriptionID
	}
	query := p.getSubscriptionsSQL(ids)
	rows, err := p.Query(query.String(), query.Args...)
	if err != nil {
		return results, err
	}
	for rows.Next() {
		var sub Subscription
		err := pan.Unmarshal(rows, &sub)
		if err != nil {
			return results, err
		}
		results[sub.UserID.String()] = sub
	}
	if err := rows.Err(); err != nil {
		return results, err
	}
	return results, nil
}

func (p Postgres) getSubscriptionStatsSQL() *pan.Query {
	var sub Subscription
	amountColumn := pan.GetUnquotedColumn(sub, "Amount")
	query := pan.New(pan.POSTGRES, "SELECT")
	query.Include("COUNT(*), SUM(" + amountColumn + "), AVG(" + amountColumn + ")")
	query.Include("FROM " + pan.GetTableName(sub))
	return query.FlushExpressions(" ")
}

func (p Postgres) getSubscriptionStats() (SubscriptionStats, error) {
	query := p.getSubscriptionStatsSQL()
	rows, err := p.Query(query.String(), query.Args...)
	if err != nil {
		return SubscriptionStats{}, err
	}
	var stats SubscriptionStats
	for rows.Next() {
		var number, total sql.NullInt64
		var mean sql.NullFloat64
		if err := rows.Scan(number, total, mean); err != nil {
			return stats, err
		}
		if number.Valid {
			stats.Number = number.Int64
		}
		if total.Valid {
			stats.TotalAmount = total.Int64
		}
		if mean.Valid {
			stats.MeanAmount = mean.Float64
		}
	}
	if err := rows.Err(); err != nil {
		return stats, err
	}
	return stats, nil
}
