Skip to content

Commit

Permalink
collab: Add usage-based billing for LLM interactions (#19081)
Browse files Browse the repository at this point in the history
This PR adds usage-based billing for LLM interactions in the Assistant.

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>
Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Richard <richard@zed.dev>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>
  • Loading branch information
5 people authored Oct 11, 2024
1 parent f1c45d9 commit 22ea7ce
Show file tree
Hide file tree
Showing 20 changed files with 920 additions and 282 deletions.
5 changes: 2 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ which = "6.0.0"
wit-component = "0.201"

[workspace.dependencies.async-stripe]
version = "0.39"
git = "https://github.com/zed-industries/async-stripe"
rev = "3672dd4efb7181aa597bf580bf5a2f5d23db6735"
default-features = false
features = [
"runtime-tokio-hyper-rustls",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
create table billing_events (
id serial primary key,
idempotency_key uuid not null default gen_random_uuid(),
user_id integer not null,
model_id integer not null references models (id) on delete cascade,
input_tokens bigint not null default 0,
input_cache_creation_tokens bigint not null default 0,
input_cache_read_tokens bigint not null default 0,
output_tokens bigint not null default 0
);

create index uix_billing_events_on_user_id_model_id on billing_events (user_id, model_id);
203 changes: 85 additions & 118 deletions crates/collab/src/api/billing.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;

use anyhow::{anyhow, bail, Context};
use axum::{
extract::{self, Query},
routing::{get, post},
Extension, Json, Router,
};
use chrono::{DateTime, SecondsFormat, Utc};
use collections::HashSet;
use reqwest::StatusCode;
use sea_orm::ActiveValue;
use serde::{Deserialize, Serialize};
use std::{str::FromStr, sync::Arc, time::Duration};
use stripe::{
BillingPortalSession, CheckoutSession, CreateBillingPortalSession,
CreateBillingPortalSessionFlowData, CreateBillingPortalSessionFlowDataAfterCompletion,
BillingPortalSession, CreateBillingPortalSession, CreateBillingPortalSessionFlowData,
CreateBillingPortalSessionFlowDataAfterCompletion,
CreateBillingPortalSessionFlowDataAfterCompletionRedirect,
CreateBillingPortalSessionFlowDataType, CreateCheckoutSession, CreateCheckoutSessionLineItems,
CreateCustomer, Customer, CustomerId, EventObject, EventType, Expandable, ListEvents,
Subscription, SubscriptionId, SubscriptionStatus,
CreateBillingPortalSessionFlowDataType, CreateCustomer, Customer, CustomerId, EventObject,
EventType, Expandable, ListEvents, Subscription, SubscriptionId, SubscriptionStatus,
};
use util::ResultExt;

use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
use crate::db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
};
use crate::llm::db::LlmDatabase;
use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::llm::DEFAULT_MAX_MONTHLY_SPEND;
use crate::rpc::ResultExt as _;
use crate::{
db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams,
UpdateBillingCustomerParams, UpdateBillingPreferencesParams,
UpdateBillingSubscriptionParams,
},
stripe_billing::StripeBilling,
};
use crate::{
db::{billing_subscription::StripeSubscriptionStatus, UserId},
llm::db::LlmDatabase,
};
use crate::{AppState, Error, Result};

pub fn router() -> Router {
Expand Down Expand Up @@ -87,6 +90,7 @@ struct UpdateBillingPreferencesBody {

async fn update_billing_preferences(
Extension(app): Extension<Arc<AppState>>,
Extension(rpc_server): Extension<Arc<crate::rpc::Server>>,
extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
) -> Result<Json<BillingPreferencesResponse>> {
let user = app
Expand Down Expand Up @@ -119,6 +123,8 @@ async fn update_billing_preferences(
.await?
};

rpc_server.refresh_llm_tokens_for_user(user.id).await;

Ok(Json(BillingPreferencesResponse {
max_monthly_llm_usage_spending_in_cents: billing_preferences
.max_monthly_llm_usage_spending_in_cents,
Expand Down Expand Up @@ -197,12 +203,15 @@ async fn create_billing_subscription(
.await?
.ok_or_else(|| anyhow!("user not found"))?;

let Some((stripe_client, stripe_access_price_id)) = app
.stripe_client
.clone()
.zip(app.config.stripe_llm_access_price_id.clone())
else {
log::error!("failed to retrieve Stripe client or price ID");
let Some(stripe_client) = app.stripe_client.clone() else {
log::error!("failed to retrieve Stripe client");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
))?
};
let Some(llm_db) = app.llm_db.clone() else {
log::error!("failed to retrieve LLM database");
Err(Error::http(
StatusCode::NOT_IMPLEMENTED,
"not supported".into(),
Expand All @@ -226,26 +235,15 @@ async fn create_billing_subscription(
customer.id
};

let checkout_session = {
let mut params = CreateCheckoutSession::new();
params.mode = Some(stripe::CheckoutSessionMode::Subscription);
params.customer = Some(customer_id);
params.client_reference_id = Some(user.github_login.as_str());
params.line_items = Some(vec![CreateCheckoutSessionLineItems {
price: Some(stripe_access_price_id.to_string()),
quantity: Some(1),
..Default::default()
}]);
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
params.success_url = Some(&success_url);

CheckoutSession::create(&stripe_client, params).await?
};

let default_model = llm_db.model(rpc::LanguageModelProvider::Anthropic, "claude-3-5-sonnet")?;
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;
let stripe_model = stripe_billing.register_model(default_model).await?;
let success_url = format!("{}/account", app.config.zed_dot_dev_url());
let checkout_session_url = stripe_billing
.checkout(customer_id, &user.github_login, &stripe_model, &success_url)
.await?;
Ok(Json(CreateBillingSubscriptionResponse {
checkout_session_url: checkout_session
.url
.ok_or_else(|| anyhow!("no checkout session URL"))?,
checkout_session_url,
}))
}

Expand Down Expand Up @@ -715,15 +713,15 @@ async fn find_or_create_billing_customer(
Ok(Some(billing_customer))
}

const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(24 * 60 * 60);
const SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL: Duration = Duration::from_secs(60);

pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDatabase) {
pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>) {
let Some(stripe_client) = app.stripe_client.clone() else {
log::warn!("failed to retrieve Stripe client");
return;
};
let Some(stripe_llm_usage_price_id) = app.config.stripe_llm_usage_price_id.clone() else {
log::warn!("failed to retrieve Stripe LLM usage price ID");
let Some(llm_db) = app.llm_db.clone() else {
log::warn!("failed to retrieve LLM database");
return;
};

Expand All @@ -732,15 +730,9 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa
let executor = executor.clone();
async move {
loop {
sync_with_stripe(
&app,
&llm_db,
&stripe_client,
stripe_llm_usage_price_id.clone(),
)
.await
.trace_err();

sync_with_stripe(&app, &llm_db, &stripe_client)
.await
.trace_err();
executor.sleep(SYNC_LLM_USAGE_WITH_STRIPE_INTERVAL).await;
}
}
Expand All @@ -749,71 +741,46 @@ pub fn sync_llm_usage_with_stripe_periodically(app: Arc<AppState>, llm_db: LlmDa

async fn sync_with_stripe(
app: &Arc<AppState>,
llm_db: &LlmDatabase,
stripe_client: &stripe::Client,
stripe_llm_usage_price_id: Arc<str>,
llm_db: &Arc<LlmDatabase>,
stripe_client: &Arc<stripe::Client>,
) -> anyhow::Result<()> {
let subscriptions = app.db.get_active_billing_subscriptions().await?;

for (customer, subscription) in subscriptions {
update_stripe_subscription(
llm_db,
stripe_client,
&stripe_llm_usage_price_id,
customer,
subscription,
)
.await
.log_err();
}

Ok(())
}

async fn update_stripe_subscription(
llm_db: &LlmDatabase,
stripe_client: &stripe::Client,
stripe_llm_usage_price_id: &Arc<str>,
customer: billing_customer::Model,
subscription: billing_subscription::Model,
) -> Result<(), anyhow::Error> {
let monthly_spending = llm_db
.get_user_spending_for_month(customer.user_id, Utc::now())
.await?;
let subscription_id = SubscriptionId::from_str(&subscription.stripe_subscription_id)
.context("failed to parse subscription ID")?;

let monthly_spending_over_free_tier =
monthly_spending.saturating_sub(FREE_TIER_MONTHLY_SPENDING_LIMIT);

let new_quantity = (monthly_spending_over_free_tier.0 as f32 / 100.).ceil();
let current_subscription = Subscription::retrieve(stripe_client, &subscription_id, &[]).await?;

let mut update_params = stripe::UpdateSubscription {
proration_behavior: Some(
stripe::generated::billing::subscription::SubscriptionProrationBehavior::None,
),
..Default::default()
};

if let Some(existing_item) = current_subscription.items.data.iter().find(|item| {
item.price.as_ref().map_or(false, |price| {
price.id == stripe_llm_usage_price_id.as_ref()
})
}) {
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
id: Some(existing_item.id.to_string()),
quantity: Some(new_quantity as u64),
..Default::default()
}]);
} else {
update_params.items = Some(vec![stripe::UpdateSubscriptionItems {
price: Some(stripe_llm_usage_price_id.to_string()),
quantity: Some(new_quantity as u64),
..Default::default()
}]);
let mut stripe_billing = StripeBilling::new(stripe_client.clone()).await?;

let events = llm_db.get_billing_events().await?;
let user_ids = events
.iter()
.map(|(event, _)| event.user_id)
.collect::<HashSet<UserId>>();
let stripe_subscriptions = app.db.get_active_billing_subscriptions(user_ids).await?;

for (event, model) in events {
let Some((stripe_db_customer, stripe_db_subscription)) =
stripe_subscriptions.get(&event.user_id)
else {
tracing::warn!(
user_id = event.user_id.0,
"Registered billing event for user who is not a Stripe customer. Billing events should only be created for users who are Stripe customers, so this is a mistake on our side."
);
continue;
};
let stripe_subscription_id: stripe::SubscriptionId = stripe_db_subscription
.stripe_subscription_id
.parse()
.context("failed to parse stripe subscription id from db")?;
let stripe_customer_id: stripe::CustomerId = stripe_db_customer
.stripe_customer_id
.parse()
.context("failed to parse stripe customer id from db")?;

let stripe_model = stripe_billing.register_model(&model).await?;
stripe_billing
.subscribe_to_model(&stripe_subscription_id, &stripe_model)
.await?;
stripe_billing
.bill_model_usage(&stripe_customer_id, &stripe_model, &event)
.await?;
llm_db.consume_billing_event(event.id).await?;
}

Subscription::update(stripe_client, &subscription_id, update_params).await?;
Ok(())
}
38 changes: 23 additions & 15 deletions crates/collab/src/db/queries/billing_subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,23 +114,31 @@ impl Database {

pub async fn get_active_billing_subscriptions(
&self,
) -> Result<Vec<(billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| async move {
let mut result = Vec::new();
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;

while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
result.push((customer, subscription));
user_ids: HashSet<UserId>,
) -> Result<HashMap<UserId, (billing_customer::Model, billing_subscription::Model)>> {
self.transaction(|tx| {
let user_ids = user_ids.clone();
async move {
let mut rows = billing_subscription::Entity::find()
.inner_join(billing_customer::Entity)
.select_also(billing_customer::Entity)
.filter(billing_customer::Column::UserId.is_in(user_ids))
.filter(
billing_subscription::Column::StripeSubscriptionStatus
.eq(StripeSubscriptionStatus::Active),
)
.order_by_asc(billing_subscription::Column::Id)
.stream(&*tx)
.await?;

let mut subscriptions = HashMap::default();
while let Some(row) = rows.next().await {
if let (subscription, Some(customer)) = row? {
subscriptions.insert(customer.user_id, (customer, subscription));
}
}
Ok(subscriptions)
}

Ok(result)
})
.await
}
Expand Down
Loading

0 comments on commit 22ea7ce

Please sign in to comment.