auth
auth/authcode_postgres.go
Clean up sessions and tokens after Profile is deleted. Add a terminateSessionsByProfile method to our sessionStore to mark Sessions associated with a Profile as inactive. Implement memstore and postgres implementations of the terminateSessionsByProfile method. Add a TerminateSessionsByProfile wrapper method to Context. Add a revokeTokensByProfileID method to our tokenStore to mark Tokens associated with a Profile as revoked. Implement memstore and postgres implementation of the revokeTokensByProfileID method. Add a RevokeTokensByProfileID wrapper method to Context. Call our RevokeTokensByProfileID and TerminateSessionsByProfile methods after a Profile is deleted, to clean up the Tokens and Sessions associated with it.
| paddy@156 | 1 package auth |
| paddy@156 | 2 |
| paddy@156 | 3 import ( |
| paddy@156 | 4 "github.com/lib/pq" |
| paddy@156 | 5 "github.com/secondbit/pan" |
| paddy@156 | 6 ) |
| paddy@156 | 7 |
| paddy@156 | 8 type authCodeScope struct { |
| paddy@156 | 9 Code string |
| paddy@156 | 10 Scope string |
| paddy@156 | 11 } |
| paddy@156 | 12 |
| paddy@156 | 13 func (acs authCodeScope) GetSQLTableName() string { |
| paddy@156 | 14 return "authorization_codes_scopes" |
| paddy@156 | 15 } |
| paddy@156 | 16 |
| paddy@156 | 17 func (ac AuthorizationCode) GetSQLTableName() string { |
| paddy@156 | 18 return "authorization_codes" |
| paddy@156 | 19 } |
| paddy@156 | 20 |
| paddy@156 | 21 func (p *postgres) getAuthorizationCodeSQL(code string) *pan.Query { |
| paddy@156 | 22 var ac AuthorizationCode |
| paddy@156 | 23 fields, _ := pan.GetFields(ac) |
| paddy@156 | 24 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(ac)) |
| paddy@156 | 25 query.IncludeWhere() |
| paddy@156 | 26 query.Include(pan.GetUnquotedColumn(ac, "Code")+" = ?", code) |
| paddy@156 | 27 return query.FlushExpressions(" ") |
| paddy@156 | 28 } |
| paddy@156 | 29 |
| paddy@156 | 30 func (p *postgres) getAuthorizationCodeScopesSQL(codes []string) *pan.Query { |
| paddy@156 | 31 var acs authCodeScope |
| paddy@156 | 32 fields, _ := pan.GetFields(acs) |
| paddy@156 | 33 codesI := make([]interface{}, len(codes)) |
| paddy@156 | 34 for pos, code := range codes { |
| paddy@156 | 35 codesI[pos] = code |
| paddy@156 | 36 } |
| paddy@156 | 37 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(acs)) |
| paddy@156 | 38 query.IncludeWhere() |
| paddy@156 | 39 query.Include(pan.GetUnquotedColumn(acs, "Code")+" IN ("+pan.VariableList(len(codesI))+")", codesI...) |
| paddy@156 | 40 return query.FlushExpressions(" ") |
| paddy@156 | 41 } |
| paddy@156 | 42 |
| paddy@156 | 43 func (p *postgres) getAuthorizationCode(code string) (AuthorizationCode, error) { |
| paddy@156 | 44 query := p.getAuthorizationCodeSQL(code) |
| paddy@156 | 45 rows, err := p.db.Query(query.String(), query.Args...) |
| paddy@156 | 46 if err != nil { |
| paddy@156 | 47 return AuthorizationCode{}, err |
| paddy@156 | 48 } |
| paddy@156 | 49 var ac AuthorizationCode |
| paddy@156 | 50 var found bool |
| paddy@156 | 51 for rows.Next() { |
| paddy@156 | 52 err := pan.Unmarshal(rows, &ac) |
| paddy@156 | 53 if err != nil { |
| paddy@156 | 54 return ac, err |
| paddy@156 | 55 } |
| paddy@156 | 56 found = true |
| paddy@156 | 57 } |
| paddy@156 | 58 if err = rows.Err(); err != nil { |
| paddy@156 | 59 return ac, err |
| paddy@156 | 60 } |
| paddy@156 | 61 if !found { |
| paddy@156 | 62 return ac, ErrAuthorizationCodeNotFound |
| paddy@156 | 63 } |
| paddy@156 | 64 query = p.getAuthorizationCodeScopesSQL([]string{code}) |
| paddy@156 | 65 rows, err = p.db.Query(query.String(), query.Args...) |
| paddy@156 | 66 if err != nil { |
| paddy@156 | 67 return ac, err |
| paddy@156 | 68 } |
| paddy@156 | 69 for rows.Next() { |
| paddy@156 | 70 var acs authCodeScope |
| paddy@156 | 71 err = pan.Unmarshal(rows, &acs) |
| paddy@156 | 72 if err != nil { |
| paddy@156 | 73 return ac, err |
| paddy@156 | 74 } |
| paddy@156 | 75 ac.Scopes = append(ac.Scopes, acs.Scope) |
| paddy@156 | 76 } |
| paddy@156 | 77 if err = rows.Err(); err != nil { |
| paddy@156 | 78 return ac, err |
| paddy@156 | 79 } |
| paddy@156 | 80 return ac, nil |
| paddy@156 | 81 } |
| paddy@156 | 82 |
| paddy@156 | 83 func (p *postgres) saveAuthorizationCodeSQL(authCode AuthorizationCode) *pan.Query { |
| paddy@156 | 84 fields, values := pan.GetFields(authCode) |
| paddy@156 | 85 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(authCode)) |
| paddy@156 | 86 query.Include("(" + pan.QueryList(fields) + ")") |
| paddy@156 | 87 query.Include("VALUES") |
| paddy@156 | 88 query.Include("("+pan.VariableList(len(values))+")", values...) |
| paddy@156 | 89 return query.FlushExpressions(" ") |
| paddy@156 | 90 } |
| paddy@156 | 91 |
| paddy@156 | 92 func (p *postgres) saveAuthorizationCodeScopesSQL(authCodeScopes []authCodeScope) *pan.Query { |
| paddy@156 | 93 fields, _ := pan.GetFields(authCodeScopes[0]) |
| paddy@156 | 94 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(authCodeScopes[0])) |
| paddy@156 | 95 query.Include("(" + pan.QueryList(fields) + ")") |
| paddy@156 | 96 query.Include("VALUES") |
| paddy@156 | 97 query.FlushExpressions(" ") |
| paddy@156 | 98 for _, acs := range authCodeScopes { |
| paddy@156 | 99 _, values := pan.GetFields(acs) |
| paddy@156 | 100 query.Include("("+pan.VariableList(len(values))+")", values...) |
| paddy@156 | 101 } |
| paddy@156 | 102 return query.FlushExpressions(", ") |
| paddy@156 | 103 } |
| paddy@156 | 104 |
| paddy@156 | 105 func (p *postgres) saveAuthorizationCode(authCode AuthorizationCode) error { |
| paddy@156 | 106 query := p.saveAuthorizationCodeSQL(authCode) |
| paddy@156 | 107 _, err := p.db.Exec(query.String(), query.Args...) |
| paddy@156 | 108 if e, ok := err.(*pq.Error); ok && e.Constraint == "authorization_codes_pkey" { |
| paddy@156 | 109 err = ErrAuthorizationCodeAlreadyExists |
| paddy@156 | 110 } |
| paddy@156 | 111 if err != nil || len(authCode.Scopes) < 1 { |
| paddy@156 | 112 return err |
| paddy@156 | 113 } |
| paddy@156 | 114 var acs []authCodeScope |
| paddy@156 | 115 for _, scope := range authCode.Scopes { |
| paddy@156 | 116 acs = append(acs, authCodeScope{Code: authCode.Code, Scope: scope}) |
| paddy@156 | 117 } |
| paddy@156 | 118 query = p.saveAuthorizationCodeScopesSQL(acs) |
| paddy@156 | 119 _, err = p.db.Exec(query.String(), query.Args...) |
| paddy@156 | 120 return err |
| paddy@156 | 121 } |
| paddy@156 | 122 |
| paddy@156 | 123 func (p *postgres) deleteAuthorizationCodeSQL(code string) *pan.Query { |
| paddy@156 | 124 var authCode AuthorizationCode |
| paddy@156 | 125 query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(authCode)) |
| paddy@156 | 126 query.IncludeWhere() |
| paddy@156 | 127 query.Include(pan.GetUnquotedColumn(authCode, "Code")+" = ?", code) |
| paddy@156 | 128 return query.FlushExpressions(" ") |
| paddy@156 | 129 } |
| paddy@156 | 130 |
| paddy@156 | 131 func (p *postgres) deleteAuthorizationCodeScopesSQL(code string) *pan.Query { |
| paddy@156 | 132 var acs authCodeScope |
| paddy@156 | 133 query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(acs)) |
| paddy@156 | 134 query.IncludeWhere() |
| paddy@156 | 135 query.Include(pan.GetUnquotedColumn(acs, "Code")+" = ?", code) |
| paddy@156 | 136 return query.FlushExpressions(" ") |
| paddy@156 | 137 } |
| paddy@156 | 138 |
| paddy@156 | 139 func (p *postgres) deleteAuthorizationCode(code string) error { |
| paddy@156 | 140 query := p.deleteAuthorizationCodeSQL(code) |
| paddy@156 | 141 res, err := p.db.Exec(query.String(), query.Args...) |
| paddy@156 | 142 if err != nil { |
| paddy@156 | 143 return err |
| paddy@156 | 144 } |
| paddy@156 | 145 rows, err := res.RowsAffected() |
| paddy@156 | 146 if err != nil { |
| paddy@156 | 147 return err |
| paddy@156 | 148 } |
| paddy@156 | 149 if rows == 0 { |
| paddy@156 | 150 return ErrAuthorizationCodeNotFound |
| paddy@156 | 151 } |
| paddy@156 | 152 query = p.deleteAuthorizationCodeScopesSQL(code) |
| paddy@156 | 153 _, err = p.db.Exec(query.String(), query.Args...) |
| paddy@156 | 154 return err |
| paddy@156 | 155 } |
| paddy@156 | 156 |
| paddy@156 | 157 func (p *postgres) useAuthorizationCodeSQL(code string) *pan.Query { |
| paddy@156 | 158 var authCode AuthorizationCode |
| paddy@156 | 159 query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(authCode)+" SET ") |
| paddy@156 | 160 query.Include(pan.GetUnquotedColumn(authCode, "Used")+" = ?", true) |
| paddy@156 | 161 query.IncludeWhere() |
| paddy@156 | 162 query.Include(pan.GetUnquotedColumn(authCode, "Code")+" = ?", code) |
| paddy@156 | 163 return query.FlushExpressions(" ") |
| paddy@156 | 164 } |
| paddy@156 | 165 |
| paddy@156 | 166 func (p *postgres) useAuthorizationCode(code string) error { |
| paddy@156 | 167 query := p.useAuthorizationCodeSQL(code) |
| paddy@156 | 168 res, err := p.db.Exec(query.String(), query.Args...) |
| paddy@156 | 169 if err != nil { |
| paddy@156 | 170 return err |
| paddy@156 | 171 } |
| paddy@156 | 172 rows, err := res.RowsAffected() |
| paddy@156 | 173 if err != nil { |
| paddy@156 | 174 return err |
| paddy@156 | 175 } |
| paddy@156 | 176 if rows == 0 { |
| paddy@156 | 177 return ErrAuthorizationCodeNotFound |
| paddy@156 | 178 } |
| paddy@156 | 179 return nil |
| paddy@156 | 180 } |