Skip to content

Commit dc500b5

Browse files
thomasthomas
thomas
authored and
thomas
committed
update
1 parent 265025b commit dc500b5

File tree

3 files changed

+14
-0
lines changed

3 files changed

+14
-0
lines changed

src/lightning/data/streaming/combined.py

+5
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ def set_epoch(self, current_epoch: int) -> None:
6666
for dataset in self._datasets:
6767
dataset.set_epoch(current_epoch)
6868

69+
def set_shuffle(self, shuffle: bool) -> None:
70+
"""Set the current shuffle to the datasets."""
71+
for dataset in self._datasets:
72+
dataset.set_shuffle(shuffle)
73+
6974
def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
7075
if any(not isinstance(d, StreamingDataset) for d in datasets):
7176
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")

src/lightning/data/streaming/dataloader.py

+6
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ def __init__(
541541
profile_batches: Union[bool, int] = False,
542542
profile_dir: Optional[str] = None,
543543
prefetch_factor: Optional[int] = None,
544+
shuffle: Optional[bool] = None,
544545
**kwargs: Any,
545546
) -> None: # pyright: ignore
546547
if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
@@ -549,6 +550,11 @@ def __init__(
549550
f" Found {dataset}."
550551
)
551552

553+
if shuffle is not None:
554+
dataset.set_shuffle(shuffle)
555+
556+
shuffle = None
557+
552558
if profile_batches and not _VIZ_TRACKER_AVAILABLE:
553559
raise ModuleNotFoundError("To use profile_batches, viztracer is required. Run `pip install viztracer`")
554560

src/lightning/data/streaming/dataset.py

+3
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def __init__(
9090
self.serializers = serializers
9191
self._state_dict: Optional[Dict[str, Any]] = None
9292

93+
def set_shuffle(self, shuffle: bool) -> None:
94+
self.shuffle = shuffle
95+
9396
def set_epoch(self, current_epoch: int) -> None:
9497
"""Set the current epoch to the dataset on epoch starts.
9598

0 commit comments

Comments
 (0)