Skip to content

Commit

Permalink
Fix end of epoch StatefulDataLoader restart (#1439)
Browse files Browse the repository at this point in the history
* add test for end of epoch state dict check

* run precommit

update stateful_dataloader

run precommit

local changes

update test to test the order of batches

update test

update tests

revert changes in SDL

revert changes in SDL

update tests

run precommit

* update sampler

* run precommit

* remove unnecessary comment

* add test for statedict before and after endofepoch

* run precommit

* check if _sampler_iter is exhausted

* run precommit

* remove commented lines

* remove default values

* only exhaust sampler_iter if present in sd

* update _StatefulRandomSamplerIterator

update state dict if the iterator has finished

add comment about why were updating state dict

run precommit

* update randomsampleriter state_dict fully

* run precommit

* fork torch.utils.data RandomSampler

reverse changes to sdl.py

generator to iterator

run precommit

update generator usage

* update class name

* run precommit

* add a method to generate permutations

* update return type

* update next logic

* add comment

* update tests to include non stateful samplers

* add comments
  • Loading branch information
ramanishsingh authored Feb 18, 2025
1 parent fe6b405 commit f15fd3a
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 75 deletions.
209 changes: 207 additions & 2 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch
import torch.utils.data

from parameterized import parameterized
from torch.testing._internal.common_utils import IS_MACOS, TEST_CUDA, TestCase
from torchdata.stateful_dataloader import Stateful, StatefulDataLoader

Expand Down Expand Up @@ -1314,7 +1316,7 @@ def test(self):
dataset=dataset,
num_workers=num_workers,
collate_fn=identity,
multiprocessing_context="forkserver" if IS_MACOS and num_workers else None,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)
it = iter(dl)
# Fetch at least one batch from each worker
Expand All @@ -1325,7 +1327,10 @@ def test(self):
if num_workers > 0:
for i in range(num_workers):
# Ensure worker state is stored only once if the dataset is also the iterator
self.assertEqual(state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], None)
self.assertEqual(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"],
None,
)
self.assertTrue(
state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][
"dataset_iter_state"
Expand Down Expand Up @@ -1441,6 +1446,206 @@ def test_fast_state_dict_request_skip_steps(self) -> None:
self._run_test(17, 19)


class TestMultiEpochSDL_shard0(TestCase):
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
dataset = DummyMapDataset(data_size, shuffle=False)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _run(self, data_size, num_workers, batch_size, shuffle):
dataloader1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
# Run through the dataloader for 2 epochs and count the number of items yielded
num_items_yielded = 0
dataloader1_items = []
for _ in range(2):
for batch in dataloader1:
dataloader1_items.append(batch)
num_items_yielded += 1
# Save the state dict
state_dict = dataloader1.state_dict()
# Create a new StatefulDataLoader instance and load the state dict
new_dataloader1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
new_dataloader1.load_state_dict(state_dict)
# Run through the new dataloader for another 2 epochs and count the number of items yielded
additional_num_items_yielded = 0
for i in range(2):
epoch_num_items_yielded = 0
for batch in new_dataloader1:
dataloader1_items.append(batch)
epoch_num_items_yielded += 1
additional_num_items_yielded += epoch_num_items_yielded
# Check that the total number of items yielded is correct
self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4)

# now run a second dataloder for 4 epochs and check if the order is same.
dataloader2 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dataloader2_items = []
for _ in range(4):
for batch in dataloader2:
dataloader2_items.append(batch)

self.assertEqual(dataloader1_items, dataloader2_items)

@parameterized.expand(itertools.product([100], [0, 2], [1], [False, True]))
def test_multi_epoch_sdl(self, datasize, num_workers, batch_size, shuffle):
self._run(datasize, num_workers, batch_size, shuffle)


class TestEndOfEpochBehavior_shard0(TestCase):
def get_map_dl(self, data_size, num_workers, batch_size, shuffle):
dataset = DummyMapDataset(data_size, shuffle=False)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int:
num_items_yielded = 0
for batch in data_loader:
num_items_yielded += 1
return num_items_yielded

def _run(self, data_size, num_workers, batch_size, shuffle):
dataloader = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
# Run through the dataloader for 1 epoch and count the number of items yielded
num_items_yielded = 0

for batch in dataloader:
num_items_yielded += 1
sd_in = dataloader.state_dict()
sd_out = dataloader.state_dict()

self.assertEqual(num_items_yielded, data_size)

# Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch
dataloader_sd_in = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dataloader_sd_in.load_state_dict(sd_in)

# Run through the new dataloader for 1 epoch and count the number of items yielded
# num_items_yielded should be 0 since the state dict was saved before the end of epoch
num_items_yielded = self._count_items_yielded(dataloader_sd_in)
self.assertEqual(num_items_yielded, 0)

# Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch
dataloader_sd_out = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
)
dataloader_sd_out.load_state_dict(sd_out)

# Run through the new dataloader for 1 epoch and count the number of items yielded
# num_items_yielded should be data_size since the state dict was saved after the end of epoch
num_items_yielded = self._count_items_yielded(dataloader_sd_out)
self.assertEqual(num_items_yielded, data_size)

@parameterized.expand(itertools.product([100], [0, 2], [1], [False, True]))
def test_end_of_epoch_behavior(self, datasize, num_workers, batch_size, shuffle):
self._run(datasize, num_workers, batch_size, shuffle)


class TestNotStatefulSamplerSDL_shard0(TestCase):
def get_map_dl(self, data_size, num_workers, batch_size, sampler_cls):
dataset = DummyMapDataset(data_size, shuffle=False)
sampler = sampler_cls(dataset)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
sampler=sampler,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

def _run(self, data_size, num_workers, batch_size, interrupt, sampler_cls):
torch.manual_seed(0) # Fixing seed for deterministic results
dataloader1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
sampler_cls=sampler_cls,
)
# interrupt the dataloader after interrupt batches and save the state dict
results_dataloader1 = []
for i, batch in enumerate(dataloader1):
results_dataloader1.append(batch)
if i == interrupt:
break
state_dict = dataloader1.state_dict()

torch.manual_seed(
0
) # We need to fix seed again so that before fast forwarding we are at the same state of gen as before
resumed_dataloader1 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
sampler_cls=sampler_cls,
)
resumed_dataloader1.load_state_dict(state_dict)

for batch in resumed_dataloader1:
results_dataloader1.append(batch)

# now start a completely new dataloader and get all the batches
torch.manual_seed(0)
dataloader2 = self.get_map_dl(
data_size=data_size,
num_workers=num_workers,
batch_size=batch_size,
sampler_cls=sampler_cls,
)
results_dataloader2 = []
for batch in dataloader2:
results_dataloader2.append(batch)
self.assertEqual(results_dataloader1, results_dataloader2)

@parameterized.expand(
itertools.product(
[100],
[0, 2],
[1],
[10, 50, 80],
[torch.utils.data.RandomSampler, torch.utils.data.SequentialSampler],
)
)
def test_notstatefulSDL(self, data_size, num_workers, batch_size, interrupt, sampler_cls):
self._run(100, 0, 1, interrupt, sampler_cls)


class TestMultiEpochState_shard0(TestCase):
def get_iterable_dl(self, pw, num_workers):
data_size = [25, 50, 100, 75]
Expand Down
Loading

0 comments on commit f15fd3a

Please sign in to comment.