auth

Paddy 2015-03-24 Parent:77db7c65216c Child:cf1aef6eb81f

153:3e8964a914ef Go to Latest

auth/client_postgres.go

Fix tests for scopeStore. Update all our tests to use the PG_TEST_DB environment variable if set, and use that to control whether or not the postgres tests get run. testing.Short() just wasn't working. Update ErrScopeNotFound and ErrScopeAlreadyExists to be variables instead of types, because PostgreSQL (annoyingly) offers no way to determine which specific row insertion caused the problem, and I anticipate this being a problem that is ongoing. So it's really just not worth it. Stop getScopes from returning an ErrScopeNotFound. Let's return what we find, and let the absence of what we didn't find speak for itself. Fix an error with generating the SQL for the postgres.createScopes call. We used to generate it in a way that was invalid (not joining values with ",") when more than one set of values was supplied. Hooray, testing! Update the postgres scopeStore to return ErrScopeNotFound and ErrScopeAlreadyExists errors, as appropriate. Update our tests to reflect that ErrScopeNotFound and ErrScopeAlreadyExists are now variables, not types.

History
paddy@151 1 package auth
paddy@151 2
paddy@151 3 import (
paddy@151 4 "code.secondbit.org/uuid.hg"
paddy@151 5 "github.com/lib/pq"
paddy@151 6 "github.com/secondbit/pan"
paddy@151 7 )
paddy@151 8
paddy@151 9 func (c Client) GetSQLTableName() string {
paddy@151 10 return "clients"
paddy@151 11 }
paddy@151 12
paddy@151 13 func (e Endpoint) GetSQLTableName() string {
paddy@151 14 return "endpoints"
paddy@151 15 }
paddy@151 16
paddy@151 17 func (p *postgres) getClientSQL(id uuid.ID) *pan.Query {
paddy@151 18 var client Client
paddy@151 19 fields, _ := pan.GetFields(client)
paddy@151 20 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client))
paddy@151 21 query.IncludeWhere()
paddy@151 22 query.Include(pan.GetUnquotedColumn(client, "ID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", id, false)
paddy@151 23 return query.FlushExpressions(" ")
paddy@151 24 }
paddy@151 25
paddy@151 26 func (p *postgres) getClient(id uuid.ID) (Client, error) {
paddy@151 27 query := p.getClientSQL(id)
paddy@151 28 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 29 if err != nil {
paddy@151 30 return Client{}, err
paddy@151 31 }
paddy@151 32 var client Client
paddy@151 33 var found bool
paddy@151 34 for rows.Next() {
paddy@151 35 err := pan.Unmarshal(rows, &client)
paddy@151 36 if err != nil {
paddy@151 37 return client, err
paddy@151 38 }
paddy@151 39 found = true
paddy@151 40 }
paddy@151 41 if err = rows.Err(); err != nil {
paddy@151 42 return client, err
paddy@151 43 }
paddy@151 44 if !found {
paddy@151 45 return client, ErrClientNotFound
paddy@151 46 }
paddy@151 47 return client, nil
paddy@151 48 }
paddy@151 49
paddy@151 50 func (p *postgres) saveClientSQL(client Client) *pan.Query {
paddy@151 51 fields, values := pan.GetFields(client)
paddy@151 52 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(client))
paddy@151 53 query.Include("(" + pan.QueryList(fields) + ")")
paddy@151 54 query.Include("VALUES")
paddy@151 55 query.Include("("+pan.VariableList(len(values))+")", values...)
paddy@151 56 return query.FlushExpressions(" ")
paddy@151 57 }
paddy@151 58
paddy@151 59 func (p *postgres) saveClient(client Client) error {
paddy@151 60 query := p.saveClientSQL(client)
paddy@151 61 _, err := p.db.Exec(query.String(), query.Args...)
paddy@151 62 if e, ok := err.(*pq.Error); ok && e.Constraint == "clients_pkey" {
paddy@151 63 err = ErrClientAlreadyExists
paddy@151 64 }
paddy@151 65 return err
paddy@151 66 }
paddy@151 67
paddy@151 68 func (p *postgres) updateClientSQL(id uuid.ID, change ClientChange) *pan.Query {
paddy@151 69 var client Client
paddy@151 70 query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(client)+" SET ")
paddy@151 71 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Secret")+" = ?", change.Secret)
paddy@151 72 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "OwnerID")+" = ?", change.OwnerID)
paddy@151 73 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Name")+" = ?", change.Name)
paddy@151 74 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Logo")+" = ?", change.Logo)
paddy@151 75 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Website")+" = ?", change.Website)
paddy@151 76 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Deleted")+" = ?", change.Deleted)
paddy@151 77 query.FlushExpressions(", ")
paddy@151 78 query.IncludeWhere()
paddy@151 79 query.Include(pan.GetUnquotedColumn(client, "ID")+" = ?", id)
paddy@151 80 return query.FlushExpressions(" ")
paddy@151 81 }
paddy@151 82
paddy@151 83 func (p *postgres) updateClient(id uuid.ID, change ClientChange) error {
paddy@151 84 if change.Empty() {
paddy@151 85 return nil
paddy@151 86 }
paddy@151 87 query := p.updateClientSQL(id, change)
paddy@151 88 _, err := p.db.Exec(query.String(), query.Args...)
paddy@151 89 return err
paddy@151 90 }
paddy@151 91
paddy@151 92 func (p *postgres) listClientsByOwnerSQL(ownerID uuid.ID, num, offset int) *pan.Query {
paddy@151 93 var client Client
paddy@151 94 fields, _ := pan.GetFields(client)
paddy@151 95 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client))
paddy@151 96 query.IncludeWhere()
paddy@151 97 query.Include(pan.GetUnquotedColumn(client, "OwnerID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", ownerID, false)
paddy@151 98 query.IncludeLimit(int64(num))
paddy@151 99 query.IncludeOffset(int64(offset))
paddy@151 100 return query.FlushExpressions(" ")
paddy@151 101 }
paddy@151 102
paddy@151 103 func (p *postgres) listClientsByOwner(ownerID uuid.ID, num, offset int) ([]Client, error) {
paddy@151 104 query := p.listClientsByOwnerSQL(ownerID, num, offset)
paddy@151 105 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 106 if err != nil {
paddy@151 107 return []Client{}, err
paddy@151 108 }
paddy@151 109 var clients []Client
paddy@151 110 for rows.Next() {
paddy@151 111 var client Client
paddy@151 112 err = pan.Unmarshal(rows, &client)
paddy@151 113 if err != nil {
paddy@151 114 return clients, err
paddy@151 115 }
paddy@151 116 clients = append(clients, client)
paddy@151 117 }
paddy@151 118 if err = rows.Err(); err != nil {
paddy@151 119 return clients, err
paddy@151 120 }
paddy@151 121 return clients, nil
paddy@151 122 }
paddy@151 123
paddy@151 124 func (p *postgres) addEndpointsSQL(endpoints []Endpoint) *pan.Query {
paddy@151 125 fields, _ := pan.GetFields(endpoints[0])
paddy@151 126 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(endpoints[0]))
paddy@151 127 query.Include("(" + pan.QueryList(fields) + ")")
paddy@151 128 query.Include("VALUES")
paddy@151 129 query.FlushExpressions(" ")
paddy@151 130 for _, endpoint := range endpoints {
paddy@151 131 _, values := pan.GetFields(endpoint)
paddy@151 132 query.Include("("+pan.VariableList(len(values))+")", values...)
paddy@151 133 }
paddy@151 134 return query.FlushExpressions(", ")
paddy@151 135 }
paddy@151 136
paddy@151 137 func (p *postgres) addEndpoints(endpoints []Endpoint) error {
paddy@151 138 if len(endpoints) < 1 {
paddy@151 139 return nil
paddy@151 140 }
paddy@151 141 query := p.addEndpointsSQL(endpoints)
paddy@151 142 _, err := p.db.Exec(query.String(), query.Args...)
paddy@151 143 if e, ok := err.(*pq.Error); ok && e.Constraint == "endpoints_pkey" {
paddy@151 144 return ErrEndpointAlreadyExists
paddy@151 145 }
paddy@151 146 return err
paddy@151 147 }
paddy@151 148
paddy@151 149 func (p *postgres) removeEndpointSQL(client, endpoint uuid.ID) *pan.Query {
paddy@151 150 var e Endpoint
paddy@151 151 query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(e))
paddy@151 152 query.IncludeWhere()
paddy@151 153 query.Include(pan.GetUnquotedColumn(e, "ID")+" = ? AND "+pan.GetUnquotedColumn(e, "ClientID")+" = ?", endpoint, client)
paddy@151 154 return query.FlushExpressions(" ")
paddy@151 155 }
paddy@151 156
paddy@151 157 func (p *postgres) removeEndpoint(client, endpoint uuid.ID) error {
paddy@151 158 query := p.removeEndpointSQL(client, endpoint)
paddy@151 159 res, err := p.db.Exec(query.String(), query.Args...)
paddy@151 160 if err != nil {
paddy@151 161 return err
paddy@151 162 }
paddy@151 163 rows, err := res.RowsAffected()
paddy@151 164 if err != nil {
paddy@151 165 return err
paddy@151 166 }
paddy@151 167 if rows == 0 {
paddy@151 168 return ErrEndpointNotFound
paddy@151 169 }
paddy@151 170 return nil
paddy@151 171 }
paddy@151 172
paddy@151 173 func (p *postgres) getEndpointSQL(client, endpoint uuid.ID) *pan.Query {
paddy@151 174 var e Endpoint
paddy@151 175 fields, _ := pan.GetFields(e)
paddy@151 176 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e))
paddy@151 177 query.IncludeWhere()
paddy@151 178 query.FlushExpressions(" ")
paddy@151 179 query.Include(pan.GetUnquotedColumn(e, "ID")+" = ?", endpoint)
paddy@151 180 query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client)
paddy@151 181 return query.FlushExpressions(" AND ")
paddy@151 182 }
paddy@151 183
paddy@151 184 func (p *postgres) getEndpoint(client, endpoint uuid.ID) (Endpoint, error) {
paddy@151 185 query := p.getEndpointSQL(client, endpoint)
paddy@151 186 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 187 if err != nil {
paddy@151 188 return Endpoint{}, err
paddy@151 189 }
paddy@151 190 var e Endpoint
paddy@151 191 var found bool
paddy@151 192 for rows.Next() {
paddy@151 193 err := pan.Unmarshal(rows, &e)
paddy@151 194 if err != nil {
paddy@151 195 return e, err
paddy@151 196 }
paddy@151 197 found = true
paddy@151 198 }
paddy@151 199 if err = rows.Err(); err != nil {
paddy@151 200 return e, err
paddy@151 201 }
paddy@151 202 if !found {
paddy@151 203 return e, ErrEndpointNotFound
paddy@151 204 }
paddy@151 205 return e, nil
paddy@151 206 }
paddy@151 207
paddy@151 208 func (p *postgres) checkEndpointSQL(client uuid.ID, endpoint string) *pan.Query {
paddy@151 209 var e Endpoint
paddy@151 210 fields, _ := pan.GetFields(e)
paddy@151 211 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e))
paddy@151 212 query.IncludeWhere()
paddy@151 213 query.FlushExpressions(" ")
paddy@151 214 query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client)
paddy@151 215 query.Include(pan.GetUnquotedColumn(e, "NormalizedURI")+" = ?", endpoint)
paddy@151 216 return query.FlushExpressions(" AND ")
paddy@151 217 }
paddy@151 218
paddy@151 219 func (p *postgres) checkEndpoint(client uuid.ID, endpoint string) (bool, error) {
paddy@151 220 query := p.checkEndpointSQL(client, endpoint)
paddy@151 221 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 222 if err != nil {
paddy@151 223 return false, err
paddy@151 224 }
paddy@151 225 var found bool
paddy@151 226 for rows.Next() {
paddy@151 227 found = true
paddy@151 228 }
paddy@151 229 if err = rows.Err(); err != nil {
paddy@151 230 return found, err
paddy@151 231 }
paddy@151 232 return found, nil
paddy@151 233 }
paddy@151 234
paddy@151 235 func (p *postgres) listEndpointsSQL(client uuid.ID, num, offset int) *pan.Query {
paddy@151 236 var endpoint Endpoint
paddy@151 237 fields, _ := pan.GetFields(endpoint)
paddy@151 238 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(endpoint))
paddy@151 239 query.IncludeWhere()
paddy@151 240 query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client)
paddy@151 241 query.IncludeLimit(int64(num))
paddy@151 242 query.IncludeOffset(int64(offset))
paddy@151 243 return query.FlushExpressions(" ")
paddy@151 244 }
paddy@151 245
paddy@151 246 func (p *postgres) listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error) {
paddy@151 247 query := p.listEndpointsSQL(client, num, offset)
paddy@151 248 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 249 if err != nil {
paddy@151 250 return []Endpoint{}, err
paddy@151 251 }
paddy@151 252 var endpoints []Endpoint
paddy@151 253 for rows.Next() {
paddy@151 254 var endpoint Endpoint
paddy@151 255 err = pan.Unmarshal(rows, &endpoint)
paddy@151 256 if err != nil {
paddy@151 257 return endpoints, err
paddy@151 258 }
paddy@151 259 endpoints = append(endpoints, endpoint)
paddy@151 260 }
paddy@151 261 if err = rows.Err(); err != nil {
paddy@151 262 return endpoints, err
paddy@151 263 }
paddy@151 264 return endpoints, nil
paddy@151 265 }
paddy@151 266
paddy@151 267 func (p *postgres) countEndpointsSQL(client uuid.ID) *pan.Query {
paddy@151 268 var endpoint Endpoint
paddy@151 269 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM "+pan.GetTableName(endpoint))
paddy@151 270 query.IncludeWhere()
paddy@151 271 query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client)
paddy@151 272 return query.FlushExpressions(" ")
paddy@151 273 }
paddy@151 274
paddy@151 275 func (p *postgres) countEndpoints(client uuid.ID) (int64, error) {
paddy@151 276 query := p.countEndpointsSQL(client)
paddy@151 277 rows, err := p.db.Query(query.String(), query.Args...)
paddy@151 278 if err != nil {
paddy@151 279 return 0, err
paddy@151 280 }
paddy@151 281 var results int64
paddy@151 282 for rows.Next() {
paddy@151 283 err = pan.Unmarshal(rows, &results)
paddy@151 284 if err != nil {
paddy@151 285 return results, err
paddy@151 286 }
paddy@151 287 }
paddy@151 288 if err = rows.Err(); err != nil {
paddy@151 289 return results, err
paddy@151 290 }
paddy@151 291 return results, nil
paddy@151 292 }