-
Notifications
You must be signed in to change notification settings - Fork 28
/
Copy pathtext.spec.ts
153 lines (125 loc) · 6.25 KB
/
text.spec.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import { expect } from "chai";
import { tokenize } from "./text.js";
import { AutoTokenizer } from "@xenova/transformers";
import { Repeat } from "immutable";
import { PreTrainedTokenizer } from "@xenova/transformers";
interface TokenizerOutput {
input_ids: number[];
}
/**
* Encodes the text into token IDs and then decodes them back to text
* Special tokens are skipped during decoding
*
* @param tokenizer - An instance of a PreTrainedTokenizer
* @param text - The text to process
* @returns The decoded text obtained after encoding and then decoding
*/
export function encodeDecode(tokenizer: PreTrainedTokenizer, text: string): string {
// Encode the text using the tokenizer.
const encoding = tokenizer(text, { return_tensor: false }) as TokenizerOutput;
// Decode the token IDs back into text while skipping special tokens.
return tokenizer.decode(encoding.input_ids, { skip_special_tokens: true });
}
describe("text processing", () => {
const text = [
"Hello world, a bc 1 2345, '? 976. Wikipedia is a free content online encyclopedia",
"written and maintained by a community \n of volunteers, known as Wikipedians.",
"Founded by Jimmy Wales and Larry Sanger on January 15, 2001, Wikipedia is hosted by the",
"Wikimedia Foundation, an American nonprofit organization that employs a staff of over 700 people.[7]"
].join(" ");
const expectedTokens = [
15496, 995, 11, 257, 47125, 352, 2242, 2231, 11, 705, 30, 860, 4304, 13,
15312, 318, 257, 1479, 2695, 2691, 45352, 3194, 290, 9456, 416, 257, 2055,
220, 198, 286, 11661, 11, 1900, 355, 11145, 46647, 1547, 13, 4062, 276, 416,
12963, 11769, 290, 13633, 311, 2564, 319, 3269, 1315, 11, 5878, 11, 15312,
318, 12007, 416, 262, 44877, 5693, 11, 281, 1605, 15346, 4009, 326, 24803,
257, 3085, 286, 625, 13037, 661, 3693, 22, 60,
];
const shortText = 'import { AutoTokenizer } from "@xenova/transformers";'
// with GPT 2 tokenizer
const shortExpectedTokens = [
11748, 1391, 11160, 30642, 7509, 1782, 422,
44212, 87, 268, 10071, 14, 35636, 364, 8172
]
it("can tokenize text with the Llama 3 tokenizer", async () => {
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-3-tokenizer");
// Tokenizer playgrounds aren't consistent: https://github.com/huggingface/transformers.js/issues/1019
// Tokenization with python:
// from transformers import AutoTokenizer
// tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
// tokenizer.encode(text, add_special_tokens=False)
const expectedTokens = [
9906, 1917, 11, 264, 18399, 220, 16, 220, 11727, 20, 11, 32167,
220, 25208, 13, 27685, 374, 264, 1949, 2262, 2930, 83708, 5439, 323, 18908,
555, 264, 4029, 720, 315, 23872, 11, 3967, 439, 119234, 291, 5493, 13, 78811,
555, 28933, 23782, 323, 30390, 328, 4091, 389, 6186, 220, 868, 11, 220, 1049,
16, 11, 27685, 374, 21685, 555, 279, 90940, 5114, 11, 459, 3778, 33184, 7471,
430, 51242, 264, 5687, 315, 927, 220, 7007, 1274, 8032, 22, 60
]
const tokens = tokenize(tokenizer, text);
expect(tokens.toArray()).to.be.deep.equal(expectedTokens);
});
it("can tokenize text with the GPT2 tokenizer", async () => {
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
const tokens = tokenize(tokenizer, text);
expect(tokens.toArray()).to.be.deep.equal(expectedTokens);
});
it("truncates until expected length", async () => {
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
const tokens = tokenize(tokenizer, text, {truncation: true, max_length: 10});
expect(tokens.toArray()).to.be.deep.equal(expectedTokens.slice(0, 10));
});
it("pads sequence until enough token are generated", async () => {
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
const max_length = 20
const tokens = tokenize(tokenizer, shortText, {padding: true, max_length});
const paddedSequence = Repeat(tokenizer.pad_token_id, max_length - shortExpectedTokens.length)
.concat(shortExpectedTokens).toArray();
expect(tokens.toArray()).to.be.deep.equal(paddedSequence);
});
it("can pad on right side", async () => {
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
const max_length = 20
const tokens = tokenize(tokenizer, shortText, {padding: true, padding_side: 'right', max_length});
const paddedSequence = shortExpectedTokens.concat(
Repeat(tokenizer.pad_token_id, max_length - shortExpectedTokens.length).toArray()
);
expect(tokens.toArray()).to.be.deep.equal(paddedSequence);
});
});
describe("Multi-Tokenizer Tests", function () {
this.timeout(20000);
const sampleText = "Hello, world! This is a test string to check tokenization.";
// List of tokenizer names to test
const tokenizerNames = [
"Xenova/gpt2",
"Xenova/llama-3-tokenizer",
// "Xenova/bert-base-uncased", // takes too long
"Xenova/roberta-base",
"Xenova/distilbert-base-uncased"
];
tokenizerNames.forEach((name) => {
it(`should tokenize text using tokenizer "${name}"`, async () => {
const tokenizer = await AutoTokenizer.from_pretrained(name);
const tokens = tokenize(tokenizer, sampleText);
const tokenArray = tokens.toArray();
// Checks that we got a non-empty array of tokens and that each token is a number.
expect(tokenArray).to.be.an("array").that.is.not.empty;
tokenArray.forEach((token) => {
expect(token).to.be.a("number");
});
});
});
});
describe("Encode-Decode tokenization", function () {
this.timeout(20000);
it("should return text close to the original after encode-decode tokenization using GPT2 tokenizer", async function () {
// Load the GPT2 tokenizer
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
const originalText = "Hello, world! This is a test for encode-decode tokenization.";
// Perform round-trip tokenization
const decodedText = encodeDecode(tokenizer, originalText);
// Check that the decoded text is almost equal to the original text
expect(decodedText).to.equal(originalText);
});
});