From 3f6af6a5c93c00af560f6398cf7519f69435bc49 Mon Sep 17 00:00:00 2001 From: Michael Diggin Date: Tue, 28 Jan 2025 18:35:54 +0000 Subject: [PATCH] Add tests for out of order with checkpointing --- test/stateful_dataloader/test_state_dict.py | 97 +++++++++++++++++++ .../stateful_dataloader.py | 15 --- 2 files changed, 97 insertions(+), 15 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 639a46310..bdc431dc0 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -6,6 +6,8 @@ import itertools import json +import math +import time import unittest from copy import deepcopy @@ -1632,5 +1634,100 @@ def test_mp(self): self._run_test(2, CountIterCallsIter(100)) +class _TestSlowIndexDataset(torch.utils.data.Dataset): + def __init__(self, end: int, slow_index: int): + self.end = end + self.slow_index = slow_index + self._worker_id = None + + def __getitem__(self, idx): + if idx == self.slow_index: + time.sleep(1.0) + return idx + + def __len__(self): + return self.end + + +class _TestSlowIterableDataset(torch.utils.data.IterableDataset): + def __init__(self, start: int, end: int): + self.start = start + self.end = end + self.mid = math.ceil((self.end - self.start) / 2) + + def give_data(self, iter_start, iter_end): + for i in range(iter_start, iter_end): + if i == self.mid: + time.sleep(1.0) + yield i + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + worker_id = worker_info.id + iter_start = self.start + worker_id * per_worker + iter_end = min(iter_start + per_worker, self.end) + return self.give_data(iter_start, iter_end) + + +class TestOutOfOrderWithCheckpointing(TestCase): + def test_out_of_order_index_ds(self): + dataset = _TestSlowIndexDataset(end=10, slow_index=0) + dataloader = StatefulDataLoader( + dataset, + num_workers=2, + in_order=False, + ) + + # worker_id = 0 gets 'stuck' on 0 and also has 2 in it's queue + # due to prefetch_factor being 2 + output = [] + for i, data in enumerate(dataloader): + output.append(data) + if i == 5: + state_dict = dataloader.state_dict() + break + + new_dataloader = StatefulDataLoader(dataset, num_workers=2, in_order=False) + new_dataloader.load_state_dict(state_dict) + for i, data in enumerate(new_dataloader): + output.append(data) + + self.assertEqual(len(output), 10) + self.assertNotEqual(output, list(range(10))) + self.assertEqual(sorted(output), list(range(10))) + + def test_out_of_order_iterable_ds(self): + dataset = _TestSlowIterableDataset(start=0, end=10) + dataloader = StatefulDataLoader( + dataset, + num_workers=2, + prefetch_factor=2, + in_order=False, + ) + + # break later on, as one of the workers will be finished + output = [] + for i, data in enumerate(dataloader): + output.append(data) + if i == 7: + state_dict = dataloader.state_dict() + break + + worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"] + worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"] + self.assertTrue(worker_0_ended) + self.assertFalse(worker_1_ended) + + new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False) + new_dataloader.load_state_dict(state_dict) + for i, data in enumerate(new_dataloader): + output.append(data) + + self.assertEqual(len(output), 10) + self.assertEqual(output, list(range(10))) + self.assertNotEqual(output, [0, 5, 1, 6, 2, 7, 3, 8, 4, 9]) + + if __name__ == "__main__": unittest.main() diff --git a/torchdata/stateful_dataloader/stateful_dataloader.py b/torchdata/stateful_dataloader/stateful_dataloader.py index 078b378ee..271996b56 100644 --- a/torchdata/stateful_dataloader/stateful_dataloader.py +++ b/torchdata/stateful_dataloader/stateful_dataloader.py @@ -181,8 +181,6 @@ class StatefulDataLoader(DataLoader[_T_co]): .. warning:: Setting `in_order` to `False` can harm reproducibility and may lead to a skewed data distribution being fed to the trainer in cases with imbalanced data. - .. warning:: Setting `in_order` to `False` currently has no guarantees for state management. - .. _multiprocessing context: https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods """ @@ -234,13 +232,6 @@ def __init__( if persistent_workers and num_workers == 0: raise ValueError("persistent_workers option needs num_workers > 0") - if num_workers > 0 and not in_order: - # TODO: remove warning log when state management is supported with in_order=False - logger.warning( - "using in_order=False with multiple workers does not give any guarantees for state management " - "and loading from a checkpoint may not work as expected." - ) - self.dataset = dataset self.num_workers = num_workers self.prefetch_factor = prefetch_factor @@ -1170,12 +1161,6 @@ def _update_worker_snapshot(self, worker_key, state_dict): self._worker_snapshots[worker_key].apply_delta(state_dict) def state_dict(self): - if not self._in_order: - # TODO: remove warning log when state management is supported with in_order=False - logger.warning( - "using in_order=False with multiple workers does not give any guarantees for state management " - "and loading from a checkpoint may not work as expected." - ) steps_since_snapshot = self._num_yielded - self._snapshot[self._SNAPSHOT_STEP] state_dict = { self._SNAPSHOT: self._snapshot,