ducky/subscriptions
ducky/subscriptions/subscription_postgres.go
Use stripe's built-in subscriptions. We're going to use Stripe's built-in subscriptions to manage our subscriptions, which required us to change a lot of stuff. We're now tracking stripe_subscription instead of stripe_customer, and we need to track the plan, status, and if the user is canceling after this month. We also don't need to know when to begin charging them (Stripe will do it), but we should track when their trial starts and ends, when the current pay period they're in starts and ends, when they canceled (if they've canceled), the number of failed charge attempts they've had, and the last time we notified them about billing (To avoid spamming users). We get to delete all the stuff about periods, which is nice. We updated our SubscriptionChange type to match. Notably, there are a lot of non-user modifiable things now, but our Stripe webhook will need to use them to update our database records and keep them in sync. We no longer need to deal with sorting stuff, which is also nice. Our SubscriptionStats have been updated to be... useful? Now we can track how many users we have, and how many of them have failing credit cards, how many are canceling at the end of their current payment period, and how many users are on each plan. We also switched around how the TestUpdateSubscription loops were written, to avoid resetting more than we needed to. Before, we had to call store.reset() after every single change iteration. Now we get to call it only when switching stores. This makes a significant difference in the amount of time it takes to run tests. Finally, we added a test case for retrieving subscription stats. It's minimal, but it works.
| paddy@1 | 1 package subscriptions |
| paddy@1 | 2 |
| paddy@1 | 3 import ( |
| paddy@2 | 4 "log" |
| paddy@1 | 5 |
| paddy@1 | 6 "code.secondbit.org/uuid.hg" |
| paddy@2 | 7 |
| paddy@1 | 8 "github.com/lib/pq" |
| paddy@1 | 9 "github.com/secondbit/pan" |
| paddy@1 | 10 ) |
| paddy@1 | 11 |
| paddy@1 | 12 // GetSQLTableName fulfills the pan.SQLTableNamer interface, allowing |
| paddy@1 | 13 // us to manipulate Subscriptions with pan. |
| paddy@1 | 14 func (s Subscription) GetSQLTableName() string { |
| paddy@1 | 15 return "subscriptions" |
| paddy@1 | 16 } |
| paddy@1 | 17 |
| paddy@1 | 18 func (p Postgres) resetSQL() *pan.Query { |
| paddy@2 | 19 var subscription Subscription |
| paddy@2 | 20 query := pan.New(pan.POSTGRES, "TRUNCATE "+pan.GetTableName(subscription)) |
| paddy@1 | 21 return query.FlushExpressions(" ") |
| paddy@1 | 22 } |
| paddy@1 | 23 |
| paddy@1 | 24 func (p Postgres) reset() error { |
| paddy@1 | 25 query := p.resetSQL() |
| paddy@1 | 26 _, err := p.Exec(query.String(), query.Args...) |
| paddy@1 | 27 if err != nil { |
| paddy@1 | 28 return err |
| paddy@1 | 29 } |
| paddy@1 | 30 return nil |
| paddy@1 | 31 } |
| paddy@1 | 32 |
| paddy@2 | 33 func (p Postgres) createSubscriptionSQL(subscription Subscription) *pan.Query { |
| paddy@2 | 34 fields, values := pan.GetFields(subscription) |
| paddy@2 | 35 query := pan.New(pan.POSTGRES, "INSERT INTO "+pan.GetTableName(subscription)) |
| paddy@1 | 36 query.Include("(" + pan.QueryList(fields) + ")") |
| paddy@1 | 37 query.Include("VALUES") |
| paddy@1 | 38 query.Include("("+pan.VariableList(len(values))+")", values...) |
| paddy@1 | 39 return query.FlushExpressions(" ") |
| paddy@1 | 40 } |
| paddy@1 | 41 |
| paddy@1 | 42 func (p Postgres) createSubscription(sub Subscription) error { |
| paddy@1 | 43 query := p.createSubscriptionSQL(sub) |
| paddy@1 | 44 _, err := p.Exec(query.String(), query.Args...) |
| paddy@1 | 45 if e, ok := err.(*pq.Error); ok && e.Constraint == "subscriptions_pkey" { |
| paddy@1 | 46 err = ErrSubscriptionAlreadyExists |
| paddy@2 | 47 } else if e, ok := err.(*pq.Error); ok && e.Constraint == "subscriptions_stripe_subscription_key" { |
| paddy@2 | 48 err = ErrStripeSubscriptionAlreadyExists |
| paddy@1 | 49 } |
| paddy@1 | 50 return err |
| paddy@1 | 51 } |
| paddy@1 | 52 |
| paddy@1 | 53 func (p Postgres) updateSubscriptionSQL(id uuid.ID, change SubscriptionChange) *pan.Query { |
| paddy@2 | 54 var subscription Subscription |
| paddy@2 | 55 query := pan.New(pan.POSTGRES, "UPDATE "+pan.GetTableName(subscription)+" SET") |
| paddy@2 | 56 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "StripeSubscription")+" = ?", change.StripeSubscription) |
| paddy@2 | 57 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "Plan")+" = ?", change.Plan) |
| paddy@2 | 58 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "Status")+" = ?", change.Status) |
| paddy@2 | 59 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "Canceling")+" = ?", change.Canceling) |
| paddy@2 | 60 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "TrialStart")+" = ?", change.TrialStart) |
| paddy@2 | 61 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "TrialEnd")+" = ?", change.TrialEnd) |
| paddy@2 | 62 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "PeriodStart")+" = ?", change.PeriodStart) |
| paddy@2 | 63 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "PeriodEnd")+" = ?", change.PeriodEnd) |
| paddy@2 | 64 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "CanceledAt")+" = ?", change.CanceledAt) |
| paddy@2 | 65 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "FailedChargeAttempts")+" = ?", change.FailedChargeAttempts) |
| paddy@2 | 66 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "LastFailedCharge")+" = ?", change.LastFailedCharge) |
| paddy@2 | 67 query.IncludeIfNotNil(pan.GetUnquotedColumn(subscription, "LastNotified")+" = ?", change.LastNotified) |
| paddy@1 | 68 query.FlushExpressions(", ") |
| paddy@1 | 69 query.IncludeWhere() |
| paddy@2 | 70 query.Include(pan.GetUnquotedColumn(subscription, "UserID")+" = ?", id) |
| paddy@1 | 71 return query.FlushExpressions(" ") |
| paddy@1 | 72 } |
| paddy@1 | 73 |
| paddy@1 | 74 func (p Postgres) updateSubscription(id uuid.ID, change SubscriptionChange) error { |
| paddy@1 | 75 if change.IsEmpty() { |
| paddy@1 | 76 return ErrSubscriptionChangeEmpty |
| paddy@1 | 77 } |
| paddy@1 | 78 |
| paddy@1 | 79 query := p.updateSubscriptionSQL(id, change) |
| paddy@1 | 80 res, err := p.Exec(query.String(), query.Args...) |
| paddy@2 | 81 if e, ok := err.(*pq.Error); ok && e.Constraint == "subscriptions_stripe_subscription_key" { |
| paddy@2 | 82 return ErrStripeSubscriptionAlreadyExists |
| paddy@1 | 83 } else if err != nil { |
| paddy@1 | 84 return err |
| paddy@1 | 85 } |
| paddy@1 | 86 rows, err := res.RowsAffected() |
| paddy@1 | 87 if err != nil { |
| paddy@1 | 88 return err |
| paddy@1 | 89 } |
| paddy@1 | 90 if rows < 1 { |
| paddy@1 | 91 return ErrSubscriptionNotFound |
| paddy@1 | 92 } |
| paddy@1 | 93 return nil |
| paddy@1 | 94 } |
| paddy@1 | 95 |
| paddy@1 | 96 func (p Postgres) deleteSubscriptionSQL(id uuid.ID) *pan.Query { |
| paddy@2 | 97 var subscription Subscription |
| paddy@2 | 98 query := pan.New(pan.POSTGRES, "DELETE FROM "+pan.GetTableName(subscription)) |
| paddy@1 | 99 query.IncludeWhere() |
| paddy@2 | 100 query.Include(pan.GetUnquotedColumn(subscription, "UserID")+" = ?", id) |
| paddy@1 | 101 return query.FlushExpressions(" ") |
| paddy@1 | 102 } |
| paddy@1 | 103 |
| paddy@1 | 104 func (p Postgres) deleteSubscription(id uuid.ID) error { |
| paddy@1 | 105 query := p.deleteSubscriptionSQL(id) |
| paddy@1 | 106 res, err := p.Exec(query.String(), query.Args...) |
| paddy@1 | 107 if err != nil { |
| paddy@1 | 108 return err |
| paddy@1 | 109 } |
| paddy@1 | 110 rows, err := res.RowsAffected() |
| paddy@1 | 111 if err != nil { |
| paddy@1 | 112 return err |
| paddy@1 | 113 } |
| paddy@1 | 114 if rows < 1 { |
| paddy@1 | 115 return ErrSubscriptionNotFound |
| paddy@1 | 116 } |
| paddy@1 | 117 return nil |
| paddy@1 | 118 } |
| paddy@1 | 119 |
| paddy@1 | 120 func (p Postgres) getSubscriptionsSQL(ids []uuid.ID) *pan.Query { |
| paddy@2 | 121 var subscription Subscription |
| paddy@2 | 122 fields, _ := pan.GetFields(subscription) |
| paddy@1 | 123 intIDs := make([]interface{}, len(ids)) |
| paddy@1 | 124 for pos, id := range ids { |
| paddy@1 | 125 intIDs[pos] = id |
| paddy@1 | 126 } |
| paddy@2 | 127 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM "+pan.GetTableName(subscription)) |
| paddy@1 | 128 query.IncludeWhere() |
| paddy@2 | 129 query.Include(pan.GetUnquotedColumn(subscription, "UserID") + " IN") |
| paddy@1 | 130 query.Include("("+pan.VariableList(len(intIDs))+")", intIDs...) |
| paddy@1 | 131 return query.FlushExpressions(" ") |
| paddy@1 | 132 } |
| paddy@1 | 133 |
| paddy@1 | 134 func (p Postgres) getSubscriptions(ids []uuid.ID) (map[string]Subscription, error) { |
| paddy@1 | 135 results := map[string]Subscription{} |
| paddy@1 | 136 if len(ids) < 1 { |
| paddy@1 | 137 return results, ErrNoSubscriptionID |
| paddy@1 | 138 } |
| paddy@1 | 139 query := p.getSubscriptionsSQL(ids) |
| paddy@1 | 140 rows, err := p.Query(query.String(), query.Args...) |
| paddy@1 | 141 if err != nil { |
| paddy@1 | 142 return results, err |
| paddy@1 | 143 } |
| paddy@1 | 144 for rows.Next() { |
| paddy@2 | 145 var subscription Subscription |
| paddy@2 | 146 err := pan.Unmarshal(rows, &subscription) |
| paddy@1 | 147 if err != nil { |
| paddy@1 | 148 return results, err |
| paddy@1 | 149 } |
| paddy@2 | 150 results[subscription.UserID.String()] = subscription |
| paddy@1 | 151 } |
| paddy@1 | 152 if err := rows.Err(); err != nil { |
| paddy@1 | 153 return results, err |
| paddy@1 | 154 } |
| paddy@1 | 155 return results, nil |
| paddy@1 | 156 } |
| paddy@1 | 157 |
| paddy@2 | 158 func (p Postgres) getSubscriptionStatsCountSQL() *pan.Query { |
| paddy@2 | 159 var subscription Subscription |
| paddy@2 | 160 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM") |
| paddy@2 | 161 query.Include(pan.GetTableName(subscription)) |
| paddy@2 | 162 return query.FlushExpressions(" ") |
| paddy@2 | 163 } |
| paddy@2 | 164 |
| paddy@2 | 165 func (p Postgres) getSubscriptionStatsCancelingSQL() *pan.Query { |
| paddy@2 | 166 var subscription Subscription |
| paddy@2 | 167 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM") |
| paddy@2 | 168 query.Include(pan.GetTableName(subscription)) |
| paddy@2 | 169 query.IncludeWhere() |
| paddy@2 | 170 query.Include(pan.GetUnquotedColumn(subscription, "Canceling")+" = ?", true) |
| paddy@2 | 171 return query.FlushExpressions(" ") |
| paddy@2 | 172 } |
| paddy@2 | 173 |
| paddy@2 | 174 func (p Postgres) getSubscriptionStatsFailingSQL() *pan.Query { |
| paddy@2 | 175 var subscription Subscription |
| paddy@2 | 176 query := pan.New(pan.POSTGRES, "SELECT COUNT(*) FROM") |
| paddy@2 | 177 query.Include(pan.GetTableName(subscription)) |
| paddy@2 | 178 query.IncludeWhere() |
| paddy@2 | 179 statuses := []interface{}{"past_due", "unpaid"} |
| paddy@2 | 180 query.Include(pan.GetUnquotedColumn(subscription, "Status")+" IN ("+pan.VariableList(len(statuses))+")", statuses...) |
| paddy@2 | 181 return query.FlushExpressions(" ") |
| paddy@2 | 182 } |
| paddy@2 | 183 |
| paddy@2 | 184 func (p Postgres) getSubscriptionStatsPlansSQL() *pan.Query { |
| paddy@2 | 185 var subscription Subscription |
| paddy@2 | 186 fields := []interface{}{pan.GetUnquotedColumn(subscription, "Plan"), "COUNT(*)"} |
| paddy@2 | 187 query := pan.New(pan.POSTGRES, "SELECT "+pan.QueryList(fields)+" FROM") |
| paddy@2 | 188 query.Include(pan.GetTableName(subscription)) |
| paddy@2 | 189 query.Include("GROUP BY " + pan.GetUnquotedColumn(subscription, "Plan")) |
| paddy@1 | 190 return query.FlushExpressions(" ") |
| paddy@1 | 191 } |
| paddy@1 | 192 |
| paddy@1 | 193 func (p Postgres) getSubscriptionStats() (SubscriptionStats, error) { |
| paddy@2 | 194 stats := SubscriptionStats{ |
| paddy@2 | 195 Plans: map[string]int64{}, |
| paddy@2 | 196 } |
| paddy@2 | 197 query := p.getSubscriptionStatsCountSQL() |
| paddy@2 | 198 err := p.QueryRow(query.String(), query.Args...).Scan(&stats.Number) |
| paddy@2 | 199 if err != nil { |
| paddy@2 | 200 log.Printf("Error querying for total subscriptions: %+v\n", err) |
| paddy@2 | 201 return stats, err |
| paddy@2 | 202 } |
| paddy@2 | 203 query = p.getSubscriptionStatsCancelingSQL() |
| paddy@2 | 204 err = p.QueryRow(query.String(), query.Args...).Scan(&stats.Canceling) |
| paddy@2 | 205 if err != nil { |
| paddy@2 | 206 log.Printf("Error querying for canceling subscriptions: %+v\n", err) |
| paddy@2 | 207 return stats, err |
| paddy@2 | 208 } |
| paddy@2 | 209 query = p.getSubscriptionStatsFailingSQL() |
| paddy@2 | 210 err = p.QueryRow(query.String(), query.Args...).Scan(&stats.Failing) |
| paddy@2 | 211 if err != nil { |
| paddy@2 | 212 log.Printf("Error querying for failing subscriptions: %+v\n", err) |
| paddy@2 | 213 return stats, err |
| paddy@2 | 214 } |
| paddy@2 | 215 query = p.getSubscriptionStatsPlansSQL() |
| paddy@1 | 216 rows, err := p.Query(query.String(), query.Args...) |
| paddy@1 | 217 if err != nil { |
| paddy@2 | 218 log.Printf("Error querying for plans: %+v\n", err) |
| paddy@2 | 219 return stats, err |
| paddy@1 | 220 } |
| paddy@1 | 221 for rows.Next() { |
| paddy@2 | 222 var plan string |
| paddy@2 | 223 var count int64 |
| paddy@2 | 224 err := rows.Scan(&plan, &count) |
| paddy@2 | 225 if err != nil { |
| paddy@2 | 226 log.Printf("Error scanning database row for plans: %+v\n", err) |
| paddy@2 | 227 continue |
| paddy@1 | 228 } |
| paddy@2 | 229 stats.Plans[plan] = count |
| paddy@1 | 230 } |
| paddy@1 | 231 if err := rows.Err(); err != nil { |
| paddy@2 | 232 log.Printf("Error querying for plans: %+v\n", err) |
| paddy@1 | 233 return stats, err |
| paddy@1 | 234 } |
| paddy@1 | 235 return stats, nil |
| paddy@1 | 236 } |