Skip to content

Commit b06e87b

Browse files
iindyktf-transform-team
authored and
tf-transform-team
committed
Typing cleanup:
- changing input types of `tf_utils.construct_and_lookup_table` to tf.Tensor since it's not currently used or tested with composite tensors (it is applied on flat values for composite inputs to mappers). - `common_types.ConsistentTensorType` is not currently used anywhere, so renaming `common_types.ConsistentInputTensorType` to `common_types.ConsistentTensorType` and removing the unused type. - renaming `common_types.InputTensorType` to `common_types.TensorType`. PiperOrigin-RevId: 404825689
1 parent 35df75a commit b06e87b

11 files changed

+140
-148
lines changed

tensorflow_transform/analyzers.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112

113113
def _apply_cacheable_combiner(
114114
combiner: analyzer_nodes.Combiner,
115-
*tensor_inputs: common_types.InputTensorType) -> Tuple[tf.Tensor, ...]:
115+
*tensor_inputs: common_types.TensorType) -> Tuple[tf.Tensor, ...]:
116116
"""Applies the combiner over the whole dataset possibly utilizing cache."""
117117
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
118118
tensor_inputs)
@@ -137,7 +137,7 @@ def _apply_cacheable_combiner(
137137

138138
def _apply_cacheable_combiner_per_key(
139139
combiner: analyzer_nodes.Combiner,
140-
*tensor_inputs: common_types.InputTensorType) -> Tuple[tf.Tensor, ...]:
140+
*tensor_inputs: common_types.TensorType) -> Tuple[tf.Tensor, ...]:
141141
"""Similar to _apply_cacheable_combiner but this is computed per key."""
142142
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
143143
tensor_inputs)
@@ -162,7 +162,7 @@ def _apply_cacheable_combiner_per_key(
162162

163163
def _apply_cacheable_combiner_per_key_large(
164164
combiner: analyzer_nodes.Combiner, key_vocabulary_filename: str,
165-
*tensor_inputs: common_types.InputTensorType
165+
*tensor_inputs: common_types.TensorType
166166
) -> Union[tf.Tensor, common_types.Asset]:
167167
"""Similar to above but saves the combined result to a file."""
168168
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
@@ -382,7 +382,7 @@ def _numeric_combine(inputs: List[tf.Tensor],
382382

383383
@common.log_api_use(common.ANALYZER_COLLECTION)
384384
def min( # pylint: disable=redefined-builtin
385-
x: common_types.InputTensorType,
385+
x: common_types.TensorType,
386386
reduce_instance_dims: bool = True,
387387
name: Optional[str] = None) -> tf.Tensor:
388388
"""Computes the minimum of the values of a `Tensor` over the whole dataset.
@@ -409,7 +409,7 @@ def min( # pylint: disable=redefined-builtin
409409

410410
@common.log_api_use(common.ANALYZER_COLLECTION)
411411
def max( # pylint: disable=redefined-builtin
412-
x: common_types.InputTensorType,
412+
x: common_types.TensorType,
413413
reduce_instance_dims: bool = True,
414414
name: Optional[str] = None) -> tf.Tensor:
415415
"""Computes the maximum of the values of a `Tensor` over the whole dataset.
@@ -433,7 +433,7 @@ def max( # pylint: disable=redefined-builtin
433433
return _min_and_max(x, reduce_instance_dims, name)[1]
434434

435435

436-
def _min_and_max(x: common_types.InputTensorType,
436+
def _min_and_max(x: common_types.TensorType,
437437
reduce_instance_dims: bool = True,
438438
name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]:
439439
"""Computes the min and max of the values of a `Tensor` or `CompositeTensor`.
@@ -482,8 +482,8 @@ def _min_and_max(x: common_types.InputTensorType,
482482

483483

484484
def _min_and_max_per_key(
485-
x: common_types.InputTensorType,
486-
key: common_types.InputTensorType,
485+
x: common_types.TensorType,
486+
key: common_types.TensorType,
487487
reduce_instance_dims: bool = True,
488488
key_vocabulary_filename: Optional[str] = None,
489489
name: Optional[str] = None
@@ -577,7 +577,7 @@ def _sum_combine_fn_and_dtype(
577577

578578
@common.log_api_use(common.ANALYZER_COLLECTION)
579579
def sum( # pylint: disable=redefined-builtin
580-
x: common_types.InputTensorType,
580+
x: common_types.TensorType,
581581
reduce_instance_dims: bool = True,
582582
name: Optional[str] = None) -> tf.Tensor:
583583
"""Computes the sum of the values of a `Tensor` over the whole dataset.
@@ -629,7 +629,7 @@ def sum( # pylint: disable=redefined-builtin
629629

630630

631631
@common.log_api_use(common.ANALYZER_COLLECTION)
632-
def histogram(x: common_types.InputTensorType,
632+
def histogram(x: common_types.TensorType,
633633
boundaries: Optional[Union[tf.Tensor, int]] = None,
634634
categorical: Optional[bool] = False,
635635
name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]:
@@ -702,7 +702,7 @@ def histogram(x: common_types.InputTensorType,
702702

703703

704704
@common.log_api_use(common.ANALYZER_COLLECTION)
705-
def size(x: common_types.InputTensorType,
705+
def size(x: common_types.TensorType,
706706
reduce_instance_dims: bool = True,
707707
name: Optional[str] = None) -> tf.Tensor:
708708
"""Computes the total size of instances in a `Tensor` over the whole dataset.
@@ -730,7 +730,7 @@ def size(x: common_types.InputTensorType,
730730

731731

732732
@common.log_api_use(common.ANALYZER_COLLECTION)
733-
def count_per_key(key: common_types.InputTensorType,
733+
def count_per_key(key: common_types.TensorType,
734734
key_vocabulary_filename: Optional[str] = None,
735735
name: Optional[str] = None):
736736
"""Computes the count of each element of a `Tensor`.
@@ -779,7 +779,7 @@ def count_per_key(key: common_types.InputTensorType,
779779

780780

781781
@common.log_api_use(common.ANALYZER_COLLECTION)
782-
def mean(x: common_types.InputTensorType,
782+
def mean(x: common_types.TensorType,
783783
reduce_instance_dims: bool = True,
784784
name: Optional[str] = None,
785785
output_dtype: Optional[tf.DType] = None) -> tf.Tensor:
@@ -807,7 +807,7 @@ def mean(x: common_types.InputTensorType,
807807

808808

809809
@common.log_api_use(common.ANALYZER_COLLECTION)
810-
def var(x: common_types.InputTensorType,
810+
def var(x: common_types.TensorType,
811811
reduce_instance_dims: bool = True,
812812
name: Optional[str] = None,
813813
output_dtype: Optional[tf.DType] = None) -> tf.Tensor:
@@ -837,7 +837,7 @@ def var(x: common_types.InputTensorType,
837837
return _mean_and_var(x, reduce_instance_dims, output_dtype)[1]
838838

839839

840-
def _mean_and_var(x: common_types.InputTensorType,
840+
def _mean_and_var(x: common_types.TensorType,
841841
reduce_instance_dims: bool = True,
842842
output_dtype: Optional[tf.DType] = None):
843843
"""More efficient combined `mean` and `var`. See `var`."""
@@ -876,7 +876,7 @@ def _mean_and_var(x: common_types.InputTensorType,
876876

877877

878878
@common.log_api_use(common.ANALYZER_COLLECTION)
879-
def tukey_location(x: common_types.InputTensorType,
879+
def tukey_location(x: common_types.TensorType,
880880
reduce_instance_dims: Optional[bool] = True,
881881
output_dtype: Optional[tf.DType] = None,
882882
name: Optional[str] = None) -> tf.Tensor:
@@ -913,7 +913,7 @@ def tukey_location(x: common_types.InputTensorType,
913913

914914

915915
@common.log_api_use(common.ANALYZER_COLLECTION)
916-
def tukey_scale(x: common_types.InputTensorType,
916+
def tukey_scale(x: common_types.TensorType,
917917
reduce_instance_dims: Optional[bool] = True,
918918
output_dtype: Optional[tf.DType] = None,
919919
name: Optional[str] = None) -> tf.Tensor:
@@ -951,7 +951,7 @@ def tukey_scale(x: common_types.InputTensorType,
951951

952952

953953
@common.log_api_use(common.ANALYZER_COLLECTION)
954-
def tukey_h_params(x: common_types.InputTensorType,
954+
def tukey_h_params(x: common_types.TensorType,
955955
reduce_instance_dims: bool = True,
956956
output_dtype: Optional[tf.DType] = None,
957957
name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]:
@@ -988,7 +988,7 @@ def tukey_h_params(x: common_types.InputTensorType,
988988

989989

990990
def _tukey_parameters(
991-
x: common_types.InputTensorType,
991+
x: common_types.TensorType,
992992
reduce_instance_dims: bool = True,
993993
output_dtype: Optional[tf.DType] = None
994994
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]:
@@ -1027,8 +1027,8 @@ def _tukey_parameters(
10271027

10281028

10291029
def _mean_and_var_per_key(
1030-
x: common_types.InputTensorType,
1031-
key: common_types.InputTensorType,
1030+
x: common_types.TensorType,
1031+
key: common_types.TensorType,
10321032
reduce_instance_dims: bool = True,
10331033
output_dtype: Optional[tf.DType] = None,
10341034
key_vocabulary_filename: Optional[str] = None
@@ -1652,7 +1652,7 @@ def _register_vocab(sanitized_filename: str,
16521652
# https://github.com/tensorflow/community/blob/master/rfcs/20190116-embedding-partitioned-variable.md#goals
16531653
@common.log_api_use(common.ANALYZER_COLLECTION)
16541654
def vocabulary(
1655-
x: common_types.InputTensorType,
1655+
x: common_types.TensorType,
16561656
top_k: Optional[int] = None,
16571657
frequency_threshold: Optional[int] = None,
16581658
vocab_filename: Optional[str] = None,

tensorflow_transform/common_types.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,14 @@
4545

4646
DomainType = Union[schema_pb2.IntDomain, schema_pb2.FloatDomain,
4747
schema_pb2.StringDomain]
48-
InputTensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
48+
TensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]
49+
ConsistentTensorType = TypeVar('ConsistentTensorType', tf.Tensor,
50+
tf.SparseTensor, tf.RaggedTensor)
4951
SparseTensorValueType = Union[tf.SparseTensor, tf.compat.v1.SparseTensorValue]
5052
RaggedTensorValueType = Union[tf.RaggedTensor,
5153
tf.compat.v1.ragged.RaggedTensorValue]
5254
TensorValueType = Union[tf.Tensor, np.ndarray, SparseTensorValueType,
5355
RaggedTensorValueType]
54-
ConsistentInputTensorType = TypeVar('ConsistentInputTensorType', tf.Tensor,
55-
tf.SparseTensor, tf.RaggedTensor)
56-
ConsistentTensorType = TypeVar('ConsistentTensorType', tf.Tensor,
57-
tf.SparseTensor)
5856
TemporaryAnalyzerOutputType = Union[tf.Tensor, Asset]
5957
VocabularyFileFormatType = Literal['text', 'tfrecord_gzip']
6058

tensorflow_transform/experimental/analyzers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535

3636
def _apply_analyzer(analyzer_def_cls: Type[analyzer_nodes.AnalyzerDef],
37-
*tensor_inputs: common_types.InputTensorType,
37+
*tensor_inputs: common_types.TensorType,
3838
**analyzer_def_kwargs: Any) -> Tuple[tf.Tensor, ...]:
3939
"""Applies the analyzer over the whole dataset.
4040

tensorflow_transform/graph_tools.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -903,8 +903,7 @@ def validate_value(self, value):
903903

904904

905905
def get_analyzers_fingerprint(
906-
graph: tf.Graph,
907-
structured_inputs: Mapping[str, common_types.InputTensorType]
906+
graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType]
908907
) -> Mapping[str, Set[bytes]]:
909908
"""Computes fingerprints for all analyzers in `graph`.
910909

tensorflow_transform/impl_helper.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ def _check_valid_sparse_tensor(indices: Union[_CompositeComponentType,
595595
# `preprocessing_fn` using tf.function as is and another that will return
596596
# specific outputs requested for.
597597
def get_traced_transform_fn(
598-
preprocessing_fn: Callable[[Mapping[str, common_types.InputTensorType]],
599-
Mapping[str, common_types.InputTensorType]],
598+
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
599+
Mapping[str, common_types.TensorType]],
600600
input_signature: Mapping[str, tf.TypeSpec],
601601
tf_graph_context: graph_context.TFGraphContext,
602602
output_keys_to_name_map: Optional[Dict[str,
@@ -720,8 +720,8 @@ def trace_preprocessing_function(preprocessing_fn,
720720

721721
def _trace_and_write_transform_fn(
722722
saved_model_dir: str,
723-
preprocessing_fn: Callable[[Mapping[str, common_types.InputTensorType]],
724-
Mapping[str, common_types.InputTensorType]],
723+
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
724+
Mapping[str, common_types.TensorType]],
725725
input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str],
726726
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],
727727
output_keys_to_name_map: Optional[Dict[str,
@@ -743,9 +743,9 @@ def _trace_and_write_transform_fn(
743743

744744
def _trace_and_get_metadata(
745745
concrete_transform_fn: function.ConcreteFunction,
746-
structured_inputs: Mapping[str, common_types.InputTensorType],
747-
preprocessing_fn: Callable[[Mapping[str, common_types.InputTensorType]],
748-
Mapping[str, common_types.InputTensorType]],
746+
structured_inputs: Mapping[str, common_types.TensorType],
747+
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
748+
Mapping[str, common_types.TensorType]],
749749
base_temp_dir: Optional[str],
750750
tensor_replacement_map: Optional[Dict[str, tf.Tensor]]
751751
) -> dataset_metadata.DatasetMetadata:
@@ -768,7 +768,7 @@ def _trace_and_get_metadata(
768768

769769
def _validate_analyzers_fingerprint(
770770
baseline_analyzers_fingerprint: Mapping[str, Set[bytes]], graph: tf.Graph,
771-
structured_inputs: Mapping[str, common_types.InputTensorType]):
771+
structured_inputs: Mapping[str, common_types.TensorType]):
772772
"""Validates analyzers fingerprint in `graph` is same as baseline."""
773773
analyzers_fingerprint = graph_tools.get_analyzers_fingerprint(
774774
graph, structured_inputs)
@@ -787,8 +787,8 @@ def _validate_analyzers_fingerprint(
787787

788788
def trace_and_write_v2_saved_model(
789789
saved_model_dir: str,
790-
preprocessing_fn: Callable[[Mapping[str, common_types.InputTensorType]],
791-
Mapping[str, common_types.InputTensorType]],
790+
preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]],
791+
Mapping[str, common_types.TensorType]],
792792
input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str],
793793
baseline_analyzers_fingerprint: Mapping[str, Set[bytes]],
794794
tensor_replacement_map: Optional[Dict[str, tf.Tensor]],

0 commit comments

Comments
 (0)