package auth

import (
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"errors"
	"log"
	"net/http"
	"sort"
	"time"

	"code.secondbit.org/pass.hg"
	"code.secondbit.org/uuid.hg"
	"github.com/gorilla/mux"
)

const (
	loginTemplateName = "login"
)

func init() {
	RegisterGrantType("password", GrantType{
		Validate:      credentialsValidate,
		Invalidate:    nil,
		IssuesRefresh: true,
		ReturnToken:   RenderJSONToken,
	})
}

var (
	// ErrNoSessionStore is returned when a Context tries to act on a sessionStore without setting on first.
	ErrNoSessionStore = errors.New("no sessionStore was specified for the Context")
	// ErrSessionNotFound is returned when a Session is requested but not found in the sessionStore.
	ErrSessionNotFound = errors.New("session not found in sessionStore")
	// ErrInvalidSession is returned when a Session is specified but is not valid.
	ErrInvalidSession = errors.New("session is not valid")
	// ErrSessionAlreadyExists is returned when a sessionStore tries to store a Session with an ID that already exists in the sessionStore.
	ErrSessionAlreadyExists = errors.New("session already exists")

	passphraseSchemes = map[int]passphraseScheme{
		1: {
			check:               pbkdf2sha256check,
			create:              pbkdf2sha256create,
			calculateIterations: pbkdf2sha256calc,
		},
	}
)

type passphraseScheme struct {
	check               func(profile Profile, passphrase string) (bool, error)
	create              func(passphrase string, iterations int) (result, salt string, err error)
	calculateIterations func() (int, error)
}

// Session represents a user's authenticated session, associating it with a profile
// and some audit data.
type Session struct {
	ID        string
	IP        string
	UserAgent string
	ProfileID uuid.ID
	Login     string
	Created   time.Time
	Active    bool
}

type sortedSessions []Session

func (s sortedSessions) Len() int {
	return len(s)
}

func (s sortedSessions) Less(i, j int) bool {
	return s[i].Created.After(s[j].Created)
}

func (s sortedSessions) Swap(i, j int) {
	s[i], s[j] = s[j], s[i]
}

type sessionStore interface {
	createSession(session Session) error
	getSession(id string) (Session, error)
	removeSession(id string) error
	listSessions(profile uuid.ID, before time.Time, num int64) ([]Session, error)
}

func (m *memstore) createSession(session Session) error {
	m.sessionLock.Lock()
	defer m.sessionLock.Unlock()
	if _, ok := m.sessions[session.ID]; ok {
		return ErrSessionAlreadyExists
	}
	m.sessions[session.ID] = session
	return nil
}

func (m *memstore) getSession(id string) (Session, error) {
	m.sessionLock.RLock()
	defer m.sessionLock.RUnlock()
	if _, ok := m.sessions[id]; !ok {
		return Session{}, ErrSessionNotFound
	}
	return m.sessions[id], nil
}

func (m *memstore) removeSession(id string) error {
	m.sessionLock.Lock()
	defer m.sessionLock.Unlock()
	if _, ok := m.sessions[id]; !ok {
		return ErrSessionNotFound
	}
	delete(m.sessions, id)
	return nil
}

func (m *memstore) listSessions(profile uuid.ID, before time.Time, num int64) ([]Session, error) {
	m.sessionLock.RLock()
	defer m.sessionLock.RUnlock()
	res := []Session{}
	for _, session := range m.sessions {
		if int64(len(res)) >= num {
			break
		}
		if profile != nil && !profile.Equal(session.ProfileID) {
			continue
		}
		if !before.IsZero() && session.Created.After(before) {
			continue
		}
		res = append(res, session)
	}
	sorted := sortedSessions(res)
	sort.Sort(sorted)
	res = []Session(sorted)
	return res, nil
}

// RegisterSessionHandlers adds handlers to the passed router to handle the session endpoints, like login and logout.
func RegisterSessionHandlers(r *mux.Router, context Context) {
	r.Handle("/login", wrap(context, CreateSessionHandler))
}

func checkCookie(r *http.Request, context Context) (Session, error) {
	cookie, err := r.Cookie(authCookieName)
	if err == http.ErrNoCookie {
		return Session{}, ErrNoSession
	} else if err != nil {
		log.Println(err)
		return Session{}, err
	}
	sess, err := context.GetSession(cookie.Value)
	if err == ErrSessionNotFound {
		return Session{}, ErrInvalidSession
	} else if err != nil {
		return Session{}, err
	}
	if !sess.Active {
		return Session{}, ErrInvalidSession
	}
	return sess, nil
}

func buildLoginRedirect(r *http.Request, context Context) string {
	if context.loginURI == nil {
		return ""
	}
	uri := *context.loginURI
	q := uri.Query()
	q.Set("from", r.URL.String())
	uri.RawQuery = q.Encode()
	return uri.String()
}

func pbkdf2sha256check(profile Profile, passphrase string) (bool, error) {
	realPass, err := hex.DecodeString(profile.Passphrase)
	if err != nil {
		return false, err
	}
	realSalt, err := hex.DecodeString(profile.Salt)
	if err != nil {
		return false, err
	}
	candidate := pass.Check(sha256.New, profile.Iterations, []byte(passphrase), []byte(realSalt))
	if !pass.Compare(candidate, realPass) {
		return false, ErrIncorrectAuth
	}
	return true, nil
}

func pbkdf2sha256create(passphrase string, iters int) (result, salt string, err error) {
	passBytes, saltBytes, err := pass.Create(sha256.New, iters, []byte(passphrase))
	if err != nil {
		return "", "", err
	}
	result = hex.EncodeToString(passBytes)
	salt = hex.EncodeToString(saltBytes)
	return result, salt, err
}

func pbkdf2sha256calc() (int, error) {
	return pass.CalculateIterations(sha256.New)
}

func authenticate(user, passphrase string, context Context) (Profile, error) {
	profile, err := context.GetProfileByLogin(user)
	if err != nil {
		if err == ErrProfileNotFound || err == ErrLoginNotFound {
			return Profile{}, ErrIncorrectAuth
		}
		return Profile{}, err
	}
	if profile.Compromised {
		return Profile{}, ErrProfileCompromised
	}
	if !profile.LockedUntil.IsZero() && profile.LockedUntil.After(time.Now()) {
		return profile, ErrProfileLocked
	}
	scheme, ok := passphraseSchemes[profile.PassphraseScheme]
	if !ok {
		return Profile{}, ErrInvalidPassphraseScheme
	}
	result, err := scheme.check(profile, passphrase)
	if !result {
		return Profile{}, err
	}
	return profile, nil
}

// CreateSessionHandler allows the user to log into their account and create their session.
func CreateSessionHandler(w http.ResponseWriter, r *http.Request, context Context) {
	// BUG(paddy): Creating a session needs CSRF protection, right? This whole thing should get a security audit
	errors := []error{}
	if r.Method == "POST" {
		profile, err := authenticate(r.PostFormValue("login"), r.PostFormValue("passphrase"), context)
		if err == nil {
			ip := r.Header.Get("X-Forwarded-For")
			if ip == "" {
				ip = r.RemoteAddr
			}
			session := Session{
				ID:        uuid.NewID().String(),
				IP:        ip,
				UserAgent: r.UserAgent(),
				ProfileID: profile.ID,
				Login:     r.PostFormValue("login"),
				Created:   time.Now(),
				Active:    true,
			}
			err = context.CreateSession(session)
			if err != nil {
				w.WriteHeader(http.StatusInternalServerError)
				w.Write([]byte(err.Error()))
				return
			}
			// BUG(paddy): really need to do a security audit on our cookie
			cookie := http.Cookie{
				Name:     authCookieName,
				Value:    session.ID,
				Expires:  time.Now().Add(24 * 7 * time.Hour),
				HttpOnly: true,
			}
			http.SetCookie(w, &cookie)
			redirectTo := r.URL.Query().Get("from")
			if redirectTo == "" {
				redirectTo = "/"
			}
			http.Redirect(w, r, redirectTo, http.StatusFound)
			return
		} else if err != ErrIncorrectAuth && err != ErrProfileCompromised && err != ErrProfileLocked {
			w.WriteHeader(http.StatusInternalServerError)
			w.Write([]byte(err.Error()))
			return
		} else {
			errors = append(errors, err)
		}
	}
	context.Render(w, loginTemplateName, map[string]interface{}{
		"errors": errors,
	})
}

func credentialsValidate(w http.ResponseWriter, r *http.Request, context Context) (scope string, profileID uuid.ID, valid bool) {
	enc := json.NewEncoder(w)
	username := r.PostFormValue("username")
	password := r.PostFormValue("password")
	scope = r.PostFormValue("scope")
	profile, err := authenticate(username, password, context)
	if err != nil {
		if err == ErrIncorrectAuth || err == ErrProfileCompromised || err == ErrProfileLocked {
			w.WriteHeader(http.StatusBadRequest)
			renderJSONError(enc, "invalid_grant")
			return
		}
		w.WriteHeader(http.StatusInternalServerError)
		w.Write([]byte(err.Error()))
		return
	}
	profileID = profile.ID
	valid = true
	return
}
