auth

Paddy 2015-04-11 Parent:77db7c65216c Child:cf1aef6eb81f

162:6f473576c6ae Go to Latest

auth/client_postgres.go

Clean up sessions and tokens after Profile is deleted. Add a terminateSessionsByProfile method to our sessionStore to mark Sessions associated with a Profile as inactive. Implement memstore and postgres implementations of the terminateSessionsByProfile method. Add a TerminateSessionsByProfile wrapper method to Context. Add a revokeTokensByProfileID method to our tokenStore to mark Tokens associated with a Profile as revoked. Implement memstore and postgres implementation of the revokeTokensByProfileID method. Add a RevokeTokensByProfileID wrapper method to Context. Call our RevokeTokensByProfileID and TerminateSessionsByProfile methods after a Profile is deleted, to clean up the Tokens and Sessions associated with it.

History
paddy@151 1 package auth
paddy@151 2
paddy@151 3 import (
paddy@151 4 "code.secondbit.org/uuid.hg"
paddy@151 5 "github.com/lib/pq"
paddy@151 6 "github.com/secondbit/pan"
paddy@151 7 )
paddy@151 8
paddy@151 9 func (c Client) GetSQLTableName() string {
paddy@151 10 return "clients"
paddy@151 11 }
paddy@151 12
paddy@151 13 func (e Endpoint) GetSQLTableName() string {
paddy@151 14 return "endpoints"
paddy@151 15 }
paddy@151 16
paddy@151 17 func (p *postgres) getClientSQL(id uuid.ID) *pan.Query {
paddy@151 18 var client Client
paddy@151 19 fields, _ := pan.GetFields(client)
paddy@151 20 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client))
paddy@151 21 query.IncludeWhere()
paddy@151 22 query.Include(pan.GetUnquotedColumn(client, "ID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", id, false)
paddy@151 23 return query.FlushExpressions(" ")
paddy@151 24 }
paddy@151 25
paddy@151 26 func (p *postgres) getClient(id uuid.ID) (Client, error) {
paddy@151 27 query := p.getClientSQL(id)
paddy@151 28 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 29 if err != nil {
paddy@151 30 return Client{}, err
paddy@151 31 }
paddy@151 32 var client Client
paddy@151 33 var found bool
paddy@151 34 for rows.Next() {
paddy@151 35 err := pan.Unmarshal(rows, &client)
paddy@151 36 if err != nil {
paddy@151 37 return client, err
paddy@151 38 }
paddy@151 39 found = true
paddy@151 40 }
paddy@151 41 if err = rows.Err(); err != nil {
paddy@151 42 return client, err
paddy@151 43 }
paddy@151 44 if !found {
paddy@151 45 return client, ErrClientNotFound
paddy@151 46 }
paddy@151 47 return client, nil
paddy@151 48 }
paddy@151 49
paddy@151 50 func (p *postgres) saveClientSQL(client Client) *pan.Query {
paddy@151 51 fields, values := pan.GetFields(client)
paddy@151 52 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(client))
paddy@151 53 query.Include("(" + pan.QueryList(fields) + ")")
paddy@151 54 query.Include("VALUES")
paddy@151 55 query.Include("("+pan.VariableList(len(values))+")", values...)
paddy@151 56 return query.FlushExpressions(" ")
paddy@151 57 }
paddy@151 58
paddy@151 59 func (p *postgres) saveClient(client Client) error {
paddy@151 60 query := p.saveClientSQL(client)
paddy@151 61 _, err := p.db.Exec(query.String(), query.Args...)
paddy@151 62 if e, ok := err.(*pq.Error); ok && e.Constraint == "clients_pkey" {
paddy@151 63 err = ErrClientAlreadyExists
paddy@151 64 }
paddy@151 65 return err
paddy@151 66 }
paddy@151 67
paddy@151 68 func (p *postgres) updateClientSQL(id uuid.ID, change ClientChange) *pan.Query {
paddy@151 69 var client Client
paddy@151 70 query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(client)+" SET ")
paddy@151 71 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Secret")+" = ?", change.Secret)
paddy@151 72 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "OwnerID")+" = ?", change.OwnerID)
paddy@151 73 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Name")+" = ?", change.Name)
paddy@151 74 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Logo")+" = ?", change.Logo)
paddy@151 75 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Website")+" = ?", change.Website)
paddy@151 76 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Deleted")+" = ?", change.Deleted)
paddy@151 77 query.FlushExpressions(", ")
paddy@151 78 query.IncludeWhere()
paddy@151 79 query.Include(pan.GetUnquotedColumn(client, "ID")+" = ?", id)
paddy@151 80 return query.FlushExpressions(" ")
paddy@151 81 }
paddy@151 82
paddy@151 83 func (p *postgres) updateClient(id uuid.ID, change ClientChange) error {
paddy@151 84 if change.Empty() {
paddy@151 85 return nil
paddy@151 86 }
paddy@151 87 query := p.updateClientSQL(id, change)
paddy@151 88 _, err := p.db.Exec(query.String(), query.Args...)
paddy@151 89 return err
paddy@151 90 }
paddy@151 91
paddy@151 92 func (p *postgres) listClientsByOwnerSQL(ownerID uuid.ID, num, offset int) *pan.Query {
paddy@151 93 var client Client
paddy@151 94 fields, _ := pan.GetFields(client)
paddy@151 95 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client))
paddy@151 96 query.IncludeWhere()
paddy@151 97 query.Include(pan.GetUnquotedColumn(client, "OwnerID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", ownerID, false)
paddy@151 98 query.IncludeLimit(int64(num))
paddy@151 99 query.IncludeOffset(int64(offset))
paddy@151 100 return query.FlushExpressions(" ")
paddy@151 101 }
paddy@151 102
paddy@151 103 func (p *postgres) listClientsByOwner(ownerID uuid.ID, num, offset int) ([]Client, error) {
paddy@151 104 query := p.listClientsByOwnerSQL(ownerID, num, offset)
paddy@151 105 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 106 if err != nil {
paddy@151 107 return []Client{}, err
paddy@151 108 }
paddy@151 109 var clients []Client
paddy@151 110 for rows.Next() {
paddy@151 111 var client Client
paddy@151 112 err = pan.Unmarshal(rows, &client)
paddy@151 113 if err != nil {
paddy@151 114 return clients, err
paddy@151 115 }
paddy@151 116 clients = append(clients, client)
paddy@151 117 }
paddy@151 118 if err = rows.Err(); err != nil {
paddy@151 119 return clients, err
paddy@151 120 }
paddy@151 121 return clients, nil
paddy@151 122 }
paddy@151 123
paddy@151 124 func (p *postgres) addEndpointsSQL(endpoints []Endpoint) *pan.Query {
paddy@151 125 fields, _ := pan.GetFields(endpoints[0])
paddy@151 126 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(endpoints[0]))
paddy@151 127 query.Include("(" + pan.QueryList(fields) + ")")
paddy@151 128 query.Include("VALUES")
paddy@151 129 query.FlushExpressions(" ")
paddy@151 130 for _, endpoint := range endpoints {
paddy@151 131 _, values := pan.GetFields(endpoint)
paddy@151 132 query.Include("("+pan.VariableList(len(values))+")", values...)
paddy@151 133 }
paddy@151 134 return query.FlushExpressions(", ")
paddy@151 135 }
paddy@151 136
paddy@151 137 func (p *postgres) addEndpoints(endpoints []Endpoint) error {
paddy@151 138 if len(endpoints) < 1 {
paddy@151 139 return nil
paddy@151 140 }
paddy@151 141 query := p.addEndpointsSQL(endpoints)
paddy@151 142 _, err := p.db.Exec(query.String(), query.Args...)
paddy@151 143 if e, ok := err.(*pq.Error); ok && e.Constraint == "endpoints_pkey" {
paddy@151 144 return ErrEndpointAlreadyExists
paddy@151 145 }
paddy@151 146 return err
paddy@151 147 }
paddy@151 148
paddy@151 149 func (p *postgres) removeEndpointSQL(client, endpoint uuid.ID) *pan.Query {
paddy@151 150 var e Endpoint
paddy@151 151 query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(e))
paddy@151 152 query.IncludeWhere()
paddy@151 153 query.Include(pan.GetUnquotedColumn(e, "ID")+" = ? AND "+pan.GetUnquotedColumn(e, "ClientID")+" = ?", endpoint, client)
paddy@151 154 return query.FlushExpressions(" ")
paddy@151 155 }
paddy@151 156
paddy@151 157 func (p *postgres) removeEndpoint(client, endpoint uuid.ID) error {
paddy@151 158 query := p.removeEndpointSQL(client, endpoint)
paddy@151 159 res, err := p.db.Exec(query.String(), query.Args...)
paddy@151 160 if err != nil {
paddy@151 161 return err
paddy@151 162 }
paddy@151 163 rows, err := res.RowsAffected()
paddy@151 164 if err != nil {
paddy@151 165 return err
paddy@151 166 }
paddy@151 167 if rows == 0 {
paddy@151 168 return ErrEndpointNotFound
paddy@151 169 }
paddy@151 170 return nil
paddy@151 171 }
paddy@151 172
paddy@151 173 func (p *postgres) getEndpointSQL(client, endpoint uuid.ID) *pan.Query {
paddy@151 174 var e Endpoint
paddy@151 175 fields, _ := pan.GetFields(e)
paddy@151 176 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e))
paddy@151 177 query.IncludeWhere()
paddy@151 178 query.FlushExpressions(" ")
paddy@151 179 query.Include(pan.GetUnquotedColumn(e, "ID")+" = ?", endpoint)
paddy@151 180 query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client)
paddy@151 181 return query.FlushExpressions(" AND ")
paddy@151 182 }
paddy@151 183
paddy@151 184 func (p *postgres) getEndpoint(client, endpoint uuid.ID) (Endpoint, error) {
paddy@151 185 query := p.getEndpointSQL(client, endpoint)
paddy@151 186 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 187 if err != nil {
paddy@151 188 return Endpoint{}, err
paddy@151 189 }
paddy@151 190 var e Endpoint
paddy@151 191 var found bool
paddy@151 192 for rows.Next() {
paddy@151 193 err := pan.Unmarshal(rows, &e)
paddy@151 194 if err != nil {
paddy@151 195 return e, err
paddy@151 196 }
paddy@151 197 found = true
paddy@151 198 }
paddy@151 199 if err = rows.Err(); err != nil {
paddy@151 200 return e, err
paddy@151 201 }
paddy@151 202 if !found {
paddy@151 203 return e, ErrEndpointNotFound
paddy@151 204 }
paddy@151 205 return e, nil
paddy@151 206 }
paddy@151 207
paddy@151 208 func (p *postgres) checkEndpointSQL(client uuid.ID, endpoint string) *pan.Query {
paddy@151 209 var e Endpoint
paddy@151 210 fields, _ := pan.GetFields(e)
paddy@151 211 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e))
paddy@151 212 query.IncludeWhere()
paddy@151 213 query.FlushExpressions(" ")
paddy@151 214 query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client)
paddy@151 215 query.Include(pan.GetUnquotedColumn(e, "NormalizedURI")+" = ?", endpoint)
paddy@151 216 return query.FlushExpressions(" AND ")
paddy@151 217 }
paddy@151 218
paddy@151 219 func (p *postgres) checkEndpoint(client uuid.ID, endpoint string) (bool, error) {
paddy@151 220 query := p.checkEndpointSQL(client, endpoint)
paddy@151 221 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 222 if err != nil {
paddy@151 223 return false, err
paddy@151 224 }
paddy@151 225 var found bool
paddy@151 226 for rows.Next() {
paddy@151 227 found = true
paddy@151 228 }
paddy@151 229 if err = rows.Err(); err != nil {
paddy@151 230 return found, err
paddy@151 231 }
paddy@151 232 return found, nil
paddy@151 233 }
paddy@151 234
paddy@151 235 func (p *postgres) listEndpointsSQL(client uuid.ID, num, offset int) *pan.Query {
paddy@151 236 var endpoint Endpoint
paddy@151 237 fields, _ := pan.GetFields(endpoint)
paddy@151 238 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(endpoint))
paddy@151 239 query.IncludeWhere()
paddy@151 240 query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client)
paddy@151 241 query.IncludeLimit(int64(num))
paddy@151 242 query.IncludeOffset(int64(offset))
paddy@151 243 return query.FlushExpressions(" ")
paddy@151 244 }
paddy@151 245
paddy@151 246 func (p *postgres) listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error) {
paddy@151 247 query := p.listEndpointsSQL(client, num, offset)
paddy@151 248 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 249 if err != nil {
paddy@151 250 return []Endpoint{}, err
paddy@151 251 }
paddy@151 252 var endpoints []Endpoint
paddy@151 253 for rows.Next() {
paddy@151 254 var endpoint Endpoint
paddy@151 255 err = pan.Unmarshal(rows, &endpoint)
paddy@151 256 if err != nil {
paddy@151 257 return endpoints, err
paddy@151 258 }
paddy@151 259 endpoints = append(endpoints, endpoint)
paddy@151 260 }
paddy@151 261 if err = rows.Err(); err != nil {
paddy@151 262 return endpoints, err
paddy@151 263 }
paddy@151 264 return endpoints, nil
paddy@151 265 }
paddy@151 266
paddy@151 267 func (p *postgres) countEndpointsSQL(client uuid.ID) *pan.Query {
paddy@151 268 var endpoint Endpoint
paddy@151 269 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM "+pan.GetTableName(endpoint))
paddy@151 270 query.IncludeWhere()
paddy@151 271 query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client)
paddy@151 272 return query.FlushExpressions(" ")
paddy@151 273 }
paddy@151 274
paddy@151 275 func (p *postgres) countEndpoints(client uuid.ID) (int64, error) {
paddy@151 276 query := p.countEndpointsSQL(client)
paddy@151 277 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 278 if err != nil {
paddy@151 279 return 0, err
paddy@151 280 }
paddy@151 281 var results int64
paddy@151 282 for rows.Next() {
paddy@151 283 err = pan.Unmarshal(rows, &results)
paddy@151 284 if err != nil {
paddy@151 285 return results, err
paddy@151 286 }
paddy@151 287 }
paddy@151 288 if err = rows.Err(); err != nil {
paddy@151 289 return results, err
paddy@151 290 }
paddy@151 291 return results, nil
paddy@151 292 }