Skip to content

Commit 6ba1eef

Browse files
embrtf-model-analysis-team
authored and
tf-model-analysis-team
committed
Fix case where error metric key is of type tuple.
PiperOrigin-RevId: 387391727
1 parent 708c195 commit 6ba1eef

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

tensorflow_model_analysis/evaluators/confidence_intervals_util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def extract_output(
171171
accumulator.num_samples < self._num_samples):
172172
self._missing_samples_counter.inc(1)
173173
missing_samples = True
174-
error_metric_key = metric_types.MetricKey(metric_keys.ERROR_METRIC),
174+
error_metric_key = metric_types.MetricKey(metric_keys.ERROR_METRIC)
175175
result[error_metric_key] = (
176176
f'CI not computed because only {accumulator.num_samples} samples '
177177
f'were non-empty. Expected {self._num_samples}.')
@@ -203,5 +203,4 @@ def extract_output(
203203
sample_standard_deviation=standard_error,
204204
sample_degrees_of_freedom=dof,
205205
unsampled_value=unsampled_value)
206-
# TODO(b/194750790): remove this once the typing issue is resolved.
207-
return result # pytype: disable=bad-return-type
206+
return result

tensorflow_model_analysis/evaluators/confidence_intervals_util_test.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def test_sample_combine_fn_missing_samples(self):
233233
]
234234

235235
with beam.Pipeline() as pipeline:
236-
_ = (
236+
result = (
237237
pipeline
238238
| 'Create' >> beam.Create(samples, reshuffle=False)
239239
| 'CombineSamplesPerKey' >> beam.CombinePerKey(
@@ -242,19 +242,31 @@ def test_sample_combine_fn_missing_samples(self):
242242
full_sample_id=_FULL_SAMPLE_ID,
243243
skip_ci_metric_keys=[example_count_key])))
244244

245-
result = pipeline.run()
245+
def check_result(got_pcoll):
246+
self.assertLen(got_pcoll, 2)
247+
slice2_metrics = None
248+
for slice_key, metrics in got_pcoll:
249+
if slice_key == slice_key2:
250+
slice2_metrics = metrics
251+
break
252+
self.assertIsNotNone(slice2_metrics)
253+
self.assertIn(metric_types.MetricKey('__ERROR__'), slice2_metrics)
254+
255+
util.assert_that(result, check_result)
256+
257+
runner_result = pipeline.run()
246258
# we expect one missing samples counter increment for slice2, since we
247259
# expected 2 samples, but only saw 1.
248260
metric_filter = beam.metrics.metric.MetricsFilter().with_name(
249261
'num_slices_missing_samples')
250-
counters = result.metrics().query(filter=metric_filter)['counters']
262+
counters = runner_result.metrics().query(filter=metric_filter)['counters']
251263
self.assertLen(counters, 1)
252264
self.assertEqual(1, counters[0].committed)
253265

254266
# verify total slice counter
255267
metric_filter = beam.metrics.metric.MetricsFilter().with_name(
256268
'num_slices')
257-
counters = result.metrics().query(filter=metric_filter)['counters']
269+
counters = runner_result.metrics().query(filter=metric_filter)['counters']
258270
self.assertLen(counters, 1)
259271
self.assertEqual(2, counters[0].committed)
260272

0 commit comments

Comments
 (0)