package auth

import (
	"code.secondbit.org/uuid.hg"

	"github.com/lib/pq"
	"github.com/secondbit/pan"
)

type tokenScope struct {
	Token string
	Scope string
}

func (t tokenScope) GetSQLTableName() string {
	return "scopes_tokens"
}

func (t Token) GetSQLTableName() string {
	return "tokens"
}

func (p *postgres) getTokenSQL(token string, refresh bool) *pan.Query {
	var t Token
	fields, _ := pan.GetFields(t)
	query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(t))
	query.IncludeWhere()
	if !refresh {
		query.Include(pan.GetUnquotedColumn(t, "AccessToken")+" = ?", token)
	} else {
		query.Include(pan.GetUnquotedColumn(t, "RefreshToken")+" = ?", token)
	}
	return query.FlushExpressions(" ")
}

func (p *postgres) getToken(token string, refresh bool) (Token, error) {
	query := p.getTokenSQL(token, refresh)
	rows, err := p.db.Query(query.String(), query.Args...)
	if err != nil {
		return Token{}, err
	}
	var t Token
	var found bool
	for rows.Next() {
		err := pan.Unmarshal(rows, &t)
		if err != nil {
			return t, err
		}
		found = true
	}
	if err = rows.Err(); err != nil {
		return t, err
	}
	if !found {
		return t, ErrTokenNotFound
	}
	query = p.getTokenScopesSQL([]string{t.AccessToken})
	rows, err = p.db.Query(query.String(), query.Args...)
	if err != nil {
		return t, err
	}
	for rows.Next() {
		var ts tokenScope
		err = pan.Unmarshal(rows, &ts)
		if err != nil {
			return t, err
		}
		t.Scopes = append(t.Scopes, ts.Scope)
	}
	if err = rows.Err(); err != nil {
		return t, err
	}
	return t, nil
}

func (p *postgres) saveTokenSQL(token Token) *pan.Query {
	fields, values := pan.GetFields(token)
	query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(token))
	query.Include("(" + pan.QueryList(fields) + ")")
	query.Include("VALUES")
	query.Include("("+pan.VariableList(len(values))+")", values...)
	return query.FlushExpressions(" ")
}

func (p *postgres) saveTokenScopesSQL(ts []tokenScope) *pan.Query {
	fields, _ := pan.GetFields(ts[0])
	query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(ts[0]))
	query.Include("(" + pan.QueryList(fields) + ")")
	query.Include("VALUES")
	query.FlushExpressions(" ")
	for _, t := range ts {
		_, values := pan.GetFields(t)
		query.Include("("+pan.VariableList(len(values))+")", values...)
	}
	return query.FlushExpressions(", ")
}

func (p *postgres) saveToken(token Token) error {
	query := p.saveTokenSQL(token)
	_, err := p.db.Exec(query.String(), query.Args...)
	if e, ok := err.(*pq.Error); ok && e.Constraint == "tokens_pkey" {
		err = ErrTokenAlreadyExists
	}
	if err != nil || len(token.Scopes) < 1 {
		return err
	}
	var ts []tokenScope
	for _, scope := range token.Scopes {
		ts = append(ts, tokenScope{Token: token.AccessToken, Scope: scope})
	}
	query = p.saveTokenScopesSQL(ts)
	_, err = p.db.Exec(query.String(), query.Args...)
	return err
}

func (p *postgres) revokeTokenSQL(token string, refresh bool) *pan.Query {
	var t Token
	query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(t)+" SET ")
	query.Include(pan.GetUnquotedColumn(t, "Revoked")+" = ?", true)
	query.IncludeWhere()
	if !refresh {
		query.Include(pan.GetUnquotedColumn(t, "AccessToken")+" = ?", token)
	} else {
		query.Include(pan.GetUnquotedColumn(t, "RefreshToken")+" = ?", token)
	}
	return query.FlushExpressions(" ")
}

func (p *postgres) revokeToken(token string, refresh bool) error {
	query := p.revokeTokenSQL(token, refresh)
	res, err := p.db.Exec(query.String(), query.Args...)
	if err != nil {
		return err
	}
	rows, err := res.RowsAffected()
	if err != nil {
		return err
	}
	if rows == 0 {
		return ErrTokenNotFound
	}
	return nil
}

func (p *postgres) getTokensByProfileIDSQL(profileID uuid.ID, num, offset int) *pan.Query {
	var token Token
	fields, _ := pan.GetFields(token)
	query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(token))
	query.IncludeWhere()
	query.Include(pan.GetUnquotedColumn(token, "ProfileID")+" = ?", profileID)
	query.IncludeLimit(int64(num))
	query.IncludeOffset(int64(offset))
	return query.FlushExpressions(" ")
}

func (p *postgres) getTokenScopesSQL(tokens []string) *pan.Query {
	var t tokenScope
	fields, _ := pan.GetFields(t)
	tokensI := make([]interface{}, len(tokens))
	for pos, token := range tokens {
		tokensI[pos] = token
	}
	query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(t))
	query.IncludeWhere()
	query.Include(pan.GetUnquotedColumn(t, "Token")+" IN ("+pan.VariableList(len(tokensI))+")", tokensI...)
	return query.FlushExpressions(" ")
}

func (p *postgres) getTokensByProfileID(profileID uuid.ID, num, offset int) ([]Token, error) {
	query := p.getTokensByProfileIDSQL(profileID, num, offset)
	rows, err := p.db.Query(query.String(), query.Args...)
	if err != nil {
		return []Token{}, err
	}
	var tokens []Token
	var tokenIDs []string
	for rows.Next() {
		var token Token
		err = pan.Unmarshal(rows, &token)
		if err != nil {
			return tokens, err
		}
		tokens = append(tokens, token)
		tokenIDs = append(tokenIDs, token.AccessToken)
	}
	if err = rows.Err(); err != nil {
		return tokens, err
	}
	if len(tokenIDs) < 1 {
		return tokens, nil
	}
	scopes := map[string][]string{}
	query = p.getTokenScopesSQL(tokenIDs)
	rows, err = p.db.Query(query.String(), query.Args...)
	if err != nil {
		return tokens, err
	}
	for rows.Next() {
		var t tokenScope
		err = pan.Unmarshal(rows, &t)
		if err != nil {
			return tokens, err
		}
		scopes[t.Token] = append(scopes[t.Token], t.Scope)
	}
	if err = rows.Err(); err != nil {
		return tokens, err
	}
	for pos, token := range tokens {
		token.Scopes = scopes[token.AccessToken]
		tokens[pos] = token
	}
	return tokens, nil
}
