ducky/subscriptions
ducky/subscriptions/subscription_postgres.go
Export all subscriptionStore methods. We're not going to wrap all our subscriptionStore interactions in a Context type, so we need to expose all the functions so other packages can call them. Also, it's now SubscriptionStore. Also creates a SubscriptionRequest type that is used to create a new Subscription instance, along with a Validate method on it, to detect errors when creating a SubscriptionRequest. Update our CreateSubscription function to be New, and have it take a SubscriptionRequest, a Stripe instance, and a SubscriptionStore as arguments. It'll create the Subscription in the SubscriptionStore, create a customer in Stripe, and create the subscription in Stripe, then associate the Stripe subscription with the created Subscription. Also, fix our StripeSubscriptionChange function to correctly equate a missing/zero Unix timestamp from Stripe with a missing/zero time.Time.
| paddy@1 | 1 package subscriptions |
| paddy@1 | 2 |
| paddy@1 | 3 import ( |
| paddy@2 | 4 "log" |
| paddy@1 | 5 |
| paddy@1 | 6 "code.secondbit.org/uuid.hg" |
| paddy@2 | 7 |
| paddy@1 | 8 "github.com/lib/pq" |
| paddy@1 | 9 "github.com/secondbit/pan" |
| paddy@1 | 10 ) |
| paddy@1 | 11 |
| paddy@1 | 12 // GetSQLTableName fulfills the pan.SQLTableNamer interface, allowing |
| paddy@1 | 13 // us to manipulate Subscriptions with pan. |
| paddy@1 | 14 func (s Subscription) GetSQLTableName() string { |
| paddy@1 | 15 return "subscriptions" |
| paddy@1 | 16 } |
| paddy@1 | 17 |
| paddy@1 | 18 func (p Postgres) resetSQL() *pan.Query { |
| paddy@2 | 19 var subscription Subscription |
| paddy@2 | 20 query := pan.New(pan.POSTGRES, "TRUNCATE "+pan.GetTableName(subscription)) |
| paddy@1 | 21 return query.FlushExpressions(" ") |
| paddy@1 | 22 } |
| paddy@1 | 23 |
| paddy@3 | 24 func (p Postgres) Reset() error { |
| paddy@1 | 25 query := p.resetSQL() |
| paddy@1 | 26 _, err := p.Exec(query.String(), query.Args...) |
| paddy@1 | 27 if err != nil { |
| paddy@1 | 28 return err |
| paddy@1 | 29 } |
| paddy@1 | 30 return nil |
| paddy@1 | 31 } |
| paddy@1 | 32 |
| paddy@2 | 33 func (p Postgres) createSubscriptionSQL(subscription Subscription) *pan.Query { |
| paddy@2 | 34 fields, values := pan.GetFields(subscription) |
| paddy@2 | 35 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(subscription)) |
| paddy@1 | 36 query.Include("(" + pan.QueryList(fields) + ")") |
| paddy@1 | 37 query.Include("VALUES") |
| paddy@1 | 38 query.Include("("+pan.VariableList(len(values))+")", values...) |
| paddy@1 | 39 return query.FlushExpressions(" ") |
| paddy@1 | 40 } |
| paddy@1 | 41 |
| paddy@3 | 42 func (p Postgres) CreateSubscription(sub Subscription) error { |
| paddy@1 | 43 query := p.createSubscriptionSQL(sub) |
| paddy@1 | 44 _, err := p.Exec(query.String(), query.Args...) |
| paddy@1 | 45 if e, ok := err.(*pq.Error); ok && e.Constraint == "subscriptions_pkey" { |
| paddy@1 | 46 err = ErrSubscriptionAlreadyExists |
| paddy@2 | 47 } else if e, ok := err.(*pq.Error); ok && e.Constraint == "subscriptions_stripe_subscription_key" { |
| paddy@2 | 48 err = ErrStripeSubscriptionAlreadyExists |
| paddy@1 | 49 } |
| paddy@1 | 50 return err |
| paddy@1 | 51 } |
| paddy@1 | 52 |
| paddy@1 | 53 func (p Postgres) updateSubscriptionSQL(id uuid.ID, change SubscriptionChange) *pan.Query { |
| paddy@2 | 54 var subscription Subscription |
| paddy@2 | 55 query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(subscription)+" SET") |
| paddy@2 | 56 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "StripeSubscription")+" = ?", change.StripeSubscription) |
| paddy@2 | 57 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "Plan")+" = ?", change.Plan) |
| paddy@2 | 58 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "Status")+" = ?", change.Status) |
| paddy@2 | 59 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "Canceling")+" = ?", change.Canceling) |
| paddy@2 | 60 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "TrialStart")+" = ?", change.TrialStart) |
| paddy@2 | 61 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "TrialEnd")+" = ?", change.TrialEnd) |
| paddy@2 | 62 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "PeriodStart")+" = ?", change.PeriodStart) |
| paddy@2 | 63 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "PeriodEnd")+" = ?", change.PeriodEnd) |
| paddy@2 | 64 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "CanceledAt")+" = ?", change.CanceledAt) |
| paddy@2 | 65 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "FailedChargeAttempts")+" = ?", change.FailedChargeAttempts) |
| paddy@2 | 66 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "LastFailedCharge")+" = ?", change.LastFailedCharge) |
| paddy@2 | 67 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "LastNotified")+" = ?", change.LastNotified) |
| paddy@1 | 68 query.FlushExpressions(", ") |
| paddy@1 | 69 query.IncludeWhere() |
| paddy@2 | 70 query.Include(pan.GetUnquotedColumn(subscription, "UserID")+" = ?", id) |
| paddy@1 | 71 return query.FlushExpressions(" ") |
| paddy@1 | 72 } |
| paddy@1 | 73 |
| paddy@3 | 74 func (p Postgres) UpdateSubscription(id uuid.ID, change SubscriptionChange) error { |
| paddy@1 | 75 if change.IsEmpty() { |
| paddy@1 | 76 return ErrSubscriptionChangeEmpty |
| paddy@1 | 77 } |
| paddy@1 | 78 |
| paddy@1 | 79 query := p.updateSubscriptionSQL(id, change) |
| paddy@1 | 80 res, err := p.Exec(query.String(), query.Args...) |
| paddy@2 | 81 if e, ok := err.(*pq.Error); ok && e.Constraint == "subscriptions_stripe_subscription_key" { |
| paddy@2 | 82 return ErrStripeSubscriptionAlreadyExists |
| paddy@1 | 83 } else if err != nil { |
| paddy@1 | 84 return err |
| paddy@1 | 85 } |
| paddy@1 | 86 rows, err := res.RowsAffected() |
| paddy@1 | 87 if err != nil { |
| paddy@1 | 88 return err |
| paddy@1 | 89 } |
| paddy@1 | 90 if rows < 1 { |
| paddy@1 | 91 return ErrSubscriptionNotFound |
| paddy@1 | 92 } |
| paddy@1 | 93 return nil |
| paddy@1 | 94 } |
| paddy@1 | 95 |
| paddy@1 | 96 func (p Postgres) deleteSubscriptionSQL(id uuid.ID) *pan.Query { |
| paddy@2 | 97 var subscription Subscription |
| paddy@2 | 98 query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(subscription)) |
| paddy@1 | 99 query.IncludeWhere() |
| paddy@2 | 100 query.Include(pan.GetUnquotedColumn(subscription, "UserID")+" = ?", id) |
| paddy@1 | 101 return query.FlushExpressions(" ") |
| paddy@1 | 102 } |
| paddy@1 | 103 |
| paddy@3 | 104 func (p Postgres) DeleteSubscription(id uuid.ID) error { |
| paddy@1 | 105 query := p.deleteSubscriptionSQL(id) |
| paddy@1 | 106 res, err := p.Exec(query.String(), query.Args...) |
| paddy@1 | 107 if err != nil { |
| paddy@1 | 108 return err |
| paddy@1 | 109 } |
| paddy@1 | 110 rows, err := res.RowsAffected() |
| paddy@1 | 111 if err != nil { |
| paddy@1 | 112 return err |
| paddy@1 | 113 } |
| paddy@1 | 114 if rows < 1 { |
| paddy@1 | 115 return ErrSubscriptionNotFound |
| paddy@1 | 116 } |
| paddy@1 | 117 return nil |
| paddy@1 | 118 } |
| paddy@1 | 119 |
| paddy@1 | 120 func (p Postgres) getSubscriptionsSQL(ids []uuid.ID) *pan.Query { |
| paddy@2 | 121 var subscription Subscription |
| paddy@2 | 122 fields, _ := pan.GetFields(subscription) |
| paddy@1 | 123 intIDs := make([]interface{}, len(ids)) |
| paddy@1 | 124 for pos, id := range ids { |
| paddy@1 | 125 intIDs[pos] = id |
| paddy@1 | 126 } |
| paddy@2 | 127 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(subscription)) |
| paddy@1 | 128 query.IncludeWhere() |
| paddy@2 | 129 query.Include(pan.GetUnquotedColumn(subscription, "UserID") + " IN") |
| paddy@1 | 130 query.Include("("+pan.VariableList(len(intIDs))+")", intIDs...) |
| paddy@1 | 131 return query.FlushExpressions(" ") |
| paddy@1 | 132 } |
| paddy@1 | 133 |
| paddy@3 | 134 func (p Postgres) GetSubscriptions(ids []uuid.ID) (map[string]Subscription, error) { |
| paddy@1 | 135 results := map[string]Subscription{} |
| paddy@1 | 136 if len(ids) < 1 { |
| paddy@1 | 137 return results, ErrNoSubscriptionID |
| paddy@1 | 138 } |
| paddy@1 | 139 query := p.getSubscriptionsSQL(ids) |
| paddy@1 | 140 rows, err := p.Query(query.String(), query.Args...) |
| paddy@1 | 141 if err != nil { |
| paddy@1 | 142 return results, err |
| paddy@1 | 143 } |
| paddy@1 | 144 for rows.Next() { |
| paddy@2 | 145 var subscription Subscription |
| paddy@2 | 146 err := pan.Unmarshal(rows, &subscription) |
| paddy@1 | 147 if err != nil { |
| paddy@1 | 148 return results, err |
| paddy@1 | 149 } |
| paddy@2 | 150 results[subscription.UserID.String()] = subscription |
| paddy@1 | 151 } |
| paddy@1 | 152 if err := rows.Err(); err != nil { |
| paddy@1 | 153 return results, err |
| paddy@1 | 154 } |
| paddy@1 | 155 return results, nil |
| paddy@1 | 156 } |
| paddy@1 | 157 |
| paddy@2 | 158 func (p Postgres) getSubscriptionStatsCountSQL() *pan.Query { |
| paddy@2 | 159 var subscription Subscription |
| paddy@2 | 160 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM") |
| paddy@2 | 161 query.Include(pan.GetTableName(subscription)) |
| paddy@2 | 162 return query.FlushExpressions(" ") |
| paddy@2 | 163 } |
| paddy@2 | 164 |
| paddy@2 | 165 func (p Postgres) getSubscriptionStatsCancelingSQL() *pan.Query { |
| paddy@2 | 166 var subscription Subscription |
| paddy@2 | 167 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM") |
| paddy@2 | 168 query.Include(pan.GetTableName(subscription)) |
| paddy@2 | 169 query.IncludeWhere() |
| paddy@2 | 170 query.Include(pan.GetUnquotedColumn(subscription, "Canceling")+" = ?", true) |
| paddy@2 | 171 return query.FlushExpressions(" ") |
| paddy@2 | 172 } |
| paddy@2 | 173 |
| paddy@2 | 174 func (p Postgres) getSubscriptionStatsFailingSQL() *pan.Query { |
| paddy@2 | 175 var subscription Subscription |
| paddy@2 | 176 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM") |
| paddy@2 | 177 query.Include(pan.GetTableName(subscription)) |
| paddy@2 | 178 query.IncludeWhere() |
| paddy@2 | 179 statuses := []interface{}{"past_due", "unpaid"} |
| paddy@2 | 180 query.Include(pan.GetUnquotedColumn(subscription, "Status")+" IN ("+pan.VariableList(len(statuses))+")", statuses...) |
| paddy@2 | 181 return query.FlushExpressions(" ") |
| paddy@2 | 182 } |
| paddy@2 | 183 |
| paddy@2 | 184 func (p Postgres) getSubscriptionStatsPlansSQL() *pan.Query { |
| paddy@2 | 185 var subscription Subscription |
| paddy@2 | 186 fields := []interface{}{pan.GetUnquotedColumn(subscription, "Plan"), "COUNT(*)"} |
| paddy@2 | 187 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM") |
| paddy@2 | 188 query.Include(pan.GetTableName(subscription)) |
| paddy@2 | 189 query.Include("GROUP BY " + pan.GetUnquotedColumn(subscription, "Plan")) |
| paddy@1 | 190 return query.FlushExpressions(" ") |
| paddy@1 | 191 } |
| paddy@1 | 192 |
| paddy@3 | 193 func (p Postgres) GetSubscriptionStats() (SubscriptionStats, error) { |
| paddy@2 | 194 stats := SubscriptionStats{ |
| paddy@2 | 195 Plans: map[string]int64{}, |
| paddy@2 | 196 } |
| paddy@2 | 197 query := p.getSubscriptionStatsCountSQL() |
| paddy@2 | 198 err := p.QueryRow(query.String(), query.Args...).Scan(&stats.Number) |
| paddy@2 | 199 if err != nil { |
| paddy@2 | 200 log.Printf("Error querying for total subscriptions: %+v\n", err) |
| paddy@2 | 201 return stats, err |
| paddy@2 | 202 } |
| paddy@2 | 203 query = p.getSubscriptionStatsCancelingSQL() |
| paddy@2 | 204 err = p.QueryRow(query.String(), query.Args...).Scan(&stats.Canceling) |
| paddy@2 | 205 if err != nil { |
| paddy@2 | 206 log.Printf("Error querying for canceling subscriptions: %+v\n", err) |
| paddy@2 | 207 return stats, err |
| paddy@2 | 208 } |
| paddy@2 | 209 query = p.getSubscriptionStatsFailingSQL() |
| paddy@2 | 210 err = p.QueryRow(query.String(), query.Args...).Scan(&stats.Failing) |
| paddy@2 | 211 if err != nil { |
| paddy@2 | 212 log.Printf("Error querying for failing subscriptions: %+v\n", err) |
| paddy@2 | 213 return stats, err |
| paddy@2 | 214 } |
| paddy@2 | 215 query = p.getSubscriptionStatsPlansSQL() |
| paddy@1 | 216 rows, err := p.Query(query.String(), query.Args...) |
| paddy@1 | 217 if err != nil { |
| paddy@2 | 218 log.Printf("Error querying for plans: %+v\n", err) |
| paddy@2 | 219 return stats, err |
| paddy@1 | 220 } |
| paddy@1 | 221 for rows.Next() { |
| paddy@2 | 222 var plan string |
| paddy@2 | 223 var count int64 |
| paddy@2 | 224 err := rows.Scan(&plan, &count) |
| paddy@2 | 225 if err != nil { |
| paddy@2 | 226 log.Printf("Error scanning database row for plans: %+v\n", err) |
| paddy@2 | 227 continue |
| paddy@1 | 228 } |
| paddy@2 | 229 stats.Plans[plan] = count |
| paddy@1 | 230 } |
| paddy@1 | 231 if err := rows.Err(); err != nil { |
| paddy@2 | 232 log.Printf("Error querying for plans: %+v\n", err) |
| paddy@1 | 233 return stats, err |
| paddy@1 | 234 } |
| paddy@1 | 235 return stats, nil |
| paddy@1 | 236 } |