Skip to content

Commit 4ba80a0

Browse files
authored
fix streaming packing test (#2454)
* fix streaming packing test * constrain amount of text generated
1 parent c496821 commit 4ba80a0

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

tests/test_packed_pretraining.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,26 @@ def random_text(self):
2323
# seed with random.seed(0) for reproducibility
2424
random.seed(0)
2525

26-
# generate 20 rows of random text with "words" of between 2 and 10 characters and
26+
# generate row of random text with "words" of between 2 and 10 characters and
2727
# between 400 to 1200 characters per line
28-
data = [
29-
"".join(random.choices(string.ascii_lowercase, k=random.randint(2, 10)))
30-
for _ in range(20)
31-
] + [
32-
" ".join(
33-
random.choices(string.ascii_lowercase, k=random.randint(400, 1200))
28+
def rand_txt():
29+
return " ".join(
30+
[
31+
"".join(
32+
random.choices(string.ascii_lowercase, k=random.randint(2, 10))
33+
)
34+
for _ in range(random.randint(50, 200))
35+
]
3436
)
35-
for _ in range(20)
36-
]
37+
38+
# Create a list of 2000 random texts rather than just using it within the
39+
# generator so the test runs faster
40+
data = [rand_txt() for _ in range(500)]
3741

3842
# Create an IterableDataset
3943
def generator():
40-
for text in data:
41-
yield {"text": text}
44+
for row in data:
45+
yield {"text": row}
4246

4347
return IterableDataset.from_generator(generator)
4448

@@ -92,7 +96,7 @@ def test_packing_stream_dataset(self, tokenizer_huggyllama, random_text):
9296
)
9397
idx = 0
9498
for data in trainer_loader:
95-
if idx > 10:
99+
if idx > 3:
96100
break
97101
assert data["input_ids"].shape == torch.Size(
98102
[1, original_bsz * cfg.sequence_len]

0 commit comments

Comments
 (0)