package auth

import (
	"bytes"
	"html/template"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"net/url"
	"testing"
	"time"

	"code.secondbit.org/uuid"
)

const (
	scopeSet = 1 << iota
	stateSet
	uriSet
)

func stripParam(param string, u *url.URL) {
	q := u.Query()
	q.Del(param)
	u.RawQuery = q.Encode()
}

func TestGetGrantCodeSuccess(t *testing.T) {
	t.Parallel()
	store := NewMemstore()
	testContext := Context{
		template: template.Must(template.New(getGrantTemplateName).Parse("Get auth grant")),
		clients:  store,
		grants:   store,
		profiles: store,
		tokens:   store,
	}
	client := Client{
		ID:      uuid.NewID(),
		Secret:  "super secret!",
		OwnerID: uuid.NewID(),
		Name:    "My test client",
		Logo:    "https://secondbit.org/logo.png",
		Website: "https://secondbit.org",
		Type:    "public",
	}
	uri, err := url.Parse("https://test.secondbit.org/redirect")
	if err != nil {
		t.Fatal("Can't parse URL:", err)
	}
	endpoint := Endpoint{
		ID:       uuid.NewID(),
		ClientID: client.ID,
		URI:      *uri,
		Added:    time.Now(),
	}
	err = testContext.SaveClient(client)
	if err != nil {
		t.Fatal("Can't store client:", err)
	}
	err = testContext.AddEndpoint(client.ID, endpoint)
	if err != nil {
		t.Fatal("Can't store endpoint:", err)
	}
	req, err := http.NewRequest("GET", "https://test.auth.secondbit.org/oauth2/grant", nil)
	if err != nil {
		t.Fatal("Can't build request:", err)
	}
	for i := 0; i < 1<<3; i++ {
		w := httptest.NewRecorder()
		params := url.Values{}
		// see OAuth 2.0 spec, section 4.1.1
		params.Set("response_type", "code")
		params.Set("client_id", client.ID.String())
		if i&uriSet != 0 {
			params.Set("redirect_uri", endpoint.URI.String())
		}
		if i&scopeSet != 0 {
			params.Set("scope", "testscope")
		}
		if i&stateSet != 0 {
			params.Set("state", "my super secure state string")
		}
		req.URL.RawQuery = params.Encode()
		req.Method = "GET"
		req.Body = nil
		req.Header.Del("Content-Type")
		GetGrantHandler(w, req, testContext)
		if w.Code != http.StatusOK {
			t.Errorf("Expected status code to be %d, got %d for %s", http.StatusOK, w.Code, req.URL.String())
		}
		if w.Body.String() != "Get auth grant" {
			t.Errorf("Expected body to be `%s`, got `%s` for %s", "Get auth grant", w.Body.String(), req.URL.String())
		}
		req.Method = "POST"
		req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
		w = httptest.NewRecorder()
		data := url.Values{}
		data.Set("grant", "approved")
		body := bytes.NewBufferString(data.Encode())
		req.Body = ioutil.NopCloser(body)
		GetGrantHandler(w, req, testContext)
		if w.Code != http.StatusFound {
			t.Errorf("Expected status code to be %d, got %d for %s", http.StatusFound, w.Code, req.URL.String())
		}
		redirectedTo := w.Header().Get("Location")
		red, err := url.Parse(redirectedTo)
		if err != nil {
			t.Fatalf(`Being redirected to a non-URL "%s" threw error "%s" for "%s"\n`, redirectedTo, err, req.URL.String())
		}
		t.Log("Redirected to", redirectedTo)
		if red.Query().Get("code") == "" {
			t.Fatalf(`Expected code param in redirect URL to be set, but it wasn't for %s`, req.URL.String())
		}
		if grant, err := testContext.GetGrant(red.Query().Get("code")); err != nil {
			t.Fatalf(`Unexpected error "%s: retrieving the grant "%s" supplied in the redirect URL for %s`, err, grant, req.URL.String())
		}
		err = testContext.DeleteGrant(red.Query().Get("code"))
		if err != nil {
			t.Log(`Unexpected error "%s" deleting grant "%s" for %s`, err, red.Query().Get("code"), req.URL.String())
		}
		stripParam("code", red)
		if red.Query().Get("state") != params.Get("state") {
			t.Errorf(`Expected state param in redirect URL to be "%s", got "%s" for %s`, params.Get("state"), red.Query().Get("state"), req.URL.String())
		}
		stripParam("state", red)
		if red.String() != endpoint.URI.String() {
			t.Errorf(`Expected redirect URL to be "%s", got "%s"`, endpoint.URI.String(), red.String())
		}
	}
}

func TestGetGrantCodeInvalidClient(t *testing.T) {
	t.Parallel()
	store := NewMemstore()
	testContext := Context{
		template: template.Must(template.New(getGrantTemplateName).Parse("{{ .error }}")),
		clients:  store,
		grants:   store,
		profiles: store,
		tokens:   store,
	}
	client := Client{
		ID:      uuid.NewID(),
		Secret:  "super secret!",
		OwnerID: uuid.NewID(),
		Name:    "My test client",
		Type:    "public",
	}
	err := testContext.SaveClient(client)
	if err != nil {
		t.Fatal("Can't store client:", err)
	}
	req, err := http.NewRequest("GET", "https://test.auth.secondbit.org/oauth2/grant", nil)
	if err != nil {
		t.Fatal("Can't build request:", err)
	}
	w := httptest.NewRecorder()
	params := url.Values{}
	params.Set("response_type", "code")
	params.Set("redirect_uri", "https://test.secondbit.org/")
	req.URL.RawQuery = params.Encode()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusBadRequest {
		t.Errorf("Expected status code to be %d, got %d", http.StatusBadRequest, w.Code)
	}
	if w.Body.String() != "Client ID must be specified in the request." {
		t.Errorf(`Expected output to be "%s", got "%s" instead.`, "Client ID must be specified in the request.", w.Body.String())
	}
	w = httptest.NewRecorder()
	params.Set("client_id", "Not an ID")
	req.URL.RawQuery = params.Encode()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusBadRequest {
		t.Errorf("Expected status code to be %d, got %d", http.StatusBadRequest, w.Code)
	}
	if w.Body.String() != "client_id is not a valid Client ID." {
		t.Errorf(`Expected output to be "%s", got "%s" instead.`, "client_id is not a valid Client ID.", w.Body.String())
	}
	w = httptest.NewRecorder()
	params.Set("client_id", uuid.NewID().String())
	req.URL.RawQuery = params.Encode()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusBadRequest {
		t.Errorf("Expected status code to be %d, got %d", http.StatusBadRequest, w.Code)
	}
	if w.Body.String() != "The specified Client couldn&rsquo;t be found." {
		t.Errorf(`Expected output to be "%s", got "%s" instead.`, "The specified Client couldn&rsquo;t be found.", w.Body.String())
	}
}

func TestGetGrantCodeInvalidURI(t *testing.T) {
	t.Parallel()
	store := NewMemstore()
	testContext := Context{
		template: template.Must(template.New(getGrantTemplateName).Parse("{{ .error }}")),
		clients:  store,
		grants:   store,
		profiles: store,
		tokens:   store,
	}
	client := Client{
		ID:      uuid.NewID(),
		Secret:  "super secret!",
		OwnerID: uuid.NewID(),
		Name:    "My test client",
		Type:    "public",
	}
	uri, err := url.Parse("https://test.secondbit.org/redirect")
	if err != nil {
		t.Fatal("Can't parse URL:", err)
	}
	err = testContext.SaveClient(client)
	if err != nil {
		t.Fatal("Can't store client:", err)
	}
	req, err := http.NewRequest("GET", "https://test.auth.secondbit.org/oauth2/grant", nil)
	if err != nil {
		t.Fatal("Can't build request:", err)
	}
	w := httptest.NewRecorder()
	params := url.Values{}
	params.Set("response_type", "code")
	params.Set("client_id", client.ID.String())
	req.URL.RawQuery = params.Encode()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusBadRequest {
		t.Errorf("Expected status code to be %d, got %d", http.StatusBadRequest, w.Code)
	}
	if w.Body.String() != "The redirect_uri specified is not valid." {
		t.Errorf(`Expected output to be "%s", got "%s" instead.`, "The redirect_uri specified is not valid.", w.Body.String())
	}
	endpoint := Endpoint{
		ID:       uuid.NewID(),
		ClientID: client.ID,
		URI:      *uri,
		Added:    time.Now(),
	}
	err = testContext.AddEndpoint(client.ID, endpoint)
	if err != nil {
		t.Fatal("Can't store endpoint:", err)
	}
	w = httptest.NewRecorder()
	params.Set("redirect_uri", "https://test.secondbit.org/wrong")
	req.URL.RawQuery = params.Encode()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusBadRequest {
		t.Errorf("Expected status code to be %d, got %d", http.StatusBadRequest, w.Code)
	}
	if w.Body.String() != "The redirect_uri specified is not valid." {
		t.Errorf(`Expected output to be "%s", got "%s" instead.`, "The redirect_uri specified is not valid.", w.Body.String())
	}
	endpoint2 := Endpoint{
		ID:       uuid.NewID(),
		ClientID: client.ID,
		URI:      *uri,
		Added:    time.Now(),
	}
	err = testContext.AddEndpoint(client.ID, endpoint2)
	if err != nil {
		t.Fatal("Can't store endpoint:", err)
	}
	w = httptest.NewRecorder()
	params.Set("redirect_uri", "")
	req.URL.RawQuery = params.Encode()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusBadRequest {
		t.Errorf("Expected status code to be %d, got %d", http.StatusBadRequest, w.Code)
	}
	if w.Body.String() != "The redirect_uri specified is not valid." {
		t.Errorf(`Expected output to be "%s", got "%s" instead.`, "The redirect_uri specified is not valid.", w.Body.String())
	}
	w = httptest.NewRecorder()
	params.Set("redirect_uri", "://not a URL")
	req.URL.RawQuery = params.Encode()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusBadRequest {
		t.Errorf("Expected status code to be %d, got %d", http.StatusBadRequest, w.Code)
	}
	if w.Body.String() != "The redirect_uri specified is not valid." {
		t.Errorf(`Expected output to be "%s", got "%s" instead.`, "The redirect_uri specified is not valid.", w.Body.String())
	}
}

func TestGetGrantCodeInvalidResponseType(t *testing.T) {
	t.Parallel()
	store := NewMemstore()
	testContext := Context{
		template: template.Must(template.New(getGrantTemplateName).Parse("{{ .error }}")),
		clients:  store,
		grants:   store,
		profiles: store,
		tokens:   store,
	}
	client := Client{
		ID:      uuid.NewID(),
		Secret:  "super secret!",
		OwnerID: uuid.NewID(),
		Name:    "My test client",
		Logo:    "https://secondbit.org/logo.png",
		Website: "https://secondbit.org",
		Type:    "public",
	}
	uri, err := url.Parse("https://test.secondbit.org/redirect")
	if err != nil {
		t.Fatal("Can't parse URL:", err)
	}
	endpoint := Endpoint{
		ID:       uuid.NewID(),
		ClientID: client.ID,
		URI:      *uri,
		Added:    time.Now(),
	}
	err = testContext.SaveClient(client)
	if err != nil {
		t.Fatal("Can't store client:", err)
	}
	err = testContext.AddEndpoint(client.ID, endpoint)
	if err != nil {
		t.Fatal("Can't store endpoint:", err)
	}
	req, err := http.NewRequest("GET", "https://test.auth.secondbit.org/oauth2/grant", nil)
	if err != nil {
		t.Fatal("Can't build request:", err)
	}
	params := url.Values{}
	params.Set("response_type", "totally not code")
	params.Set("client_id", client.ID.String())
	params.Set("redirect_uri", endpoint.URI.String())
	params.Set("scope", "testscope")
	params.Set("state", "my super secure state string")
	req.URL.RawQuery = params.Encode()
	w := httptest.NewRecorder()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusFound {
		t.Errorf("Expected status code to be %d, got %d", http.StatusFound, w.Code)
	}
	redirectedTo := w.Header().Get("Location")
	red, err := url.Parse(redirectedTo)
	if err != nil {
		t.Fatalf("Being redirected to a non-URL (%s) threw error: %s\n", redirectedTo, err)
	}
	if red.Query().Get("error") != "invalid_request" {
		t.Errorf(`Expected error param in redirect URL to be "%s", got "%s"`, "invalid_request", red.Query().Get("error"))
	}
	stripParam("error", red)
	if red.Query().Get("state") != params.Get("state") {
		t.Errorf(`Expected state param in redirect URL to be "%s", got "%s"`, params.Get("state"), red.Query().Get("state"))
	}
	stripParam("state", red)
	if red.String() != endpoint.URI.String() {
		t.Errorf(`Expected redirect URL to be "%s", got "%s"`, endpoint.URI.String(), red.String())
	}
	stripParam("response_type", req.URL)
	w = httptest.NewRecorder()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusFound {
		t.Errorf("Expected status code to be %d, got %d", http.StatusFound, w.Code)
	}
	redirectedTo = w.Header().Get("Location")
	red, err = url.Parse(redirectedTo)
	if err != nil {
		t.Fatalf("Being redirected to a non-URL (%s) threw error: %s\n", redirectedTo, err)
	}
	if red.Query().Get("error") != "invalid_request" {
		t.Errorf(`Expected error param in redirect URL to be "%s", got "%s"`, "invalid_request", red.Query().Get("error"))
	}
	stripParam("error", red)
	if red.Query().Get("state") != params.Get("state") {
		t.Errorf(`Expected state param in redirect URL to be "%s", got "%s"`, params.Get("state"), red.Query().Get("state"))
	}
	stripParam("state", red)
	if red.String() != endpoint.URI.String() {
		t.Errorf(`Expected redirect URL to be "%s", got "%s"`, endpoint.URI.String(), red.String())
	}
}

func TestGetGrantCodeDenied(t *testing.T) {
	t.Parallel()
	store := NewMemstore()
	testContext := Context{
		template: template.Must(template.New(getGrantTemplateName).Parse("{{ .error }}")),
		clients:  store,
		grants:   store,
		profiles: store,
		tokens:   store,
	}
	client := Client{
		ID:      uuid.NewID(),
		Secret:  "super secret!",
		OwnerID: uuid.NewID(),
		Name:    "My test client",
		Logo:    "https://secondbit.org/logo.png",
		Website: "https://secondbit.org",
		Type:    "public",
	}
	uri, err := url.Parse("https://test.secondbit.org/redirect")
	if err != nil {
		t.Fatal("Can't parse URL:", err)
	}
	endpoint := Endpoint{
		ID:       uuid.NewID(),
		ClientID: client.ID,
		URI:      *uri,
		Added:    time.Now(),
	}
	err = testContext.SaveClient(client)
	if err != nil {
		t.Fatal("Can't store client:", err)
	}
	err = testContext.AddEndpoint(client.ID, endpoint)
	if err != nil {
		t.Fatal("Can't store endpoint:", err)
	}
	req, err := http.NewRequest("GET", "https://test.auth.secondbit.org/oauth2/grant", nil)
	if err != nil {
		t.Fatal("Can't build request:", err)
	}
	params := url.Values{}
	params.Set("response_type", "code")
	params.Set("client_id", client.ID.String())
	params.Set("redirect_uri", endpoint.URI.String())
	params.Set("scope", "testscope")
	params.Set("state", "my super secure state string")
	data := url.Values{}
	data.Set("grant", "denied")
	req.URL.RawQuery = params.Encode()
	req.Body = ioutil.NopCloser(bytes.NewBufferString(data.Encode()))
	req.Method = "POST"
	w := httptest.NewRecorder()
	GetGrantHandler(w, req, testContext)
	if w.Code != http.StatusFound {
		t.Errorf("Expected status code to be %d, got %d", http.StatusFound, w.Code)
	}
	redirectedTo := w.Header().Get("Location")
	red, err := url.Parse(redirectedTo)
	if err != nil {
		t.Fatalf("Being redirected to a non-URL (%s) threw error: %s\n", redirectedTo, err)
	}
	if red.Query().Get("error") != "access_denied" {
		t.Errorf(`Expected error param in redirect URL to be "%s", got "%s"`, "access_denied", red.Query().Get("error"))
	}
	stripParam("error", red)
	if red.Query().Get("state") != params.Get("state") {
		t.Errorf(`Expected state param in redirect URL to be "%s", got "%s"`, params.Get("state"), red.Query().Get("state"))
	}
	stripParam("state", red)
	if red.String() != endpoint.URI.String() {
		t.Errorf(`Expected redirect URL to be "%s", got "%s"`, endpoint.URI.String(), red.String())
	}
}
