Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Feb 4, 2025
1 parent a96cae5 commit 2097717
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions test/stateful_dataloader/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,11 +1443,12 @@ def test_fast_state_dict_request_skip_steps(self) -> None:

class TestMultiEpochState(TestCase):
def get_map_dl(self, data_size=100, num_workers=0, batch_size=1, shuffle=False):
dataset = list(range(data_size))
dataset = DummyMapDataset(data_size, shuffle=False)
return StatefulDataLoader(
dataset=dataset,
num_workers=num_workers,
batch_size=batch_size,
shuffle=shuffle,
multiprocessing_context=("forkserver" if IS_MACOS and num_workers else None),
)

Expand All @@ -1463,7 +1464,7 @@ def _run(self, data_size, num_workers, batch_size, shuffle=False):
# Save the state dict
state_dict = dl1.state_dict()
# Create a new StatefulDataLoader instance and load the state dict
new_dl1 = self.get_map_dl(num_workers=num_workers, batch_size=batch_size, shuffle=shuffle)
new_dl1 = self.get_map_dl(data_size=data_size, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle)
new_dl1.load_state_dict(state_dict)
# Run through the new dataloader for another 2 epochs and count the number of items yielded
additional_num_items_yielded = 0
Expand Down

0 comments on commit 2097717

Please sign in to comment.