From 6c6ef787ade40f9b989b9a5eb6a0d62bb6299cfd Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Wed, 21 Aug 2024 09:53:23 -0700 Subject: [PATCH] add tests --- test/stateful_dataloader/test_dataloader.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/test/stateful_dataloader/test_dataloader.py b/test/stateful_dataloader/test_dataloader.py index 9eb4c8895..1ed1fcbc1 100644 --- a/test/stateful_dataloader/test_dataloader.py +++ b/test/stateful_dataloader/test_dataloader.py @@ -1960,7 +1960,7 @@ def test_sampler_reproducibility(self): ls[i].append(next(its[i])) self.assertEqual(ls[0], ls[1]) - def test_initialization_TestStatefulDistributedSampler(self): + def test_initialization_StatefulDistributedSampler(self): from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler dataset = self.dataset @@ -1971,6 +1971,8 @@ def test_initialization_TestStatefulDistributedSampler(self): self.assertFalse(sampler.shuffle) self.assertEqual(sampler.seed, 42) self.assertFalse(sampler.drop_last) + self.assertEqual(sampler.yielded, 0) + self.assertIsNone(sampler.next_yielded) def test_state_dict(self): from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler @@ -1991,6 +1993,19 @@ def test_load_state_dict(self): with self.assertRaises(ValueError): sampler.load_state_dict({"yielded": -1}) + def test_next_yielded(self): + from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler + + sampler = StatefulDistributedSampler(self.dataset, num_replicas=2, rank=0, shuffle=True, seed=42) + iterator = iter(sampler) + next(iterator) # advance the iterator + self.assertEqual(sampler.yielded, 1) + self.assertIsNone(sampler.next_yielded) + sampler.load_state_dict({StatefulDistributedSampler._YIELDED: 5}) + self.assertEqual(sampler.next_yielded, 5) + next(iterator) # advance the iterator again + self.assertEqual(sampler.yielded, 6) + def test_drop_last_effect(self): from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler @@ -2050,7 +2065,7 @@ def test_no_data_loss_with_drop_last(self): sampler = StatefulDistributedSampler(self.dataset, num_replicas=3, rank=0, drop_last=True) indices = list(iter(sampler)) - expected_length = (len(self.dataset) // 3) * 3 // 3 # Calculate expected length considering drop_last + expected_length = (len(self.dataset) // 3) * 3 // 3 self.assertEqual( len(indices), expected_length, "Length of indices should match expected length with drop_last=True" )