From 0dce97679c93b1c0671f906791003a51ca88e740 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 21 Aug 2024 11:07:02 -0700 Subject: [PATCH] define methods in _StatefulDistributedSamplerIterator --- torchdata/stateful_dataloader/sampler.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchdata/stateful_dataloader/sampler.py b/torchdata/stateful_dataloader/sampler.py index 21fa74a85..6e9ef7220 100644 --- a/torchdata/stateful_dataloader/sampler.py +++ b/torchdata/stateful_dataloader/sampler.py @@ -41,7 +41,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.generator_state = state_dict[self._GENERATOR] self.sampler.generator.set_state(state_dict[self._GENERATOR]) self.next_yielded = state_dict[self._YIELDED] - return None def state_dict(self) -> Dict[str, Any]: return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded} @@ -175,6 +174,12 @@ def __next__(self) -> int: self.sampler.yielded += 1 return val + def state_dict(self) -> Dict[str, Any]: + return self.sampler.state_dict() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.sampler.load_state_dict(state_dict) + class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler): _YIELDED = "yielded"