Implement postgres version of the tokenStore.
Create a postgres implementation for the tokenStore. Note that because pq
doesn't support Postgres' array types (see https://github.com/lib/pq/issues/49),
we couldn't just store the token.Scopes field as a Postgres array of varchars,
which would have been the ideal. Instead, we need a many-to-many table that maps
tokens to scopes. This meant we needed a special tokenScope type for our
database mapping, and we needed to complicate the token storage/retrieval
functions a little bit. It's kind of ugly, I'm not a huge fan of it, and I'd
much rather be using the Postgres array types, but... well, here we are.
We also added the postgres tokenStore to our slice of tokenStores to test when
the correct environment variables are present.
We wrote initialization SQL for the tables required by the postgres tokenStore.
Also, added a helper script for emptying the test database, because I got tired
of doing it by hand. We should be doing that in an automated fashion in the
tests themselves, but that would mean extending the *Store interfaces.
4 "code.secondbit.org/uuid.hg"
6 "github.com/secondbit/pan"
9 func (c Client) GetSQLTableName() string {
13 func (e Endpoint) GetSQLTableName() string {
17 func (p *postgres) getClientSQL(id uuid.ID) *pan.Query {
19 fields, _ := pan.GetFields(client)
20 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client))
22 query.Include(pan.GetUnquotedColumn(client, "ID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", id, false)
23 return query.FlushExpressions(" ")
26 func (p *postgres) getClient(id uuid.ID) (Client, error) {
27 query := p.getClientSQL(id)
28 rows, err := p.db.Query(query.String(), query.Args...)
35 err := pan.Unmarshal(rows, &client)
41 if err = rows.Err(); err != nil {
45 return client, ErrClientNotFound
50 func (p *postgres) saveClientSQL(client Client) *pan.Query {
51 fields, values := pan.GetFields(client)
52 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(client))
53 query.Include("(" + pan.QueryList(fields) + ")")
54 query.Include("VALUES")
55 query.Include("("+pan.VariableList(len(values))+")", values...)
56 return query.FlushExpressions(" ")
59 func (p *postgres) saveClient(client Client) error {
60 query := p.saveClientSQL(client)
61 _, err := p.db.Exec(query.String(), query.Args...)
62 if e, ok := err.(*pq.Error); ok && e.Constraint == "clients_pkey" {
63 err = ErrClientAlreadyExists
68 func (p *postgres) updateClientSQL(id uuid.ID, change ClientChange) *pan.Query {
70 query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(client)+" SET ")
71 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Secret")+" = ?", change.Secret)
72 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "OwnerID")+" = ?", change.OwnerID)
73 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Name")+" = ?", change.Name)
74 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Logo")+" = ?", change.Logo)
75 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Website")+" = ?", change.Website)
76 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Deleted")+" = ?", change.Deleted)
77 query.FlushExpressions(", ")
79 query.Include(pan.GetUnquotedColumn(client, "ID")+" = ?", id)
80 return query.FlushExpressions(" ")
83 func (p *postgres) updateClient(id uuid.ID, change ClientChange) error {
87 query := p.updateClientSQL(id, change)
88 _, err := p.db.Exec(query.String(), query.Args...)
92 func (p *postgres) listClientsByOwnerSQL(ownerID uuid.ID, num, offset int) *pan.Query {
94 fields, _ := pan.GetFields(client)
95 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client))
97 query.Include(pan.GetUnquotedColumn(client, "OwnerID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", ownerID, false)
98 query.IncludeLimit(int64(num))
99 query.IncludeOffset(int64(offset))
100 return query.FlushExpressions(" ")
103 func (p *postgres) listClientsByOwner(ownerID uuid.ID, num, offset int) ([]Client, error) {
104 query := p.listClientsByOwnerSQL(ownerID, num, offset)
105 rows, err := p.db.Query(query.String(), query.Args...)
107 return []Client{}, err
112 err = pan.Unmarshal(rows, &client)
116 clients = append(clients, client)
118 if err = rows.Err(); err != nil {
124 func (p *postgres) addEndpointsSQL(endpoints []Endpoint) *pan.Query {
125 fields, _ := pan.GetFields(endpoints[0])
126 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(endpoints[0]))
127 query.Include("(" + pan.QueryList(fields) + ")")
128 query.Include("VALUES")
129 query.FlushExpressions(" ")
130 for _, endpoint := range endpoints {
131 _, values := pan.GetFields(endpoint)
132 query.Include("("+pan.VariableList(len(values))+")", values...)
134 return query.FlushExpressions(", ")
137 func (p *postgres) addEndpoints(endpoints []Endpoint) error {
138 if len(endpoints) < 1 {
141 query := p.addEndpointsSQL(endpoints)
142 _, err := p.db.Exec(query.String(), query.Args...)
143 if e, ok := err.(*pq.Error); ok && e.Constraint == "endpoints_pkey" {
144 return ErrEndpointAlreadyExists
149 func (p *postgres) removeEndpointSQL(client, endpoint uuid.ID) *pan.Query {
151 query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(e))
153 query.Include(pan.GetUnquotedColumn(e, "ID")+" = ? AND "+pan.GetUnquotedColumn(e, "ClientID")+" = ?", endpoint, client)
154 return query.FlushExpressions(" ")
157 func (p *postgres) removeEndpoint(client, endpoint uuid.ID) error {
158 query := p.removeEndpointSQL(client, endpoint)
159 res, err := p.db.Exec(query.String(), query.Args...)
163 rows, err := res.RowsAffected()
168 return ErrEndpointNotFound
173 func (p *postgres) getEndpointSQL(client, endpoint uuid.ID) *pan.Query {
175 fields, _ := pan.GetFields(e)
176 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e))
178 query.FlushExpressions(" ")
179 query.Include(pan.GetUnquotedColumn(e, "ID")+" = ?", endpoint)
180 query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client)
181 return query.FlushExpressions(" AND ")
184 func (p *postgres) getEndpoint(client, endpoint uuid.ID) (Endpoint, error) {
185 query := p.getEndpointSQL(client, endpoint)
186 rows, err := p.db.Query(query.String(), query.Args...)
188 return Endpoint{}, err
193 err := pan.Unmarshal(rows, &e)
199 if err = rows.Err(); err != nil {
203 return e, ErrEndpointNotFound
208 func (p *postgres) checkEndpointSQL(client uuid.ID, endpoint string) *pan.Query {
210 fields, _ := pan.GetFields(e)
211 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e))
213 query.FlushExpressions(" ")
214 query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client)
215 query.Include(pan.GetUnquotedColumn(e, "NormalizedURI")+" = ?", endpoint)
216 return query.FlushExpressions(" AND ")
219 func (p *postgres) checkEndpoint(client uuid.ID, endpoint string) (bool, error) {
220 query := p.checkEndpointSQL(client, endpoint)
221 rows, err := p.db.Query(query.String(), query.Args...)
229 if err = rows.Err(); err != nil {
235 func (p *postgres) listEndpointsSQL(client uuid.ID, num, offset int) *pan.Query {
236 var endpoint Endpoint
237 fields, _ := pan.GetFields(endpoint)
238 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(endpoint))
240 query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client)
241 query.IncludeLimit(int64(num))
242 query.IncludeOffset(int64(offset))
243 return query.FlushExpressions(" ")
246 func (p *postgres) listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error) {
247 query := p.listEndpointsSQL(client, num, offset)
248 rows, err := p.db.Query(query.String(), query.Args...)
250 return []Endpoint{}, err
252 var endpoints []Endpoint
254 var endpoint Endpoint
255 err = pan.Unmarshal(rows, &endpoint)
257 return endpoints, err
259 endpoints = append(endpoints, endpoint)
261 if err = rows.Err(); err != nil {
262 return endpoints, err
264 return endpoints, nil
267 func (p *postgres) countEndpointsSQL(client uuid.ID) *pan.Query {
268 var endpoint Endpoint
269 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM "+pan.GetTableName(endpoint))
271 query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client)
272 return query.FlushExpressions(" ")
275 func (p *postgres) countEndpoints(client uuid.ID) (int64, error) {
276 query := p.countEndpointsSQL(client)
277 rows, err := p.db.Query(query.String(), query.Args...)
283 err = pan.Unmarshal(rows, &results)
288 if err = rows.Err(); err != nil {