Skip to content

Commit fba7b93

Browse files
thomasthomas
thomas
authored and
thomas
committed
updarte
1 parent dc500b5 commit fba7b93

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

tests/tests_data/streaming/test_dataloader.py

+15
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ def __init__(self, size, step):
1212
self.size = size
1313
self.step = step
1414
self.counter = 0
15+
self.shuffle = None
16+
17+
def set_shuffle(self, shuffle):
18+
self.shuffle = shuffle
1519

1620
def __len__(self):
1721
return self.size
@@ -92,3 +96,14 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch):
9296
batches.append(batch)
9397

9498
assert os.path.exists(os.path.join(tmpdir, "result.json"))
99+
100+
101+
def test_dataloader_shuffle():
102+
dataset = TestCombinedStreamingDataset(
103+
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
104+
)
105+
assert dataset._datasets[0].shuffle is None
106+
assert dataset._datasets[1].shuffle is None
107+
StreamingDataLoader(dataset, batch_size=2, num_workers=1, shuffle=True)
108+
assert dataset._datasets[0].shuffle
109+
assert dataset._datasets[1].shuffle

0 commit comments

Comments
 (0)