-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference.ts
47 lines (40 loc) · 1.25 KB
/
inference.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import tokenizer from "../../public/tokenizer/tokenizer.json";
import tokenizer_config from "../../public/tokenizer/tokenizer_config.json";
import { BertTokenizer } from "@xenova/transformers";
import * as ort from "onnxruntime-web";
// import { AutoTokenizer } from "../tokenizer";
ort.env.wasm.numThreads = 3;
ort.env.wasm.simd = true;
export async function load_session(model = "model/model-int8.onnx") {
const session = await ort.InferenceSession.create(model, {
executionProviders: ["wasm"],
graphOptimizationLevel: "all",
});
return session;
}
export function Tokenizer() {
try {
const bert_tokenizer = new BertTokenizer(tokenizer, tokenizer_config);
return bert_tokenizer;
} catch (error) {
console.log(error);
return () => {};
}
}
export async function inference(text: string, tokenizer: any, session: any) {
const { input_ids, attention_mask } = tokenizer(text, {
padding: true,
truncation: true,
max_length: 512,
});
const input = {
input_ids,
attention_mask,
};
const output = await session?.run(input);
// @ts-ignore
const logits: Float32Array = output?.["logits"]?.data ?? Float32Array;
const result =
Array?.from(logits)?.[0] * 13.3627190349059 + 10.85810766787474;
return result;
}