@@ -324,7 +324,7 @@ def test_combined_dataset_with_dataloader_and_one_worker(batch_size):
324
324
"0" : {"num_samples_yielded" : 9 , "num_workers" : 1 , "batch_size" : batch_size },
325
325
"1" : {"num_samples_yielded" : 3 , "num_workers" : 1 , "batch_size" : batch_size },
326
326
},
327
- "current_epoch" : 0 ,
327
+ "current_epoch" : 1 ,
328
328
"latest_worker_idx" : 0 ,
329
329
"num_samples_yielded" : {0 : [9 , 3 ]},
330
330
}
@@ -374,7 +374,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
374
374
"num_samples_yielded" : 0 ,
375
375
"num_workers" : 3 ,
376
376
"batch_size" : 2 ,
377
- "current_epoch" : 0 ,
377
+ "current_epoch" : 1 ,
378
378
"input_dir_path" : ANY ,
379
379
"input_dir_url" : ANY ,
380
380
"cache_dir_path" : None ,
@@ -390,7 +390,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
390
390
"num_samples_yielded" : 0 ,
391
391
"num_workers" : 3 ,
392
392
"batch_size" : 2 ,
393
- "current_epoch" : 0 ,
393
+ "current_epoch" : 1 ,
394
394
"input_dir_path" : ANY ,
395
395
"input_dir_url" : ANY ,
396
396
"cache_dir_path" : None ,
@@ -403,7 +403,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
403
403
"region_of_interest" : ANY ,
404
404
},
405
405
},
406
- "current_epoch" : 0 ,
406
+ "current_epoch" : 1 ,
407
407
"latest_worker_idx" : 0 ,
408
408
"num_samples_yielded" : {},
409
409
}
@@ -417,7 +417,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
417
417
{0 : [4 , 1 ], 1 : [3 , 1 ], 2 : [2 , 1 ]},
418
418
{0 : [4 , 1 ], 1 : [4 , 1 ], 2 : [2 , 1 ]},
419
419
]
420
- expected_current_epoch = [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ]
420
+ expected_current_epoch = [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
421
421
dataset_1_current_epoch = [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
422
422
dataset_2_current_epoch = [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
423
423
expected_latest_worker_idx = [0 , 1 , 2 , 0 , 1 , 2 , 0 , 1 ]
@@ -459,7 +459,7 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
459
459
]
460
460
dataset_1_current_epoch = [2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ]
461
461
dataset_2_current_epoch = [2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ]
462
- expected_current_epoch = [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
462
+ expected_current_epoch = [2 , 2 , 2 , 2 , 2 , 2 , 2 , 2 ]
463
463
expected_latest_worker_idx = [0 , 1 , 2 , 0 , 1 , 2 , 0 , 1 ]
464
464
expected_dataset0_samples_yielded = [2 , 4 , 6 , 7 , 8 , 8 , 9 , 10 ]
465
465
expected_dataset1_samples_yielded = [0 , 0 , 0 , 1 , 2 , 3 , 3 , 3 ]
@@ -497,6 +497,81 @@ def test_combined_dataset_with_dataloader_2_epochs(tmpdir):
497
497
states_23 .append (dataloader .state_dict ())
498
498
499
499
assert sum (not torch .equal (b1 , b2 ) for b1 , b2 in zip (batches_2 [2 :], batches_23 )) == 0
500
- assert states_23 [0 ]["current_epoch" ] == 1
500
+ assert states_23 [0 ]["current_epoch" ] == 2
501
+
502
+ assert not dataloader .restore
503
+
504
+
505
+ def test_combined_dataset_dataloader_states_without_any_iterations (combined_dataset ):
506
+ dataloader = StreamingDataLoader (combined_dataset , batch_size = 4 )
507
+ assert not dataloader .restore
508
+ dataloader .load_state_dict (dataloader .state_dict ())
509
+ assert not dataloader .restore
510
+
511
+
512
+ @pytest .mark .timeout (120 )
513
+ @pytest .mark .parametrize ("num_workers" , [0 , 2 , 4 ])
514
+ def test_combined_dataset_dataloader_states_complete_iterations (combined_dataset , num_workers ):
515
+ print (f"Testing with num_workers={ num_workers } " )
516
+ dataloader = StreamingDataLoader (combined_dataset , batch_size = 4 , num_workers = num_workers )
517
+ assert len (dataloader ) == 25 , "Dataloader length should be 25 (50+50 items / batch size 4)"
518
+
519
+ # Verify dataloader state after complete last iteration
520
+ for _ in dataloader :
521
+ assert dataloader .current_epoch == 1 , "Current epoch should be 1"
522
+ pass
523
+
524
+ dataloader .load_state_dict (dataloader .state_dict ())
525
+ assert not dataloader .restore
526
+
527
+ for _ in dataloader :
528
+ assert dataloader .current_epoch == 2 , "Current epoch should be 2"
529
+ pass
501
530
502
531
assert not dataloader .restore
532
+
533
+ del dataloader
534
+
535
+
536
+ @pytest .mark .timeout (300 )
537
+ @pytest .mark .parametrize (("num_workers" , "break_at" ), [(0 , 10 ), (0 , 15 ), (2 , 10 ), (2 , 15 ), (4 , 10 ), (4 , 15 )])
538
+ def test_combined_dataset_dataloader_states_partial_iterations (combined_dataset , num_workers , break_at ):
539
+ print (f"Testing with num_workers={ num_workers } , break_at={ break_at } " )
540
+
541
+ # Verify dataloader state after partial last iteration
542
+ dataloader = StreamingDataLoader (combined_dataset , batch_size = 4 , num_workers = num_workers )
543
+
544
+ total_batches = len (dataloader )
545
+ assert total_batches == 25 , "Dataloader length should be 25 (100 items / batch size 4)"
546
+
547
+ assert not dataloader .restore , "Dataloader should not be in restore state initially."
548
+
549
+ # Partial iteration up to 'break_at'
550
+ for batch_idx , batch in enumerate (dataloader ):
551
+ assert dataloader .current_epoch == 1 , "Current epoch should be 1 during first iteration"
552
+ if batch_idx == break_at :
553
+ break
554
+
555
+ assert not dataloader .restore , (
556
+ "Dataloader should not be in restore state after partial iteration, before loading state."
557
+ )
558
+ dataloader .load_state_dict (dataloader .state_dict ())
559
+ assert dataloader .restore , "Dataloader should be in restore state after loading the state from a partial iteration."
560
+
561
+ # Verify remaining batches in the first epoch
562
+ count = 0
563
+ for _ in dataloader :
564
+ assert dataloader .current_epoch == 1 , "Current epoch should be 1 during restore"
565
+ count += 1
566
+ expected_batches = total_batches - break_at - 1
567
+ assert count >= expected_batches , (
568
+ f"There should be at least{ expected_batches } remaining batches in the first epoch."
569
+ )
570
+ assert not dataloader .restore , "Dataloader should not be in restore state after completing first epoch."
571
+
572
+ # Verify batches in the second epoch
573
+ samples_yielded = 0
574
+ for batch in dataloader :
575
+ assert dataloader .current_epoch == 2 , "Current epoch should be 2 in the second iteration"
576
+ samples_yielded += len (batch )
577
+ assert samples_yielded == len (combined_dataset ), "All samples should be yielded in the second epoch."
0 commit comments