Skip to content

Commit 339b416

Browse files
Merge pull request #890 from epfml/tokenization_tests
2 more tokenization tests
2 parents 42c9d3f + 5b23528 commit 339b416

File tree

1 file changed

+152
-83
lines changed

1 file changed

+152
-83
lines changed

discojs/src/processing/text.spec.ts

+152-83
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,153 @@
1-
import { expect } from "chai";
2-
3-
import { tokenize } from "./text.js";
4-
import { AutoTokenizer } from "@xenova/transformers";
5-
import { Repeat } from "immutable";
6-
7-
describe("text processing", () => {
8-
const text = [
9-
"Hello world, a bc 1 2345, '? 976. Wikipedia is a free content online encyclopedia",
10-
"written and maintained by a community \n of volunteers, known as Wikipedians.",
11-
"Founded by Jimmy Wales and Larry Sanger on January 15, 2001, Wikipedia is hosted by the",
12-
"Wikimedia Foundation, an American nonprofit organization that employs a staff of over 700 people.[7]"
13-
].join(" ");
14-
15-
const expectedTokens = [
16-
15496, 995, 11, 257, 47125, 352, 2242, 2231, 11, 705, 30, 860, 4304, 13,
17-
15312, 318, 257, 1479, 2695, 2691, 45352, 3194, 290, 9456, 416, 257, 2055,
18-
220, 198, 286, 11661, 11, 1900, 355, 11145, 46647, 1547, 13, 4062, 276, 416,
19-
12963, 11769, 290, 13633, 311, 2564, 319, 3269, 1315, 11, 5878, 11, 15312,
20-
318, 12007, 416, 262, 44877, 5693, 11, 281, 1605, 15346, 4009, 326, 24803,
21-
257, 3085, 286, 625, 13037, 661, 3693, 22, 60,
22-
];
23-
24-
const shortText = 'import { AutoTokenizer } from "@xenova/transformers";'
25-
// with GPT 2 tokenizer
26-
const shortExpectedTokens = [
27-
11748, 1391, 11160, 30642, 7509, 1782, 422,
28-
44212, 87, 268, 10071, 14, 35636, 364, 8172
29-
]
30-
31-
it("can tokenize text with the Llama 3 tokenizer", async () => {
32-
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-3-tokenizer");
33-
// Tokenizer playgrounds aren't consistent: https://github.com/huggingface/transformers.js/issues/1019
34-
// Tokenization with python:
35-
// from transformers import AutoTokenizer
36-
// tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
37-
// tokenizer.encode(text, add_special_tokens=False)
38-
const expectedTokens = [
39-
9906, 1917, 11, 264, 18399, 220, 16, 220, 11727, 20, 11, 32167,
40-
220, 25208, 13, 27685, 374, 264, 1949, 2262, 2930, 83708, 5439, 323, 18908,
41-
555, 264, 4029, 720, 315, 23872, 11, 3967, 439, 119234, 291, 5493, 13, 78811,
42-
555, 28933, 23782, 323, 30390, 328, 4091, 389, 6186, 220, 868, 11, 220, 1049,
43-
16, 11, 27685, 374, 21685, 555, 279, 90940, 5114, 11, 459, 3778, 33184, 7471,
44-
430, 51242, 264, 5687, 315, 927, 220, 7007, 1274, 8032, 22, 60
45-
]
46-
const tokens = tokenize(tokenizer, text);
47-
expect(tokens.toArray()).to.be.deep.equal(expectedTokens);
48-
});
49-
50-
it("can tokenize text with the GPT2 tokenizer", async () => {
51-
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
52-
53-
const tokens = tokenize(tokenizer, text);
54-
expect(tokens.toArray()).to.be.deep.equal(expectedTokens);
55-
});
56-
57-
it("truncates until expected length", async () => {
58-
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
59-
60-
const tokens = tokenize(tokenizer, text, {truncation: true, max_length: 10});
61-
expect(tokens.toArray()).to.be.deep.equal(expectedTokens.slice(0, 10));
62-
});
63-
64-
it("pads sequence until enough token are generated", async () => {
65-
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
66-
const max_length = 20
67-
68-
const tokens = tokenize(tokenizer, shortText, {padding: true, max_length});
69-
const paddedSequence = Repeat(tokenizer.pad_token_id, max_length - shortExpectedTokens.length)
70-
.concat(shortExpectedTokens).toArray();
71-
expect(tokens.toArray()).to.be.deep.equal(paddedSequence);
72-
});
73-
74-
it("can pad on right side", async () => {
75-
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
76-
const max_length = 20
77-
78-
const tokens = tokenize(tokenizer, shortText, {padding: true, padding_side: 'right', max_length});
79-
const paddedSequence = shortExpectedTokens.concat(
80-
Repeat(tokenizer.pad_token_id, max_length - shortExpectedTokens.length).toArray()
81-
);
82-
expect(tokens.toArray()).to.be.deep.equal(paddedSequence);
83-
});
1+
import { expect } from "chai";
2+
3+
import { tokenize } from "./text.js";
4+
import { AutoTokenizer } from "@xenova/transformers";
5+
import { Repeat } from "immutable";
6+
import { PreTrainedTokenizer } from "@xenova/transformers";
7+
8+
9+
interface TokenizerOutput {
10+
input_ids: number[];
11+
}
12+
13+
/**
14+
* Encodes the text into token IDs and then decodes them back to text
15+
* Special tokens are skipped during decoding
16+
*
17+
* @param tokenizer - An instance of a PreTrainedTokenizer
18+
* @param text - The text to process
19+
* @returns The decoded text obtained after encoding and then decoding
20+
*/
21+
export function encodeDecode(tokenizer: PreTrainedTokenizer, text: string): string {
22+
// Encode the text using the tokenizer.
23+
const encoding = tokenizer(text, { return_tensor: false }) as TokenizerOutput;
24+
// Decode the token IDs back into text while skipping special tokens.
25+
return tokenizer.decode(encoding.input_ids, { skip_special_tokens: true });
26+
}
27+
28+
29+
describe("text processing", () => {
30+
const text = [
31+
"Hello world, a bc 1 2345, '? 976. Wikipedia is a free content online encyclopedia",
32+
"written and maintained by a community \n of volunteers, known as Wikipedians.",
33+
"Founded by Jimmy Wales and Larry Sanger on January 15, 2001, Wikipedia is hosted by the",
34+
"Wikimedia Foundation, an American nonprofit organization that employs a staff of over 700 people.[7]"
35+
].join(" ");
36+
37+
const expectedTokens = [
38+
15496, 995, 11, 257, 47125, 352, 2242, 2231, 11, 705, 30, 860, 4304, 13,
39+
15312, 318, 257, 1479, 2695, 2691, 45352, 3194, 290, 9456, 416, 257, 2055,
40+
220, 198, 286, 11661, 11, 1900, 355, 11145, 46647, 1547, 13, 4062, 276, 416,
41+
12963, 11769, 290, 13633, 311, 2564, 319, 3269, 1315, 11, 5878, 11, 15312,
42+
318, 12007, 416, 262, 44877, 5693, 11, 281, 1605, 15346, 4009, 326, 24803,
43+
257, 3085, 286, 625, 13037, 661, 3693, 22, 60,
44+
];
45+
46+
const shortText = 'import { AutoTokenizer } from "@xenova/transformers";'
47+
// with GPT 2 tokenizer
48+
const shortExpectedTokens = [
49+
11748, 1391, 11160, 30642, 7509, 1782, 422,
50+
44212, 87, 268, 10071, 14, 35636, 364, 8172
51+
]
52+
53+
it("can tokenize text with the Llama 3 tokenizer", async () => {
54+
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-3-tokenizer");
55+
// Tokenizer playgrounds aren't consistent: https://github.com/huggingface/transformers.js/issues/1019
56+
// Tokenization with python:
57+
// from transformers import AutoTokenizer
58+
// tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
59+
// tokenizer.encode(text, add_special_tokens=False)
60+
const expectedTokens = [
61+
9906, 1917, 11, 264, 18399, 220, 16, 220, 11727, 20, 11, 32167,
62+
220, 25208, 13, 27685, 374, 264, 1949, 2262, 2930, 83708, 5439, 323, 18908,
63+
555, 264, 4029, 720, 315, 23872, 11, 3967, 439, 119234, 291, 5493, 13, 78811,
64+
555, 28933, 23782, 323, 30390, 328, 4091, 389, 6186, 220, 868, 11, 220, 1049,
65+
16, 11, 27685, 374, 21685, 555, 279, 90940, 5114, 11, 459, 3778, 33184, 7471,
66+
430, 51242, 264, 5687, 315, 927, 220, 7007, 1274, 8032, 22, 60
67+
]
68+
const tokens = tokenize(tokenizer, text);
69+
expect(tokens.toArray()).to.be.deep.equal(expectedTokens);
70+
});
71+
72+
it("can tokenize text with the GPT2 tokenizer", async () => {
73+
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
74+
75+
const tokens = tokenize(tokenizer, text);
76+
expect(tokens.toArray()).to.be.deep.equal(expectedTokens);
77+
});
78+
79+
it("truncates until expected length", async () => {
80+
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
81+
82+
const tokens = tokenize(tokenizer, text, {truncation: true, max_length: 10});
83+
expect(tokens.toArray()).to.be.deep.equal(expectedTokens.slice(0, 10));
84+
});
85+
86+
it("pads sequence until enough token are generated", async () => {
87+
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
88+
const max_length = 20
89+
90+
const tokens = tokenize(tokenizer, shortText, {padding: true, max_length});
91+
const paddedSequence = Repeat(tokenizer.pad_token_id, max_length - shortExpectedTokens.length)
92+
.concat(shortExpectedTokens).toArray();
93+
expect(tokens.toArray()).to.be.deep.equal(paddedSequence);
94+
});
95+
96+
it("can pad on right side", async () => {
97+
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
98+
const max_length = 20
99+
100+
const tokens = tokenize(tokenizer, shortText, {padding: true, padding_side: 'right', max_length});
101+
const paddedSequence = shortExpectedTokens.concat(
102+
Repeat(tokenizer.pad_token_id, max_length - shortExpectedTokens.length).toArray()
103+
);
104+
expect(tokens.toArray()).to.be.deep.equal(paddedSequence);
105+
});
106+
});
107+
108+
109+
describe("Multi-Tokenizer Tests", function () {
110+
this.timeout(20000);
111+
112+
const sampleText = "Hello, world! This is a test string to check tokenization.";
113+
114+
// List of tokenizer names to test
115+
const tokenizerNames = [
116+
"Xenova/gpt2",
117+
"Xenova/llama-3-tokenizer",
118+
// "Xenova/bert-base-uncased", // takes too long
119+
"Xenova/roberta-base",
120+
"Xenova/distilbert-base-uncased"
121+
];
122+
123+
tokenizerNames.forEach((name) => {
124+
it(`should tokenize text using tokenizer "${name}"`, async () => {
125+
const tokenizer = await AutoTokenizer.from_pretrained(name);
126+
const tokens = tokenize(tokenizer, sampleText);
127+
const tokenArray = tokens.toArray();
128+
129+
// Checks that we got a non-empty array of tokens and that each token is a number.
130+
expect(tokenArray).to.be.an("array").that.is.not.empty;
131+
tokenArray.forEach((token) => {
132+
expect(token).to.be.a("number");
133+
});
134+
});
135+
});
136+
});
137+
138+
139+
describe("Encode-Decode tokenization", function () {
140+
this.timeout(20000);
141+
142+
it("should return text close to the original after encode-decode tokenization using GPT2 tokenizer", async function () {
143+
// Load the GPT2 tokenizer
144+
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/gpt2");
145+
const originalText = "Hello, world! This is a test for encode-decode tokenization.";
146+
147+
// Perform round-trip tokenization
148+
const decodedText = encodeDecode(tokenizer, originalText);
149+
150+
// Check that the decoded text is almost equal to the original text
151+
expect(decodedText).to.equal(originalText);
152+
});
84153
});

0 commit comments

Comments
 (0)