Skip to content

Commit

Permalink
update state dict if the iterator has finished
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Feb 7, 2025
1 parent 3a7499a commit 8a83c5f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
16 changes: 11 additions & 5 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def __init__(self, sampler, parent_iterator: Iterator[int]):
self.generator_state = sampler.generator.get_state()

def __next__(self) -> int:
if self.next_yielded is not None:
for _ in range(self.next_yielded):
next(self.parent_iterator)

self.yielded = self.next_yielded
self.next_yielded = None
val = next(self.parent_iterator)
self.yielded += 1
return val
Expand All @@ -35,12 +40,9 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.generator_state = state_dict[self._GENERATOR]
self.sampler.generator.set_state(state_dict[self._GENERATOR])
self.next_yielded = state_dict[self._YIELDED]
if self.next_yielded is not None:
for _ in range(self.next_yielded):
next(self.parent_iterator)

self.yielded = self.next_yielded
self.next_yielded = None
def update_state_dict(self) -> None:
self.generator_state = self.sampler.generator.get_state()

def state_dict(self) -> Dict[str, Any]:
return {self._GENERATOR: self.generator_state, self._YIELDED: self.yielded}
Expand Down Expand Up @@ -114,6 +116,10 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
for _ in range(self.samples_yielded):
next(self.sampler_iter)

def update_state_dict(self):
if isinstance(self.sampler_iter, Stateful) and hasattr(self.sampler_iter, "update_state_dict"):
self.sampler_iter.update_state_dict()


class BatchSampler(torch.utils.data.sampler.BatchSampler):
def __init__(self, sampler, batch_size, drop_last):
Expand Down
2 changes: 2 additions & 0 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,8 @@ def __next__(self):
try:
return super().__next__()
except StopIteration:
if hasattr(self._sampler_iter, "update_state_dict"):
self._sampler_iter.update_state_dict()
self._finished = True
raise

Expand Down

0 comments on commit 8a83c5f

Please sign in to comment.