From efa9858456ea14d8ffd735eb115ffbc8718df6ab Mon Sep 17 00:00:00 2001 From: Sahil Date: Fri, 20 Dec 2024 15:28:20 -0800 Subject: [PATCH 1/9] Init Implementation per_stream batching --- src/litdata/streaming/combined.py | 25 +++++++++++++++++++++++-- src/litdata/streaming/dataloader.py | 8 +++++++- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 22149589..9815954e 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -13,7 +13,7 @@ import random from copy import deepcopy -from typing import Any, Dict, Iterator, List, Optional, Sequence +from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence from torch.utils.data import IterableDataset @@ -42,6 +42,7 @@ def __init__( seed: int = 42, weights: Optional[Sequence[float]] = None, iterate_over_all: bool = True, + batching_method: Literal["stratified", "per_stream"] = "stratified", ) -> None: """Enable to stream data from multiple StreamingDataset with the sampling ratio of your choice. @@ -51,6 +52,9 @@ def __init__( weights: The sampling ratio for the datasets iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets. Otherwise, it stops as soon as one raises a StopIteration. + batching_method: When batching_method is "stratified" (default), every sample in a batch is drawn randomly across all datasets. + When batching_method is "per_stream" every sample in a batch is drawn from the same dataset. After each batch, a dataset + is selected at random. """ self._check_datasets(datasets) @@ -67,6 +71,7 @@ def __init__( ) self._iterate_over_all = iterate_over_all + self._batching_method = batching_method if weights is None: # Weighted based on the dataset length @@ -165,6 +170,7 @@ def __iter__(self) -> Iterator[Any]: self._use_streaming_dataloader, num_samples_yielded, self._iterate_over_all, + self._batching_method ) return self._iterator @@ -204,6 +210,7 @@ def __init__( use_streaming_dataloader: bool, num_samples_yielded: Any, iterate_over_all: bool = False, + batching_method: Literal["stratified", "per_stream"] = "stratified", ) -> None: self._datasets = datasets self._dataset_iters = [iter(dataset) for dataset in datasets] @@ -213,6 +220,8 @@ def __init__( self._weights = deepcopy(weights) self._rng = random.Random(seed) self._iterate_over_all = iterate_over_all + self._batching_method = batching_method + self._cur_dataset_index = -1 self._is_done = False if num_samples_yielded is not None: @@ -234,6 +243,7 @@ def __next__(self) -> Any: dataset_index = self._get_dataset_index() elif len(indexes_left) == 1: dataset_index = indexes_left[0] + self._cur_dataset_index = dataset_index return self._get_sample(dataset_index) except StopIteration as e: if len(indexes_left) == 1: @@ -250,11 +260,22 @@ def __next__(self) -> Any: return self._get_sample(self._get_dataset_index()) def _get_dataset_index(self) -> int: + if self._batching_method == "stratified": + # randomly select a dataset index + self._set_new_dataset_index() + elif self._batching_method == "per_stream": + # randomly select a dataset index, if no previous dataset index exists + if self._cur_dataset_index == -1: + self._set_new_dataset_index() + return self._cur_dataset_index + + def _set_new_dataset_index(self): # randomly select a dataset index indexes = [index for index in self._dataset_indexes if index is not None] weights = [w for w in self._weights if w is not None] (dataset_index,) = self._rng.choices(indexes, weights=weights, k=1) - return dataset_index + self._cur_dataset_index = dataset_index + def _get_sample(self, dataset_index: int) -> Any: # get the sample diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 1612f55d..892cec63 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -18,7 +18,7 @@ from copy import deepcopy from importlib import reload from itertools import cycle -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union import torch from torch.utils.data import Dataset, IterableDataset @@ -549,6 +549,7 @@ def __init__( self, dataset: Union[StreamingDataset, CombinedStreamingDataset], *args: Any, + batching_method: Literal["stratified", "per_stream"] = "stratified", batch_size: int = 1, num_workers: int = 0, profile_batches: Union[bool, int] = False, @@ -626,10 +627,15 @@ def __iter__(self) -> Any: self._num_samples_yielded_streaming += self.batch_size yield batch else: + # Assume, this is a CombinedStreamingDataset. self.dataset._set_use_streaming_dataloader(True) assert self.batch_size # TODO: Inject a custom collate function to avoid collating the __NUM_SAMPLES_YIELDED__ key for batch in super().__iter__(): + # Force selection of a new dataset on batch boundaries + # Note, samples may come from several datasets within a batch, depending + # on `CombinedStreamingDataset`'s `batching_method` value. + self.dataset._set_new_dataset_index() self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore if isinstance(batch, dict) and __NUM_SAMPLES_YIELDED_KEY__ in batch: self._num_samples_yielded_combined[self._latest_worker_idx] = [ From da6d43d65e3993d4c8b3747f48f3174eae6ee376 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 22 Dec 2024 01:07:20 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/combined.py | 3 +-- src/litdata/streaming/dataloader.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 9815954e..9d018127 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -170,7 +170,7 @@ def __iter__(self) -> Iterator[Any]: self._use_streaming_dataloader, num_samples_yielded, self._iterate_over_all, - self._batching_method + self._batching_method, ) return self._iterator @@ -276,7 +276,6 @@ def _set_new_dataset_index(self): (dataset_index,) = self._rng.choices(indexes, weights=weights, k=1) self._cur_dataset_index = dataset_index - def _get_sample(self, dataset_index: int) -> Any: # get the sample sample = next(self._dataset_iters[dataset_index]) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 892cec63..0e22d55b 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -634,8 +634,8 @@ def __iter__(self) -> Any: for batch in super().__iter__(): # Force selection of a new dataset on batch boundaries # Note, samples may come from several datasets within a batch, depending - # on `CombinedStreamingDataset`'s `batching_method` value. - self.dataset._set_new_dataset_index() + # on `CombinedStreamingDataset`'s `batching_method` value. + self.dataset._set_new_dataset_index() self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore if isinstance(batch, dict) and __NUM_SAMPLES_YIELDED_KEY__ in batch: self._num_samples_yielded_combined[self._latest_worker_idx] = [ From 829ecded034ce8a426a282f68655a68a54eb1b58 Mon Sep 17 00:00:00 2001 From: Sahil Date: Mon, 30 Dec 2024 10:56:16 -0800 Subject: [PATCH 3/9] Address pre-commit issues --- src/litdata/streaming/combined.py | 49 ++++++++++++++++++------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 9815954e..3ea70a0c 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -35,7 +35,6 @@ class CombinedStreamingDataset(IterableDataset): of the given seed. The combined dataset will raise a StopIteration as soon as any of the datasets is exhausted. """ - def __init__( self, datasets: List[StreamingDataset], @@ -52,9 +51,9 @@ def __init__( weights: The sampling ratio for the datasets iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets. Otherwise, it stops as soon as one raises a StopIteration. - batching_method: When batching_method is "stratified" (default), every sample in a batch is drawn randomly across all datasets. - When batching_method is "per_stream" every sample in a batch is drawn from the same dataset. After each batch, a dataset - is selected at random. + batching_method: When batching_method is "stratified" (default), every sample in a batch is drawn randomly + across all datasets. When batching_method is "per_stream" every sample in a batch is drawn from the + same dataset. After each batch, a dataset is selected at random. """ self._check_datasets(datasets) @@ -147,9 +146,12 @@ def reset_state_dict(self) -> None: def _check_datasets(self, datasets: List[StreamingDataset]) -> None: if any(not isinstance(d, StreamingDataset) for d in datasets): - raise RuntimeError("The provided datasets should be instances of the StreamingDataset.") + raise RuntimeError( + "The provided datasets should be instances of the StreamingDataset." + ) - def _set_use_streaming_dataloader(self, use_streaming_dataloader: bool) -> None: + def _set_use_streaming_dataloader(self, + use_streaming_dataloader: bool) -> None: # Used to prevent returning num_samples_yielded when using PyTorch DataLoader self._use_streaming_dataloader = use_streaming_dataloader @@ -161,26 +163,25 @@ def __iter__(self) -> Iterator[Any]: num_samples_yielded = None if self._num_samples_yielded is not None and worker_env.rank in self._num_samples_yielded: - num_samples_yielded = self._num_samples_yielded.get(worker_env.rank, 0) + num_samples_yielded = self._num_samples_yielded.get( + worker_env.rank, 0) self._iterator = _CombinedDatasetIterator( - self._datasets, - self._seed, - self._weights, - self._use_streaming_dataloader, - num_samples_yielded, - self._iterate_over_all, - self._batching_method - ) + self._datasets, self._seed, self._weights, + self._use_streaming_dataloader, num_samples_yielded, + self._iterate_over_all, self._batching_method) return self._iterator def state_dict( - self, num_workers: int, batch_size: int, num_samples_yielded: Optional[List[int]] = None - ) -> Dict[str, Any]: + self, + num_workers: int, + batch_size: int, + num_samples_yielded: Optional[List[int]] = None) -> Dict[str, Any]: if self._iterator is None: if num_samples_yielded is None: return {} - return _state_dict(self._datasets, num_samples_yielded, num_workers, batch_size) + return _state_dict(self._datasets, num_samples_yielded, + num_workers, batch_size) return self._iterator.state_dict(num_workers, batch_size) def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -188,11 +189,15 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return if len(state_dict["dataset"]) != len(self._datasets): - raise RuntimeError(f"The provided state doesn't match the current number of datasets: {self._datasets}.") + raise RuntimeError( + f"The provided state doesn't match the current number of datasets: {self._datasets}." + ) for dataset_idx, dataset in enumerate(self._datasets): if str(dataset_idx) not in state_dict["dataset"]: - raise RuntimeError(f"The provided state doesn't contain the index {dataset_idx}.") + raise RuntimeError( + f"The provided state doesn't contain the index {dataset_idx}." + ) dataset.load_state_dict(state_dict["dataset"][str(dataset_idx)]) @@ -265,8 +270,10 @@ def _get_dataset_index(self) -> int: self._set_new_dataset_index() elif self._batching_method == "per_stream": # randomly select a dataset index, if no previous dataset index exists - if self._cur_dataset_index == -1: + if self._cur_dataset_index == -1 self._set_new_dataset_index() + else: + raise ValueError(f"Invalid batching method: {self._batching_method}") return self._cur_dataset_index def _set_new_dataset_index(self): From 31777b4f77aec467d048196f23fcbad5d78c8f05 Mon Sep 17 00:00:00 2001 From: Sahil Date: Mon, 30 Dec 2024 10:58:50 -0800 Subject: [PATCH 4/9] fixed bug in combined.py --- src/litdata/streaming/combined.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 43054f00..34cc6673 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -275,7 +275,7 @@ def _get_dataset_index(self) -> int: self._set_new_dataset_index() elif self._batching_method == "per_stream": # randomly select a dataset index, if no previous dataset index exists - if self._cur_dataset_index == -1 + if self._cur_dataset_index == -1: self._set_new_dataset_index() else: raise ValueError(f"Invalid batching method: {self._batching_method}") From 8b32f8efe389479e88d7b9d0b2509f0c052f668c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 18:58:59 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/combined.py | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 34cc6673..01f35d08 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -35,6 +35,7 @@ class CombinedStreamingDataset(IterableDataset): of the given seed. The combined dataset will raise a StopIteration as soon as any of the datasets is exhausted. """ + def __init__( self, datasets: List[StreamingDataset], @@ -146,12 +147,9 @@ def reset_state_dict(self) -> None: def _check_datasets(self, datasets: List[StreamingDataset]) -> None: if any(not isinstance(d, StreamingDataset) for d in datasets): - raise RuntimeError( - "The provided datasets should be instances of the StreamingDataset." - ) + raise RuntimeError("The provided datasets should be instances of the StreamingDataset.") - def _set_use_streaming_dataloader(self, - use_streaming_dataloader: bool) -> None: + def _set_use_streaming_dataloader(self, use_streaming_dataloader: bool) -> None: # Used to prevent returning num_samples_yielded when using PyTorch DataLoader self._use_streaming_dataloader = use_streaming_dataloader @@ -163,8 +161,7 @@ def __iter__(self) -> Iterator[Any]: num_samples_yielded = None if self._num_samples_yielded is not None and worker_env.rank in self._num_samples_yielded: - num_samples_yielded = self._num_samples_yielded.get( - worker_env.rank, 0) + num_samples_yielded = self._num_samples_yielded.get(worker_env.rank, 0) self._iterator = _CombinedDatasetIterator( self._datasets, @@ -178,15 +175,12 @@ def __iter__(self) -> Iterator[Any]: return self._iterator def state_dict( - self, - num_workers: int, - batch_size: int, - num_samples_yielded: Optional[List[int]] = None) -> Dict[str, Any]: + self, num_workers: int, batch_size: int, num_samples_yielded: Optional[List[int]] = None + ) -> Dict[str, Any]: if self._iterator is None: if num_samples_yielded is None: return {} - return _state_dict(self._datasets, num_samples_yielded, - num_workers, batch_size) + return _state_dict(self._datasets, num_samples_yielded, num_workers, batch_size) return self._iterator.state_dict(num_workers, batch_size) def load_state_dict(self, state_dict: Dict[str, Any]) -> None: @@ -194,15 +188,11 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: return if len(state_dict["dataset"]) != len(self._datasets): - raise RuntimeError( - f"The provided state doesn't match the current number of datasets: {self._datasets}." - ) + raise RuntimeError(f"The provided state doesn't match the current number of datasets: {self._datasets}.") for dataset_idx, dataset in enumerate(self._datasets): if str(dataset_idx) not in state_dict["dataset"]: - raise RuntimeError( - f"The provided state doesn't contain the index {dataset_idx}." - ) + raise RuntimeError(f"The provided state doesn't contain the index {dataset_idx}.") dataset.load_state_dict(state_dict["dataset"][str(dataset_idx)]) From 39142347e8e5d8f32c4b8cb8c9a4e82cbcd74ec9 Mon Sep 17 00:00:00 2001 From: Sahil Date: Mon, 30 Dec 2024 14:14:13 -0800 Subject: [PATCH 6/9] expose _set_new_dataset_index() to dataloader --- src/litdata/streaming/combined.py | 25 +++++++++++++++++-------- src/litdata/streaming/dataloader.py | 21 ++++++++++++++------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 01f35d08..94027c08 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -42,7 +42,6 @@ def __init__( seed: int = 42, weights: Optional[Sequence[float]] = None, iterate_over_all: bool = True, - batching_method: Literal["stratified", "per_stream"] = "stratified", ) -> None: """Enable to stream data from multiple StreamingDataset with the sampling ratio of your choice. @@ -52,10 +51,6 @@ def __init__( weights: The sampling ratio for the datasets iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets. Otherwise, it stops as soon as one raises a StopIteration. - batching_method: When batching_method is "stratified" (default), every sample in a batch is drawn randomly - across all datasets. When batching_method is "per_stream" every sample in a batch is drawn from the - same dataset. After each batch, a dataset is selected at random. - """ self._check_datasets(datasets) @@ -71,7 +66,6 @@ def __init__( ) self._iterate_over_all = iterate_over_all - self._batching_method = batching_method if weights is None: # Weighted based on the dataset length @@ -89,6 +83,7 @@ def __init__( self._current_epoch = 0 self.num_workers = 1 self.batch_size = 1 + self.batching_method = "stratified" def get_len(self, num_workers: int, batch_size: int) -> Optional[int]: self.num_workers = num_workers @@ -130,6 +125,15 @@ def set_batch_size(self, batch_size: int) -> None: for dataset in self._datasets: dataset.set_batch_size(batch_size) + + def set_batching_method(self, batching_method: Literal["stratified", "per_stream"]) -> None: + """Set the current batching method to the datasets. + When batching_method is "stratified" (default), batches consist of samples from all datasets. + When batching_method is "per_stream" batches consist of samples from one dataset, + which is selected at random. + """ + self.batching_method = batching_method + def set_num_workers(self, num_workers: int) -> None: """Set the current number of workers to the datasets.""" for dataset in self._datasets: @@ -169,8 +173,8 @@ def __iter__(self) -> Iterator[Any]: self._weights, self._use_streaming_dataloader, num_samples_yielded, + self.batching_method, self._iterate_over_all, - self._batching_method, ) return self._iterator @@ -200,6 +204,11 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self._use_streaming_dataloader: self._num_samples_yielded = state_dict["num_samples_yielded"] + def _set_new_dataset_index(self) -> None: + """Select a new dataset index randomly based on weights.""" + if self._iterator is not None: + self._iterator._set_new_dataset_index() + class _CombinedDatasetIterator(Iterator): def __init__( @@ -209,8 +218,8 @@ def __init__( weights: Sequence[Optional[float]], use_streaming_dataloader: bool, num_samples_yielded: Any, + batching_method: Literal["stratified", "per_stream"], iterate_over_all: bool = False, - batching_method: Literal["stratified", "per_stream"] = "stratified", ) -> None: self._datasets = datasets self._dataset_iters = [iter(dataset) for dataset in datasets] diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 0e22d55b..40ca770d 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -513,6 +513,10 @@ class StreamingDataLoader(DataLoader): collate_fn (Callable, optional): merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. + batching_method (str, optional): When batching_method is "stratified" (default), + batches consist of samples from all datasets. When batching_method is "per_stream", + batches consist of samples from one dataset, which is selected at random. Note that this + parameter is only applicable to CombinedStreamingDataset. pin_memory (bool, optional): If ``True``, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, @@ -549,7 +553,6 @@ def __init__( self, dataset: Union[StreamingDataset, CombinedStreamingDataset], *args: Any, - batching_method: Literal["stratified", "per_stream"] = "stratified", batch_size: int = 1, num_workers: int = 0, profile_batches: Union[bool, int] = False, @@ -558,6 +561,7 @@ def __init__( shuffle: Optional[bool] = None, drop_last: Optional[bool] = None, collate_fn: Optional[Callable] = None, + batching_method: Literal["stratified", "per_stream"] = "stratified", **kwargs: Any, ) -> None: # pyright: ignore if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)): @@ -575,6 +579,9 @@ def __init__( dataset.set_batch_size(batch_size) dataset.set_num_workers(num_workers) + if isinstance(dataset, CombinedStreamingDataset): + dataset.set_batching_method(batching_method) + shuffle = None if profile_batches and not _VIZ_TRACKER_AVAILABLE: @@ -630,19 +637,19 @@ def __iter__(self) -> Any: # Assume, this is a CombinedStreamingDataset. self.dataset._set_use_streaming_dataloader(True) assert self.batch_size - # TODO: Inject a custom collate function to avoid collating the __NUM_SAMPLES_YIELDED__ key + for batch in super().__iter__(): - # Force selection of a new dataset on batch boundaries - # Note, samples may come from several datasets within a batch, depending - # on `CombinedStreamingDataset`'s `batching_method` value. - self.dataset._set_new_dataset_index() self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore + + # Force selection of a new dataset for the next batch + if isinstance(self.dataset, CombinedStreamingDataset) and self.dataset.batching_method == "per_stream": + self.dataset._set_new_dataset_index() + if isinstance(batch, dict) and __NUM_SAMPLES_YIELDED_KEY__ in batch: self._num_samples_yielded_combined[self._latest_worker_idx] = [ sample[-1].item() if self.batch_size > 1 else sample.item() for sample in batch[__NUM_SAMPLES_YIELDED_KEY__] ] - yield batch[__SAMPLES_KEY__] else: yield batch From bedd6fcbc38f4181cefc6f872c84de2e16482198 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Dec 2024 22:14:25 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/combined.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 94027c08..f5ccfef5 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -125,12 +125,11 @@ def set_batch_size(self, batch_size: int) -> None: for dataset in self._datasets: dataset.set_batch_size(batch_size) - def set_batching_method(self, batching_method: Literal["stratified", "per_stream"]) -> None: """Set the current batching method to the datasets. - When batching_method is "stratified" (default), batches consist of samples from all datasets. - When batching_method is "per_stream" batches consist of samples from one dataset, - which is selected at random. + When batching_method is "stratified" (default), batches consist of samples from all datasets. + When batching_method is "per_stream" batches consist of samples from one dataset, + which is selected at random. """ self.batching_method = batching_method From b0a106cad459e1c7af3b5c792a2f1e17798b089d Mon Sep 17 00:00:00 2001 From: Sahil Date: Mon, 30 Dec 2024 17:16:28 -0800 Subject: [PATCH 8/9] WIP - communicate to worker loop --- src/litdata/streaming/combined.py | 6 +---- src/litdata/streaming/dataloader.py | 41 ++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index f5ccfef5..c530068b 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -175,6 +175,7 @@ def __iter__(self) -> Iterator[Any]: self.batching_method, self._iterate_over_all, ) + print(f"Creating new iterator: {id(self._iterator)}") return self._iterator def state_dict( @@ -203,11 +204,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self._use_streaming_dataloader: self._num_samples_yielded = state_dict["num_samples_yielded"] - def _set_new_dataset_index(self) -> None: - """Select a new dataset index randomly based on weights.""" - if self._iterator is not None: - self._iterator._set_new_dataset_index() - class _CombinedDatasetIterator(Iterator): def __init__( diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 40ca770d..9f2a7fef 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -193,6 +193,24 @@ def __call__( create_fetcher = _DatasetKind.create_fetcher fetcher = None + # Create a wrapper around the original index_queue to intercept commands + # This allows us to intercept the "SET_NEW_DATASET_INDEX" command and call + # the _set_new_dataset_index method on the iterator, if we're using a + # CombinedStreamingDataset with per_stream batching. + original_get = index_queue.get + def wrapped_get(*args, **kwargs): + item = original_get(*args, **kwargs) + if isinstance(item, tuple) and item[0] == "SET_NEW_DATASET_INDEX": + print(f"Worker {worker_id} received SET_NEW_DATASET_INDEX command") + if hasattr(dataset, '_iterator') and dataset._iterator is not None: + print(f"Worker {worker_id} is picking a new dataset index ...") + dataset._iterator._set_new_dataset_index() + # Get the next item since we handled this command + return original_get(*args, **kwargs) + return item + index_queue.get = wrapped_get + + def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": nonlocal fetcher fetcher = create_fetcher(*args, **kwargs) @@ -452,11 +470,25 @@ def __init__(self, loader: DataLoader) -> None: if self._loader._profile_batches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE: from torch.utils.data._utils import worker - worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_batches, self._loader._profile_dir) super().__init__(loader) + def _next_data(self): + # Get data as normal + data = super()._next_data() + + # If we're using per_stream batching, send command to switch datasets on batch boundaries + if (isinstance(self._loader.dataset, CombinedStreamingDataset) + and self._loader.dataset.batching_method == "per_stream" + and self._rcvd_idx % self._loader.batch_size == 0): + print(f"Batch {self._rcvd_idx // self._loader.batch_size}: Sending SET_NEW_DATASET_INDEX command to worker") + self._index_queues[self._loader._latest_worker_idx].put( + ("SET_NEW_DATASET_INDEX", None) + ) + + return data + def _try_put_index(self) -> None: # Used to restart on the right DataLoader worker if self._loader.restore and self._indexes: @@ -605,6 +637,7 @@ def __init__( self._worker_idx_iter: Optional[Any] = None self._latest_worker_idx = 0 self.restore = False + super().__init__( dataset, *args, @@ -639,12 +672,8 @@ def __iter__(self) -> Any: assert self.batch_size for batch in super().__iter__(): + print("Fetched a batch ...") self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore - - # Force selection of a new dataset for the next batch - if isinstance(self.dataset, CombinedStreamingDataset) and self.dataset.batching_method == "per_stream": - self.dataset._set_new_dataset_index() - if isinstance(batch, dict) and __NUM_SAMPLES_YIELDED_KEY__ in batch: self._num_samples_yielded_combined[self._latest_worker_idx] = [ sample[-1].item() if self.batch_size > 1 else sample.item() From 32abd7744c0a3fd6246226faed4b2e2ec7fb3b48 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Dec 2024 01:16:38 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/streaming/dataloader.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 9f2a7fef..f11aa601 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -198,18 +198,19 @@ def __call__( # the _set_new_dataset_index method on the iterator, if we're using a # CombinedStreamingDataset with per_stream batching. original_get = index_queue.get + def wrapped_get(*args, **kwargs): item = original_get(*args, **kwargs) if isinstance(item, tuple) and item[0] == "SET_NEW_DATASET_INDEX": print(f"Worker {worker_id} received SET_NEW_DATASET_INDEX command") - if hasattr(dataset, '_iterator') and dataset._iterator is not None: + if hasattr(dataset, "_iterator") and dataset._iterator is not None: print(f"Worker {worker_id} is picking a new dataset index ...") dataset._iterator._set_new_dataset_index() # Get the next item since we handled this command return original_get(*args, **kwargs) return item - index_queue.get = wrapped_get + index_queue.get = wrapped_get def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": nonlocal fetcher @@ -470,6 +471,7 @@ def __init__(self, loader: DataLoader) -> None: if self._loader._profile_batches and distributed_env.global_rank == 0 and _VIZ_TRACKER_AVAILABLE: from torch.utils.data._utils import worker + worker._worker_loop = _ProfileWorkerLoop(self._loader._profile_batches, self._loader._profile_dir) super().__init__(loader) @@ -477,16 +479,16 @@ def __init__(self, loader: DataLoader) -> None: def _next_data(self): # Get data as normal data = super()._next_data() - + # If we're using per_stream batching, send command to switch datasets on batch boundaries - if (isinstance(self._loader.dataset, CombinedStreamingDataset) + if ( + isinstance(self._loader.dataset, CombinedStreamingDataset) and self._loader.dataset.batching_method == "per_stream" - and self._rcvd_idx % self._loader.batch_size == 0): + and self._rcvd_idx % self._loader.batch_size == 0 + ): print(f"Batch {self._rcvd_idx // self._loader.batch_size}: Sending SET_NEW_DATASET_INDEX command to worker") - self._index_queues[self._loader._latest_worker_idx].put( - ("SET_NEW_DATASET_INDEX", None) - ) - + self._index_queues[self._loader._latest_worker_idx].put(("SET_NEW_DATASET_INDEX", None)) + return data def _try_put_index(self) -> None: