Export all subscriptionStore methods.
We're not going to wrap all our subscriptionStore interactions in a Context
type, so we need to expose all the functions so other packages can call them.
Also, it's now SubscriptionStore.
Also creates a SubscriptionRequest type that is used to create a new
Subscription instance, along with a Validate method on it, to detect errors when
creating a SubscriptionRequest.
Update our CreateSubscription function to be New, and have it take a
SubscriptionRequest, a Stripe instance, and a SubscriptionStore as arguments.
It'll create the Subscription in the SubscriptionStore, create a customer in
Stripe, and create the subscription in Stripe, then associate the Stripe
subscription with the created Subscription.
Also, fix our StripeSubscriptionChange function to correctly equate a
missing/zero Unix timestamp from Stripe with a missing/zero time.Time.
9 "code.secondbit.org/uuid.hg"
13 subscriptionChangeStripeSubscription = 1 << iota
14 subscriptionChangePlan
15 subscriptionChangeStatus
16 subscriptionChangeCanceling
17 subscriptionChangeTrialStart
18 subscriptionChangeTrialEnd
19 subscriptionChangePeriodStart
20 subscriptionChangePeriodEnd
21 subscriptionChangeCanceledAt
22 subscriptionChangeFailedChargeAttempts
23 subscriptionChangeLastFailedCharge
24 subscriptionChangeLastNotified
28 if os.Getenv("PG_TEST_DB") != "" {
29 p, err := NewPostgres(os.Getenv("PG_TEST_DB"))
33 testSubscriptionStores = append(testSubscriptionStores, p)
37 var testSubscriptionStores = []SubscriptionStore{
41 func compareSubscriptions(sub1, sub2 Subscription) (bool, string, interface{}, interface{}) {
42 if !sub1.UserID.Equal(sub2.UserID) {
43 return false, "UserID", sub1.UserID, sub2.UserID
45 if sub1.StripeSubscription != sub2.StripeSubscription {
46 return false, "StripeSubscription", sub1.StripeSubscription, sub2.StripeSubscription
48 if sub1.Plan != sub2.Plan {
49 return false, "Plan", sub1.Plan, sub2.Plan
51 if sub1.Status != sub2.Status {
52 return false, "Status", sub1.Status, sub2.Status
54 if sub1.Canceling != sub2.Canceling {
55 return false, "Canceling", sub1.Canceling, sub2.Canceling
57 if !sub1.Created.Equal(sub2.Created) {
58 return false, "Created", sub1.Created, sub2.Created
60 if !sub1.TrialStart.Equal(sub2.TrialStart) {
61 return false, "TrialStart", sub1.TrialStart, sub2.TrialStart
63 if !sub1.TrialEnd.Equal(sub2.TrialEnd) {
64 return false, "TrialEnd", sub1.TrialEnd, sub2.TrialEnd
66 if !sub1.PeriodStart.Equal(sub2.PeriodStart) {
67 return false, "PeriodStart", sub1.PeriodStart, sub2.PeriodStart
69 if !sub1.PeriodEnd.Equal(sub2.PeriodEnd) {
70 return false, "PeriodEnd", sub1.PeriodEnd, sub2.PeriodEnd
72 if !sub1.CanceledAt.Equal(sub2.CanceledAt) {
73 return false, "CanceledAt", sub1.CanceledAt, sub2.CanceledAt
75 if sub1.FailedChargeAttempts != sub2.FailedChargeAttempts {
76 return false, "FailedChargeAttempts", sub1.FailedChargeAttempts, sub2.FailedChargeAttempts
78 if !sub1.LastFailedCharge.Equal(sub2.LastFailedCharge) {
79 return false, "LastFailedCharge", sub1.LastFailedCharge, sub2.LastFailedCharge
81 if !sub1.LastNotified.Equal(sub2.LastNotified) {
82 return false, "LastNotified", sub1.LastNotified, sub2.LastNotified
84 return true, "", nil, nil
87 func subscriptionMapContains(subscriptionMap map[string]Subscription, subscriptions ...Subscription) (bool, []Subscription) {
88 var missing []Subscription
89 for _, sub := range subscriptions {
90 if _, ok := subscriptionMap[sub.UserID.String()]; !ok {
91 missing = append(missing, sub)
100 func compareSubscriptionStats(stat1, stat2 SubscriptionStats) (bool, string, interface{}, interface{}) {
101 if stat1.Number != stat2.Number {
102 return false, "Number", stat1.Number, stat2.Number
104 if stat1.Canceling != stat2.Canceling {
105 return false, "Canceling", stat1.Canceling, stat2.Canceling
107 if stat1.Failing != stat2.Failing {
108 return false, "Failing", stat1.Failing, stat2.Failing
110 if len(stat1.Plans) != len(stat2.Plans) {
111 return false, "Plans", stat1.Plans, stat2.Plans
113 for key, count := range stat1.Plans {
114 count2, ok := stat2.Plans[key]
116 return false, "Plans", stat1.Plans, stat2.Plans
119 return false, "Plans", stat1.Plans, stat2.Plans
122 return true, "", nil, nil
125 func TestCreateSubscription(t *testing.T) {
126 for _, store := range testSubscriptionStores {
129 t.Fatalf("Error resetting %T: %+v\n", store, err)
131 customerID := uuid.NewID()
134 StripeSubscription: "stripeSubscription1",
135 Created: time.Now().Round(time.Millisecond),
136 TrialStart: time.Now().Round(time.Millisecond),
137 TrialEnd: time.Now().Round(time.Millisecond).Add(time.Hour * 24 * 31),
139 err = store.CreateSubscription(sub)
141 t.Errorf("Error creating subscription in %T: %+v\n", store, err)
143 retrieved, err := store.GetSubscriptions([]uuid.ID{sub.UserID})
145 t.Errorf("Error retrieving subscription from %T: %+v\n", store, err)
147 if _, returned := retrieved[sub.UserID.String()]; !returned {
148 t.Errorf("Error retrieving subscription from %T: %s wasn't in the results.", store, sub.UserID)
150 ok, field, expected, result := compareSubscriptions(sub, retrieved[sub.UserID.String()])
152 t.Errorf("Expected %s to be %v, got %v from %T\n", field, expected, result, store)
154 err = store.CreateSubscription(sub)
155 if err != ErrSubscriptionAlreadyExists {
156 t.Errorf("Unexpected error creating subscription in %T (wanted %+v): %+v\n", store, ErrSubscriptionAlreadyExists, err)
158 sub.UserID = uuid.NewID()
159 err = store.CreateSubscription(sub)
160 if err != ErrStripeSubscriptionAlreadyExists {
161 t.Errorf("Unexpected error creating subscription in %T (wanted %+v): %#+v\n", store, ErrStripeSubscriptionAlreadyExists, err)
163 sub.StripeSubscription = "stripeSubscription2"
164 err = store.CreateSubscription(sub)
166 t.Errorf("Error creating subscription in %T: %+v\n", store, err)
171 func TestUpdateSubscription(t *testing.T) {
172 variations := 1 << 12
174 UserID: uuid.NewID(),
175 StripeSubscription: "default",
176 Created: time.Now().Round(time.Millisecond).Add(time.Hour * -24 * -32),
177 TrialStart: time.Now().Round(time.Millisecond).Add(time.Hour * -24 * -32),
178 TrialEnd: time.Now().Round(time.Millisecond).Add(time.Hour * -24),
179 LastNotified: time.Now().Round(time.Millisecond).Add(time.Hour * -24),
181 sub2 := Subscription{
182 UserID: uuid.NewID(),
183 StripeSubscription: "stripeSubscription2",
184 Created: time.Now().Round(time.Millisecond),
185 TrialStart: time.Now().Round(time.Millisecond),
186 TrialEnd: time.Now().Round(time.Millisecond),
187 LastNotified: time.Now().Round(time.Millisecond),
190 for _, store := range testSubscriptionStores {
193 t.Fatalf("Error resetting %T: %+v\n", store, err)
195 err = store.CreateSubscription(sub)
197 t.Fatalf("Error saving subscription in %T: %s\n", store, err)
199 for i := 1; i < variations; i++ {
200 var stripeSubscription, plan, status string
202 var failedChargeAttempts int
203 var trialStart, trialEnd, periodStart, periodEnd, canceledAt, lastFailedCharge, lastNotified time.Time
205 change := SubscriptionChange{}
206 empty := change.IsEmpty()
208 t.Errorf("Expected empty to be %t, was %t\n", true, empty)
211 strI := strconv.Itoa(i)
213 if i&subscriptionChangeStripeSubscription != 0 {
214 stripeSubscription = "stripeSubscription-" + strI
215 change.StripeSubscription = &stripeSubscription
216 sub.StripeSubscription = stripeSubscription
219 if i&subscriptionChangePlan != 0 {
220 plan = "plan-" + strI
225 if i&subscriptionChangeStatus != 0 {
226 status = "status-" + strI
227 change.Status = &status
231 if i&subscriptionChangeCanceling != 0 {
233 change.Canceling = &canceling
234 sub.Canceling = canceling
237 if i&subscriptionChangeTrialStart != 0 {
238 trialStart = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
239 change.TrialStart = &trialStart
240 sub.TrialStart = trialStart
243 if i&subscriptionChangeTrialEnd != 0 {
244 trialEnd = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
245 change.TrialEnd = &trialEnd
246 sub.TrialEnd = trialEnd
249 if i&subscriptionChangePeriodStart != 0 {
250 periodStart = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
251 change.PeriodStart = &periodStart
252 sub.PeriodStart = periodStart
255 if i&subscriptionChangePeriodEnd != 0 {
256 periodEnd = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
257 change.PeriodEnd = &periodEnd
258 sub.PeriodEnd = periodEnd
261 if i&subscriptionChangeCanceledAt != 0 {
262 canceledAt = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
263 change.CanceledAt = &canceledAt
264 sub.CanceledAt = canceledAt
267 if i&subscriptionChangeFailedChargeAttempts != 0 {
268 failedChargeAttempts = i
269 change.FailedChargeAttempts = &failedChargeAttempts
270 sub.FailedChargeAttempts = failedChargeAttempts
273 if i&subscriptionChangeLastFailedCharge != 0 {
274 lastFailedCharge = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
275 change.LastFailedCharge = &lastFailedCharge
276 sub.LastFailedCharge = lastFailedCharge
279 if i&subscriptionChangeLastNotified != 0 {
280 lastNotified = time.Now().Round(time.Millisecond).Add(time.Hour * time.Duration(i))
281 change.LastNotified = &lastNotified
282 sub.LastNotified = lastNotified
285 empty = change.IsEmpty()
287 t.Errorf("Expected empty to be %t, was %t\n", false, empty)
290 result.ApplyChange(change)
291 match, field, expected, got := compareSubscriptions(sub, result)
293 t.Errorf("Expected field `%s` to be `%v`, got `%v`\n", field, expected, got)
295 err = store.UpdateSubscription(sub.UserID, change)
297 t.Errorf("Error updating subscription in %T: %s\n", store, err)
299 retrieved, err := store.GetSubscriptions([]uuid.ID{sub.UserID})
301 t.Errorf("Error getting subscription from %T: %s\n", store, err)
303 ok, missing := subscriptionMapContains(retrieved, sub)
305 t.Errorf("Expected to retrieve %s from %T, but missing was %+v\n", sub.UserID.String(), store, missing)
307 match, field, expected, got = compareSubscriptions(sub, retrieved[sub.UserID.String()])
309 t.Errorf("Expected field `%s` to be `%v`, got `%v` from %T\n", field, expected, got, store)
314 err = store.CreateSubscription(sub2)
316 t.Fatalf("Error saving subscription in %T: %+v\n", store, err)
318 change := SubscriptionChange{}
319 err = store.UpdateSubscription(sub.UserID, change)
320 if err != ErrSubscriptionChangeEmpty {
321 t.Errorf("Expected err to be %+v, but got %+v from %T\n", ErrSubscriptionChangeEmpty, err, store)
323 stripeSubscription := sub2.StripeSubscription
324 change.StripeSubscription = &stripeSubscription
325 err = store.UpdateSubscription(uuid.NewID(), change)
326 if err != ErrSubscriptionNotFound {
327 t.Errorf("Expected err to be %+v, but got %+v from %T\n", ErrSubscriptionNotFound, err, store)
329 err = store.UpdateSubscription(sub.UserID, change)
330 if err != ErrStripeSubscriptionAlreadyExists {
331 t.Errorf("Expected err to be %+v, but got %+v from %T\n", ErrStripeSubscriptionAlreadyExists, err, store)
336 func TestDeleteSubscription(t *testing.T) {
337 for _, store := range testSubscriptionStores {
340 t.Fatalf("Error resetting %T: %+v\n", store, err)
342 sub1 := Subscription{
343 UserID: uuid.NewID(),
344 StripeSubscription: "stripeSubscription1",
346 sub2 := Subscription{
347 UserID: uuid.NewID(),
348 StripeSubscription: "stripeSubscription2",
350 err = store.CreateSubscription(sub1)
352 t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
354 err = store.CreateSubscription(sub2)
356 t.Fatalf("Error creating %+v in %T: %+v\n", sub2, store, err)
358 err = store.DeleteSubscription(sub1.UserID)
360 t.Fatalf("Error deleting %+v in %T: %+v\n", sub1, store, err)
362 retrieved, err := store.GetSubscriptions([]uuid.ID{sub1.UserID, sub2.UserID})
364 t.Errorf("Error retrieving subscriptions from %T: %+v\n", store, err)
366 ok, missing := subscriptionMapContains(retrieved, sub1)
368 t.Errorf("Expected not to retrieve %s from %T, but missing was %+v\n", sub1.UserID.String(), store, missing)
370 ok, missing = subscriptionMapContains(retrieved, sub2)
372 t.Errorf("Expected to retrieve %s from %T, but missing was %+v\n", sub2.UserID.String(), store, missing)
374 err = store.DeleteSubscription(sub1.UserID)
375 if err != ErrSubscriptionNotFound {
376 t.Errorf("Expected err to be %+v, but got %+v from %T\n", ErrSubscriptionNotFound, err, store)
381 func TestGetSubscriptions(t *testing.T) {
382 for _, store := range testSubscriptionStores {
385 t.Fatalf("Error resetting %T: %+v\n", store, err)
387 sub1 := Subscription{
388 UserID: uuid.NewID(),
389 StripeSubscription: "stripeSubscription1",
391 Created: time.Now().Round(time.Millisecond),
392 TrialStart: time.Now().Round(time.Millisecond),
393 TrialEnd: time.Now().Round(time.Millisecond).Add(time.Hour * 24 * 32),
395 sub2 := Subscription{
396 UserID: uuid.NewID(),
397 StripeSubscription: "stripeSubscription2",
399 Created: time.Now().Round(time.Millisecond).Add(time.Hour * -720),
400 TrialStart: time.Now().Round(time.Millisecond).Add(time.Hour * -720),
401 TrialEnd: time.Now().Round(time.Millisecond),
403 sub3 := Subscription{
404 UserID: uuid.NewID(),
405 StripeSubscription: "stripeSubscription3",
407 Created: time.Now().Round(time.Millisecond).Add(time.Hour * -1440),
408 TrialStart: time.Now().Round(time.Millisecond).Add(time.Hour * -1440),
409 TrialEnd: time.Now().Round(time.Millisecond).Add(time.Hour * -720),
410 PeriodStart: time.Now().Round(time.Millisecond).Add(time.Hour * -720),
411 PeriodEnd: time.Now().Round(time.Millisecond),
414 err = store.CreateSubscription(sub1)
416 t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
418 err = store.CreateSubscription(sub2)
420 t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
422 err = store.CreateSubscription(sub3)
424 t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
426 retrieved, err := store.GetSubscriptions([]uuid.ID{})
427 if err != ErrNoSubscriptionID {
428 t.Errorf("Error retrieving no subscriptions from %T. Expected %+v, got %+v\n", store, ErrNoSubscriptionID, err)
430 retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID})
432 t.Errorf("Error retrieving %s from %T: %+v\n", sub1.UserID, store, err)
434 ok, missing := subscriptionMapContains(retrieved, sub1)
436 t.Logf("Results: %+v\n", retrieved)
437 t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
439 retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID, sub2.UserID})
441 t.Errorf("Error retrieving %s and %s from %T: %+v\n", sub1.UserID, sub2.UserID, store, err)
443 ok, missing = subscriptionMapContains(retrieved, sub1, sub2)
445 t.Logf("Results: %+v\n", retrieved)
446 t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
448 retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID, sub3.UserID})
450 t.Errorf("Error retrieving %s and %s from %T: %+v\n", sub1.UserID, sub3.UserID, store, err)
452 ok, missing = subscriptionMapContains(retrieved, sub1, sub3)
454 t.Logf("Results: %+v\n", retrieved)
455 t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
457 retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID, sub2.UserID, sub3.UserID})
459 t.Errorf("Error retrieving %s, %s, and %s from %T: %+v\n", sub1.UserID, sub2.UserID, sub3.UserID, store, err)
461 ok, missing = subscriptionMapContains(retrieved, sub1, sub2, sub3)
463 t.Logf("Results: %+v\n", retrieved)
464 t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
466 retrieved, err = store.GetSubscriptions([]uuid.ID{sub2.UserID})
468 t.Errorf("Error retrieving %s from %T: %+v\n", sub2.UserID, store, err)
470 ok, missing = subscriptionMapContains(retrieved, sub2)
472 t.Logf("Results: %+v\n", retrieved)
473 t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
475 retrieved, err = store.GetSubscriptions([]uuid.ID{sub2.UserID, sub3.UserID})
477 t.Errorf("Error retrieving %s and %s from %T: %+v\n", sub2.UserID, sub3.UserID, store, err)
479 ok, missing = subscriptionMapContains(retrieved, sub2, sub3)
481 t.Logf("Results: %+v\n", retrieved)
482 t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
484 retrieved, err = store.GetSubscriptions([]uuid.ID{sub3.UserID})
486 t.Errorf("Error retrieving %s from %T: %+v\n", sub3.UserID, store, err)
488 ok, missing = subscriptionMapContains(retrieved, sub3)
490 t.Logf("Results: %+v\n", retrieved)
491 t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
493 retrieved, err = store.GetSubscriptions([]uuid.ID{uuid.NewID()})
495 t.Errorf("Error retrieving non-existent ID from %T: %+v\n", store, err)
497 if len(retrieved) != 0 {
498 t.Errorf("Expected no results, %T returned %+v\n", store, retrieved)
500 retrieved, err = store.GetSubscriptions([]uuid.ID{sub1.UserID, sub2.UserID, uuid.NewID(), sub3.UserID})
502 t.Errorf("Error retrieving non-existent ID from %T: %+v\n", store, err)
504 if len(retrieved) != 3 {
505 t.Errorf("Expected 3 results, %T returned %+v\n", store, retrieved)
507 ok, missing = subscriptionMapContains(retrieved, sub1, sub2, sub3)
509 t.Logf("Results: %+v\n", retrieved)
510 t.Errorf("Expected %+v to be in the results, was not for %T.\n", missing, store)
515 func TestGetSubscriptionStats(t *testing.T) {
516 for _, store := range testSubscriptionStores {
519 t.Fatalf("Error resetting %T: %+v\n", store, err)
521 sub1 := Subscription{
522 UserID: uuid.NewID(),
523 StripeSubscription: "stripeSubscription1",
527 sub2 := Subscription{
528 UserID: uuid.NewID(),
529 StripeSubscription: "stripeSubscription2",
533 err = store.CreateSubscription(sub1)
535 t.Fatalf("Error creating %+v in %T: %+v\n", sub1, store, err)
537 stats, err := store.GetSubscriptionStats()
539 t.Errorf("Error getting stats from %T: %+v\n", store, err)
541 ok, field, expected, results := compareSubscriptionStats(SubscriptionStats{
545 Plans: map[string]int64{
550 t.Errorf("Expected %s to be %+v, got %+v from %T\n", field, expected, results, store)
552 err = store.CreateSubscription(sub2)
554 t.Fatalf("Error creating %+v in %T: %+v\n", sub2, store, err)
556 stats, err = store.GetSubscriptionStats()
558 t.Errorf("Error getting status from %T: %+v\n", store, err)
560 ok, field, expected, results = compareSubscriptionStats(SubscriptionStats{
564 Plans: map[string]int64{
570 t.Errorf("Expected %s to be %+v, got %+v from %T\n", field, expected, results, store)
572 err = store.DeleteSubscription(sub1.UserID)
574 t.Errorf("Error deleting subscription from %T: %+v\n", store, err)
576 stats, err = store.GetSubscriptionStats()
578 t.Errorf("Error getting status from %T: %+v\n", store, err)
580 ok, field, expected, results = compareSubscriptionStats(SubscriptionStats{
584 Plans: map[string]int64{
589 t.Errorf("Expected %s to be %+v, got %+v from %T\n", field, expected, results, store)