@@ -233,7 +233,7 @@ def test_sample_combine_fn_missing_samples(self):
233
233
]
234
234
235
235
with beam .Pipeline () as pipeline :
236
- _ = (
236
+ result = (
237
237
pipeline
238
238
| 'Create' >> beam .Create (samples , reshuffle = False )
239
239
| 'CombineSamplesPerKey' >> beam .CombinePerKey (
@@ -242,19 +242,31 @@ def test_sample_combine_fn_missing_samples(self):
242
242
full_sample_id = _FULL_SAMPLE_ID ,
243
243
skip_ci_metric_keys = [example_count_key ])))
244
244
245
- result = pipeline .run ()
245
+ def check_result (got_pcoll ):
246
+ self .assertLen (got_pcoll , 2 )
247
+ slice2_metrics = None
248
+ for slice_key , metrics in got_pcoll :
249
+ if slice_key == slice_key2 :
250
+ slice2_metrics = metrics
251
+ break
252
+ self .assertIsNotNone (slice2_metrics )
253
+ self .assertIn (metric_types .MetricKey ('__ERROR__' ), slice2_metrics )
254
+
255
+ util .assert_that (result , check_result )
256
+
257
+ runner_result = pipeline .run ()
246
258
# we expect one missing samples counter increment for slice2, since we
247
259
# expected 2 samples, but only saw 1.
248
260
metric_filter = beam .metrics .metric .MetricsFilter ().with_name (
249
261
'num_slices_missing_samples' )
250
- counters = result .metrics ().query (filter = metric_filter )['counters' ]
262
+ counters = runner_result .metrics ().query (filter = metric_filter )['counters' ]
251
263
self .assertLen (counters , 1 )
252
264
self .assertEqual (1 , counters [0 ].committed )
253
265
254
266
# verify total slice counter
255
267
metric_filter = beam .metrics .metric .MetricsFilter ().with_name (
256
268
'num_slices' )
257
- counters = result .metrics ().query (filter = metric_filter )['counters' ]
269
+ counters = runner_result .metrics ().query (filter = metric_filter )['counters' ]
258
270
self .assertLen (counters , 1 )
259
271
self .assertEqual (2 , counters [0 ].committed )
260
272
0 commit comments