auth
auth/token_postgres.go
Clean up sessions and tokens after Profile is deleted. Add a terminateSessionsByProfile method to our sessionStore to mark Sessions associated with a Profile as inactive. Implement memstore and postgres implementations of the terminateSessionsByProfile method. Add a TerminateSessionsByProfile wrapper method to Context. Add a revokeTokensByProfileID method to our tokenStore to mark Tokens associated with a Profile as revoked. Implement memstore and postgres implementation of the revokeTokensByProfileID method. Add a RevokeTokensByProfileID wrapper method to Context. Call our RevokeTokensByProfileID and TerminateSessionsByProfile methods after a Profile is deleted, to clean up the Tokens and Sessions associated with it.
| paddy@155 | 1 package auth |
| paddy@155 | 2 |
| paddy@155 | 3 import ( |
| paddy@155 | 4 "code.secondbit.org/uuid.hg" |
| paddy@155 | 5 |
| paddy@155 | 6 "github.com/lib/pq" |
| paddy@155 | 7 "github.com/secondbit/pan" |
| paddy@155 | 8 ) |
| paddy@155 | 9 |
| paddy@155 | 10 type tokenScope struct { |
| paddy@155 | 11 Token string |
| paddy@155 | 12 Scope string |
| paddy@155 | 13 } |
| paddy@155 | 14 |
| paddy@155 | 15 func (t tokenScope) GetSQLTableName() string { |
| paddy@155 | 16 return "scopes_tokens" |
| paddy@155 | 17 } |
| paddy@155 | 18 |
| paddy@155 | 19 func (t Token) GetSQLTableName() string { |
| paddy@155 | 20 return "tokens" |
| paddy@155 | 21 } |
| paddy@155 | 22 |
| paddy@155 | 23 func (p *postgres) getTokenSQL(token string, refresh bool) *pan.Query { |
| paddy@155 | 24 var t Token |
| paddy@155 | 25 fields, _ := pan.GetFields(t) |
| paddy@155 | 26 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(t)) |
| paddy@155 | 27 query.IncludeWhere() |
| paddy@155 | 28 if !refresh { |
| paddy@155 | 29 query.Include(pan.GetUnquotedColumn(t, "AccessToken")+" = ?", token) |
| paddy@155 | 30 } else { |
| paddy@155 | 31 query.Include(pan.GetUnquotedColumn(t, "RefreshToken")+" = ?", token) |
| paddy@155 | 32 } |
| paddy@155 | 33 return query.FlushExpressions(" ") |
| paddy@155 | 34 } |
| paddy@155 | 35 |
| paddy@155 | 36 func (p *postgres) getToken(token string, refresh bool) (Token, error) { |
| paddy@155 | 37 query := p.getTokenSQL(token, refresh) |
| paddy@155 | 38 rows, err := p.db.Query(query.String(), query.Args...) |
| paddy@155 | 39 if err != nil { |
| paddy@155 | 40 return Token{}, err |
| paddy@155 | 41 } |
| paddy@155 | 42 var t Token |
| paddy@155 | 43 var found bool |
| paddy@155 | 44 for rows.Next() { |
| paddy@155 | 45 err := pan.Unmarshal(rows, &t) |
| paddy@155 | 46 if err != nil { |
| paddy@155 | 47 return t, err |
| paddy@155 | 48 } |
| paddy@155 | 49 found = true |
| paddy@155 | 50 } |
| paddy@155 | 51 if err = rows.Err(); err != nil { |
| paddy@155 | 52 return t, err |
| paddy@155 | 53 } |
| paddy@155 | 54 if !found { |
| paddy@155 | 55 return t, ErrTokenNotFound |
| paddy@155 | 56 } |
| paddy@155 | 57 query = p.getTokenScopesSQL([]string{t.AccessToken}) |
| paddy@155 | 58 rows, err = p.db.Query(query.String(), query.Args...) |
| paddy@155 | 59 if err != nil { |
| paddy@155 | 60 return t, err |
| paddy@155 | 61 } |
| paddy@155 | 62 for rows.Next() { |
| paddy@155 | 63 var ts tokenScope |
| paddy@155 | 64 err = pan.Unmarshal(rows, &ts) |
| paddy@155 | 65 if err != nil { |
| paddy@155 | 66 return t, err |
| paddy@155 | 67 } |
| paddy@155 | 68 t.Scopes = append(t.Scopes, ts.Scope) |
| paddy@155 | 69 } |
| paddy@155 | 70 if err = rows.Err(); err != nil { |
| paddy@155 | 71 return t, err |
| paddy@155 | 72 } |
| paddy@155 | 73 return t, nil |
| paddy@155 | 74 } |
| paddy@155 | 75 |
| paddy@155 | 76 func (p *postgres) saveTokenSQL(token Token) *pan.Query { |
| paddy@155 | 77 fields, values := pan.GetFields(token) |
| paddy@155 | 78 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(token)) |
| paddy@155 | 79 query.Include("(" + pan.QueryList(fields) + ")") |
| paddy@155 | 80 query.Include("VALUES") |
| paddy@155 | 81 query.Include("("+pan.VariableList(len(values))+")", values...) |
| paddy@155 | 82 return query.FlushExpressions(" ") |
| paddy@155 | 83 } |
| paddy@155 | 84 |
| paddy@155 | 85 func (p *postgres) saveTokenScopesSQL(ts []tokenScope) *pan.Query { |
| paddy@155 | 86 fields, _ := pan.GetFields(ts[0]) |
| paddy@155 | 87 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(ts[0])) |
| paddy@155 | 88 query.Include("(" + pan.QueryList(fields) + ")") |
| paddy@155 | 89 query.Include("VALUES") |
| paddy@155 | 90 query.FlushExpressions(" ") |
| paddy@155 | 91 for _, t := range ts { |
| paddy@155 | 92 _, values := pan.GetFields(t) |
| paddy@155 | 93 query.Include("("+pan.VariableList(len(values))+")", values...) |
| paddy@155 | 94 } |
| paddy@155 | 95 return query.FlushExpressions(", ") |
| paddy@155 | 96 } |
| paddy@155 | 97 |
| paddy@155 | 98 func (p *postgres) saveToken(token Token) error { |
| paddy@155 | 99 query := p.saveTokenSQL(token) |
| paddy@155 | 100 _, err := p.db.Exec(query.String(), query.Args...) |
| paddy@155 | 101 if e, ok := err.(*pq.Error); ok && e.Constraint == "tokens_pkey" { |
| paddy@155 | 102 err = ErrTokenAlreadyExists |
| paddy@155 | 103 } |
| paddy@155 | 104 if err != nil || len(token.Scopes) < 1 { |
| paddy@155 | 105 return err |
| paddy@155 | 106 } |
| paddy@155 | 107 var ts []tokenScope |
| paddy@155 | 108 for _, scope := range token.Scopes { |
| paddy@155 | 109 ts = append(ts, tokenScope{Token: token.AccessToken, Scope: scope}) |
| paddy@155 | 110 } |
| paddy@155 | 111 query = p.saveTokenScopesSQL(ts) |
| paddy@155 | 112 _, err = p.db.Exec(query.String(), query.Args...) |
| paddy@155 | 113 return err |
| paddy@155 | 114 } |
| paddy@155 | 115 |
| paddy@155 | 116 func (p *postgres) revokeTokenSQL(token string, refresh bool) *pan.Query { |
| paddy@155 | 117 var t Token |
| paddy@155 | 118 query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(t)+" SET ") |
| paddy@155 | 119 query.Include(pan.GetUnquotedColumn(t, "Revoked")+" = ?", true) |
| paddy@155 | 120 query.IncludeWhere() |
| paddy@155 | 121 if !refresh { |
| paddy@155 | 122 query.Include(pan.GetUnquotedColumn(t, "AccessToken")+" = ?", token) |
| paddy@155 | 123 } else { |
| paddy@155 | 124 query.Include(pan.GetUnquotedColumn(t, "RefreshToken")+" = ?", token) |
| paddy@155 | 125 } |
| paddy@155 | 126 return query.FlushExpressions(" ") |
| paddy@155 | 127 } |
| paddy@155 | 128 |
| paddy@155 | 129 func (p *postgres) revokeToken(token string, refresh bool) error { |
| paddy@155 | 130 query := p.revokeTokenSQL(token, refresh) |
| paddy@155 | 131 res, err := p.db.Exec(query.String(), query.Args...) |
| paddy@155 | 132 if err != nil { |
| paddy@155 | 133 return err |
| paddy@155 | 134 } |
| paddy@155 | 135 rows, err := res.RowsAffected() |
| paddy@155 | 136 if err != nil { |
| paddy@155 | 137 return err |
| paddy@155 | 138 } |
| paddy@155 | 139 if rows == 0 { |
| paddy@155 | 140 return ErrTokenNotFound |
| paddy@155 | 141 } |
| paddy@155 | 142 return nil |
| paddy@155 | 143 } |
| paddy@155 | 144 |
| paddy@162 | 145 func (p *postgres) revokeTokensByProfileIDSQL(profileID uuid.ID) *pan.Query { |
| paddy@162 | 146 var t Token |
| paddy@162 | 147 query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(t)+" SET ") |
| paddy@162 | 148 query.Include(pan.GetUnquotedColumn(t, "Revoked")+" = ?", true) |
| paddy@162 | 149 query.IncludeWhere() |
| paddy@162 | 150 query.Include(pan.GetUnquotedColumn(t, "ProfileID")+" = ?", profileID) |
| paddy@162 | 151 return query.FlushExpressions(" ") |
| paddy@162 | 152 } |
| paddy@162 | 153 |
| paddy@162 | 154 func (p *postgres) revokeTokensByProfileID(profileID uuid.ID) error { |
| paddy@162 | 155 query := p.revokeTokensByProfileIDSQL(profileID) |
| paddy@162 | 156 res, err := p.db.Exec(query.String(), query.Args...) |
| paddy@162 | 157 if err != nil { |
| paddy@162 | 158 return err |
| paddy@162 | 159 } |
| paddy@162 | 160 rows, err := res.RowsAffected() |
| paddy@162 | 161 if err != nil { |
| paddy@162 | 162 return err |
| paddy@162 | 163 } |
| paddy@162 | 164 if rows == 0 { |
| paddy@162 | 165 return ErrProfileNotFound |
| paddy@162 | 166 } |
| paddy@162 | 167 return nil |
| paddy@162 | 168 } |
| paddy@162 | 169 |
| paddy@155 | 170 func (p *postgres) getTokensByProfileIDSQL(profileID uuid.ID, num, offset int) *pan.Query { |
| paddy@155 | 171 var token Token |
| paddy@155 | 172 fields, _ := pan.GetFields(token) |
| paddy@155 | 173 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(token)) |
| paddy@155 | 174 query.IncludeWhere() |
| paddy@155 | 175 query.Include(pan.GetUnquotedColumn(token, "ProfileID")+" = ?", profileID) |
| paddy@155 | 176 query.IncludeLimit(int64(num)) |
| paddy@155 | 177 query.IncludeOffset(int64(offset)) |
| paddy@155 | 178 return query.FlushExpressions(" ") |
| paddy@155 | 179 } |
| paddy@155 | 180 |
| paddy@155 | 181 func (p *postgres) getTokenScopesSQL(tokens []string) *pan.Query { |
| paddy@155 | 182 var t tokenScope |
| paddy@155 | 183 fields, _ := pan.GetFields(t) |
| paddy@155 | 184 tokensI := make([]interface{}, len(tokens)) |
| paddy@155 | 185 for pos, token := range tokens { |
| paddy@155 | 186 tokensI[pos] = token |
| paddy@155 | 187 } |
| paddy@155 | 188 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(t)) |
| paddy@155 | 189 query.IncludeWhere() |
| paddy@155 | 190 query.Include(pan.GetUnquotedColumn(t, "Token")+" IN ("+pan.VariableList(len(tokensI))+")", tokensI...) |
| paddy@155 | 191 return query.FlushExpressions(" ") |
| paddy@155 | 192 } |
| paddy@155 | 193 |
| paddy@155 | 194 func (p *postgres) getTokensByProfileID(profileID uuid.ID, num, offset int) ([]Token, error) { |
| paddy@155 | 195 query := p.getTokensByProfileIDSQL(profileID, num, offset) |
| paddy@155 | 196 rows, err := p.db.Query(query.String(), query.Args...) |
| paddy@155 | 197 if err != nil { |
| paddy@155 | 198 return []Token{}, err |
| paddy@155 | 199 } |
| paddy@155 | 200 var tokens []Token |
| paddy@155 | 201 var tokenIDs []string |
| paddy@155 | 202 for rows.Next() { |
| paddy@155 | 203 var token Token |
| paddy@155 | 204 err = pan.Unmarshal(rows, &token) |
| paddy@155 | 205 if err != nil { |
| paddy@155 | 206 return tokens, err |
| paddy@155 | 207 } |
| paddy@155 | 208 tokens = append(tokens, token) |
| paddy@155 | 209 tokenIDs = append(tokenIDs, token.AccessToken) |
| paddy@155 | 210 } |
| paddy@155 | 211 if err = rows.Err(); err != nil { |
| paddy@155 | 212 return tokens, err |
| paddy@155 | 213 } |
| paddy@155 | 214 if len(tokenIDs) < 1 { |
| paddy@155 | 215 return tokens, nil |
| paddy@155 | 216 } |
| paddy@155 | 217 scopes := map[string][]string{} |
| paddy@155 | 218 query = p.getTokenScopesSQL(tokenIDs) |
| paddy@155 | 219 rows, err = p.db.Query(query.String(), query.Args...) |
| paddy@155 | 220 if err != nil { |
| paddy@155 | 221 return tokens, err |
| paddy@155 | 222 } |
| paddy@155 | 223 for rows.Next() { |
| paddy@155 | 224 var t tokenScope |
| paddy@155 | 225 err = pan.Unmarshal(rows, &t) |
| paddy@155 | 226 if err != nil { |
| paddy@155 | 227 return tokens, err |
| paddy@155 | 228 } |
| paddy@155 | 229 scopes[t.Token] = append(scopes[t.Token], t.Scope) |
| paddy@155 | 230 } |
| paddy@155 | 231 if err = rows.Err(); err != nil { |
| paddy@155 | 232 return tokens, err |
| paddy@155 | 233 } |
| paddy@155 | 234 for pos, token := range tokens { |
| paddy@155 | 235 token.Scopes = scopes[token.AccessToken] |
| paddy@155 | 236 tokens[pos] = token |
| paddy@155 | 237 } |
| paddy@155 | 238 return tokens, nil |
| paddy@155 | 239 } |