Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix end of epoch StatefulDataLoader restart #1439

Merged
merged 24 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ def test(self):
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)
it = iter(dl)
# Fetch at least one batch from each worker
Expand All @@ -1325,7 +1325,10 @@ def test(self):
if num_workers > 0:
for i in range(num_workers):
# Ensure worker state is stored only once if the dataset is also the iterator
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None)
self.assertEqual(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"],
None,
)
self.assertTrue(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][
"dataset_iter_state"
Expand Down Expand Up @@ -1441,6 +1444,79 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
self._run_test(17, 19)


class TestMultiEpochSDL_shard0(TestCase):
def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False):
dataset = DummyMapDataset(data_size, shuffle=False)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _run(self, data_size, num_workers, batch_size, shuffle=False):
dl1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
# Run through the dataloader for 2 epochs and count the number of items yielded
num_items_yielded = 0
dl1_items = []
for _ in range(2):
for batch in dl1:
dl1_items.append(batch)
num_items_yielded += 1
# Save the state dict
state_dict = dl1.state_dict()
# Create a new StatefulDataLoader instance and load the state dict
new_dl1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
new_dl1.load_state_dict(state_dict)
# Run through the new dataloader for another 2 epochs and count the number of items yielded
additional_num_items_yielded = 0
for i in range(2):
epoch_num_items_yielded = 0
for batch in new_dl1:
dl1_items.append(batch)
epoch_num_items_yielded += 1
additional_num_items_yielded += epoch_num_items_yielded
# Check that the total number of items yielded is correct
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4)

# now run a second dataloder for 4 epochs and check if the order is same.
dl2 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dl2_items = []
for _ in range(4):
for batch in dl2:
dl2_items.append(batch)

self.assertEqual(dl1_items, dl2_items)

def test_main_process(self):
self._run(100, 0, 1, False)

def test_multiprocess(self):
self._run(100, 2, 1, False)

def test_main_process_shuffle(self):
self._run(100, 0, 1, True)

def test_multiprocess_shuffle(self):
self._run(100, 2, 1, True)


class TestMultiEpochState_shard0(TestCase):
def get_iterable_dl(self, pw, num_workers):
data_size = [25, 50, 100, 75]
Expand Down
91 changes: 47 additions & 44 deletions torchdata/stateful_dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def state_dict(self) -> Dict[str, Any]:

class RandomSampler(torch.utils.data.sampler.RandomSampler):
def __init__(
self, data_source: Sized, replacement: bool = False, num_samples: Optional[int] = None, generator=None
self,
data_source: Sized,
replacement: bool = False,
num_samples: Optional[int] = None,
generator=None,
):
if generator is None:
# Ensure that underlying sampler has something repeatable
Expand All @@ -60,16 +64,30 @@ def __iter__(self):
return _StatefulRandomSamplerIterator(self, super().__iter__())


class BatchSampler(torch.utils.data.sampler.BatchSampler, Stateful):
class _BatchSamplerIterator(Iterator[list[int]], Stateful):
_SAMPLES_YIELDED = "samples_yielded"
_SAMPLER_STATE = "sampler_state"
_SAMPLER_ITER_STATE = "sampler_iter_state"

def __init__(self, sampler, batch_size, drop_last):
super().__init__(sampler, batch_size, drop_last)
def __init__(self, sampler, batch_size: int, drop_last: bool):
self.sampler = sampler
self.sampler_iter = iter(self.sampler)
self.batch_size = batch_size
self.drop_last = drop_last
self.samples_yielded = 0
self.next_yielded = None
self.sampler_iter = iter(sampler)

def __next__(self) -> list[int]:
batch = []
try:
for _ in range(self.batch_size):
batch.append(next(self.sampler_iter))
self.samples_yielded += 1
return batch
except StopIteration:
if self.drop_last or len(batch) == 0:
raise StopIteration
else:
return batch

def state_dict(self) -> Dict[str, Any]:
sd: Dict[str, Any] = {self._SAMPLES_YIELDED: self.samples_yielded}
Expand All @@ -80,7 +98,7 @@ def state_dict(self) -> Dict[str, Any]:
return sd

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.next_yielded = state_dict[self._SAMPLES_YIELDED]
self.samples_yielded = state_dict[self._SAMPLES_YIELDED]
if self._SAMPLER_STATE in state_dict:
assert isinstance(self.sampler, Stateful)
self.sampler.load_state_dict(state_dict[self._SAMPLER_STATE])
Expand All @@ -89,44 +107,29 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
assert isinstance(self.sampler_iter, Stateful)
self.sampler_iter.load_state_dict(state_dict[self._SAMPLER_ITER_STATE])

if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance(
self.sampler, _InfiniteConstantSampler
):
# We skip x samples if underlying sampler is not stateful
for _ in range(self.samples_yielded):
next(self.sampler_iter)

# Skip one epoch if we were at the end of the last epoch
if hasattr(self.sampler, "__len__") and self.samples_yielded == len(self.sampler):
for _ in self.sampler_iter:
pass


class BatchSampler(torch.utils.data.sampler.BatchSampler):
def __init__(self, sampler, batch_size, drop_last):
super().__init__(sampler, batch_size, drop_last)

def __iter__(self):
if self.next_yielded is not None:
self.samples_yielded = self.next_yielded
if not (isinstance(self.sampler, Stateful) or isinstance(self.sampler_iter, Stateful)) and not isinstance(
self.sampler, _InfiniteConstantSampler
):
# We skip x samples if underlying sampler is not stateful
for _ in range(self.next_yielded):
next(self.sampler_iter)
self.next_yielded = None
elif self.samples_yielded > 0:
# don't re-create sampler_iter unless necessary, we may already have one from init
self.sampler_iter = iter(self.sampler)
self.samples_yielded = 0

if self.drop_last:
while True:
try:
batch = []
for _ in range(self.batch_size):
batch.append(next(self.sampler_iter))
self.samples_yielded += 1
yield batch
except StopIteration:
break
else:
batch = [0] * self.batch_size
idx_in_batch = 0
for idx in self.sampler_iter:
self.samples_yielded += 1
batch[idx_in_batch] = idx
idx_in_batch += 1
if idx_in_batch == self.batch_size:
yield batch
idx_in_batch = 0
batch = [0] * self.batch_size
if idx_in_batch > 0:
yield batch[:idx_in_batch]
return _BatchSamplerIterator(
sampler=self.sampler,
batch_size=self.batch_size,
drop_last=self.drop_last,
)


class StatefulDistributedSampler(torch.utils.data.distributed.DistributedSampler):
Expand Down