auth

Paddy 2015-04-11 Parent:3223a8e679db Child:48200d8c4036

159:cf6c1f05eb21 Browse Files

Enable terminating sessions through the API. Add a terminateSession method to the sessionStore that sets the Active property of the Session to false. Create a Context.TerminateSession wrapper for the terminateSession method on the sessionStore. Add a Sessions property to our response type so we can return a []Session in API responses. Use the URL-safe encoding when base64 encoding our session ID and CSRFToken, so the ID can be passed in the URL and so our encodings are consistent. Add a TerminateSessionHandler function that will extract a Session ID from the request URL, authenticate the user, check that the authenticated user owns the session in question, and terminate the session. Add implementations for our new terminateSession method for the memstore and postgres types. Test both the memstore and postgres implementation of our terminateSession helper in session_test.go.

context.go request.go session.go session_postgres.go session_test.go

     1.1 --- a/context.go	Sat Apr 11 14:39:51 2015 -0400
     1.2 +++ b/context.go	Sat Apr 11 15:37:41 2015 -0400
     1.3 @@ -336,6 +336,15 @@
     1.4  	return c.sessions.getSession(id)
     1.5  }
     1.6  
     1.7 +// TerminateSession sets the Session identified by the passed ID as inactive in the sessionStore assocated
     1.8 +// with the Context.
     1.9 +func (c Context) TerminateSession(id string) error {
    1.10 +	if c.sessions == nil {
    1.11 +		return ErrNoSessionStore
    1.12 +	}
    1.13 +	return c.sessions.terminateSession(id)
    1.14 +}
    1.15 +
    1.16  // RemoveSession removes the Session identified by the passed ID from the sessionStore associated with
    1.17  // the Context.
    1.18  func (c Context) RemoveSession(id string) error {
     2.1 --- a/request.go	Sat Apr 11 14:39:51 2015 -0400
     2.2 +++ b/request.go	Sat Apr 11 15:37:41 2015 -0400
     2.3 @@ -33,6 +33,7 @@
     2.4  	Profiles  []Profile      `json:"profiles,omitempty"`
     2.5  	Clients   []Client       `json:"clients,omitempty"`
     2.6  	Endpoints []Endpoint     `json:"endpoints,omitempty"`
     2.7 +	Sessions  []Session      `json:"sessions,omitempty"`
     2.8  }
     2.9  
    2.10  type requestError struct {
     3.1 --- a/session.go	Sat Apr 11 14:39:51 2015 -0400
     3.2 +++ b/session.go	Sat Apr 11 15:37:41 2015 -0400
     3.3 @@ -91,6 +91,7 @@
     3.4  type sessionStore interface {
     3.5  	createSession(session Session) error
     3.6  	getSession(id string) (Session, error)
     3.7 +	terminateSession(id string) error
     3.8  	removeSession(id string) error
     3.9  	listSessions(profile uuid.ID, before time.Time, num int64) ([]Session, error)
    3.10  }
    3.11 @@ -114,6 +115,18 @@
    3.12  	return m.sessions[id], nil
    3.13  }
    3.14  
    3.15 +func (m *memstore) terminateSession(id string) error {
    3.16 +	m.sessionLock.RLock()
    3.17 +	defer m.sessionLock.RUnlock()
    3.18 +	sess, ok := m.sessions[id]
    3.19 +	if !ok {
    3.20 +		return ErrSessionNotFound
    3.21 +	}
    3.22 +	sess.Active = false
    3.23 +	m.sessions[id] = sess
    3.24 +	return nil
    3.25 +}
    3.26 +
    3.27  func (m *memstore) removeSession(id string) error {
    3.28  	m.sessionLock.Lock()
    3.29  	defer m.sessionLock.Unlock()
    3.30 @@ -150,7 +163,7 @@
    3.31  func RegisterSessionHandlers(r *mux.Router, context Context) {
    3.32  	r.Handle("/login", wrap(context, CreateSessionHandler))
    3.33  	// BUG(paddy): We need to implement a handler for listing sessions active on a profile.
    3.34 -	// BUG(paddy): We need to implement a handler for terminating sessions.
    3.35 +	r.Handle("/sessions/{id}", wrap(context, TerminateSessionHandler)).Methods("OPTIONS", "DELETE")
    3.36  }
    3.37  
    3.38  func checkCSRF(r *http.Request, s Session) error {
    3.39 @@ -280,7 +293,7 @@
    3.40  				return
    3.41  			}
    3.42  			session := Session{
    3.43 -				ID:        base64.StdEncoding.EncodeToString(sessionID),
    3.44 +				ID:        base64.URLEncoding.EncodeToString(sessionID),
    3.45  				IP:        ip,
    3.46  				UserAgent: r.UserAgent(),
    3.47  				ProfileID: profile.ID,
    3.48 @@ -288,7 +301,7 @@
    3.49  				Created:   time.Now(),
    3.50  				Expires:   time.Now().Add(time.Hour),
    3.51  				Active:    true,
    3.52 -				CSRFToken: base64.StdEncoding.EncodeToString(csrfToken),
    3.53 +				CSRFToken: base64.URLEncoding.EncodeToString(csrfToken),
    3.54  			}
    3.55  			err = context.CreateSession(session)
    3.56  			if err != nil {
    3.57 @@ -324,6 +337,64 @@
    3.58  	})
    3.59  }
    3.60  
    3.61 +// TerminateSessionHandler allows the user to end their session before it expires.
    3.62 +func TerminateSessionHandler(w http.ResponseWriter, r *http.Request, context Context) {
    3.63 +	var errors []requestError
    3.64 +	vars := mux.Vars(r)
    3.65 +	if vars["id"] == "" {
    3.66 +		errors = append(errors, requestError{Slug: requestErrMissing, Param: "id"})
    3.67 +		encode(w, r, http.StatusBadRequest, response{Errors: errors})
    3.68 +		return
    3.69 +	}
    3.70 +	id := vars["id"]
    3.71 +	un, pw, ok := r.BasicAuth()
    3.72 +	if !ok {
    3.73 +		errors = append(errors, requestError{Slug: requestErrAccessDenied})
    3.74 +		encode(w, r, http.StatusUnauthorized, response{Errors: errors})
    3.75 +		return
    3.76 +	}
    3.77 +	profile, err := authenticate(un, pw, context)
    3.78 +	if err != nil {
    3.79 +		if isAuthError(err) {
    3.80 +			errors = append(errors, requestError{Slug: requestErrAccessDenied})
    3.81 +			encode(w, r, http.StatusForbidden, response{Errors: errors})
    3.82 +			return
    3.83 +		}
    3.84 +		errors = append(errors, requestError{Slug: requestErrActOfGod})
    3.85 +		encode(w, r, http.StatusInternalServerError, response{Errors: errors})
    3.86 +		return
    3.87 +	}
    3.88 +	session, err := context.GetSession(id)
    3.89 +	if err != nil {
    3.90 +		if err == ErrSessionNotFound {
    3.91 +			errors = append(errors, requestError{Slug: requestErrNotFound, Param: "id"})
    3.92 +			encode(w, r, http.StatusNotFound, response{Errors: errors})
    3.93 +			return
    3.94 +		}
    3.95 +		errors = append(errors, requestError{Slug: requestErrActOfGod})
    3.96 +		encode(w, r, http.StatusInternalServerError, response{Errors: errors})
    3.97 +		return
    3.98 +	}
    3.99 +	if !session.ProfileID.Equal(profile.ID) {
   3.100 +		errors = append(errors, requestError{Slug: requestErrAccessDenied, Param: "id"})
   3.101 +		encode(w, r, http.StatusForbidden, response{Errors: errors})
   3.102 +		return
   3.103 +	}
   3.104 +	err = context.TerminateSession(id)
   3.105 +	if err != nil {
   3.106 +		if err == ErrSessionNotFound {
   3.107 +			errors = append(errors, requestError{Slug: requestErrNotFound, Param: "id"})
   3.108 +			encode(w, r, http.StatusNotFound, response{Errors: errors})
   3.109 +			return
   3.110 +		}
   3.111 +		errors = append(errors, requestError{Slug: requestErrActOfGod})
   3.112 +		encode(w, r, http.StatusInternalServerError, response{Errors: errors})
   3.113 +		return
   3.114 +	}
   3.115 +	session.Active = false
   3.116 +	encode(w, r, http.StatusOK, response{Sessions: []Session{session}, Errors: errors})
   3.117 +}
   3.118 +
   3.119  func credentialsValidate(w http.ResponseWriter, r *http.Request, context Context) (scopes []string, profileID uuid.ID, valid bool) {
   3.120  	enc := json.NewEncoder(w)
   3.121  	username := r.PostFormValue("username")
     4.1 --- a/session_postgres.go	Sat Apr 11 14:39:51 2015 -0400
     4.2 +++ b/session_postgres.go	Sat Apr 11 15:37:41 2015 -0400
     4.3 @@ -64,6 +64,31 @@
     4.4  	return session, nil
     4.5  }
     4.6  
     4.7 +func (p *postgres) terminateSessionSQL(id string) *pan.Query {
     4.8 +	var session Session
     4.9 +	query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(session)+" SET")
    4.10 +	query.Include(pan.GetUnquotedColumn(session, "Active")+" = ?", false)
    4.11 +	query.IncludeWhere()
    4.12 +	query.Include(pan.GetUnquotedColumn(session, "ID")+" = ?", id)
    4.13 +	return query.FlushExpressions(" ")
    4.14 +}
    4.15 +
    4.16 +func (p *postgres) terminateSession(id string) error {
    4.17 +	query := p.terminateSessionSQL(id)
    4.18 +	res, err := p.db.Exec(query.String(), query.Args...)
    4.19 +	if err != nil {
    4.20 +		return err
    4.21 +	}
    4.22 +	rows, err := res.RowsAffected()
    4.23 +	if err != nil {
    4.24 +		return err
    4.25 +	}
    4.26 +	if rows < 1 {
    4.27 +		return ErrSessionNotFound
    4.28 +	}
    4.29 +	return nil
    4.30 +}
    4.31 +
    4.32  func (p *postgres) removeSessionSQL(id string) *pan.Query {
    4.33  	var session Session
    4.34  	query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(session))
     5.1 --- a/session_test.go	Sat Apr 11 14:39:51 2015 -0400
     5.2 +++ b/session_test.go	Sat Apr 11 15:37:41 2015 -0400
     5.3 @@ -91,6 +91,31 @@
     5.4  		if !success {
     5.5  			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
     5.6  		}
     5.7 +		err = context.TerminateSession(session.ID)
     5.8 +		if err != nil {
     5.9 +			t.Errorf("Error terminating session in %T: %s", store, err)
    5.10 +		}
    5.11 +		retrieved, err = context.GetSession(session.ID)
    5.12 +		if err != nil {
    5.13 +			t.Errorf("Error retrieving session from %T: %s", store, err)
    5.14 +		}
    5.15 +		expected := session
    5.16 +		expected.Active = false
    5.17 +		success, field, expectation, result = compareSessions(expected, retrieved)
    5.18 +		if !success {
    5.19 +			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
    5.20 +		}
    5.21 +		retrievedList, err = context.ListSessions(session.ProfileID, time.Time{}, 10)
    5.22 +		if err != nil {
    5.23 +			t.Errorf("Error retrieving sessions by profile from %T: %s", store, err)
    5.24 +		}
    5.25 +		if len(retrievedList) != 1 {
    5.26 +			t.Errorf("Expected 1 session retrieved by profile from %T, got %d", store, len(retrievedList))
    5.27 +		}
    5.28 +		success, field, expectation, result = compareSessions(expected, retrievedList[0])
    5.29 +		if !success {
    5.30 +			t.Errorf("Expected field %s to be %v, but got %v from %T", field, expectation, result, store)
    5.31 +		}
    5.32  		err = context.RemoveSession(session.ID)
    5.33  		if err != nil {
    5.34  			t.Errorf("Error removing session from %T: %s", store, err)
    5.35 @@ -110,6 +135,10 @@
    5.36  		if err != ErrSessionNotFound {
    5.37  			t.Errorf("Expected ErrSessionNotFound from %T, got %s", store, err)
    5.38  		}
    5.39 +		err = context.TerminateSession(session.ID)
    5.40 +		if err != ErrSessionNotFound {
    5.41 +			t.Errorf("Expected ERrSessionNotFound from %T, got %s", store, err)
    5.42 +		}
    5.43  	}
    5.44  }
    5.45