auth
auth/client_postgres.go
Implement postgres clientStore. Stop requiring the client ID be passed to clientStore.addEndpoints and context.AddEndpoints. The Endpoints themselves contain the client ID. When using the authd server, set the log flags to include the file path and line number. Add an ErrEndpointAlreadyExists error, to return when creating an endpoint and its ID already exists in the database. Add a Deleted property to Clients and remove the clientStore.deleteClient and context.DeleteClient methods. We're not going to actually remove that data, and we want to be able to restore it, so include it in the ClientChange type and call it using UpdateClient. Create a ClientChange.Empty helper method that will return whether the ClientChange has any changes to perform. Return ErrClientNotFound from clientStore.getClient if the Client's Deleted property is set to true. This also requires us to ignore ErrClientNotFound errors when calling memstore.listClientsByOwner, as they should just be skipped instead of returning an error. Add the postgres type methods needed to implement clientStore. Include postgres as a clientStore if the testing.Short() flag is not set. Generate a new ID for the Client on every run in the tests, now that we can't actually remove it from the database/memstore in code. We really just need a *Store.Reset() function that erases all the data and starts over again, to give the tests a clean execution environment (and so they can clean up after themselves). Add the CREATE TABLE statements for the Clients table and the Endpoints table to sql/postgres_init.sql.
1.1 --- /dev/null Thu Jan 01 00:00:00 1970 +0000 1.2 +++ b/client_postgres.go Sun Mar 22 16:26:37 2015 -0400 1.3 @@ -0,0 +1,292 @@ 1.4 +package auth 1.5 + 1.6 +import ( 1.7 + "code.secondbit.org/uuid.hg" 1.8 + "github.com/lib/pq" 1.9 + "github.com/secondbit/pan" 1.10 +) 1.11 + 1.12 +func (c Client) GetSQLTableName() string { 1.13 + return "clients" 1.14 +} 1.15 + 1.16 +func (e Endpoint) GetSQLTableName() string { 1.17 + return "endpoints" 1.18 +} 1.19 + 1.20 +func (p *postgres) getClientSQL(id uuid.ID) *pan.Query { 1.21 + var client Client 1.22 + fields, _ := pan.GetFields(client) 1.23 + query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client)) 1.24 + query.IncludeWhere() 1.25 + query.Include(pan.GetUnquotedColumn(client, "ID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", id, false) 1.26 + return query.FlushExpressions(" ") 1.27 +} 1.28 + 1.29 +func (p *postgres) getClient(id uuid.ID) (Client, error) { 1.30 + query := p.getClientSQL(id) 1.31 + rows, err := p.db.Query(query.String(), query.Args...) 1.32 + if err != nil { 1.33 + return Client{}, err 1.34 + } 1.35 + var client Client 1.36 + var found bool 1.37 + for rows.Next() { 1.38 + err := pan.Unmarshal(rows, &client) 1.39 + if err != nil { 1.40 + return client, err 1.41 + } 1.42 + found = true 1.43 + } 1.44 + if err = rows.Err(); err != nil { 1.45 + return client, err 1.46 + } 1.47 + if !found { 1.48 + return client, ErrClientNotFound 1.49 + } 1.50 + return client, nil 1.51 +} 1.52 + 1.53 +func (p *postgres) saveClientSQL(client Client) *pan.Query { 1.54 + fields, values := pan.GetFields(client) 1.55 + query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(client)) 1.56 + query.Include("(" + pan.QueryList(fields) + ")") 1.57 + query.Include("VALUES") 1.58 + query.Include("("+pan.VariableList(len(values))+")", values...) 1.59 + return query.FlushExpressions(" ") 1.60 +} 1.61 + 1.62 +func (p *postgres) saveClient(client Client) error { 1.63 + query := p.saveClientSQL(client) 1.64 + _, err := p.db.Exec(query.String(), query.Args...) 1.65 + if e, ok := err.(*pq.Error); ok && e.Constraint == "clients_pkey" { 1.66 + err = ErrClientAlreadyExists 1.67 + } 1.68 + return err 1.69 +} 1.70 + 1.71 +func (p *postgres) updateClientSQL(id uuid.ID, change ClientChange) *pan.Query { 1.72 + var client Client 1.73 + query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(client)+" SET ") 1.74 + query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Secret")+" = ?", change.Secret) 1.75 + query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "OwnerID")+" = ?", change.OwnerID) 1.76 + query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Name")+" = ?", change.Name) 1.77 + query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Logo")+" = ?", change.Logo) 1.78 + query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Website")+" = ?", change.Website) 1.79 + query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Deleted")+" = ?", change.Deleted) 1.80 + query.FlushExpressions(", ") 1.81 + query.IncludeWhere() 1.82 + query.Include(pan.GetUnquotedColumn(client, "ID")+" = ?", id) 1.83 + return query.FlushExpressions(" ") 1.84 +} 1.85 + 1.86 +func (p *postgres) updateClient(id uuid.ID, change ClientChange) error { 1.87 + if change.Empty() { 1.88 + return nil 1.89 + } 1.90 + query := p.updateClientSQL(id, change) 1.91 + _, err := p.db.Exec(query.String(), query.Args...) 1.92 + return err 1.93 +} 1.94 + 1.95 +func (p *postgres) listClientsByOwnerSQL(ownerID uuid.ID, num, offset int) *pan.Query { 1.96 + var client Client 1.97 + fields, _ := pan.GetFields(client) 1.98 + query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client)) 1.99 + query.IncludeWhere() 1.100 + query.Include(pan.GetUnquotedColumn(client, "OwnerID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", ownerID, false) 1.101 + query.IncludeLimit(int64(num)) 1.102 + query.IncludeOffset(int64(offset)) 1.103 + return query.FlushExpressions(" ") 1.104 +} 1.105 + 1.106 +func (p *postgres) listClientsByOwner(ownerID uuid.ID, num, offset int) ([]Client, error) { 1.107 + query := p.listClientsByOwnerSQL(ownerID, num, offset) 1.108 + rows, err := p.db.Query(query.String(), query.Args...) 1.109 + if err != nil { 1.110 + return []Client{}, err 1.111 + } 1.112 + var clients []Client 1.113 + for rows.Next() { 1.114 + var client Client 1.115 + err = pan.Unmarshal(rows, &client) 1.116 + if err != nil { 1.117 + return clients, err 1.118 + } 1.119 + clients = append(clients, client) 1.120 + } 1.121 + if err = rows.Err(); err != nil { 1.122 + return clients, err 1.123 + } 1.124 + return clients, nil 1.125 +} 1.126 + 1.127 +func (p *postgres) addEndpointsSQL(endpoints []Endpoint) *pan.Query { 1.128 + fields, _ := pan.GetFields(endpoints[0]) 1.129 + query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(endpoints[0])) 1.130 + query.Include("(" + pan.QueryList(fields) + ")") 1.131 + query.Include("VALUES") 1.132 + query.FlushExpressions(" ") 1.133 + for _, endpoint := range endpoints { 1.134 + _, values := pan.GetFields(endpoint) 1.135 + query.Include("("+pan.VariableList(len(values))+")", values...) 1.136 + } 1.137 + return query.FlushExpressions(", ") 1.138 +} 1.139 + 1.140 +func (p *postgres) addEndpoints(endpoints []Endpoint) error { 1.141 + if len(endpoints) < 1 { 1.142 + return nil 1.143 + } 1.144 + query := p.addEndpointsSQL(endpoints) 1.145 + _, err := p.db.Exec(query.String(), query.Args...) 1.146 + if e, ok := err.(*pq.Error); ok && e.Constraint == "endpoints_pkey" { 1.147 + return ErrEndpointAlreadyExists 1.148 + } 1.149 + return err 1.150 +} 1.151 + 1.152 +func (p *postgres) removeEndpointSQL(client, endpoint uuid.ID) *pan.Query { 1.153 + var e Endpoint 1.154 + query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(e)) 1.155 + query.IncludeWhere() 1.156 + query.Include(pan.GetUnquotedColumn(e, "ID")+" = ? AND "+pan.GetUnquotedColumn(e, "ClientID")+" = ?", endpoint, client) 1.157 + return query.FlushExpressions(" ") 1.158 +} 1.159 + 1.160 +func (p *postgres) removeEndpoint(client, endpoint uuid.ID) error { 1.161 + query := p.removeEndpointSQL(client, endpoint) 1.162 + res, err := p.db.Exec(query.String(), query.Args...) 1.163 + if err != nil { 1.164 + return err 1.165 + } 1.166 + rows, err := res.RowsAffected() 1.167 + if err != nil { 1.168 + return err 1.169 + } 1.170 + if rows == 0 { 1.171 + return ErrEndpointNotFound 1.172 + } 1.173 + return nil 1.174 +} 1.175 + 1.176 +func (p *postgres) getEndpointSQL(client, endpoint uuid.ID) *pan.Query { 1.177 + var e Endpoint 1.178 + fields, _ := pan.GetFields(e) 1.179 + query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e)) 1.180 + query.IncludeWhere() 1.181 + query.FlushExpressions(" ") 1.182 + query.Include(pan.GetUnquotedColumn(e, "ID")+" = ?", endpoint) 1.183 + query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client) 1.184 + return query.FlushExpressions(" AND ") 1.185 +} 1.186 + 1.187 +func (p *postgres) getEndpoint(client, endpoint uuid.ID) (Endpoint, error) { 1.188 + query := p.getEndpointSQL(client, endpoint) 1.189 + rows, err := p.db.Query(query.String(), query.Args...) 1.190 + if err != nil { 1.191 + return Endpoint{}, err 1.192 + } 1.193 + var e Endpoint 1.194 + var found bool 1.195 + for rows.Next() { 1.196 + err := pan.Unmarshal(rows, &e) 1.197 + if err != nil { 1.198 + return e, err 1.199 + } 1.200 + found = true 1.201 + } 1.202 + if err = rows.Err(); err != nil { 1.203 + return e, err 1.204 + } 1.205 + if !found { 1.206 + return e, ErrEndpointNotFound 1.207 + } 1.208 + return e, nil 1.209 +} 1.210 + 1.211 +func (p *postgres) checkEndpointSQL(client uuid.ID, endpoint string) *pan.Query { 1.212 + var e Endpoint 1.213 + fields, _ := pan.GetFields(e) 1.214 + query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e)) 1.215 + query.IncludeWhere() 1.216 + query.FlushExpressions(" ") 1.217 + query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client) 1.218 + query.Include(pan.GetUnquotedColumn(e, "NormalizedURI")+" = ?", endpoint) 1.219 + return query.FlushExpressions(" AND ") 1.220 +} 1.221 + 1.222 +func (p *postgres) checkEndpoint(client uuid.ID, endpoint string) (bool, error) { 1.223 + query := p.checkEndpointSQL(client, endpoint) 1.224 + rows, err := p.db.Query(query.String(), query.Args...) 1.225 + if err != nil { 1.226 + return false, err 1.227 + } 1.228 + var found bool 1.229 + for rows.Next() { 1.230 + found = true 1.231 + } 1.232 + if err = rows.Err(); err != nil { 1.233 + return found, err 1.234 + } 1.235 + return found, nil 1.236 +} 1.237 + 1.238 +func (p *postgres) listEndpointsSQL(client uuid.ID, num, offset int) *pan.Query { 1.239 + var endpoint Endpoint 1.240 + fields, _ := pan.GetFields(endpoint) 1.241 + query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(endpoint)) 1.242 + query.IncludeWhere() 1.243 + query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client) 1.244 + query.IncludeLimit(int64(num)) 1.245 + query.IncludeOffset(int64(offset)) 1.246 + return query.FlushExpressions(" ") 1.247 +} 1.248 + 1.249 +func (p *postgres) listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error) { 1.250 + query := p.listEndpointsSQL(client, num, offset) 1.251 + rows, err := p.db.Query(query.String(), query.Args...) 1.252 + if err != nil { 1.253 + return []Endpoint{}, err 1.254 + } 1.255 + var endpoints []Endpoint 1.256 + for rows.Next() { 1.257 + var endpoint Endpoint 1.258 + err = pan.Unmarshal(rows, &endpoint) 1.259 + if err != nil { 1.260 + return endpoints, err 1.261 + } 1.262 + endpoints = append(endpoints, endpoint) 1.263 + } 1.264 + if err = rows.Err(); err != nil { 1.265 + return endpoints, err 1.266 + } 1.267 + return endpoints, nil 1.268 +} 1.269 + 1.270 +func (p *postgres) countEndpointsSQL(client uuid.ID) *pan.Query { 1.271 + var endpoint Endpoint 1.272 + query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM "+pan.GetTableName(endpoint)) 1.273 + query.IncludeWhere() 1.274 + query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client) 1.275 + return query.FlushExpressions(" ") 1.276 +} 1.277 + 1.278 +func (p *postgres) countEndpoints(client uuid.ID) (int64, error) { 1.279 + query := p.countEndpointsSQL(client) 1.280 + rows, err := p.db.Query(query.String(), query.Args...) 1.281 + if err != nil { 1.282 + return 0, err 1.283 + } 1.284 + var results int64 1.285 + for rows.Next() { 1.286 + err = pan.Unmarshal(rows, &results) 1.287 + if err != nil { 1.288 + return results, err 1.289 + } 1.290 + } 1.291 + if err = rows.Err(); err != nil { 1.292 + return results, err 1.293 + } 1.294 + return results, nil 1.295 +}