diff --git a/src/litdata/streaming/combined.py b/src/litdata/streaming/combined.py index 010601ee..9eab2cab 100644 --- a/src/litdata/streaming/combined.py +++ b/src/litdata/streaming/combined.py @@ -13,7 +13,7 @@ import random from copy import deepcopy -from typing import Any, Dict, Iterator, List, Optional, Sequence +from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence from torch.utils.data import IterableDataset @@ -51,7 +51,6 @@ def __init__( weights: The sampling ratio for the datasets iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets. Otherwise, it stops as soon as one raises a StopIteration. - """ self._check_datasets(datasets) @@ -84,6 +83,7 @@ def __init__( self._current_epoch = 0 self.num_workers = 1 self.batch_size = 1 + self.batching_method = "stratified" def get_len(self, num_workers: int, batch_size: int) -> Optional[int]: self.num_workers = num_workers @@ -125,6 +125,14 @@ def set_batch_size(self, batch_size: int) -> None: for dataset in self._datasets: dataset.set_batch_size(batch_size) + def set_batching_method(self, batching_method: Literal["stratified", "per_stream"]) -> None: + """Set the current batching method to the datasets. + When batching_method is "stratified" (default), batches consist of samples from all datasets. + When batching_method is "per_stream" batches consist of samples from one dataset, + which is selected at random. + """ + self.batching_method = batching_method + def set_num_workers(self, num_workers: int) -> None: """Set the current number of workers to the datasets.""" for dataset in self._datasets: @@ -164,8 +172,10 @@ def __iter__(self) -> Iterator[Any]: self._weights, self._use_streaming_dataloader, num_samples_yielded, + self.batching_method, self._iterate_over_all, ) + print(f"Creating new iterator: {id(self._iterator)}") return self._iterator def state_dict( @@ -203,6 +213,7 @@ def __init__( weights: Sequence[Optional[float]], use_streaming_dataloader: bool, num_samples_yielded: Any, + batching_method: Literal["stratified", "per_stream"], iterate_over_all: bool = False, ) -> None: self._datasets = datasets @@ -213,6 +224,8 @@ def __init__( self._weights = deepcopy(weights) self._rng = random.Random(seed) # noqa: S311 self._iterate_over_all = iterate_over_all + self._batching_method = batching_method + self._cur_dataset_index = -1 self._is_done = False if num_samples_yielded is not None: @@ -234,6 +247,7 @@ def __next__(self) -> Any: dataset_index = self._get_dataset_index() elif len(indexes_left) == 1: dataset_index = indexes_left[0] + self._cur_dataset_index = dataset_index return self._get_sample(dataset_index) except StopIteration as e: if len(indexes_left) == 1: @@ -250,11 +264,23 @@ def __next__(self) -> Any: return self._get_sample(self._get_dataset_index()) def _get_dataset_index(self) -> int: + if self._batching_method == "stratified": + # randomly select a dataset index + self._set_new_dataset_index() + elif self._batching_method == "per_stream": + # randomly select a dataset index, if no previous dataset index exists + if self._cur_dataset_index == -1: + self._set_new_dataset_index() + else: + raise ValueError(f"Invalid batching method: {self._batching_method}") + return self._cur_dataset_index + + def _set_new_dataset_index(self): # randomly select a dataset index indexes = [index for index in self._dataset_indexes if index is not None] weights = [w for w in self._weights if w is not None] (dataset_index,) = self._rng.choices(indexes, weights=weights, k=1) - return dataset_index + self._cur_dataset_index = dataset_index def _get_sample(self, dataset_index: int) -> Any: # get the sample diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 67bdfe4e..d7de1406 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -18,7 +18,7 @@ from copy import deepcopy from importlib import reload from itertools import cycle -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Union import torch from torch.utils.data import Dataset, IterableDataset @@ -191,6 +191,25 @@ def __call__( create_fetcher = _DatasetKind.create_fetcher fetcher = None + # Create a wrapper around the original index_queue to intercept commands + # This allows us to intercept the "SET_NEW_DATASET_INDEX" command and call + # the _set_new_dataset_index method on the iterator, if we're using a + # CombinedStreamingDataset with per_stream batching. + original_get = index_queue.get + + def wrapped_get(*args, **kwargs): + item = original_get(*args, **kwargs) + if isinstance(item, tuple) and item[0] == "SET_NEW_DATASET_INDEX": + print(f"Worker {worker_id} received SET_NEW_DATASET_INDEX command") + if hasattr(dataset, "_iterator") and dataset._iterator is not None: + print(f"Worker {worker_id} is picking a new dataset index ...") + dataset._iterator._set_new_dataset_index() + # Get the next item since we handled this command + return original_get(*args, **kwargs) + return item + + index_queue.get = wrapped_get + def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": nonlocal fetcher fetcher = create_fetcher(*args, **kwargs) @@ -455,6 +474,21 @@ def __init__(self, loader: DataLoader) -> None: super().__init__(loader) + def _next_data(self): + # Get data as normal + data = super()._next_data() + + # If we're using per_stream batching, send command to switch datasets on batch boundaries + if ( + isinstance(self._loader.dataset, CombinedStreamingDataset) + and self._loader.dataset.batching_method == "per_stream" + and self._rcvd_idx % self._loader.batch_size == 0 + ): + print(f"Batch {self._rcvd_idx // self._loader.batch_size}: Sending SET_NEW_DATASET_INDEX command to worker") + self._index_queues[self._loader._latest_worker_idx].put(("SET_NEW_DATASET_INDEX", None)) + + return data + def _try_put_index(self) -> None: # Used to restart on the right DataLoader worker if self._loader.restore and self._indexes: @@ -511,6 +545,10 @@ class StreamingDataLoader(DataLoader): collate_fn (Callable, optional): merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset. + batching_method (str, optional): When batching_method is "stratified" (default), + batches consist of samples from all datasets. When batching_method is "per_stream", + batches consist of samples from one dataset, which is selected at random. Note that this + parameter is only applicable to CombinedStreamingDataset. pin_memory (bool, optional): If ``True``, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type, @@ -555,6 +593,7 @@ def __init__( shuffle: Optional[bool] = None, drop_last: Optional[bool] = None, collate_fn: Optional[Callable] = None, + batching_method: Literal["stratified", "per_stream"] = "stratified", **kwargs: Any, ) -> None: # pyright: ignore if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)): @@ -572,6 +611,9 @@ def __init__( dataset.set_batch_size(batch_size) dataset.set_num_workers(num_workers) + if isinstance(dataset, CombinedStreamingDataset): + dataset.set_batching_method(batching_method) + shuffle = None if profile_batches and not _VIZ_TRACKER_AVAILABLE: @@ -595,6 +637,7 @@ def __init__( self._worker_idx_iter: Optional[Any] = None self._latest_worker_idx = 0 self.restore = False + super().__init__( dataset, *args, @@ -624,17 +667,18 @@ def __iter__(self) -> Any: self._num_samples_yielded_streaming += self.batch_size yield batch else: + # Assume, this is a CombinedStreamingDataset. self.dataset._set_use_streaming_dataloader(True) assert self.batch_size - # TODO: Inject a custom collate function to avoid collating the __NUM_SAMPLES_YIELDED__ key + for batch in super().__iter__(): + print("Fetched a batch ...") self._latest_worker_idx = next(self._worker_idx_iter) # type: ignore if isinstance(batch, dict) and __NUM_SAMPLES_YIELDED_KEY__ in batch: self._num_samples_yielded_combined[self._latest_worker_idx] = [ sample[-1].item() if self.batch_size > 1 else sample.item() for sample in batch[__NUM_SAMPLES_YIELDED_KEY__] ] - yield batch[__SAMPLES_KEY__] else: yield batch