Skip to content

Commit b28b673

Browse files
authored
Make StreamingDataLoader shuffle to set shuffle to datasets. (#19481)
1 parent 5c9a6fa commit b28b673

File tree

4 files changed

+29
-0
lines changed

4 files changed

+29
-0
lines changed

src/lightning/data/streaming/combined.py

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def set_epoch(self, current_epoch: int) -> None:
6666
for dataset in self._datasets:
6767
dataset.set_epoch(current_epoch)
6868

69+
def set_shuffle(self, shuffle: bool) -> None:
70+
"""Set the current shuffle to the datasets."""
71+
for dataset in self._datasets:
72+
dataset.set_shuffle(shuffle)
73+
6974
def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
7075
if any(not isinstance(d, StreamingDataset) for d in datasets):
7176
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")

src/lightning/data/streaming/dataloader.py

+6
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def __init__(
541541
profile_batches: Union[bool, int] = False,
542542
profile_dir: Optional[str] = None,
543543
prefetch_factor: Optional[int] = None,
544+
shuffle: Optional[bool] = None,
544545
**kwargs: Any,
545546
) -> None: # pyright: ignore
546547
if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
@@ -549,6 +550,11 @@ def __init__(
549550
f" Found {dataset}."
550551
)
551552

553+
if shuffle is not None:
554+
dataset.set_shuffle(shuffle)
555+
556+
shuffle = None
557+
552558
if profile_batches and not _VIZ_TRACKER_AVAILABLE:
553559
raise ModuleNotFoundError("To use profile_batches, viztracer is required. Run `pip install viztracer`")
554560

src/lightning/data/streaming/dataset.py

+3
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def __init__(
107107
self.serializers = serializers
108108
self._state_dict: Optional[Dict[str, Any]] = None
109109

110+
def set_shuffle(self, shuffle: bool) -> None:
111+
self.shuffle = shuffle
112+
110113
def set_epoch(self, current_epoch: int) -> None:
111114
"""Set the current epoch to the dataset on epoch starts.
112115

tests/tests_data/streaming/test_dataloader.py

+15
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ def __init__(self, size, step):
1212
self.size = size
1313
self.step = step
1414
self.counter = 0
15+
self.shuffle = None
16+
17+
def set_shuffle(self, shuffle):
18+
self.shuffle = shuffle
1519

1620
def __len__(self):
1721
return self.size
@@ -92,3 +96,14 @@ def test_dataloader_profiling(profile, tmpdir, monkeypatch):
9296
batches.append(batch)
9397

9498
assert os.path.exists(os.path.join(tmpdir, "result.json"))
99+
100+
101+
def test_dataloader_shuffle():
102+
dataset = TestCombinedStreamingDataset(
103+
[TestStatefulDataset(10, 1), TestStatefulDataset(10, -1)], 42, weights=(0.5, 0.5)
104+
)
105+
assert dataset._datasets[0].shuffle is None
106+
assert dataset._datasets[1].shuffle is None
107+
StreamingDataLoader(dataset, batch_size=2, num_workers=1, shuffle=True)
108+
assert dataset._datasets[0].shuffle
109+
assert dataset._datasets[1].shuffle

0 commit comments

Comments
 (0)