From 99fa7e7f272e6a0ca5a4031ae43267bda1376cd8 Mon Sep 17 00:00:00 2001 From: Ramanish Singh Date: Tue, 21 Jan 2025 20:30:43 -0800 Subject: [PATCH] update test --- test/nodes/test_multi_node_weighted_sampler.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/test/nodes/test_multi_node_weighted_sampler.py b/test/nodes/test_multi_node_weighted_sampler.py index c89630138..014f696c5 100644 --- a/test/nodes/test_multi_node_weighted_sampler.py +++ b/test/nodes/test_multi_node_weighted_sampler.py @@ -100,7 +100,7 @@ def test_multi_node_weighted_sampler_first_exhausted(self) -> None: seed=self._seed ) - for _ in range(self._num_epochs): + for _ in range(1): #only running for one epoch as the number of samples taken from each dataset is stochastic and is epoch dependent results = list(mixer) datasets_in_results = [result["name"] for result in results] @@ -109,10 +109,17 @@ def test_multi_node_weighted_sampler_first_exhausted(self) -> None: # Check max item count for dataset is exactly _num_samples self.assertEqual(max(dataset_counts_in_results), self._num_samples) - # Check only one dataset has been exhausted + # Check that the max number of samples (10) have been taken from two datasets (ds2 and ds3) self.assertEqual(dataset_counts_in_results.count(self._num_samples), 2) + # The number of datasets from which max number of samples is taken is 2 because StopIteration is called after next is called + # on the dataset node which is on its last element. We do not have a way to preemptively tell if a node is at its last element without + # calling next on it. Thus, during multi dataset sampling, multiple dataset nodes can be at their last element and when next is called + # on any one of them, it raises StopIteration. Thus, multiple datasets can yield max number of elements. + # Check that the number of samples taken from each dataset + self.assertEqual(dataset_counts_in_results, [4, 8, 10, 10]) mixer.reset() + def test_multi_node_weighted_sampler_all_dataset_exhausted(self) -> None: """Test MultiNodeWeightedSampler with stop criteria ALL_DATASETS_EXHAUSTED""" mixer = self._setup_multi_node_weighted_sampler(