Skip to content

Commit 65bb403

Browse files
tf-model-analysis-teamtfx-copybara
tf-model-analysis-team
authored andcommitted
This is an internal cleanup
PiperOrigin-RevId: 420896786
1 parent fcb15bd commit 65bb403

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

tensorflow_model_analysis/evaluators/metrics_plots_and_validations_evaluator.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import copy
1717
import datetime
1818
import numbers
19-
from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, Union
19+
from typing import Any, Dict, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union
2020
import apache_beam as beam
2121
import numpy as np
2222

@@ -41,6 +41,9 @@
4141
from tfx_bsl.tfxio import tensor_adapter
4242
from tensorflow_metadata.proto.v0 import schema_pb2
4343

44+
SliceKeyTypeVar = TypeVar('SliceKeyTypeVar', slicer.SliceKeyType,
45+
slicer.CrossSliceKeyType)
46+
4447
_COMBINER_INPUTS_KEY = '_combiner_inputs'
4548
_DEFAULT_COMBINER_INPUT_KEY = '_default_combiner_input'
4649
_DEFAULT_NUM_JACKKNIFE_BUCKETS = 20
@@ -381,19 +384,20 @@ def _is_private_metrics(metric_key: metric_types.MetricKey):
381384

382385

383386
def _remove_private_metrics(
384-
slice_key: slicer.SliceKeyOrCrossSliceKeyType,
385-
metrics: metric_types.MetricsDict
386-
) -> Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict]:
387+
slice_key: SliceKeyTypeVar, metrics: metric_types.MetricsDict
388+
) -> Tuple[SliceKeyTypeVar, metric_types.MetricsDict]:
387389
return (slice_key,
388390
{k: v for (k, v) in metrics.items() if not _is_private_metrics(k)})
389391

390392

391393
@beam.ptransform_fn
392394
def _AddCrossSliceMetrics( # pylint: disable=invalid-name
393-
sliced_combiner_outputs: beam.pvalue.PCollection,
395+
sliced_combiner_outputs: beam.pvalue.PCollection[Tuple[
396+
slicer.SliceKeyType, metric_types.MetricsDict]],
394397
cross_slice_specs: Optional[Iterable[config_pb2.CrossSlicingSpec]],
395398
cross_slice_computations: List[metric_types.CrossSliceMetricComputation],
396-
) -> Tuple[slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict]:
399+
) -> beam.pvalue.PCollection[Tuple[slicer.SliceKeyOrCrossSliceKeyType,
400+
metric_types.MetricsDict]]:
397401
"""Generates CrossSlice metrics from SingleSlices."""
398402

399403
def is_slice_applicable(
@@ -495,8 +499,8 @@ def _AddDerivedCrossSliceAndDiffMetrics( # pylint: disable=invalid-name
495499
derived_computations: List[metric_types.DerivedMetricComputation],
496500
cross_slice_computations: List[metric_types.CrossSliceMetricComputation],
497501
cross_slice_specs: Optional[Iterable[config_pb2.CrossSlicingSpec]] = None,
498-
baseline_model_name: Optional[str] = None
499-
) -> beam.PCollection[Tuple[slicer.SliceKeyType, metric_types.MetricsDict]]:
502+
baseline_model_name: Optional[str] = None) -> beam.PCollection[Tuple[
503+
slicer.SliceKeyOrCrossSliceKeyType, metric_types.MetricsDict]]:
500504
"""A PTransform for adding cross slice and derived metrics.
501505
502506
This PTransform uses the input PCollection of sliced metrics to compute
@@ -564,11 +568,11 @@ def add_diff_metrics(
564568

565569

566570
def _filter_by_key_type(
567-
sliced_metrics_plots_attributions: Tuple[slicer.SliceKeyType,
571+
sliced_metrics_plots_attributions: Tuple[SliceKeyTypeVar,
568572
Dict[metric_types.MetricKey, Any]],
569573
key_type: Type[Union[metric_types.MetricKey, metric_types.PlotKey,
570574
metric_types.AttributionsKey]]
571-
) -> Tuple[slicer.SliceKeyType, Dict[metric_types.MetricKey, Any]]:
575+
) -> Tuple[SliceKeyTypeVar, Dict[metric_types.MetricKey, Any]]:
572576
"""Filters metrics and plots by key type."""
573577
slice_value, metrics_plots_attributions = sliced_metrics_plots_attributions
574578
output = {}

0 commit comments

Comments
 (0)