Skip to content

Commit 16fe7ac

Browse files
committed
update epoch numbers and also add the test cases for handling the resume in combined streaming dataset
1 parent a9f15f6 commit 16fe7ac

File tree

1 file changed

+82
-7
lines changed

1 file changed

+82
-7
lines changed

tests/streaming/test_combined.py

+82-7
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def test_combined_dataset_with_dataloader_and_one_worker(batch_size):
324324
"0": {"num_samples_yielded": 9, "num_workers": 1, "batch_size": batch_size},
325325
"1": {"num_samples_yielded": 3, "num_workers": 1, "batch_size": batch_size},
326326
},
327-
"current_epoch": 0,
327+
"current_epoch": 1,
328328
"latest_worker_idx": 0,
329329
"num_samples_yielded": {0: [9, 3]},
330330
}
@@ -374,7 +374,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
374374
"num_samples_yielded": 0,
375375
"num_workers": 3,
376376
"batch_size": 2,
377-
"current_epoch": 0,
377+
"current_epoch": 1,
378378
"input_dir_path": ANY,
379379
"input_dir_url": ANY,
380380
"cache_dir_path": None,
@@ -390,7 +390,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
390390
"num_samples_yielded": 0,
391391
"num_workers": 3,
392392
"batch_size": 2,
393-
"current_epoch": 0,
393+
"current_epoch": 1,
394394
"input_dir_path": ANY,
395395
"input_dir_url": ANY,
396396
"cache_dir_path": None,
@@ -403,7 +403,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
403403
"region_of_interest": ANY,
404404
},
405405
},
406-
"current_epoch": 0,
406+
"current_epoch": 1,
407407
"latest_worker_idx": 0,
408408
"num_samples_yielded": {},
409409
}
@@ -417,7 +417,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
417417
{0: [4, 1], 1: [3, 1], 2: [2, 1]},
418418
{0: [4, 1], 1: [4, 1], 2: [2, 1]},
419419
]
420-
expected_current_epoch = [0, 0, 0, 0, 0, 0, 0, 0]
420+
expected_current_epoch = [1, 1, 1, 1, 1, 1, 1, 1]
421421
dataset_1_current_epoch = [1, 1, 1, 1, 1, 1, 1, 1]
422422
dataset_2_current_epoch = [1, 1, 1, 1, 1, 1, 1, 1]
423423
expected_latest_worker_idx = [0, 1, 2, 0, 1, 2, 0, 1]
@@ -459,7 +459,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
459459
]
460460
dataset_1_current_epoch = [2, 2, 2, 2, 2, 2, 2, 2]
461461
dataset_2_current_epoch = [2, 2, 2, 2, 2, 2, 2, 2]
462-
expected_current_epoch = [1, 1, 1, 1, 1, 1, 1, 1]
462+
expected_current_epoch = [2, 2, 2, 2, 2, 2, 2, 2]
463463
expected_latest_worker_idx = [0, 1, 2, 0, 1, 2, 0, 1]
464464
expected_dataset0_samples_yielded = [2, 4, 6, 7, 8, 8, 9, 10]
465465
expected_dataset1_samples_yielded = [0, 0, 0, 1, 2, 3, 3, 3]
@@ -497,6 +497,81 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
497497
states_23.append(dataloader.state_dict())
498498

499499
assert sum(not torch.equal(b1, b2) for b1, b2 in zip(batches_2[2:], batches_23)) == 0
500-
assert states_23[0]["current_epoch"] == 1
500+
assert states_23[0]["current_epoch"] == 2
501+
502+
assert not dataloader.restore
503+
504+
505+
def test_combined_dataset_dataloader_states_without_any_iterations(combined_dataset):
506+
dataloader = StreamingDataLoader(combined_dataset, batch_size=4)
507+
assert not dataloader.restore
508+
dataloader.load_state_dict(dataloader.state_dict())
509+
assert not dataloader.restore
510+
511+
512+
@pytest.mark.timeout(120)
513+
@pytest.mark.parametrize("num_workers", [0, 2, 4])
514+
def test_combined_dataset_dataloader_states_complete_iterations(combined_dataset, num_workers):
515+
print(f"Testing with num_workers={num_workers}")
516+
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=num_workers)
517+
assert len(dataloader) == 25, "Dataloader length should be 25 (50+50 items / batch size 4)"
518+
519+
# Verify dataloader state after complete last iteration
520+
for _ in dataloader:
521+
assert dataloader.current_epoch == 1, "Current epoch should be 1"
522+
pass
523+
524+
dataloader.load_state_dict(dataloader.state_dict())
525+
assert not dataloader.restore
526+
527+
for _ in dataloader:
528+
assert dataloader.current_epoch == 2, "Current epoch should be 2"
529+
pass
501530

502531
assert not dataloader.restore
532+
533+
del dataloader
534+
535+
536+
@pytest.mark.timeout(300)
537+
@pytest.mark.parametrize(("num_workers", "break_at"), [(0, 10), (0, 15), (2, 10), (2, 15), (4, 10), (4, 15)])
538+
def test_combined_dataset_dataloader_states_partial_iterations(combined_dataset, num_workers, break_at):
539+
print(f"Testing with num_workers={num_workers}, break_at={break_at}")
540+
541+
# Verify dataloader state after partial last iteration
542+
dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=num_workers)
543+
544+
total_batches = len(dataloader)
545+
assert total_batches == 25, "Dataloader length should be 25 (100 items / batch size 4)"
546+
547+
assert not dataloader.restore, "Dataloader should not be in restore state initially."
548+
549+
# Partial iteration up to 'break_at'
550+
for batch_idx, batch in enumerate(dataloader):
551+
assert dataloader.current_epoch == 1, "Current epoch should be 1 during first iteration"
552+
if batch_idx == break_at:
553+
break
554+
555+
assert not dataloader.restore, (
556+
"Dataloader should not be in restore state after partial iteration, before loading state."
557+
)
558+
dataloader.load_state_dict(dataloader.state_dict())
559+
assert dataloader.restore, "Dataloader should be in restore state after loading the state from a partial iteration."
560+
561+
# Verify remaining batches in the first epoch
562+
count = 0
563+
for _ in dataloader:
564+
assert dataloader.current_epoch == 1, "Current epoch should be 1 during restore"
565+
count += 1
566+
expected_batches = total_batches - break_at - 1
567+
assert count >= expected_batches, (
568+
f"There should be at least{expected_batches} remaining batches in the first epoch."
569+
)
570+
assert not dataloader.restore, "Dataloader should not be in restore state after completing first epoch."
571+
572+
# Verify batches in the second epoch
573+
samples_yielded = 0
574+
for batch in dataloader:
575+
assert dataloader.current_epoch == 2, "Current epoch should be 2 in the second iteration"
576+
samples_yielded += len(batch)
577+
assert samples_yielded == len(combined_dataset), "All samples should be yielded in the second epoch."

0 commit comments

Comments
 (0)