15
15
16
16
# TODO(b/149126671): Put ValidationResultsWriter in a separate file.
17
17
18
-
19
18
import os
20
19
import tempfile
21
20
from typing import Any , Dict , Iterable , Iterator , List , Optional , Set , Union
@@ -108,9 +107,8 @@ def _is_legacy_eval(
108
107
return ((config_version is not None and config_version == 1 ) or
109
108
(eval_shared_model and not isinstance (eval_shared_model , dict ) and
110
109
not isinstance (eval_shared_model , list ) and
111
- (not eval_shared_model .model_loader .tags or
112
- eval_constants .EVAL_TAG in eval_shared_model .model_loader .tags ) and
113
- not eval_config ))
110
+ (not eval_shared_model .model_loader .tags or eval_constants .EVAL_TAG
111
+ in eval_shared_model .model_loader .tags ) and not eval_config ))
114
112
115
113
116
114
def _default_eval_config (eval_shared_models : List [types .EvalSharedModel ],
@@ -441,6 +439,14 @@ def default_eval_shared_model(
441
439
is_baseline = is_baseline )
442
440
443
441
442
+ def _has_sql_slices (eval_config : Optional [config_pb2 .EvalConfig ]) -> bool :
443
+ if eval_config :
444
+ for spec in eval_config .slicing_specs :
445
+ if spec .slice_keys_sql :
446
+ return True
447
+ return False
448
+
449
+
444
450
def default_extractors ( # pylint: disable=invalid-name
445
451
eval_shared_model : Optional [types .MaybeMultipleEvalSharedModels ] = None ,
446
452
eval_config : Optional [config_pb2 .EvalConfig ] = None ,
@@ -495,7 +501,16 @@ def default_extractors( # pylint: disable=invalid-name
495
501
slice_key_extractor .SliceKeyExtractor (
496
502
eval_config = eval_config , materialize = materialize )
497
503
]
498
- elif eval_shared_model :
504
+ slicing_extractors = []
505
+ if _has_sql_slices (eval_config ):
506
+ slicing_extractors .append (
507
+ sql_slice_key_extractor .SqlSliceKeyExtractor (eval_config ))
508
+ slicing_extractors .extend ([
509
+ unbatch_extractor .UnbatchExtractor (),
510
+ slice_key_extractor .SliceKeyExtractor (
511
+ eval_config = eval_config , materialize = materialize )
512
+ ])
513
+ if eval_shared_model :
499
514
model_types = _model_types (eval_shared_model )
500
515
eval_shared_models = model_util .verify_and_update_eval_shared_models (
501
516
eval_shared_model )
@@ -520,11 +535,8 @@ def default_extractors( # pylint: disable=invalid-name
520
535
eval_config = eval_config ),
521
536
(custom_predict_extractor or
522
537
tflite_predict_extractor .TFLitePredictExtractor (
523
- eval_config = eval_config , eval_shared_model = eval_shared_model )),
524
- unbatch_extractor .UnbatchExtractor (),
525
- slice_key_extractor .SliceKeyExtractor (
526
- eval_config = eval_config , materialize = materialize )
527
- ]
538
+ eval_config = eval_config , eval_shared_model = eval_shared_model ))
539
+ ] + slicing_extractors
528
540
elif constants .TF_LITE in model_types :
529
541
raise NotImplementedError (
530
542
'support for mixing tf_lite and non-tf_lite models is not '
@@ -538,11 +550,8 @@ def default_extractors( # pylint: disable=invalid-name
538
550
eval_config = eval_config ),
539
551
(custom_predict_extractor or
540
552
tfjs_predict_extractor .TFJSPredictExtractor (
541
- eval_config = eval_config , eval_shared_model = eval_shared_model )),
542
- unbatch_extractor .UnbatchExtractor (),
543
- slice_key_extractor .SliceKeyExtractor (
544
- eval_config = eval_config , materialize = materialize )
545
- ]
553
+ eval_config = eval_config , eval_shared_model = eval_shared_model ))
554
+ ] + slicing_extractors
546
555
elif constants .TF_JS in model_types :
547
556
raise NotImplementedError (
548
557
'support for mixing tf_js and non-tf_js models is not '
@@ -555,11 +564,8 @@ def default_extractors( # pylint: disable=invalid-name
555
564
custom_predict_extractor or legacy_predict_extractor .PredictExtractor (
556
565
eval_shared_model ,
557
566
materialize = materialize ,
558
- eval_config = eval_config ),
559
- unbatch_extractor .UnbatchExtractor (),
560
- slice_key_extractor .SliceKeyExtractor (
561
- eval_config = eval_config , materialize = materialize )
562
- ]
567
+ eval_config = eval_config )
568
+ ] + slicing_extractors
563
569
elif (eval_config and constants .TF_ESTIMATOR in model_types and
564
570
any (eval_constants .EVAL_TAG in m .model_loader .tags
565
571
for m in eval_shared_models )):
@@ -585,23 +591,17 @@ def default_extractors( # pylint: disable=invalid-name
585
591
eval_config = eval_config ,
586
592
eval_shared_model = eval_shared_model ,
587
593
tensor_adapter_config = tensor_adapter_config )),
588
- unbatch_extractor .UnbatchExtractor (),
589
- slice_key_extractor .SliceKeyExtractor (
590
- eval_config = eval_config , materialize = materialize )
591
594
])
595
+ extractors .extend (slicing_extractors )
592
596
return extractors
593
597
else :
594
598
return [
595
599
features_extractor .FeaturesExtractor (eval_config = eval_config ),
596
600
labels_extractor .LabelsExtractor (eval_config = eval_config ),
597
601
example_weights_extractor .ExampleWeightsExtractor (
598
602
eval_config = eval_config ),
599
- predictions_extractor .PredictionsExtractor (eval_config = eval_config ),
600
- sql_slice_key_extractor .SqlSliceKeyExtractor (eval_config ),
601
- unbatch_extractor .UnbatchExtractor (),
602
- slice_key_extractor .SliceKeyExtractor (
603
- eval_config = eval_config , materialize = materialize )
604
- ]
603
+ predictions_extractor .PredictionsExtractor (eval_config = eval_config )
604
+ ] + slicing_extractors
605
605
606
606
607
607
def default_evaluators ( # pylint: disable=invalid-name
0 commit comments