|
16 | 16 | import copy
|
17 | 17 | import datetime
|
18 | 18 | import numbers
|
19 |
| -from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, Union |
| 19 | +from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union |
20 | 20 | import apache_beam as beam
|
21 | 21 | import numpy as np
|
22 | 22 |
|
|
41 | 41 | from tfx_bsl.tfxio import tensor_adapter
|
42 | 42 | from tensorflow_metadata.proto.v0 import schema_pb2
|
43 | 43 |
|
| 44 | +SliceKeyTypeVar = TypeVar('SliceKeyTypeVar', slicer.SliceKeyType, |
| 45 | + slicer.CrossSliceKeyType) |
| 46 | + |
44 | 47 | _COMBINER_INPUTS_KEY = '_combiner_inputs'
|
45 | 48 | _DEFAULT_COMBINER_INPUT_KEY = '_default_combiner_input'
|
46 | 49 | _DEFAULT_NUM_JACKKNIFE_BUCKETS = 20
|
@@ -381,19 +384,20 @@ def _is_private_metrics(metric_key: metric_types.MetricKey):
|
381 | 384 |
|
382 | 385 |
|
383 | 386 | def _remove_private_metrics(
|
384 |
| - slice_key: slicer.SliceKeyOrCrossSliceKeyType, |
385 |
| - metrics: metric_types.MetricsDict |
386 |
| -) -> Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict]: |
| 387 | + slice_key: SliceKeyTypeVar, metrics: metric_types.MetricsDict |
| 388 | +) -> Tuple[SliceKeyTypeVar, metric_types.MetricsDict]: |
387 | 389 | return (slice_key,
|
388 | 390 | {k: v for (k, v) in metrics.items() if not _is_private_metrics(k)})
|
389 | 391 |
|
390 | 392 |
|
391 | 393 | @beam.ptransform_fn
|
392 | 394 | def _AddCrossSliceMetrics( # pylint: disable=invalid-name
|
393 |
| - sliced_combiner_outputs: beam.pvalue.PCollection, |
| 395 | + sliced_combiner_outputs: beam.pvalue.PCollection[Tuple[ |
| 396 | + slicer.SliceKeyType, metric_types.MetricsDict]], |
394 | 397 | cross_slice_specs: Optional[Iterable[config_pb2.CrossSlicingSpec]],
|
395 | 398 | cross_slice_computations: List[metric_types.CrossSliceMetricComputation],
|
396 |
| -) -> Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict]: |
| 399 | +) -> beam.pvalue.PCollection[Tuple[slicer.SliceKeyOrCrossSliceKeyType, |
| 400 | + metric_types.MetricsDict]]: |
397 | 401 | """Generates CrossSlice metrics from SingleSlices."""
|
398 | 402 |
|
399 | 403 | def is_slice_applicable(
|
@@ -495,8 +499,8 @@ def _AddDerivedCrossSliceAndDiffMetrics( # pylint: disable=invalid-name
|
495 | 499 | derived_computations: List[metric_types.DerivedMetricComputation],
|
496 | 500 | cross_slice_computations: List[metric_types.CrossSliceMetricComputation],
|
497 | 501 | cross_slice_specs: Optional[Iterable[config_pb2.CrossSlicingSpec]] = None,
|
498 |
| - baseline_model_name: Optional[str] = None |
499 |
| -) -> beam.PCollection[Tuple[slicer.SliceKeyType, metric_types.MetricsDict]]: |
| 502 | + baseline_model_name: Optional[str] = None) -> beam.PCollection[Tuple[ |
| 503 | + slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict]]: |
500 | 504 | """A PTransform for adding cross slice and derived metrics.
|
501 | 505 |
|
502 | 506 | This PTransform uses the input PCollection of sliced metrics to compute
|
@@ -564,11 +568,11 @@ def add_diff_metrics(
|
564 | 568 |
|
565 | 569 |
|
566 | 570 | def _filter_by_key_type(
|
567 |
| - sliced_metrics_plots_attributions: Tuple[slicer.SliceKeyType, |
| 571 | + sliced_metrics_plots_attributions: Tuple[SliceKeyTypeVar, |
568 | 572 | Dict[metric_types.MetricKey, Any]],
|
569 | 573 | key_type: Type[Union[metric_types.MetricKey, metric_types.PlotKey,
|
570 | 574 | metric_types.AttributionsKey]]
|
571 |
| -) -> Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]]: |
| 575 | +) -> Tuple[SliceKeyTypeVar, Dict[metric_types.MetricKey, Any]]: |
572 | 576 | """Filters metrics and plots by key type."""
|
573 | 577 | slice_value, metrics_plots_attributions = sliced_metrics_plots_attributions
|
574 | 578 | output = {}
|
|
0 commit comments