diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 327b97a5d..01bb17ed5 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -15,6 +15,8 @@ import torch import torch.utils.data + +from parameterized import parameterized from torch.testing._internal.common_utils import IS_MACOS, TEST_CUDA, TestCase from torchdata.stateful_dataloader import Stateful, StatefulDataLoader @@ -1314,7 +1316,7 @@ def test(self): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) # Fetch at least one batch from each worker @@ -1325,7 +1327,10 @@ def test(self): if num_workers > 0: for i in range(num_workers): # Ensure worker state is stored only once if the dataset is also the iterator - self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None) + self.assertEqual( + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], + None, + ) self.assertTrue( state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][ "dataset_iter_state" @@ -1441,6 +1446,206 @@ def test_fast_state_dict_request_skip_steps(self) -> None: self._run_test(17, 19) +class TestMultiEpochSDL_shard0(TestCase): + def get_map_dl(self, data_size, num_workers, batch_size, shuffle): + dataset = DummyMapDataset(data_size, shuffle=False) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + ) + + def _run(self, data_size, num_workers, batch_size, shuffle): + dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + # Run through the dataloader for 2 epochs and count the number of items yielded + num_items_yielded = 0 + dataloader1_items = [] + for _ in range(2): + for batch in dataloader1: + dataloader1_items.append(batch) + num_items_yielded += 1 + # Save the state dict + state_dict = dataloader1.state_dict() + # Create a new StatefulDataLoader instance and load the state dict + new_dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + new_dataloader1.load_state_dict(state_dict) + # Run through the new dataloader for another 2 epochs and count the number of items yielded + additional_num_items_yielded = 0 + for i in range(2): + epoch_num_items_yielded = 0 + for batch in new_dataloader1: + dataloader1_items.append(batch) + epoch_num_items_yielded += 1 + additional_num_items_yielded += epoch_num_items_yielded + # Check that the total number of items yielded is correct + self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4) + + # now run a second dataloder for 4 epochs and check if the order is same. + dataloader2 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dataloader2_items = [] + for _ in range(4): + for batch in dataloader2: + dataloader2_items.append(batch) + + self.assertEqual(dataloader1_items, dataloader2_items) + + @parameterized.expand(itertools.product([100], [0, 2], [1], [False, True])) + def test_multi_epoch_sdl(self, datasize, num_workers, batch_size, shuffle): + self._run(datasize, num_workers, batch_size, shuffle) + + +class TestEndOfEpochBehavior_shard0(TestCase): + def get_map_dl(self, data_size, num_workers, batch_size, shuffle): + dataset = DummyMapDataset(data_size, shuffle=False) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + ) + + def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int: + num_items_yielded = 0 + for batch in data_loader: + num_items_yielded += 1 + return num_items_yielded + + def _run(self, data_size, num_workers, batch_size, shuffle): + dataloader = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + # Run through the dataloader for 1 epoch and count the number of items yielded + num_items_yielded = 0 + + for batch in dataloader: + num_items_yielded += 1 + sd_in = dataloader.state_dict() + sd_out = dataloader.state_dict() + + self.assertEqual(num_items_yielded, data_size) + + # Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch + dataloader_sd_in = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dataloader_sd_in.load_state_dict(sd_in) + + # Run through the new dataloader for 1 epoch and count the number of items yielded + # num_items_yielded should be 0 since the state dict was saved before the end of epoch + num_items_yielded = self._count_items_yielded(dataloader_sd_in) + self.assertEqual(num_items_yielded, 0) + + # Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch + dataloader_sd_out = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dataloader_sd_out.load_state_dict(sd_out) + + # Run through the new dataloader for 1 epoch and count the number of items yielded + # num_items_yielded should be data_size since the state dict was saved after the end of epoch + num_items_yielded = self._count_items_yielded(dataloader_sd_out) + self.assertEqual(num_items_yielded, data_size) + + @parameterized.expand(itertools.product([100], [0, 2], [1], [False, True])) + def test_end_of_epoch_behavior(self, datasize, num_workers, batch_size, shuffle): + self._run(datasize, num_workers, batch_size, shuffle) + + +class TestNotStatefulSamplerSDL_shard0(TestCase): + def get_map_dl(self, data_size, num_workers, batch_size, sampler_cls): + dataset = DummyMapDataset(data_size, shuffle=False) + sampler = sampler_cls(dataset) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + sampler=sampler, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + ) + + def _run(self, data_size, num_workers, batch_size, interrupt, sampler_cls): + torch.manual_seed(0) # Fixing seed for deterministic results + dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + sampler_cls=sampler_cls, + ) + # interrupt the dataloader after interrupt batches and save the state dict + results_dataloader1 = [] + for i, batch in enumerate(dataloader1): + results_dataloader1.append(batch) + if i == interrupt: + break + state_dict = dataloader1.state_dict() + + torch.manual_seed( + 0 + ) # We need to fix seed again so that before fast forwarding we are at the same state of gen as before + resumed_dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + sampler_cls=sampler_cls, + ) + resumed_dataloader1.load_state_dict(state_dict) + + for batch in resumed_dataloader1: + results_dataloader1.append(batch) + + # now start a completely new dataloader and get all the batches + torch.manual_seed(0) + dataloader2 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + sampler_cls=sampler_cls, + ) + results_dataloader2 = [] + for batch in dataloader2: + results_dataloader2.append(batch) + self.assertEqual(results_dataloader1, results_dataloader2) + + @parameterized.expand( + itertools.product( + [100], + [0, 2], + [1], + [10, 50, 80], + [torch.utils.data.RandomSampler, torch.utils.data.SequentialSampler], + ) + ) + def test_notstatefulSDL(self, data_size, num_workers, batch_size, interrupt, sampler_cls): + self._run(100, 0, 1, interrupt, sampler_cls) + + class TestMultiEpochState_shard0(TestCase): def get_iterable_dl(self, pw, num_workers): data_size = [25, 50, 100, 75] diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 4effec1d1..d0ae1a3eb 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -5,11 +5,12 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Any, Dict, Iterator, Optional, Sized +from typing import Any, Dict, Iterator, List, Optional, Sized import torch.utils.data.sampler from torch.utils.data import Dataset from torch.utils.data.dataloader import _InfiniteConstantSampler +from torch.utils.data.sampler import Sampler from .stateful import Stateful @@ -18,58 +19,122 @@ class _StatefulRandomSamplerIterator(Iterator[int], Stateful): _GENERATOR = "generator" _YIELDED = "yielded" - def __init__(self, sampler, parent_iterator: Iterator[int]): + def __init__(self, sampler): self.sampler = sampler - self.parent_iterator = parent_iterator + self.generator_state = self.sampler.generator.get_state() self.yielded = 0 self.next_yielded = None - self.generator_state = sampler.generator.get_state() + self.n = len(sampler.data_source) + self.replacement = sampler.replacement + self.num_samples = sampler.num_samples + self.chunk_size = 32 + self.perm: List[int] = self._get_perm() + self.perm_index = 0 + self.chunk_index = 0 - def __next__(self) -> int: - if self.next_yielded is not None: - for _ in range(self.next_yielded): - next(self.parent_iterator) - - self.yielded = self.next_yielded - self.next_yielded = None - - val = next(self.parent_iterator) + def __iter__(self): + return self + + def _get_perm(self) -> List[int]: + if self.replacement: + return torch.randint( + high=self.n, + size=(self.chunk_size,), + dtype=torch.int64, + generator=self.sampler.generator, + ).tolist() + else: + return torch.randperm(self.n, generator=self.sampler.generator).tolist() + + def __next__(self): + if self.yielded == self.num_samples: + raise StopIteration() + if self.perm_index == len(self.perm): + self.perm = self._get_perm() + self.perm_index = 0 + val = self.perm[self.perm_index] + self.perm_index += 1 self.yielded += 1 return val - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self.generator_state = state_dict[self._GENERATOR] - self.sampler.generator.set_state(state_dict[self._GENERATOR]) + def state_dict(self) -> dict: + return { + self._YIELDED: self.yielded, + self._GENERATOR: self.generator_state, + } + + def load_state_dict(self, state_dict: dict) -> None: self.next_yielded = state_dict[self._YIELDED] + self.generator_state = state_dict[self._GENERATOR] + self.sampler.generator.set_state(self.generator_state) - def state_dict(self) -> Dict[str, Any]: - return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded} + if self.next_yielded is not None: + self.perm = self._get_perm() # We want permutations from the latest generator state that's loaded + for _ in range(self.next_yielded): + next(self) + self.yielded = self.next_yielded + self.next_yielded = None -class RandomSampler(torch.utils.data.sampler.RandomSampler): +class RandomSampler(Sampler[int]): def __init__( - self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None - ): + self, + data_source: Sized, + replacement: bool = False, + num_samples: Optional[int] = None, + generator=None, + ) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples if generator is None: # Ensure that underlying sampler has something repeatable generator = torch.Generator() generator.manual_seed(1) - super().__init__(data_source, replacement, num_samples, generator) + self.generator = generator + if not isinstance(self.replacement, bool): + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") - def __iter__(self): - return _StatefulRandomSamplerIterator(self, super().__iter__()) + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + def __iter__(self) -> Iterator[int]: + return _StatefulRandomSamplerIterator(self) -class BatchSampler(torch.utils.data.sampler.BatchSampler, Stateful): + def __len__(self) -> int: + return self.num_samples + + +class _BatchSamplerIterator(Iterator[list[int]], Stateful): _SAMPLES_YIELDED = "samples_yielded" _SAMPLER_STATE = "sampler_state" _SAMPLER_ITER_STATE = "sampler_iter_state" - def __init__(self, sampler, batch_size, drop_last): - super().__init__(sampler, batch_size, drop_last) + def __init__(self, sampler, batch_size: int, drop_last: bool): + self.sampler = sampler + self.sampler_iter = iter(self.sampler) + self.batch_size = batch_size + self.drop_last = drop_last self.samples_yielded = 0 - self.next_yielded = None - self.sampler_iter = iter(sampler) + + def __next__(self) -> list[int]: + batch = [] + try: + for _ in range(self.batch_size): + batch.append(next(self.sampler_iter)) + self.samples_yielded += 1 + return batch + except StopIteration: + if self.drop_last or len(batch) == 0: + raise StopIteration + else: + return batch def state_dict(self) -> Dict[str, Any]: sd: Dict[str, Any] = {self._SAMPLES_YIELDED: self.samples_yielded} @@ -80,7 +145,7 @@ def state_dict(self) -> Dict[str, Any]: return sd def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self.next_yielded = state_dict[self._SAMPLES_YIELDED] + self.samples_yielded = state_dict[self._SAMPLES_YIELDED] if self._SAMPLER_STATE in state_dict: assert isinstance(self.sampler, Stateful) self.sampler.load_state_dict(state_dict[self._SAMPLER_STATE]) @@ -89,44 +154,28 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) + if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( + self.sampler, _InfiniteConstantSampler + ): + # We skip x samples if underlying sampler is not stateful + for _ in range(self.samples_yielded): + next(self.sampler_iter) + + def update_state_dict(self) -> None: + if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): + self.sampler_iter.update_state_dict() + + +class BatchSampler(torch.utils.data.sampler.BatchSampler): + def __init__(self, sampler, batch_size, drop_last): + super().__init__(sampler, batch_size, drop_last) + def __iter__(self): - if self.next_yielded is not None: - self.samples_yielded = self.next_yielded - if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( - self.sampler, _InfiniteConstantSampler - ): - # We skip x samples if underlying sampler is not stateful - for _ in range(self.next_yielded): - next(self.sampler_iter) - self.next_yielded = None - elif self.samples_yielded > 0: - # don't re-create sampler_iter unless necessary, we may already have one from init - self.sampler_iter = iter(self.sampler) - self.samples_yielded = 0 - - if self.drop_last: - while True: - try: - batch = [] - for _ in range(self.batch_size): - batch.append(next(self.sampler_iter)) - self.samples_yielded += 1 - yield batch - except StopIteration: - break - else: - batch = [0] * self.batch_size - idx_in_batch = 0 - for idx in self.sampler_iter: - self.samples_yielded += 1 - batch[idx_in_batch] = idx - idx_in_batch += 1 - if idx_in_batch == self.batch_size: - yield batch - idx_in_batch = 0 - batch = [0] * self.batch_size - if idx_in_batch > 0: - yield batch[:idx_in_batch] + return _BatchSamplerIterator( + sampler=self.sampler, + batch_size=self.batch_size, + drop_last=self.drop_last, + ) class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler): diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 078b378ee..1ffeec298 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -479,7 +479,11 @@ def __init__(self, loader, next_iter_state=None): self.load_state_dict(next_iter_state) else: self._dataset_fetcher = _DatasetKind.create_fetcher( - self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, ) def _next_data(self): @@ -528,7 +532,10 @@ def load_state_dict(self, state_dict): if state_dict[_SAMPLER_ITER_STATE] is not None: self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) else: - if not isinstance(self._index_sampler, torch.utils.data.dataloader._InfiniteConstantSampler): + if not isinstance( + self._index_sampler, + torch.utils.data.dataloader._InfiniteConstantSampler, + ): # Fallback to fastforward self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) self._num_yielded = state_dict[self._NUM_YIELDED] @@ -542,7 +549,11 @@ def load_state_dict(self, state_dict): if state_dict[_DATASET_STATE] is not None and isinstance(self._dataset, Stateful): self._dataset = try_to_deserialize(self._dataset, state_dict[_DATASET_STATE]) self._dataset_fetcher = _DatasetKind.create_fetcher( - self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, ) if self._dataset_kind == _DatasetKind.Iterable: # If either dataset or it's iter is stateful, we don't fast-forward @@ -907,7 +918,10 @@ def __init__(self, loader, next_iter_state): # Additional worker init function will take care of sharding in MP and Distributed if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): self._worker_init_fn = functools.partial( - _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank + _sharding_worker_init_fn, + self._worker_init_fn, + self._world_size, + self._rank, ) # No certainty which module multiprocessing_context is @@ -1462,7 +1476,10 @@ def _restore_main_state(self, state_dict): if state_dict[_SAMPLER_ITER_STATE] is not None: self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) else: - if not isinstance(self._index_sampler, torch.utils.data.dataloader._InfiniteConstantSampler): + if not isinstance( + self._index_sampler, + torch.utils.data.dataloader._InfiniteConstantSampler, + ): # Fallback to fastforward self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] @@ -1540,7 +1557,10 @@ def _take_snapshot(self): # in_order is False and no main snapshot is available as we're ahead of rcvd_idx # we can't take a snapshot with the current implementation return - assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1) + assert main_snapshot_idx == self._rcvd_idx - 1, ( + main_snapshot_idx, + self._rcvd_idx - 1, + ) self._update_snapshot( self._num_yielded + 1, self._last_yielded_worker_id,