auth

Paddy 2015-04-11 Parent:77db7c65216c Child:cf1aef6eb81f

162:6f473576c6ae Go to Latest

auth/client_postgres.go

Clean up sessions and tokens after Profile is deleted. Add a terminateSessionsByProfile method to our sessionStore to mark Sessions associated with a Profile as inactive. Implement memstore and postgres implementations of the terminateSessionsByProfile method. Add a TerminateSessionsByProfile wrapper method to Context. Add a revokeTokensByProfileID method to our tokenStore to mark Tokens associated with a Profile as revoked. Implement memstore and postgres implementation of the revokeTokensByProfileID method. Add a RevokeTokensByProfileID wrapper method to Context. Call our RevokeTokensByProfileID and TerminateSessionsByProfile methods after a Profile is deleted, to clean up the Tokens and Sessions associated with it.

History
1 package auth
3 import (
4 "code.secondbit.org/uuid.hg"
5 "github.com/lib/pq"
6 "github.com/secondbit/pan"
7 )
9 func (c Client) GetSQLTableName() string {
10 return "clients"
11 }
13 func (e Endpoint) GetSQLTableName() string {
14 return "endpoints"
15 }
17 func (p *postgres) getClientSQL(id uuid.ID) *pan.Query {
18 var client Client
19 fields, _ := pan.GetFields(client)
20 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client))
21 query.IncludeWhere()
22 query.Include(pan.GetUnquotedColumn(client, "ID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", id, false)
23 return query.FlushExpressions(" ")
24 }
26 func (p *postgres) getClient(id uuid.ID) (Client, error) {
27 query := p.getClientSQL(id)
28 rows, err := p.db.Query(query.String(), query.Args...)
29 if err != nil {
30 return Client{}, err
31 }
32 var client Client
33 var found bool
34 for rows.Next() {
35 err := pan.Unmarshal(rows, &client)
36 if err != nil {
37 return client, err
38 }
39 found = true
40 }
41 if err = rows.Err(); err != nil {
42 return client, err
43 }
44 if !found {
45 return client, ErrClientNotFound
46 }
47 return client, nil
48 }
50 func (p *postgres) saveClientSQL(client Client) *pan.Query {
51 fields, values := pan.GetFields(client)
52 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(client))
53 query.Include("(" + pan.QueryList(fields) + ")")
54 query.Include("VALUES")
55 query.Include("("+pan.VariableList(len(values))+")", values...)
56 return query.FlushExpressions(" ")
57 }
59 func (p *postgres) saveClient(client Client) error {
60 query := p.saveClientSQL(client)
61 _, err := p.db.Exec(query.String(), query.Args...)
62 if e, ok := err.(*pq.Error); ok && e.Constraint == "clients_pkey" {
63 err = ErrClientAlreadyExists
64 }
65 return err
66 }
68 func (p *postgres) updateClientSQL(id uuid.ID, change ClientChange) *pan.Query {
69 var client Client
70 query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(client)+" SET ")
71 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Secret")+" = ?", change.Secret)
72 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "OwnerID")+" = ?", change.OwnerID)
73 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Name")+" = ?", change.Name)
74 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Logo")+" = ?", change.Logo)
75 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Website")+" = ?", change.Website)
76 query.IncludeIfNotNil(pan.GetUnquotedColumn(client, "Deleted")+" = ?", change.Deleted)
77 query.FlushExpressions(", ")
78 query.IncludeWhere()
79 query.Include(pan.GetUnquotedColumn(client, "ID")+" = ?", id)
80 return query.FlushExpressions(" ")
81 }
83 func (p *postgres) updateClient(id uuid.ID, change ClientChange) error {
84 if change.Empty() {
85 return nil
86 }
87 query := p.updateClientSQL(id, change)
88 _, err := p.db.Exec(query.String(), query.Args...)
89 return err
90 }
92 func (p *postgres) listClientsByOwnerSQL(ownerID uuid.ID, num, offset int) *pan.Query {
93 var client Client
94 fields, _ := pan.GetFields(client)
95 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(client))
96 query.IncludeWhere()
97 query.Include(pan.GetUnquotedColumn(client, "OwnerID")+" = ? AND "+pan.GetUnquotedColumn(client, "Deleted")+" = ?", ownerID, false)
98 query.IncludeLimit(int64(num))
99 query.IncludeOffset(int64(offset))
100 return query.FlushExpressions(" ")
101 }
103 func (p *postgres) listClientsByOwner(ownerID uuid.ID, num, offset int) ([]Client, error) {
104 query := p.listClientsByOwnerSQL(ownerID, num, offset)
105 rows, err := p.db.Query(query.String(), query.Args...)
106 if err != nil {
107 return []Client{}, err
108 }
109 var clients []Client
110 for rows.Next() {
111 var client Client
112 err = pan.Unmarshal(rows, &client)
113 if err != nil {
114 return clients, err
115 }
116 clients = append(clients, client)
117 }
118 if err = rows.Err(); err != nil {
119 return clients, err
120 }
121 return clients, nil
122 }
124 func (p *postgres) addEndpointsSQL(endpoints []Endpoint) *pan.Query {
125 fields, _ := pan.GetFields(endpoints[0])
126 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(endpoints[0]))
127 query.Include("(" + pan.QueryList(fields) + ")")
128 query.Include("VALUES")
129 query.FlushExpressions(" ")
130 for _, endpoint := range endpoints {
131 _, values := pan.GetFields(endpoint)
132 query.Include("("+pan.VariableList(len(values))+")", values...)
133 }
134 return query.FlushExpressions(", ")
135 }
137 func (p *postgres) addEndpoints(endpoints []Endpoint) error {
138 if len(endpoints) < 1 {
139 return nil
140 }
141 query := p.addEndpointsSQL(endpoints)
142 _, err := p.db.Exec(query.String(), query.Args...)
143 if e, ok := err.(*pq.Error); ok && e.Constraint == "endpoints_pkey" {
144 return ErrEndpointAlreadyExists
145 }
146 return err
147 }
149 func (p *postgres) removeEndpointSQL(client, endpoint uuid.ID) *pan.Query {
150 var e Endpoint
151 query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(e))
152 query.IncludeWhere()
153 query.Include(pan.GetUnquotedColumn(e, "ID")+" = ? AND "+pan.GetUnquotedColumn(e, "ClientID")+" = ?", endpoint, client)
154 return query.FlushExpressions(" ")
155 }
157 func (p *postgres) removeEndpoint(client, endpoint uuid.ID) error {
158 query := p.removeEndpointSQL(client, endpoint)
159 res, err := p.db.Exec(query.String(), query.Args...)
160 if err != nil {
161 return err
162 }
163 rows, err := res.RowsAffected()
164 if err != nil {
165 return err
166 }
167 if rows == 0 {
168 return ErrEndpointNotFound
169 }
170 return nil
171 }
173 func (p *postgres) getEndpointSQL(client, endpoint uuid.ID) *pan.Query {
174 var e Endpoint
175 fields, _ := pan.GetFields(e)
176 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e))
177 query.IncludeWhere()
178 query.FlushExpressions(" ")
179 query.Include(pan.GetUnquotedColumn(e, "ID")+" = ?", endpoint)
180 query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client)
181 return query.FlushExpressions(" AND ")
182 }
184 func (p *postgres) getEndpoint(client, endpoint uuid.ID) (Endpoint, error) {
185 query := p.getEndpointSQL(client, endpoint)
186 rows, err := p.db.Query(query.String(), query.Args...)
187 if err != nil {
188 return Endpoint{}, err
189 }
190 var e Endpoint
191 var found bool
192 for rows.Next() {
193 err := pan.Unmarshal(rows, &e)
194 if err != nil {
195 return e, err
196 }
197 found = true
198 }
199 if err = rows.Err(); err != nil {
200 return e, err
201 }
202 if !found {
203 return e, ErrEndpointNotFound
204 }
205 return e, nil
206 }
208 func (p *postgres) checkEndpointSQL(client uuid.ID, endpoint string) *pan.Query {
209 var e Endpoint
210 fields, _ := pan.GetFields(e)
211 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(e))
212 query.IncludeWhere()
213 query.FlushExpressions(" ")
214 query.Include(pan.GetUnquotedColumn(e, "ClientID")+" = ?", client)
215 query.Include(pan.GetUnquotedColumn(e, "NormalizedURI")+" = ?", endpoint)
216 return query.FlushExpressions(" AND ")
217 }
219 func (p *postgres) checkEndpoint(client uuid.ID, endpoint string) (bool, error) {
220 query := p.checkEndpointSQL(client, endpoint)
221 rows, err := p.db.Query(query.String(), query.Args...)
222 if err != nil {
223 return false, err
224 }
225 var found bool
226 for rows.Next() {
227 found = true
228 }
229 if err = rows.Err(); err != nil {
230 return found, err
231 }
232 return found, nil
233 }
235 func (p *postgres) listEndpointsSQL(client uuid.ID, num, offset int) *pan.Query {
236 var endpoint Endpoint
237 fields, _ := pan.GetFields(endpoint)
238 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(endpoint))
239 query.IncludeWhere()
240 query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client)
241 query.IncludeLimit(int64(num))
242 query.IncludeOffset(int64(offset))
243 return query.FlushExpressions(" ")
244 }
246 func (p *postgres) listEndpoints(client uuid.ID, num, offset int) ([]Endpoint, error) {
247 query := p.listEndpointsSQL(client, num, offset)
248 rows, err := p.db.Query(query.String(), query.Args...)
249 if err != nil {
250 return []Endpoint{}, err
251 }
252 var endpoints []Endpoint
253 for rows.Next() {
254 var endpoint Endpoint
255 err = pan.Unmarshal(rows, &endpoint)
256 if err != nil {
257 return endpoints, err
258 }
259 endpoints = append(endpoints, endpoint)
260 }
261 if err = rows.Err(); err != nil {
262 return endpoints, err
263 }
264 return endpoints, nil
265 }
267 func (p *postgres) countEndpointsSQL(client uuid.ID) *pan.Query {
268 var endpoint Endpoint
269 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM "+pan.GetTableName(endpoint))
270 query.IncludeWhere()
271 query.Include(pan.GetUnquotedColumn(endpoint, "ClientID")+" = ?", client)
272 return query.FlushExpressions(" ")
273 }
275 func (p *postgres) countEndpoints(client uuid.ID) (int64, error) {
276 query := p.countEndpointsSQL(client)
277 rows, err := p.db.Query(query.String(), query.Args...)
278 if err != nil {
279 return 0, err
280 }
281 var results int64
282 for rows.Next() {
283 err = pan.Unmarshal(rows, &results)
284 if err != nil {
285 return results, err
286 }
287 }
288 if err = rows.Err(); err != nil {
289 return results, err
290 }
291 return results, nil
292 }