Skip to content

Commit

Permalink
remove unnecessary repetition of methods
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Aug 22, 2024
1 parent 0dce976 commit fb0a187
Show file tree
Hide file tree
Showing 4 changed files with 306 additions and 179 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/stateful_dataloader_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ jobs:
- name: Run StatefulDataLoader tests with pytest - dataloader
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_dataloader.py
- name: Run StatefulDataSampler tests with pytest - datasampler
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_sampler.py
- name: Run StatefulDataLoader tests with pytest - state_dict 0
if: ${{ ! contains(github.event.pull_request.labels.*.name, 'ciflow/slow') }}
run: pytest --durations=0 --no-header -v test/stateful_dataloader/test_state_dict.py -k _shard0
Expand Down
123 changes: 0 additions & 123 deletions test/stateful_dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,18 +1075,6 @@ def filter_len(row):
return len(row) == 4


class MockDataset(Dataset):
def __init__(self, size):
self.size = size
self.data = torch.arange(size) # Simple data that is easy to verify

def __len__(self):
return self.size

def __getitem__(self, idx):
return self.data[idx]


@unittest.skipIf(
TEST_WITH_TSAN,
"Fails with TSAN with the following error: starting new threads after multi-threaded "
Expand All @@ -1100,7 +1088,6 @@ def setUp(self):
self.labels = torch.randperm(50).repeat(2)
self.dataset = TensorDataset(self.data, self.labels)
self.persistent_workers = False
self.mockdataset = MockDataset(100)

def _get_data_loader(self, dataset, **kwargs):
persistent_workers = kwargs.get("persistent_workers", self.persistent_workers)
Expand Down Expand Up @@ -1960,116 +1947,6 @@ def test_sampler_reproducibility(self):
ls[i].append(next(its[i]))
self.assertEqual(ls[0], ls[1])

def test_initialization_StatefulDistributedSampler(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

dataset = self.dataset
sampler = StatefulDistributedSampler(dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False)
self.assertEqual(sampler.dataset, dataset)
self.assertEqual(sampler.num_replicas, 10)
self.assertEqual(sampler.rank, 0)
self.assertFalse(sampler.shuffle)
self.assertEqual(sampler.seed, 42)
self.assertFalse(sampler.drop_last)
self.assertEqual(sampler.yielded, 0)
self.assertIsNone(sampler.next_yielded)

def test_state_dict(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

dataset = self.dataset
sampler = StatefulDistributedSampler(dataset, num_replicas=10, rank=0)
sampler.yielded = 5
state_dict = sampler.state_dict()
self.assertEqual(state_dict["yielded"], 5)

def test_load_state_dict(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

dataset = self.dataset
sampler = StatefulDistributedSampler(dataset, num_replicas=10, rank=0)
sampler.load_state_dict({"yielded": 3})
self.assertEqual(sampler.next_yielded, 3)
with self.assertRaises(ValueError):
sampler.load_state_dict({"yielded": -1})

def test_next_yielded(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

sampler = StatefulDistributedSampler(self.dataset, num_replicas=2, rank=0, shuffle=True, seed=42)
iterator = iter(sampler)
next(iterator) # advance the iterator
self.assertEqual(sampler.yielded, 1)
self.assertIsNone(sampler.next_yielded)
sampler.load_state_dict({StatefulDistributedSampler._YIELDED: 5})
self.assertEqual(sampler.next_yielded, 5)
next(iterator) # advance the iterator again
self.assertEqual(sampler.yielded, 6)

def test_drop_last_effect(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

sampler_with_drop = StatefulDistributedSampler(self.dataset, num_replicas=3, rank=0, drop_last=True)
sampler_without_drop = StatefulDistributedSampler(self.dataset, num_replicas=3, rank=0, drop_last=False)
indices_with_drop = list(iter(sampler_with_drop))
indices_without_drop = list(iter(sampler_without_drop))
self.assertTrue(
len(indices_with_drop) <= len(indices_without_drop), "Drop last should result in fewer or equal indices"
)

def test_data_order_with_shuffle(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

sampler = StatefulDistributedSampler(self.mockdataset, num_replicas=1, rank=0, shuffle=True, seed=42)
indices = list(iter(sampler))
data_sampled = [self.mockdataset[i] for i in indices]
self.assertNotEqual(data_sampled, list(range(100)), "Data should be shuffled")

def test_data_order_without_shuffle(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

sampler = StatefulDistributedSampler(self.mockdataset, num_replicas=1, rank=0, shuffle=False)
indices = list(iter(sampler))
data_sampled = [self.mockdataset[i] for i in indices]
self.assertEqual(data_sampled, list(range(100)), "Data should be in sequential order when shuffle is False")

def test_data_distribution_across_replicas(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

num_replicas = 5
all_data = []
for rank in range(num_replicas):
sampler = StatefulDistributedSampler(self.mockdataset, num_replicas=num_replicas, rank=rank, shuffle=False)
indices = list(iter(sampler))
data_sampled = [int(self.mockdataset[i].numpy().astype(int)) for i in indices]
all_data.extend(data_sampled)
self.assertEqual(
sorted(all_data), list(range(100)), "All data points should be covered exactly once across all replicas"
)

def test_consistency_across_epochs(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

num_replicas = 3
rank = 1
sampler = StatefulDistributedSampler(self.dataset, num_replicas=num_replicas, rank=rank, shuffle=True, seed=42)
indices_epoch1 = list(iter(sampler))
data_epoch1 = [self.dataset[i] for i in indices_epoch1]
sampler.set_epoch(1) # Move to the next epoch
indices_epoch2 = list(iter(sampler))
data_epoch2 = [self.dataset[i] for i in indices_epoch2]
self.assertNotEqual(data_epoch1, data_epoch2, "Data order should change with different epochs due to shuffling")

def test_no_data_loss_with_drop_last(self):
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler

sampler = StatefulDistributedSampler(self.dataset, num_replicas=3, rank=0, drop_last=True)
indices = list(iter(sampler))
expected_length = (len(self.dataset) // 3) * 3 // 3
self.assertEqual(
len(indices), expected_length, "Length of indices should match expected length with drop_last=True"
)

def _test_sampler(self, **kwargs):
indices = range(2, 12) # using a regular iterable
dl = self._get_data_loader(self.dataset, sampler=indices, batch_size=2, **kwargs)
Expand Down
Loading

0 comments on commit fb0a187

Please sign in to comment.