Skip to content

Commit

Permalink
Add tests for out of order with checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-diggin committed Jan 28, 2025
1 parent e25df94 commit 3f6af6a
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 15 deletions.
97 changes: 97 additions & 0 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import itertools
import json
import math
import time
import unittest
from copy import deepcopy

Expand Down Expand Up @@ -1632,5 +1634,100 @@ def test_mp(self):
self._run_test(2, CountIterCallsIter(100))


class _TestSlowIndexDataset(torch.utils.data.Dataset):
def __init__(self, end: int, slow_index: int):
self.end = end
self.slow_index = slow_index
self._worker_id = None

def __getitem__(self, idx):
if idx == self.slow_index:
time.sleep(1.0)
return idx

def __len__(self):
return self.end


class _TestSlowIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start: int, end: int):
self.start = start
self.end = end
self.mid = math.ceil((self.end - self.start) / 2)

def give_data(self, iter_start, iter_end):
for i in range(iter_start, iter_end):
if i == self.mid:
time.sleep(1.0)
yield i

def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return self.give_data(iter_start, iter_end)


class TestOutOfOrderWithCheckpointing(TestCase):
def test_out_of_order_index_ds(self):
dataset = _TestSlowIndexDataset(end=10, slow_index=0)
dataloader = StatefulDataLoader(
dataset,
num_workers=2,
in_order=False,
)

# worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue
# due to prefetch_factor being 2
output = []
for i, data in enumerate(dataloader):
output.append(data)
if i == 5:
state_dict = dataloader.state_dict()
break

new_dataloader = StatefulDataLoader(dataset, num_workers=2, in_order=False)
new_dataloader.load_state_dict(state_dict)
for i, data in enumerate(new_dataloader):
output.append(data)

self.assertEqual(len(output), 10)
self.assertNotEqual(output, list(range(10)))
self.assertEqual(sorted(output), list(range(10)))

def test_out_of_order_iterable_ds(self):
dataset = _TestSlowIterableDataset(start=0, end=10)
dataloader = StatefulDataLoader(
dataset,
num_workers=2,
prefetch_factor=2,
in_order=False,
)

# break later on, as one of the workers will be finished
output = []
for i, data in enumerate(dataloader):
output.append(data)
if i == 7:
state_dict = dataloader.state_dict()
break

worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"]
worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"]
self.assertTrue(worker_0_ended)
self.assertFalse(worker_1_ended)

new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False)
new_dataloader.load_state_dict(state_dict)
for i, data in enumerate(new_dataloader):
output.append(data)

self.assertEqual(len(output), 10)
self.assertEqual(output, list(range(10)))
self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9])


if __name__ == "__main__":
unittest.main()
15 changes: 0 additions & 15 deletions torchdata/stateful_dataloader/stateful_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,6 @@ class StatefulDataLoader(DataLoader[_T_co]):
.. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data.
.. warning:: Setting `in_order` to `False` currently has no guarantees for state management.
.. _multiprocessing context:
https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods
"""
Expand Down Expand Up @@ -234,13 +232,6 @@ def __init__(
if persistent_workers and num_workers == 0:
raise ValueError("persistent_workers option needs num_workers > 0")

if num_workers > 0 and not in_order:
# TODO: remove warning log when state management is supported with in_order=False
logger.warning(
"using in_order=False with multiple workers does not give any guarantees for state management "
"and loading from a checkpoint may not work as expected."
)

self.dataset = dataset
self.num_workers = num_workers
self.prefetch_factor = prefetch_factor
Expand Down Expand Up @@ -1170,12 +1161,6 @@ def _update_worker_snapshot(self, worker_key, state_dict):
self._worker_snapshots[worker_key].apply_delta(state_dict)

def state_dict(self):
if not self._in_order:
# TODO: remove warning log when state management is supported with in_order=False
logger.warning(
"using in_order=False with multiple workers does not give any guarantees for state management "
"and loading from a checkpoint may not work as expected."
)
steps_since_snapshot = self._num_yielded - self._snapshot[self._SNAPSHOT_STEP]
state_dict = {
self._SNAPSHOT: self._snapshot,
Expand Down

0 comments on commit 3f6af6a

Please sign in to comment.