package auth

import (
	"encoding/json"
	"errors"
	"log"
	"net/http"
	"strings"
	"time"

	"code.secondbit.org/uuid.hg"

	"github.com/dgrijalva/jwt-go"
)

const (
	defaultTokenExpiration = 900 // fifteen minutes
)

func init() {
	RegisterGrantType("refresh_token", GrantType{
		Validate:      refreshTokenValidate,
		Invalidate:    refreshTokenInvalidate,
		IssuesRefresh: true,
		ReturnToken:   RenderJSONToken,
		AuditString:   refreshTokenAuditString,
	})
}

var (
	// ErrNoTokenStore is returned when a Context tries to act on a tokenStore without setting one first.
	ErrNoTokenStore = errors.New("no tokenStore was specified for the Context")
	// ErrTokenNotFound is returned when a Token is requested but not found in a tokenStore.
	ErrTokenNotFound = errors.New("token not found in tokenStore")
	// ErrTokenAlreadyExists is returned when a Token is added to a tokenStore, but another Token with
	// the same AccessToken property already exists in the tokenStore.
	ErrTokenAlreadyExists = errors.New("token already exists in tokenStore")
)

// Token represents an access and/or refresh token that the Client can use to access user data
// or obtain a new access token.
type Token struct {
	AccessToken  string
	RefreshToken string
	Created      time.Time
	CreatedFrom  string
	ExpiresIn    int32
	TokenType    string
	Scopes       Scopes
	ProfileID    uuid.ID
	ClientID     uuid.ID
	Revoked      bool
}

func (t Token) GenerateAccessToken(privateKey []byte) (string, error) {
	access := jwt.New(jwt.SigningMethodHS256)
	access.Claims["iss"] = t.ClientID
	access.Claims["sub"] = t.ProfileID
	access.Claims["exp"] = t.Created.Add(defaultTokenExpiration * time.Second).Unix()
	access.Claims["nbf"] = t.Created.Add(-2 * time.Minute).Unix()
	access.Claims["iat"] = t.Created.Unix()
	access.Claims["scope"] = strings.Join(t.Scopes.Strings(), " ")
	return access.SignedString(privateKey)
}

// BUG(paddy): Now that access tokens are generated and have a meaning, refresh tokens should be the primary key

type tokenStore interface {
	getToken(token string, refresh bool) (Token, error)
	saveToken(token Token) error
	revokeToken(token string) error
	getTokensByProfileID(profileID uuid.ID, num, offset int) ([]Token, error)
	revokeTokensByProfileID(profileID uuid.ID) error
	revokeTokensByClientID(clientID uuid.ID) error
}

func (m *memstore) getToken(token string, refresh bool) (Token, error) {
	if refresh {
		t, err := m.lookupTokenByRefresh(token)
		if err != nil {
			return Token{}, err
		}
		token = t
	}
	m.tokenLock.RLock()
	defer m.tokenLock.RUnlock()
	result, ok := m.tokens[token]
	if !ok {
		return Token{}, ErrTokenNotFound
	}
	return result, nil
}

func (m *memstore) saveToken(token Token) error {
	m.tokenLock.Lock()
	defer m.tokenLock.Unlock()
	_, ok := m.tokens[token.AccessToken]
	if ok {
		return ErrTokenAlreadyExists
	}
	m.tokens[token.AccessToken] = token
	if token.RefreshToken != "" {
		m.refreshTokenLookup[token.RefreshToken] = token.AccessToken
	}
	if _, ok = m.profileTokenLookup[token.ProfileID.String()]; ok {
		m.profileTokenLookup[token.ProfileID.String()] = append(m.profileTokenLookup[token.ProfileID.String()], token.AccessToken)
	} else {
		m.profileTokenLookup[token.ProfileID.String()] = []string{token.AccessToken}
	}
	return nil
}

func (m *memstore) revokeToken(token string) error {
	token, err := m.lookupTokenByRefresh(token)
	if err != nil {
		return err
	}
	m.tokenLock.Lock()
	defer m.tokenLock.Unlock()
	t, ok := m.tokens[token]
	if !ok {
		return ErrTokenNotFound
	}
	t.Revoked = true
	m.tokens[token] = t
	return nil
}

func (m *memstore) revokeTokensByProfileID(profileID uuid.ID) error {
	ids, err := m.lookupTokensByProfileID(profileID.String())
	if err != nil {
		return err
	}
	if len(ids) < 1 {
		return ErrProfileNotFound
	}
	m.tokenLock.Lock()
	defer m.tokenLock.Unlock()
	for _, id := range ids {
		token := m.tokens[id]
		token.Revoked = true
		m.tokens[id] = token
	}
	return nil
}

func (m *memstore) revokeTokensByClientID(clientID uuid.ID) error {
	m.tokenLock.Lock()
	defer m.tokenLock.Unlock()
	for id, token := range m.tokens {
		if !token.ClientID.Equal(clientID) {
			continue
		}
		token.Revoked = true
		m.tokens[id] = token
	}
	return nil
}

func (m *memstore) getTokensByProfileID(profileID uuid.ID, num, offset int) ([]Token, error) {
	ids, err := m.lookupTokensByProfileID(profileID.String())
	if err != nil {
		return []Token{}, err
	}
	if len(ids) > num+offset {
		ids = ids[offset : num+offset]
	} else if len(ids) > offset {
		ids = ids[offset:]
	} else {
		return []Token{}, nil
	}
	tokens := []Token{}
	for _, id := range ids {
		token, err := m.getToken(id, false)
		if err != nil {
			return []Token{}, err
		}
		tokens = append(tokens, token)
	}
	return tokens, nil
}

func refreshTokenValidate(w http.ResponseWriter, r *http.Request, context Context) (scopes Scopes, profileID uuid.ID, valid bool) {
	enc := json.NewEncoder(w)
	refresh := r.PostFormValue("refresh_token")
	if refresh == "" {
		w.WriteHeader(http.StatusBadRequest)
		renderJSONError(enc, "invalid_request")
		return
	}
	token, err := context.GetToken(refresh, true)
	if err != nil {
		if err == ErrTokenNotFound {
			w.WriteHeader(http.StatusBadRequest)
			renderJSONError(enc, "invalid_grant")
			return
		}
		log.Println("Error exchanging refresh token:", err)
		w.WriteHeader(http.StatusInternalServerError)
		renderJSONError(enc, "server_error")
		return
	}
	clientID, _, ok := getClientAuth(w, r, true)
	if !ok {
		return
	}
	if !token.ClientID.Equal(clientID) {
		w.WriteHeader(http.StatusBadRequest)
		renderJSONError(enc, "invalid_grant")
		return
	}
	if token.Revoked {
		w.WriteHeader(http.StatusBadRequest)
		renderJSONError(enc, "invalid_grant")
		return
	}
	return token.Scopes, token.ProfileID, true
}

func refreshTokenInvalidate(r *http.Request, context Context) error {
	refresh := r.PostFormValue("refresh_token")
	if refresh == "" {
		return ErrTokenNotFound
	}
	return context.RevokeToken(refresh)
}

func refreshTokenAuditString(r *http.Request) string {
	return "refresh_token:" + r.PostFormValue("refresh_token")
}
