diff --git a/packages/ai-bot/lib/ai-billing.ts b/packages/ai-bot/lib/ai-billing.ts new file mode 100644 index 0000000000..8ef38c4894 --- /dev/null +++ b/packages/ai-bot/lib/ai-billing.ts @@ -0,0 +1,95 @@ +import { + getCurrentActiveSubscription, + getUserByMatrixUserId, + spendCredits, + sumUpCreditsLedger, +} from '@cardstack/billing/billing-queries'; +import { PgAdapter, TransactionManager } from '@cardstack/postgres'; +import { logger, retry } from '@cardstack/runtime-common'; +import * as Sentry from '@sentry/node'; + +let log = logger('ai-bot'); + +export async function saveUsageCost( + pgAdapter: PgAdapter, + matrixUserId: string, + generationId: string, +) { + try { + // Generation data is sometimes not immediately available, so we retry a couple of times until we are able to get the cost + let costInUsd = await retry(() => fetchGenerationCost(generationId), { + retries: 10, + delayMs: 500, + }); + + let creditsConsumed = Math.round(costInUsd / 0.001); + + let user = await getUserByMatrixUserId(pgAdapter, matrixUserId); + + // This check is for the transition period where we don't have subscriptions fully rolled out yet. + // When we have assurance that all users who use the bot have subscriptions, we can remove this subscription check. + let subscription = await getCurrentActiveSubscription(pgAdapter, user!.id); + if (!subscription) { + log.info( + `user ${matrixUserId} has no subscription, skipping credit usage tracking`, + ); + return Promise.resolve(); + } + + if (!user) { + throw new Error( + `should not happen: user with matrix id ${matrixUserId} not found in the users table`, + ); + } + + let txManager = new TransactionManager(pgAdapter); + + await txManager.withTransaction(async () => { + await spendCredits(pgAdapter, user!.id, creditsConsumed); + + // TODO: send a signal to the host app to update credits balance displayed in the UI + }); + } catch (err) { + log.error( + `Failed to track AI usage (matrixUserId: ${matrixUserId}, generationId: ${generationId}):`, + err, + ); + Sentry.captureException(err); + // Don't throw, because we don't want to crash the bot over this + } +} + +export async function getAvailableCredits( + pgAdapter: PgAdapter, + matrixUserId: string, +) { + let user = await getUserByMatrixUserId(pgAdapter, matrixUserId); + + if (!user) { + throw new Error( + `should not happen: user with matrix id ${matrixUserId} not found in the users table`, + ); + } + + let availableCredits = await sumUpCreditsLedger(pgAdapter, { + userId: user.id, + }); + + return availableCredits; +} + +async function fetchGenerationCost(generationId: string) { + let response = await ( + await fetch(`https://openrouter.ai/api/v1/generation?id=${generationId}`, { + headers: { + Authorization: `Bearer ${process.env.OPENROUTER_API_KEY}`, + }, + }) + ).json(); + + if (response.error && response.error.includes('not found')) { + return null; + } + + return response.data.total_cost; +} diff --git a/packages/ai-bot/lib/matrix.ts b/packages/ai-bot/lib/matrix.ts index 5d129994e1..dc8399d696 100644 --- a/packages/ai-bot/lib/matrix.ts +++ b/packages/ai-bot/lib/matrix.ts @@ -145,7 +145,7 @@ function getErrorMessage(error: any): string { return `OpenAI error: ${error.name} - ${error.message}`; } if (typeof error === 'string') { - return `Unknown error: ${error}`; + return error; } - return `Unknown error`; + return 'Unknown error'; } diff --git a/packages/ai-bot/lib/send-response.ts b/packages/ai-bot/lib/send-response.ts index 4b913c8987..494aa114e5 100644 --- a/packages/ai-bot/lib/send-response.ts +++ b/packages/ai-bot/lib/send-response.ts @@ -133,7 +133,7 @@ export class Responder { } } - async onError(error: OpenAIError) { + async onError(error: OpenAIError | string) { Sentry.captureException(error); return await sendError( this.client, diff --git a/packages/ai-bot/main.ts b/packages/ai-bot/main.ts index 910a3851ee..a7bb0b6843 100644 --- a/packages/ai-bot/main.ts +++ b/packages/ai-bot/main.ts @@ -25,20 +25,43 @@ import { MatrixClient } from './lib/matrix'; import type { MatrixEvent as DiscreteMatrixEvent } from 'https://cardstack.com/base/matrix-event'; import * as Sentry from '@sentry/node'; +import { getAvailableCredits, saveUsageCost } from './lib/ai-billing'; +import { PgAdapter } from '@cardstack/postgres'; +import { + getUserByMatrixUserId, + sumUpCreditsLedger, +} from '@cardstack/billing/billing-queries'; + let log = logger('ai-bot'); +let trackAiUsageCostPromises = new Map>(); + class Assistant { private openai: OpenAI; private client: MatrixClient; + pgAdapter: PgAdapter; id: string; constructor(client: MatrixClient, id: string) { this.openai = new OpenAI({ - baseURL: 'https://openrouter.ai/api/v1', // We use openrouter so that we can track usage cost in $ + baseURL: 'https://openrouter.ai/api/v1', apiKey: process.env.OPENROUTER_API_KEY, }); this.id = id; this.client = client; + this.pgAdapter = new PgAdapter(); + } + + async trackAiUsageCost(matrixUserId: string, generationId: string) { + if (trackAiUsageCostPromises.has(matrixUserId)) { + return; + } + trackAiUsageCostPromises.set( + matrixUserId, + saveUsageCost(this.pgAdapter, matrixUserId, generationId).finally(() => { + trackAiUsageCostPromises.delete(matrixUserId); + }), + ); } getResponse(history: DiscreteMatrixEvent[]) { @@ -133,6 +156,7 @@ Common issues are: async function (event, room, toStartOfTimeline) { try { let eventBody = event.getContent().body; + let senderMatrixUserId = event.getSender()!; if (!room) { return; } @@ -150,7 +174,7 @@ Common issues are: return; // don't respond to card fragments, we just gather these in our history } - if (event.getSender() === aiBotUserId) { + if (senderMatrixUserId === aiBotUserId) { return; } log.info( @@ -158,7 +182,7 @@ Common issues are: event.getType(), room?.name, room?.roomId, - event.getSender(), + senderMatrixUserId, eventBody, ); @@ -177,18 +201,47 @@ Common issues are: } const responder = new Responder(client, room.roomId); - await responder.initialize(); - if (historyError) { - responder.finalize( + return responder.finalize( 'There was an error processing chat history. Please open another session.', ); - return; } + await responder.initialize(); + + // Do not generate new responses if previous ones' cost is still being reported + let pendingCreditsConsumptionPromise = trackAiUsageCostPromises.get( + senderMatrixUserId!, + ); + if (pendingCreditsConsumptionPromise) { + try { + await pendingCreditsConsumptionPromise; + } catch (e) { + log.error(e); + return responder.onError( + 'There was an error saving your Boxel credits usage. Try again or contact support if the problem persists.', + ); + } + } + + let availableCredits = await getAvailableCredits( + assistant.pgAdapter, + senderMatrixUserId, + ); + + let minimumCredits = 10; + + if (availableCredits < minimumCredits) { + return responder.onError( + `You need a minimum of ${minimumCredits} credits to continue using the AI bot. Please upgrade to a larger plan, or top up your account.`, + ); + } + + let generationId: string | undefined; const runner = assistant .getResponse(history) .on('chunk', async (chunk, _snapshot) => { + generationId = chunk.id; await responder.onChunk(chunk); }) .on('content', async (_delta, snapshot) => { @@ -200,9 +253,18 @@ Common issues are: .on('error', async (error) => { await responder.onError(error); }); - // We also need to catch the error when getting the final content - let finalContent = await runner.finalContent().catch(responder.onError); - await responder.finalize(finalContent); + + let finalContent; + try { + finalContent = await runner.finalContent(); + await responder.finalize(finalContent); + } catch (error) { + await responder.onError(error); + } finally { + if (generationId) { + assistant.trackAiUsageCost(senderMatrixUserId, generationId); + } + } if (shouldSetRoomTitle(eventList, aiBotUserId, event)) { return await assistant.setTitle(room.roomId, history, event); diff --git a/packages/ai-bot/package.json b/packages/ai-bot/package.json index 9578581ce1..5eb5768711 100644 --- a/packages/ai-bot/package.json +++ b/packages/ai-bot/package.json @@ -1,7 +1,9 @@ { "name": "@cardstack/ai-bot", "dependencies": { - "@cardstack/runtime-common": "workspace:^", + "@cardstack/runtime-common": "workspace:*", + "@cardstack/postgres": "workspace:*", + "@cardstack/billing": "workspace:*", "@sentry/node": "^8.31.0", "@types/node": "^18.18.5", "@types/stream-chain": "^2.0.1", @@ -21,7 +23,7 @@ }, "scripts": { "lint": "eslint . --cache --ext ts", - "start": "NODE_NO_WARNINGS=1 ts-node --transpileOnly main", + "start": "NODE_NO_WARNINGS=1 PGDATABASE=boxel PGPORT=5435 ts-node --transpileOnly main", "test": "NODE_NO_WARNINGS=1 qunit --require ts-node/register/transpile-only tests/index.ts", "get-chat": "NODE_NO_WARNINGS=1 ts-node --transpileOnly scripts/get_chat.ts" }, diff --git a/packages/billing/billing-queries.ts b/packages/billing/billing-queries.ts index 827778c217..07ab2076ea 100644 --- a/packages/billing/billing-queries.ts +++ b/packages/billing/billing-queries.ts @@ -472,3 +472,67 @@ export async function expireRemainingPlanAllowanceInSubscriptionCycle( subscriptionCycleId, }); } + +export async function spendCredits( + dbAdapter: DBAdapter, + userId: string, + creditsToSpend: number, +) { + let subscription = await getCurrentActiveSubscription(dbAdapter, userId); + if (!subscription) { + throw new Error('active subscription not found'); + } + let subscriptionCycle = await getMostRecentSubscriptionCycle( + dbAdapter, + subscription.id, + ); + if (!subscriptionCycle) { + throw new Error('subscription cycle not found'); + } + let availablePlanAllowanceCredits = await sumUpCreditsLedger(dbAdapter, { + creditType: [ + 'plan_allowance', + 'plan_allowance_used', + 'plan_allowance_expired', + ], + userId, + }); + + if (availablePlanAllowanceCredits >= creditsToSpend) { + await addToCreditsLedger(dbAdapter, { + userId, + creditAmount: -creditsToSpend, + creditType: 'plan_allowance_used', + subscriptionCycleId: subscriptionCycle.id, + }); + } else { + // If user does not have enough plan allowance credits to cover the spend, try to also use extra credits + let availableExtraCredits = await sumUpCreditsLedger(dbAdapter, { + creditType: ['extra_credit', 'extra_credit_used'], + userId, + }); + let planAllowanceToSpend = availablePlanAllowanceCredits; // Spend all plan allowance credits first + let extraCreditsToSpend = creditsToSpend - planAllowanceToSpend; + if (extraCreditsToSpend > availableExtraCredits) { + extraCreditsToSpend = availableExtraCredits; + } + + if (planAllowanceToSpend > 0) { + await addToCreditsLedger(dbAdapter, { + userId, + creditAmount: -planAllowanceToSpend, + creditType: 'plan_allowance_used', + subscriptionCycleId: subscriptionCycle.id, + }); + } + + if (extraCreditsToSpend > 0) { + await addToCreditsLedger(dbAdapter, { + userId, + creditAmount: -extraCreditsToSpend, + creditType: 'extra_credit_used', + subscriptionCycleId: subscriptionCycle.id, + }); + } + } +} diff --git a/packages/billing/stripe-webhook-handlers/checkout-session-completed.ts b/packages/billing/stripe-webhook-handlers/checkout-session-completed.ts index 9d7c108758..d176fa5057 100644 --- a/packages/billing/stripe-webhook-handlers/checkout-session-completed.ts +++ b/packages/billing/stripe-webhook-handlers/checkout-session-completed.ts @@ -1,6 +1,8 @@ import { type DBAdapter } from '@cardstack/runtime-common'; import { addToCreditsLedger, + getCurrentActiveSubscription, + getMostRecentSubscriptionCycle, getUserByStripeId, insertStripeEvent, markStripeEventAsProcessed, @@ -50,11 +52,27 @@ export async function handleCheckoutSessionCompleted( ); } + let subscription = await getCurrentActiveSubscription(dbAdapter, user.id); + if (!subscription) { + throw new Error( + `User ${user.id} has no subscription, cannot add extra credits`, + ); + } + let subscriptionCycle = await getMostRecentSubscriptionCycle( + dbAdapter, + subscription!.id, + ); + if (!subscriptionCycle) { + throw new Error( + `User ${user.id} has no subscription cycle, cannot add extra credits`, + ); + } + await addToCreditsLedger(dbAdapter, { userId: user.id, creditAmount: creditReloadAmount, creditType: 'extra_credit', - subscriptionCycleId: null, + subscriptionCycleId: subscriptionCycle.id, }); } }); diff --git a/packages/billing/stripe-webhook-handlers/index.ts b/packages/billing/stripe-webhook-handlers/index.ts index 5f343e3d37..4b53bf6c3c 100644 --- a/packages/billing/stripe-webhook-handlers/index.ts +++ b/packages/billing/stripe-webhook-handlers/index.ts @@ -124,9 +124,6 @@ export default async function stripeWebhookHandler( let type = event.type; - // For adding extra credits, we should listen for charge.succeeded, and for - // subsciptions, we should listen for invoice.payment_succeeded (I discovered this when I was - // testing which webhooks arrive for both types of payments) switch (type) { // These handlers should eventually become jobs which workers will process asynchronously case 'invoice.payment_succeeded': diff --git a/packages/realm-server/tests/billing-test.ts b/packages/realm-server/tests/billing-test.ts index c0cd26a1bc..60385f7998 100644 --- a/packages/realm-server/tests/billing-test.ts +++ b/packages/realm-server/tests/billing-test.ts @@ -18,6 +18,9 @@ import { addToCreditsLedger, insertSubscription, User, + spendCredits, + Plan, + Subscription, } from '@cardstack/billing/billing-queries'; import { @@ -768,6 +771,25 @@ module('billing', function (hooks) { }); test('add extra credits to user ledger when checkout session completed', async function (assert) { + let creatorPlan = await insertPlan( + dbAdapter, + 'Creator', + 12, + 2500, + 'prod_creator', + ); + let subscription = await insertSubscription(dbAdapter, { + user_id: user.id, + plan_id: creatorPlan.id, + started_at: 1, + status: 'active', + stripe_subscription_id: 'sub_1234567890', + }); + await insertSubscriptionCycle(dbAdapter, { + subscriptionId: subscription.id, + periodStart: 1, + periodEnd: 2, + }); let stripeCheckoutSessionCompletedEvent = { id: 'evt_1234567890', object: 'event', @@ -796,4 +818,153 @@ module('billing', function (hooks) { assert.strictEqual(availableExtraCredits, 25000); }); }); + + module('AI usage tracking', function (hooks) { + let user: User; + let creatorPlan: Plan; + let subscription: Subscription; + let subscriptionCycle: SubscriptionCycle; + + hooks.beforeEach(async function () { + user = await insertUser(dbAdapter, 'testuser', 'cus_123'); + creatorPlan = await insertPlan( + dbAdapter, + 'Creator', + 12, + 2500, + 'prod_creator', + ); + subscription = await insertSubscription(dbAdapter, { + user_id: user.id, + plan_id: creatorPlan.id, + started_at: 1, + status: 'active', + stripe_subscription_id: 'sub_1234567890', + }); + subscriptionCycle = await insertSubscriptionCycle(dbAdapter, { + subscriptionId: subscription.id, + periodStart: 1, + periodEnd: 2, + }); + }); + + test('spends ai credits correctly when no extra credits are available', async function (assert) { + // User receives 2500 credits for the creator plan and spends 2490 credits + await addToCreditsLedger(dbAdapter, { + userId: user.id, + creditAmount: creatorPlan.creditsIncluded, + creditType: 'plan_allowance', + subscriptionCycleId: subscriptionCycle.id, + }); + + await addToCreditsLedger(dbAdapter, { + userId: user.id, + creditAmount: -2490, + creditType: 'plan_allowance_used', + subscriptionCycleId: subscriptionCycle.id, + }); + + assert.strictEqual( + await sumUpCreditsLedger(dbAdapter, { + userId: user.id, + }), + 10, + ); + + await spendCredits(dbAdapter, user.id, 2); + + assert.strictEqual( + await sumUpCreditsLedger(dbAdapter, { + userId: user.id, + }), + 8, + ); + + await spendCredits(dbAdapter, user.id, 5); + + assert.strictEqual( + await sumUpCreditsLedger(dbAdapter, { + userId: user.id, + }), + 3, + ); + + // Make sure that we can't spend more credits than the user has - in this case user has 3 credits left and we try to spend 5 + await spendCredits(dbAdapter, user.id, 5); + assert.strictEqual( + await sumUpCreditsLedger(dbAdapter, { + userId: user.id, + }), + 0, + ); + }); + + test('spends ai credits correctly when extra credits are available', async function (assert) { + // User receives 2500 credits for the creator plan and spends 2490 credits + await addToCreditsLedger(dbAdapter, { + userId: user.id, + creditAmount: creatorPlan.creditsIncluded, + creditType: 'plan_allowance', + subscriptionCycleId: subscriptionCycle.id, + }); + + await addToCreditsLedger(dbAdapter, { + userId: user.id, + creditAmount: -2490, + creditType: 'plan_allowance_used', + subscriptionCycleId: subscriptionCycle.id, + }); + + assert.strictEqual( + await sumUpCreditsLedger(dbAdapter, { + userId: user.id, + }), + 10, + ); + + // Add 5 extra credits + await addToCreditsLedger(dbAdapter, { + userId: user.id, + creditAmount: 5, + creditType: 'extra_credit', + subscriptionCycleId: null, + }); + + // User has 15 credits in total: 10 credits from the plan allowance and 5 extra credits + assert.strictEqual( + await sumUpCreditsLedger(dbAdapter, { + userId: user.id, + }), + 15, + ); + + // This should spend 10 credits from the plan allowance and 2 from the extra credits + await spendCredits(dbAdapter, user.id, 12); + + // Plan allowance is now 0, 3 credits left from the extra credits + assert.strictEqual( + await sumUpCreditsLedger(dbAdapter, { + userId: user.id, + }), + 3, + ); + + // Make sure the available credits come from the extra credits and not the plan allowance + assert.strictEqual( + await sumUpCreditsLedger(dbAdapter, { + userId: user.id, + creditType: ['plan_allowance', 'plan_allowance_used'], + }), + 0, + ); + + assert.strictEqual( + await sumUpCreditsLedger(dbAdapter, { + userId: user.id, + creditType: ['extra_credit', 'extra_credit_used'], + }), + 3, + ); + }); + }); }); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d48ee4ab93..c659c62df9 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -117,8 +117,14 @@ importers: packages/ai-bot: dependencies: + '@cardstack/billing': + specifier: workspace:* + version: link:../billing + '@cardstack/postgres': + specifier: workspace:* + version: link:../postgres '@cardstack/runtime-common': - specifier: workspace:^ + specifier: workspace:* version: link:../runtime-common '@sentry/node': specifier: ^8.31.0