Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Aug 21, 2024
1 parent 2e33f9c commit 6c6ef78
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions test/stateful_dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,7 +1960,7 @@ def test_sampler_reproducibility(self):
ls[i].append(next(its[i]))
self.assertEqual(ls[0], ls[1])

def test_initialization_TestStatefulDistributedSampler(self):
def test_initialization_StatefulDistributedSampler(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

dataset = self.dataset
Expand All @@ -1971,6 +1971,8 @@ def test_initialization_TestStatefulDistributedSampler(self):
self.assertFalse(sampler.shuffle)
self.assertEqual(sampler.seed, 42)
self.assertFalse(sampler.drop_last)
self.assertEqual(sampler.yielded, 0)
self.assertIsNone(sampler.next_yielded)

def test_state_dict(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
Expand All @@ -1991,6 +1993,19 @@ def test_load_state_dict(self):
with self.assertRaises(ValueError):
sampler.load_state_dict({"yielded": -1})

def test_next_yielded(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

sampler = StatefulDistributedSampler(self.dataset, num_replicas=2, rank=0, shuffle=True, seed=42)
iterator = iter(sampler)
next(iterator) # advance the iterator
self.assertEqual(sampler.yielded, 1)
self.assertIsNone(sampler.next_yielded)
sampler.load_state_dict({StatefulDistributedSampler._YIELDED: 5})
self.assertEqual(sampler.next_yielded, 5)
next(iterator) # advance the iterator again
self.assertEqual(sampler.yielded, 6)

def test_drop_last_effect(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

Expand Down Expand Up @@ -2050,7 +2065,7 @@ def test_no_data_loss_with_drop_last(self):

sampler = StatefulDistributedSampler(self.dataset, num_replicas=3, rank=0, drop_last=True)
indices = list(iter(sampler))
expected_length = (len(self.dataset) // 3) * 3 // 3 # Calculate expected length considering drop_last
expected_length = (len(self.dataset) // 3) * 3 // 3
self.assertEqual(
len(indices), expected_length, "Length of indices should match expected length with drop_last=True"
)
Expand Down

0 comments on commit 6c6ef78

Please sign in to comment.