From 297d7bfd508dd256c1dadc97aa8ec5e8deab6d89 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 5 Feb 2025 09:03:30 -0800 Subject: [PATCH] add test for statedict before and after endofepoch --- test/stateful_dataloader/test_state_dict.py | 298 ++++++++++++++++---- 1 file changed, 240 insertions(+), 58 deletions(-) diff --git a/test/stateful_dataloader/test_state_dict.py b/test/stateful_dataloader/test_state_dict.py index 15ecbe037..f8dc86e60 100644 --- a/test/stateful_dataloader/test_state_dict.py +++ b/test/stateful_dataloader/test_state_dict.py @@ -83,7 +83,9 @@ def __len__(self): return self.size -class DummyIteratorIterableDataset(torch.utils.data.IterableDataset, Iterator, Stateful): +class DummyIteratorIterableDataset( + torch.utils.data.IterableDataset, Iterator, Stateful +): def __init__(self, samples, shuffle, include_generator): self.samples = samples self.shuffle = shuffle @@ -139,7 +141,10 @@ def __iter__(self): class DummyMapDataset(torch.utils.data.Dataset): def __init__(self, size, shuffle, include_generator=True): self.size = size - self.data = [{"id": i, "strcol": f"strcol_{i}", "listcol": [i, i + 1, i + 2]} for i in range(size)] + self.data = [ + {"id": i, "strcol": f"strcol_{i}", "listcol": [i, i + 1, i + 2]} + for i in range(size) + ] self.shuffle = shuffle self.include_generator = include_generator @@ -202,7 +207,9 @@ class TestStatefulDataLoaderIterable_shard0(TestCase): def _get_dataset(self, shuffle): return DummyIterableDataset([0, 100, 37], shuffle=shuffle) - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): dataset = self._get_dataset(shuffle) dl = StatefulDataLoader( dataset=dataset, @@ -270,7 +277,9 @@ def test_mp_pw(self): def test_mp_every_n_steps(self): batch_size = 7 for every_n_steps, interrupt in itertools.product([2, 5], [0, 1, 10]): - with self.subTest(every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt): + with self.subTest( + every_n_steps=every_n_steps, batch_size=batch_size, interrupt=interrupt + ): self._run_and_checkpoint( num_workers=3, batch_size=batch_size, @@ -291,7 +300,9 @@ def test_random_state(self): class TestStatefulDataLoaderMap_shard1(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): if num_workers == 0: return dataset = DummyMapDataset(100, shuffle=shuffle) @@ -344,7 +355,9 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st class TestStatefulSampler_shard1(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): dataset = DummyMapDataset(100, shuffle=shuffle) sampler = DummySampler(len(dataset)) dl = StatefulDataLoader( @@ -472,7 +485,9 @@ def load_state_dict(self, state): class TestStatefulDataLoaderGenerator_shard2(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): dataset = GeneratorIterable([0, 100, 37]) dl = StatefulDataLoader( dataset=dataset, @@ -521,8 +536,12 @@ def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_st self.assertEqual(batches, exp) -class TestStatefulDataLoaderGeneratorNoState_shard2(TestStatefulDataLoaderIterable_shard0): - def _run_and_checkpoint(self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False): +class TestStatefulDataLoaderGeneratorNoState_shard2( + TestStatefulDataLoaderIterable_shard0 +): + def _run_and_checkpoint( + self, num_workers, batch_size, pw, interrupt, every_n_steps=1, shuffle=False + ): dataset = GeneratorIterableNoState([0, 100, 37]) dl = StatefulDataLoader( dataset=dataset, @@ -582,7 +601,9 @@ def test_generator(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -605,7 +626,9 @@ def test_iterable(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -628,7 +651,9 @@ def test_map(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -652,7 +677,9 @@ def test_map_shuffle(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -667,7 +694,9 @@ def test_map_shuffle(self): def test_map_iterrupted_shuffle(self): every_n_steps = 10 - for pw, num_workers, every_n_steps in itertools.product([False, True], [0, 2], [1, 15]): + for pw, num_workers, every_n_steps in itertools.product( + [False, True], [0, 2], [1, 15] + ): dataset = DummyMapDataset(10, shuffle=True) dl = StatefulDataLoader( dataset=dataset, @@ -676,7 +705,9 @@ def test_map_iterrupted_shuffle(self): collate_fn=identity, snapshot_every_n_steps=every_n_steps, persistent_workers=pw if num_workers > 0 else False, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) @@ -712,7 +743,9 @@ def test_generator(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) exp = list(dl) state_end = dl.state_dict() @@ -728,7 +761,9 @@ def test_generator(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) for _ in range(2): @@ -750,7 +785,9 @@ def test_generator_no_state(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) exp = list(dl) state_end = dl.state_dict() @@ -766,7 +803,9 @@ def test_generator_no_state(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) for _ in range(2): @@ -791,7 +830,9 @@ def test_iterable(self): persistent_workers=pw, batch_size=bs, generator=g, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) list(dl) state_end = dl.state_dict() @@ -806,7 +847,9 @@ def test_iterable(self): persistent_workers=pw, batch_size=bs, generator=g, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl.load_state_dict(state_end) batches = list(dl) @@ -828,7 +871,9 @@ def test_map(self): persistent_workers=pw, batch_size=bs, generator=generator, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) list(dl) state_end = dl.state_dict() @@ -843,7 +888,9 @@ def test_map(self): persistent_workers=pw, batch_size=bs, generator=generator, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl.load_state_dict(state_end) batches = list(dl) @@ -863,7 +910,9 @@ def test_map_shuffle(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) list(dl) state_end = dl.state_dict() @@ -878,7 +927,9 @@ def test_map_shuffle(self): snapshot_every_n_steps=every_n_steps, persistent_workers=pw, batch_size=bs, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl.load_state_dict(state_end) batches = list(dl) @@ -896,7 +947,9 @@ def test_num_workers_mismatch(self): dataset=dataset, num_workers=initial_num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and initial_num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and initial_num_workers else None + ), ) state = dl.state_dict() @@ -908,7 +961,9 @@ def test_num_workers_mismatch(self): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl.load_state_dict(state) try: @@ -994,7 +1049,9 @@ def test_fast_state_dict_request_skip_steps(self) -> None: class TestJsonSerDe_shard3(TestCase): def _run_test_iterable(self, num_workers): interrupt = 4 - dataset = DummyIterableDataset([0, 100, 37], shuffle=False, include_generator=False) + dataset = DummyIterableDataset( + [0, 100, 37], shuffle=False, include_generator=False + ) dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, @@ -1256,7 +1313,9 @@ def test_load_then_state(self): class TestStatefulDataLoaderIterable2_shard0(TestStatefulDataLoaderIterable_shard0): # Perform sanity test checks with the iterable dataset that is also an iterator def _get_dataset(self, shuffle): - return DummyIteratorIterableDataset(list(range(100)), shuffle=shuffle, include_generator=True) + return DummyIteratorIterableDataset( + list(range(100)), shuffle=shuffle, include_generator=True + ) class TestDynamicStateIterableDataset_shard0(TestCase): @@ -1274,7 +1333,9 @@ def test(self): for _ in range((num_workers + 1) * 2): next(it) state_dict = dl.state_dict() - worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"] + worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["dataset_iter_state"] self.assertEqual(len(worker_state), 7) deep_copy_state_dict = deepcopy(state_dict) @@ -1284,9 +1345,9 @@ def test(self): next_state_dict = dl.state_dict() self.assertEqual(state_dict, deep_copy_state_dict) self.assertFalse(state_dict == next_state_dict) - worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"][ - "dataset_iter_state" - ] + worker_state = next_state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["dataset_iter_state"] self.assertEqual(len(worker_state), 11) dl = StatefulDataLoader( @@ -1302,19 +1363,25 @@ def test(self): exp.extend(next(it)) state_dict = dl.state_dict() self.assertEqual(exp, [3, 3]) - worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["dataset_iter_state"] + worker_state = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["dataset_iter_state"] self.assertEqual(len(worker_state), 9) class TestDatasetIteratorStateDuplication_shard0(TestCase): def test(self): - dataset = DummyIteratorIterableDataset(list(range(100)), shuffle=True, include_generator=True) + dataset = DummyIteratorIterableDataset( + list(range(100)), shuffle=True, include_generator=True + ) for num_workers in (0, 2): dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) # Fetch at least one batch from each worker @@ -1326,13 +1393,15 @@ def test(self): for i in range(num_workers): # Ensure worker state is stored only once if the dataset is also the iterator self.assertEqual( - state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["dataset_state"], + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"][ + "dataset_state" + ], None, ) self.assertTrue( - state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"]["fetcher_state"][ - "dataset_iter_state" - ] + state_dict["_snapshot"]["_worker_snapshots"][f"worker_{i}"][ + "fetcher_state" + ]["dataset_iter_state"] ) else: self.assertEqual(state_dict["dataset_state"], None) @@ -1452,7 +1521,9 @@ def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): num_workers=num_workers, batch_size=batch_size, shuffle=shuffle, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) def _run(self, data_size, num_workers, batch_size, shuffle=False): @@ -1488,7 +1559,9 @@ def _run(self, data_size, num_workers, batch_size, shuffle=False): epoch_num_items_yielded += 1 additional_num_items_yielded += epoch_num_items_yielded # Check that the total number of items yielded is correct - self.assertEqual(num_items_yielded + additional_num_items_yielded, data_size * 4) + self.assertEqual( + num_items_yielded + additional_num_items_yielded, data_size * 4 + ) # now run a second dataloder for 4 epochs and check if the order is same. dl2 = self.get_map_dl( @@ -1517,6 +1590,83 @@ def test_multiprocess_shuffle(self): self._run(100, 2, 1, True) +class TestEndOfEpochBehavior_shard0(TestCase): + def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False): + dataset = DummyMapDataset(data_size, shuffle=False) + return StatefulDataLoader( + dataset=dataset, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), + ) + + def _count_items_yielded(self, data_loader: StatefulDataLoader) -> int: + num_items_yielded = 0 + for batch in data_loader: + num_items_yielded += 1 + return num_items_yielded + + def _run(self, data_size, num_workers, batch_size, shuffle=False): + dl = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + # Run through the dataloader for 1 epoch and count the number of items yielded + num_items_yielded = 0 + + for batch in dl: + num_items_yielded += 1 + sd_in = dl.state_dict() + sd_out = dl.state_dict() + + self.assertEqual(num_items_yielded, data_size) + + # Create a new StatefulDataLoader instance and load the state dict saved before the end of epoch + dl_sd_in = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dl_sd_in.load_state_dict(sd_in) + + # Run through the new dataloader for 1 epoch and count the number of items yielded + # num_items_yielded should be 0 since the state dict was saved before the end of epoch + num_items_yielded = self._count_items_yielded(dl_sd_in) + self.assertEqual(num_items_yielded, 0) + + # Create a new StatefulDataLoader instance and load the state dict saved after the end of epoch + dl_sd_out = self.get_map_dl( + data_size=data_size, + num_workers=num_workers, + batch_size=batch_size, + shuffle=shuffle, + ) + dl_sd_out.load_state_dict(sd_out) + + # Run through the new dataloader for 1 epoch and count the number of items yielded + # num_items_yielded should be data_size since the state dict was saved after the end of epoch + num_items_yielded = self._count_items_yielded(dl_sd_out) + self.assertEqual(num_items_yielded, data_size) + + def test_main_process(self): + self._run(100, 0, 1, False) + + def test_multiprocess(self): + self._run(100, 2, 1, False) + + def test_main_process_shuffle(self): + self._run(100, 0, 1, True) + + def test_multiprocess_shuffle(self): + self._run(100, 2, 1, True) + + class TestMultiEpochState_shard0(TestCase): def get_iterable_dl(self, pw, num_workers): data_size = [25, 50, 100, 75] @@ -1528,7 +1678,9 @@ def get_iterable_dl(self, pw, num_workers): num_workers=num_workers, persistent_workers=pw, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) def _run(self, pw: bool, num_workers: int): @@ -1589,7 +1741,9 @@ def __iter__(self): num_workers = torch.utils.data.get_worker_info().num_workers num_samples = (int)(self.length / num_workers) - self.iter_state = IterationState(num_samples * worker_id, num_samples * (worker_id + 1)) + self.iter_state = IterationState( + num_samples * worker_id, num_samples * (worker_id + 1) + ) return self def __next__(self): @@ -1615,29 +1769,39 @@ def _get_iter_calls(self, state): if w_states[0]["dataset_state"] is not None: return [x["dataset_state"]["iter_calls"] for x in w_states] - return [x["fetcher_state"]["dataset_iter_state"]["iter_calls"] for x in w_states] + return [ + x["fetcher_state"]["dataset_iter_state"]["iter_calls"] for x in w_states + ] def _run_test(self, num_workers, dataset, expected_iter_calls): dl = StatefulDataLoader( dataset=dataset, num_workers=num_workers, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) iter(dl) state = dl.state_dict() # Ensure iter is called only once per worker - self.assertEqual(self._get_iter_calls(state), [expected_iter_calls[0]] * max(1, num_workers)) + self.assertEqual( + self._get_iter_calls(state), [expected_iter_calls[0]] * max(1, num_workers) + ) dl2 = StatefulDataLoader( dataset=dataset, num_workers=num_workers, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl2.load_state_dict(state) iter(dl2) state2 = dl2.state_dict() # Ensure that iter is called only once per worker even when dataloader resumes from a state - self.assertEqual(self._get_iter_calls(state2), [expected_iter_calls[1]] * max(1, num_workers)) + self.assertEqual( + self._get_iter_calls(state2), [expected_iter_calls[1]] * max(1, num_workers) + ) def test_inline(self): self._run_test(0, CountIterCalls(100), [1, 2]) @@ -1678,7 +1842,9 @@ def _run_test(self, num_workers, dataset): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) it = iter(dl) data = [] @@ -1692,7 +1858,9 @@ def _run_test(self, num_workers, dataset): dataset=dataset, num_workers=num_workers, collate_fn=identity, - multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None), + multiprocessing_context=( + "forkserver" if IS_MACOS and num_workers else None + ), ) dl2.load_state_dict(state) it = iter(dl2) @@ -1739,7 +1907,9 @@ def give_data(self, iter_start, iter_end): def __iter__(self): worker_info = torch.utils.data.get_worker_info() - per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) + per_worker = int( + math.ceil((self.end - self.start) / float(worker_info.num_workers)) + ) worker_id = worker_info.id iter_start = self.start + worker_id * per_worker iter_end = min(iter_start + per_worker, self.end) @@ -1793,12 +1963,18 @@ def test_out_of_order_iterable_ds_one_completed_worker(self): state_dict = dataloader.state_dict() break - worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"] - worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"] + worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["fetcher_ended"] + worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"][ + "fetcher_state" + ]["fetcher_ended"] self.assertTrue(worker_0_ended) self.assertFalse(worker_1_ended) - new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False) + new_dataloader = StatefulDataLoader( + dataset, batch_size=1, num_workers=2, in_order=False + ) new_dataloader.load_state_dict(state_dict) for i, data in enumerate(new_dataloader): output.append(data) @@ -1824,12 +2000,18 @@ def test_out_of_order_iterable_ds_no_completed_workers(self): state_dict = dataloader.state_dict() break - worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"]["fetcher_state"]["fetcher_ended"] - worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"]["fetcher_state"]["fetcher_ended"] + worker_0_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_0"][ + "fetcher_state" + ]["fetcher_ended"] + worker_1_ended = state_dict["_snapshot"]["_worker_snapshots"]["worker_1"][ + "fetcher_state" + ]["fetcher_ended"] self.assertFalse(worker_0_ended) self.assertFalse(worker_1_ended) - new_dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, in_order=False) + new_dataloader = StatefulDataLoader( + dataset, batch_size=1, num_workers=2, in_order=False + ) new_dataloader.load_state_dict(state_dict) for i, data in enumerate(new_dataloader): output.append(data)