From 9fb6fd475bc548ac7fce27a604bf82c53088d411 Mon Sep 17 00:00:00 2001 From: Gokul Gunasekaran Date: Mon, 8 Apr 2024 22:00:15 -0700 Subject: [PATCH] fix random sampler patch state management --- test/stateful_dataloader/test_state_dict.py | 35 ++++++++++++- torchdata/stateful_dataloader/sampler.py | 54 ++++++++++++--------- 2 files changed, 66 insertions(+), 23 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 3130b53c1..7e7374a30 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -1,4 +1,3 @@ -import copy import itertools import unittest from typing import Iterator @@ -547,6 +546,40 @@ def test_map_shuffle(self): self.assertEqual(batches, exp) + def test_map_iterrupted_shuffle(self): + every_n_steps = 10 + + for pw, num_workers, every_n_steps in itertools.product([False, True], [0, 2], [1, 5, 10, 15]): + shuffle = True + dataset = DummyMapDataset(10, shuffle=False) + dl = StatefulDataLoader( + dataset=dataset, + shuffle=shuffle, + num_workers=num_workers, + collate_fn=identity, + snapshot_every_n_steps=every_n_steps, + persistent_workers=pw if num_workers > 0 else False, + ) + + it = iter(dl) + state0 = dl.state_dict() + exp = [] + for _ in range(4): + exp.append(next(it)) + state1 = dl.state_dict() + + dl.load_state_dict(state1) + it = iter(dl) + for data in it: + exp.append(data) + + dl.load_state_dict(state0) + batches = [] + for data in iter(dl): + batches.append(data) + + self.assertEqual(batches, exp) + class TestSnapshotEnd(unittest.TestCase): def test_generator(self): diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 997fc98ad..64c9d8660 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Sized +from typing import Any, Dict, Iterator, Optional, Sized import torch.utils.data.sampler from torch.utils.data.dataloader import _InfiniteConstantSampler @@ -6,37 +6,47 @@ from .stateful import Stateful -class RandomSampler(torch.utils.data.sampler.RandomSampler, Stateful): +class _StatefulRandomSamplerIterator(Iterator[int], Stateful): + def __init__(self, sampler, parent_iterator: Iterator[int]): + self.sampler = sampler + self.parent_iterator = parent_iterator + self.yielded = 0 + self.next_yielded = None + self.generator_state = sampler.generator.get_state() + + def __next__(self) -> int: + if self.next_yielded is not None: + for _ in range(self.next_yielded): + next(self.parent_iterator) + + self.yielded = self.next_yielded + self.next_yielded = None + + val = next(self.parent_iterator) + self.yielded += 1 + return val + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.generator_state = state_dict["generator"] + self.sampler.generator.set_state(state_dict["generator"]) + self.next_yielded = state_dict["yielded"] + + def state_dict(self) -> Dict[str, Any]: + return {"generator": self.generator_state, "yielded": self.yielded} + + +class RandomSampler(torch.utils.data.sampler.RandomSampler): def __init__( self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None ): - if generator is None: # Ensure that underlying sampler has something repeatable generator = torch.Generator() generator.manual_seed(1) super().__init__(data_source, replacement, num_samples, generator) - self.yielded = 0 - self.next_yielded = None - - def state_dict(self) -> Dict[str, Any]: - return {"generator": self.generator.get_state() if self.generator else None, "yielded": self.yielded} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - if state_dict["generator"] is not None: - self.generator.set_state(state_dict["generator"]) - self.next_yielded = state_dict["yielded"] def __iter__(self): - super_iter = super().__iter__() - self.yielded = self.next_yielded or 0 - while True: - try: - val = next(super_iter) - yield val - self.yielded += 1 - except StopIteration: - return + return _StatefulRandomSamplerIterator(self, super().__iter__()) torch.utils.data.sampler.RandomSampler = RandomSampler # type: ignore[misc]