diff --git a/crates/collab/src/api/billing.rs b/crates/collab/src/api/billing.rs index 9d720147ff3918..eceb01ee231b4f 100644 --- a/crates/collab/src/api/billing.rs +++ b/crates/collab/src/api/billing.rs @@ -249,6 +249,13 @@ async fn create_billing_subscription( )); } + if app.db.has_overdue_billing_subscriptions(user.id).await? { + return Err(Error::http( + StatusCode::PAYMENT_REQUIRED, + "user has overdue billing subscriptions".into(), + )); + } + let customer_id = if let Some(existing_customer) = app.db.get_billing_customer_by_user_id(user.id).await? { CustomerId::from_str(&existing_customer.stripe_customer_id) @@ -719,6 +726,10 @@ async fn handle_customer_subscription_event( billing_customer_id: billing_customer.id, stripe_subscription_id: subscription.id.to_string(), stripe_subscription_status: subscription.status.into(), + stripe_cancellation_reason: subscription + .cancellation_details + .and_then(|details| details.reason) + .map(|reason| reason.into()), }) .await?; } diff --git a/crates/collab/src/db/queries/billing_subscriptions.rs b/crates/collab/src/db/queries/billing_subscriptions.rs index 027f46f6b78141..d2762e2e8f45ba 100644 --- a/crates/collab/src/db/queries/billing_subscriptions.rs +++ b/crates/collab/src/db/queries/billing_subscriptions.rs @@ -7,6 +7,7 @@ pub struct CreateBillingSubscriptionParams { pub billing_customer_id: BillingCustomerId, pub stripe_subscription_id: String, pub stripe_subscription_status: StripeSubscriptionStatus, + pub stripe_cancellation_reason: Option, } #[derive(Debug, Default)] @@ -29,6 +30,7 @@ impl Database { billing_customer_id: ActiveValue::set(params.billing_customer_id), stripe_subscription_id: ActiveValue::set(params.stripe_subscription_id.clone()), stripe_subscription_status: ActiveValue::set(params.stripe_subscription_status), + stripe_cancellation_reason: ActiveValue::set(params.stripe_cancellation_reason), ..Default::default() }) .exec_without_returning(&*tx) @@ -168,4 +170,40 @@ impl Database { }) .await } + + /// Returns whether the user has any overdue billing subscriptions. + pub async fn has_overdue_billing_subscriptions(&self, user_id: UserId) -> Result { + Ok(self.count_overdue_billing_subscriptions(user_id).await? > 0) + } + + /// Returns the count of the overdue billing subscriptions for the user with the specified ID. + /// + /// This includes subscriptions: + /// - Whose status is `past_due` + /// - Whose status is `canceled` and the cancellation reason is `payment_failed` + pub async fn count_overdue_billing_subscriptions(&self, user_id: UserId) -> Result { + self.transaction(|tx| async move { + let past_due = billing_subscription::Column::StripeSubscriptionStatus + .eq(StripeSubscriptionStatus::PastDue); + let payment_failed = billing_subscription::Column::StripeSubscriptionStatus + .eq(StripeSubscriptionStatus::Canceled) + .and( + billing_subscription::Column::StripeCancellationReason + .eq(StripeCancellationReason::PaymentFailed), + ); + + let count = billing_subscription::Entity::find() + .inner_join(billing_customer::Entity) + .filter( + billing_customer::Column::UserId + .eq(user_id) + .and(past_due.or(payment_failed)), + ) + .count(&*tx) + .await?; + + Ok(count as usize) + }) + .await + } } diff --git a/crates/collab/src/db/tests/billing_subscription_tests.rs b/crates/collab/src/db/tests/billing_subscription_tests.rs index a1973e3fbbf2cb..d2368b72b3301a 100644 --- a/crates/collab/src/db/tests/billing_subscription_tests.rs +++ b/crates/collab/src/db/tests/billing_subscription_tests.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::db::billing_subscription::StripeSubscriptionStatus; +use crate::db::billing_subscription::{StripeCancellationReason, StripeSubscriptionStatus}; use crate::db::tests::new_test_user; use crate::db::{CreateBillingCustomerParams, CreateBillingSubscriptionParams}; use crate::test_both_dbs; @@ -41,6 +41,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { billing_customer_id: customer.id, stripe_subscription_id: "sub_active_user".into(), stripe_subscription_status: StripeSubscriptionStatus::Active, + stripe_cancellation_reason: None, }) .await .unwrap(); @@ -75,6 +76,7 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { billing_customer_id: customer.id, stripe_subscription_id: "sub_past_due_user".into(), stripe_subscription_status: StripeSubscriptionStatus::PastDue, + stripe_cancellation_reason: None, }) .await .unwrap(); @@ -86,3 +88,113 @@ async fn test_get_active_billing_subscriptions(db: &Arc) { assert_eq!(subscription_count, 0); } } + +test_both_dbs!( + test_count_overdue_billing_subscriptions, + test_count_overdue_billing_subscriptions_postgres, + test_count_overdue_billing_subscriptions_sqlite +); + +async fn test_count_overdue_billing_subscriptions(db: &Arc) { + // A user with no subscription has no overdue billing subscriptions. + { + let user_id = new_test_user(db, "no-subscription-user@example.com").await; + let subscription_count = db + .count_overdue_billing_subscriptions(user_id) + .await + .unwrap(); + + assert_eq!(subscription_count, 0); + } + + // A user with a past-due subscription has an overdue billing subscription. + { + let user_id = new_test_user(db, "past-due-user@example.com").await; + let customer = db + .create_billing_customer(&CreateBillingCustomerParams { + user_id, + stripe_customer_id: "cus_past_due_user".into(), + }) + .await + .unwrap(); + assert_eq!(customer.stripe_customer_id, "cus_past_due_user".to_string()); + + db.create_billing_subscription(&CreateBillingSubscriptionParams { + billing_customer_id: customer.id, + stripe_subscription_id: "sub_past_due_user".into(), + stripe_subscription_status: StripeSubscriptionStatus::PastDue, + stripe_cancellation_reason: None, + }) + .await + .unwrap(); + + let subscription_count = db + .count_overdue_billing_subscriptions(user_id) + .await + .unwrap(); + assert_eq!(subscription_count, 1); + } + + // A user with a canceled subscription with a reason of `payment_failed` has an overdue billing subscription. + { + let user_id = + new_test_user(db, "canceled-subscription-payment-failed-user@example.com").await; + let customer = db + .create_billing_customer(&CreateBillingCustomerParams { + user_id, + stripe_customer_id: "cus_canceled_subscription_payment_failed_user".into(), + }) + .await + .unwrap(); + assert_eq!( + customer.stripe_customer_id, + "cus_canceled_subscription_payment_failed_user".to_string() + ); + + db.create_billing_subscription(&CreateBillingSubscriptionParams { + billing_customer_id: customer.id, + stripe_subscription_id: "sub_canceled_subscription_payment_failed_user".into(), + stripe_subscription_status: StripeSubscriptionStatus::Canceled, + stripe_cancellation_reason: Some(StripeCancellationReason::PaymentFailed), + }) + .await + .unwrap(); + + let subscription_count = db + .count_overdue_billing_subscriptions(user_id) + .await + .unwrap(); + assert_eq!(subscription_count, 1); + } + + // A user with a canceled subscription with a reason of `cancellation_requested` has no overdue billing subscriptions. + { + let user_id = new_test_user(db, "canceled-subscription-user@example.com").await; + let customer = db + .create_billing_customer(&CreateBillingCustomerParams { + user_id, + stripe_customer_id: "cus_canceled_subscription_user".into(), + }) + .await + .unwrap(); + assert_eq!( + customer.stripe_customer_id, + "cus_canceled_subscription_user".to_string() + ); + + db.create_billing_subscription(&CreateBillingSubscriptionParams { + billing_customer_id: customer.id, + stripe_subscription_id: "sub_canceled_subscription_user".into(), + stripe_subscription_status: StripeSubscriptionStatus::Canceled, + stripe_cancellation_reason: Some(StripeCancellationReason::CancellationRequested), + }) + .await + .unwrap(); + + let subscription_count = db + .count_overdue_billing_subscriptions(user_id) + .await + .unwrap(); + assert_eq!(subscription_count, 0); + } +}