From 0cd82341c9b87c7cc490bfe4d0b9db124d1a88d2 Mon Sep 17 00:00:00 2001 From: Ramanish Singh <48493179+ramanishsingh@users.noreply.github.com> Date: Tue, 18 Feb 2025 10:21:47 -0800 Subject: [PATCH 1/2] Fix end of epoch StatefulDataLoader restart (#1439) * add test for end of epoch state dict check * run precommit update stateful_dataloader run precommit local changes update test to test the order of batches update test update tests revert changes in SDL revert changes in SDL update tests run precommit * update sampler * run precommit * remove unnecessary comment * add test for statedict before and after endofepoch * run precommit * check if _sampler_iter is exhausted * run precommit * remove commented lines * remove default values * only exhaust sampler_iter if present in sd * update _StatefulRandomSamplerIterator update state dict if the iterator has finished add comment about why were updating state dict run precommit * update randomsampleriter state_dict fully * run precommit * fork torch.utils.data RandomSampler reverse changes to sdl.py generator to iterator run precommit update generator usage * update class name * run precommit * add a method to generate permutations * update return type * update next logic * add comment * update tests to include non stateful samplers * add comments --- test/stateful_dataloader/test_state_dict.py | 209 +++++++++++++++++- torchdata/stateful_dataloader/sampler.py | 183 +++++++++------ .../stateful_dataloader.py | 32 ++- 3 files changed, 349 insertions(+), 75 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 327b97a5d..01bb17ed5 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -15,6 +15,8 @@ import torch import torch.utils.data + +from parameterized import parameterized from torch.testing._internal.common_utils import IS_MACOS, TEST_CUDA, TestCase from torchdata.stateful_dataloader import Stateful, StatefulDataLoader @@ -1314,7 +1316,7 @@ def test(self): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context="forkserver" if IS_MACOS and num_workers else None, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), ) it = iter(dl) # Fetch at least one batch from each worker @@ -1325,7 +1327,10 @@ def test(self): if num_workers > 0: for i in range(num_workers): # Ensure worker state is stored only once if the dataset is also the iterator - self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None) + self.assertEqual( + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], + None, + ) self.assertTrue( state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][ "dataset_iter_state" @@ -1441,6 +1446,206 @@ def test_fast_state_dict_request_skip_steps(self) -> None: self._run_test(17, 19) +class TestMultiEpochSDL_shard0(TestCase): + def get_map_dl(self, data_size, num_workers, batch_size, shuffle): + dataset = DummyMapDataset(data_size, shuffle=False) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + ) + + def _run(self, data_size, num_workers, batch_size, shuffle): + dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + # Run through the dataloader for 2 epochs and count the number of items yielded + num_items_yielded = 0 + dataloader1_items = [] + for _ in range(2): + for batch in dataloader1: + dataloader1_items.append(batch) + num_items_yielded += 1 + # Save the state dict + state_dict = dataloader1.state_dict() + # Create a new StatefulDataLoader instance and load the state dict + new_dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + new_dataloader1.load_state_dict(state_dict) + # Run through the new dataloader for another 2 epochs and count the number of items yielded + additional_num_items_yielded = 0 + for i in range(2): + epoch_num_items_yielded = 0 + for batch in new_dataloader1: + dataloader1_items.append(batch) + epoch_num_items_yielded += 1 + additional_num_items_yielded += epoch_num_items_yielded + # Check that the total number of items yielded is correct + self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4) + + # now run a second dataloder for 4 epochs and check if the order is same. + dataloader2 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dataloader2_items = [] + for _ in range(4): + for batch in dataloader2: + dataloader2_items.append(batch) + + self.assertEqual(dataloader1_items, dataloader2_items) + + @parameterized.expand(itertools.product([100], [0, 2], [1], [False, True])) + def test_multi_epoch_sdl(self, datasize, num_workers, batch_size, shuffle): + self._run(datasize, num_workers, batch_size, shuffle) + + +class TestEndOfEpochBehavior_shard0(TestCase): + def get_map_dl(self, data_size, num_workers, batch_size, shuffle): + dataset = DummyMapDataset(data_size, shuffle=False) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + ) + + def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int: + num_items_yielded = 0 + for batch in data_loader: + num_items_yielded += 1 + return num_items_yielded + + def _run(self, data_size, num_workers, batch_size, shuffle): + dataloader = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + # Run through the dataloader for 1 epoch and count the number of items yielded + num_items_yielded = 0 + + for batch in dataloader: + num_items_yielded += 1 + sd_in = dataloader.state_dict() + sd_out = dataloader.state_dict() + + self.assertEqual(num_items_yielded, data_size) + + # Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch + dataloader_sd_in = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dataloader_sd_in.load_state_dict(sd_in) + + # Run through the new dataloader for 1 epoch and count the number of items yielded + # num_items_yielded should be 0 since the state dict was saved before the end of epoch + num_items_yielded = self._count_items_yielded(dataloader_sd_in) + self.assertEqual(num_items_yielded, 0) + + # Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch + dataloader_sd_out = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dataloader_sd_out.load_state_dict(sd_out) + + # Run through the new dataloader for 1 epoch and count the number of items yielded + # num_items_yielded should be data_size since the state dict was saved after the end of epoch + num_items_yielded = self._count_items_yielded(dataloader_sd_out) + self.assertEqual(num_items_yielded, data_size) + + @parameterized.expand(itertools.product([100], [0, 2], [1], [False, True])) + def test_end_of_epoch_behavior(self, datasize, num_workers, batch_size, shuffle): + self._run(datasize, num_workers, batch_size, shuffle) + + +class TestNotStatefulSamplerSDL_shard0(TestCase): + def get_map_dl(self, data_size, num_workers, batch_size, sampler_cls): + dataset = DummyMapDataset(data_size, shuffle=False) + sampler = sampler_cls(dataset) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + sampler=sampler, + multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + ) + + def _run(self, data_size, num_workers, batch_size, interrupt, sampler_cls): + torch.manual_seed(0) # Fixing seed for deterministic results + dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + sampler_cls=sampler_cls, + ) + # interrupt the dataloader after interrupt batches and save the state dict + results_dataloader1 = [] + for i, batch in enumerate(dataloader1): + results_dataloader1.append(batch) + if i == interrupt: + break + state_dict = dataloader1.state_dict() + + torch.manual_seed( + 0 + ) # We need to fix seed again so that before fast forwarding we are at the same state of gen as before + resumed_dataloader1 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + sampler_cls=sampler_cls, + ) + resumed_dataloader1.load_state_dict(state_dict) + + for batch in resumed_dataloader1: + results_dataloader1.append(batch) + + # now start a completely new dataloader and get all the batches + torch.manual_seed(0) + dataloader2 = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + sampler_cls=sampler_cls, + ) + results_dataloader2 = [] + for batch in dataloader2: + results_dataloader2.append(batch) + self.assertEqual(results_dataloader1, results_dataloader2) + + @parameterized.expand( + itertools.product( + [100], + [0, 2], + [1], + [10, 50, 80], + [torch.utils.data.RandomSampler, torch.utils.data.SequentialSampler], + ) + ) + def test_notstatefulSDL(self, data_size, num_workers, batch_size, interrupt, sampler_cls): + self._run(100, 0, 1, interrupt, sampler_cls) + + class TestMultiEpochState_shard0(TestCase): def get_iterable_dl(self, pw, num_workers): data_size = [25, 50, 100, 75] diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 4effec1d1..d0ae1a3eb 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -5,11 +5,12 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Any, Dict, Iterator, Optional, Sized +from typing import Any, Dict, Iterator, List, Optional, Sized import torch.utils.data.sampler from torch.utils.data import Dataset from torch.utils.data.dataloader import _InfiniteConstantSampler +from torch.utils.data.sampler import Sampler from .stateful import Stateful @@ -18,58 +19,122 @@ class _StatefulRandomSamplerIterator(Iterator[int], Stateful): _GENERATOR = "generator" _YIELDED = "yielded" - def __init__(self, sampler, parent_iterator: Iterator[int]): + def __init__(self, sampler): self.sampler = sampler - self.parent_iterator = parent_iterator + self.generator_state = self.sampler.generator.get_state() self.yielded = 0 self.next_yielded = None - self.generator_state = sampler.generator.get_state() + self.n = len(sampler.data_source) + self.replacement = sampler.replacement + self.num_samples = sampler.num_samples + self.chunk_size = 32 + self.perm: List[int] = self._get_perm() + self.perm_index = 0 + self.chunk_index = 0 - def __next__(self) -> int: - if self.next_yielded is not None: - for _ in range(self.next_yielded): - next(self.parent_iterator) - - self.yielded = self.next_yielded - self.next_yielded = None - - val = next(self.parent_iterator) + def __iter__(self): + return self + + def _get_perm(self) -> List[int]: + if self.replacement: + return torch.randint( + high=self.n, + size=(self.chunk_size,), + dtype=torch.int64, + generator=self.sampler.generator, + ).tolist() + else: + return torch.randperm(self.n, generator=self.sampler.generator).tolist() + + def __next__(self): + if self.yielded == self.num_samples: + raise StopIteration() + if self.perm_index == len(self.perm): + self.perm = self._get_perm() + self.perm_index = 0 + val = self.perm[self.perm_index] + self.perm_index += 1 self.yielded += 1 return val - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self.generator_state = state_dict[self._GENERATOR] - self.sampler.generator.set_state(state_dict[self._GENERATOR]) + def state_dict(self) -> dict: + return { + self._YIELDED: self.yielded, + self._GENERATOR: self.generator_state, + } + + def load_state_dict(self, state_dict: dict) -> None: self.next_yielded = state_dict[self._YIELDED] + self.generator_state = state_dict[self._GENERATOR] + self.sampler.generator.set_state(self.generator_state) - def state_dict(self) -> Dict[str, Any]: - return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded} + if self.next_yielded is not None: + self.perm = self._get_perm() # We want permutations from the latest generator state that's loaded + for _ in range(self.next_yielded): + next(self) + self.yielded = self.next_yielded + self.next_yielded = None -class RandomSampler(torch.utils.data.sampler.RandomSampler): +class RandomSampler(Sampler[int]): def __init__( - self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None - ): + self, + data_source: Sized, + replacement: bool = False, + num_samples: Optional[int] = None, + generator=None, + ) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples if generator is None: # Ensure that underlying sampler has something repeatable generator = torch.Generator() generator.manual_seed(1) - super().__init__(data_source, replacement, num_samples, generator) + self.generator = generator + if not isinstance(self.replacement, bool): + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") - def __iter__(self): - return _StatefulRandomSamplerIterator(self, super().__iter__()) + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + def __iter__(self) -> Iterator[int]: + return _StatefulRandomSamplerIterator(self) -class BatchSampler(torch.utils.data.sampler.BatchSampler, Stateful): + def __len__(self) -> int: + return self.num_samples + + +class _BatchSamplerIterator(Iterator[list[int]], Stateful): _SAMPLES_YIELDED = "samples_yielded" _SAMPLER_STATE = "sampler_state" _SAMPLER_ITER_STATE = "sampler_iter_state" - def __init__(self, sampler, batch_size, drop_last): - super().__init__(sampler, batch_size, drop_last) + def __init__(self, sampler, batch_size: int, drop_last: bool): + self.sampler = sampler + self.sampler_iter = iter(self.sampler) + self.batch_size = batch_size + self.drop_last = drop_last self.samples_yielded = 0 - self.next_yielded = None - self.sampler_iter = iter(sampler) + + def __next__(self) -> list[int]: + batch = [] + try: + for _ in range(self.batch_size): + batch.append(next(self.sampler_iter)) + self.samples_yielded += 1 + return batch + except StopIteration: + if self.drop_last or len(batch) == 0: + raise StopIteration + else: + return batch def state_dict(self) -> Dict[str, Any]: sd: Dict[str, Any] = {self._SAMPLES_YIELDED: self.samples_yielded} @@ -80,7 +145,7 @@ def state_dict(self) -> Dict[str, Any]: return sd def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - self.next_yielded = state_dict[self._SAMPLES_YIELDED] + self.samples_yielded = state_dict[self._SAMPLES_YIELDED] if self._SAMPLER_STATE in state_dict: assert isinstance(self.sampler, Stateful) self.sampler.load_state_dict(state_dict[self._SAMPLER_STATE]) @@ -89,44 +154,28 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: assert isinstance(self.sampler_iter, Stateful) self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE]) + if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( + self.sampler, _InfiniteConstantSampler + ): + # We skip x samples if underlying sampler is not stateful + for _ in range(self.samples_yielded): + next(self.sampler_iter) + + def update_state_dict(self) -> None: + if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"): + self.sampler_iter.update_state_dict() + + +class BatchSampler(torch.utils.data.sampler.BatchSampler): + def __init__(self, sampler, batch_size, drop_last): + super().__init__(sampler, batch_size, drop_last) + def __iter__(self): - if self.next_yielded is not None: - self.samples_yielded = self.next_yielded - if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance( - self.sampler, _InfiniteConstantSampler - ): - # We skip x samples if underlying sampler is not stateful - for _ in range(self.next_yielded): - next(self.sampler_iter) - self.next_yielded = None - elif self.samples_yielded > 0: - # don't re-create sampler_iter unless necessary, we may already have one from init - self.sampler_iter = iter(self.sampler) - self.samples_yielded = 0 - - if self.drop_last: - while True: - try: - batch = [] - for _ in range(self.batch_size): - batch.append(next(self.sampler_iter)) - self.samples_yielded += 1 - yield batch - except StopIteration: - break - else: - batch = [0] * self.batch_size - idx_in_batch = 0 - for idx in self.sampler_iter: - self.samples_yielded += 1 - batch[idx_in_batch] = idx - idx_in_batch += 1 - if idx_in_batch == self.batch_size: - yield batch - idx_in_batch = 0 - batch = [0] * self.batch_size - if idx_in_batch > 0: - yield batch[:idx_in_batch] + return _BatchSamplerIterator( + sampler=self.sampler, + batch_size=self.batch_size, + drop_last=self.drop_last, + ) class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler): diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 078b378ee..1ffeec298 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -479,7 +479,11 @@ def __init__(self, loader, next_iter_state=None): self.load_state_dict(next_iter_state) else: self._dataset_fetcher = _DatasetKind.create_fetcher( - self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, ) def _next_data(self): @@ -528,7 +532,10 @@ def load_state_dict(self, state_dict): if state_dict[_SAMPLER_ITER_STATE] is not None: self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) else: - if not isinstance(self._index_sampler, torch.utils.data.dataloader._InfiniteConstantSampler): + if not isinstance( + self._index_sampler, + torch.utils.data.dataloader._InfiniteConstantSampler, + ): # Fallback to fastforward self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) self._num_yielded = state_dict[self._NUM_YIELDED] @@ -542,7 +549,11 @@ def load_state_dict(self, state_dict): if state_dict[_DATASET_STATE] is not None and isinstance(self._dataset, Stateful): self._dataset = try_to_deserialize(self._dataset, state_dict[_DATASET_STATE]) self._dataset_fetcher = _DatasetKind.create_fetcher( - self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last + self._dataset_kind, + self._dataset, + self._auto_collation, + self._collate_fn, + self._drop_last, ) if self._dataset_kind == _DatasetKind.Iterable: # If either dataset or it's iter is stateful, we don't fast-forward @@ -907,7 +918,10 @@ def __init__(self, loader, next_iter_state): # Additional worker init function will take care of sharding in MP and Distributed if isinstance(self._dataset, (IterDataPipe, MapDataPipe)): self._worker_init_fn = functools.partial( - _sharding_worker_init_fn, self._worker_init_fn, self._world_size, self._rank + _sharding_worker_init_fn, + self._worker_init_fn, + self._world_size, + self._rank, ) # No certainty which module multiprocessing_context is @@ -1462,7 +1476,10 @@ def _restore_main_state(self, state_dict): if state_dict[_SAMPLER_ITER_STATE] is not None: self._sampler_iter = try_to_deserialize(self._sampler_iter, state_dict[_SAMPLER_ITER_STATE]) else: - if not isinstance(self._index_sampler, torch.utils.data.dataloader._InfiniteConstantSampler): + if not isinstance( + self._index_sampler, + torch.utils.data.dataloader._InfiniteConstantSampler, + ): # Fallback to fastforward self._sampler_iter = itertools.islice(self._index_sampler, self._sampler_iter_yielded, None) self._IterableDataset_len_called = state_dict[_ITERABLEDATASET_LEN_CALLED] @@ -1540,7 +1557,10 @@ def _take_snapshot(self): # in_order is False and no main snapshot is available as we're ahead of rcvd_idx # we can't take a snapshot with the current implementation return - assert main_snapshot_idx == self._rcvd_idx - 1, (main_snapshot_idx, self._rcvd_idx - 1) + assert main_snapshot_idx == self._rcvd_idx - 1, ( + main_snapshot_idx, + self._rcvd_idx - 1, + ) self._update_snapshot( self._num_yielded + 1, self._last_yielded_worker_id, From 61ec72334bb793668d5b01b7b56d11e098a3529c Mon Sep 17 00:00:00 2001 From: Ramanish Singh <48493179+ramanishsingh@users.noreply.github.com> Date: Wed, 19 Feb 2025 10:13:46 -0800 Subject: [PATCH 2/2] Using system generated seed in RandomSampler (#1441) * add new sampler tests * update seed generation in sampler * run precommit * update seed generation * change variable name * update comment * add seed to tests * run precommit --- test/stateful_dataloader/test_sampler.py | 64 ++++++++++++++++++--- test/stateful_dataloader/test_state_dict.py | 4 ++ torchdata/stateful_dataloader/sampler.py | 5 +- 3 files changed, 63 insertions(+), 10 deletions(-) diff --git a/test/stateful_dataloader/test_sampler.py b/test/stateful_dataloader/test_sampler.py index 665cf5a36..7b172c8c0 100644 --- a/test/stateful_dataloader/test_sampler.py +++ b/test/stateful_dataloader/test_sampler.py @@ -14,7 +14,7 @@ from torch.utils.data import Dataset from torchdata.stateful_dataloader import StatefulDataLoader -from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler +from torchdata.stateful_dataloader.sampler import RandomSampler, StatefulDistributedSampler class MockDataset(Dataset): @@ -34,7 +34,10 @@ def __getitem__(self, idx): "Fails with TSAN with the following error: starting new threads after multi-threaded " "fork is not supported. Dying (set die_after_fork=0 to override)", ) -@unittest.skipIf(TEST_WITH_ASAN, "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223") +@unittest.skipIf( + TEST_WITH_ASAN, + "DataLoader tests hang in ASAN, see: https://github.com/pytorch/pytorch/issues/66223", +) class TestDataLoader(TestCase): def setUp(self): super().setUp() @@ -44,7 +47,12 @@ def setUp(self): def test_initialization_StatefulDistributedSampler(self): sampler = StatefulDistributedSampler( - self.dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False + self.dataset, + num_replicas=10, + rank=0, + shuffle=False, + seed=42, + drop_last=False, ) self.assertEqual(sampler.dataset, self.dataset) self.assertEqual(sampler.num_replicas, 10) @@ -139,7 +147,8 @@ def test_drop_last_effect(self): ) self.assertTrue( - len(indices_with_drop) <= len(indices_without_drop), "Drop last should result in fewer or equal indices" + len(indices_with_drop) <= len(indices_without_drop), + "Drop last should result in fewer or equal indices", ) def test_data_order_with_shuffle(self): @@ -153,7 +162,11 @@ def test_data_order_with_shuffle(self): for batch in dataloader: data_loaded.extend(batch) self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded") - self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler") + self.assertEqual( + data_loaded, + data_sampled, + "Data loaded by DataLoader should match data sampled by sampler", + ) def test_data_order_without_shuffle(self): sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False) @@ -167,8 +180,16 @@ def test_data_order_without_shuffle(self): for batch in dataloader: data_loaded.extend(batch) self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded") - self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler") - self.assertEqual(data_loaded, list(range(100)), "Data loaded by DataLoader should be in original order") + self.assertEqual( + data_loaded, + data_sampled, + "Data loaded by DataLoader should match data sampled by sampler", + ) + self.assertEqual( + data_loaded, + list(range(100)), + "Data loaded by DataLoader should be in original order", + ) def test_data_distribution_across_replicas(self): num_replicas = 5 @@ -181,9 +202,36 @@ def test_data_distribution_across_replicas(self): data_loaded.extend([int(x.item()) for x in batch]) all_data.extend(data_loaded) self.assertEqual( - sorted(all_data), list(range(100)), "All data points should be covered exactly once across all replicas" + sorted(all_data), + list(range(100)), + "All data points should be covered exactly once across all replicas", ) + def test_seed_replicability(self): + # Test that the same seed will result in the same data order + # We first pick a random number as seed, then use it to initialize two dataloaders + min_seed, max_seed = 0, 1000 # [min_seed, max_seed) + seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item() + torch.manual_seed(seed) + + dataloader1 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True) + results1 = list(dataloader1) + + # Repeat the same process with the same seed + torch.manual_seed(seed) + dataloader2 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True) + results2 = list(dataloader2) + + # Repeat the same process with a different seed, making sure that the seed is different + min_seed, max_seed = 1000, 2000 # [min_seed, max_seed) + seed = torch.randint(min_seed, max_seed, (1,), dtype=torch.int64).item() + torch.manual_seed(seed) + dataloader3 = StatefulDataLoader(self.dataset, batch_size=1, shuffle=True) + results3 = list(dataloader3) + + self.assertEqual(results1, results2, "Data should be replicable with same seed") + self.assertNotEqual(results1, results3, "Data should not be replicable with different seed") + if __name__ == "__main__": run_tests() diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 01bb17ed5..ab7dbfb08 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -1458,6 +1458,8 @@ def get_map_dl(self, data_size, num_workers, batch_size, shuffle): ) def _run(self, data_size, num_workers, batch_size, shuffle): + # For reproducibility of testing, fixing the seed + torch.manual_seed(0) dataloader1 = self.get_map_dl( data_size=data_size, num_workers=num_workers, @@ -1493,6 +1495,8 @@ def _run(self, data_size, num_workers, batch_size, shuffle): self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4) # now run a second dataloder for 4 epochs and check if the order is same. + # we need to fix the seed again since we want to bring the initial conditions to the same state as at the time of instantiating the first dataloader + torch.manual_seed(0) dataloader2 = self.get_map_dl( data_size=data_size, num_workers=num_workers, diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index d0ae1a3eb..cacb1d12c 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -88,9 +88,10 @@ def __init__( self.replacement = replacement self._num_samples = num_samples if generator is None: - # Ensure that underlying sampler has something repeatable + # Prevoiusly the random seed was fixed as 1. We then changed it to system generated seed to ensure deterministic randomness. + seed = int(torch.empty((), dtype=torch.int64).random_().item()) generator = torch.Generator() - generator.manual_seed(1) + generator.manual_seed(seed) self.generator = generator if not isinstance(self.replacement, bool): raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")