package oauth2

import (
	"encoding/base64"
	"errors"
	"net/http"
	"strings"
)

var (
	BasicAuthNotSetError      = errors.New("Authorization header not set.")
	InvalidBasicAuthTypeError = errors.New("Invalid basic auth type.")
	InvalidBasicAuthMessage   = errors.New("Invalid basic auth format.")
)

// Parse basic authentication header
type BasicAuth struct {
	Username string
	Password string
}

// Return authorization header data
func CheckBasicAuth(r *http.Request) (BasicAuth, error) {
	if r.Header.Get("Authorization") == "" {
		return BasicAuth{}, BasicAuthNotSetError
	}

	s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
	if len(s) != 2 || s[0] != "Basic" {
		return BasicAuth{}, InvalidBasicAuthTypeError
	}

	b, err := base64.StdEncoding.DecodeString(s[1])
	if err != nil {
		return BasicAuth{}, err
	}
	pair := strings.SplitN(string(b), ":", 2)
	if len(pair) != 2 {
		return BasicAuth{}, InvalidBasicAuthMessage
	}

	return BasicAuth{Username: pair[0], Password: pair[1]}, nil
}

// getClientAuth checks client basic authentication in params if allowed,
// otherwise gets it from the header.
func getClientAuth(r *http.Request, allowQueryParams bool) (BasicAuth, error) {

	if allowQueryParams {
		// Allow for auth without password
		if _, hasSecret := r.Form["client_secret"]; hasSecret {
			auth := BasicAuth{
				Username: r.Form.Get("client_id"),
				Password: r.Form.Get("client_secret"),
			}
			if auth.Username != "" {
				return auth, nil
			}
		}
	}

	return CheckBasicAuth(r)
}
