Skip to content

Commit ed367ca

Browse files
authored
StreamingDataLoader: Resolve fault tolerance with the CombinedStreamingDataset and multiple workers (#19326)
1 parent e1a6dd9 commit ed367ca

File tree

10 files changed

+1002
-100
lines changed

10 files changed

+1002
-100
lines changed

.github/workflows/ci-tests-data.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ jobs:
8787
# ls -lh $PYPI_CACHE_DIR
8888

8989
- name: Install package & dependencies
90-
timeout-minutes: 20
90+
timeout-minutes: 30
9191
run: |
9292
pip install -e ".[data-dev]" -U --prefer-binary -f ${TORCH_URL}
9393
pip list

src/lightning/data/streaming/combined.py

+93-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from torch.utils.data import IterableDataset
1818

1919
from lightning.data.streaming.dataset import StreamingDataset
20+
from lightning.data.utilities.env import _WorkerEnv
21+
22+
__NUM_SAMPLES_YIELDED_KEY__ = "__NUM_SAMPLES_YIELDED__"
23+
__SAMPLES_KEY__ = "__SAMPLES__"
2024

2125

2226
class CombinedStreamingDataset(IterableDataset):
@@ -31,6 +35,8 @@ class CombinedStreamingDataset(IterableDataset):
3135
def __init__(
3236
self, datasets: List[StreamingDataset], seed: int = 42, weights: Optional[Sequence[float]] = None
3337
) -> None:
38+
self._check_datasets(datasets)
39+
3440
self._seed = seed
3541
self._datasets = datasets
3642
self._weights = weights
@@ -43,53 +49,128 @@ def __init__(
4349
self._weights = [w / sum(weights) for w in weights]
4450

4551
self._iterator: Optional[_CombinedDatasetIterator] = None
52+
self._use_streaming_dataloader = False
53+
self._num_samples_yielded: Optional[List[int]] = None
54+
self._current_epoch = 0
55+
56+
def set_epoch(self, current_epoch: int) -> None:
57+
"""Set the current epoch to the datasets on epoch starts.
58+
59+
When using the StreamingDataLoader, this is done automatically
60+
61+
"""
62+
self._current_epoch = current_epoch
63+
for dataset in self._datasets:
64+
dataset.set_epoch(current_epoch)
65+
66+
def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
67+
if any(not isinstance(d, StreamingDataset) for d in datasets):
68+
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")
69+
70+
def _set_use_streaming_dataloader(self, use_streaming_dataloader: bool) -> None:
71+
# Used to prevent returning num_samples_yielded when using PyTorch DataLoader
72+
self._use_streaming_dataloader = use_streaming_dataloader
4673

4774
def __len__(self) -> int:
4875
assert self._weights
4976
return int(min([1 / w * len(d) for w, d in zip(self._weights, self._datasets) if w > 0]))
5077

5178
def __iter__(self) -> Iterator[Any]:
5279
assert self._weights
53-
self._iterator = _CombinedDatasetIterator(self._datasets, self._seed, self._weights)
80+
81+
worker_env = _WorkerEnv.detect()
82+
83+
num_samples_yielded = None
84+
85+
if self._num_samples_yielded is not None and worker_env.rank in self._num_samples_yielded:
86+
num_samples_yielded = self._num_samples_yielded[worker_env.rank]
87+
88+
self._iterator = _CombinedDatasetIterator(
89+
self._datasets,
90+
self._seed,
91+
self._weights,
92+
self._use_streaming_dataloader,
93+
num_samples_yielded,
94+
)
5495
return self._iterator
5596

56-
def state_dict(self, num_workers: int, batch_size: int) -> Dict[str, Any]:
97+
def state_dict(
98+
self, num_workers: int, batch_size: int, num_samples_yielded: Optional[List[int]] = None
99+
) -> Dict[str, Any]:
57100
if self._iterator is None:
58-
return {}
101+
if num_samples_yielded is None:
102+
return {}
103+
return _state_dict(self._datasets, num_samples_yielded, num_workers, batch_size)
59104
return self._iterator.state_dict(num_workers, batch_size)
60105

61106
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
62-
if len(state_dict) != len(self._datasets):
107+
if not state_dict:
108+
return
109+
110+
if len(state_dict["dataset"]) != len(self._datasets):
63111
raise RuntimeError(f"The provided state doesn't match the current number of datasets: {self._datasets}.")
64112

65113
for dataset_idx, dataset in enumerate(self._datasets):
66-
if str(dataset_idx) not in state_dict:
114+
if str(dataset_idx) not in state_dict["dataset"]:
67115
raise RuntimeError(f"The provided state doesn't contain the index {dataset_idx}.")
68116

69-
dataset.load_state_dict(state_dict[str(dataset_idx)])
117+
dataset.load_state_dict(state_dict["dataset"][str(dataset_idx)])
118+
119+
# Used to iterate over the sampler to avoid sampling the same samples
120+
if self._use_streaming_dataloader:
121+
self._num_samples_yielded = state_dict["num_samples_yielded"]
70122

71123

72124
class _CombinedDatasetIterator(Iterator):
73-
def __init__(self, datasets: List[StreamingDataset], seed: int, weights: Sequence[float]) -> None:
125+
def __init__(
126+
self,
127+
datasets: List[StreamingDataset],
128+
seed: int,
129+
weights: Sequence[float],
130+
use_streaming_dataloader: bool,
131+
num_samples_yielded: Optional[Any] = None,
132+
) -> None:
74133
self._datasets = datasets
75134
self._dataset_iters = [iter(dataset) for dataset in datasets]
76135
self._dataset_indexes = list(range(len(datasets)))
77136
self._num_samples_yielded = [0 for _ in range(len(datasets))]
78137
self._weights = weights
79138
self._rng = random.Random(seed)
80139

140+
if num_samples_yielded is not None:
141+
self._num_samples_yielded = num_samples_yielded
142+
for _ in range(sum(num_samples_yielded)):
143+
self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
144+
145+
self._use_streaming_dataloader = use_streaming_dataloader
146+
81147
def __next__(self) -> Any:
82148
# randomly select a dataset index
83149
(dataset_index,) = self._rng.choices(self._dataset_indexes, weights=self._weights, k=1)
84150

85151
# keep track the sample was fetched
86152
self._num_samples_yielded[dataset_index] += 1
87153

154+
sample = next(self._dataset_iters[dataset_index])
155+
88156
# return a new sample
89-
return next(self._dataset_iters[dataset_index])
157+
if self._use_streaming_dataloader:
158+
return {
159+
__SAMPLES_KEY__: sample,
160+
__NUM_SAMPLES_YIELDED_KEY__: self._num_samples_yielded,
161+
}
162+
return sample
90163

91164
def state_dict(self, num_workers: int = 0, batch_size: int = 1) -> Dict[str, Any]:
92-
return {
93-
str(dataset_idx): dataset.state_dict(self._num_samples_yielded[dataset_idx], num_workers, batch_size)
94-
for dataset_idx, dataset in enumerate(self._datasets)
95-
}
165+
return _state_dict(self._datasets, self._num_samples_yielded, num_workers, batch_size)
166+
167+
168+
def _state_dict(
169+
datasets: List[StreamingDataset], num_samples_yielded: List[int], num_workers: int = 0, batch_size: int = 1
170+
) -> Dict[str, Any]:
171+
return {
172+
str(dataset_idx): dataset.state_dict(
173+
num_samples_yielded=num_samples_yielded[dataset_idx], num_workers=num_workers, batch_size=batch_size
174+
)
175+
for dataset_idx, dataset in enumerate(datasets)
176+
}

src/lightning/data/streaming/dataloader.py

+126-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
import inspect
1616
import logging
1717
import os
18+
from copy import deepcopy
1819
from importlib import reload
20+
from itertools import cycle
1921
from typing import Any, Callable, Dict, List, Optional, Union
2022

2123
import torch
@@ -32,7 +34,11 @@
3234
from torch.utils.data.sampler import BatchSampler, Sampler
3335

3436
from lightning.data.streaming import Cache
35-
from lightning.data.streaming.combined import CombinedStreamingDataset
37+
from lightning.data.streaming.combined import (
38+
__NUM_SAMPLES_YIELDED_KEY__,
39+
__SAMPLES_KEY__,
40+
CombinedStreamingDataset,
41+
)
3642
from lightning.data.streaming.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE
3743
from lightning.data.streaming.dataset import StreamingDataset
3844
from lightning.data.streaming.sampler import CacheBatchSampler
@@ -341,6 +347,35 @@ def _get_iterator(self) -> "_BaseDataLoaderIter":
341347
return _MultiProcessingDataLoaderIterPatch(self)
342348

343349

350+
class _StreamingMultiProcessingDataLoaderIter(_MultiProcessingDataLoaderIter):
351+
def __init__(self, loader: DataLoader) -> None:
352+
self._loader = loader
353+
self._indexes = (
354+
list(range(self._loader._latest_worker_idx, self._loader.num_workers))
355+
if self._loader._latest_worker_idx > 0
356+
else []
357+
)
358+
super().__init__(loader)
359+
360+
def _try_put_index(self) -> None:
361+
# Used to restart on the right DataLoader worker
362+
if self._loader.restore and self._indexes:
363+
assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
364+
365+
try:
366+
index = self._next_index()
367+
except StopIteration:
368+
return
369+
worker_queue_idx = self._indexes.pop(0)
370+
371+
self._index_queues[worker_queue_idx].put((self._send_idx, index))
372+
self._task_info[self._send_idx] = (worker_queue_idx,)
373+
self._tasks_outstanding += 1
374+
self._send_idx += 1
375+
else:
376+
super()._try_put_index()
377+
378+
344379
class StreamingDataLoader(DataLoader):
345380
"""The `StreamingDataLoader` keeps track of the number of samples fetched in order to enable resumability of the
346381
dataset."""
@@ -355,27 +390,82 @@ def __init__(
355390
num_workers: int = 0,
356391
**kwargs: Any,
357392
) -> None: # pyright: ignore
393+
if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
394+
raise RuntimeError(
395+
"The provided dataset should be either an instance of StreamingDataset or CombinedStreamingDataset."
396+
f" Found {dataset}."
397+
)
398+
399+
self.current_epoch = 0
358400
self.batch_size = batch_size
359401
self.num_workers = num_workers
360-
self.num_samples_yielded = 0
402+
self._num_samples_yielded_streaming = 0
403+
self._num_samples_yielded_combined: Dict[int, List[Any]] = {}
404+
self.rng_state: Optional[Any] = None
405+
self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
406+
self._worker_idx_iter: Optional[Any] = None
407+
self._latest_worker_idx = 0
408+
self.restore = False
361409
super().__init__(dataset, *args, batch_size=batch_size, num_workers=num_workers, **kwargs) # type: ignore
362410

363411
def __iter__(self) -> Any:
412+
if not self.restore:
413+
self._latest_worker_idx = 0
414+
self._worker_idx = cycle(list(range(self.num_workers if self.num_workers > 0 else 1)))
415+
self._worker_idx_iter = iter(self._worker_idx)
416+
self.current_epoch += 1
417+
self._num_samples_yielded_combined = {}
418+
self._num_samples_yielded_streaming = 0
419+
420+
self.dataset.set_epoch(self.current_epoch)
421+
364422
if isinstance(self.dataset, StreamingDataset):
365423
assert self.batch_size
366-
self.num_samples_yielded = 0
367424
for batch in super().__iter__():
368-
self.num_samples_yielded += self.batch_size
425+
self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore
426+
self._num_samples_yielded_streaming += self.batch_size
369427
yield batch
370428
else:
371-
yield from super().__iter__()
429+
self.dataset._set_use_streaming_dataloader(True)
430+
assert self.batch_size
431+
# TODO: Inject a custom collate function to avoid collating the __NUM_SAMPLES_YIELDED__ key
432+
for batch in super().__iter__():
433+
self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore
434+
if isinstance(batch, dict) and __NUM_SAMPLES_YIELDED_KEY__ in batch:
435+
self._num_samples_yielded_combined[self._latest_worker_idx] = [
436+
sample[-1].item() if self.batch_size > 1 else sample.item()
437+
for sample in batch[__NUM_SAMPLES_YIELDED_KEY__]
438+
]
439+
440+
yield batch[__SAMPLES_KEY__]
441+
else:
442+
yield batch
443+
444+
self.restore = False
372445

373446
def state_dict(self) -> Dict[str, Any]:
374447
if isinstance(self.dataset, StreamingDataset):
375448
assert self.batch_size
376-
num_samples = self.num_samples_yielded
377-
return self.dataset.state_dict(num_samples, self.num_workers, self.batch_size)
378-
return self.dataset.state_dict(self.num_workers, self.batch_size)
449+
return {
450+
"dataset": self.dataset.state_dict(
451+
self._num_samples_yielded_streaming, self.num_workers, self.batch_size
452+
),
453+
"current_epoch": self.current_epoch,
454+
"num_samples_yielded": self._num_samples_yielded_streaming,
455+
"latest_worker_idx": self._latest_worker_idx,
456+
}
457+
458+
num_samples_yieled = [0 for _ in range(len(list(self._num_samples_yielded_combined.values())[0]))]
459+
for worker_idx in self._num_samples_yielded_combined:
460+
for dataset_idx, samples_yieled in enumerate(self._num_samples_yielded_combined[worker_idx]):
461+
num_samples_yieled[dataset_idx] += samples_yieled
462+
463+
return {
464+
"dataset": self.dataset.state_dict(self.num_workers, self.batch_size, num_samples_yieled),
465+
"current_epoch": self.current_epoch if self.restore else self.current_epoch - 1,
466+
"latest_worker_idx": self._latest_worker_idx,
467+
"num_samples_yielded": deepcopy(self._num_samples_yielded_combined),
468+
}
379469

380470
def load_state_dict(self, obj: Dict[str, Any]) -> None:
381471
"""Load a dict containing training state (called from non-worker process).
@@ -386,7 +476,34 @@ def load_state_dict(self, obj: Dict[str, Any]) -> None:
386476
obj (Any): The state.
387477
388478
"""
389-
if isinstance(self.dataset, (StreamingDataset, CombinedStreamingDataset)):
479+
self.current_epoch = obj["current_epoch"]
480+
481+
if isinstance(self.dataset, StreamingDataset):
482+
self._num_samples_yielded_streaming = obj["num_samples_yielded"]
483+
else:
484+
self._num_samples_yielded_combined = obj["num_samples_yielded"]
485+
486+
# Used to restart on the next DataLoader worker from the previous run.
487+
self._latest_worker_idx = obj["latest_worker_idx"] + 1
488+
self._worker_idx_iter = iter(self._worker_idx)
489+
for _ in range(self._latest_worker_idx):
490+
next(self._worker_idx_iter)
491+
492+
# Inform we are resuming and disable resetting the StreamingDataLoader state.
493+
# This is toggle back to False when the `__iter__` method of the StreamingDataLoader completes.
494+
self.restore = True
495+
496+
if isinstance(self.dataset, CombinedStreamingDataset):
497+
self.dataset._set_use_streaming_dataloader(True)
390498
self.dataset.load_state_dict(obj)
499+
elif isinstance(self.dataset, StreamingDataset):
500+
self.dataset.load_state_dict(obj["dataset"])
391501
else:
392502
raise RuntimeError("The provided dataset should be a `StreamingDataset` or a `CombinedStreamingDataset`.")
503+
504+
def _get_iterator(self) -> "_BaseDataLoaderIter":
505+
"""Overriden to ensure the `Cache.done()` method is triggered on iteration done."""
506+
if self.num_workers == 0:
507+
return _SingleProcessDataLoaderIter(self)
508+
self.check_worker_number_rationality()
509+
return _StreamingMultiProcessingDataLoaderIter(self)

0 commit comments

Comments
 (0)