-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaisdk_client.ts
112 lines (100 loc) · 2.83 KB
/
aisdk_client.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import {
CoreAssistantMessage,
CoreMessage,
CoreSystemMessage,
CoreTool,
CoreUserMessage,
generateObject,
generateText,
ImagePart,
LanguageModel,
TextPart,
} from "ai";
import { ChatCompletion } from "openai/resources/chat/completions";
import {
CreateChatCompletionOptions,
LLMClient,
AvailableModel,
} from "@browserbasehq/stagehand";
export class AISdkClient extends LLMClient {
public type = "aisdk" as const;
private model: LanguageModel;
constructor({ model }: { model: LanguageModel }) {
super(model.modelId as AvailableModel);
this.model = model;
}
async createChatCompletion<T = ChatCompletion>({
options,
}: CreateChatCompletionOptions): Promise<T> {
const formattedMessages: CoreMessage[] = options.messages.map((message) => {
if (Array.isArray(message.content)) {
if (message.role === "system") {
const systemMessage: CoreSystemMessage = {
role: "system",
content: message.content
.map((c) => ("text" in c ? c.text : ""))
.join("\n"),
};
return systemMessage;
}
const contentParts = message.content.map((content) => {
if ("image_url" in content) {
const imageContent: ImagePart = {
type: "image",
image: content.image_url.url,
};
return imageContent;
} else {
const textContent: TextPart = {
type: "text",
text: content.text,
};
return textContent;
}
});
if (message.role === "user") {
const userMessage: CoreUserMessage = {
role: "user",
content: contentParts,
};
return userMessage;
} else {
const textOnlyParts = contentParts.map((part) => ({
type: "text" as const,
text: part.type === "image" ? "[Image]" : part.text,
}));
const assistantMessage: CoreAssistantMessage = {
role: "assistant",
content: textOnlyParts,
};
return assistantMessage;
}
}
return {
role: message.role,
content: message.content,
};
});
if (options.response_model) {
const response = await generateObject({
model: this.model,
messages: formattedMessages,
schema: options.response_model.schema,
});
return response.object;
}
const tools: Record<string, CoreTool> = {};
for (const rawTool of options.tools || []) {
tools[rawTool.name] = {
description: rawTool.description,
parameters: rawTool.parameters,
};
}
const response = await generateText({
model: this.model,
messages: formattedMessages,
tools,
});
return response as T;
}
}