Skip to content

Commit

Permalink
define methods in _StatefulDistributedSamplerIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Aug 21, 2024
1 parent fe08bfc commit 0dce976
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.generator_state = state_dict[self._GENERATOR]
self.sampler.generator.set_state(state_dict[self._GENERATOR])
self.next_yielded = state_dict[self._YIELDED]
return None

def state_dict(self) -> Dict[str, Any]:
return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded}
Expand Down Expand Up @@ -175,6 +174,12 @@ def __next__(self) -> int:
self.sampler.yielded += 1
return val

def state_dict(self) -> Dict[str, Any]:
return self.sampler.state_dict()

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.sampler.load_state_dict(state_dict)


class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler):
_YIELDED = "yielded"
Expand Down

0 comments on commit 0dce976

Please sign in to comment.