Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if user has enough credits before responding with AI generated content #1824

95 changes: 95 additions & 0 deletions packages/ai-bot/lib/ai-billing.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import {
getCurrentActiveSubscription,
getUserByMatrixUserId,
spendCredits,
sumUpCreditsLedger,
} from '@cardstack/billing/billing-queries';
import { PgAdapter, TransactionManager } from '@cardstack/postgres';
import { logger, retry } from '@cardstack/runtime-common';
import * as Sentry from '@sentry/node';

let log = logger('ai-bot');

export async function saveUsageCost(
pgAdapter: PgAdapter,
matrixUserId: string,
generationId: string,
) {
try {
// Generation data is sometimes not immediately available, so we retry a couple of times until we are able to get the cost
let costInUsd = await retry(() => fetchGenerationCost(generationId), {
retries: 10,
delayMs: 500,
});

let creditsConsumed = Math.round(costInUsd / 0.001);

let user = await getUserByMatrixUserId(pgAdapter, matrixUserId);

// This check is for the transition period where we don't have subscriptions fully rolled out yet.
// When we have assurance that all users who use the bot have subscriptions, we can remove this subscription check.
let subscription = await getCurrentActiveSubscription(pgAdapter, user!.id);
if (!subscription) {
log.info(
`user ${matrixUserId} has no subscription, skipping credit usage tracking`,
);
return Promise.resolve();
}

if (!user) {
throw new Error(
`should not happen: user with matrix id ${matrixUserId} not found in the users table`,
);
}

let txManager = new TransactionManager(pgAdapter);

await txManager.withTransaction(async () => {
await spendCredits(pgAdapter, user!.id, creditsConsumed);

// TODO: send a signal to the host app to update credits balance displayed in the UI
});
} catch (err) {
log.error(
`Failed to track AI usage (matrixUserId: ${matrixUserId}, generationId: ${generationId}):`,
err,
);
Sentry.captureException(err);
// Don't throw, because we don't want to crash the bot over this
}
}

export async function getAvailableCredits(
pgAdapter: PgAdapter,
matrixUserId: string,
) {
let user = await getUserByMatrixUserId(pgAdapter, matrixUserId);

if (!user) {
throw new Error(
`should not happen: user with matrix id ${matrixUserId} not found in the users table`,
);
}

let availableCredits = await sumUpCreditsLedger(pgAdapter, {
userId: user.id,
});

return availableCredits;
}

async function fetchGenerationCost(generationId: string) {
let response = await (
await fetch(`https://openrouter.ai/api/v1/generation?id=${generationId}`, {
headers: {
Authorization: `Bearer ${process.env.OPENROUTER_API_KEY}`,
},
})
).json();

if (response.error && response.error.includes('not found')) {
return null;
}

return response.data.total_cost;
}
4 changes: 2 additions & 2 deletions packages/ai-bot/lib/matrix.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ function getErrorMessage(error: any): string {
return `OpenAI error: ${error.name} - ${error.message}`;
}
if (typeof error === 'string') {
return `Unknown error: ${error}`;
return error;
}
return `Unknown error`;
return 'Unknown error';
}
2 changes: 1 addition & 1 deletion packages/ai-bot/lib/send-response.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ export class Responder {
}
}

async onError(error: OpenAIError) {
async onError(error: OpenAIError | string) {
Sentry.captureException(error);
return await sendError(
this.client,
Expand Down
82 changes: 72 additions & 10 deletions packages/ai-bot/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,43 @@
import type { MatrixEvent as DiscreteMatrixEvent } from 'https://cardstack.com/base/matrix-event';
import * as Sentry from '@sentry/node';

import { getAvailableCredits, saveUsageCost } from './lib/ai-billing';
import { PgAdapter } from '@cardstack/postgres';
import {
getUserByMatrixUserId,

Check failure on line 31 in packages/ai-bot/main.ts

View workflow job for this annotation

GitHub Actions / Lint

'getUserByMatrixUserId' is defined but never used. Allowed unused vars must match /^_/u

Check failure on line 31 in packages/ai-bot/main.ts

View workflow job for this annotation

GitHub Actions / Lint

'getUserByMatrixUserId' is defined but never used. Allowed unused vars must match /^_/u
sumUpCreditsLedger,

Check failure on line 32 in packages/ai-bot/main.ts

View workflow job for this annotation

GitHub Actions / Lint

'sumUpCreditsLedger' is defined but never used. Allowed unused vars must match /^_/u

Check failure on line 32 in packages/ai-bot/main.ts

View workflow job for this annotation

GitHub Actions / Lint

'sumUpCreditsLedger' is defined but never used. Allowed unused vars must match /^_/u
} from '@cardstack/billing/billing-queries';

let log = logger('ai-bot');

let trackAiUsageCostPromises = new Map<string, Promise<void>>();

class Assistant {
private openai: OpenAI;
private client: MatrixClient;
pgAdapter: PgAdapter;
id: string;

constructor(client: MatrixClient, id: string) {
this.openai = new OpenAI({
baseURL: 'https://openrouter.ai/api/v1', // We use openrouter so that we can track usage cost in $
baseURL: 'https://openrouter.ai/api/v1',
apiKey: process.env.OPENROUTER_API_KEY,
});
this.id = id;
this.client = client;
this.pgAdapter = new PgAdapter();
}

async trackAiUsageCost(matrixUserId: string, generationId: string) {
if (trackAiUsageCostPromises.has(matrixUserId)) {
return;
}
trackAiUsageCostPromises.set(
matrixUserId,
saveUsageCost(this.pgAdapter, matrixUserId, generationId).finally(() => {
trackAiUsageCostPromises.delete(matrixUserId);
}),
);
}

getResponse(history: DiscreteMatrixEvent[]) {
Expand Down Expand Up @@ -133,6 +156,7 @@
async function (event, room, toStartOfTimeline) {
try {
let eventBody = event.getContent().body;
let senderMatrixUserId = event.getSender()!;
if (!room) {
return;
}
Expand All @@ -150,15 +174,15 @@
return; // don't respond to card fragments, we just gather these in our history
}

if (event.getSender() === aiBotUserId) {
if (senderMatrixUserId === aiBotUserId) {
return;
}
log.info(
'(%s) (Room: "%s" %s) (Message: %s %s)',
event.getType(),
room?.name,
room?.roomId,
event.getSender(),
senderMatrixUserId,
eventBody,
);

Expand All @@ -177,18 +201,47 @@
}

const responder = new Responder(client, room.roomId);
await responder.initialize();

if (historyError) {
responder.finalize(
return responder.finalize(
'There was an error processing chat history. Please open another session.',
);
return;
}

await responder.initialize();

// Do not generate new responses if previous ones' cost is still being reported
let pendingCreditsConsumptionPromise = trackAiUsageCostPromises.get(
senderMatrixUserId!,
);
if (pendingCreditsConsumptionPromise) {
try {
await pendingCreditsConsumptionPromise;
} catch (e) {
log.error(e);
return responder.onError(
'There was an error saving your Boxel credits usage. Try again or contact support if the problem persists.',
);
}
}

let availableCredits = await getAvailableCredits(
assistant.pgAdapter,
senderMatrixUserId,
);

let minimumCredits = 10;

if (availableCredits < minimumCredits) {
return responder.onError(
`You need a minimum of ${minimumCredits} credits to continue using the AI bot. Please upgrade to a larger plan, or top up your account.`,
);
}

let generationId: string | undefined;
const runner = assistant
.getResponse(history)
.on('chunk', async (chunk, _snapshot) => {
generationId = chunk.id;
await responder.onChunk(chunk);
})
.on('content', async (_delta, snapshot) => {
Expand All @@ -200,9 +253,18 @@
.on('error', async (error) => {
await responder.onError(error);
});
// We also need to catch the error when getting the final content
let finalContent = await runner.finalContent().catch(responder.onError);
await responder.finalize(finalContent);

let finalContent;
try {
finalContent = await runner.finalContent();
await responder.finalize(finalContent);
} catch (error) {
await responder.onError(error);
} finally {
if (generationId) {
assistant.trackAiUsageCost(senderMatrixUserId, generationId);
}
}

if (shouldSetRoomTitle(eventList, aiBotUserId, event)) {
return await assistant.setTitle(room.roomId, history, event);
Expand Down
6 changes: 4 additions & 2 deletions packages/ai-bot/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
{
"name": "@cardstack/ai-bot",
"dependencies": {
"@cardstack/runtime-common": "workspace:^",
"@cardstack/runtime-common": "workspace:*",
"@cardstack/postgres": "workspace:*",
"@cardstack/billing": "workspace:*",
"@sentry/node": "^8.31.0",
"@types/node": "^18.18.5",
"@types/stream-chain": "^2.0.1",
Expand All @@ -21,7 +23,7 @@
},
"scripts": {
"lint": "eslint . --cache --ext ts",
"start": "NODE_NO_WARNINGS=1 ts-node --transpileOnly main",
"start": "NODE_NO_WARNINGS=1 PGDATABASE=boxel PGPORT=5435 ts-node --transpileOnly main",
"test": "NODE_NO_WARNINGS=1 qunit --require ts-node/register/transpile-only tests/index.ts",
"get-chat": "NODE_NO_WARNINGS=1 ts-node --transpileOnly scripts/get_chat.ts"
},
Expand Down
64 changes: 64 additions & 0 deletions packages/billing/billing-queries.ts
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,67 @@ export async function expireRemainingPlanAllowanceInSubscriptionCycle(
subscriptionCycleId,
});
}

export async function spendCredits(
dbAdapter: DBAdapter,
userId: string,
creditsToSpend: number,
) {
let subscription = await getCurrentActiveSubscription(dbAdapter, userId);
if (!subscription) {
throw new Error('active subscription not found');
}
let subscriptionCycle = await getMostRecentSubscriptionCycle(
dbAdapter,
subscription.id,
);
if (!subscriptionCycle) {
throw new Error('subscription cycle not found');
}
let availablePlanAllowanceCredits = await sumUpCreditsLedger(dbAdapter, {
creditType: [
'plan_allowance',
'plan_allowance_used',
'plan_allowance_expired',
],
userId,
});

if (availablePlanAllowanceCredits >= creditsToSpend) {
await addToCreditsLedger(dbAdapter, {
userId,
creditAmount: -creditsToSpend,
creditType: 'plan_allowance_used',
subscriptionCycleId: subscriptionCycle.id,
});
} else {
// If user does not have enough plan allowance credits to cover the spend, try to also use extra credits
let availableExtraCredits = await sumUpCreditsLedger(dbAdapter, {
creditType: ['extra_credit', 'extra_credit_used'],
userId,
});
let planAllowanceToSpend = availablePlanAllowanceCredits; // Spend all plan allowance credits first
let extraCreditsToSpend = creditsToSpend - planAllowanceToSpend;
if (extraCreditsToSpend > availableExtraCredits) {
extraCreditsToSpend = availableExtraCredits;
}

if (planAllowanceToSpend > 0) {
await addToCreditsLedger(dbAdapter, {
userId,
creditAmount: -planAllowanceToSpend,
creditType: 'plan_allowance_used',
subscriptionCycleId: subscriptionCycle.id,
});
}

if (extraCreditsToSpend > 0) {
await addToCreditsLedger(dbAdapter, {
userId,
creditAmount: -extraCreditsToSpend,
creditType: 'extra_credit_used',
subscriptionCycleId: subscriptionCycle.id,
});
}
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { type DBAdapter } from '@cardstack/runtime-common';
import {
addToCreditsLedger,
getCurrentActiveSubscription,
getMostRecentSubscriptionCycle,
getUserByStripeId,
insertStripeEvent,
markStripeEventAsProcessed,
Expand Down Expand Up @@ -50,11 +52,27 @@ export async function handleCheckoutSessionCompleted(
);
}

let subscription = await getCurrentActiveSubscription(dbAdapter, user.id);
if (!subscription) {
throw new Error(
`User ${user.id} has no subscription, cannot add extra credits`,
);
}
let subscriptionCycle = await getMostRecentSubscriptionCycle(
dbAdapter,
subscription!.id,
);
if (!subscriptionCycle) {
throw new Error(
`User ${user.id} has no subscription cycle, cannot add extra credits`,
);
}

await addToCreditsLedger(dbAdapter, {
userId: user.id,
creditAmount: creditReloadAmount,
creditType: 'extra_credit',
subscriptionCycleId: null,
subscriptionCycleId: subscriptionCycle.id,
});
}
});
Expand Down
3 changes: 0 additions & 3 deletions packages/billing/stripe-webhook-handlers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ export default async function stripeWebhookHandler(

let type = event.type;

// For adding extra credits, we should listen for charge.succeeded, and for
// subsciptions, we should listen for invoice.payment_succeeded (I discovered this when I was
// testing which webhooks arrive for both types of payments)
switch (type) {
// These handlers should eventually become jobs which workers will process asynchronously
case 'invoice.payment_succeeded':
Expand Down
Loading
Loading