-
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtemplating.ts
135 lines (114 loc) · 4.21 KB
/
templating.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// @ts-types="@/types/jinja.d.ts"
import {
ArrayLiteral,
Environment,
Identifier,
Interpreter,
Literal,
SetStatement,
Template,
} from "@huggingface/jinja";
import * as z from "@/common/myZod.ts";
import * as Path from "@std/path";
// ts-types="@types/strftime"
import strftime from "strftime";
// From @huggingface/jinja
export function range(start: number, stop?: number, step = 1): number[] {
if (stop === undefined) {
stop = start;
start = 0;
}
const result: number[] = [];
for (let i = start; i < stop; i += step) {
result.push(i);
}
return result;
}
const TemplateMetadataSchema = z.object({
stop_strings: z.array(z.string()).optional(),
tool_start: z.string().optional(),
tool_start_token: z.number().optional(),
});
type TemplateMetadata = z.infer<typeof TemplateMetadataSchema>;
export class PromptTemplate {
name: string;
rawTemplate: string;
template: Template;
metadata: TemplateMetadata;
public constructor(
name: string,
rawTemplate: string,
) {
this.name = name;
this.rawTemplate = rawTemplate;
this.template = new Template(rawTemplate);
this.metadata = this.extractMetadata(this.template);
}
// Overrides the template's render function to expose the env
public render(context: Record<string, unknown> = {}): string {
const env = new Environment();
// Environment vars
env.set("false", false);
env.set("true", true);
// Function vars
env.set("raise_exception", (args: string) => {
throw new Error(args);
});
env.set("strftime_now", (format: string) => {
return strftime(format);
});
env.set("range", range);
// Add custom template vars
for (const [key, value] of Object.entries(context)) {
env.set(key, value);
}
// Run the template
const interpreter = new Interpreter(env);
const response = interpreter.run(this.template.parsed);
// Value is always a string here
return response.value as string;
}
private extractMetadata(template: Template) {
const metadata: TemplateMetadata = TemplateMetadataSchema.parse({});
template.parsed.body.forEach((statement) => {
if (statement.type === "Set") {
const setStatement = statement as SetStatement;
const assignee = setStatement.assignee as Identifier;
const foundMetaKey = Object.keys(TemplateMetadataSchema.shape)
.find(
(key) => key === assignee.value,
) as keyof TemplateMetadata;
if (foundMetaKey) {
const fieldSchema =
TemplateMetadataSchema.shape[foundMetaKey];
let result: unknown;
if (setStatement.value.type === "ArrayLiteral") {
const arrayValue = setStatement.value as ArrayLiteral;
result = arrayValue.value.map((e) => {
const literalValue = e as Literal<unknown>;
return literalValue.value;
});
} else if (setStatement.value.type.endsWith("Literal")) {
const literalValue = setStatement.value as Literal<
unknown
>;
result = literalValue.value;
}
const parsedValue = fieldSchema.safeParse(result);
if (parsedValue.success) {
// deno-lint-ignore no-explicit-any
metadata[foundMetaKey] = parsedValue.data as any;
}
}
}
});
return metadata;
}
static async fromFile(templatePath: string) {
const parsedPath = Path.parse(templatePath);
parsedPath.ext = ".jinja";
const formattedPath = Path.format({ ...parsedPath, base: undefined });
const rawTemplate = await Deno.readTextFile(formattedPath);
return new PromptTemplate(parsedPath.name, rawTemplate);
}
}