From f366b978991bdbd39f41f51168942bd9582b8e59 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 4 Feb 2025 18:38:00 -0500 Subject: [PATCH] collab: Use `billing_customers.has_overdue_invoices` to gate subscription access (#24240) This PR updates the check that prevents subscribing with overdue subscriptions to use the `billing_customers.has_overdue_invoices` field instead. This will allow us to set the value of `has_overdue_invoices` to `false` when the invoices have been paid. Release Notes: - N/A --- crates/collab/src/api/billing.rs | 42 +++---- .../src/db/queries/billing_subscriptions.rs | 36 ------ .../db/tests/billing_subscription_tests.rs | 112 +----------------- 3 files changed, 23 insertions(+), 167 deletions(-) diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 2e5a4a925b3980..0a1a544483a0c5 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -249,29 +249,31 @@ async fn create_billing_subscription( )); } - if app.db.has_overdue_billing_subscriptions(user.id).await? { - return Err(Error::http( - StatusCode::PAYMENT_REQUIRED, - "user has overdue billing subscriptions".into(), - )); + let existing_billing_customer = app.db.get_billing_customer_by_user_id(user.id).await?; + if let Some(existing_billing_customer) = &existing_billing_customer { + if existing_billing_customer.has_overdue_invoices { + return Err(Error::http( + StatusCode::PAYMENT_REQUIRED, + "user has overdue invoices".into(), + )); + } } - let customer_id = - if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? { - CustomerId::from_str(&existing_customer.stripe_customer_id) - .context("failed to parse customer ID")? - } else { - let customer = Customer::create( - &stripe_client, - CreateCustomer { - email: user.email_address.as_deref(), - ..Default::default() - }, - ) - .await?; + let customer_id = if let Some(existing_customer) = existing_billing_customer { + CustomerId::from_str(&existing_customer.stripe_customer_id) + .context("failed to parse customer ID")? + } else { + let customer = Customer::create( + &stripe_client, + CreateCustomer { + email: user.email_address.as_deref(), + ..Default::default() + }, + ) + .await?; - customer.id - }; + customer.id + }; let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?; let stripe_model = stripe_billing.register_model(default_model).await?; diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index d2762e2e8f45ba..4d2fce8c782d5a 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -170,40 +170,4 @@ impl Database { }) .await } - - /// Returns whether the user has any overdue billing subscriptions. - pub async fn has_overdue_billing_subscriptions(&self, user_id: UserId) -> Result { - Ok(self.count_overdue_billing_subscriptions(user_id).await? > 0) - } - - /// Returns the count of the overdue billing subscriptions for the user with the specified ID. - /// - /// This includes subscriptions: - /// - Whose status is `past_due` - /// - Whose status is `canceled` and the cancellation reason is `payment_failed` - pub async fn count_overdue_billing_subscriptions(&self, user_id: UserId) -> Result { - self.transaction(|tx| async move { - let past_due = billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::PastDue); - let payment_failed = billing_subscription::Column::StripeSubscriptionStatus - .eq(StripeSubscriptionStatus::Canceled) - .and( - billing_subscription::Column::StripeCancellationReason - .eq(StripeCancellationReason::PaymentFailed), - ); - - let count = billing_subscription::Entity::find() - .inner_join(billing_customer::Entity) - .filter( - billing_customer::Column::UserId - .eq(user_id) - .and(past_due.or(payment_failed)), - ) - .count(&*tx) - .await?; - - Ok(count as usize) - }) - .await - } } diff --git a/crates/collab/src/db/tests/billing_subscription_tests.rs b/crates/collab/src/db/tests/billing_subscription_tests.rs index d2368b72b3301a..4c9e0e77ec7240 100644 --- a/crates/collab/src/db/tests/billing_subscription_tests.rs +++ b/crates/collab/src/db/tests/billing_subscription_tests.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::db::billing_subscription::{StripeCancellationReason, StripeSubscriptionStatus}; +use crate::db::billing_subscription::StripeSubscriptionStatus; use crate::db::tests::new_test_user; use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams}; use crate::test_both_dbs; @@ -88,113 +88,3 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { assert_eq!(subscription_count, 0); } } - -test_both_dbs!( - test_count_overdue_billing_subscriptions, - test_count_overdue_billing_subscriptions_postgres, - test_count_overdue_billing_subscriptions_sqlite -); - -async fn test_count_overdue_billing_subscriptions(db: &Arc) { - // A user with no subscription has no overdue billing subscriptions. - { - let user_id = new_test_user(db, "no-subscription-user@example.com").await; - let subscription_count = db - .count_overdue_billing_subscriptions(user_id) - .await - .unwrap(); - - assert_eq!(subscription_count, 0); - } - - // A user with a past-due subscription has an overdue billing subscription. - { - let user_id = new_test_user(db, "past-due-user@example.com").await; - let customer = db - .create_billing_customer(&CreateBillingCustomerParams { - user_id, - stripe_customer_id: "cus_past_due_user".into(), - }) - .await - .unwrap(); - assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string()); - - db.create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: customer.id, - stripe_subscription_id: "sub_past_due_user".into(), - stripe_subscription_status: StripeSubscriptionStatus::PastDue, - stripe_cancellation_reason: None, - }) - .await - .unwrap(); - - let subscription_count = db - .count_overdue_billing_subscriptions(user_id) - .await - .unwrap(); - assert_eq!(subscription_count, 1); - } - - // A user with a canceled subscription with a reason of `payment_failed` has an overdue billing subscription. - { - let user_id = - new_test_user(db, "canceled-subscription-payment-failed-user@example.com").await; - let customer = db - .create_billing_customer(&CreateBillingCustomerParams { - user_id, - stripe_customer_id: "cus_canceled_subscription_payment_failed_user".into(), - }) - .await - .unwrap(); - assert_eq!( - customer.stripe_customer_id, - "cus_canceled_subscription_payment_failed_user".to_string() - ); - - db.create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: customer.id, - stripe_subscription_id: "sub_canceled_subscription_payment_failed_user".into(), - stripe_subscription_status: StripeSubscriptionStatus::Canceled, - stripe_cancellation_reason: Some(StripeCancellationReason::PaymentFailed), - }) - .await - .unwrap(); - - let subscription_count = db - .count_overdue_billing_subscriptions(user_id) - .await - .unwrap(); - assert_eq!(subscription_count, 1); - } - - // A user with a canceled subscription with a reason of `cancellation_requested` has no overdue billing subscriptions. - { - let user_id = new_test_user(db, "canceled-subscription-user@example.com").await; - let customer = db - .create_billing_customer(&CreateBillingCustomerParams { - user_id, - stripe_customer_id: "cus_canceled_subscription_user".into(), - }) - .await - .unwrap(); - assert_eq!( - customer.stripe_customer_id, - "cus_canceled_subscription_user".to_string() - ); - - db.create_billing_subscription(&CreateBillingSubscriptionParams { - billing_customer_id: customer.id, - stripe_subscription_id: "sub_canceled_subscription_user".into(), - stripe_subscription_status: StripeSubscriptionStatus::Canceled, - stripe_cancellation_reason: Some(StripeCancellationReason::CancellationRequested), - }) - .await - .unwrap(); - - let subscription_count = db - .count_overdue_billing_subscriptions(user_id) - .await - .unwrap(); - assert_eq!(subscription_count, 0); - } -}