From 535469872289e4788b493a73fc636814633a8b5e Mon Sep 17 00:00:00 2001 From: andrewfulton9 Date: Tue, 13 May 2025 14:07:37 -0600 Subject: [PATCH 1/4] setup pre-commit --- .github/workflows/ci-lint.yml | 21 +++++++++ .pre-commit-config.yaml | 39 ++++++++++++++++ pyproject.toml | 85 +++++++++++++++++++++++++++++++++++ setup.py | 3 ++ 4 files changed, 148 insertions(+) create mode 100644 .github/workflows/ci-lint.yml create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml diff --git a/.github/workflows/ci-lint.yml b/.github/workflows/ci-lint.yml new file mode 100644 index 0000000..dede434 --- /dev/null +++ b/.github/workflows/ci-lint.yml @@ -0,0 +1,21 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [master] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4.1.7 + with: + # Ensure the full history is fetched + # This is required to run pre-commit on a specific set of commits + # TODO: Remove this when all the pre-commit issues are fixed + fetch-depth: 0 + - uses: actions/setup-python@v5.1.1 + with: + python-version: 3.13 + - uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..387a3ef --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +# pre-commit is a tool to perform a predefined set of tasks manually and/or +# automatically before git commits are made. +# +# Config reference: https://pre-commit.com/#pre-commit-configyaml---top-level +# +# Common tasks +# +# - Register git hooks: pre-commit install --install-hooks +# - Run on all files: pre-commit run --all-files +# +# These pre-commit hooks are run as CI. +# +# NOTE: if it can be avoided, add configs/args in pyproject.toml or below instead of creating a new `.config.file`. +# https://pre-commit.ci/#configuration +ci: + autoupdate_schedule: monthly + autofix_commit_msg: | + [pre-commit.ci] Apply automatic pre-commit fixes + +repos: + # general + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: end-of-file-fixer + exclude: '\.svg$' + - id: trailing-whitespace + exclude: '\.svg$' + - id: check-json + - id: check-yaml + args: [--allow-multiple-documents, --unsafe] + - id: check-toml + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.5.6 + hooks: + - id: ruff + args: ["--fix"] + - id: ruff-format diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..10cdddd --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,85 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +[build-system] +requires = [ + "setuptools", + "wheel", +] + +[tool.ruff] +line-length = 88 + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + "W", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # pep8 naming + "N", + # pydocstyle + "D", + # annotations + "ANN", + # debugger + "T10", + # flake8-pytest + "PT", + # flake8-return + "RET", + # flake8-unused-arguments + "ARG", + # flake8-fixme + "FIX", + # flake8-eradicate + "ERA", + # pandas-vet + "PD", + # numpy-specific rules + "NPY", +] + +ignore = [ + "D104", # Missing docstring in public package + "D100", # Missing docstring in public module + "D211", # No blank line before class + "PD901", # Avoid using 'df' for pandas dataframes. Perfectly fine in functions with limited scope + "ANN201", # Missing return type annotation for public function (makes no sense for NoneType return types...) + "ANN101", # Missing type annotation for `self` + "ANN204", # Missing return type annotation for special method + "ANN002", # Missing type annotation for `*args` + "ANN003", # Missing type annotation for `**kwargs` + "D105", # Missing docstring in magic method + "D203", # 1 blank line before after class docstring + "D204", # 1 blank line required after class docstring + "D413", # 1 blank line after parameters + "SIM108", # Simplify if/else to one line; not always clearer + "D206", # Docstrings should be indented with spaces; unnecessary when running ruff-format + "E501", # Line length too long; unnecessary when running ruff-format + "W191", # Indentation contains tabs; unnecessary when running ruff-format +] + + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] diff --git a/setup.py b/setup.py index 3f69292..fa75232 100644 --- a/setup.py +++ b/setup.py @@ -99,6 +99,9 @@ def _make_required_install_packages(): ], namespace_packages=[], install_requires=_make_required_install_packages(), + extras_require={ + "dev": ["pre-commit"], + }, python_requires='>=3.9,<4', packages=find_packages(), include_package_data=True, From dacf5bb190b80eb42ccb0141df271d80e8c11e61 Mon Sep 17 00:00:00 2001 From: andrewfulton9 Date: Wed, 14 May 2025 10:55:24 -0600 Subject: [PATCH 2/4] pre-commit auto fixes --- RELEASE.md | 3 +- docs/build_tft_beam_docs.py | 64 +- docs/build_tft_docs.py | 53 +- examples/census_example.py | 312 +- examples/census_example_v2.py | 898 +- examples/census_example_v2_test.py | 260 +- examples/dataset_tfxio_example.py | 80 +- examples/dataset_tfxio_example_test.py | 56 +- examples/local_model_server.py | 16 +- examples/sentiment_example_v2.py | 660 +- examples/sentiment_example_v2_test.py | 136 +- examples/simple_example.py | 76 +- examples/simple_example_test.py | 50 +- examples/simple_sequence_example.py | 269 +- examples/simple_sequence_example_test.py | 103 +- setup.py | 151 +- tensorflow_transform/__init__.py | 21 +- tensorflow_transform/analyzer_nodes.py | 1933 ++-- tensorflow_transform/analyzers.py | 4810 ++++---- tensorflow_transform/analyzers_test.py | 838 +- tensorflow_transform/annotators.py | 351 +- tensorflow_transform/annotators_test.py | 77 +- tensorflow_transform/beam/__init__.py | 20 +- .../beam/analysis_graph_builder.py | 1349 +-- .../beam/analysis_graph_builder_test.py | 546 +- tensorflow_transform/beam/analyzer_cache.py | 672 +- .../beam/analyzer_cache_test.py | 610 +- tensorflow_transform/beam/analyzer_impls.py | 2503 +++-- .../beam/analyzer_impls_test.py | 249 +- tensorflow_transform/beam/annotators_test.py | 256 +- tensorflow_transform/beam/beam_nodes.py | 264 +- .../beam/bucketize_integration_test.py | 1740 +-- tensorflow_transform/beam/cached_impl_test.py | 2785 ++--- .../beam/combiner_packing_util.py | 927 +- .../beam/combiner_packing_util_test.py | 174 +- tensorflow_transform/beam/common.py | 322 +- tensorflow_transform/beam/context.py | 337 +- tensorflow_transform/beam/context_test.py | 45 +- tensorflow_transform/beam/deep_copy.py | 446 +- tensorflow_transform/beam/deep_copy_test.py | 652 +- .../beam/experimental/analyzer_impls.py | 19 +- tensorflow_transform/beam/impl.py | 2848 ++--- .../beam/impl_output_record_batches_test.py | 336 +- tensorflow_transform/beam/impl_test.py | 9897 +++++++++-------- tensorflow_transform/beam/test_helpers.py | 7 +- .../beam/tft_beam_io/__init__.py | 6 +- .../beam/tft_beam_io/beam_metadata_io.py | 129 +- .../beam/tft_beam_io/beam_metadata_io_test.py | 140 +- .../beam/tft_beam_io/test_metadata.py | 14 +- .../beam/tft_beam_io/transform_fn_io.py | 229 +- .../beam/tft_beam_io/transform_fn_io_test.py | 274 +- tensorflow_transform/beam/tft_unit.py | 766 +- .../beam/tukey_hh_params_integration_test.py | 1562 ++- .../beam/vocabulary_integration_test.py | 4491 ++++---- ...cabulary_tfrecord_gzip_integration_test.py | 15 +- tensorflow_transform/coders/csv_coder.py | 476 +- tensorflow_transform/coders/csv_coder_test.py | 432 +- .../coders/example_proto_coder.py | 435 +- .../coders/example_proto_coder_test.py | 479 +- tensorflow_transform/common.py | 87 +- tensorflow_transform/common_test.py | 80 +- tensorflow_transform/common_types.py | 32 +- .../experimental/analyzers.py | 981 +- .../experimental/annotators.py | 200 +- tensorflow_transform/experimental/mappers.py | 696 +- tensorflow_transform/gaussianization.py | 673 +- tensorflow_transform/gaussianization_test.py | 466 +- tensorflow_transform/graph_context.py | 207 +- tensorflow_transform/graph_tools.py | 1729 +-- tensorflow_transform/graph_tools_test.py | 2384 ++-- tensorflow_transform/impl_helper.py | 1153 +- tensorflow_transform/impl_helper_test.py | 1619 +-- tensorflow_transform/info_theory.py | 203 +- tensorflow_transform/info_theory_test.py | 299 +- .../inspect_preprocessing_fn.py | 152 +- .../inspect_preprocessing_fn_test.py | 254 +- tensorflow_transform/keras_lib.py | 45 +- tensorflow_transform/mappers.py | 4115 +++---- tensorflow_transform/mappers_test.py | 2121 ++-- tensorflow_transform/nodes.py | 596 +- tensorflow_transform/nodes_test.py | 369 +- tensorflow_transform/output_wrapper.py | 1017 +- tensorflow_transform/pickle_helper.py | 42 +- tensorflow_transform/pretrained_models.py | 525 +- .../pretrained_models_test.py | 281 +- tensorflow_transform/py.typed | 2 +- tensorflow_transform/py_func/__init__.py | 4 +- tensorflow_transform/py_func/api.py | 89 +- tensorflow_transform/py_func/pyfunc_helper.py | 252 +- tensorflow_transform/saved/constants.py | 4 +- .../saved/saved_model_loader.py | 103 +- .../saved/saved_model_loader_test.py | 36 +- .../saved/saved_transform_io.py | 845 +- .../saved/saved_transform_io_test.py | 585 +- .../saved/saved_transform_io_v2.py | 996 +- .../saved/saved_transform_io_v2_test.py | 1231 +- tensorflow_transform/schema_inference.py | 1497 +-- tensorflow_transform/schema_inference_test.py | 602 +- tensorflow_transform/test_case.py | 569 +- tensorflow_transform/test_case_test.py | 116 +- tensorflow_transform/tf2_utils.py | 352 +- tensorflow_transform/tf2_utils_test.py | 157 +- .../tf_metadata/dataset_metadata.py | 67 +- .../tf_metadata/dataset_metadata_test.py | 18 +- .../tf_metadata/metadata_io.py | 175 +- .../tf_metadata/metadata_io_test.py | 85 +- .../tf_metadata/schema_utils.py | 986 +- .../tf_metadata/schema_utils_legacy.py | 15 +- .../tf_metadata/schema_utils_test.py | 139 +- .../tf_metadata/schema_utils_test_cases.py | 745 +- .../tf_metadata/test_common.py | 25 +- tensorflow_transform/tf_utils.py | 3789 ++++--- tensorflow_transform/tf_utils_test.py | 5473 +++++---- tensorflow_transform/version.py | 2 +- 114 files changed, 46058 insertions(+), 40896 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 816cb10..2733322 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -45,7 +45,7 @@ * Enable passing `tf.saved_model.SaveOptions` to model saving functionality. * Census and sentiment examples updated to only use Keras instead of estimator. -* Depends on `apache-beam[gcp]>=2.53.0,<3` for Python 3.11 and on +* Depends on `apache-beam[gcp]>=2.53.0,<3` for Python 3.11 and on `apache-beam[gcp]>=2.47.0,<3` for 3.9 and 3.10. * Depends on `protobuf>=4.25.2,<5` for Python 3.11 and on `protobuf>3.20.3,<5` for 3.9 and 3.10. @@ -1513,4 +1513,3 @@ the generated vocab_filename on a downstream component. * Update tensorflow_transform to use `tf.saved_model` APIs. * Add default values on example proto coder. * Various performance and stability improvements. - diff --git a/docs/build_tft_beam_docs.py b/docs/build_tft_beam_docs.py index 58d6ec1..f94dba0 100644 --- a/docs/build_tft_beam_docs.py +++ b/docs/build_tft_beam_docs.py @@ -25,51 +25,53 @@ ``` """ -from absl import app -from absl import flags -from tensorflow_docs.api_generator import doc_controls -from tensorflow_docs.api_generator import generate_lib -from tensorflow_docs.api_generator import public_api -import tensorflow_transform.beam as tft_beam +from absl import app, flags +from tensorflow_docs.api_generator import doc_controls, generate_lib, public_api +import tensorflow_transform.beam as tft_beam -flags.DEFINE_string('output_dir', '/tmp/tft_beam_api/', - 'The path to output the files to') +flags.DEFINE_string( + "output_dir", "/tmp/tft_beam_api/", "The path to output the files to" +) flags.DEFINE_string( - 'code_url_prefix', - 'https://github.com/tensorflow/transform/tree/master/tensorflow_transform', - 'The url prefix for links to code.') + "code_url_prefix", + "https://github.com/tensorflow/transform/tree/master/tensorflow_transform", + "The url prefix for links to code.", +) -flags.DEFINE_bool('search_hints', True, - 'Include metadata search hints in the generated files') +flags.DEFINE_bool( + "search_hints", True, "Include metadata search hints in the generated files" +) -flags.DEFINE_string('site_path', 'tfx/transform/api_docs/python', - 'Path prefix in the _toc.yaml') +flags.DEFINE_string( + "site_path", "tfx/transform/api_docs/python", "Path prefix in the _toc.yaml" +) FLAGS = flags.FLAGS def main(args): - if args[1:]: - raise ValueError('Unrecognized Command line args', args[1:]) + if args[1:]: + raise ValueError("Unrecognized Command line args", args[1:]) - doc_controls.do_not_generate_docs(tft_beam.analyzer_impls) + doc_controls.do_not_generate_docs(tft_beam.analyzer_impls) - doc_generator = generate_lib.DocGenerator( - root_title='TFT-Beam', - py_modules=[('tft_beam', tft_beam)], - code_url_prefix=FLAGS.code_url_prefix + '/beam', - search_hints=FLAGS.search_hints, - site_path=FLAGS.site_path, - callbacks=[ - public_api.explicit_package_contents_filter, - public_api.local_definitions_filter - ]) + doc_generator = generate_lib.DocGenerator( + root_title="TFT-Beam", + py_modules=[("tft_beam", tft_beam)], + code_url_prefix=FLAGS.code_url_prefix + "/beam", + search_hints=FLAGS.search_hints, + site_path=FLAGS.site_path, + callbacks=[ + public_api.explicit_package_contents_filter, + public_api.local_definitions_filter, + ], + ) - doc_generator.build(FLAGS.output_dir) + doc_generator.build(FLAGS.output_dir) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/docs/build_tft_docs.py b/docs/build_tft_docs.py index 1685d77..61227e9 100644 --- a/docs/build_tft_docs.py +++ b/docs/build_tft_docs.py @@ -25,45 +25,46 @@ ``` """ -from absl import app -from absl import flags -from tensorflow_docs.api_generator import generate_lib -from tensorflow_docs.api_generator import public_api -import tensorflow_transform as transform +from absl import app, flags +from tensorflow_docs.api_generator import generate_lib, public_api +import tensorflow_transform as transform -flags.DEFINE_string('output_dir', '/tmp/tft_api/', - 'The path to output the files to') +flags.DEFINE_string("output_dir", "/tmp/tft_api/", "The path to output the files to") flags.DEFINE_string( - 'code_url_prefix', - 'https://github.com/tensorflow/transform/tree/master/tensorflow_transform', - 'The url prefix for links to code.') + "code_url_prefix", + "https://github.com/tensorflow/transform/tree/master/tensorflow_transform", + "The url prefix for links to code.", +) -flags.DEFINE_bool('search_hints', True, - 'Include metadata search hints in the generated files') +flags.DEFINE_bool( + "search_hints", True, "Include metadata search hints in the generated files" +) -flags.DEFINE_string('site_path', 'tfx/transform/api_docs/python', - 'Path prefix in the _toc.yaml') +flags.DEFINE_string( + "site_path", "tfx/transform/api_docs/python", "Path prefix in the _toc.yaml" +) FLAGS = flags.FLAGS def main(args): - if args[1:]: - raise ValueError('Unrecognized Command line args', args[1:]) + if args[1:]: + raise ValueError("Unrecognized Command line args", args[1:]) - doc_generator = generate_lib.DocGenerator( - root_title='TF-Transform', - py_modules=[('tft', transform)], - code_url_prefix=FLAGS.code_url_prefix, - search_hints=FLAGS.search_hints, - site_path=FLAGS.site_path, - callbacks=[public_api.explicit_package_contents_filter]) + doc_generator = generate_lib.DocGenerator( + root_title="TF-Transform", + py_modules=[("tft", transform)], + code_url_prefix=FLAGS.code_url_prefix, + search_hints=FLAGS.search_hints, + site_path=FLAGS.site_path, + callbacks=[public_api.explicit_package_contents_filter], + ) - doc_generator.build(FLAGS.output_dir) + doc_generator.build(FLAGS.output_dir) -if __name__ == '__main__': - app.run(main) +if __name__ == "__main__": + app.run(main) diff --git a/examples/census_example.py b/examples/census_example.py index 8a600ea..0d321bb 100644 --- a/examples/census_example.py +++ b/examples/census_example.py @@ -18,177 +18,201 @@ import pprint import tempfile +import census_example_common as common import tensorflow as tf from tensorflow import estimator as tf_estimator + import tensorflow_transform as tft -import census_example_common as common # Functions for training def _make_inputs_dense(transformed_features): - return { - k: tf.sparse.to_dense(v) if isinstance(v, tf.SparseTensor) else v - for k, v in transformed_features.items() - } + return { + k: tf.sparse.to_dense(v) if isinstance(v, tf.SparseTensor) else v + for k, v in transformed_features.items() + } + + # pylint: disable=g-deprecated-tf-checker -def _make_training_input_fn(tf_transform_output, transformed_examples, - batch_size): - """Creates an input function reading from transformed data. - - Args: - tf_transform_output: Wrapper around output of tf.Transform. - transformed_examples: Base filename of examples. - batch_size: Batch size. - - Returns: - The input function for training or eval. - """ - def input_fn(): - """Input function for training and eval.""" - dataset = tf.data.experimental.make_batched_features_dataset( - file_pattern=transformed_examples, - batch_size=batch_size, - features=tf_transform_output.transformed_feature_spec(), - reader=tf.data.TFRecordDataset, - shuffle=True) - - transformed_features = _make_inputs_dense( - tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() - ) +def _make_training_input_fn(tf_transform_output, transformed_examples, batch_size): + """Creates an input function reading from transformed data. + + Args: + ---- + tf_transform_output: Wrapper around output of tf.Transform. + transformed_examples: Base filename of examples. + batch_size: Batch size. + + Returns: + ------- + The input function for training or eval. + """ + + def input_fn(): + """Input function for training and eval.""" + dataset = tf.data.experimental.make_batched_features_dataset( + file_pattern=transformed_examples, + batch_size=batch_size, + features=tf_transform_output.transformed_feature_spec(), + reader=tf.data.TFRecordDataset, + shuffle=True, + ) - # Extract features and label from the transformed tensors. - # TODO(b/30367437): make transformed_labels a dict. - transformed_labels = tf.where( - tf.equal(transformed_features.pop(common.LABEL_KEY), 1)) + transformed_features = _make_inputs_dense( + tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() + ) - return transformed_features, transformed_labels[:, 1] + # Extract features and label from the transformed tensors. + # TODO(b/30367437): make transformed_labels a dict. + transformed_labels = tf.where( + tf.equal(transformed_features.pop(common.LABEL_KEY), 1) + ) - return input_fn + return transformed_features, transformed_labels[:, 1] + + return input_fn def _make_serving_input_fn(tf_transform_output): - """Creates an input function reading from raw data. - - Args: - tf_transform_output: Wrapper around output of tf.Transform. - - Returns: - The serving input function. - """ - raw_feature_spec = common.RAW_DATA_FEATURE_SPEC.copy() - # Remove label since it is not available during serving. - raw_feature_spec.pop(common.LABEL_KEY) - - def serving_input_fn(): - """Input function for serving.""" - # Get raw features by generating the basic serving input_fn and calling it. - # Here we generate an input_fn that expects a parsed Example proto to be fed - # to the model at serving time. See also - # tf.estimator.export.build_raw_serving_input_receiver_fn. - raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( - raw_feature_spec, default_batch_size=None) - serving_input_receiver = raw_input_fn() - - # Apply the transform function that was used to generate the materialized - # data. - raw_features = serving_input_receiver.features - transformed_features = _make_inputs_dense( - tf_transform_output.transform_raw_features(raw_features) + """Creates an input function reading from raw data. + + Args: + ---- + tf_transform_output: Wrapper around output of tf.Transform. + + Returns: + ------- + The serving input function. + """ + raw_feature_spec = common.RAW_DATA_FEATURE_SPEC.copy() + # Remove label since it is not available during serving. + raw_feature_spec.pop(common.LABEL_KEY) + + def serving_input_fn(): + """Input function for serving.""" + # Get raw features by generating the basic serving input_fn and calling it. + # Here we generate an input_fn that expects a parsed Example proto to be fed + # to the model at serving time. See also + # tf.estimator.export.build_raw_serving_input_receiver_fn. + raw_input_fn = tf_estimator.export.build_parsing_serving_input_receiver_fn( + raw_feature_spec, default_batch_size=None + ) + serving_input_receiver = raw_input_fn() + + # Apply the transform function that was used to generate the materialized + # data. + raw_features = serving_input_receiver.features + transformed_features = _make_inputs_dense( + tf_transform_output.transform_raw_features(raw_features) + ) + + return tf_estimator.export.ServingInputReceiver( + transformed_features, serving_input_receiver.receiver_tensors + ) + + return serving_input_fn + + +def get_feature_columns(tf_transform_output): + """Returns the FeatureColumns for the model. + + Args: + ---- + tf_transform_output: A `TFTransformOutput` object. + + Returns: + ------- + A list of FeatureColumns. + """ + feature_spec = tf_transform_output.transformed_feature_spec() + + # Wrap scalars as real valued columns. + def get_shape(spec): + if isinstance(spec, tf.io.SparseFeature): + return spec.size + return spec.shape + + return [ + tf.feature_column.numeric_column(key, shape=get_shape(feature_spec[key])) + for key in (common.NUMERIC_FEATURE_KEYS + common.CATEGORICAL_FEATURE_KEYS) + ] + + +def train_and_evaluate( + working_dir, + num_train_instances=common.NUM_TRAIN_INSTANCES, + num_test_instances=common.NUM_TEST_INSTANCES, +): + """Train the model on training data and evaluate on test data. + + Args: + ---- + working_dir: Directory to read transformed data and metadata from and to + write exported model to. + num_train_instances: Number of instances in train set + num_test_instances: Number of instances in test set + + Returns: + ------- + The results from the estimator's 'evaluate' method + """ + tf_transform_output = tft.TFTransformOutput(working_dir) + + run_config = tf_estimator.RunConfig() + + estimator = tf_estimator.LinearClassifier( + feature_columns=get_feature_columns(tf_transform_output), + config=run_config, + loss_reduction=tf.losses.Reduction.SUM, ) - return tf_estimator.export.ServingInputReceiver( - transformed_features, serving_input_receiver.receiver_tensors) + # Fit the model using the default optimizer. + train_input_fn = _make_training_input_fn( + tf_transform_output, + os.path.join(working_dir, common.TRANSFORMED_TRAIN_DATA_FILEBASE + "*"), + batch_size=common.TRAIN_BATCH_SIZE, + ) + estimator.train( + input_fn=train_input_fn, + max_steps=common.TRAIN_NUM_EPOCHS + * num_train_instances + / common.TRAIN_BATCH_SIZE, + ) - return serving_input_fn + # Evaluate model on test dataset. + eval_input_fn = _make_training_input_fn( + tf_transform_output, + os.path.join(working_dir, common.TRANSFORMED_TEST_DATA_FILEBASE + "*"), + batch_size=1, + ) + # Export the model. + serving_input_fn = _make_serving_input_fn(tf_transform_output) + exported_model_dir = os.path.join(working_dir, common.EXPORTED_MODEL_DIR) + estimator.export_saved_model(exported_model_dir, serving_input_fn) -def get_feature_columns(tf_transform_output): - """Returns the FeatureColumns for the model. - - Args: - tf_transform_output: A `TFTransformOutput` object. - - Returns: - A list of FeatureColumns. - """ - feature_spec = tf_transform_output.transformed_feature_spec() - # Wrap scalars as real valued columns. - def get_shape(spec): - if isinstance(spec, tf.io.SparseFeature): - return spec.size - return spec.shape - - return [ - tf.feature_column.numeric_column(key, shape=get_shape(feature_spec[key])) - for key in (common.NUMERIC_FEATURE_KEYS + common.CATEGORICAL_FEATURE_KEYS) - ] - - -def train_and_evaluate(working_dir, - num_train_instances=common.NUM_TRAIN_INSTANCES, - num_test_instances=common.NUM_TEST_INSTANCES): - """Train the model on training data and evaluate on test data. - - Args: - working_dir: Directory to read transformed data and metadata from and to - write exported model to. - num_train_instances: Number of instances in train set - num_test_instances: Number of instances in test set - - Returns: - The results from the estimator's 'evaluate' method - """ - tf_transform_output = tft.TFTransformOutput(working_dir) - - run_config = tf_estimator.RunConfig() - - estimator = tf_estimator.LinearClassifier( - feature_columns=get_feature_columns(tf_transform_output), - config=run_config, - loss_reduction=tf.losses.Reduction.SUM) - - # Fit the model using the default optimizer. - train_input_fn = _make_training_input_fn( - tf_transform_output, - os.path.join(working_dir, common.TRANSFORMED_TRAIN_DATA_FILEBASE + '*'), - batch_size=common.TRAIN_BATCH_SIZE) - estimator.train( - input_fn=train_input_fn, - max_steps=common.TRAIN_NUM_EPOCHS * num_train_instances / - common.TRAIN_BATCH_SIZE) - - # Evaluate model on test dataset. - eval_input_fn = _make_training_input_fn( - tf_transform_output, - os.path.join(working_dir, common.TRANSFORMED_TEST_DATA_FILEBASE + '*'), - batch_size=1) - - # Export the model. - serving_input_fn = _make_serving_input_fn(tf_transform_output) - exported_model_dir = os.path.join(working_dir, common.EXPORTED_MODEL_DIR) - estimator.export_saved_model(exported_model_dir, serving_input_fn) - - return estimator.evaluate(input_fn=eval_input_fn, steps=num_test_instances) + return estimator.evaluate(input_fn=eval_input_fn, steps=num_test_instances) def main(): - args = common.get_args() - if args.working_dir: - working_dir = args.working_dir - else: - working_dir = tempfile.mkdtemp(dir=args.input_data_dir) + args = common.get_args() + if args.working_dir: + working_dir = args.working_dir + else: + working_dir = tempfile.mkdtemp(dir=args.input_data_dir) + + train_data_file = os.path.join(args.input_data_dir, "adult.data") + test_data_file = os.path.join(args.input_data_dir, "adult.test") - train_data_file = os.path.join(args.input_data_dir, 'adult.data') - test_data_file = os.path.join(args.input_data_dir, 'adult.test') + common.transform_data(train_data_file, test_data_file, working_dir) - common.transform_data(train_data_file, test_data_file, working_dir) + results = train_and_evaluate(working_dir) - results = train_and_evaluate(working_dir) + pprint.pprint(results) - pprint.pprint(results) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/examples/census_example_v2.py b/examples/census_example_v2.py index 9992995..a469664 100644 --- a/examples/census_example_v2.py +++ b/examples/census_example_v2.py @@ -14,74 +14,68 @@ """Example using census data from UCI repository.""" # pylint: disable=g-bad-import-order +import argparse import math import os import pprint import tempfile -from absl import logging -import tensorflow as tf -import tensorflow_transform as tft -from tensorflow_transform.keras_lib import tf_keras -import argparse - import apache_beam as beam +import tensorflow as tf import tensorflow.compat.v2 as tf -import tensorflow_transform.beam as tft_beam +from absl import logging from tfx_bsl.public import tfxio +import tensorflow_transform as tft +import tensorflow_transform.beam as tft_beam +from tensorflow_transform.keras_lib import tf_keras + # Functions for training CATEGORICAL_FEATURE_KEYS = [ - 'workclass', - 'education', - 'marital-status', - 'occupation', - 'relationship', - 'race', - 'sex', - 'native-country', + "workclass", + "education", + "marital-status", + "occupation", + "relationship", + "race", + "sex", + "native-country", ] NUMERIC_FEATURE_KEYS = [ - 'age', - 'capital-gain', - 'capital-loss', - 'hours-per-week', + "age", + "capital-gain", + "capital-loss", + "hours-per-week", ] OPTIONAL_NUMERIC_FEATURE_KEYS = [ - 'education-num', + "education-num", ] -LABEL_KEY = 'label' +LABEL_KEY = "label" ORDERED_CSV_COLUMNS = [ - 'age', - 'workclass', - 'fnlwgt', - 'education', - 'education-num', - 'marital-status', - 'occupation', - 'relationship', - 'race', - 'sex', - 'capital-gain', - 'capital-loss', - 'hours-per-week', - 'native-country', - 'label', + "age", + "workclass", + "fnlwgt", + "education", + "education-num", + "marital-status", + "occupation", + "relationship", + "race", + "sex", + "capital-gain", + "capital-loss", + "hours-per-week", + "native-country", + "label", ] RAW_DATA_FEATURE_SPEC = dict( - [ - (name, tf.io.FixedLenFeature([], tf.string)) - for name in CATEGORICAL_FEATURE_KEYS - ] - + [ - (name, tf.io.FixedLenFeature([], tf.float32)) - for name in NUMERIC_FEATURE_KEYS - ] + [(name, tf.io.FixedLenFeature([], tf.string)) for name in CATEGORICAL_FEATURE_KEYS] + + [(name, tf.io.FixedLenFeature([], tf.float32)) for name in NUMERIC_FEATURE_KEYS] + [ ( name, # pylint: disable=g-complex-comprehension @@ -112,205 +106,200 @@ NUM_OOV_BUCKETS = 1 # Names of temp files -TRANSFORMED_TRAIN_DATA_FILEBASE = 'train_transformed' -TRANSFORMED_TEST_DATA_FILEBASE = 'test_transformed' -EXPORTED_MODEL_DIR = 'exported_model_dir' +TRANSFORMED_TRAIN_DATA_FILEBASE = "train_transformed" +TRANSFORMED_TEST_DATA_FILEBASE = "test_transformed" +EXPORTED_MODEL_DIR = "exported_model_dir" parser = argparse.ArgumentParser() +parser.add_argument("--input_data_dir", help="path to directory containing input data") parser.add_argument( - '--input_data_dir', help='path to directory containing input data' -) -parser.add_argument( - '--working_dir', help='optional, path to directory to hold transformed data' + "--working_dir", help="optional, path to directory to hold transformed data" ) def get_args(): - return parser.parse_args() + return parser.parse_args() # Functions for preprocessing def transform_data(train_data_file: str, test_data_file: str, working_dir: str): - """Transform the data and write out as a TFRecord of Example protos. - - Read in the data using the CSV reader, and transform it using a - preprocessing pipeline that scales numeric data and converts categorical data - from strings to int64 values indices, by creating a vocabulary for each - category. - - Args: - train_data_file: File containing training data - test_data_file: File containing test data - working_dir: Directory to write transformed data and metadata to - """ - - def preprocessing_fn(inputs): - """Preprocess input columns into transformed columns.""" - # Since we are modifying some features and leaving others unchanged, we - # start by setting `outputs` to a copy of `inputs. - outputs = inputs.copy() - - # Scale numeric columns to have range [0, 1]. - for key in NUMERIC_FEATURE_KEYS: - outputs[key] = tft.scale_to_0_1(inputs[key]) - - for key in OPTIONAL_NUMERIC_FEATURE_KEYS: - # This is a RaggedTensor because it is optional. Here we fill in a default - # value when it is missing, after scaling it. - outputs[key] = tft.scale_to_0_1(inputs[key]).to_tensor( - default_value=0.0, shape=[None, 1] - ) - - # For all categorical columns except the label column, we generate a - # vocabulary, and convert the string feature to a one-hot encoding. - for key in CATEGORICAL_FEATURE_KEYS: - integerized = tft.compute_and_apply_vocabulary( - tf.strings.strip(inputs[key]), - num_oov_buckets=NUM_OOV_BUCKETS, - vocab_filename=key, - ) - depth = ( - tft.experimental.get_vocabulary_size_by_name(key) + NUM_OOV_BUCKETS - ) - one_hot_encoded = tf.one_hot( - integerized, - depth=tf.cast(depth, tf.int32), - on_value=1, - off_value=0, - dtype=tf.int64, - ) - # Saving one-hot encoded outputs as sparse in order to avoid large dense - # (mostly empty) tensors. This is especially important when saving - # transformed data to disk. - outputs[key] = tf.sparse.from_dense( - tf.reshape(one_hot_encoded, [-1, depth]) - ) - tft.experimental.annotate_sparse_output_shape(outputs[key], depth) - - # For the label column we provide the mapping from string to index. - table_keys = ['>50K', '<=50K'] - with tf.init_scope(): - initializer = tf.lookup.KeyValueTensorInitializer( - keys=table_keys, - values=tf.cast(tf.range(len(table_keys)), tf.int64), - key_dtype=tf.string, - value_dtype=tf.int64, - ) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - # Remove trailing periods for test data when the data is read with tf.data. - label_str = tf.strings.regex_replace(inputs[LABEL_KEY], r'\.', '') - label_str = tf.strings.strip(label_str) - data_labels = table.lookup(label_str) - transformed_label = tf.one_hot( - indices=data_labels, depth=len(table_keys), on_value=1.0, off_value=0.0 - ) - outputs[LABEL_KEY] = tf.reshape(transformed_label, [-1, len(table_keys)]) - - return outputs - - # The "with" block will create a pipeline, and run that pipeline at the exit - # of the block. - with beam.Pipeline() as pipeline: - with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - # Create a TFXIO to read the census data with the schema. To do this we - # need to list all columns in order since the schema doesn't specify the - # order of columns in the csv. - # We first read CSV files and use BeamRecordCsvTFXIO whose .BeamSource() - # accepts a PCollection[bytes] because we need to patch the records first - # (see "FixCommasTrainData" below). Otherwise, tfxio.CsvTFXIO can be used - # to both read the CSV files and parse them to TFT inputs: - # csv_tfxio = tfxio.CsvTFXIO(...) - # raw_data = (pipeline | 'ToRecordBatches' >> csv_tfxio.BeamSource()) - csv_tfxio = tfxio.BeamRecordCsvTFXIO( - physical_format='text', - column_names=ORDERED_CSV_COLUMNS, - schema=_SCHEMA, - ) - - # Read in raw data and convert using CSV TFXIO. Note that we apply - # some Beam transformations here, which will not be encoded in the TF - # graph since we don't do the from within tf.Transform's methods - # (AnalyzeDataset, TransformDataset etc.). These transformations are just - # to get data into a format that the CSV TFXIO can read, in particular - # removing spaces after commas. - raw_data = ( - pipeline - | 'ReadTrainData' - >> beam.io.ReadFromText( - train_data_file, coder=beam.coders.BytesCoder() - ) - | 'FixCommasTrainData' - >> beam.Map(lambda line: line.replace(b', ', b',')) - | 'DecodeTrainData' >> csv_tfxio.BeamSource() - ) - - # Combine data and schema into a dataset tuple. Note that we already used - # the schema to read the CSV data, but we also need it to interpret - # raw_data. - raw_dataset = (raw_data, csv_tfxio.TensorAdapterConfig()) - - # The TFXIO output format is chosen for improved performance. - transformed_dataset, transform_fn = ( - raw_dataset - | tft_beam.AnalyzeAndTransformDataset( - preprocessing_fn, output_record_batches=True - ) - ) - - # Extract transformed RecordBatches, encode and write them to the given - # directory. - _ = ( - transformed_dataset - | 'EncodeTrainData' >> tft_beam.EncodeTransformedDataset() - | 'WriteTrainData' - >> beam.io.WriteToTFRecord( - os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE) - ) - ) - - # Now apply transform function to test data. In this case we remove the - # trailing period at the end of each line, and also ignore the header line - # that is present in the test data file. - raw_test_data = ( - pipeline - | 'ReadTestData' - >> beam.io.ReadFromText( - test_data_file, - skip_header_lines=1, - coder=beam.coders.BytesCoder(), - ) - | 'FixCommasTestData' - >> beam.Map(lambda line: line.replace(b', ', b',')) - | 'RemoveTrailingPeriodsTestData' >> beam.Map(lambda line: line[:-1]) - | 'DecodeTestData' >> csv_tfxio.BeamSource() - ) - - raw_test_dataset = (raw_test_data, csv_tfxio.TensorAdapterConfig()) - - # The TFXIO output format is chosen for improved performance. - transformed_test_dataset = ( - raw_test_dataset, - transform_fn, - ) | tft_beam.TransformDataset(output_record_batches=True) - - # Extract transformed RecordBatches, encode and write them to the given - # directory. - _ = ( - transformed_test_dataset - | 'EncodeTestData' >> tft_beam.EncodeTransformedDataset() - | 'WriteTestData' - >> beam.io.WriteToTFRecord( - os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE) - ) - ) - - # Will write a SavedModel and metadata to working_dir, which can then - # be read by the tft.TFTransformOutput class. - _ = transform_fn | 'WriteTransformFn' >> tft_beam.WriteTransformFn( - working_dir - ) + """Transform the data and write out as a TFRecord of Example protos. + + Read in the data using the CSV reader, and transform it using a + preprocessing pipeline that scales numeric data and converts categorical data + from strings to int64 values indices, by creating a vocabulary for each + category. + + Args: + ---- + train_data_file: File containing training data + test_data_file: File containing test data + working_dir: Directory to write transformed data and metadata to + """ + + def preprocessing_fn(inputs): + """Preprocess input columns into transformed columns.""" + # Since we are modifying some features and leaving others unchanged, we + # start by setting `outputs` to a copy of `inputs. + outputs = inputs.copy() + + # Scale numeric columns to have range [0, 1]. + for key in NUMERIC_FEATURE_KEYS: + outputs[key] = tft.scale_to_0_1(inputs[key]) + + for key in OPTIONAL_NUMERIC_FEATURE_KEYS: + # This is a RaggedTensor because it is optional. Here we fill in a default + # value when it is missing, after scaling it. + outputs[key] = tft.scale_to_0_1(inputs[key]).to_tensor( + default_value=0.0, shape=[None, 1] + ) + + # For all categorical columns except the label column, we generate a + # vocabulary, and convert the string feature to a one-hot encoding. + for key in CATEGORICAL_FEATURE_KEYS: + integerized = tft.compute_and_apply_vocabulary( + tf.strings.strip(inputs[key]), + num_oov_buckets=NUM_OOV_BUCKETS, + vocab_filename=key, + ) + depth = tft.experimental.get_vocabulary_size_by_name(key) + NUM_OOV_BUCKETS + one_hot_encoded = tf.one_hot( + integerized, + depth=tf.cast(depth, tf.int32), + on_value=1, + off_value=0, + dtype=tf.int64, + ) + # Saving one-hot encoded outputs as sparse in order to avoid large dense + # (mostly empty) tensors. This is especially important when saving + # transformed data to disk. + outputs[key] = tf.sparse.from_dense( + tf.reshape(one_hot_encoded, [-1, depth]) + ) + tft.experimental.annotate_sparse_output_shape(outputs[key], depth) + + # For the label column we provide the mapping from string to index. + table_keys = [">50K", "<=50K"] + with tf.init_scope(): + initializer = tf.lookup.KeyValueTensorInitializer( + keys=table_keys, + values=tf.cast(tf.range(len(table_keys)), tf.int64), + key_dtype=tf.string, + value_dtype=tf.int64, + ) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + # Remove trailing periods for test data when the data is read with tf.data. + label_str = tf.strings.regex_replace(inputs[LABEL_KEY], r"\.", "") + label_str = tf.strings.strip(label_str) + data_labels = table.lookup(label_str) + transformed_label = tf.one_hot( + indices=data_labels, depth=len(table_keys), on_value=1.0, off_value=0.0 + ) + outputs[LABEL_KEY] = tf.reshape(transformed_label, [-1, len(table_keys)]) + + return outputs + + # The "with" block will create a pipeline, and run that pipeline at the exit + # of the block. + with beam.Pipeline() as pipeline: + with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + # Create a TFXIO to read the census data with the schema. To do this we + # need to list all columns in order since the schema doesn't specify the + # order of columns in the csv. + # We first read CSV files and use BeamRecordCsvTFXIO whose .BeamSource() + # accepts a PCollection[bytes] because we need to patch the records first + # (see "FixCommasTrainData" below). Otherwise, tfxio.CsvTFXIO can be used + # to both read the CSV files and parse them to TFT inputs: + # csv_tfxio = tfxio.CsvTFXIO(...) + # raw_data = (pipeline | 'ToRecordBatches' >> csv_tfxio.BeamSource()) + csv_tfxio = tfxio.BeamRecordCsvTFXIO( + physical_format="text", + column_names=ORDERED_CSV_COLUMNS, + schema=_SCHEMA, + ) + + # Read in raw data and convert using CSV TFXIO. Note that we apply + # some Beam transformations here, which will not be encoded in the TF + # graph since we don't do the from within tf.Transform's methods + # (AnalyzeDataset, TransformDataset etc.). These transformations are just + # to get data into a format that the CSV TFXIO can read, in particular + # removing spaces after commas. + raw_data = ( + pipeline + | "ReadTrainData" + >> beam.io.ReadFromText(train_data_file, coder=beam.coders.BytesCoder()) + | "FixCommasTrainData" + >> beam.Map(lambda line: line.replace(b", ", b",")) + | "DecodeTrainData" >> csv_tfxio.BeamSource() + ) + + # Combine data and schema into a dataset tuple. Note that we already used + # the schema to read the CSV data, but we also need it to interpret + # raw_data. + raw_dataset = (raw_data, csv_tfxio.TensorAdapterConfig()) + + # The TFXIO output format is chosen for improved performance. + transformed_dataset, transform_fn = ( + raw_dataset + | tft_beam.AnalyzeAndTransformDataset( + preprocessing_fn, output_record_batches=True + ) + ) + + # Extract transformed RecordBatches, encode and write them to the given + # directory. + _ = ( + transformed_dataset + | "EncodeTrainData" >> tft_beam.EncodeTransformedDataset() + | "WriteTrainData" + >> beam.io.WriteToTFRecord( + os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE) + ) + ) + + # Now apply transform function to test data. In this case we remove the + # trailing period at the end of each line, and also ignore the header line + # that is present in the test data file. + raw_test_data = ( + pipeline + | "ReadTestData" + >> beam.io.ReadFromText( + test_data_file, + skip_header_lines=1, + coder=beam.coders.BytesCoder(), + ) + | "FixCommasTestData" + >> beam.Map(lambda line: line.replace(b", ", b",")) + | "RemoveTrailingPeriodsTestData" >> beam.Map(lambda line: line[:-1]) + | "DecodeTestData" >> csv_tfxio.BeamSource() + ) + + raw_test_dataset = (raw_test_data, csv_tfxio.TensorAdapterConfig()) + + # The TFXIO output format is chosen for improved performance. + transformed_test_dataset = ( + raw_test_dataset, + transform_fn, + ) | tft_beam.TransformDataset(output_record_batches=True) + + # Extract transformed RecordBatches, encode and write them to the given + # directory. + _ = ( + transformed_test_dataset + | "EncodeTestData" >> tft_beam.EncodeTransformedDataset() + | "WriteTestData" + >> beam.io.WriteToTFRecord( + os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE) + ) + ) + + # Will write a SavedModel and metadata to working_dir, which can then + # be read by the tft.TFTransformOutput class. + _ = transform_fn | "WriteTransformFn" >> tft_beam.WriteTransformFn( + working_dir + ) def input_fn( @@ -318,24 +307,26 @@ def input_fn( transformed_examples_pattern: str, batch_size: int, ): - """An input function reading from transformed data, converting to model input. - - Args: - tf_transform_output: Wrapper around output of tf.Transform. - transformed_examples_pattern: Base filename of examples. - batch_size: Batch size. - - Returns: - The input data for training or eval, in the form of k. - """ - return tf.data.experimental.make_batched_features_dataset( - file_pattern=transformed_examples_pattern, - batch_size=batch_size, - features=tf_transform_output.transformed_feature_spec(), - reader=tf.data.TFRecordDataset, - label_key=LABEL_KEY, - shuffle=True, - ).prefetch(tf.data.experimental.AUTOTUNE) + """An input function reading from transformed data, converting to model input. + + Args: + ---- + tf_transform_output: Wrapper around output of tf.Transform. + transformed_examples_pattern: Base filename of examples. + batch_size: Batch size. + + Returns: + ------- + The input data for training or eval, in the form of k. + """ + return tf.data.experimental.make_batched_features_dataset( + file_pattern=transformed_examples_pattern, + batch_size=batch_size, + features=tf_transform_output.transformed_feature_spec(), + reader=tf.data.TFRecordDataset, + label_key=LABEL_KEY, + shuffle=True, + ).prefetch(tf.data.experimental.AUTOTUNE) def input_fn_raw( @@ -343,60 +334,62 @@ def input_fn_raw( raw_examples_pattern: str, batch_size: int, ): - """An input function reading from raw data, converting to model input. - - Args: - tf_transform_output: Wrapper around output of tf.Transform. - raw_examples_pattern: Base filename of examples. - batch_size: Batch size. - - Returns: - The input data for training or eval, in the form of k. - """ - - def get_ordered_raw_data_dtypes(): - result = [] - for col in ORDERED_CSV_COLUMNS: - if col not in RAW_DATA_FEATURE_SPEC: - result.append(0.0) - continue - spec = RAW_DATA_FEATURE_SPEC[col] - if isinstance(spec, tf.io.FixedLenFeature): - result.append(spec.dtype) - else: - result.append(0.0) - return result - - dataset = tf.data.experimental.make_csv_dataset( - file_pattern=raw_examples_pattern, - batch_size=batch_size, - column_names=ORDERED_CSV_COLUMNS, - column_defaults=get_ordered_raw_data_dtypes(), - prefetch_buffer_size=0, - ignore_errors=True, - ) - - tft_layer = tf_transform_output.transform_features_layer() - - def transform_dataset(data): - raw_features = {} - for key, val in data.items(): - if key not in RAW_DATA_FEATURE_SPEC: - continue - if isinstance(RAW_DATA_FEATURE_SPEC[key], tf.io.RaggedFeature): - # make_csv_dataset will set the value to 0 when it's missing. - raw_features[key] = tf.RaggedTensor.from_tensor( - tf.expand_dims(val, axis=-1), padding=0) - continue - raw_features[key] = val - transformed_features = tft_layer(raw_features) - data_labels = transformed_features.pop(LABEL_KEY) - return (transformed_features, data_labels) - - return dataset.map( - transform_dataset, - num_parallel_calls=tf.data.experimental.AUTOTUNE).prefetch( - tf.data.experimental.AUTOTUNE) + """An input function reading from raw data, converting to model input. + + Args: + ---- + tf_transform_output: Wrapper around output of tf.Transform. + raw_examples_pattern: Base filename of examples. + batch_size: Batch size. + + Returns: + ------- + The input data for training or eval, in the form of k. + """ + + def get_ordered_raw_data_dtypes(): + result = [] + for col in ORDERED_CSV_COLUMNS: + if col not in RAW_DATA_FEATURE_SPEC: + result.append(0.0) + continue + spec = RAW_DATA_FEATURE_SPEC[col] + if isinstance(spec, tf.io.FixedLenFeature): + result.append(spec.dtype) + else: + result.append(0.0) + return result + + dataset = tf.data.experimental.make_csv_dataset( + file_pattern=raw_examples_pattern, + batch_size=batch_size, + column_names=ORDERED_CSV_COLUMNS, + column_defaults=get_ordered_raw_data_dtypes(), + prefetch_buffer_size=0, + ignore_errors=True, + ) + + tft_layer = tf_transform_output.transform_features_layer() + + def transform_dataset(data): + raw_features = {} + for key, val in data.items(): + if key not in RAW_DATA_FEATURE_SPEC: + continue + if isinstance(RAW_DATA_FEATURE_SPEC[key], tf.io.RaggedFeature): + # make_csv_dataset will set the value to 0 when it's missing. + raw_features[key] = tf.RaggedTensor.from_tensor( + tf.expand_dims(val, axis=-1), padding=0 + ) + continue + raw_features[key] = val + transformed_features = tft_layer(raw_features) + data_labels = transformed_features.pop(LABEL_KEY) + return (transformed_features, data_labels) + + return dataset.map( + transform_dataset, num_parallel_calls=tf.data.experimental.AUTOTUNE + ).prefetch(tf.data.experimental.AUTOTUNE) def export_serving_model( @@ -404,35 +397,37 @@ def export_serving_model( model: tf_keras.Model, output_dir: str, ): - """Exports a keras model for serving. - - Args: - tf_transform_output: Wrapper around output of tf.Transform. - model: A keras model to export for serving. - output_dir: A directory where the model will be exported to. - """ - # The layer has to be saved to the model for keras tracking purpases. - model.tft_layer = tf_transform_output.transform_features_layer() - - @tf.function - def serve_tf_examples_fn(serialized_tf_examples): - """Serving tf.function model wrapper.""" - feature_spec = RAW_DATA_FEATURE_SPEC.copy() - feature_spec.pop(LABEL_KEY) - parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec) - transformed_features = model.tft_layer(parsed_features) - outputs = model(transformed_features) - classes_names = tf.constant([['0', '1']]) - classes = tf.tile(classes_names, [tf.shape(outputs)[0], 1]) - return {'classes': classes, 'scores': outputs} - - concrete_serving_fn = serve_tf_examples_fn.get_concrete_function( - tf.TensorSpec(shape=[None], dtype=tf.string, name='inputs')) - signatures = {'serving_default': concrete_serving_fn} + """Exports a keras model for serving. + + Args: + ---- + tf_transform_output: Wrapper around output of tf.Transform. + model: A keras model to export for serving. + output_dir: A directory where the model will be exported to. + """ + # The layer has to be saved to the model for keras tracking purpases. + model.tft_layer = tf_transform_output.transform_features_layer() + + @tf.function + def serve_tf_examples_fn(serialized_tf_examples): + """Serving tf.function model wrapper.""" + feature_spec = RAW_DATA_FEATURE_SPEC.copy() + feature_spec.pop(LABEL_KEY) + parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec) + transformed_features = model.tft_layer(parsed_features) + outputs = model(transformed_features) + classes_names = tf.constant([["0", "1"]]) + classes = tf.tile(classes_names, [tf.shape(outputs)[0], 1]) + return {"classes": classes, "scores": outputs} + + concrete_serving_fn = serve_tf_examples_fn.get_concrete_function( + tf.TensorSpec(shape=[None], dtype=tf.string, name="inputs") + ) + signatures = {"serving_default": concrete_serving_fn} - # This is required in order to make this model servable with model_server. - versioned_output_dir = os.path.join(output_dir, '1') - model.save(versioned_output_dir, save_format='tf', signatures=signatures) + # This is required in order to make this model servable with model_server. + versioned_output_dir = os.path.join(output_dir, "1") + model.save(versioned_output_dir, save_format="tf", signatures=signatures) def train_and_evaluate( @@ -443,96 +438,102 @@ def train_and_evaluate( num_train_instances: int = NUM_TRAIN_INSTANCES, num_test_instances: int = NUM_TEST_INSTANCES, ): - """Train the model on training data and evaluate on test data. - - Args: - raw_train_eval_data_path_pattern: A pair of patterns of raw - (train data file paths, eval data file paths) in CSV format. - transformed_train_eval_data_path_pattern: A pair of patterns of transformed - (train data file paths, eval data file paths) in TFRecord format. - output_dir: A directory where the output should be exported to. - transform_output_dir: The location of the Transform output. - num_train_instances: Number of instances in train set - num_test_instances: Number of instances in test set - - Returns: - The results from the estimator's 'evaluate' method - """ - if not ((raw_train_eval_data_path_pattern is None) ^ - (transformed_train_eval_data_path_pattern is None)): - raise ValueError( - 'Exactly one of raw_train_eval_data_path_pattern and ' - 'transformed_train_eval_data_path_pattern should be provided') - tf_transform_output = tft.TFTransformOutput(transform_output_dir) - - if raw_train_eval_data_path_pattern is not None: - selected_input_fn = input_fn_raw - (train_data_path_pattern, - eval_data_path_pattern) = raw_train_eval_data_path_pattern - else: - selected_input_fn = input_fn - (train_data_path_pattern, - eval_data_path_pattern) = transformed_train_eval_data_path_pattern - - train_dataset = selected_input_fn( - tf_transform_output, train_data_path_pattern, batch_size=TRAIN_BATCH_SIZE - ) - - # Evaluate model on test dataset. - validation_dataset = selected_input_fn( - tf_transform_output, eval_data_path_pattern, batch_size=TRAIN_BATCH_SIZE - ) - - feature_spec = tf_transform_output.transformed_feature_spec().copy() - feature_spec.pop(LABEL_KEY) - - inputs = {} - sparse_inputs = {} - dense_inputs = {} - for key, spec in feature_spec.items(): - if isinstance(spec, tf.io.FixedLenFeature): - # TODO(b/208879020): Move into schema such that spec.shape is [1] and not - # [] for scalars. - inputs[key] = tf_keras.layers.Input( - shape=spec.shape or [1], name=key, dtype=spec.dtype) - dense_inputs[key] = inputs[key] - elif isinstance(spec, tf.io.SparseFeature): - inputs[key] = tf_keras.layers.Input( - shape=spec.size, name=key, dtype=spec.dtype, sparse=True - ) - sparse_inputs[key] = inputs[key] + """Train the model on training data and evaluate on test data. + + Args: + ---- + raw_train_eval_data_path_pattern: A pair of patterns of raw + (train data file paths, eval data file paths) in CSV format. + transformed_train_eval_data_path_pattern: A pair of patterns of transformed + (train data file paths, eval data file paths) in TFRecord format. + output_dir: A directory where the output should be exported to. + transform_output_dir: The location of the Transform output. + num_train_instances: Number of instances in train set + num_test_instances: Number of instances in test set + + Returns: + ------- + The results from the estimator's 'evaluate' method + """ + if not ( + (raw_train_eval_data_path_pattern is None) + ^ (transformed_train_eval_data_path_pattern is None) + ): + raise ValueError( + "Exactly one of raw_train_eval_data_path_pattern and " + "transformed_train_eval_data_path_pattern should be provided" + ) + tf_transform_output = tft.TFTransformOutput(transform_output_dir) + + if raw_train_eval_data_path_pattern is not None: + selected_input_fn = input_fn_raw + (train_data_path_pattern, eval_data_path_pattern) = ( + raw_train_eval_data_path_pattern + ) else: - raise ValueError('Spec type is not supported: ', key, spec) - - outputs = [ - tf_keras.layers.Dense(10, activation='relu')(x) - for x in tf.nest.flatten(sparse_inputs) - ] - stacked_inputs = tf.concat(tf.nest.flatten(dense_inputs) + outputs, axis=1) - output = tf_keras.layers.Dense(100, activation='relu')(stacked_inputs) - output = tf_keras.layers.Dense(70, activation='relu')(output) - output = tf_keras.layers.Dense(50, activation='relu')(output) - output = tf_keras.layers.Dense(20, activation='relu')(output) - output = tf_keras.layers.Dense(2, activation='sigmoid')(output) - model = tf_keras.Model(inputs=inputs, outputs=output) - - model.compile(optimizer='adam', - loss='binary_crossentropy', - metrics=['accuracy']) - logging.info(model.summary()) - - model.fit( - train_dataset, - validation_data=validation_dataset, - epochs=TRAIN_NUM_EPOCHS, - steps_per_epoch=math.ceil(num_train_instances / TRAIN_BATCH_SIZE), - validation_steps=math.ceil(num_test_instances / TRAIN_BATCH_SIZE), - ) - - # Export the model. - export_serving_model(tf_transform_output, model, output_dir) - - return model.evaluate(validation_dataset, steps=num_test_instances) + selected_input_fn = input_fn + (train_data_path_pattern, eval_data_path_pattern) = ( + transformed_train_eval_data_path_pattern + ) + + train_dataset = selected_input_fn( + tf_transform_output, train_data_path_pattern, batch_size=TRAIN_BATCH_SIZE + ) + + # Evaluate model on test dataset. + validation_dataset = selected_input_fn( + tf_transform_output, eval_data_path_pattern, batch_size=TRAIN_BATCH_SIZE + ) + + feature_spec = tf_transform_output.transformed_feature_spec().copy() + feature_spec.pop(LABEL_KEY) + + inputs = {} + sparse_inputs = {} + dense_inputs = {} + for key, spec in feature_spec.items(): + if isinstance(spec, tf.io.FixedLenFeature): + # TODO(b/208879020): Move into schema such that spec.shape is [1] and not + # [] for scalars. + inputs[key] = tf_keras.layers.Input( + shape=spec.shape or [1], name=key, dtype=spec.dtype + ) + dense_inputs[key] = inputs[key] + elif isinstance(spec, tf.io.SparseFeature): + inputs[key] = tf_keras.layers.Input( + shape=spec.size, name=key, dtype=spec.dtype, sparse=True + ) + sparse_inputs[key] = inputs[key] + else: + raise ValueError("Spec type is not supported: ", key, spec) + + outputs = [ + tf_keras.layers.Dense(10, activation="relu")(x) + for x in tf.nest.flatten(sparse_inputs) + ] + stacked_inputs = tf.concat(tf.nest.flatten(dense_inputs) + outputs, axis=1) + output = tf_keras.layers.Dense(100, activation="relu")(stacked_inputs) + output = tf_keras.layers.Dense(70, activation="relu")(output) + output = tf_keras.layers.Dense(50, activation="relu")(output) + output = tf_keras.layers.Dense(20, activation="relu")(output) + output = tf_keras.layers.Dense(2, activation="sigmoid")(output) + model = tf_keras.Model(inputs=inputs, outputs=output) + + model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]) + logging.info(model.summary()) + + model.fit( + train_dataset, + validation_data=validation_dataset, + epochs=TRAIN_NUM_EPOCHS, + steps_per_epoch=math.ceil(num_train_instances / TRAIN_BATCH_SIZE), + validation_steps=math.ceil(num_test_instances / TRAIN_BATCH_SIZE), + ) + + # Export the model. + export_serving_model(tf_transform_output, model, output_dir) + + return model.evaluate(validation_dataset, steps=num_test_instances) def main( @@ -542,38 +543,35 @@ def main( num_train_instances: int = NUM_TRAIN_INSTANCES, num_test_instances: int = NUM_TEST_INSTANCES, ): - if not working_dir: - working_dir = tempfile.mkdtemp(dir=input_data_dir) + if not working_dir: + working_dir = tempfile.mkdtemp(dir=input_data_dir) - train_data_file = os.path.join(input_data_dir, 'adult.data') - test_data_file = os.path.join(input_data_dir, 'adult.test') + train_data_file = os.path.join(input_data_dir, "adult.data") + test_data_file = os.path.join(input_data_dir, "adult.test") - transform_data(train_data_file, test_data_file, working_dir) + transform_data(train_data_file, test_data_file, working_dir) - if read_raw_data_for_training: - raw_train_and_eval_patterns = (train_data_file, test_data_file) - transformed_train_and_eval_patterns = None - else: - train_pattern = os.path.join( - working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE + '*' - ) - eval_pattern = os.path.join( - working_dir, TRANSFORMED_TEST_DATA_FILEBASE + '*' + if read_raw_data_for_training: + raw_train_and_eval_patterns = (train_data_file, test_data_file) + transformed_train_and_eval_patterns = None + else: + train_pattern = os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE + "*") + eval_pattern = os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE + "*") + raw_train_and_eval_patterns = None + transformed_train_and_eval_patterns = (train_pattern, eval_pattern) + output_dir = os.path.join(working_dir, EXPORTED_MODEL_DIR) + results = train_and_evaluate( + raw_train_and_eval_patterns, + transformed_train_and_eval_patterns, + output_dir, + working_dir, + num_train_instances=num_train_instances, + num_test_instances=num_test_instances, ) - raw_train_and_eval_patterns = None - transformed_train_and_eval_patterns = (train_pattern, eval_pattern) - output_dir = os.path.join(working_dir, EXPORTED_MODEL_DIR) - results = train_and_evaluate( - raw_train_and_eval_patterns, - transformed_train_and_eval_patterns, - output_dir, - working_dir, - num_train_instances=num_train_instances, - num_test_instances=num_test_instances) - - pprint.pprint(results) - - -if __name__ == '__main__': - args = get_args() - main(args.input_data_dir, args.working_dir) + + pprint.pprint(results) + + +if __name__ == "__main__": + args = get_args() + main(args.input_data_dir, args.working_dir) diff --git a/examples/census_example_v2_test.py b/examples/census_example_v2_test.py index ac9f57f..36634d6 100644 --- a/examples/census_example_v2_test.py +++ b/examples/census_example_v2_test.py @@ -15,15 +15,15 @@ import os import shutil -from packaging import version -import tensorflow.compat.v2 as tf import census_example_v2 -from tensorflow_transform import test_case as tft_test_case import local_model_server -from tensorflow_transform.keras_lib import tf_keras - +import tensorflow.compat.v2 as tf from google.protobuf import text_format +from packaging import version + +from tensorflow_transform import test_case as tft_test_case +from tensorflow_transform.keras_lib import tf_keras # Use first row of test data set, which has high probability on label 1 (which # corresponds to '<=50K'). @@ -84,7 +84,7 @@ } """ -_MODEL_NAME = 'my_model' +_MODEL_NAME = "my_model" _CLASSIFICATION_REQUEST_TEXT_PB = """model_spec { name: "%s" } input { @@ -97,124 +97,132 @@ class CensusExampleV2Test(tft_test_case.TransformTestCase): - - def setUp(self): - super().setUp() - if tft_test_case.is_external_environment() and version.parse( - tf.version.VERSION - ) < version.parse('2.3'): - raise tft_test_case.SkipTest('This test requires TF version >= 2.3') - - def _get_data_dir(self): - return os.path.join(os.path.dirname(__file__), 'testdata/census') - - def _get_working_dir(self): - return os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - - def _should_saved_model_load_work(self): - return version.parse(tf.__version__) >= version.parse('2.2') - - @tft_test_case.named_parameters([ - dict( - testcase_name='_read_raw_data_for_training', - read_raw_data_for_training=True), - dict( - testcase_name='_read_transformed_data_for_training', - read_raw_data_for_training=False), - ]) - def testCensusExampleAccuracy(self, read_raw_data_for_training): - - if not self._should_saved_model_load_work(): - self.skipTest('The generated SavedModel cannot be read with TF<2.2') - raw_data_dir = self._get_data_dir() - working_dir = self._get_working_dir() - - train_data_file = os.path.join(raw_data_dir, 'adult.data') - test_data_file = os.path.join(raw_data_dir, 'adult.test') - - census_example_v2.transform_data( - train_data_file, test_data_file, working_dir + def setUp(self): + super().setUp() + if tft_test_case.is_external_environment() and version.parse( + tf.version.VERSION + ) < version.parse("2.3"): + raise tft_test_case.SkipTest("This test requires TF version >= 2.3") + + def _get_data_dir(self): + return os.path.join(os.path.dirname(__file__), "testdata/census") + + def _get_working_dir(self): + return os.path.join( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", self.get_temp_dir()), + self._testMethodName, + ) + + def _should_saved_model_load_work(self): + return version.parse(tf.__version__) >= version.parse("2.2") + + @tft_test_case.named_parameters( + [ + dict( + testcase_name="_read_raw_data_for_training", + read_raw_data_for_training=True, + ), + dict( + testcase_name="_read_transformed_data_for_training", + read_raw_data_for_training=False, + ), + ] ) - - if read_raw_data_for_training: - raw_train_and_eval_patterns = (train_data_file, test_data_file) - transformed_train_and_eval_patterns = None - else: - train_pattern = os.path.join( - working_dir, census_example_v2.TRANSFORMED_TRAIN_DATA_FILEBASE + '*' - ) - eval_pattern = os.path.join( - working_dir, census_example_v2.TRANSFORMED_TEST_DATA_FILEBASE + '*' - ) - raw_train_and_eval_patterns = None - transformed_train_and_eval_patterns = (train_pattern, eval_pattern) - output_dir = os.path.join(working_dir, census_example_v2.EXPORTED_MODEL_DIR) - results = census_example_v2.train_and_evaluate( - raw_train_and_eval_patterns, - transformed_train_and_eval_patterns, - output_dir, - working_dir, - num_train_instances=1000, - num_test_instances=1000) - self.assertGreaterEqual(results[1], 0.7) - - # Removing the tf.Transform output directory in order to show that the - # exported model is hermetic. - shutil.rmtree(os.path.join(working_dir, 'transform_fn')) - - model_path = os.path.join(working_dir, census_example_v2.EXPORTED_MODEL_DIR) - - actual_model_path = os.path.join(model_path, '1') - tf_keras.backend.clear_session() - model = tf_keras.models.load_model(actual_model_path) - model.summary() - - example = text_format.Parse(_PREDICT_TF_EXAMPLE_TEXT_PB, tf.train.Example()) - prediction = model.signatures['serving_default']( - tf.constant([example.SerializeToString()], tf.string)) - self.assertAllEqual([['0', '1']], prediction['classes']) - self.assertAllClose([[0, 1]], prediction['scores'], atol=0.01) - - # This is required in order to support the classify API for this Keras - # model. - updater = tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater( - actual_model_path) - updater.replace_method_name( - signature_key='serving_default', - method_name='tensorflow/serving/classify', - tags=['serve']) - updater.save() - - if local_model_server.local_model_server_supported(): - with local_model_server.start_server(_MODEL_NAME, model_path) as address: - ascii_classification_request = _CLASSIFICATION_REQUEST_TEXT_PB - results = local_model_server.make_classification_request( - address, ascii_classification_request) - self.assertLen(results, 1) - self.assertLen(results[0].classes, 2) - self.assertEqual(results[0].classes[0].label, '0') - self.assertLess(results[0].classes[0].score, 0.01) - self.assertEqual(results[0].classes[1].label, '1') - self.assertGreater(results[0].classes[1].score, 0.99) - - def test_main_runs(self): - census_example_v2.main( - self._get_data_dir(), - self._get_working_dir(), - read_raw_data_for_training=False, - num_train_instances=10, - num_test_instances=10) - - def test_main_runs_raw_data(self): - census_example_v2.main( - self._get_data_dir(), - self._get_working_dir(), - read_raw_data_for_training=True, - num_train_instances=10, - num_test_instances=10) - - -if __name__ == '__main__': - tf.test.main() + def testCensusExampleAccuracy(self, read_raw_data_for_training): + if not self._should_saved_model_load_work(): + self.skipTest("The generated SavedModel cannot be read with TF<2.2") + raw_data_dir = self._get_data_dir() + working_dir = self._get_working_dir() + + train_data_file = os.path.join(raw_data_dir, "adult.data") + test_data_file = os.path.join(raw_data_dir, "adult.test") + + census_example_v2.transform_data(train_data_file, test_data_file, working_dir) + + if read_raw_data_for_training: + raw_train_and_eval_patterns = (train_data_file, test_data_file) + transformed_train_and_eval_patterns = None + else: + train_pattern = os.path.join( + working_dir, census_example_v2.TRANSFORMED_TRAIN_DATA_FILEBASE + "*" + ) + eval_pattern = os.path.join( + working_dir, census_example_v2.TRANSFORMED_TEST_DATA_FILEBASE + "*" + ) + raw_train_and_eval_patterns = None + transformed_train_and_eval_patterns = (train_pattern, eval_pattern) + output_dir = os.path.join(working_dir, census_example_v2.EXPORTED_MODEL_DIR) + results = census_example_v2.train_and_evaluate( + raw_train_and_eval_patterns, + transformed_train_and_eval_patterns, + output_dir, + working_dir, + num_train_instances=1000, + num_test_instances=1000, + ) + self.assertGreaterEqual(results[1], 0.7) + + # Removing the tf.Transform output directory in order to show that the + # exported model is hermetic. + shutil.rmtree(os.path.join(working_dir, "transform_fn")) + + model_path = os.path.join(working_dir, census_example_v2.EXPORTED_MODEL_DIR) + + actual_model_path = os.path.join(model_path, "1") + tf_keras.backend.clear_session() + model = tf_keras.models.load_model(actual_model_path) + model.summary() + + example = text_format.Parse(_PREDICT_TF_EXAMPLE_TEXT_PB, tf.train.Example()) + prediction = model.signatures["serving_default"]( + tf.constant([example.SerializeToString()], tf.string) + ) + self.assertAllEqual([["0", "1"]], prediction["classes"]) + self.assertAllClose([[0, 1]], prediction["scores"], atol=0.01) + + # This is required in order to support the classify API for this Keras + # model. + updater = tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater( + actual_model_path + ) + updater.replace_method_name( + signature_key="serving_default", + method_name="tensorflow/serving/classify", + tags=["serve"], + ) + updater.save() + + if local_model_server.local_model_server_supported(): + with local_model_server.start_server(_MODEL_NAME, model_path) as address: + ascii_classification_request = _CLASSIFICATION_REQUEST_TEXT_PB + results = local_model_server.make_classification_request( + address, ascii_classification_request + ) + self.assertLen(results, 1) + self.assertLen(results[0].classes, 2) + self.assertEqual(results[0].classes[0].label, "0") + self.assertLess(results[0].classes[0].score, 0.01) + self.assertEqual(results[0].classes[1].label, "1") + self.assertGreater(results[0].classes[1].score, 0.99) + + def test_main_runs(self): + census_example_v2.main( + self._get_data_dir(), + self._get_working_dir(), + read_raw_data_for_training=False, + num_train_instances=10, + num_test_instances=10, + ) + + def test_main_runs_raw_data(self): + census_example_v2.main( + self._get_data_dir(), + self._get_working_dir(), + read_raw_data_for_training=True, + num_train_instances=10, + num_test_instances=10, + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/examples/dataset_tfxio_example.py b/examples/dataset_tfxio_example.py index b385f5b..f35006f 100644 --- a/examples/dataset_tfxio_example.py +++ b/examples/dataset_tfxio_example.py @@ -16,60 +16,60 @@ import pprint import tempfile -from absl import app import apache_beam as beam import tensorflow as tf +from absl import app +from tfx_bsl.tfxio import dataset_tfxio + import tensorflow_transform as tft import tensorflow_transform.beam.impl as tft_beam -from tfx_bsl.tfxio import dataset_tfxio def _print_record_batch(data): - pprint.pprint(data.to_pydict()) + pprint.pprint(data.to_pydict()) def _preprocessing_fn(inputs): - return { - 'x_centered': tf.cast(inputs['feature0'], tf.float32) - tft.mean( - inputs['feature0'] - ), - 'x_scaled': tft.scale_by_min_max(inputs['feature0']), - } + return { + "x_centered": tf.cast(inputs["feature0"], tf.float32) + - tft.mean(inputs["feature0"]), + "x_scaled": tft.scale_by_min_max(inputs["feature0"]), + } def _make_tfxio() -> dataset_tfxio.DatasetTFXIO: - """Make DatasetTFXIO.""" - num_elements = 9 - batch_size = 2 - dataset = tf.data.Dataset.range(num_elements).batch(batch_size) + """Make DatasetTFXIO.""" + num_elements = 9 + batch_size = 2 + dataset = tf.data.Dataset.range(num_elements).batch(batch_size) - return dataset_tfxio.DatasetTFXIO(dataset=dataset) + return dataset_tfxio.DatasetTFXIO(dataset=dataset) def main(args): - del args - - input_tfxio = _make_tfxio() - - # User-Defined Processing Pipeline - with beam.Pipeline() as pipeline: - with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - raw_dataset = ( - pipeline | 'ReadRecordBatch' >> input_tfxio.BeamSource(batch_size=5), - input_tfxio.TensorAdapterConfig(), - ) - (transformed_data, _), _ = ( - raw_dataset - | 'AnalyzeAndTransform' - >> tft_beam.AnalyzeAndTransformDataset( - _preprocessing_fn, output_record_batches=True - ) - ) - transformed_data = transformed_data | 'ExtractRecordBatch' >> beam.Keys() - _ = transformed_data | 'PrintTransformedData' >> beam.Map( - _print_record_batch - ) - - -if __name__ == '__main__': - app.run(main) + del args + + input_tfxio = _make_tfxio() + + # User-Defined Processing Pipeline + with beam.Pipeline() as pipeline: + with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + raw_dataset = ( + pipeline | "ReadRecordBatch" >> input_tfxio.BeamSource(batch_size=5), + input_tfxio.TensorAdapterConfig(), + ) + (transformed_data, _), _ = ( + raw_dataset + | "AnalyzeAndTransform" + >> tft_beam.AnalyzeAndTransformDataset( + _preprocessing_fn, output_record_batches=True + ) + ) + transformed_data = transformed_data | "ExtractRecordBatch" >> beam.Keys() + _ = transformed_data | "PrintTransformedData" >> beam.Map( + _print_record_batch + ) + + +if __name__ == "__main__": + app.run(main) diff --git a/examples/dataset_tfxio_example_test.py b/examples/dataset_tfxio_example_test.py index f9d6915..3963a6d 100644 --- a/examples/dataset_tfxio_example_test.py +++ b/examples/dataset_tfxio_example_test.py @@ -13,43 +13,41 @@ # limitations under the License. """Tests for dataset_tfxio.""" -import tensorflow as tf import dataset_tfxio_example -from tensorflow_transform.beam import tft_unit +import tensorflow as tf +from tensorflow_transform.beam import tft_unit _EXPECTED_TRANSFORMED_OUTPUT = [ - {'x_scaled': 0.0, 'x_centered': -4.0}, - {'x_scaled': 0.125, 'x_centered': -3.0}, - {'x_scaled': 0.25, 'x_centered': -2.0}, - {'x_scaled': 0.375, 'x_centered': -1.0}, - {'x_scaled': 0.5, 'x_centered': 0.0}, - {'x_scaled': 0.625, 'x_centered': 1.0}, - {'x_scaled': 0.75, 'x_centered': 2.0}, - {'x_scaled': 0.875, 'x_centered': 3.0}, - {'x_scaled': 1.0, 'x_centered': 4.0}, + {"x_scaled": 0.0, "x_centered": -4.0}, + {"x_scaled": 0.125, "x_centered": -3.0}, + {"x_scaled": 0.25, "x_centered": -2.0}, + {"x_scaled": 0.375, "x_centered": -1.0}, + {"x_scaled": 0.5, "x_centered": 0.0}, + {"x_scaled": 0.625, "x_centered": 1.0}, + {"x_scaled": 0.75, "x_centered": 2.0}, + {"x_scaled": 0.875, "x_centered": 3.0}, + {"x_scaled": 1.0, "x_centered": 4.0}, ] class SimpleMainTest(tf.test.TestCase): - - def testMainDoesNotCrash(self): - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - dataset_tfxio_example.main('') + def testMainDoesNotCrash(self): + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + dataset_tfxio_example.main("") class SimpleProcessingTest(tft_unit.TransformTestCase): - - # Asserts equal for each element. (Does not check batchwise.) - def test_preprocessing_fn(self): - tfxio = dataset_tfxio_example._make_tfxio() - self.assertAnalyzeAndTransformResults( - tfxio.BeamSource(), - tfxio.TensorAdapterConfig(), - dataset_tfxio_example._preprocessing_fn, - _EXPECTED_TRANSFORMED_OUTPUT, - ) - - -if __name__ == '__main__': - tf.test.main() + # Asserts equal for each element. (Does not check batchwise.) + def test_preprocessing_fn(self): + tfxio = dataset_tfxio_example._make_tfxio() + self.assertAnalyzeAndTransformResults( + tfxio.BeamSource(), + tfxio.TensorAdapterConfig(), + dataset_tfxio_example._preprocessing_fn, + _EXPECTED_TRANSFORMED_OUTPUT, + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/examples/local_model_server.py b/examples/local_model_server.py index b78530e..373799b 100644 --- a/examples/local_model_server.py +++ b/examples/local_model_server.py @@ -17,21 +17,21 @@ def local_model_server_supported(): - return False + return False @contextlib.contextmanager def start_server(model_name, model_path): - del model_name # unused - del model_path # unused - raise NotImplementedError + del model_name # unused + del model_path # unused + raise NotImplementedError # TODO(KesterTong): Change the input of make_classification_request to not be a # string. This will require adding a test-only dependency on # tensorflow_serving.apis. def make_classification_request(address, ascii_classification_request): - """Makes a classify request to a local server.""" - del address # unused - del ascii_classification_request # unused - raise NotImplementedError + """Makes a classify request to a local server.""" + del address # unused + del ascii_classification_request # unused + raise NotImplementedError diff --git a/examples/sentiment_example_v2.py b/examples/sentiment_example_v2.py index ad11dc1..380401d 100644 --- a/examples/sentiment_example_v2.py +++ b/examples/sentiment_example_v2.py @@ -19,15 +19,15 @@ import os import pprint import tempfile -from absl import logging import apache_beam as beam import tensorflow as tf +from absl import logging +from tfx_bsl.public import tfxio + import tensorflow_transform as tft import tensorflow_transform.beam as tft_beam from tensorflow_transform.keras_lib import tf_keras -from tfx_bsl.public import tfxio - VOCAB_SIZE = 20000 TRAIN_BATCH_SIZE = 128 @@ -35,26 +35,26 @@ NUM_TRAIN_INSTANCES = 25000 NUM_TEST_INSTANCES = 25000 -REVIEW_KEY = 'review' -REVIEW_WEIGHT_KEY = 'review_weight' -LABEL_KEY = 'label' +REVIEW_KEY = "review" +REVIEW_WEIGHT_KEY = "review_weight" +LABEL_KEY = "label" RAW_DATA_FEATURE_SPEC = { REVIEW_KEY: tf.io.FixedLenFeature([], tf.string), - LABEL_KEY: tf.io.FixedLenFeature([], tf.int64) + LABEL_KEY: tf.io.FixedLenFeature([], tf.int64), } SCHEMA = tft.DatasetMetadata.from_feature_spec(RAW_DATA_FEATURE_SPEC).schema -DELIMITERS = '.,!?() ' +DELIMITERS = ".,!?() " # Names of temp files -SHUFFLED_TRAIN_DATA_FILEBASE = 'train_shuffled' -SHUFFLED_TEST_DATA_FILEBASE = 'test_shuffled' -TRANSFORMED_TRAIN_DATA_FILEBASE = 'train_transformed' -TRANSFORMED_TEST_DATA_FILEBASE = 'test_transformed' -TRANSFORM_TEMP_DIR = 'tft_temp' -EXPORTED_MODEL_DIR = 'exported_model_dir' +SHUFFLED_TRAIN_DATA_FILEBASE = "train_shuffled" +SHUFFLED_TEST_DATA_FILEBASE = "test_shuffled" +TRANSFORMED_TRAIN_DATA_FILEBASE = "train_transformed" +TRANSFORMED_TEST_DATA_FILEBASE = "test_transformed" +TRANSFORM_TEMP_DIR = "tft_temp" +EXPORTED_MODEL_DIR = "exported_model_dir" # Functions for preprocessing @@ -62,48 +62,50 @@ # pylint: disable=invalid-name @beam.ptransform_fn def Shuffle(pcoll): - """Shuffles a PCollection. Collection should not contain duplicates.""" - return (pcoll - | 'PairWithHash' >> beam.Map(lambda x: (hash(x), x)) - | 'GroupByHash' >> beam.GroupByKey() - | 'DropHash' >> beam.FlatMap( - lambda hash_and_values: hash_and_values[1])) + """Shuffles a PCollection. Collection should not contain duplicates.""" + return ( + pcoll + | "PairWithHash" >> beam.Map(lambda x: (hash(x), x)) + | "GroupByHash" >> beam.GroupByKey() + | "DropHash" >> beam.FlatMap(lambda hash_and_values: hash_and_values[1]) + ) # pylint: disable=invalid-name @beam.ptransform_fn def ReadAndShuffleData(pcoll, filepatterns): - """Read a train or test dataset from disk and shuffle it.""" - # NOTE: we pass filepatterns as a tuple instead of two args, as the current - # version of beam assumes that if the first arg to a ptransfrom_fn is a - # string, then that string is the label. - neg_filepattern, pos_filepattern = filepatterns - - # Read from each file pattern and create a tuple of the review text and the - # correct label. - negative_examples = ( - pcoll - | 'ReadNegativeExamples' >> beam.io.ReadFromText(neg_filepattern) - | 'PairWithZero' >> beam.Map(lambda review: (review, 0))) - positive_examples = ( - pcoll - | 'ReadPositiveExamples' >> beam.io.ReadFromText(pos_filepattern) - | 'PairWithOne' >> beam.Map(lambda review: (review, 1))) - all_examples = ( - [negative_examples, positive_examples] | 'Merge' >> beam.Flatten()) - - # Shuffle the data. Note that the data does in fact contain duplicate reviews - # for reasons that are unclear. This means that NUM_TRAIN_INSTANCES and - # NUM_TRAIN_INSTANCES are slightly wrong for the preprocessed data. - # pylint: disable=no-value-for-parameter - shuffled_examples = ( - all_examples - | 'Distinct' >> beam.Distinct() - | 'Shuffle' >> Shuffle()) - - # Put the data in the format that can be accepted directly by tf.Transform. - return shuffled_examples | 'MakeInstances' >> beam.Map( - lambda p: {REVIEW_KEY: p[0], LABEL_KEY: p[1]}) + """Read a train or test dataset from disk and shuffle it.""" + # NOTE: we pass filepatterns as a tuple instead of two args, as the current + # version of beam assumes that if the first arg to a ptransfrom_fn is a + # string, then that string is the label. + neg_filepattern, pos_filepattern = filepatterns + + # Read from each file pattern and create a tuple of the review text and the + # correct label. + negative_examples = ( + pcoll + | "ReadNegativeExamples" >> beam.io.ReadFromText(neg_filepattern) + | "PairWithZero" >> beam.Map(lambda review: (review, 0)) + ) + positive_examples = ( + pcoll + | "ReadPositiveExamples" >> beam.io.ReadFromText(pos_filepattern) + | "PairWithOne" >> beam.Map(lambda review: (review, 1)) + ) + all_examples = [negative_examples, positive_examples] | "Merge" >> beam.Flatten() + + # Shuffle the data. Note that the data does in fact contain duplicate reviews + # for reasons that are unclear. This means that NUM_TRAIN_INSTANCES and + # NUM_TRAIN_INSTANCES are slightly wrong for the preprocessed data. + # pylint: disable=no-value-for-parameter + shuffled_examples = ( + all_examples | "Distinct" >> beam.Distinct() | "Shuffle" >> Shuffle() + ) + + # Put the data in the format that can be accepted directly by tf.Transform. + return shuffled_examples | "MakeInstances" >> beam.Map( + lambda p: {REVIEW_KEY: p[0], LABEL_KEY: p[1]} + ) def read_and_shuffle_data( @@ -113,120 +115,138 @@ def read_and_shuffle_data( test_pos_filepattern: str, working_dir: str, ): - """Read and shuffle the data and write out as a TFRecord of Example protos. - - Read in the data from the positive and negative examples on disk, shuffle it - and write it out in TFRecord format. - transform it using a preprocessing pipeline that removes punctuation, - tokenizes and maps tokens to int64 values indices. - - Args: - train_neg_filepattern: Filepattern for training data negative examples - train_pos_filepattern: Filepattern for training data positive examples - test_neg_filepattern: Filepattern for test data negative examples - test_pos_filepattern: Filepattern for test data positive examples - working_dir: Directory to write shuffled data to - """ - with beam.Pipeline() as pipeline: - coder = tft.coders.ExampleProtoCoder(SCHEMA) - - # pylint: disable=no-value-for-parameter - _ = ( - pipeline - | 'ReadAndShuffleTrain' >> ReadAndShuffleData( - (train_neg_filepattern, train_pos_filepattern)) - | 'EncodeTrainData' >> beam.Map(coder.encode) - | 'WriteTrainData' >> beam.io.WriteToTFRecord( - os.path.join(working_dir, SHUFFLED_TRAIN_DATA_FILEBASE))) - - _ = ( - pipeline - | 'ReadAndShuffleTest' >> ReadAndShuffleData( - (test_neg_filepattern, test_pos_filepattern)) - | 'EncodeTestData' >> beam.Map(coder.encode) - | 'WriteTestData' >> beam.io.WriteToTFRecord( - os.path.join(working_dir, SHUFFLED_TEST_DATA_FILEBASE))) - # pylint: enable=no-value-for-parameter + """Read and shuffle the data and write out as a TFRecord of Example protos. + + Read in the data from the positive and negative examples on disk, shuffle it + and write it out in TFRecord format. + transform it using a preprocessing pipeline that removes punctuation, + tokenizes and maps tokens to int64 values indices. + + Args: + ---- + train_neg_filepattern: Filepattern for training data negative examples + train_pos_filepattern: Filepattern for training data positive examples + test_neg_filepattern: Filepattern for test data negative examples + test_pos_filepattern: Filepattern for test data positive examples + working_dir: Directory to write shuffled data to + """ + with beam.Pipeline() as pipeline: + coder = tft.coders.ExampleProtoCoder(SCHEMA) + + # pylint: disable=no-value-for-parameter + _ = ( + pipeline + | "ReadAndShuffleTrain" + >> ReadAndShuffleData((train_neg_filepattern, train_pos_filepattern)) + | "EncodeTrainData" >> beam.Map(coder.encode) + | "WriteTrainData" + >> beam.io.WriteToTFRecord( + os.path.join(working_dir, SHUFFLED_TRAIN_DATA_FILEBASE) + ) + ) + + _ = ( + pipeline + | "ReadAndShuffleTest" + >> ReadAndShuffleData((test_neg_filepattern, test_pos_filepattern)) + | "EncodeTestData" >> beam.Map(coder.encode) + | "WriteTestData" + >> beam.io.WriteToTFRecord( + os.path.join(working_dir, SHUFFLED_TEST_DATA_FILEBASE) + ) + ) + # pylint: enable=no-value-for-parameter def transform_data(working_dir: str): - """Transform the data and write out as a TFRecord of Example protos. - - Read in the data from the positive and negative examples on disk, and - transform it using a preprocessing pipeline that removes punctuation, - tokenizes and maps tokens to int64 values indices. - - Args: - working_dir: Directory to read shuffled data from and write transformed data - and metadata to. - """ - - with beam.Pipeline() as pipeline: - with tft_beam.Context( - temp_dir=os.path.join(working_dir, TRANSFORM_TEMP_DIR)): - tfxio_train_data = tfxio.TFExampleRecord( - file_pattern=os.path.join(working_dir, - SHUFFLED_TRAIN_DATA_FILEBASE + '*'), - schema=SCHEMA) - train_data = ( - pipeline | 'TFXIORead[Train]' >> tfxio_train_data.BeamSource()) - - tfxio_test_data = tfxio.TFExampleRecord( - file_pattern=os.path.join(working_dir, - SHUFFLED_TEST_DATA_FILEBASE + '*'), - schema=SCHEMA) - test_data = (pipeline | 'TFXIORead[Test]' >> tfxio_test_data.BeamSource()) - - def preprocessing_fn(inputs): - """Preprocess input columns into transformed columns.""" - review = inputs[REVIEW_KEY] - - # Here tf.compat.v1.string_split behaves differently from - # tf.strings.split. - review_tokens = tf.compat.v1.string_split(review, DELIMITERS) - review_indices = tft.compute_and_apply_vocabulary( - review_tokens, top_k=VOCAB_SIZE) - # Add one for the oov bucket created by compute_and_apply_vocabulary. - review_bow_indices, review_weight = tft.tfidf(review_indices, - VOCAB_SIZE + 1) - return { - REVIEW_KEY: review_bow_indices, - REVIEW_WEIGHT_KEY: review_weight, - LABEL_KEY: tf.one_hot(inputs[LABEL_KEY], 2), - } - - # The TFXIO output format is chosen for improved performance. - transformed_train_data, transform_fn = ( - (train_data, tfxio_train_data.TensorAdapterConfig()) - | 'AnalyzeAndTransform' >> tft_beam.AnalyzeAndTransformDataset( - preprocessing_fn, output_record_batches=True)) - - transformed_test_data = ( - ((test_data, tfxio_test_data.TensorAdapterConfig()), transform_fn) - | - 'Transform' >> tft_beam.TransformDataset(output_record_batches=True)) - - # Extract transformed RecordBatches, encode and write them to the given - # directory. - _ = ( - transformed_train_data - | 'EncodeTrainData' >> tft_beam.EncodeTransformedDataset() - | 'WriteTrainData' >> beam.io.WriteToTFRecord( - os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE))) - - _ = ( - transformed_test_data - | 'EncodeTestData' >> tft_beam.EncodeTransformedDataset() - | 'WriteTestData' >> beam.io.WriteToTFRecord( - os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE))) - - # Will write a SavedModel and metadata to two subdirectories of - # working_dir, given by tft.TRANSFORM_FN_DIR and - # tft.TRANSFORMED_METADATA_DIR respectively. - _ = ( - transform_fn - | 'WriteTransformFn' >> - tft_beam.WriteTransformFn(working_dir)) + """Transform the data and write out as a TFRecord of Example protos. + + Read in the data from the positive and negative examples on disk, and + transform it using a preprocessing pipeline that removes punctuation, + tokenizes and maps tokens to int64 values indices. + + Args: + ---- + working_dir: Directory to read shuffled data from and write transformed data + and metadata to. + """ + with beam.Pipeline() as pipeline: + with tft_beam.Context(temp_dir=os.path.join(working_dir, TRANSFORM_TEMP_DIR)): + tfxio_train_data = tfxio.TFExampleRecord( + file_pattern=os.path.join( + working_dir, SHUFFLED_TRAIN_DATA_FILEBASE + "*" + ), + schema=SCHEMA, + ) + train_data = pipeline | "TFXIORead[Train]" >> tfxio_train_data.BeamSource() + + tfxio_test_data = tfxio.TFExampleRecord( + file_pattern=os.path.join( + working_dir, SHUFFLED_TEST_DATA_FILEBASE + "*" + ), + schema=SCHEMA, + ) + test_data = pipeline | "TFXIORead[Test]" >> tfxio_test_data.BeamSource() + + def preprocessing_fn(inputs): + """Preprocess input columns into transformed columns.""" + review = inputs[REVIEW_KEY] + + # Here tf.compat.v1.string_split behaves differently from + # tf.strings.split. + review_tokens = tf.compat.v1.string_split(review, DELIMITERS) + review_indices = tft.compute_and_apply_vocabulary( + review_tokens, top_k=VOCAB_SIZE + ) + # Add one for the oov bucket created by compute_and_apply_vocabulary. + review_bow_indices, review_weight = tft.tfidf( + review_indices, VOCAB_SIZE + 1 + ) + return { + REVIEW_KEY: review_bow_indices, + REVIEW_WEIGHT_KEY: review_weight, + LABEL_KEY: tf.one_hot(inputs[LABEL_KEY], 2), + } + + # The TFXIO output format is chosen for improved performance. + transformed_train_data, transform_fn = ( + train_data, + tfxio_train_data.TensorAdapterConfig(), + ) | "AnalyzeAndTransform" >> tft_beam.AnalyzeAndTransformDataset( + preprocessing_fn, output_record_batches=True + ) + + transformed_test_data = ( + (test_data, tfxio_test_data.TensorAdapterConfig()), + transform_fn, + ) | "Transform" >> tft_beam.TransformDataset(output_record_batches=True) + + # Extract transformed RecordBatches, encode and write them to the given + # directory. + _ = ( + transformed_train_data + | "EncodeTrainData" >> tft_beam.EncodeTransformedDataset() + | "WriteTrainData" + >> beam.io.WriteToTFRecord( + os.path.join(working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE) + ) + ) + + _ = ( + transformed_test_data + | "EncodeTestData" >> tft_beam.EncodeTransformedDataset() + | "WriteTestData" + >> beam.io.WriteToTFRecord( + os.path.join(working_dir, TRANSFORMED_TEST_DATA_FILEBASE) + ) + ) + + # Will write a SavedModel and metadata to two subdirectories of + # working_dir, given by tft.TRANSFORM_FN_DIR and + # tft.TRANSFORMED_METADATA_DIR respectively. + _ = transform_fn | "WriteTransformFn" >> tft_beam.WriteTransformFn( + working_dir + ) # Functions for training @@ -237,24 +257,26 @@ def _input_fn( transformed_examples: str, batch_size: int, ): - """Creates an input function reading from transformed data. - - Args: - tf_transform_output: Wrapper around output of tf.Transform. - transformed_examples: Base filename of examples. - batch_size: Batch size. - - Returns: - The input function for training or eval. - """ - return tf.data.experimental.make_batched_features_dataset( - file_pattern=transformed_examples, - batch_size=batch_size, - features=tf_transform_output.transformed_feature_spec(), - reader=tf.data.TFRecordDataset, - label_key=LABEL_KEY, - shuffle=True, - ).prefetch(tf.data.experimental.AUTOTUNE) + """Creates an input function reading from transformed data. + + Args: + ---- + tf_transform_output: Wrapper around output of tf.Transform. + transformed_examples: Base filename of examples. + batch_size: Batch size. + + Returns: + ------- + The input function for training or eval. + """ + return tf.data.experimental.make_batched_features_dataset( + file_pattern=transformed_examples, + batch_size=batch_size, + features=tf_transform_output.transformed_feature_spec(), + reader=tf.data.TFRecordDataset, + label_key=LABEL_KEY, + shuffle=True, + ).prefetch(tf.data.experimental.AUTOTUNE) def export_serving_model( @@ -262,43 +284,43 @@ def export_serving_model( model: tf_keras.Model, output_dir: str, ): - """Creates an input function reading from raw data. - - Args: - tf_transform_output: Wrapper around output of tf.Transform. - model: The keras model to export. - output_dir: A path to export the model to. - - Returns: - The serving input function. - """ - # The layer has to be saved to the model for keras tracking purpases. - model.tft_layer = tf_transform_output.transform_features_layer() - - raw_feature_spec = RAW_DATA_FEATURE_SPEC.copy() - # Remove label since it is not available during serving. - raw_feature_spec.pop(LABEL_KEY) - - @tf.function - def serve_tf_examples_fn(serialized_tf_examples): - """Serving tf.function model wrapper.""" - parsed_features = tf.io.parse_example( - serialized_tf_examples, raw_feature_spec + """Creates an input function reading from raw data. + + Args: + ---- + tf_transform_output: Wrapper around output of tf.Transform. + model: The keras model to export. + output_dir: A path to export the model to. + + Returns: + ------- + The serving input function. + """ + # The layer has to be saved to the model for keras tracking purpases. + model.tft_layer = tf_transform_output.transform_features_layer() + + raw_feature_spec = RAW_DATA_FEATURE_SPEC.copy() + # Remove label since it is not available during serving. + raw_feature_spec.pop(LABEL_KEY) + + @tf.function + def serve_tf_examples_fn(serialized_tf_examples): + """Serving tf.function model wrapper.""" + parsed_features = tf.io.parse_example(serialized_tf_examples, raw_feature_spec) + transformed_features = model.tft_layer(parsed_features) + outputs = model(transformed_features) + classes_names = tf.constant([["0", "1"]]) + classes = tf.tile(classes_names, [tf.shape(outputs)[0], 1]) + return {"classes": classes, "scores": outputs} + + concrete_serving_fn = serve_tf_examples_fn.get_concrete_function( + tf.TensorSpec(shape=[None], dtype=tf.string, name="inputs") ) - transformed_features = model.tft_layer(parsed_features) - outputs = model(transformed_features) - classes_names = tf.constant([['0', '1']]) - classes = tf.tile(classes_names, [tf.shape(outputs)[0], 1]) - return {'classes': classes, 'scores': outputs} - - concrete_serving_fn = serve_tf_examples_fn.get_concrete_function( - tf.TensorSpec(shape=[None], dtype=tf.string, name='inputs') - ) - signatures = {'serving_default': concrete_serving_fn} + signatures = {"serving_default": concrete_serving_fn} - # This is required in order to make this model servable with model_server. - versioned_output_dir = os.path.join(output_dir, '1') - model.save(versioned_output_dir, save_format='tf', signatures=signatures) + # This is required in order to make this model servable with model_server. + versioned_output_dir = os.path.join(output_dir, "1") + model.save(versioned_output_dir, save_format="tf", signatures=signatures) def train_and_evaluate( @@ -307,107 +329,113 @@ def train_and_evaluate( num_train_instances: int = NUM_TRAIN_INSTANCES, num_test_instances: int = NUM_TEST_INSTANCES, ): - """Train the model on training data and evaluate on test data. - - Args: - working_dir: Directory to read transformed data and metadata from. - output_dir: A directory where the output should be exported to. - num_train_instances: Number of instances in train set - num_test_instances: Number of instances in test set - - Returns: - The results from the estimator's 'evaluate' method - """ - tf_transform_output = tft.TFTransformOutput(working_dir) - train_data_path_pattern = os.path.join( - working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE + '*' - ) - test_data_path_pattern = os.path.join( - working_dir, TRANSFORMED_TEST_DATA_FILEBASE + '*' - ) - - train_dataset = _input_fn( - tf_transform_output, train_data_path_pattern, batch_size=TRAIN_BATCH_SIZE - ) - validation_dataset = _input_fn( - tf_transform_output, test_data_path_pattern, batch_size=TRAIN_BATCH_SIZE - ) - - feature_spec = tf_transform_output.transformed_feature_spec().copy() - feature_spec.pop(LABEL_KEY) - - review_input = tf_keras.layers.Input( - shape=[None], name=REVIEW_KEY, dtype=tf.int64, sparse=True - ) - review_weight_input = tf_keras.layers.Input( - shape=[None], name=REVIEW_WEIGHT_KEY, dtype=tf.float32, sparse=True - ) - count_layer = tf.keras.layers.CategoryEncoding( - num_tokens=VOCAB_SIZE + 1, output_mode='count' - ) - embedding_layer = tf.keras.layers.Dense(4, use_bias=False) - embedding = embedding_layer( - count_layer(review_input, count_weights=review_weight_input) - ) - output = tf_keras.layers.Dense(100, activation='relu')(embedding) - output = tf_keras.layers.Dense(70, activation='relu')(output) - output = tf_keras.layers.Dense(50, activation='relu')(output) - output = tf_keras.layers.Dense(20, activation='relu')(output) - output = tf_keras.layers.Dense(2, activation='sigmoid')(output) - model = tf_keras.Model( - inputs={ - REVIEW_KEY: review_input, - REVIEW_WEIGHT_KEY: review_weight_input, - }, - outputs=output, - ) - - model.compile( - optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'] - ) - logging.info(model.summary()) - - model.fit( - train_dataset, - validation_data=validation_dataset, - epochs=TRAIN_NUM_EPOCHS, - steps_per_epoch=math.ceil(num_train_instances / TRAIN_BATCH_SIZE), - validation_steps=math.ceil(num_test_instances / TRAIN_BATCH_SIZE), - ) - - # Export the model. - export_serving_model(tf_transform_output, model, output_dir) - - return model.evaluate(validation_dataset, steps=num_test_instances) + """Train the model on training data and evaluate on test data. + + Args: + ---- + working_dir: Directory to read transformed data and metadata from. + output_dir: A directory where the output should be exported to. + num_train_instances: Number of instances in train set + num_test_instances: Number of instances in test set + + Returns: + ------- + The results from the estimator's 'evaluate' method + """ + tf_transform_output = tft.TFTransformOutput(working_dir) + train_data_path_pattern = os.path.join( + working_dir, TRANSFORMED_TRAIN_DATA_FILEBASE + "*" + ) + test_data_path_pattern = os.path.join( + working_dir, TRANSFORMED_TEST_DATA_FILEBASE + "*" + ) + + train_dataset = _input_fn( + tf_transform_output, train_data_path_pattern, batch_size=TRAIN_BATCH_SIZE + ) + validation_dataset = _input_fn( + tf_transform_output, test_data_path_pattern, batch_size=TRAIN_BATCH_SIZE + ) + + feature_spec = tf_transform_output.transformed_feature_spec().copy() + feature_spec.pop(LABEL_KEY) + + review_input = tf_keras.layers.Input( + shape=[None], name=REVIEW_KEY, dtype=tf.int64, sparse=True + ) + review_weight_input = tf_keras.layers.Input( + shape=[None], name=REVIEW_WEIGHT_KEY, dtype=tf.float32, sparse=True + ) + count_layer = tf.keras.layers.CategoryEncoding( + num_tokens=VOCAB_SIZE + 1, output_mode="count" + ) + embedding_layer = tf.keras.layers.Dense(4, use_bias=False) + embedding = embedding_layer( + count_layer(review_input, count_weights=review_weight_input) + ) + output = tf_keras.layers.Dense(100, activation="relu")(embedding) + output = tf_keras.layers.Dense(70, activation="relu")(output) + output = tf_keras.layers.Dense(50, activation="relu")(output) + output = tf_keras.layers.Dense(20, activation="relu")(output) + output = tf_keras.layers.Dense(2, activation="sigmoid")(output) + model = tf_keras.Model( + inputs={ + REVIEW_KEY: review_input, + REVIEW_WEIGHT_KEY: review_weight_input, + }, + outputs=output, + ) + + model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]) + logging.info(model.summary()) + + model.fit( + train_dataset, + validation_data=validation_dataset, + epochs=TRAIN_NUM_EPOCHS, + steps_per_epoch=math.ceil(num_train_instances / TRAIN_BATCH_SIZE), + validation_steps=math.ceil(num_test_instances / TRAIN_BATCH_SIZE), + ) + + # Export the model. + export_serving_model(tf_transform_output, model, output_dir) + + return model.evaluate(validation_dataset, steps=num_test_instances) def main(): - parser = argparse.ArgumentParser() - parser.add_argument('input_data_dir', - help='path to directory containing input data') - parser.add_argument('--working_dir', - help='path to directory to hold transformed data') - args = parser.parse_args() - - if args.working_dir: - working_dir = args.working_dir - else: - working_dir = tempfile.mkdtemp(dir=args.input_data_dir) - - train_neg_filepattern = os.path.join(args.input_data_dir, 'train/neg/*') - train_pos_filepattern = os.path.join(args.input_data_dir, 'train/pos/*') - test_neg_filepattern = os.path.join(args.input_data_dir, 'test/neg/*') - test_pos_filepattern = os.path.join(args.input_data_dir, 'test/pos/*') - - read_and_shuffle_data(train_neg_filepattern, train_pos_filepattern, - test_neg_filepattern, test_pos_filepattern, - working_dir) - transform_data(working_dir) - exported_model_dir = os.path.join(working_dir, EXPORTED_MODEL_DIR) - results = train_and_evaluate(working_dir, exported_model_dir) - - pprint.pprint(results) - - -if __name__ == '__main__': - main() + parser = argparse.ArgumentParser() + parser.add_argument( + "input_data_dir", help="path to directory containing input data" + ) + parser.add_argument( + "--working_dir", help="path to directory to hold transformed data" + ) + args = parser.parse_args() + + if args.working_dir: + working_dir = args.working_dir + else: + working_dir = tempfile.mkdtemp(dir=args.input_data_dir) + + train_neg_filepattern = os.path.join(args.input_data_dir, "train/neg/*") + train_pos_filepattern = os.path.join(args.input_data_dir, "train/pos/*") + test_neg_filepattern = os.path.join(args.input_data_dir, "test/neg/*") + test_pos_filepattern = os.path.join(args.input_data_dir, "test/pos/*") + + read_and_shuffle_data( + train_neg_filepattern, + train_pos_filepattern, + test_neg_filepattern, + test_pos_filepattern, + working_dir, + ) + transform_data(working_dir) + exported_model_dir = os.path.join(working_dir, EXPORTED_MODEL_DIR) + results = train_and_evaluate(working_dir, exported_model_dir) + + pprint.pprint(results) + + +if __name__ == "__main__": + main() diff --git a/examples/sentiment_example_v2_test.py b/examples/sentiment_example_v2_test.py index cd0dcd8..816a56d 100644 --- a/examples/sentiment_example_v2_test.py +++ b/examples/sentiment_example_v2_test.py @@ -16,73 +16,76 @@ import os import shutil +import local_model_server +import sentiment_example_v2 import tensorflow as tf + import tensorflow_transform as tft -import sentiment_example_v2 from tensorflow_transform import test_case -import local_model_server class SentimentExampleTest(test_case.TransformTestCase): + def testSentimentExampleAccuracy(self): + raw_data_dir = os.path.join(os.path.dirname(__file__), "testdata/sentiment") + working_dir = self.get_temp_dir() - def testSentimentExampleAccuracy(self): - raw_data_dir = os.path.join(os.path.dirname(__file__), 'testdata/sentiment') - working_dir = self.get_temp_dir() - - # Copy data from raw data directory to `working_dir` - try: - for filename in ['test_shuffled-00000-of-00001', - 'train_shuffled-00000-of-00001']: - shutil.copy(os.path.join(raw_data_dir, filename), working_dir) - except FileNotFoundError: - # We only use a small sample of the data for testing purposes. - train_neg_filepattern = os.path.join(raw_data_dir, 'train/neg/10000*') - train_pos_filepattern = os.path.join(raw_data_dir, 'train/pos/10000*') - test_neg_filepattern = os.path.join(raw_data_dir, 'test/neg/10000*') - test_pos_filepattern = os.path.join(raw_data_dir, 'test/pos/10000*') + # Copy data from raw data directory to `working_dir` + try: + for filename in [ + "test_shuffled-00000-of-00001", + "train_shuffled-00000-of-00001", + ]: + shutil.copy(os.path.join(raw_data_dir, filename), working_dir) + except FileNotFoundError: + # We only use a small sample of the data for testing purposes. + train_neg_filepattern = os.path.join(raw_data_dir, "train/neg/10000*") + train_pos_filepattern = os.path.join(raw_data_dir, "train/pos/10000*") + test_neg_filepattern = os.path.join(raw_data_dir, "test/neg/10000*") + test_pos_filepattern = os.path.join(raw_data_dir, "test/pos/10000*") - # Writes the shuffled data under working_dir in TFRecord format. - sentiment_example_v2.read_and_shuffle_data( - train_neg_filepattern, - train_pos_filepattern, - test_neg_filepattern, - test_pos_filepattern, - working_dir, - ) + # Writes the shuffled data under working_dir in TFRecord format. + sentiment_example_v2.read_and_shuffle_data( + train_neg_filepattern, + train_pos_filepattern, + test_neg_filepattern, + test_pos_filepattern, + working_dir, + ) - sentiment_example_v2.transform_data(working_dir) - # TODO: b/323209255 - Remove this if clause once TF pulls the latest keras - # nightly version. - if not test_case.is_external_environment(): - model_path = os.path.join( - working_dir, sentiment_example_v2.EXPORTED_MODEL_DIR - ) - results = sentiment_example_v2.train_and_evaluate( - working_dir, - model_path, - num_train_instances=1000, - num_test_instances=1000, - ) - if not test_case.is_external_environment(): - # Assert expected accuracy. - self.assertGreaterEqual(results[1], 0.7) + sentiment_example_v2.transform_data(working_dir) + # TODO: b/323209255 - Remove this if clause once TF pulls the latest keras + # nightly version. + if not test_case.is_external_environment(): + model_path = os.path.join( + working_dir, sentiment_example_v2.EXPORTED_MODEL_DIR + ) + results = sentiment_example_v2.train_and_evaluate( + working_dir, + model_path, + num_train_instances=1000, + num_test_instances=1000, + ) + if not test_case.is_external_environment(): + # Assert expected accuracy. + self.assertGreaterEqual(results[1], 0.7) - # Delete temp directory and transform_fn directory. This ensures that the - # test of serving the model below will only pass if the SavedModel saved - # to sentiment_example_v2.EXPORTED_MODEL_DIR is hermetic, i.e does not - # contain references to tft_temp and transform_fn. - shutil.rmtree( - os.path.join(working_dir, sentiment_example_v2.TRANSFORM_TEMP_DIR) - ) - shutil.rmtree( - os.path.join(working_dir, tft.TFTransformOutput.TRANSFORM_FN_DIR)) + # Delete temp directory and transform_fn directory. This ensures that the + # test of serving the model below will only pass if the SavedModel saved + # to sentiment_example_v2.EXPORTED_MODEL_DIR is hermetic, i.e does not + # contain references to tft_temp and transform_fn. + shutil.rmtree( + os.path.join(working_dir, sentiment_example_v2.TRANSFORM_TEMP_DIR) + ) + shutil.rmtree( + os.path.join(working_dir, tft.TFTransformOutput.TRANSFORM_FN_DIR) + ) - if local_model_server.local_model_server_supported(): - model_name = 'my_model' - with local_model_server.start_server(model_name, model_path) as address: - # Use made up data chosen to give high probability of negative - # sentiment. - ascii_classification_request = """model_spec { name: "my_model" } + if local_model_server.local_model_server_supported(): + model_name = "my_model" + with local_model_server.start_server(model_name, model_path) as address: + # Use made up data chosen to give high probability of negative + # sentiment. + ascii_classification_request = """model_spec { name: "my_model" } input { example_list { examples { @@ -99,15 +102,16 @@ def testSentimentExampleAccuracy(self): } } }""" - results = local_model_server.make_classification_request( - address, ascii_classification_request) - self.assertLen(results, 1) - self.assertLen(results[0].classes, 2) - self.assertEqual(results[0].classes[0].label, '0') - self.assertGreater(results[0].classes[0].score, 0.8) - self.assertEqual(results[0].classes[1].label, '1') - self.assertLess(results[0].classes[1].score, 0.2) + results = local_model_server.make_classification_request( + address, ascii_classification_request + ) + self.assertLen(results, 1) + self.assertLen(results[0].classes, 2) + self.assertEqual(results[0].classes[0].label, "0") + self.assertGreater(results[0].classes[0].score, 0.8) + self.assertEqual(results[0].classes[1].label, "1") + self.assertLess(results[0].classes[1].score, 0.2) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/examples/simple_example.py b/examples/simple_example.py index 39c9e58..138c029 100644 --- a/examples/simple_example.py +++ b/examples/simple_example.py @@ -17,57 +17,53 @@ import tempfile import tensorflow as tf + import tensorflow_transform as tft import tensorflow_transform.beam as tft_beam -_RAW_DATA_METADATA = tft.DatasetMetadata.from_feature_spec({ - 's': tf.io.FixedLenFeature([], tf.string), - 'y': tf.io.FixedLenFeature([], tf.float32), - 'x': tf.io.FixedLenFeature([], tf.float32), -}) +_RAW_DATA_METADATA = tft.DatasetMetadata.from_feature_spec( + { + "s": tf.io.FixedLenFeature([], tf.string), + "y": tf.io.FixedLenFeature([], tf.float32), + "x": tf.io.FixedLenFeature([], tf.float32), + } +) -_RAW_DATA = [{ - 'x': 1, - 'y': 1, - 's': 'hello' -}, { - 'x': 2, - 'y': 2, - 's': 'world' -}, { - 'x': 3, - 'y': 3, - 's': 'hello' -}] +_RAW_DATA = [ + {"x": 1, "y": 1, "s": "hello"}, + {"x": 2, "y": 2, "s": "world"}, + {"x": 3, "y": 3, "s": "hello"}, +] def _preprocessing_fn(inputs): - """Preprocess input columns into transformed columns.""" - x = inputs['x'] - y = inputs['y'] - s = inputs['s'] - x_centered = x - tft.mean(x) - y_normalized = tft.scale_to_0_1(y) - s_integerized = tft.compute_and_apply_vocabulary(s) - x_centered_times_y_normalized = (x_centered * y_normalized) - return { - 'x_centered': x_centered, - 'y_normalized': y_normalized, - 'x_centered_times_y_normalized': x_centered_times_y_normalized, - 's_integerized': s_integerized - } + """Preprocess input columns into transformed columns.""" + x = inputs["x"] + y = inputs["y"] + s = inputs["s"] + x_centered = x - tft.mean(x) + y_normalized = tft.scale_to_0_1(y) + s_integerized = tft.compute_and_apply_vocabulary(s) + x_centered_times_y_normalized = x_centered * y_normalized + return { + "x_centered": x_centered, + "y_normalized": y_normalized, + "x_centered_times_y_normalized": x_centered_times_y_normalized, + "s_integerized": s_integerized, + } def main(): + with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + transformed_dataset, transform_fn = ( # pylint: disable=unused-variable + (_RAW_DATA, _RAW_DATA_METADATA) + | tft_beam.AnalyzeAndTransformDataset(_preprocessing_fn) + ) - with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - transformed_dataset, transform_fn = ( # pylint: disable=unused-variable - (_RAW_DATA, _RAW_DATA_METADATA) - | tft_beam.AnalyzeAndTransformDataset(_preprocessing_fn)) + transformed_data, transformed_metadata = transformed_dataset # pylint: disable=unused-variable - transformed_data, transformed_metadata = transformed_dataset # pylint: disable=unused-variable + pprint.pprint(transformed_data) - pprint.pprint(transformed_data) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/examples/simple_example_test.py b/examples/simple_example_test.py index 90f230c..d71944f 100644 --- a/examples/simple_example_test.py +++ b/examples/simple_example_test.py @@ -13,47 +13,47 @@ # limitations under the License. """Tests for simple_example.""" -import tensorflow as tf import simple_example -from tensorflow_transform.beam import tft_unit +import tensorflow as tf +from tensorflow_transform.beam import tft_unit _EXPECTED_TRANSFORMED_OUTPUT = [ { - 'x_centered': 1.0, - 'y_normalized': 1.0, - 'x_centered_times_y_normalized': 1.0, - 's_integerized': 0, + "x_centered": 1.0, + "y_normalized": 1.0, + "x_centered_times_y_normalized": 1.0, + "s_integerized": 0, }, { - 'x_centered': 0.0, - 'y_normalized': 0.5, - 'x_centered_times_y_normalized': 0.0, - 's_integerized': 1, + "x_centered": 0.0, + "y_normalized": 0.5, + "x_centered_times_y_normalized": 0.0, + "s_integerized": 1, }, { - 'x_centered': -1.0, - 'y_normalized': 0.0, - 'x_centered_times_y_normalized': -0.0, - 's_integerized': 0, + "x_centered": -1.0, + "y_normalized": 0.0, + "x_centered_times_y_normalized": -0.0, + "s_integerized": 0, }, ] class SimpleExampleTest(tft_unit.TransformTestCase): - - def test_preprocessing_fn(self): - self.assertAnalyzeAndTransformResults(simple_example._RAW_DATA, - simple_example._RAW_DATA_METADATA, - simple_example._preprocessing_fn, - _EXPECTED_TRANSFORMED_OUTPUT) + def test_preprocessing_fn(self): + self.assertAnalyzeAndTransformResults( + simple_example._RAW_DATA, + simple_example._RAW_DATA_METADATA, + simple_example._preprocessing_fn, + _EXPECTED_TRANSFORMED_OUTPUT, + ) class SimpleMainTest(tf.test.TestCase): - - def testMainDoesNotCrash(self): - simple_example.main() + def testMainDoesNotCrash(self): + simple_example.main() -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/examples/simple_sequence_example.py b/examples/simple_sequence_example.py index 5148eb4..1966ea0 100644 --- a/examples/simple_sequence_example.py +++ b/examples/simple_sequence_example.py @@ -16,16 +16,16 @@ import os import tempfile -from absl import logging import apache_beam as beam -import tensorflow_transform as tft -import tensorflow_transform.beam as tft_beam +from absl import logging +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 from tfx_bsl.public import tfxio -from tensorflow_metadata.proto.v0 import schema_pb2 -from google.protobuf import text_format +import tensorflow_transform as tft +import tensorflow_transform.beam as tft_beam -_TRANSFORM_TEMP_DIR = 'tft_temp' +_TRANSFORM_TEMP_DIR = "tft_temp" _SCHEMA = text_format.Parse( """ feature { @@ -89,149 +89,160 @@ } } } -""", schema_pb2.Schema()) +""", + schema_pb2.Schema(), +) -_TELEMETRY_DESCRIPTORS = ['TFT', 'SequenceExample'] +_TELEMETRY_DESCRIPTORS = ["TFT", "SequenceExample"] def _print_record_batch(data): - logging.info(data.to_pydict()) + logging.info(data.to_pydict()) def _make_tfxio(schema): - """Creates TFXIO for SequenceExample. - - Args: - schema: A TFMD Schema describing the dataset. - - Returns: - TFSequenceExampleRecord TFXIO Instance. - - The data_tfrecord.gz file holds Serialized SequenceExample as below: - context { - feature { key: "int_feature" value { int64_list { value: [0] } } } - feature { - key: "float_feature" - value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } - } - } - feature_lists { - feature_list { - key: "int_feature" - value { - feature { int64_list { value: [1, 2] } } - feature { int64_list { value: [3, 4] } } + """Creates TFXIO for SequenceExample. + + Args: + ---- + schema: A TFMD Schema describing the dataset. + + Returns: + ------- + TFSequenceExampleRecord TFXIO Instance. + + The data_tfrecord.gz file holds Serialized SequenceExample as below: + context { + feature { key: "int_feature" value { int64_list { value: [0] } } } + feature { + key: "float_feature" + value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } } - feature_list { - key: "string_feature" - value { - feature { bytes_list { value: ["Hello", "World"] } } - feature { bytes_list { value: [] } } + feature_lists { + feature_list { + key: "int_feature" + value { + feature { int64_list { value: [1, 2] } } + feature { int64_list { value: [3, 4] } } + } + } + feature_list { + key: "string_feature" + value { + feature { bytes_list { value: ["Hello", "World"] } } + feature { bytes_list { value: [] } } + } } } - } - """ - sequence_example_file = os.path.join( - os.path.dirname(__file__), 'testdata/sequence_example/data_tfrecord.gz') - return tfxio.TFSequenceExampleRecord( - sequence_example_file, - schema=schema, - telemetry_descriptors=_TELEMETRY_DESCRIPTORS) + """ + sequence_example_file = os.path.join( + os.path.dirname(__file__), "testdata/sequence_example/data_tfrecord.gz" + ) + return tfxio.TFSequenceExampleRecord( + sequence_example_file, + schema=schema, + telemetry_descriptors=_TELEMETRY_DESCRIPTORS, + ) def _preprocessing_fn(inputs): - """Preprocess input columns into transformed columns. - - Args: - inputs: Input Tensors. - - Returns: - Dictionary of respective transformed inputs - - Example: - `int_features`: tft.scale_to_0_1(...) - Input: [[[0]], [[1]], [[2]]] - Output: [[[0]], [[0.5]], [[1]]] - - `float_features`: tft.scale_to_0_1(.., elementwise = True) - Input: [ - [[1.0, 2.0, 3.0, 4.0]], - [[2.0, 3.0, 4.0, 5.0]], - [[3.0, 4.0, 0.0, 0.0]] - ] - Output: [ - [[0.0, 0.0, 0.75, 0.8]], - [[0.5, 0.5, 1.0, 1.0]], - [[1.0, 1.0, 0.0, 0.0]] - ] - - `seq_int_feature`: tft.scale_by_min_max(...) - Input: [ - [ [1, 2], [3, 4]], - [ [5, 6], [7, 8]], - [[9, 10], [11, 12]] - ] - Output: [ - [[ 0.0, 0.0909], [0.1818, 0.2727]], - [[0.3636, 0.4545], [0.5454, 0.6363]], - [[0.7272, 0.8181], [0.9090, 1.0]] - ] - - `seq_string_feature`: tft.compute_and_apply_vocabulary(...) - Input: [ - [[ b'Hello', b'World'], []], - [[ b'foo', b'bar'], []], - [[b'tensor', b'flow'], []] - ] - Output: [ - [[[5, 4], []]], - [[[1, 3], []]], - [[[0, 2], []]] - ] - """ - return { - 'transformed_seq_int_feature': - tft.scale_by_min_max(inputs['seq_int_feature']), - 'transformed_seq_string_feature': - tft.compute_and_apply_vocabulary(inputs['seq_string_feature']), - 'transformed_float_feature': - tft.scale_to_0_1(inputs['float_feature'], elementwise=True), - 'transformed_int_feature': - tft.scale_to_0_1(inputs['int_feature']), - } + """Preprocess input columns into transformed columns. + + Args: + ---- + inputs: Input Tensors. + + Returns: + ------- + Dictionary of respective transformed inputs + + Example: + ------- + `int_features`: tft.scale_to_0_1(...) + Input: [[[0]], [[1]], [[2]]] + Output: [[[0]], [[0.5]], [[1]]] + + `float_features`: tft.scale_to_0_1(.., elementwise = True) + Input: [ + [[1.0, 2.0, 3.0, 4.0]], + [[2.0, 3.0, 4.0, 5.0]], + [[3.0, 4.0, 0.0, 0.0]] + ] + Output: [ + [[0.0, 0.0, 0.75, 0.8]], + [[0.5, 0.5, 1.0, 1.0]], + [[1.0, 1.0, 0.0, 0.0]] + ] + + `seq_int_feature`: tft.scale_by_min_max(...) + Input: [ + [ [1, 2], [3, 4]], + [ [5, 6], [7, 8]], + [[9, 10], [11, 12]] + ] + Output: [ + [[ 0.0, 0.0909], [0.1818, 0.2727]], + [[0.3636, 0.4545], [0.5454, 0.6363]], + [[0.7272, 0.8181], [0.9090, 1.0]] + ] + + `seq_string_feature`: tft.compute_and_apply_vocabulary(...) + Input: [ + [[ b'Hello', b'World'], []], + [[ b'foo', b'bar'], []], + [[b'tensor', b'flow'], []] + ] + Output: [ + [[[5, 4], []]], + [[[1, 3], []]], + [[[0, 2], []]] + ] + """ + return { + "transformed_seq_int_feature": tft.scale_by_min_max(inputs["seq_int_feature"]), + "transformed_seq_string_feature": tft.compute_and_apply_vocabulary( + inputs["seq_string_feature"] + ), + "transformed_float_feature": tft.scale_to_0_1( + inputs["float_feature"], elementwise=True + ), + "transformed_int_feature": tft.scale_to_0_1(inputs["int_feature"]), + } def _transform_data(sequence_example_tfxio): - """Transform the data and output transformed values. - - Args: - sequence_example_tfxio: tfxio.TFSequenceExampleRecord Object - """ - - with beam.Pipeline() as pipeline: - with tft_beam.Context( - temp_dir=os.path.join(tempfile.mkdtemp(), _TRANSFORM_TEMP_DIR)): + """Transform the data and output transformed values. - raw_data = pipeline | 'ReadAndDecode' >> sequence_example_tfxio.BeamSource( - ) - _ = raw_data | 'PrintInputData' >> beam.Map(_print_record_batch) - - (transformed_data, - _), _ = ((raw_data, sequence_example_tfxio.TensorAdapterConfig()) - | 'AnalyzeAndTransform' >> tft_beam.AnalyzeAndTransformDataset( - _preprocessing_fn, output_record_batches=True)) - - # Drop empty pass-through features dictionary that is not relevant - # for this example. - transformed_data = transformed_data | 'ExtractRecordBatch' >> beam.Keys() - _ = transformed_data | 'PrintTransformedData' >> beam.Map( - _print_record_batch) + Args: + ---- + sequence_example_tfxio: tfxio.TFSequenceExampleRecord Object + """ + with beam.Pipeline() as pipeline: + with tft_beam.Context( + temp_dir=os.path.join(tempfile.mkdtemp(), _TRANSFORM_TEMP_DIR) + ): + raw_data = pipeline | "ReadAndDecode" >> sequence_example_tfxio.BeamSource() + _ = raw_data | "PrintInputData" >> beam.Map(_print_record_batch) + + (transformed_data, _), _ = ( + raw_data, + sequence_example_tfxio.TensorAdapterConfig(), + ) | "AnalyzeAndTransform" >> tft_beam.AnalyzeAndTransformDataset( + _preprocessing_fn, output_record_batches=True + ) + + # Drop empty pass-through features dictionary that is not relevant + # for this example. + transformed_data = transformed_data | "ExtractRecordBatch" >> beam.Keys() + _ = transformed_data | "PrintTransformedData" >> beam.Map( + _print_record_batch + ) def main(): - _transform_data(_make_tfxio(_SCHEMA)) + _transform_data(_make_tfxio(_SCHEMA)) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/examples/simple_sequence_example_test.py b/examples/simple_sequence_example_test.py index 41ea983..204c8d2 100644 --- a/examples/simple_sequence_example_test.py +++ b/examples/simple_sequence_example_test.py @@ -13,59 +13,72 @@ # limitations under the License. """Tests for simple_example.""" -import tensorflow as tf import simple_sequence_example +import tensorflow as tf + from tensorflow_transform.beam import tft_unit -_EXPECTED_TRANSFORMED_OUTPUT = [{ - 'transformed_seq_int_feature$ragged_values': [ - 0.0, 0.09090909, 0.18181818, 0.27272727 - ], - 'transformed_seq_int_feature$row_lengths_1': [2, 2], - 'transformed_seq_string_feature$ragged_values': [5, 4], - 'transformed_seq_string_feature$row_lengths_1': [2, 0], - 'transformed_float_feature': [0.0, 0.0, 0.75, 0.8], - 'transformed_int_feature': [0], -}, { - 'transformed_seq_int_feature$ragged_values': [ - 0.36363636, 0.45454545, 0.54545454, 0.63636363 - ], - 'transformed_seq_int_feature$row_lengths_1': [2, 2], - 'transformed_seq_string_feature$ragged_values': [1, 3], - 'transformed_seq_string_feature$row_lengths_1': [2, 0], - 'transformed_float_feature': [0.5, 0.5, 1.0, 1.0], - 'transformed_int_feature': [0.5], -}, { - 'transformed_seq_int_feature$ragged_values': [ - 0.72727272, 0.81818181, 0.90909090, 1.0 - ], - 'transformed_seq_int_feature$row_lengths_1': [2, 2], - 'transformed_seq_string_feature$ragged_values': [0, 2], - 'transformed_seq_string_feature$row_lengths_1': [2, 0], - 'transformed_float_feature': [1.0, 1.0, 0.0, 0.0], - 'transformed_int_feature': [1], -}] +_EXPECTED_TRANSFORMED_OUTPUT = [ + { + "transformed_seq_int_feature$ragged_values": [ + 0.0, + 0.09090909, + 0.18181818, + 0.27272727, + ], + "transformed_seq_int_feature$row_lengths_1": [2, 2], + "transformed_seq_string_feature$ragged_values": [5, 4], + "transformed_seq_string_feature$row_lengths_1": [2, 0], + "transformed_float_feature": [0.0, 0.0, 0.75, 0.8], + "transformed_int_feature": [0], + }, + { + "transformed_seq_int_feature$ragged_values": [ + 0.36363636, + 0.45454545, + 0.54545454, + 0.63636363, + ], + "transformed_seq_int_feature$row_lengths_1": [2, 2], + "transformed_seq_string_feature$ragged_values": [1, 3], + "transformed_seq_string_feature$row_lengths_1": [2, 0], + "transformed_float_feature": [0.5, 0.5, 1.0, 1.0], + "transformed_int_feature": [0.5], + }, + { + "transformed_seq_int_feature$ragged_values": [ + 0.72727272, + 0.81818181, + 0.90909090, + 1.0, + ], + "transformed_seq_int_feature$row_lengths_1": [2, 2], + "transformed_seq_string_feature$ragged_values": [0, 2], + "transformed_seq_string_feature$row_lengths_1": [2, 0], + "transformed_float_feature": [1.0, 1.0, 0.0, 0.0], + "transformed_int_feature": [1], + }, +] class SimpleMainTest(tf.test.TestCase): - - def testMainDoesNotCrash(self): - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - simple_sequence_example.main() + def testMainDoesNotCrash(self): + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + simple_sequence_example.main() class SimpleSequenceExampleTest(tft_unit.TransformTestCase): - - def testPreprocessingFn(self): - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - tfxio = simple_sequence_example._make_tfxio(simple_sequence_example._SCHEMA) - self.assertAnalyzeAndTransformResults( - tfxio.BeamSource(), - tfxio.TensorAdapterConfig(), - simple_sequence_example._preprocessing_fn, - output_record_batches=True, - expected_data=_EXPECTED_TRANSFORMED_OUTPUT) + def testPreprocessingFn(self): + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + tfxio = simple_sequence_example._make_tfxio(simple_sequence_example._SCHEMA) + self.assertAnalyzeAndTransformResults( + tfxio.BeamSource(), + tfxio.TensorAdapterConfig(), + simple_sequence_example._preprocessing_fn, + output_record_batches=True, + expected_data=_EXPECTED_TRANSFORMED_OUTPUT, + ) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/setup.py b/setup.py index fa75232..35066fb 100644 --- a/setup.py +++ b/setup.py @@ -12,104 +12,105 @@ # See the License for the specific language governing permissions and # limitations under the License. """Package Setup script for tf.Transform.""" + import os -from setuptools import find_packages -from setuptools import setup +from setuptools import find_packages, setup def select_constraint(default, nightly=None, git_master=None): - """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" - selector = os.environ.get('TFX_DEPENDENCY_SELECTOR') - if selector == 'UNCONSTRAINED': - return '' - elif selector == 'NIGHTLY' and nightly is not None: - return nightly - elif selector == 'GIT_MASTER' and git_master is not None: - return git_master - else: - return default + """Select dependency constraint based on TFX_DEPENDENCY_SELECTOR env var.""" + selector = os.environ.get("TFX_DEPENDENCY_SELECTOR") + if selector == "UNCONSTRAINED": + return "" + elif selector == "NIGHTLY" and nightly is not None: + return nightly + elif selector == "GIT_MASTER" and git_master is not None: + return git_master + else: + return default # Get version from version module. -with open('tensorflow_transform/version.py') as fp: - globals_dict = {} - exec(fp.read(), globals_dict) # pylint: disable=exec-used -__version__ = globals_dict['__version__'] +with open("tensorflow_transform/version.py") as fp: + globals_dict = {} + exec(fp.read(), globals_dict) # pylint: disable=exec-used +__version__ = globals_dict["__version__"] def _make_required_install_packages(): - # Make sure to sync the versions of common dependencies (absl-py, numpy, and - # protobuf) with TF and pyarrow version with tfx-bsl. - return [ - 'absl-py>=0.9,<2.0.0', - 'apache-beam[gcp]>=2.53,<3;python_version>="3.11"', - 'apache-beam[gcp]>=2.47,<3;python_version<"3.11"', - 'numpy>=1.22.0', - 'protobuf>=4.25.2,<6;python_version>="3.11"', - 'protobuf>=3.20.3,<5;python_version<"3.11"', - 'pyarrow>=10,<11', - 'pydot>=1.2,<2', - 'tensorflow>=2.17,<2.18', - 'tensorflow-metadata' - + select_constraint( - default='>=1.16.1,<1.17.0', - nightly='>=1.17.0.dev', - git_master='@git+https://github.com/tensorflow/metadata@master', - ), - 'tf_keras>=2', - 'tfx-bsl' - + select_constraint( - default='>=1.16.1,<1.17.0', - nightly='>=1.17.0.dev', - git_master='@git+https://github.com/tensorflow/tfx-bsl@master', - ), - ] + # Make sure to sync the versions of common dependencies (absl-py, numpy, and + # protobuf) with TF and pyarrow version with tfx-bsl. + return [ + "absl-py>=0.9,<2.0.0", + 'apache-beam[gcp]>=2.53,<3;python_version>="3.11"', + 'apache-beam[gcp]>=2.47,<3;python_version<"3.11"', + "numpy>=1.22.0", + 'protobuf>=4.25.2,<6;python_version>="3.11"', + 'protobuf>=3.20.3,<5;python_version<"3.11"', + "pyarrow>=10,<11", + "pydot>=1.2,<2", + "tensorflow>=2.17,<2.18", + "tensorflow-metadata" + + select_constraint( + default=">=1.16.1,<1.17.0", + nightly=">=1.17.0.dev", + git_master="@git+https://github.com/tensorflow/metadata@master", + ), + "tf_keras>=2", + "tfx-bsl" + + select_constraint( + default=">=1.16.1,<1.17.0", + nightly=">=1.17.0.dev", + git_master="@git+https://github.com/tensorflow/tfx-bsl@master", + ), + ] # Get the long description from the README file. -with open('README.md') as fp: - _LONG_DESCRIPTION = fp.read() +with open("README.md") as fp: + _LONG_DESCRIPTION = fp.read() setup( - name='tensorflow-transform', + name="tensorflow-transform", version=__version__, - author='Google Inc.', - author_email='tensorflow-extended-dev@googlegroups.com', - license='Apache 2.0', + author="Google Inc.", + author_email="tensorflow-extended-dev@googlegroups.com", + license="Apache 2.0", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'Intended Audience :: Education', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', - 'Programming Language :: Python :: 3 :: Only', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", ], namespace_packages=[], install_requires=_make_required_install_packages(), extras_require={ - "dev": ["pre-commit"], + "dev": ["pre-commit"], }, - python_requires='>=3.9,<4', + python_requires=">=3.9,<4", packages=find_packages(), include_package_data=True, - package_data={'tensorflow_transform': ['py.typed']}, - description='A library for data preprocessing with TensorFlow', + package_data={"tensorflow_transform": ["py.typed"]}, + description="A library for data preprocessing with TensorFlow", long_description=_LONG_DESCRIPTION, - long_description_content_type='text/markdown', - keywords='tensorflow transform tfx', - url='https://www.tensorflow.org/tfx/transform/get_started', - download_url='https://github.com/tensorflow/transform/tags', - requires=[]) + long_description_content_type="text/markdown", + keywords="tensorflow transform tfx", + url="https://www.tensorflow.org/tfx/transform/get_started", + download_url="https://github.com/tensorflow/transform/tags", + requires=[], +) diff --git a/tensorflow_transform/__init__.py b/tensorflow_transform/__init__.py index deb53eb..0c24bb2 100644 --- a/tensorflow_transform/__init__.py +++ b/tensorflow_transform/__init__.py @@ -14,18 +14,19 @@ """Init module for TF.Transform.""" # pylint: disable=wildcard-import -from tensorflow_transform import coders -from tensorflow_transform import experimental +from tensorflow_transform import coders, experimental from tensorflow_transform.analyzers import * from tensorflow_transform.annotators import * from tensorflow_transform.inspect_preprocessing_fn import * from tensorflow_transform.mappers import * -from tensorflow_transform.output_wrapper import TFTransformOutput -from tensorflow_transform.output_wrapper import TransformFeaturesLayer +from tensorflow_transform.output_wrapper import ( + TFTransformOutput, + TransformFeaturesLayer, +) from tensorflow_transform.py_func.api import apply_pyfunc from tensorflow_transform.tf_metadata.dataset_metadata import DatasetMetadata -# pylint: enable=wildcard-import +# pylint: enable=wildcard-import # Import version string. from tensorflow_transform.version import __version__ @@ -33,11 +34,13 @@ # `tensorflow_io` package. Hence, this import is needed wherever we touch the # filesystem. try: - import tensorflow_io as _ # pytype: disable=import-error # pylint: disable=g-import-not-at-top + import tensorflow_io as _ # pytype: disable=import-error # pylint: disable=g-import-not-at-top except ModuleNotFoundError: - pass + pass try: - from tensorflow_transform import google # pytype: disable=import-error # pylint: disable=g-import-not-at-top + from tensorflow_transform import ( + google, # pytype: disable=import-error # pylint: disable=g-import-not-at-top + ) except ImportError: - pass + pass diff --git a/tensorflow_transform/analyzer_nodes.py b/tensorflow_transform/analyzer_nodes.py index 8846451..8685648 100644 --- a/tensorflow_transform/analyzer_nodes.py +++ b/tensorflow_transform/analyzer_nodes.py @@ -25,450 +25,473 @@ import json import os import struct -from typing import Any, Optional, Sequence, Type import uuid +from typing import Any, Optional, Sequence, Type import numpy as np import tensorflow as tf -from tensorflow_transform import common_types -from tensorflow_transform import nodes -from tensorflow_transform import tf2_utils -from tensorflow_transform import tf_utils -from tensorflow_transform.graph_context import TFGraphContext + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import ops + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple -# pylint: disable=g-direct-tensorflow-import -from tensorflow.python.framework import ops +from tensorflow_transform import common_types, nodes, tf2_utils, tf_utils +from tensorflow_transform.graph_context import TFGraphContext + # pylint: disable=g-enable-tensorflow-import # Key for graph collection containing `TensorSink` objects representing TFT # analyzers. -TENSOR_REPLACEMENTS = 'tft_tensor_replacements' +TENSOR_REPLACEMENTS = "tft_tensor_replacements" # Key for graph collection containing `TensorSink` objects representing TFT # analyzers irrespective of whether they have been evaluated or not. -ALL_REPLACEMENTS = 'tft_all_replacements' +ALL_REPLACEMENTS = "tft_all_replacements" def sanitize_label(label: str) -> str: - return label.replace('/', '#') + return label.replace("/", "#") -def _make_label(cls: Type[nodes.OperationDef], - label: Optional[str] = None) -> str: - if label is None: - scope = tf.compat.v1.get_default_graph().get_name_scope() - label = '{}[{}]'.format(cls.__name__, scope) - return sanitize_label(label) +def _make_label(cls: Type[nodes.OperationDef], label: Optional[str] = None) -> str: + if label is None: + scope = tf.compat.v1.get_default_graph().get_name_scope() + label = f"{cls.__name__}[{scope}]" + return sanitize_label(label) -TemporaryAssetInfo = tfx_namedtuple.namedtuple('TemporaryAssetInfo', - ['value', 'file_format']) +TemporaryAssetInfo = tfx_namedtuple.namedtuple( + "TemporaryAssetInfo", ["value", "file_format"] +) class TensorInfo( - tfx_namedtuple.namedtuple('TensorInfo', - ['dtype', 'shape', 'temporary_asset_info'])): - """A container for attributes of output tensors from analyzers. - - Fields: - dtype: The TensorFlow dtype. - shape: The shape of the tensor. - temporary_asset_info: A named tuple containing information about the - temporary asset file to write out while tracing the TF graph. - """ - - def __new__( - cls: Type['TensorInfo'], dtype: tf.dtypes.DType, - shape: Optional[Sequence[Optional[int]]], - temporary_asset_info: Optional[TemporaryAssetInfo]) -> 'TensorInfo': - if not isinstance(dtype, tf.DType): - raise TypeError('dtype must be a TensorFlow dtype, got {}'.format(dtype)) - if temporary_asset_info is not None and not isinstance( - temporary_asset_info, TemporaryAssetInfo): - raise TypeError( - 'temporary_asset_info should be an instance of TemporaryAssetInfo or ' - f'None, got {temporary_asset_info}') - return super(TensorInfo, cls).__new__( - cls, - dtype=dtype, - shape=shape, - temporary_asset_info=temporary_asset_info) + tfx_namedtuple.namedtuple("TensorInfo", ["dtype", "shape", "temporary_asset_info"]) +): + """A container for attributes of output tensors from analyzers. + + Fields: + dtype: The TensorFlow dtype. + shape: The shape of the tensor. + temporary_asset_info: A named tuple containing information about the + temporary asset file to write out while tracing the TF graph. + """ + + def __new__( + cls: Type["TensorInfo"], + dtype: tf.dtypes.DType, + shape: Optional[Sequence[Optional[int]]], + temporary_asset_info: Optional[TemporaryAssetInfo], + ) -> "TensorInfo": + if not isinstance(dtype, tf.DType): + raise TypeError(f"dtype must be a TensorFlow dtype, got {dtype}") + if temporary_asset_info is not None and not isinstance( + temporary_asset_info, TemporaryAssetInfo + ): + raise TypeError( + "temporary_asset_info should be an instance of TemporaryAssetInfo or " + f"None, got {temporary_asset_info}" + ) + return super(TensorInfo, cls).__new__( + cls, dtype=dtype, shape=shape, temporary_asset_info=temporary_asset_info + ) class TensorSource( - tfx_namedtuple.namedtuple('TensorSource', ['tensors', 'label']), - nodes.OperationDef): - """An `OperationDef` that defines extracting a tuple of tensor values. - - This `OperationDef` defines an operation that extracts the values of the given - tensors into a PCollection of tuples of values. It is used as a source for - analyzers, which further transform - - This OperationDef accepts zero inputs and return a single output representing - the PCollection of tuples of values. It will be converted in - tensorflow_transform.beam.analysis_graph_builder.build to an operation that - extracts the tensors for a dictionary of tensors, after running a beam.ParDo - to produce tensor values by running the graph on its inputs. - - Fields: - tensors: The tensors whose values should be extracted. - label: A unique label for this operation. - """ - - def __new__(cls, tensors): - for tensor in tensors: - if not isinstance(tensor, tf.Tensor): - raise TypeError('tensor must be a Tensor, got {} of type {}'.format( - tensor, type(tensor))) - return super(TensorSource, cls).__new__( - cls, tensors=tensors, label=_make_label(cls)) + tfx_namedtuple.namedtuple("TensorSource", ["tensors", "label"]), nodes.OperationDef +): + """An `OperationDef` that defines extracting a tuple of tensor values. + This `OperationDef` defines an operation that extracts the values of the given + tensors into a PCollection of tuples of values. It is used as a source for + analyzers, which further transform -def get_input_tensors_value_nodes(tensor_inputs): - return nodes.apply_operation(TensorSource, tensors=tensor_inputs) + This OperationDef accepts zero inputs and return a single output representing + the PCollection of tuples of values. It will be converted in + tensorflow_transform.beam.analysis_graph_builder.build to an operation that + extracts the tensors for a dictionary of tensors, after running a beam.ParDo + to produce tensor values by running the graph on its inputs. + Fields: + tensors: The tensors whose values should be extracted. + label: A unique label for this operation. + """ + + def __new__(cls, tensors): + for tensor in tensors: + if not isinstance(tensor, tf.Tensor): + raise TypeError( + f"tensor must be a Tensor, got {tensor} of type {type(tensor)}" + ) + return super(TensorSource, cls).__new__( + cls, tensors=tensors, label=_make_label(cls) + ) -TensorSink = tfx_namedtuple.namedtuple( - 'TensorSink', ['tensor', 'future', 'is_asset_filepath']) + +def get_input_tensors_value_nodes(tensor_inputs): + return nodes.apply_operation(TensorSource, tensors=tensor_inputs) -def _bind_future_as_tensor_v1(future: nodes.ValueNode, - tensor_info: TensorInfo, - name: Optional[str] = None) -> tf.Tensor: - """Bind a future value as a tensor to a TF1 graph.""" - result = tf.compat.v1.placeholder(tensor_info.dtype, tensor_info.shape, name) - is_asset_filepath = tensor_info.temporary_asset_info is not None - tf.compat.v1.add_to_collection(TENSOR_REPLACEMENTS, - TensorSink(result, future, is_asset_filepath)) - return result +TensorSink = tfx_namedtuple.namedtuple( + "TensorSink", ["tensor", "future", "is_asset_filepath"] +) + + +def _bind_future_as_tensor_v1( + future: nodes.ValueNode, tensor_info: TensorInfo, name: Optional[str] = None +) -> tf.Tensor: + """Bind a future value as a tensor to a TF1 graph.""" + result = tf.compat.v1.placeholder(tensor_info.dtype, tensor_info.shape, name) + is_asset_filepath = tensor_info.temporary_asset_info is not None + tf.compat.v1.add_to_collection( + TENSOR_REPLACEMENTS, TensorSink(result, future, is_asset_filepath) + ) + return result _TemporaryAnalyzerOutputWrapper = tfx_namedtuple.namedtuple( - '_TemporaryAnalyzerOutputWrapper', ['eager_asset_path', 'graph_tensor']) + "_TemporaryAnalyzerOutputWrapper", ["eager_asset_path", "graph_tensor"] +) def _write_to_temporary_asset_file( - temp_dir: str, temporary_asset_info: TemporaryAssetInfo) -> str: - """Returns path to temporary asset file created during tracing.""" - # TODO(b/170111921): This temporary file should have a unique name to - # avoid namespace collisions between temporary files that contain data - # of different dtypes. - base_filename = uuid.uuid4().hex - if temporary_asset_info.file_format == 'text': - result = os.path.join(temp_dir, base_filename) - with tf.io.gfile.GFile(result, 'w') as f: - f.write(temporary_asset_info.value) - elif temporary_asset_info.file_format == 'tfrecord_gzip': - result = os.path.join(temp_dir, '{}.tfrecord.gz'.format(base_filename)) - with tf.io.TFRecordWriter(result, 'GZIP') as f: - f.write(temporary_asset_info.value) - else: - raise ValueError( - 'File format should be one of \'text\' or \'tfrecord_gzip\'. Received ' - f'{temporary_asset_info.file_format}') - return result + temp_dir: str, temporary_asset_info: TemporaryAssetInfo +) -> str: + """Returns path to temporary asset file created during tracing.""" + # TODO(b/170111921): This temporary file should have a unique name to + # avoid namespace collisions between temporary files that contain data + # of different dtypes. + base_filename = uuid.uuid4().hex + if temporary_asset_info.file_format == "text": + result = os.path.join(temp_dir, base_filename) + with tf.io.gfile.GFile(result, "w") as f: + f.write(temporary_asset_info.value) + elif temporary_asset_info.file_format == "tfrecord_gzip": + result = os.path.join(temp_dir, f"{base_filename}.tfrecord.gz") + with tf.io.TFRecordWriter(result, "GZIP") as f: + f.write(temporary_asset_info.value) + else: + raise ValueError( + "File format should be one of 'text' or 'tfrecord_gzip'. Received " + f"{temporary_asset_info.file_format}" + ) + return result def _get_temporary_analyzer_output( - temp_dir: str, - tensor_info: TensorInfo, - name: Optional[str] = None) -> _TemporaryAnalyzerOutputWrapper: - """Create a temporary graph tensor using attributes in `tensor_info`. - - Args: - temp_dir: Path to a directory to write out any temporary asset files to. - tensor_info: A `TensorInfo` object containing attributes to create the graph - tensor. - name: A string (or None). The created graph tensor uses this name. - - Returns: - A named tuple `_TemporaryAnalyzerOutputWrapper` with: - eager_asset_path: If the analyzer output is an asset file, an eager tensor - pointing to the file path. Else, None. - graph_tensor: The graph tensor representing the analyzer output. - """ - asset = None - with tf.name_scope('temporary_analyzer_output'): - temporary_asset_info = tensor_info.temporary_asset_info - is_asset_filepath = temporary_asset_info is not None - if is_asset_filepath: - # Placeholders cannot be used for assets, if this graph will be serialized - # to a SavedModel, as they will be initialized with the init op. If a - # `temp_dir` is provided, it is assumed that this graph will be - # serialized and a temporary asset file is written out. Else, a - # placeholder is returned. - # TODO(b/149997088): Reduce number of temporary files written out. - if temp_dir: - with tf.init_scope(): - temporary_asset_filepath = _write_to_temporary_asset_file( - temp_dir, temporary_asset_info) - asset = tf.constant(temporary_asset_filepath) - graph_tensor = tf.constant( - temporary_asset_filepath, - dtype=tensor_info.dtype, - shape=tensor_info.shape, - name=name) - else: - graph_tensor = tf.raw_ops.Placeholder( - dtype=tensor_info.dtype, shape=tensor_info.shape, name=name) - else: - # Using a placeholder with no default value causes tracing to fail if - # there is any control flow dependent on a child tensor of this - # placeholder. Hence, provide a temporary default value for it. - # If dtype is string, we want a tensor that contains '0's instead of b'[] - # to allow string to numeric conversion ops to trace successfully. - temporary_dtype = ( - tf.int64 if tensor_info.dtype == tf.string else tensor_info.dtype) - temporary_tensor = tf2_utils.supply_missing_tensor( - 1, tf.TensorShape(tensor_info.shape), temporary_dtype) - if tensor_info.dtype == tf.string: - temporary_tensor = tf.strings.as_string(temporary_tensor) - graph_tensor = tf.raw_ops.PlaceholderWithDefault( - input=temporary_tensor, shape=tensor_info.shape, name=name) - return _TemporaryAnalyzerOutputWrapper(asset, graph_tensor) + temp_dir: str, tensor_info: TensorInfo, name: Optional[str] = None +) -> _TemporaryAnalyzerOutputWrapper: + """Create a temporary graph tensor using attributes in `tensor_info`. + + Args: + ---- + temp_dir: Path to a directory to write out any temporary asset files to. + tensor_info: A `TensorInfo` object containing attributes to create the graph + tensor. + name: A string (or None). The created graph tensor uses this name. + + Returns: + ------- + A named tuple `_TemporaryAnalyzerOutputWrapper` with: + eager_asset_path: If the analyzer output is an asset file, an eager tensor + pointing to the file path. Else, None. + graph_tensor: The graph tensor representing the analyzer output. + """ + asset = None + with tf.name_scope("temporary_analyzer_output"): + temporary_asset_info = tensor_info.temporary_asset_info + is_asset_filepath = temporary_asset_info is not None + if is_asset_filepath: + # Placeholders cannot be used for assets, if this graph will be serialized + # to a SavedModel, as they will be initialized with the init op. If a + # `temp_dir` is provided, it is assumed that this graph will be + # serialized and a temporary asset file is written out. Else, a + # placeholder is returned. + # TODO(b/149997088): Reduce number of temporary files written out. + if temp_dir: + with tf.init_scope(): + temporary_asset_filepath = _write_to_temporary_asset_file( + temp_dir, temporary_asset_info + ) + asset = tf.constant(temporary_asset_filepath) + graph_tensor = tf.constant( + temporary_asset_filepath, + dtype=tensor_info.dtype, + shape=tensor_info.shape, + name=name, + ) + else: + graph_tensor = tf.raw_ops.Placeholder( + dtype=tensor_info.dtype, shape=tensor_info.shape, name=name + ) + else: + # Using a placeholder with no default value causes tracing to fail if + # there is any control flow dependent on a child tensor of this + # placeholder. Hence, provide a temporary default value for it. + # If dtype is string, we want a tensor that contains '0's instead of b'[] + # to allow string to numeric conversion ops to trace successfully. + temporary_dtype = ( + tf.int64 if tensor_info.dtype == tf.string else tensor_info.dtype + ) + temporary_tensor = tf2_utils.supply_missing_tensor( + 1, tf.TensorShape(tensor_info.shape), temporary_dtype + ) + if tensor_info.dtype == tf.string: + temporary_tensor = tf.strings.as_string(temporary_tensor) + graph_tensor = tf.raw_ops.PlaceholderWithDefault( + input=temporary_tensor, shape=tensor_info.shape, name=name + ) + return _TemporaryAnalyzerOutputWrapper(asset, graph_tensor) def _bind_future_as_tensor_v2( - future: nodes.ValueNode, - tensor_info: TensorInfo, - name: Optional[str] = None) -> common_types.TemporaryAnalyzerOutputType: - """Bind a future value as a tensor to a TF2 FuncGraph. - - If the future is expected to write out an asset file and this method is - invoked within a `TFGraphContext` that was provided a temporary directory, - a temporary file is written out by this method. - - This could write out a significant number of temporary files depending on - number of times the `preprocessing_fn` is traced and number of asset files - in each tracing. - - Args: - future: Future whose result should replace the graph tensor to which its - bound. - tensor_info: A `TensorInfo` object containing attributes to create the graph - tensor. - name: (Optional) If provided, the graph tensor created uses this name. - - Returns: - A graph tensor or `tf.saved_model.Asset` that this future is bound to. If - this future has already been evaluated in a previous TFT phase, it is - directly returned. - """ - graph = ops.get_default_graph() - temp_dir = TFGraphContext.get_or_create_temp_dir() - temporary_analyzer_info = _get_temporary_analyzer_output( - temp_dir, tensor_info, name) - is_asset_filepath = tensor_info.temporary_asset_info is not None - - # TODO(b/149997088): Switch to using a counter instead of tensor names. - # Check if an evaluated value exists for this analyzer node. - evaluated_replacements = TFGraphContext.get_evaluated_replacements() - # evaluated_replacements is a dictionary from placeholder name to evaluated - # tensor. - # If `preprocessing_fn` was traced previously and this future was then - # evaluated in a TFT phase, the result will be present in this dictionary. - analyzer_name = temporary_analyzer_info.graph_tensor.name - tensor_sink = TensorSink(temporary_analyzer_info.graph_tensor, future, - is_asset_filepath) - graph.add_to_collection(ALL_REPLACEMENTS, tensor_sink) - if (evaluated_replacements is not None and - analyzer_name in evaluated_replacements): - replaced_result = evaluated_replacements[analyzer_name] - if is_asset_filepath: - graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS, - replaced_result) - return replaced_result + future: nodes.ValueNode, tensor_info: TensorInfo, name: Optional[str] = None +) -> common_types.TemporaryAnalyzerOutputType: + """Bind a future value as a tensor to a TF2 FuncGraph. + + If the future is expected to write out an asset file and this method is + invoked within a `TFGraphContext` that was provided a temporary directory, + a temporary file is written out by this method. + + This could write out a significant number of temporary files depending on + number of times the `preprocessing_fn` is traced and number of asset files + in each tracing. + + Args: + ---- + future: Future whose result should replace the graph tensor to which its + bound. + tensor_info: A `TensorInfo` object containing attributes to create the graph + tensor. + name: (Optional) If provided, the graph tensor created uses this name. + + Returns: + ------- + A graph tensor or `tf.saved_model.Asset` that this future is bound to. If + this future has already been evaluated in a previous TFT phase, it is + directly returned. + """ + graph = ops.get_default_graph() + temp_dir = TFGraphContext.get_or_create_temp_dir() + temporary_analyzer_info = _get_temporary_analyzer_output( + temp_dir, tensor_info, name + ) + is_asset_filepath = tensor_info.temporary_asset_info is not None + + # TODO(b/149997088): Switch to using a counter instead of tensor names. + # Check if an evaluated value exists for this analyzer node. + evaluated_replacements = TFGraphContext.get_evaluated_replacements() + # evaluated_replacements is a dictionary from placeholder name to evaluated + # tensor. + # If `preprocessing_fn` was traced previously and this future was then + # evaluated in a TFT phase, the result will be present in this dictionary. + analyzer_name = temporary_analyzer_info.graph_tensor.name + tensor_sink = TensorSink( + temporary_analyzer_info.graph_tensor, future, is_asset_filepath + ) + graph.add_to_collection(ALL_REPLACEMENTS, tensor_sink) + if evaluated_replacements is not None and analyzer_name in evaluated_replacements: + replaced_result = evaluated_replacements[analyzer_name] + if is_asset_filepath: + graph.add_to_collection( + tf.compat.v1.GraphKeys.ASSET_FILEPATHS, replaced_result + ) + return replaced_result + else: + return replaced_result else: - return replaced_result - else: - graph.add_to_collection(TENSOR_REPLACEMENTS, tensor_sink) - eager_asset_path = temporary_analyzer_info.eager_asset_path - if is_asset_filepath and eager_asset_path is not None: - tf_utils.track_asset_analyzer_output(eager_asset_path, - temporary_analyzer_info.graph_tensor) - graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS, - eager_asset_path) - return temporary_analyzer_info.graph_tensor + graph.add_to_collection(TENSOR_REPLACEMENTS, tensor_sink) + eager_asset_path = temporary_analyzer_info.eager_asset_path + if is_asset_filepath and eager_asset_path is not None: + tf_utils.track_asset_analyzer_output( + eager_asset_path, temporary_analyzer_info.graph_tensor + ) + graph.add_to_collection( + tf.compat.v1.GraphKeys.ASSET_FILEPATHS, eager_asset_path + ) + return temporary_analyzer_info.graph_tensor def bind_future_as_tensor( - future: nodes.ValueNode, - tensor_info: TensorInfo, - name: Optional[str] = None) -> common_types.TemporaryAnalyzerOutputType: - """Bind a future value as a tensor.""" - if tf.inside_function(): - # If the default graph is a `FuncGraph`, tf.function was used to trace the - # preprocessing fn. - return _bind_future_as_tensor_v2(future, tensor_info, name) - else: - return _bind_future_as_tensor_v1(future, tensor_info, name) + future: nodes.ValueNode, tensor_info: TensorInfo, name: Optional[str] = None +) -> common_types.TemporaryAnalyzerOutputType: + """Bind a future value as a tensor.""" + if tf.inside_function(): + # If the default graph is a `FuncGraph`, tf.function was used to trace the + # preprocessing fn. + return _bind_future_as_tensor_v2(future, tensor_info, name) + else: + return _bind_future_as_tensor_v1(future, tensor_info, name) def wrap_as_tensor( - output_value_node: nodes.ValueNode + output_value_node: nodes.ValueNode, ) -> common_types.TemporaryAnalyzerOutputType: - analyzer_def = output_value_node.parent_operation.operation_def - assert isinstance(analyzer_def, AnalyzerDef) - return bind_future_as_tensor( - output_value_node, - analyzer_def.output_tensor_infos[output_value_node.value_index]) + analyzer_def = output_value_node.parent_operation.operation_def + assert isinstance(analyzer_def, AnalyzerDef) + return bind_future_as_tensor( + output_value_node, + analyzer_def.output_tensor_infos[output_value_node.value_index], + ) class Combiner: - """Analyze using combiner function. + """Analyze using combiner function. - This object mirrors a beam.CombineFn, that will receive a beam PCollection - representing the batched input tensors. - """ + This object mirrors a beam.CombineFn, that will receive a beam PCollection + representing the batched input tensors. + """ - def __repr__(self): - return '<{}>'.format(self.__class__.__name__) + def __repr__(self): + return f"<{self.__class__.__name__}>" - def create_accumulator(self): - """Return a fresh, empty accumulator. + def create_accumulator(self): + """Return a fresh, empty accumulator. - Returns: An empty accumulator. This can be any Python value. - """ - raise NotImplementedError + Returns: An empty accumulator. This can be any Python value. + """ + raise NotImplementedError - def add_input(self, accumulator, batch_values): - """Return result of folding a batch of inputs into accumulator. + def add_input(self, accumulator, batch_values): + """Return result of folding a batch of inputs into accumulator. - Args: - accumulator: the current accumulator - batch_values: A list of ndarrays representing the values of the inputs for - a batch, which should be added to the accumulator. + Args: + ---- + accumulator: the current accumulator + batch_values: A list of ndarrays representing the values of the inputs for + a batch, which should be added to the accumulator. - Returns: An accumulator that includes the batch of inputs. - """ - raise NotImplementedError + Returns: An accumulator that includes the batch of inputs. + """ + raise NotImplementedError - def merge_accumulators(self, accumulators): - """Merges several accumulators to a single accumulator value. + def merge_accumulators(self, accumulators): + """Merges several accumulators to a single accumulator value. - Args: - accumulators: the accumulators to merge + Args: + ---- + accumulators: the accumulators to merge - Returns: The sole merged accumulator. - """ - raise NotImplementedError + Returns: The sole merged accumulator. + """ + raise NotImplementedError - def compact(self, accumulator): - """Returns an equivalent but more compact represenation of the accumulator. + def compact(self, accumulator): + """Returns an equivalent but more compact represenation of the accumulator. - Args: - accumulator: the current accumulator. + Args: + ---- + accumulator: the current accumulator. - Returns: A more compact accumulator. - """ - return accumulator + Returns: A more compact accumulator. + """ + return accumulator - def extract_output(self, accumulator): - """Return result of converting accumulator into the output value. + def extract_output(self, accumulator): + """Return result of converting accumulator into the output value. - Args: - accumulator: the final accumulator value. + Args: + ---- + accumulator: the final accumulator value. - Returns: A list of ndarrays representing the result of this combiner. - """ - raise NotImplementedError + Returns: A list of ndarrays representing the result of this combiner. + """ + raise NotImplementedError - def output_tensor_infos(self): - """Return the number / types of outputs that are produced by extract_output. + def output_tensor_infos(self): + """Return the number / types of outputs that are produced by extract_output. - Returns: An iterable of `TensorInfo` describing how the outputs that - extract_output will produce should be wrapped as `Tensor`s. + Returns: An iterable of `TensorInfo` describing how the outputs that + extract_output will produce should be wrapped as `Tensor`s. - Types are required to be TensorFlow dtypes. - """ - raise NotImplementedError + Types are required to be TensorFlow dtypes. + """ + raise NotImplementedError - @property - def accumulator_coder(self): - return JsonNumpyCacheCoder() + @property + def accumulator_coder(self): + return JsonNumpyCacheCoder() class CacheCoder(metaclass=abc.ABCMeta): - """A coder iterface for encoding and decoding cache items.""" + """A coder iterface for encoding and decoding cache items.""" - def __repr__(self): - return '<{}>'.format(self.__class__.__name__) + def __repr__(self): + return f"<{self.__class__.__name__}>" - @abc.abstractmethod - def encode_cache(self, cache): - pass + @abc.abstractmethod + def encode_cache(self, cache): + pass - @abc.abstractmethod - def decode_cache(self, encoded_cache): - pass + @abc.abstractmethod + def decode_cache(self, encoded_cache): + pass class JsonNumpyCacheCoder(CacheCoder): - """An accumulator cache coder that can handle lists.""" - - def __init__(self, np_dtype=None): - self._dtype = np_dtype + """An accumulator cache coder that can handle lists.""" + + def __init__(self, np_dtype=None): + self._dtype = np_dtype + + def _convert_numpy_dtype(self, x): + if hasattr(x, "tolist"): + return x.tolist() + return x + + def encode_cache(self, accumulator): + if isinstance(accumulator, (list, tuple)): + primitive_accumulator = [self._convert_numpy_dtype(a) for a in accumulator] + else: + primitive_accumulator = self._convert_numpy_dtype(accumulator) + # Need to wrap in np.array and call tolist to make it JSON serializable. + return tf.compat.as_bytes(json.dumps(primitive_accumulator)) + + def decode_cache(self, encoded_accumulator): + # TODO(b/268341036): Set dtype correctly for combiners for numpy 1.24. + try: + return np.array( + json.loads(tf.compat.as_text(encoded_accumulator)), dtype=self._dtype + ) + except ValueError: + if self._dtype != object: + return np.array( + json.loads(tf.compat.as_text(encoded_accumulator)), dtype=object + ) + raise - def _convert_numpy_dtype(self, x): - if hasattr(x, 'tolist'): - return x.tolist() - return x - def encode_cache(self, accumulator): - if isinstance(accumulator, (list, tuple)): - primitive_accumulator = [ - self._convert_numpy_dtype(a) for a in accumulator - ] - else: - primitive_accumulator = self._convert_numpy_dtype(accumulator) - # Need to wrap in np.array and call tolist to make it JSON serializable. - return tf.compat.as_bytes(json.dumps(primitive_accumulator)) - - def decode_cache(self, encoded_accumulator): - # TODO(b/268341036): Set dtype correctly for combiners for numpy 1.24. - try: - return np.array( - json.loads(tf.compat.as_text(encoded_accumulator)), dtype=self._dtype - ) - except ValueError: - if self._dtype != object: - return np.array( - json.loads(tf.compat.as_text(encoded_accumulator)), dtype=object - ) - raise +class AnalyzerDef(nodes.OperationDef, metaclass=abc.ABCMeta): + """A subclass of OperationDef whose outputs can be constant tensors. + An AnalyzerDef is an OperationDef that also provides enough information to + wrap each of its outputs as constant `Tensor`s in the graph. By inserting + the output of the AnalyzerDef back into the graph, the user can define + multiple levels of anaylsis and transformation. -class AnalyzerDef(nodes.OperationDef, metaclass=abc.ABCMeta): - """A subclass of OperationDef whose outputs can be constant tensors. - - An AnalyzerDef is an OperationDef that also provides enough information to - wrap each of its outputs as constant `Tensor`s in the graph. By inserting - the output of the AnalyzerDef back into the graph, the user can define - multiple levels of anaylsis and transformation. - - All `OperationDef`s are placeholders for operations that will be implemented - as `beam.PTransform`s. This is done by a registration system. The subclasses - defined below that inherit from `AnalyzerDef` have there implementations - registered in the module `tensorflow_transform.beam.analyzer_impls`. - """ - - @property - @abc.abstractmethod - def output_tensor_infos(self): - """A description on how to wrap the outputs of this AnalyzerDef. - - An `OperationDef` defines the number of outputs it creates. An - `AnalyzerDef` must implemented this property that defines not only the - number of outputs but how to wrap each output as a tensor. + All `OperationDef`s are placeholders for operations that will be implemented + as `beam.PTransform`s. This is done by a registration system. The subclasses + defined below that inherit from `AnalyzerDef` have there implementations + registered in the module `tensorflow_transform.beam.analyzer_impls`. """ - pass - @property - def num_outputs(self): - """The number of outputs returned by this operation.""" - return len(self.output_tensor_infos) + @property + @abc.abstractmethod + def output_tensor_infos(self): + """A description on how to wrap the outputs of this AnalyzerDef. + + An `OperationDef` defines the number of outputs it creates. An + `AnalyzerDef` must implemented this property that defines not only the + number of outputs but how to wrap each output as a tensor. + """ + pass + + @property + def num_outputs(self): + """The number of outputs returned by this operation.""" + return len(self.output_tensor_infos) # We do the packing of combiners after the caching optimization. Hence, we don't @@ -477,699 +500,813 @@ def num_outputs(self): # more of a Beam execution level optimization and we want to keep it towards the # end. So that, once Beam can automatically pack combines, we can remove this. class PackedCombineAccumulate( - tfx_namedtuple.namedtuple('PackedCombineAccumulate', - ['combiners', 'label']), nodes.OperationDef): - """An analyzer that packs a list of combiners into a single beam CombineFn. + tfx_namedtuple.namedtuple("PackedCombineAccumulate", ["combiners", "label"]), + nodes.OperationDef, +): + """An analyzer that packs a list of combiners into a single beam CombineFn. + + Fields: + combiners: A list of `analysis_graph_builder._CombinerOpWrapper` objects. + label: A unique label for this operation. + """ - Fields: - combiners: A list of `analysis_graph_builder._CombinerOpWrapper` objects. - label: A unique label for this operation. - """ - __slots__ = () + __slots__ = () - def __new__(cls, combiners, label): - return super(PackedCombineAccumulate, cls).__new__( - cls, combiners=combiners, label=_make_label(cls, label)) + def __new__(cls, combiners, label): + return super(PackedCombineAccumulate, cls).__new__( + cls, combiners=combiners, label=_make_label(cls, label) + ) - @property - def num_outputs(self): - return 1 + @property + def num_outputs(self): + return 1 - # Note that this will not have any effect as packing of combiners is done - # after the caching optimization. - @property - def is_partitionable(self): - return True + # Note that this will not have any effect as packing of combiners is done + # after the caching optimization. + @property + def is_partitionable(self): + return True class PackedCombineMerge( - tfx_namedtuple.namedtuple('PackedCombineMerge', ['combiners', 'label']), - nodes.OperationDef): - """An analyzer that packs a list of combiners into a single beam CombineFn. + tfx_namedtuple.namedtuple("PackedCombineMerge", ["combiners", "label"]), + nodes.OperationDef, +): + """An analyzer that packs a list of combiners into a single beam CombineFn. - Fields: - combiners: A list of `analysis_graph_builder._CombinerOpWrapper` objects. - label: A unique label for this operation. - """ - __slots__ = () + Fields: + combiners: A list of `analysis_graph_builder._CombinerOpWrapper` objects. + label: A unique label for this operation. + """ + + __slots__ = () - def __new__(cls, combiners, label): - return super(PackedCombineMerge, cls).__new__( - cls, combiners=combiners, label=_make_label(cls, label)) + def __new__(cls, combiners, label): + return super(PackedCombineMerge, cls).__new__( + cls, combiners=combiners, label=_make_label(cls, label) + ) - @property - def num_outputs(self): - return 1 + @property + def num_outputs(self): + return 1 class CacheableCombineAccumulate( - tfx_namedtuple.namedtuple('CacheableCombineAccumulate', - ['combiner', 'label']), nodes.OperationDef): - """An analyzer that runs a beam CombineFn to accumulate without merging. + tfx_namedtuple.namedtuple("CacheableCombineAccumulate", ["combiner", "label"]), + nodes.OperationDef, +): + """An analyzer that runs a beam CombineFn to accumulate without merging. + + This analyzer reduces the values that it accepts as inputs, using the + provided `Combiner`. The `Combiner` is applied to the data by wrapping it as + a `beam.CombineFn` and applying `beam.Combine`. - This analyzer reduces the values that it accepts as inputs, using the - provided `Combiner`. The `Combiner` is applied to the data by wrapping it as - a `beam.CombineFn` and applying `beam.Combine`. + Fields: + combiner: The Combiner to be applies to the inputs. + label: A unique label for this operation. + """ - Fields: - combiner: The Combiner to be applies to the inputs. - label: A unique label for this operation. - """ - __slots__ = () + __slots__ = () - def __new__(cls, combiner): - return super(CacheableCombineAccumulate, cls).__new__( - cls, combiner=combiner, label=_make_label(cls)) + def __new__(cls, combiner): + return super(CacheableCombineAccumulate, cls).__new__( + cls, combiner=combiner, label=_make_label(cls) + ) - @property - def num_outputs(self): - return 1 + @property + def num_outputs(self): + return 1 - @property - def is_partitionable(self): - return True + @property + def is_partitionable(self): + return True - @property - def cache_coder(self): - return self.combiner.accumulator_coder + @property + def cache_coder(self): + return self.combiner.accumulator_coder class CacheableCombineMerge( - tfx_namedtuple.namedtuple('CacheableCombineMerge', ['combiner', 'label']), - nodes.OperationDef): - """An analyzer that runs a beam CombineFn to only merge computed accumulators. + tfx_namedtuple.namedtuple("CacheableCombineMerge", ["combiner", "label"]), + nodes.OperationDef, +): + """An analyzer that runs a beam CombineFn to only merge computed accumulators. + + This analyzer reduces the values that it accepts as inputs, using the + provided `Combiner`. The `Combiner` is applied to the data by wrapping it as + a `beam.CombineFn` and applying `beam.Combine`. - This analyzer reduces the values that it accepts as inputs, using the - provided `Combiner`. The `Combiner` is applied to the data by wrapping it as - a `beam.CombineFn` and applying `beam.Combine`. + Fields: + combiner: The Combiner to be applied to the inputs. + label: A unique label for this operation. + """ - Fields: - combiner: The Combiner to be applied to the inputs. - label: A unique label for this operation. - """ - __slots__ = () + __slots__ = () - def __new__(cls, combiner): - return super(CacheableCombineMerge, cls).__new__( - cls, combiner=combiner, label=_make_label(cls)) + def __new__(cls, combiner): + return super(CacheableCombineMerge, cls).__new__( + cls, combiner=combiner, label=_make_label(cls) + ) - @property - def num_outputs(self): - return 1 + @property + def num_outputs(self): + return 1 class _CombinerPerKeyAccumulatorCoder(CacheCoder): - """Coder for per-key combiner accumulators.""" + """Coder for per-key combiner accumulators.""" - def __init__(self, value_coder): - self._combiner_coder = value_coder - self._vocabulary_coder = _BaseKVCoder() - super().__init__() + def __init__(self, value_coder): + self._combiner_coder = value_coder + self._vocabulary_coder = _BaseKVCoder() + super().__init__() - def __repr__(self): - return '<{}[{}[{}]]>'.format(self.__class__.__name__, - repr(self._vocabulary_coder), - repr(self._combiner_coder)) + def __repr__(self): + return f"<{self.__class__.__name__}[{repr(self._vocabulary_coder)}[{repr(self._combiner_coder)}]]>" - def encode_cache(self, accumulator): - key, value = accumulator - encoded_value = self._combiner_coder.encode_cache(value) - return self._vocabulary_coder.encode_cache((key, encoded_value)) + def encode_cache(self, accumulator): + key, value = accumulator + encoded_value = self._combiner_coder.encode_cache(value) + return self._vocabulary_coder.encode_cache((key, encoded_value)) - def decode_cache(self, encoded_accumulator): - accumulator = self._vocabulary_coder.decode_cache(encoded_accumulator) - key, encoded_value = accumulator - value = self._combiner_coder.decode_cache(encoded_value) - return (key, value) + def decode_cache(self, encoded_accumulator): + accumulator = self._vocabulary_coder.decode_cache(encoded_accumulator) + key, encoded_value = accumulator + value = self._combiner_coder.decode_cache(encoded_value) + return (key, value) class CacheableCombinePerKeyAccumulate( - tfx_namedtuple.namedtuple('CacheableCombinePerKeyAccumulate', - ['combiner', 'label']), AnalyzerDef): - """An analyzer that runs `beam.CombinePerKey` to accumulate without merging. + tfx_namedtuple.namedtuple( + "CacheableCombinePerKeyAccumulate", ["combiner", "label"] + ), + AnalyzerDef, +): + """An analyzer that runs `beam.CombinePerKey` to accumulate without merging. - This analyzer reduces the values that it accepts as inputs, using the - provided `Combiner`. The `Combiner` is applied to the data by wrapping it as - a `beam.CombineFn` and applying `beam.CombinePerKey`. + This analyzer reduces the values that it accepts as inputs, using the + provided `Combiner`. The `Combiner` is applied to the data by wrapping it as + a `beam.CombineFn` and applying `beam.CombinePerKey`. - This analyzer is implemented by - `tensorflow_transform.beam.analyzer_impls._IntermediateAccumulateCombineImpl`. + This analyzer is implemented by + `tensorflow_transform.beam.analyzer_impls._IntermediateAccumulateCombineImpl`. - Fields: - combiner: The Combiner to be applied to the inputs. - label: A unique label for this operation. - """ - __slots__ = () + Fields: + combiner: The Combiner to be applied to the inputs. + label: A unique label for this operation. + """ + + __slots__ = () - def __new__(cls, combiner): - return super(CacheableCombinePerKeyAccumulate, cls).__new__( - cls, combiner=combiner, label=_make_label(cls)) + def __new__(cls, combiner): + return super(CacheableCombinePerKeyAccumulate, cls).__new__( + cls, combiner=combiner, label=_make_label(cls) + ) - @property - def num_outputs(self): - return 1 + @property + def num_outputs(self): + return 1 - @property - def is_partitionable(self): - return True + @property + def is_partitionable(self): + return True - @property - def cache_coder(self): - return _CombinerPerKeyAccumulatorCoder(self.combiner.accumulator_coder) + @property + def cache_coder(self): + return _CombinerPerKeyAccumulatorCoder(self.combiner.accumulator_coder) class CacheableCombinePerKeyMerge( - tfx_namedtuple.namedtuple('CacheableCombinePerKeyMerge', - ['combiner', 'label']), nodes.OperationDef): - """An analyzer that runs `beam.CombinePerKey` to only merge accumulators. + tfx_namedtuple.namedtuple("CacheableCombinePerKeyMerge", ["combiner", "label"]), + nodes.OperationDef, +): + """An analyzer that runs `beam.CombinePerKey` to only merge accumulators. - This analyzer reduces the values that it accepts as inputs, using the - provided `Combiner`. The `Combiner` is applied to the data by wrapping it as - a `beam.CombineFn` and applying `beam.CombinePerKey`. + This analyzer reduces the values that it accepts as inputs, using the + provided `Combiner`. The `Combiner` is applied to the data by wrapping it as + a `beam.CombineFn` and applying `beam.CombinePerKey`. - This analyzer is implemented by - `tensorflow_transform.beam.analyzer_impls._MergeAccumulatorsCombinePerKeyImpl` + This analyzer is implemented by + `tensorflow_transform.beam.analyzer_impls._MergeAccumulatorsCombinePerKeyImpl` - Fields: - combiner: The Combiner to use for merging and extracting outputs. - label: A unique label for this operation. - """ - __slots__ = () + Fields: + combiner: The Combiner to use for merging and extracting outputs. + label: A unique label for this operation. + """ - def __new__(cls, combiner): - return super(CacheableCombinePerKeyMerge, cls).__new__( - cls, combiner=combiner, label=_make_label(cls)) + __slots__ = () + + def __new__(cls, combiner): + return super(CacheableCombinePerKeyMerge, cls).__new__( + cls, combiner=combiner, label=_make_label(cls) + ) class CacheableCombinePerKeyFormatKeys( - tfx_namedtuple.namedtuple('CacheableCombinePerKeyFormatKeys', - ['combiner', 'label']), AnalyzerDef): - """An analyzer that formats output for the non-stored per-key case. + tfx_namedtuple.namedtuple( + "CacheableCombinePerKeyFormatKeys", ["combiner", "label"] + ), + AnalyzerDef, +): + """An analyzer that formats output for the non-stored per-key case. + + This analyzer converts the (key, output) pairs into a tuple of keys (of type + string) and outputs. - This analyzer converts the (key, output) pairs into a tuple of keys (of type - string) and outputs. + This analyzer is implemented by + `tensorflow_transform.beam.analyzer_impls._CombinePerKeyFormatKeysImpl` - This analyzer is implemented by - `tensorflow_transform.beam.analyzer_impls._CombinePerKeyFormatKeysImpl` + Fields: + combiner: The Combiner to use for extracting outputs. + label: A unique label for this operation. + """ - Fields: - combiner: The Combiner to use for extracting outputs. - label: A unique label for this operation. - """ - __slots__ = () + __slots__ = () - def __new__(cls, combiner): - return super(CacheableCombinePerKeyFormatKeys, cls).__new__( - cls, combiner=combiner, label=_make_label(cls)) + def __new__(cls, combiner): + return super(CacheableCombinePerKeyFormatKeys, cls).__new__( + cls, combiner=combiner, label=_make_label(cls) + ) - @property - def output_tensor_infos(self): - # Returns a key vocab and one output per combiner output. - return [TensorInfo(tf.string, (None,), None)] + [ - TensorInfo(info.dtype, (None,) + info.shape, info.temporary_asset_info) - for info in self.combiner.output_tensor_infos() - ] + @property + def output_tensor_infos(self): + # Returns a key vocab and one output per combiner output. + return [TensorInfo(tf.string, (None,), None)] + [ + TensorInfo(info.dtype, (None,) + info.shape, info.temporary_asset_info) + for info in self.combiner.output_tensor_infos() + ] class CacheableCombinePerKeyFormatLarge( - tfx_namedtuple.namedtuple('CacheableCombinePerKeyFormatLarge', ['label']), - nodes.OperationDef): - """An analyzer that formats output prior to writing to file for per-key case. + tfx_namedtuple.namedtuple("CacheableCombinePerKeyFormatLarge", ["label"]), + nodes.OperationDef, +): + """An analyzer that formats output prior to writing to file for per-key case. - This operation operates on the output of CacheableCombinePerKeyAccumulate and - is implemented by `tensorflow_transform.beam.analyzer_impls. - _CombinePerKeyFormatLargeImpl`. - """ - __slots__ = () + This operation operates on the output of CacheableCombinePerKeyAccumulate and + is implemented by `tensorflow_transform.beam.analyzer_impls. + _CombinePerKeyFormatLargeImpl`. + """ - def __new__(cls): - return super(CacheableCombinePerKeyFormatLarge, cls).__new__( - cls, label=_make_label(cls)) + __slots__ = () - @property - def num_outputs(self): - return 1 + def __new__(cls): + return super(CacheableCombinePerKeyFormatLarge, cls).__new__( + cls, label=_make_label(cls) + ) + + @property + def num_outputs(self): + return 1 class ScaleAndFlattenPerKeyBucketBouandaries( - tfx_namedtuple.namedtuple('PostProcessPerKeyBucketBoundaries', - ['output_tensor_dtype', 'label']), AnalyzerDef): - """An analyzer which takes quantile boundaries per key and combines them. + tfx_namedtuple.namedtuple( + "PostProcessPerKeyBucketBoundaries", ["output_tensor_dtype", "label"] + ), + AnalyzerDef, +): + """An analyzer which takes quantile boundaries per key and combines them. - It receives a 2-d array of boundaries, computes scales and shifts to each - row separately, a new boundaries 1-d array which is a combination of - boundaries for all the keys, and the number of buckets defined for each key. + It receives a 2-d array of boundaries, computes scales and shifts to each + row separately, a new boundaries 1-d array which is a combination of + boundaries for all the keys, and the number of buckets defined for each key. - This outputs boundaries, scale_factor_per_key, shift_per_key, num_buckets. + This outputs boundaries, scale_factor_per_key, shift_per_key, num_buckets. - For example, for an input boundaries matrix, [[0, 1, 2], [0, 1, 2]] it will - return: - boundaries: [0, 0.5, 1, 1.5, 2] - scale_factor_per_key: [0.5, 0.5] - shift_per_key: [0, 1] - num_buckets: 4 + For example, for an input boundaries matrix, [[0, 1, 2], [0, 1, 2]] it will + return: + boundaries: [0, 0.5, 1, 1.5, 2] + scale_factor_per_key: [0.5, 0.5] + shift_per_key: [0, 1] + num_buckets: 4 - So the transformation of each input x before computing its bucket should be: - F(x, key) = x * scale_factor_per_key[key] + shift_per_key[key] - """ - __slots__ = () + So the transformation of each input x before computing its bucket should be: + F(x, key) = x * scale_factor_per_key[key] + shift_per_key[key] + """ + + __slots__ = () - def __new__(cls, output_tensor_dtype): - return super(ScaleAndFlattenPerKeyBucketBouandaries, cls).__new__( - cls, output_tensor_dtype=output_tensor_dtype, label=_make_label(cls)) + def __new__(cls, output_tensor_dtype): + return super(ScaleAndFlattenPerKeyBucketBouandaries, cls).__new__( + cls, output_tensor_dtype=output_tensor_dtype, label=_make_label(cls) + ) - @property - def output_tensor_infos(self): - # Boundaries, scale_factor_per_key, shift_per_key, num_buckets. - return [TensorInfo(self.output_tensor_dtype, - (None,), None)] * 3 + [TensorInfo(tf.int64, (), None)] + @property + def output_tensor_infos(self): + # Boundaries, scale_factor_per_key, shift_per_key, num_buckets. + return [TensorInfo(self.output_tensor_dtype, (None,), None)] * 3 + [ + TensorInfo(tf.int64, (), None) + ] class VocabularyAccumulate( - tfx_namedtuple.namedtuple('VocabularyAccumulate', - ['vocab_ordering_type', 'input_dtype', 'label']), - nodes.OperationDef): - """An operation that accumulates unique words with their frequency or weight. - - This operation is implemented by - `tensorflow_transform.beam.analyzer_impls._VocabularyAccumulateImpl`. - """ - __slots__ = () - - def __new__(cls, vocab_ordering_type, input_dtype=tf.string.name): - return super(VocabularyAccumulate, cls).__new__( - cls, - vocab_ordering_type=vocab_ordering_type, - input_dtype=input_dtype, - label=_make_label(cls)) + tfx_namedtuple.namedtuple( + "VocabularyAccumulate", ["vocab_ordering_type", "input_dtype", "label"] + ), + nodes.OperationDef, +): + """An operation that accumulates unique words with their frequency or weight. + + This operation is implemented by + `tensorflow_transform.beam.analyzer_impls._VocabularyAccumulateImpl`. + """ + + __slots__ = () + + def __new__(cls, vocab_ordering_type, input_dtype=tf.string.name): + return super(VocabularyAccumulate, cls).__new__( + cls, + vocab_ordering_type=vocab_ordering_type, + input_dtype=input_dtype, + label=_make_label(cls), + ) - @property - def num_outputs(self): - return 1 + @property + def num_outputs(self): + return 1 - @property - def is_partitionable(self): - return True + @property + def is_partitionable(self): + return True - @property - def cache_coder(self): - return _VocabularyAccumulatorCoder(input_dtype=self.input_dtype) + @property + def cache_coder(self): + return _VocabularyAccumulatorCoder(input_dtype=self.input_dtype) class _BaseKVCoder(CacheCoder): - """Coder for key-value based accumulators.""" - - def __init__(self): - self._lengths_prefix_format = 'qq' - self._lengths_prefix_length = struct.calcsize(self._lengths_prefix_format) - super().__init__() - - def encode_cache(self, accumulator): - token, value = accumulator - len_token, len_value = len(token), len(value) - return struct.pack( - '{}{}s{}s'.format(self._lengths_prefix_format, len_token, len_value), - len_token, len_value, token, value) - - def decode_cache(self, encoded_accumulator): - (len_token, len_value) = struct.unpack_from( - self._lengths_prefix_format, - encoded_accumulator[:self._lengths_prefix_length]) - accumulator = struct.unpack_from( - '{}s{}s'.format(len_token, len_value), - encoded_accumulator[self._lengths_prefix_length:]) - return accumulator + """Coder for key-value based accumulators.""" + + def __init__(self): + self._lengths_prefix_format = "qq" + self._lengths_prefix_length = struct.calcsize(self._lengths_prefix_format) + super().__init__() + + def encode_cache(self, accumulator): + token, value = accumulator + len_token, len_value = len(token), len(value) + return struct.pack( + f"{self._lengths_prefix_format}{len_token}s{len_value}s", + len_token, + len_value, + token, + value, + ) + + def decode_cache(self, encoded_accumulator): + (len_token, len_value) = struct.unpack_from( + self._lengths_prefix_format, + encoded_accumulator[: self._lengths_prefix_length], + ) + accumulator = struct.unpack_from( + f"{len_token}s{len_value}s", + encoded_accumulator[self._lengths_prefix_length :], + ) + return accumulator class _VocabularyAccumulatorCoder(_BaseKVCoder): - """Coder for vocabulary accumulators.""" - - def __init__(self, input_dtype=tf.string.name): - self._input_dtype = tf.dtypes.as_dtype(input_dtype) - super().__init__() - - def encode_cache(self, accumulator): - token, value = accumulator - if self._input_dtype is not tf.string: - token = tf.compat.as_bytes(json.dumps(token)) - # If the value is a _WeightedMeanAndVarAccumulator, cast each field to a - # list for serialization. - if isinstance(value, tuple): - value = [ - a.tolist() - for a in (value.count, value.mean, value.variance, value.weight) - ] - value = tf.compat.as_bytes(json.dumps(value)) - return super().encode_cache((token, value)) - - def decode_cache(self, encoded_accumulator): - accumulator = super().decode_cache(encoded_accumulator) - token, value = accumulator - if self._input_dtype is not tf.string: - token = json.loads(tf.compat.as_text(token)) - - value = json.loads(tf.compat.as_text(value)) - if isinstance(value, list): - # If the value is a _WeightedMeanAndVarAccumulator (serialized to tuple), - # cast each field back to a np.array. - (count, mean, variance, weight) = value - value = (np.array(count), np.array(mean), np.array(variance), - np.array(weight)) - return token, value + """Coder for vocabulary accumulators.""" + + def __init__(self, input_dtype=tf.string.name): + self._input_dtype = tf.dtypes.as_dtype(input_dtype) + super().__init__() + + def encode_cache(self, accumulator): + token, value = accumulator + if self._input_dtype is not tf.string: + token = tf.compat.as_bytes(json.dumps(token)) + # If the value is a _WeightedMeanAndVarAccumulator, cast each field to a + # list for serialization. + if isinstance(value, tuple): + value = [ + a.tolist() + for a in (value.count, value.mean, value.variance, value.weight) + ] + value = tf.compat.as_bytes(json.dumps(value)) + return super().encode_cache((token, value)) + + def decode_cache(self, encoded_accumulator): + accumulator = super().decode_cache(encoded_accumulator) + token, value = accumulator + if self._input_dtype is not tf.string: + token = json.loads(tf.compat.as_text(token)) + + value = json.loads(tf.compat.as_text(value)) + if isinstance(value, list): + # If the value is a _WeightedMeanAndVarAccumulator (serialized to tuple), + # cast each field back to a np.array. + (count, mean, variance, weight) = value + value = ( + np.array(count), + np.array(mean), + np.array(variance), + np.array(weight), + ) + return token, value class VocabularyCount( - tfx_namedtuple.namedtuple('VocabularyCount', ['label']), - nodes.OperationDef): - """An operation counts the total number of tokens in a vocabulary. + tfx_namedtuple.namedtuple("VocabularyCount", ["label"]), nodes.OperationDef +): + """An operation counts the total number of tokens in a vocabulary. + + This operation takes in the output of VocabularyAccumulate and is implemented + by `tensorflow_transform.beam.analyzer_impls._VocabularyCountImpl`. - This operation takes in the output of VocabularyAccumulate and is implemented - by `tensorflow_transform.beam.analyzer_impls._VocabularyCountImpl`. + The output of this operation is a singleton Integer. - The output of this operation is a singleton Integer. + Fields: + label: A unique label for this operation. + """ - Fields: - label: A unique label for this operation. - """ - __slots__ = () + __slots__ = () - def __new__(cls, label): - return super().__new__(cls, label=_make_label(cls, label)) + def __new__(cls, label): + return super().__new__(cls, label=_make_label(cls, label)) - @property - def num_outputs(self): - return 1 + @property + def num_outputs(self): + return 1 class VocabularyMerge( - tfx_namedtuple.namedtuple('VocabularyMerge', [ - 'vocab_ordering_type', 'use_adjusted_mutual_info', 'min_diff_from_avg', - 'label' - ]), nodes.OperationDef): - """An operation that merges the accumulators produced by VocabularyAccumulate. - - This operation operates on the output of VocabularyAccumulate and is - implemented by `tensorflow_transform.beam.analyzer_impls._VocabularyMergeImpl` - . - - See `tft.vocabulary` for a description of the parameters. - """ - __slots__ = () - - def __new__(cls, vocab_ordering_type, use_adjusted_mutual_info, - min_diff_from_avg): - return super(VocabularyMerge, cls).__new__( - cls, - vocab_ordering_type=vocab_ordering_type, - use_adjusted_mutual_info=use_adjusted_mutual_info, - min_diff_from_avg=min_diff_from_avg, - label=_make_label(cls)) + tfx_namedtuple.namedtuple( + "VocabularyMerge", + [ + "vocab_ordering_type", + "use_adjusted_mutual_info", + "min_diff_from_avg", + "label", + ], + ), + nodes.OperationDef, +): + """An operation that merges the accumulators produced by VocabularyAccumulate. - @property - def num_outputs(self): - return 1 + This operation operates on the output of VocabularyAccumulate and is + implemented by `tensorflow_transform.beam.analyzer_impls._VocabularyMergeImpl` + . + + See `tft.vocabulary` for a description of the parameters. + """ + + __slots__ = () + + def __new__(cls, vocab_ordering_type, use_adjusted_mutual_info, min_diff_from_avg): + return super(VocabularyMerge, cls).__new__( + cls, + vocab_ordering_type=vocab_ordering_type, + use_adjusted_mutual_info=use_adjusted_mutual_info, + min_diff_from_avg=min_diff_from_avg, + label=_make_label(cls), + ) + + @property + def num_outputs(self): + return 1 class VocabularyPrune( - tfx_namedtuple.namedtuple('VocabularyPrune', [ - 'top_k', 'frequency_threshold', 'informativeness_threshold', - 'coverage_top_k', 'coverage_frequency_threshold', - 'coverage_informativeness_threshold', 'key_fn', 'input_dtype', 'label' - ]), nodes.OperationDef): - """An operation that filters and orders a computed vocabulary. - - This operation operates on the output of VocabularyMerge and is implemented by - `tensorflow_transform.beam.analyzer_impls._VocabularyPruneImpl`. - - See `tft.vocabulary` for a description of the parameters. - """ - __slots__ = () - - def __new__(cls, - top_k, - frequency_threshold, - input_dtype, - informativeness_threshold=float('-inf'), - coverage_top_k=None, - coverage_frequency_threshold=0, - coverage_informativeness_threshold=float('-inf'), - key_fn=None): - return super(VocabularyPrune, cls).__new__( + tfx_namedtuple.namedtuple( + "VocabularyPrune", + [ + "top_k", + "frequency_threshold", + "informativeness_threshold", + "coverage_top_k", + "coverage_frequency_threshold", + "coverage_informativeness_threshold", + "key_fn", + "input_dtype", + "label", + ], + ), + nodes.OperationDef, +): + """An operation that filters and orders a computed vocabulary. + + This operation operates on the output of VocabularyMerge and is implemented by + `tensorflow_transform.beam.analyzer_impls._VocabularyPruneImpl`. + + See `tft.vocabulary` for a description of the parameters. + """ + + __slots__ = () + + def __new__( cls, - top_k=top_k, - frequency_threshold=frequency_threshold, - informativeness_threshold=informativeness_threshold, - coverage_top_k=coverage_top_k, - coverage_frequency_threshold=coverage_frequency_threshold, - coverage_informativeness_threshold=coverage_informativeness_threshold, - key_fn=key_fn, - input_dtype=input_dtype, - label=_make_label(cls)) + top_k, + frequency_threshold, + input_dtype, + informativeness_threshold=float("-inf"), + coverage_top_k=None, + coverage_frequency_threshold=0, + coverage_informativeness_threshold=float("-inf"), + key_fn=None, + ): + return super(VocabularyPrune, cls).__new__( + cls, + top_k=top_k, + frequency_threshold=frequency_threshold, + informativeness_threshold=informativeness_threshold, + coverage_top_k=coverage_top_k, + coverage_frequency_threshold=coverage_frequency_threshold, + coverage_informativeness_threshold=coverage_informativeness_threshold, + key_fn=key_fn, + input_dtype=input_dtype, + label=_make_label(cls), + ) - @property - def num_outputs(self): - return 1 + @property + def num_outputs(self): + return 1 class VocabularyOrderAndWrite( - tfx_namedtuple.namedtuple('VocabularyOrderAndWrite', [ - 'vocab_filename', 'store_frequency', 'input_dtype', 'label', - 'fingerprint_shuffle', 'file_format', 'input_is_sorted' - ]), AnalyzerDef): - """An analyzer that writes vocabulary files from an accumulator. - - This operation operates on the output of VocabularyPrune and is implemented by - `tensorflow_transform.beam.analyzer_impls._VocabularyOrderAndWriteImpl`. - - See `tft.vocabulary` for a description of the parameters. - """ - __slots__ = () - - def __new__(cls, - vocab_filename, - store_frequency, - fingerprint_shuffle, - file_format, - input_dtype=tf.string.name, - input_is_sorted=False): - return super(VocabularyOrderAndWrite, cls).__new__( + tfx_namedtuple.namedtuple( + "VocabularyOrderAndWrite", + [ + "vocab_filename", + "store_frequency", + "input_dtype", + "label", + "fingerprint_shuffle", + "file_format", + "input_is_sorted", + ], + ), + AnalyzerDef, +): + """An analyzer that writes vocabulary files from an accumulator. + + This operation operates on the output of VocabularyPrune and is implemented by + `tensorflow_transform.beam.analyzer_impls._VocabularyOrderAndWriteImpl`. + + See `tft.vocabulary` for a description of the parameters. + """ + + __slots__ = () + + def __new__( cls, - vocab_filename=vocab_filename, - store_frequency=store_frequency, - fingerprint_shuffle=fingerprint_shuffle, - file_format=file_format, - input_dtype=input_dtype, - input_is_sorted=input_is_sorted, - label=_make_label(cls)) - - @property - def output_tensor_infos(self): - # Define temporary data for this node to write to a file before the actual - # vocab file is evaluated and written out. - temporary_asset_value = (b'TEMPORARY_ASSET_VALUE' if tf.dtypes.as_dtype( - self.input_dtype) == tf.string else b'-777777') - if self.store_frequency: - temporary_asset_value = b'1 %s' % temporary_asset_value - - return [ - TensorInfo(tf.string, [], - TemporaryAssetInfo(temporary_asset_value, self.file_format)) - ] + vocab_filename, + store_frequency, + fingerprint_shuffle, + file_format, + input_dtype=tf.string.name, + input_is_sorted=False, + ): + return super(VocabularyOrderAndWrite, cls).__new__( + cls, + vocab_filename=vocab_filename, + store_frequency=store_frequency, + fingerprint_shuffle=fingerprint_shuffle, + file_format=file_format, + input_dtype=input_dtype, + input_is_sorted=input_is_sorted, + label=_make_label(cls), + ) + + @property + def output_tensor_infos(self): + # Define temporary data for this node to write to a file before the actual + # vocab file is evaluated and written out. + temporary_asset_value = ( + b"TEMPORARY_ASSET_VALUE" + if tf.dtypes.as_dtype(self.input_dtype) == tf.string + else b"-777777" + ) + if self.store_frequency: + temporary_asset_value = b"1 %s" % temporary_asset_value + + return [ + TensorInfo( + tf.string, + [], + TemporaryAssetInfo(temporary_asset_value, self.file_format), + ) + ] class ExtractVocabularyReservedTokens( - tfx_namedtuple.namedtuple( - 'ExtractVocabularyReservedTokens', ['name', 'label'] - ), + tfx_namedtuple.namedtuple("ExtractVocabularyReservedTokens", ["name", "label"]), nodes.OperationDef, ): - """An operation which extracts vocabulary reserved tokens from the graph.""" - __slots__ = () + """An operation which extracts vocabulary reserved tokens from the graph.""" - def __new__(cls, name): - return super(ExtractVocabularyReservedTokens, cls).__new__( - cls, name=name, label=_make_label(cls) - ) + __slots__ = () + + def __new__(cls, name): + return super(ExtractVocabularyReservedTokens, cls).__new__( + cls, name=name, label=_make_label(cls) + ) class PTransform( - tfx_namedtuple.namedtuple('PTransform', [ - 'ptransform', 'output_tensor_info_list', 'is_partitionable', - 'cache_coder', 'label' - ]), AnalyzerDef): - """(Experimental) OperationDef for PTransform anaylzer. - - This analyzer is implemented by - `tensorflow_transform.beam.analyzer_impls._PTransformImpl`. - - Fields: - ptransform: The `beam.PTransform` to be applied to the inputs. - output_tensor_info_list: A list of `TensorInfo`s that defines the outputs of - this `PTransform`. - is_partitionable: Whether or not this PTransform is partitionable. - cache_coder: (optional) A `CacheCoder` instance. - label: A unique label for this operation. - """ - __slots__ = () - - def __new__(cls, - ptransform: Any, - output_tensor_info_list: Sequence[TensorInfo], - is_partitionable: bool, - cache_coder: Optional[CacheCoder] = None): - return super(PTransform, cls).__new__( + tfx_namedtuple.namedtuple( + "PTransform", + [ + "ptransform", + "output_tensor_info_list", + "is_partitionable", + "cache_coder", + "label", + ], + ), + AnalyzerDef, +): + """(Experimental) OperationDef for PTransform anaylzer. + + This analyzer is implemented by + `tensorflow_transform.beam.analyzer_impls._PTransformImpl`. + + Fields: + ptransform: The `beam.PTransform` to be applied to the inputs. + output_tensor_info_list: A list of `TensorInfo`s that defines the outputs of + this `PTransform`. + is_partitionable: Whether or not this PTransform is partitionable. + cache_coder: (optional) A `CacheCoder` instance. + label: A unique label for this operation. + """ + + __slots__ = () + + def __new__( cls, - ptransform=ptransform, - output_tensor_info_list=output_tensor_info_list, - is_partitionable=is_partitionable, - cache_coder=cache_coder, - label=_make_label(cls)) + ptransform: Any, + output_tensor_info_list: Sequence[TensorInfo], + is_partitionable: bool, + cache_coder: Optional[CacheCoder] = None, + ): + return super(PTransform, cls).__new__( + cls, + ptransform=ptransform, + output_tensor_info_list=output_tensor_info_list, + is_partitionable=is_partitionable, + cache_coder=cache_coder, + label=_make_label(cls), + ) - @property - def output_tensor_infos(self): - return self.output_tensor_info_list + @property + def output_tensor_infos(self): + return self.output_tensor_info_list class EncodeCache( - tfx_namedtuple.namedtuple('EncodeCache', ['coder', 'label']), - nodes.OperationDef): - """OperationDef for encoding a cache instance. + tfx_namedtuple.namedtuple("EncodeCache", ["coder", "label"]), nodes.OperationDef +): + """OperationDef for encoding a cache instance. - Fields: - coder: An instance of CacheCoder used to encode cache. - label: A unique label for this operation. - """ - __slots__ = () + Fields: + coder: An instance of CacheCoder used to encode cache. + label: A unique label for this operation. + """ - @property - def is_partitionable(self): - return True + __slots__ = () + + @property + def is_partitionable(self): + return True class InstrumentDatasetCache( - tfx_namedtuple.namedtuple('InstrumentDatasetCache', [ - 'input_cache_dataset_keys', 'num_encode_cache', 'num_decode_cache', - 'label' - ]), nodes.OperationDef): - """OperationDef instrumenting cached datasets. + tfx_namedtuple.namedtuple( + "InstrumentDatasetCache", + ["input_cache_dataset_keys", "num_encode_cache", "num_decode_cache", "label"], + ), + nodes.OperationDef, +): + """OperationDef instrumenting cached datasets. - Fields: - input_cache_dataset_keys: A dataset keys for which there's input cache. - num_encode_cache: Number of cache entries encoded. - num_decode_cache: Number of cache entries decoded. - label: A unique label for this operation. - """ - __slots__ = () + Fields: + input_cache_dataset_keys: A dataset keys for which there's input cache. + num_encode_cache: Number of cache entries encoded. + num_decode_cache: Number of cache entries decoded. + label: A unique label for this operation. + """ - @property - def is_partitionable(self): - return True + __slots__ = () + + @property + def is_partitionable(self): + return True class DecodeCache( - tfx_namedtuple.namedtuple('DecodeCache', - ['dataset_key', 'cache_key', 'coder', 'label']), - nodes.OperationDef): - """OperationDef for decoding a cache instance. + tfx_namedtuple.namedtuple( + "DecodeCache", ["dataset_key", "cache_key", "coder", "label"] + ), + nodes.OperationDef, +): + """OperationDef for decoding a cache instance. - Fields: - dataset_key: A dataset key. - cache_key: A cache entry key. - coder: An instance of CacheCoder used to decode cache. - label: A unique label for this operation. - """ - __slots__ = () + Fields: + dataset_key: A dataset key. + cache_key: A cache entry key. + coder: An instance of CacheCoder used to decode cache. + label: A unique label for this operation. + """ - def get_field_str(self, field_name): - if field_name == 'cache_key': - return '' - return super().get_field_str(field_name) + __slots__ = () - @property - def is_partitionable(self): - return True + def get_field_str(self, field_name): + if field_name == "cache_key": + return "" + return super().get_field_str(field_name) + @property + def is_partitionable(self): + return True -class AddKey( - tfx_namedtuple.namedtuple('AddKey', ['key', 'label']), nodes.OperationDef): - """An operation that represents adding a key to a value. - This operation represents a `beam.Map` that is applied to a PCollection. - For each element of the PCollection, this corresponding element of the output - PCollection is a tuple of (key, value). +class AddKey(tfx_namedtuple.namedtuple("AddKey", ["key", "label"]), nodes.OperationDef): + """An operation that represents adding a key to a value. - Attributes: - key: The key which should be added to each element of the input PCollection. - label: A unique label for this operation. - """ - __slots__ = () + This operation represents a `beam.Map` that is applied to a PCollection. + For each element of the PCollection, this corresponding element of the output + PCollection is a tuple of (key, value). - @property - def is_partitionable(self): - return True + Attributes + ---------- + key: The key which should be added to each element of the input PCollection. + label: A unique label for this operation. + """ + + __slots__ = () + + @property + def is_partitionable(self): + return True class FlattenLists( - tfx_namedtuple.namedtuple('FlattenLists', ['label']), nodes.OperationDef): - """An operation that represents flattening a PCollection of lists. + tfx_namedtuple.namedtuple("FlattenLists", ["label"]), nodes.OperationDef +): + """An operation that represents flattening a PCollection of lists. - Attributes: - label: A unique label for this operation. - """ + Attributes + ---------- + label: A unique label for this operation. + """ - def __new__(cls): - return super(FlattenLists, cls).__new__(cls, label=_make_label(cls)) + def __new__(cls): + return super(FlattenLists, cls).__new__(cls, label=_make_label(cls)) - @property - def is_partitionable(self): - return True + @property + def is_partitionable(self): + return True class ExtractCombineMergeOutputs( - tfx_namedtuple.namedtuple('ExtractOutputs', - ['output_tensor_info_list', 'label']), - AnalyzerDef): - """An operation that represents extracting outputs of a combine merge. - - This operation represents a `beam.Map` that is applied to a PCollection. - For each element of the PCollection, this corresponding element of the output - PCollection is a tuple of outputs. - - Attributes: - output_tensor_info_list: A list of `TensorInfo`s that defines the outputs of - this operation. - label: A unique label for this operation. - """ - __slots__ = () - - def __new__(cls, output_tensor_info_list): - return super(ExtractCombineMergeOutputs, cls).__new__( - cls, - output_tensor_info_list=output_tensor_info_list, - label=_make_label(cls)) + tfx_namedtuple.namedtuple("ExtractOutputs", ["output_tensor_info_list", "label"]), + AnalyzerDef, +): + """An operation that represents extracting outputs of a combine merge. - @property - def output_tensor_infos(self): - return self.output_tensor_info_list + This operation represents a `beam.Map` that is applied to a PCollection. + For each element of the PCollection, this corresponding element of the output + PCollection is a tuple of outputs. + + Attributes + ---------- + output_tensor_info_list: A list of `TensorInfo`s that defines the outputs of + this operation. + label: A unique label for this operation. + """ + + __slots__ = () + + def __new__(cls, output_tensor_info_list): + return super(ExtractCombineMergeOutputs, cls).__new__( + cls, output_tensor_info_list=output_tensor_info_list, label=_make_label(cls) + ) + + @property + def output_tensor_infos(self): + return self.output_tensor_info_list class ExtractPackedCombineMergeOutputs( - tfx_namedtuple.namedtuple('ExtractOutputs', - ['output_tensor_info_list', 'label']), - AnalyzerDef): - """An operation that represents extracting outputs of a packed combine merge. - - This operation represents a `beam.Map` that is applied to a PCollection. - For each element of the PCollection, this corresponding element of the output - PCollection is a tuple of outputs. - - Attributes: - output_tensor_info_list: A list of `TensorInfo`s that defines the outputs of - this operation. - label: A unique label for this operation. - """ - __slots__ = () - - @property - def output_tensor_infos(self): - return self.output_tensor_info_list + tfx_namedtuple.namedtuple("ExtractOutputs", ["output_tensor_info_list", "label"]), + AnalyzerDef, +): + """An operation that represents extracting outputs of a packed combine merge. + + This operation represents a `beam.Map` that is applied to a PCollection. + For each element of the PCollection, this corresponding element of the output + PCollection is a tuple of outputs. + + Attributes + ---------- + output_tensor_info_list: A list of `TensorInfo`s that defines the outputs of + this operation. + label: A unique label for this operation. + """ + + __slots__ = () + + @property + def output_tensor_infos(self): + return self.output_tensor_info_list diff --git a/tensorflow_transform/analyzers.py b/tensorflow_transform/analyzers.py index c393232..a1cc5e5 100644 --- a/tensorflow_transform/analyzers.py +++ b/tensorflow_transform/analyzers.py @@ -29,44 +29,46 @@ import pickle import re from typing import Any, Callable, Collection, List, Optional, Sequence, Tuple, Union -from absl import logging - import numpy as np import pyarrow as pa import tensorflow as tf -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import annotators -from tensorflow_transform import common -from tensorflow_transform import common_types -from tensorflow_transform import gaussianization -from tensorflow_transform import nodes -from tensorflow_transform import schema_inference -from tensorflow_transform import tf_utils +from absl import logging +from google.protobuf import descriptor_pb2 from tfx_bsl import sketches + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple from typing_extensions import Literal -from google.protobuf import descriptor_pb2 +from tensorflow_transform import ( + analyzer_nodes, + annotators, + common, + common_types, + gaussianization, + nodes, + schema_inference, + tf_utils, +) __all__ = [ - 'count_per_key', - 'covariance', - 'histogram', - 'max', - 'mean', - 'min', - 'pca', - 'quantiles', - 'size', - 'sum', - 'tukey_location', - 'tukey_scale', - 'tukey_h_params', - 'var', - 'vocabulary', + "count_per_key", + "covariance", + "histogram", + "max", + "mean", + "min", + "pca", + "quantiles", + "size", + "sum", + "tukey_location", + "tukey_scale", + "tukey_h_params", + "var", + "vocabulary", ] # This module defines max and min functions that override the builtins. @@ -74,11 +76,11 @@ builtin_min = min -DEFAULT_VOCABULARY_FILE_FORMAT: Literal['text'] = 'text' -ALLOWED_VOCABULARY_FILE_FORMATS = ('text', 'tfrecord_gzip') +DEFAULT_VOCABULARY_FILE_FORMAT: Literal["text"] = "text" +ALLOWED_VOCABULARY_FILE_FORMATS = ("text", "tfrecord_gzip") -VOCAB_FILENAME_PREFIX = 'vocab_' -VOCAB_FREQUENCY_FILENAME_PREFIX = 'vocab_frequency_' +VOCAB_FILENAME_PREFIX = "vocab_" +VOCAB_FREQUENCY_FILENAME_PREFIX = "vocab_frequency_" # Experimentally estimated value of top_k after which the exact `tft.vocabulary` # implementation becomes more efficient than @@ -88,7 +90,7 @@ # Matches empty strings and strings with \n or \r (including strings with \n or # \r that contain invalid UTF-8 characters). This has to follow the re2 syntax: # https://github.com/google/re2/wiki/Syntax. -_EMPTY_STRING_OR_NEWLINE_CHARS_REGEX = r'^$|\C*[\n\r]\C*' +_EMPTY_STRING_OR_NEWLINE_CHARS_REGEX = r"^$|\C*[\n\r]\C*" # For some input types, widen the output type of sum analyzer to avoid overflow. _SUM_OUTPUT_DTYPE_MAP = { @@ -121,1582 +123,1788 @@ def apply_cacheable_combine_operation( - combiner: analyzer_nodes.Combiner, - *tensor_inputs: common_types.TensorType) -> Tuple[nodes.ValueNode, ...]: - """Applies combine operation nodes over the whole dataset. + combiner: analyzer_nodes.Combiner, *tensor_inputs: common_types.TensorType +) -> Tuple[nodes.ValueNode, ...]: + """Applies combine operation nodes over the whole dataset. - Applied nodes are subject to analyzer cache optimization. + Applied nodes are subject to analyzer cache optimization. - Args: - combiner: Combiner to be applied. - *tensor_inputs: Tensors representing inputs to the combiner. + Args: + ---- + combiner: Combiner to be applied. + *tensor_inputs: Tensors representing inputs to the combiner. - Returns: - A tuple of ValueNodes representing outputs of the combiner. - """ - input_values_node = analyzer_nodes.get_input_tensors_value_nodes( - tensor_inputs) + Returns: + ------- + A tuple of ValueNodes representing outputs of the combiner. + """ + input_values_node = analyzer_nodes.get_input_tensors_value_nodes(tensor_inputs) - accumulate_outputs_value_nodes = nodes.apply_multi_output_operation( - analyzer_nodes.CacheableCombineAccumulate, - input_values_node, - combiner=combiner) + accumulate_outputs_value_nodes = nodes.apply_multi_output_operation( + analyzer_nodes.CacheableCombineAccumulate, input_values_node, combiner=combiner + ) - merge_outputs_value_nodes = nodes.apply_multi_output_operation( - analyzer_nodes.CacheableCombineMerge, - *accumulate_outputs_value_nodes, - combiner=combiner) + merge_outputs_value_nodes = nodes.apply_multi_output_operation( + analyzer_nodes.CacheableCombineMerge, + *accumulate_outputs_value_nodes, + combiner=combiner, + ) - return nodes.apply_multi_output_operation( - analyzer_nodes.ExtractCombineMergeOutputs, - *merge_outputs_value_nodes, - output_tensor_info_list=combiner.output_tensor_infos()) + return nodes.apply_multi_output_operation( + analyzer_nodes.ExtractCombineMergeOutputs, + *merge_outputs_value_nodes, + output_tensor_info_list=combiner.output_tensor_infos(), + ) def _apply_cacheable_combiner( - combiner: analyzer_nodes.Combiner, - *tensor_inputs: common_types.TensorType) -> Tuple[tf.Tensor, ...]: - """Applies the combiner over the whole dataset possibly utilizing cache. + combiner: analyzer_nodes.Combiner, *tensor_inputs: common_types.TensorType +) -> Tuple[tf.Tensor, ...]: + """Applies the combiner over the whole dataset possibly utilizing cache. - Similar to above but returns a tuple of output tensors. + Similar to above but returns a tuple of output tensors. - Args: - combiner: Combiner to be applied. - *tensor_inputs: Tensors representing inputs to the combiner. + Args: + ---- + combiner: Combiner to be applied. + *tensor_inputs: Tensors representing inputs to the combiner. - Returns: - A tuple of tensors representing outputs of the combiner. - """ - outputs_value_nodes = apply_cacheable_combine_operation( - combiner, *tensor_inputs) - return tuple(map(analyzer_nodes.wrap_as_tensor, outputs_value_nodes)) # pytype: disable=bad-return-type + Returns: + ------- + A tuple of tensors representing outputs of the combiner. + """ + outputs_value_nodes = apply_cacheable_combine_operation(combiner, *tensor_inputs) + return tuple( + map(analyzer_nodes.wrap_as_tensor, outputs_value_nodes) + ) # pytype: disable=bad-return-type def _apply_cacheable_combiner_per_key( - combiner: analyzer_nodes.Combiner, - *tensor_inputs: common_types.TensorType) -> Tuple[tf.Tensor, ...]: - """Similar to _apply_cacheable_combiner but this is computed per key.""" - input_values_node = analyzer_nodes.get_input_tensors_value_nodes( - tensor_inputs) + combiner: analyzer_nodes.Combiner, *tensor_inputs: common_types.TensorType +) -> Tuple[tf.Tensor, ...]: + """Similar to _apply_cacheable_combiner but this is computed per key.""" + input_values_node = analyzer_nodes.get_input_tensors_value_nodes(tensor_inputs) - accumulate_outputs_value_nodes = nodes.apply_multi_output_operation( - analyzer_nodes.CacheableCombinePerKeyAccumulate, - input_values_node, - combiner=combiner) + accumulate_outputs_value_nodes = nodes.apply_multi_output_operation( + analyzer_nodes.CacheableCombinePerKeyAccumulate, + input_values_node, + combiner=combiner, + ) - merge_output_value_node = nodes.apply_operation( - analyzer_nodes.CacheableCombinePerKeyMerge, - *accumulate_outputs_value_nodes, - combiner=combiner) + merge_output_value_node = nodes.apply_operation( + analyzer_nodes.CacheableCombinePerKeyMerge, + *accumulate_outputs_value_nodes, + combiner=combiner, + ) - output_value_nodes = nodes.apply_multi_output_operation( - analyzer_nodes.CacheableCombinePerKeyFormatKeys, - merge_output_value_node, - combiner=combiner) + output_value_nodes = nodes.apply_multi_output_operation( + analyzer_nodes.CacheableCombinePerKeyFormatKeys, + merge_output_value_node, + combiner=combiner, + ) - return tuple(map(analyzer_nodes.wrap_as_tensor, output_value_nodes)) + return tuple(map(analyzer_nodes.wrap_as_tensor, output_value_nodes)) def _apply_cacheable_combiner_per_key_large( - combiner: analyzer_nodes.Combiner, key_vocabulary_filename: str, - *tensor_inputs: common_types.TensorType + combiner: analyzer_nodes.Combiner, + key_vocabulary_filename: str, + *tensor_inputs: common_types.TensorType, ) -> Union[tf.Tensor, tf.saved_model.Asset]: - """Similar to above but saves the combined result to a file.""" - input_values_node = analyzer_nodes.get_input_tensors_value_nodes( - tensor_inputs) - - accumulate_outputs_value_node = nodes.apply_operation( - analyzer_nodes.CacheableCombinePerKeyAccumulate, - input_values_node, - combiner=combiner) - - merge_output_value_node = nodes.apply_operation( - analyzer_nodes.CacheableCombinePerKeyMerge, - accumulate_outputs_value_node, - combiner=combiner) - - keys_and_values_node = nodes.apply_operation( - analyzer_nodes.CacheableCombinePerKeyFormatLarge, - merge_output_value_node) - - # `store_frequency` is True by default because we want to write some values - # alongside the key "vocabulary". Without doing so it would be equivalent to - # vanilla vocabulary analzyer. `fingerprint_shuffle` is not as important but - # signifies that the values are not required to be ordered here. - key_vocabulary_filename_node = nodes.apply_operation( - analyzer_nodes.VocabularyOrderAndWrite, - keys_and_values_node, - vocab_filename=key_vocabulary_filename, - store_frequency=True, - fingerprint_shuffle=True, - # TODO(b/62379925): Use tfrecord. - file_format='text') - - return analyzer_nodes.wrap_as_tensor(key_vocabulary_filename_node) + """Similar to above but saves the combined result to a file.""" + input_values_node = analyzer_nodes.get_input_tensors_value_nodes(tensor_inputs) + accumulate_outputs_value_node = nodes.apply_operation( + analyzer_nodes.CacheableCombinePerKeyAccumulate, + input_values_node, + combiner=combiner, + ) -class NumPyCombiner(analyzer_nodes.Combiner): - """Combines the PCollection only on the 0th dimension using nparray. - - Attributes: - fn: The numpy function representing the reduction to be done. - default_accumulator_value: The default value each accumulator entry is - initialized to. - output_dtypes: The numpy dtype to cast each output to. - output_shapes: List of tuples representing the shapes of the outputs or - Nones if the shapes are not fully defined. - """ - - def __init__(self, fn, default_accumulator_value, output_dtypes, - output_shapes): - self._fn = fn - self._default_accumulator_value = default_accumulator_value - self._default_sub_accumulator = np.array(default_accumulator_value) - self._output_dtypes = output_dtypes - if not all( - isinstance(shape, (tuple, type(None))) for shape in output_shapes): - raise TypeError('Expected all tuples or Nones, but got %r' % - output_shapes) - self._output_shapes = output_shapes - if np.isnan(default_accumulator_value): - # This case is needed because np.nan != np.nan. - self._is_default_sub_accumulator = self._equals_to_scalar_nan - else: - self._is_default_sub_accumulator = self._equals_to_default_sub_accumulator - - def _equals_to_scalar_nan(self, array): - return not array.shape and np.isnan(array) - - def _equals_to_default_sub_accumulator(self, array): - # Note that `np.array_equal` below does at most per-element comparison of - # 0-dim arrays since `_default_sub_accumulator` is a 0-dim array, and - # `np.array_equal` exits early on a shape mismatch. - return np.array_equal(array, self._default_sub_accumulator) - - def _is_default_sub_accumulator(self, array): - raise NotImplementedError('Implementation should be set in __init__.') - - def create_accumulator(self): - return [ - self._create_sub_accumulator(shape) - for shape in self._output_shapes - ] - - def _create_sub_accumulator(self, shape): - # Returns a default subaccumulator of the given shape if it's fully defined - # and a 0-dim default array otherwise. - if shape is None: - return self._default_sub_accumulator - else: - return np.full(shape, self._default_accumulator_value) - - def add_input(self, accumulator, batch_values): - # TODO(b/112414577): Go back to accepting only a single input. - # See comment in _numeric_combine. - # If the first subaccumulator is default, then the accumulator is default - # and can be discarded. - if self._is_default_sub_accumulator(accumulator[0]): - return batch_values - else: - return [ - self._fn((sub_accumulator, batch_value), axis=0) - for sub_accumulator, batch_value in zip(accumulator, batch_values) - ] - - def merge_accumulators(self, accumulators): - # If the first subaccumulator is default, then the accumulator is default - # and can be discarded. - non_default_accumulators = [ - accumulator for accumulator in accumulators - if not self._is_default_sub_accumulator(accumulator[0]) - ] - if non_default_accumulators: - return [ - # numpy's sum, min, max, etc functions operate on array-like objects, - # but not arbitrary iterables. Convert the provided sub_accumulators - # into a list. - self._fn(list(sub_accumulators), axis=0) - for sub_accumulators in zip(*non_default_accumulators) - ] - else: - return self.create_accumulator() + merge_output_value_node = nodes.apply_operation( + analyzer_nodes.CacheableCombinePerKeyMerge, + accumulate_outputs_value_node, + combiner=combiner, + ) + + keys_and_values_node = nodes.apply_operation( + analyzer_nodes.CacheableCombinePerKeyFormatLarge, merge_output_value_node + ) + + # `store_frequency` is True by default because we want to write some values + # alongside the key "vocabulary". Without doing so it would be equivalent to + # vanilla vocabulary analzyer. `fingerprint_shuffle` is not as important but + # signifies that the values are not required to be ordered here. + key_vocabulary_filename_node = nodes.apply_operation( + analyzer_nodes.VocabularyOrderAndWrite, + keys_and_values_node, + vocab_filename=key_vocabulary_filename, + store_frequency=True, + fingerprint_shuffle=True, + # TODO(b/62379925): Use tfrecord. + file_format="text", + ) - def extract_output(self, accumulator): - # For each output, cast that output to the specified type. Note there - # will be one output for each input tensor to the analyzer. - return [ - sub_accumulator.astype(output_dtype) for sub_accumulator, output_dtype - in zip(accumulator, self._output_dtypes) - ] + return analyzer_nodes.wrap_as_tensor(key_vocabulary_filename_node) - def output_tensor_infos(self): - return [ - analyzer_nodes.TensorInfo(tf.as_dtype(dtype), shape, None) - for dtype, shape in zip(self._output_dtypes, self._output_shapes) - ] + +class NumPyCombiner(analyzer_nodes.Combiner): + """Combines the PCollection only on the 0th dimension using nparray. + + Attributes + ---------- + fn: The numpy function representing the reduction to be done. + default_accumulator_value: The default value each accumulator entry is + initialized to. + output_dtypes: The numpy dtype to cast each output to. + output_shapes: List of tuples representing the shapes of the outputs or + Nones if the shapes are not fully defined. + """ + + def __init__(self, fn, default_accumulator_value, output_dtypes, output_shapes): + self._fn = fn + self._default_accumulator_value = default_accumulator_value + self._default_sub_accumulator = np.array(default_accumulator_value) + self._output_dtypes = output_dtypes + if not all(isinstance(shape, (tuple, type(None))) for shape in output_shapes): + raise TypeError("Expected all tuples or Nones, but got %r" % output_shapes) + self._output_shapes = output_shapes + if np.isnan(default_accumulator_value): + # This case is needed because np.nan != np.nan. + self._is_default_sub_accumulator = self._equals_to_scalar_nan + else: + self._is_default_sub_accumulator = self._equals_to_default_sub_accumulator + + def _equals_to_scalar_nan(self, array): + return not array.shape and np.isnan(array) + + def _equals_to_default_sub_accumulator(self, array): + # Note that `np.array_equal` below does at most per-element comparison of + # 0-dim arrays since `_default_sub_accumulator` is a 0-dim array, and + # `np.array_equal` exits early on a shape mismatch. + return np.array_equal(array, self._default_sub_accumulator) + + def _is_default_sub_accumulator(self, array): + raise NotImplementedError("Implementation should be set in __init__.") + + def create_accumulator(self): + return [self._create_sub_accumulator(shape) for shape in self._output_shapes] + + def _create_sub_accumulator(self, shape): + # Returns a default subaccumulator of the given shape if it's fully defined + # and a 0-dim default array otherwise. + if shape is None: + return self._default_sub_accumulator + else: + return np.full(shape, self._default_accumulator_value) + + def add_input(self, accumulator, batch_values): + # TODO(b/112414577): Go back to accepting only a single input. + # See comment in _numeric_combine. + # If the first subaccumulator is default, then the accumulator is default + # and can be discarded. + if self._is_default_sub_accumulator(accumulator[0]): + return batch_values + else: + return [ + self._fn((sub_accumulator, batch_value), axis=0) + for sub_accumulator, batch_value in zip(accumulator, batch_values) + ] + + def merge_accumulators(self, accumulators): + # If the first subaccumulator is default, then the accumulator is default + # and can be discarded. + non_default_accumulators = [ + accumulator + for accumulator in accumulators + if not self._is_default_sub_accumulator(accumulator[0]) + ] + if non_default_accumulators: + return [ + # numpy's sum, min, max, etc functions operate on array-like objects, + # but not arbitrary iterables. Convert the provided sub_accumulators + # into a list. + self._fn(list(sub_accumulators), axis=0) + for sub_accumulators in zip(*non_default_accumulators) + ] + else: + return self.create_accumulator() + + def extract_output(self, accumulator): + # For each output, cast that output to the specified type. Note there + # will be one output for each input tensor to the analyzer. + return [ + sub_accumulator.astype(output_dtype) + for sub_accumulator, output_dtype in zip(accumulator, self._output_dtypes) + ] + + def output_tensor_infos(self): + return [ + analyzer_nodes.TensorInfo(tf.as_dtype(dtype), shape, None) + for dtype, shape in zip(self._output_dtypes, self._output_shapes) + ] def _get_output_shape_from_input(x): - if isinstance(x, tf.SparseTensor): - return x.get_shape().as_list()[1:] + if isinstance(x, tf.SparseTensor): + return x.get_shape().as_list()[1:] - # When reducing over batch dimensions, with known shape, the result will be - # the same shape as the input, but without the batch. - if x.shape.rank is not None: - return x.shape.as_list()[1:] - return (None,) + # When reducing over batch dimensions, with known shape, the result will be + # the same shape as the input, but without the batch. + if x.shape.rank is not None: + return x.shape.as_list()[1:] + return (None,) def _get_elementwise_per_key_output_shape( - x: tf.Tensor, key: Optional[tf.Tensor]) -> Optional[Tuple[int]]: - shape = x.get_shape() if key is None else x.get_shape()[1:] - return tuple(shape) if shape.is_fully_defined() else None + x: tf.Tensor, key: Optional[tf.Tensor] +) -> Optional[Tuple[int]]: + shape = x.get_shape() if key is None else x.get_shape()[1:] + return tuple(shape) if shape.is_fully_defined() else None # TODO(b/112414577): Go back to accepting only a single input. # Currently we accept multiple inputs so that we can implement min and max # with a single combiner. Once this is done, add a return pytype as well. -def _numeric_combine(inputs: List[tf.Tensor], - fn: Callable[[np.ndarray], np.ndarray], - default_accumulator_value: Union[float, int], - reduce_instance_dims: bool = True, - output_dtypes: Optional[List[tf.DType]] = None, - key: Optional[tf.Tensor] = None, - key_vocabulary_filename: Optional[str] = None): - """Apply a reduction, defined by a numpy function to multiple inputs. - - Args: - inputs: A list of tensors, which will be independently reduced. - fn: A function to reduce tensors across instances/batches, to get a single - output. - default_accumulator_value: The default scalar value that each accumulator - entry is initialized to. Must be properly processed by the reduction - function. - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - output_dtypes: (Optional) A list of dtypes of the output tensors. If None, - the output tensor has the same type as the input one. - key: (Optional) Apply the same operation, but on a per-key basis. - key_vocabulary_filename: (Optional) The file name for the key-output mapping - file. If None and key are provided, this combiner assumes the keys fit in - memory and will not store the result in a file. If empty string, a file - name will be chosen based on the current scope. If not an empty string, - should be unique within a given preprocessing function. - - Returns: - Either: - (A) A list of Tensors with the same length as `inputs`, representing the - input Tensors that have been reduced by `fn` across instances and - batches (if key_vocabulary_filename is None). - (B) A Tensor with the filename where the key-value mapping is stored (if - key_vocabulary_filename is not None). - """ - for x in inputs: - if not isinstance(x, tf.Tensor): - raise TypeError('Expected a Tensor, but got %r' % x) - if not np.isscalar(default_accumulator_value): - raise TypeError('Expected a scalar, but got %r' % default_accumulator_value) - - if output_dtypes is None: - output_dtypes = [x.dtype for x in inputs] - if reduce_instance_dims: - # If reducing over all dimensions, result is scalar. - output_shapes = [() for _ in inputs] - else: - # Reducing over batch dimensions. - output_shapes = [ - _get_elementwise_per_key_output_shape(x, key) for x in inputs - ] - combiner = NumPyCombiner(fn, default_accumulator_value, - [dtype.as_numpy_dtype for dtype in output_dtypes], - output_shapes) - if key is None: - return _apply_cacheable_combiner(combiner, *inputs) - - if key_vocabulary_filename is None: - return _apply_cacheable_combiner_per_key(combiner, key, *inputs) - - return _apply_cacheable_combiner_per_key_large( - combiner, _maybe_get_per_key_vocab_filename(key_vocabulary_filename), key, - *inputs) +def _numeric_combine( + inputs: List[tf.Tensor], + fn: Callable[[np.ndarray], np.ndarray], + default_accumulator_value: Union[float, int], + reduce_instance_dims: bool = True, + output_dtypes: Optional[List[tf.DType]] = None, + key: Optional[tf.Tensor] = None, + key_vocabulary_filename: Optional[str] = None, +): + """Apply a reduction, defined by a numpy function to multiple inputs. + + Args: + ---- + inputs: A list of tensors, which will be independently reduced. + fn: A function to reduce tensors across instances/batches, to get a single + output. + default_accumulator_value: The default scalar value that each accumulator + entry is initialized to. Must be properly processed by the reduction + function. + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + output_dtypes: (Optional) A list of dtypes of the output tensors. If None, + the output tensor has the same type as the input one. + key: (Optional) Apply the same operation, but on a per-key basis. + key_vocabulary_filename: (Optional) The file name for the key-output mapping + file. If None and key are provided, this combiner assumes the keys fit in + memory and will not store the result in a file. If empty string, a file + name will be chosen based on the current scope. If not an empty string, + should be unique within a given preprocessing function. + + Returns: + ------- + Either: + (A) A list of Tensors with the same length as `inputs`, representing the + input Tensors that have been reduced by `fn` across instances and + batches (if key_vocabulary_filename is None). + (B) A Tensor with the filename where the key-value mapping is stored (if + key_vocabulary_filename is not None). + """ + for x in inputs: + if not isinstance(x, tf.Tensor): + raise TypeError("Expected a Tensor, but got %r" % x) + if not np.isscalar(default_accumulator_value): + raise TypeError("Expected a scalar, but got %r" % default_accumulator_value) + + if output_dtypes is None: + output_dtypes = [x.dtype for x in inputs] + if reduce_instance_dims: + # If reducing over all dimensions, result is scalar. + output_shapes = [() for _ in inputs] + else: + # Reducing over batch dimensions. + output_shapes = [_get_elementwise_per_key_output_shape(x, key) for x in inputs] + combiner = NumPyCombiner( + fn, + default_accumulator_value, + [dtype.as_numpy_dtype for dtype in output_dtypes], + output_shapes, + ) + if key is None: + return _apply_cacheable_combiner(combiner, *inputs) + + if key_vocabulary_filename is None: + return _apply_cacheable_combiner_per_key(combiner, key, *inputs) + + return _apply_cacheable_combiner_per_key_large( + combiner, + _maybe_get_per_key_vocab_filename(key_vocabulary_filename), + key, + *inputs, + ) @common.log_api_use(common.ANALYZER_COLLECTION) def min( # pylint: disable=redefined-builtin x: common_types.TensorType, reduce_instance_dims: bool = True, - name: Optional[str] = None) -> tf.Tensor: - """Computes the minimum of the values of `x` over the whole dataset. + name: Optional[str] = None, +) -> tf.Tensor: + """Computes the minimum of the values of `x` over the whole dataset. - In the case of a `CompositeTensor` missing values will be used in return - value: for float, NaN is used and for other dtypes the max is used. + In the case of a `CompositeTensor` missing values will be used in return + value: for float, NaN is used and for other dtypes the max is used. - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a `Tensor` of the same shape as the input. - name: (Optional) A name for this operation. + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a `Tensor` of the same shape as the input. + name: (Optional) A name for this operation. - Returns: - A `Tensor` with the same type as `x`. + Returns: + ------- + A `Tensor` with the same type as `x`. - Raises: - TypeError: If the type of `x` is not supported. - """ - with tf.compat.v1.name_scope(name, 'min'): - return _min_and_max(x, reduce_instance_dims, name)[0] + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + with tf.compat.v1.name_scope(name, "min"): + return _min_and_max(x, reduce_instance_dims, name)[0] @common.log_api_use(common.ANALYZER_COLLECTION) def max( # pylint: disable=redefined-builtin x: common_types.TensorType, reduce_instance_dims: bool = True, - name: Optional[str] = None) -> tf.Tensor: - """Computes the maximum of the values of `x` over the whole dataset. - - In the case of a `CompositeTensor` missing values will be used in return - value: for float, NaN is used and for other dtypes the min is used. - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - name: (Optional) A name for this operation. - - Returns: - A `Tensor`. Has the same type as `x`. - Raises: - TypeError: If the type of `x` is not supported. - """ - with tf.compat.v1.name_scope(name, 'max'): - return _min_and_max(x, reduce_instance_dims, name)[1] - - -def _min_and_max(x: common_types.TensorType, - reduce_instance_dims: bool = True, - name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]: - """Computes the min and max of the values of `x`. - - In the case of a `CompositeTensor` missing values will be used in return - value: - for float, NaN is used and for other dtypes the min is used. - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - name: (Optional) A name for this operation. - - Returns: - Two `Tensor`s. Both have the same type as `x`. - - Raises: - TypeError: If the type of `x` is not supported. - """ - with tf.compat.v1.name_scope(name, 'min_and_max'): - output_dtype = x.dtype - if (not reduce_instance_dims and isinstance(x, tf.SparseTensor) and - x.dtype.is_floating): - combine_fn = np.nanmax - default_accumulator_value = (np.nan if x.dtype.is_floating else - -output_dtype.max) - elif not reduce_instance_dims and isinstance(x, tf.RaggedTensor): - raise NotImplementedError( - 'Elementwise min_and_max does not support RaggedTensors.') - else: - combine_fn = np.max - default_accumulator_value = (-np.inf if x.dtype.is_floating else - -output_dtype.max) + name: Optional[str] = None, +) -> tf.Tensor: + """Computes the maximum of the values of `x` over the whole dataset. + + In the case of a `CompositeTensor` missing values will be used in return + value: for float, NaN is used and for other dtypes the min is used. - x_batch_minus_min, x_batch_max = tf_utils.reduce_batch_minus_min_and_max( - x, reduce_instance_dims) + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + name: (Optional) A name for this operation. - minus_x_min, x_max = _numeric_combine( # pylint: disable=unbalanced-tuple-unpacking - inputs=[x_batch_minus_min, x_batch_max], - fn=combine_fn, - default_accumulator_value=default_accumulator_value, - reduce_instance_dims=reduce_instance_dims) - return tf.cast(0 - minus_x_min, output_dtype), tf.cast(x_max, output_dtype) + Returns: + ------- + A `Tensor`. Has the same type as `x`. + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + with tf.compat.v1.name_scope(name, "max"): + return _min_and_max(x, reduce_instance_dims, name)[1] -def _min_and_max_per_key( + +def _min_and_max( x: common_types.TensorType, - key: common_types.TensorType, reduce_instance_dims: bool = True, - key_vocabulary_filename: Optional[str] = None, - name: Optional[str] = None -) -> Union[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor]: - """Computes the min and max of the values of `x`. - - In the case of a `CompositeTensor` missing values will be used in return - value: for float, NaN is used and for other dtypes the min is used. - - This function operates under the assumption that the size of the key set - is small enough to fit in memory. Anything above a certain size larger is not - guaranteed to be handled properly, but support for larger key sets may be - available in a future version. - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. - key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string. If - `x` is a `CompositeTensor`, `key` must exactly match `x` in everything - except values. - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. The False - case is not currently supported for _min_and_max_per_key. - key_vocabulary_filename: (Optional) The file name for the key-output mapping - file. If None and key are provided, this combiner assumes the keys fit in - memory and will not store the result in a file. If empty string, a file - name will be chosen based on the current scope. If not an empty string, - should be unique within a given preprocessing function. - name: (Optional) A name for this operation. - - Returns: - Either: - (A) Three `Tensor`s. The first is the key vocab of type tf.string, and the - second two have same type as `x` (if key_vocabulary_filename is None). - (B) The filename where the key-value mapping is stored (if - key_vocabulary_filename is not None). - - Raises: - TypeError: If the type of `x` is not supported. - """ - if key is None: - raise ValueError('A key is required for _min_and_max_per_key') - - if not reduce_instance_dims and isinstance( - x, (tf.SparseTensor, tf.RaggedTensor)): - raise NotImplementedError( - 'Per-key elementwise reduction of Composite Tensors not supported ') - - with tf.compat.v1.name_scope(name, 'min_and_max_per_key'): - output_dtype = x.dtype - if (not reduce_instance_dims and - isinstance(x, - (tf.SparseTensor, tf.RaggedTensor)) and x.dtype.is_floating): - combine_fn = np.nanmax - default_accumulator_value = (np.nan if x.dtype.is_floating else - -output_dtype.max) - else: - combine_fn = np.max - default_accumulator_value = (-np.inf if x.dtype.is_floating else - -output_dtype.max) - - key_vocab, x_batch_minus_min, x_batch_max = ( - tf_utils.reduce_batch_minus_min_and_max_per_key(x, key, - reduce_instance_dims)) - - key_values = _numeric_combine( # pylint: disable=unbalanced-tuple-unpacking - inputs=[x_batch_minus_min, x_batch_max], - fn=combine_fn, - default_accumulator_value=default_accumulator_value, - reduce_instance_dims=reduce_instance_dims, - key=key_vocab, - key_vocabulary_filename=key_vocabulary_filename) + name: Optional[str] = None, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes the min and max of the values of `x`. - if key_vocabulary_filename is not None: - return key_values # pytype: disable=bad-return-type # always-use-return-annotations + In the case of a `CompositeTensor` missing values will be used in return + value: + for float, NaN is used and for other dtypes the min is used. - key, minus_x_min, x_max = key_values - return ( - key, - tf.cast(0 - minus_x_min, output_dtype), - tf.cast(x_max, output_dtype)) + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + name: (Optional) A name for this operation. + Returns: + ------- + Two `Tensor`s. Both have the same type as `x`. -def _sum_combine_fn_and_dtype( - input_dtype: tf.DType -) -> Tuple[tf.DType, Callable[[np.ndarray], np.ndarray]]: - output_dtype = _SUM_OUTPUT_DTYPE_MAP.get(input_dtype) - if output_dtype is None: - raise TypeError('Tensor type %r is not supported' % input_dtype) + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + with tf.compat.v1.name_scope(name, "min_and_max"): + output_dtype = x.dtype + if ( + not reduce_instance_dims + and isinstance(x, tf.SparseTensor) + and x.dtype.is_floating + ): + combine_fn = np.nanmax + default_accumulator_value = ( + np.nan if x.dtype.is_floating else -output_dtype.max + ) + elif not reduce_instance_dims and isinstance(x, tf.RaggedTensor): + raise NotImplementedError( + "Elementwise min_and_max does not support RaggedTensors." + ) + else: + combine_fn = np.max + default_accumulator_value = ( + -np.inf if x.dtype.is_floating else -output_dtype.max + ) + + x_batch_minus_min, x_batch_max = tf_utils.reduce_batch_minus_min_and_max( + x, reduce_instance_dims + ) - return output_dtype, functools.partial( - np.sum, dtype=output_dtype.as_numpy_dtype) + minus_x_min, x_max = _numeric_combine( # pylint: disable=unbalanced-tuple-unpacking + inputs=[x_batch_minus_min, x_batch_max], + fn=combine_fn, + default_accumulator_value=default_accumulator_value, + reduce_instance_dims=reduce_instance_dims, + ) + return tf.cast(0 - minus_x_min, output_dtype), tf.cast(x_max, output_dtype) -@common.log_api_use(common.ANALYZER_COLLECTION) -def sum( # pylint: disable=redefined-builtin +def _min_and_max_per_key( x: common_types.TensorType, + key: common_types.TensorType, reduce_instance_dims: bool = True, - name: Optional[str] = None) -> tf.Tensor: - """Computes the sum of the values of a `Tensor` over the whole dataset. - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating - point (float{16|32|64}),integral (int{8|16|32|64}), or unsigned - integral (uint{8|16}). - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - name: (Optional) A name for this operation. - - Returns: - A `Tensor` containing the sum. If `x` is float32 or float64, the sum will - have the same type as `x`. If `x` is float16, the output is cast to float32. - If `x` is integral, the output is cast to [u]int64. If `x` is sparse and - reduce_inst_dims is False will return 0 in place where column has no values - across batches. - - Raises: - TypeError: If the type of `x` is not supported. - """ - with tf.compat.v1.name_scope(name, 'sum'): - if reduce_instance_dims: - x = tf.reduce_sum(input_tensor=tf_utils.get_values(x)) - elif isinstance(x, tf.SparseTensor): - if x.dtype == tf.uint8 or x.dtype == tf.uint16: - x = tf.cast(x, tf.int64) - elif x.dtype == tf.uint32 or x.dtype == tf.uint64: - raise TypeError('Data type %r is not supported' % x.dtype) - x = tf.sparse.reduce_sum(x, axis=0) - elif isinstance(x, tf.RaggedTensor): - raise NotImplementedError( - 'Elementwise sum does not support RaggedTensors.') - else: - x = tf.reduce_sum(input_tensor=x, axis=0) - output_dtype, sum_fn = _sum_combine_fn_and_dtype(x.dtype) - return _numeric_combine( - inputs=[x], - fn=sum_fn, - default_accumulator_value=0, - reduce_instance_dims=reduce_instance_dims, - output_dtypes=[output_dtype])[0] + key_vocabulary_filename: Optional[str] = None, + name: Optional[str] = None, +) -> Union[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor]: + """Computes the min and max of the values of `x`. + In the case of a `CompositeTensor` missing values will be used in return + value: for float, NaN is used and for other dtypes the min is used. -def remove_leftmost_boundary(boundaries: tf.Tensor) -> tf.Tensor: - """Removes the leftmost boundary from [1, None]-shaped `Tensor` of buckets.""" - return boundaries[:, 1:] + This function operates under the assumption that the size of the key set + is small enough to fit in memory. Anything above a certain size larger is not + guaranteed to be handled properly, but support for larger key sets may be + available in a future version. + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. + key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string. If + `x` is a `CompositeTensor`, `key` must exactly match `x` in everything + except values. + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. The False + case is not currently supported for _min_and_max_per_key. + key_vocabulary_filename: (Optional) The file name for the key-output mapping + file. If None and key are provided, this combiner assumes the keys fit in + memory and will not store the result in a file. If empty string, a file + name will be chosen based on the current scope. If not an empty string, + should be unique within a given preprocessing function. + name: (Optional) A name for this operation. -@common.log_api_use(common.ANALYZER_COLLECTION) -def histogram(x: common_types.TensorType, - boundaries: Optional[Union[tf.Tensor, int]] = None, - categorical: Optional[bool] = False, - name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]: - """Computes a histogram over x, given the bin boundaries or bin count. - - Ex (1): - counts, boundaries = histogram([0, 1, 0, 1, 0, 3, 0, 1], range(5)) - counts: [4, 3, 0, 1, 0] - boundaries: [0, 1, 2, 3, 4] - - Ex (2): - Can be used to compute class weights. - counts, classes = histogram([0, 1, 0, 1, 0, 3, 0, 1], categorical=True) - probabilities = counts / tf.reduce_sum(counts) - class_weights = dict(map(lambda (a, b): (a.numpy(), 1.0 / b.numpy()), - zip(classes, probabilities))) - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. - boundaries: (Optional) A `Tensor` or `int` used to build the histogram; - ignored if `categorical` is True. If possible, provide boundaries as - multiple sorted values. Default to 10 intervals over the 0-1 range, or - find the min/max if an int is provided (not recommended because - multi-phase analysis is inefficient). - categorical: (Optional) A `bool` that treats `x` as discrete values if true. - name: (Optional) A name for this operation. - - Returns: - counts: The histogram, as counts per bin. - boundaries: A `Tensor` used to build the histogram representing boundaries. - """ - - with tf.compat.v1.name_scope(name, 'histogram'): - x = tf.reshape(tf_utils.get_values(x), [-1]) - if categorical: - x_dtype = x.dtype - x = x if x_dtype == tf.string else tf.strings.as_string(x) - elements, counts = count_per_key(x) - if x_dtype != elements.dtype: - elements = tf.strings.to_number(elements, tf.int64) - return counts, elements - - if boundaries is None: - boundaries = tf.range(11, dtype=tf.float32) / 10.0 - elif isinstance(boundaries, int) or (isinstance(boundaries, tf.Tensor) and - boundaries.get_shape().ndims == 0): - min_value, max_value = _min_and_max(x, True) - boundaries = tf.linspace( - tf.cast(min_value, tf.float32), tf.cast(max_value, tf.float32), - tf.cast(boundaries, tf.int64)) - - # Shift the boundaries slightly to account for floating point errors, - # and due to the fact that the rightmost boundary is essentially ignored. - boundaries = tf.expand_dims(tf.cast(boundaries, tf.float32), 0) - 0.0001 - - bucket_indices = tf_utils.assign_buckets( - tf.cast(x, tf.float32), remove_leftmost_boundary(boundaries)) - bucket_vocab, counts = count_per_key(tf.strings.as_string(bucket_indices)) - counts = tf_utils.reorder_histogram(bucket_vocab, counts, - tf.size(boundaries) - 1) - return counts, boundaries + Returns: + ------- + Either: + (A) Three `Tensor`s. The first is the key vocab of type tf.string, and the + second two have same type as `x` (if key_vocabulary_filename is None). + (B) The filename where the key-value mapping is stored (if + key_vocabulary_filename is not None). + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + if key is None: + raise ValueError("A key is required for _min_and_max_per_key") -@common.log_api_use(common.ANALYZER_COLLECTION) -def size(x: common_types.TensorType, - reduce_instance_dims: bool = True, - name: Optional[str] = None) -> tf.Tensor: - """Computes the total size of instances in a `Tensor` over the whole dataset. - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - name: (Optional) A name for this operation. - - Returns: - A `Tensor` of type int64. - """ - with tf.compat.v1.name_scope(name, 'size'): - # Note: Calling `sum` defined in this module, not the builtin. - if isinstance(x, tf.SparseTensor): - ones_like_x = tf.SparseTensor( - indices=x.indices, - values=tf.ones_like(x.values, tf.int64), - dense_shape=x.dense_shape) - else: - ones_like_x = tf.ones_like(x, dtype=tf.int64) - return sum(ones_like_x, reduce_instance_dims) + if not reduce_instance_dims and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): + raise NotImplementedError( + "Per-key elementwise reduction of Composite Tensors not supported " + ) + with tf.compat.v1.name_scope(name, "min_and_max_per_key"): + output_dtype = x.dtype + if ( + not reduce_instance_dims + and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)) + and x.dtype.is_floating + ): + combine_fn = np.nanmax + default_accumulator_value = ( + np.nan if x.dtype.is_floating else -output_dtype.max + ) + else: + combine_fn = np.max + default_accumulator_value = ( + -np.inf if x.dtype.is_floating else -output_dtype.max + ) + + key_vocab, x_batch_minus_min, x_batch_max = ( + tf_utils.reduce_batch_minus_min_and_max_per_key( + x, key, reduce_instance_dims + ) + ) -@common.log_api_use(common.ANALYZER_COLLECTION) -def count_per_key(key: common_types.TensorType, - key_vocabulary_filename: Optional[str] = None, - name: Optional[str] = None): - """Computes the count of each element of a `Tensor`. - - Args: - key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string or - tf.int. - key_vocabulary_filename: (Optional) The file name for the key-output mapping - file. If None and key are provided, this combiner assumes the keys fit in - memory and will not store the result in a file. If empty string, a file - name will be chosen based on the current scope. If not an empty string, - should be unique within a given preprocessing function. - name: (Optional) A name for this operation. - - Returns: - Either: - (A) Two `Tensor`s: one the key vocab with dtype of input; - the other the count for each key, dtype tf.int64. (if - key_vocabulary_filename is None). - (B) The filename where the key-value mapping is stored (if - key_vocabulary_filename is not None). - - Raises: - TypeError: If the type of `x` is not supported. - """ - - with tf.compat.v1.name_scope(name, 'count_per_key'): - key_dtype = key.dtype - batch_keys, batch_counts = tf_utils.reduce_batch_count_per_key(key) - - output_dtype, sum_fn = _sum_combine_fn_and_dtype(tf.int64) - numeric_combine_result = _numeric_combine( - inputs=[batch_counts], - fn=sum_fn, - default_accumulator_value=0, - reduce_instance_dims=True, - output_dtypes=[output_dtype], - key=batch_keys, - key_vocabulary_filename=key_vocabulary_filename) + key_values = _numeric_combine( # pylint: disable=unbalanced-tuple-unpacking + inputs=[x_batch_minus_min, x_batch_max], + fn=combine_fn, + default_accumulator_value=default_accumulator_value, + reduce_instance_dims=reduce_instance_dims, + key=key_vocab, + key_vocabulary_filename=key_vocabulary_filename, + ) - if key_vocabulary_filename is not None: - return numeric_combine_result - keys, counts = numeric_combine_result - if key_dtype is not tf.string: - keys = tf.strings.to_number(keys, key_dtype) - return keys, counts + if key_vocabulary_filename is not None: + return key_values # pytype: disable=bad-return-type # always-use-return-annotations + key, minus_x_min, x_max = key_values + return ( + key, + tf.cast(0 - minus_x_min, output_dtype), + tf.cast(x_max, output_dtype), + ) -@common.log_api_use(common.ANALYZER_COLLECTION) -def mean(x: common_types.TensorType, - reduce_instance_dims: bool = True, - name: Optional[str] = None, - output_dtype: Optional[tf.DType] = None) -> tf.Tensor: - """Computes the mean of the values of a `Tensor` over the whole dataset. - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating - point (float{16|32|64}), or integral ([u]int{8|16|32|64}). - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - name: (Optional) A name for this operation. - output_dtype: (Optional) If not None, casts the output tensor to this type. - Returns: - A `Tensor` containing the mean. If `x` is floating point, the mean will have - the same type as `x`. If `x` is integral, the output is cast to float32. - NaNs and infinite input values are ignored. +def _sum_combine_fn_and_dtype( + input_dtype: tf.DType, +) -> Tuple[tf.DType, Callable[[np.ndarray], np.ndarray]]: + output_dtype = _SUM_OUTPUT_DTYPE_MAP.get(input_dtype) + if output_dtype is None: + raise TypeError("Tensor type %r is not supported" % input_dtype) - Raises: - TypeError: If the type of `x` is not supported. - """ - with tf.compat.v1.name_scope(name, 'mean'): - return _mean_and_var(x, reduce_instance_dims, output_dtype)[0] + return output_dtype, functools.partial(np.sum, dtype=output_dtype.as_numpy_dtype) @common.log_api_use(common.ANALYZER_COLLECTION) -def var(x: common_types.TensorType, - reduce_instance_dims: bool = True, - name: Optional[str] = None, - output_dtype: Optional[tf.DType] = None) -> tf.Tensor: - """Computes the variance of the values of a `Tensor` over the whole dataset. - - Uses the biased variance (0 delta degrees of freedom), as given by - (x - mean(x))**2 / length(x). - - Args: - x: `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating - point (float{16|32|64}), or integral ([u]int{8|16|32|64}). - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - name: (Optional) A name for this operation. - output_dtype: (Optional) If not None, casts the output tensor to this type. - - Returns: - A `Tensor` containing the variance. If `x` is floating point, the variance - will have the same type as `x`. If `x` is integral, the output is cast to - float32. NaNs and infinite input values are ignored. - - Raises: - TypeError: If the type of `x` is not supported. - """ - with tf.compat.v1.name_scope(name, 'var'): - return _mean_and_var(x, reduce_instance_dims, output_dtype)[1] - - -def _mean_and_var(x: common_types.TensorType, - reduce_instance_dims: bool = True, - output_dtype: Optional[tf.DType] = None): - """More efficient combined `mean` and `var`. See `var`.""" - if output_dtype is None: - output_dtype = _FLOAT_OUTPUT_DTYPE_MAP.get(x.dtype) - if output_dtype is None: - raise TypeError('Tensor type %r is not supported' % x.dtype) - if not reduce_instance_dims and isinstance(x, tf.RaggedTensor): - raise NotImplementedError( - 'Elementwise mean_and_var does not support RaggedTensors.') - - with tf.compat.v1.name_scope('mean_and_var'): +def sum( # pylint: disable=redefined-builtin + x: common_types.TensorType, + reduce_instance_dims: bool = True, + name: Optional[str] = None, +) -> tf.Tensor: + """Computes the sum of the values of a `Tensor` over the whole dataset. - x = tf.cast(x, output_dtype) + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating + point (float{16|32|64}),integral (int{8|16|32|64}), or unsigned + integral (uint{8|16}). + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + name: (Optional) A name for this operation. - x_count, x_mean, x_variance = ( - tf_utils.reduce_batch_count_mean_and_var(x, reduce_instance_dims)) + Returns: + ------- + A `Tensor` containing the sum. If `x` is float32 or float64, the sum will + have the same type as `x`. If `x` is float16, the output is cast to float32. + If `x` is integral, the output is cast to [u]int64. If `x` is sparse and + reduce_inst_dims is False will return 0 in place where column has no values + across batches. + + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + with tf.compat.v1.name_scope(name, "sum"): + if reduce_instance_dims: + x = tf.reduce_sum(input_tensor=tf_utils.get_values(x)) + elif isinstance(x, tf.SparseTensor): + if x.dtype == tf.uint8 or x.dtype == tf.uint16: + x = tf.cast(x, tf.int64) + elif x.dtype == tf.uint32 or x.dtype == tf.uint64: + raise TypeError("Data type %r is not supported" % x.dtype) + x = tf.sparse.reduce_sum(x, axis=0) + elif isinstance(x, tf.RaggedTensor): + raise NotImplementedError("Elementwise sum does not support RaggedTensors.") + else: + x = tf.reduce_sum(input_tensor=x, axis=0) + output_dtype, sum_fn = _sum_combine_fn_and_dtype(x.dtype) + return _numeric_combine( + inputs=[x], + fn=sum_fn, + default_accumulator_value=0, + reduce_instance_dims=reduce_instance_dims, + output_dtypes=[output_dtype], + )[0] - combine_inputs = _WeightedMeanAndVarAccumulator( - count=x_count, - mean=x_mean, - variance=x_variance, - weight=tf.zeros([], tf.float32)) - output_shape = () - if not reduce_instance_dims: - # We need to use tf.expand_dims to artificially add a batch dimension. - output_shape = _get_output_shape_from_input( - tf.expand_dims(x_count, axis=0)) +def remove_leftmost_boundary(boundaries: tf.Tensor) -> tf.Tensor: + """Removes the leftmost boundary from [1, None]-shaped `Tensor` of buckets.""" + return boundaries[:, 1:] - x_mean, x_var = _apply_cacheable_combiner( - WeightedMeanAndVarCombiner(output_dtype.as_numpy_dtype, output_shape), - *combine_inputs) - return x_mean, x_var +@common.log_api_use(common.ANALYZER_COLLECTION) +def histogram( + x: common_types.TensorType, + boundaries: Optional[Union[tf.Tensor, int]] = None, + categorical: Optional[bool] = False, + name: Optional[str] = None, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes a histogram over x, given the bin boundaries or bin count. + Ex (1): + counts, boundaries = histogram([0, 1, 0, 1, 0, 3, 0, 1], range(5)) + counts: [4, 3, 0, 1, 0] + boundaries: [0, 1, 2, 3, 4] -@common.log_api_use(common.ANALYZER_COLLECTION) -def tukey_location(x: common_types.TensorType, - reduce_instance_dims: Optional[bool] = True, - output_dtype: Optional[tf.DType] = None, - name: Optional[str] = None) -> tf.Tensor: - """Computes the location of the values of a `Tensor` over the whole dataset. - - This computes the location of x, assuming a Tukey HH distribution, i.e. - (x - tukey_location) / tukey_scale is a Tukey HH distribution with parameters - tukey_h_params. See the following publication for the definition of the Tukey - HH distribution: - - Todd C. Headrick, and Mohan D. Pant. "Characterizing Tukey h and - hh-Distributions through L-Moments and the L-Correlation," ISRN Applied - Mathematics, vol. 2012, 2012. doi:10.5402/2012/980153 - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating - point (float{16|32|64}), or integral ([u]int{8|16|32|64}). - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - output_dtype: (Optional) If not None, casts the output tensor to this type. - name: (Optional) A name for this operation. + Ex (2): + Can be used to compute class weights. + counts, classes = histogram([0, 1, 0, 1, 0, 3, 0, 1], categorical=True) + probabilities = counts / tf.reduce_sum(counts) + class_weights = dict(map(lambda (a, b): (a.numpy(), 1.0 / b.numpy()), + zip(classes, probabilities))) - Returns: - A `Tensor` containing the location. If `x` is floating point, the location - will have the same type as `x`. If `x` is integral, the output is cast to - float32. + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. + boundaries: (Optional) A `Tensor` or `int` used to build the histogram; + ignored if `categorical` is True. If possible, provide boundaries as + multiple sorted values. Default to 10 intervals over the 0-1 range, or + find the min/max if an int is provided (not recommended because + multi-phase analysis is inefficient). + categorical: (Optional) A `bool` that treats `x` as discrete values if true. + name: (Optional) A name for this operation. - Raises: - TypeError: If the type of `x` is not supported. - """ - with tf.compat.v1.name_scope(name, 'tukey_location'): - return _tukey_parameters(x, reduce_instance_dims, output_dtype)[0] + Returns: + ------- + counts: The histogram, as counts per bin. + boundaries: A `Tensor` used to build the histogram representing boundaries. + """ + with tf.compat.v1.name_scope(name, "histogram"): + x = tf.reshape(tf_utils.get_values(x), [-1]) + if categorical: + x_dtype = x.dtype + x = x if x_dtype == tf.string else tf.strings.as_string(x) + elements, counts = count_per_key(x) + if x_dtype != elements.dtype: + elements = tf.strings.to_number(elements, tf.int64) + return counts, elements + + if boundaries is None: + boundaries = tf.range(11, dtype=tf.float32) / 10.0 + elif isinstance(boundaries, int) or ( + isinstance(boundaries, tf.Tensor) and boundaries.get_shape().ndims == 0 + ): + min_value, max_value = _min_and_max(x, True) + boundaries = tf.linspace( + tf.cast(min_value, tf.float32), + tf.cast(max_value, tf.float32), + tf.cast(boundaries, tf.int64), + ) + + # Shift the boundaries slightly to account for floating point errors, + # and due to the fact that the rightmost boundary is essentially ignored. + boundaries = tf.expand_dims(tf.cast(boundaries, tf.float32), 0) - 0.0001 + + bucket_indices = tf_utils.assign_buckets( + tf.cast(x, tf.float32), remove_leftmost_boundary(boundaries) + ) + bucket_vocab, counts = count_per_key(tf.strings.as_string(bucket_indices)) + counts = tf_utils.reorder_histogram( + bucket_vocab, counts, tf.size(boundaries) - 1 + ) + return counts, boundaries @common.log_api_use(common.ANALYZER_COLLECTION) -def tukey_scale(x: common_types.TensorType, - reduce_instance_dims: Optional[bool] = True, - output_dtype: Optional[tf.DType] = None, - name: Optional[str] = None) -> tf.Tensor: - """Computes the scale of the values of a `Tensor` over the whole dataset. - - This computes the scale of x, assuming a Tukey HH distribution, i.e. - (x - tukey_location) / tukey_scale is a Tukey HH distribution with parameters - tukey_h_params. See the following publication for the definition of the Tukey - HH distribution: - - Todd C. Headrick, and Mohan D. Pant. "Characterizing Tukey h and - hh-Distributions through L-Moments and the L-Correlation," ISRN Applied - Mathematics, vol. 2012, 2012. doi:10.5402/2012/980153 - - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating - point (float{16|32|64}), or integral ([u]int{8|16|32|64}). - reduce_instance_dims: By default collapses the batch and instance dimensions +def size( + x: common_types.TensorType, + reduce_instance_dims: bool = True, + name: Optional[str] = None, +) -> tf.Tensor: + """Computes the total size of instances in a `Tensor` over the whole dataset. + + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. + reduce_instance_dims: By default collapses the batch and instance dimensions to arrive at a single scalar output. If False, only collapses the batch dimension and outputs a vector of the same shape as the input. - output_dtype: (Optional) If not None, casts the output tensor to this type. - name: (Optional) A name for this operation. - - Returns: - A `Tensor` containing the scale. If `x` is floating point, the location - will have the same type as `x`. If `x` is integral, the output is cast to - float32. + name: (Optional) A name for this operation. - Raises: - TypeError: If the type of `x` is not supported. - """ - with tf.compat.v1.name_scope(name, 'tukey_scale'): - return _tukey_parameters(x, reduce_instance_dims, output_dtype)[1] + Returns: + ------- + A `Tensor` of type int64. + """ + with tf.compat.v1.name_scope(name, "size"): + # Note: Calling `sum` defined in this module, not the builtin. + if isinstance(x, tf.SparseTensor): + ones_like_x = tf.SparseTensor( + indices=x.indices, + values=tf.ones_like(x.values, tf.int64), + dense_shape=x.dense_shape, + ) + else: + ones_like_x = tf.ones_like(x, dtype=tf.int64) + return sum(ones_like_x, reduce_instance_dims) @common.log_api_use(common.ANALYZER_COLLECTION) -def tukey_h_params(x: common_types.TensorType, - reduce_instance_dims: bool = True, - output_dtype: Optional[tf.DType] = None, - name: Optional[str] = None) -> Tuple[tf.Tensor, tf.Tensor]: - """Computes the h parameters of the values of a `Tensor` over the dataset. - - This computes the parameters (hl, hr) of the samples, assuming a Tukey HH - distribution, i.e. (x - tukey_location) / tukey_scale is a Tukey HH - distribution with parameters hl (left parameter) and hr (right parameter). - See the following publication for the definition of the Tukey HH distribution: - - Todd C. Headrick, and Mohan D. Pant. "Characterizing Tukey h and - hh-Distributions through L-Moments and the L-Correlation," ISRN Applied - Mathematics, vol. 2012, 2012. doi:10.5402/2012/980153 - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating - point (float{16|32|64}), or integral ([u]int{8|16|32|64}). - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single scalar output. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - output_dtype: (Optional) If not None, casts the output tensor to this type. - name: (Optional) A name for this operation. +def count_per_key( + key: common_types.TensorType, + key_vocabulary_filename: Optional[str] = None, + name: Optional[str] = None, +): + """Computes the count of each element of a `Tensor`. - Returns: - The tuple (hl, hr) containing two `Tensor` instances with the hl and hr - parameters. If `x` is floating point, each parameter will have the same type - as `x`. If `x` is integral, the output is cast to float32. + Args: + ---- + key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string or + tf.int. + key_vocabulary_filename: (Optional) The file name for the key-output mapping + file. If None and key are provided, this combiner assumes the keys fit in + memory and will not store the result in a file. If empty string, a file + name will be chosen based on the current scope. If not an empty string, + should be unique within a given preprocessing function. + name: (Optional) A name for this operation. + + Returns: + ------- + Either: + (A) Two `Tensor`s: one the key vocab with dtype of input; + the other the count for each key, dtype tf.int64. (if + key_vocabulary_filename is None). + (B) The filename where the key-value mapping is stored (if + key_vocabulary_filename is not None). - Raises: - TypeError: If the type of `x` is not supported. - """ - with tf.compat.v1.name_scope(name, 'tukey_h_params'): - return _tukey_parameters(x, reduce_instance_dims, output_dtype)[2:] + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + with tf.compat.v1.name_scope(name, "count_per_key"): + key_dtype = key.dtype + batch_keys, batch_counts = tf_utils.reduce_batch_count_per_key(key) + + output_dtype, sum_fn = _sum_combine_fn_and_dtype(tf.int64) + numeric_combine_result = _numeric_combine( + inputs=[batch_counts], + fn=sum_fn, + default_accumulator_value=0, + reduce_instance_dims=True, + output_dtypes=[output_dtype], + key=batch_keys, + key_vocabulary_filename=key_vocabulary_filename, + ) + if key_vocabulary_filename is not None: + return numeric_combine_result + keys, counts = numeric_combine_result + if key_dtype is not tf.string: + keys = tf.strings.to_number(keys, key_dtype) + return keys, counts -def _tukey_parameters( + +@common.log_api_use(common.ANALYZER_COLLECTION) +def mean( x: common_types.TensorType, reduce_instance_dims: bool = True, - output_dtype: Optional[tf.DType] = None -) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: - """Efficient computation of L-moments.""" - if output_dtype is None: - output_dtype = _FLOAT_OUTPUT_DTYPE_MAP.get(x.dtype) - if output_dtype is None: - raise TypeError('Tensor type %r is not supported' % x.dtype) + name: Optional[str] = None, + output_dtype: Optional[tf.DType] = None, +) -> tf.Tensor: + """Computes the mean of the values of a `Tensor` over the whole dataset. - with tf.compat.v1.name_scope('tukey_parameters'): + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating + point (float{16|32|64}), or integral ([u]int{8|16|32|64}). + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + name: (Optional) A name for this operation. + output_dtype: (Optional) If not None, casts the output tensor to this type. - x = tf.cast(x, output_dtype) + Returns: + ------- + A `Tensor` containing the mean. If `x` is floating point, the mean will have + the same type as `x`. If `x` is integral, the output is cast to float32. + NaNs and infinite input values are ignored. + + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + with tf.compat.v1.name_scope(name, "mean"): + return _mean_and_var(x, reduce_instance_dims, output_dtype)[0] - (count_l1, l1, count_l2, l2, count_l3, l3, count_l4, l4) = ( - tf_utils.reduce_batch_count_l_moments(x, reduce_instance_dims)) - combine_inputs = _LMomentsAccumulator( - count_l1=count_l1, - count_l2=count_l2, - count_l3=count_l3, - count_l4=count_l4, - l1=l1, - l2=l2, - l3=l3, - l4=l4) +@common.log_api_use(common.ANALYZER_COLLECTION) +def var( + x: common_types.TensorType, + reduce_instance_dims: bool = True, + name: Optional[str] = None, + output_dtype: Optional[tf.DType] = None, +) -> tf.Tensor: + """Computes the variance of the values of a `Tensor` over the whole dataset. - output_shape = () - if not reduce_instance_dims: - output_shape = _get_output_shape_from_input(x) + Uses the biased variance (0 delta degrees of freedom), as given by + (x - mean(x))**2 / length(x). - x_loc, x_scale, hl_param, hr_param = _apply_cacheable_combiner( - _LMomentsCombiner(output_dtype.as_numpy_dtype, output_shape), - *combine_inputs) + Args: + ---- + x: `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating + point (float{16|32|64}), or integral ([u]int{8|16|32|64}). + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + name: (Optional) A name for this operation. + output_dtype: (Optional) If not None, casts the output tensor to this type. - return x_loc, x_scale, hl_param, hr_param + Returns: + ------- + A `Tensor` containing the variance. If `x` is floating point, the variance + will have the same type as `x`. If `x` is integral, the output is cast to + float32. NaNs and infinite input values are ignored. + + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + with tf.compat.v1.name_scope(name, "var"): + return _mean_and_var(x, reduce_instance_dims, output_dtype)[1] -def _mean_and_var_per_key( +def _mean_and_var( x: common_types.TensorType, - key: common_types.TensorType, reduce_instance_dims: bool = True, output_dtype: Optional[tf.DType] = None, - key_vocabulary_filename: Optional[str] = None -) -> Union[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor, - tf.saved_model.Asset]: - """`mean_and_var` by group, specified by key. - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. - key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string. If - `x` is a `CompositeTensor`, `key` must exactly match `x` in everything - except values. - reduce_instance_dims: (Optional) By default collapses the batch and instance - dimensions to arrive at a single scalar output. The False case is not - currently supported for _mean_and_var_per_key. - output_dtype: (Optional) Desired output dtype, otherwise inferred. - key_vocabulary_filename: (Optional) The file name for the key-output mapping - file. If None and key are provided, this combiner assumes the keys fit in - memory and will not store the result in a file. If empty string, a file - name will be chosen based on the current scope. If not an empty string, - should be unique within a given preprocessing function. - - Returns: - Either: - (A) Three `Tensor`s. The first is the key vocab of type tf.string, and the - second two have same type as `x` (if key_vocabulary_filename is None). - (B) The filename where the key-value mapping is stored (if - key_vocabulary_filename is not None). - NaNs and infinite input values are ignored. - """ - if output_dtype is None: - output_dtype = _FLOAT_OUTPUT_DTYPE_MAP.get(x.dtype) +): + """More efficient combined `mean` and `var`. See `var`.""" if output_dtype is None: - raise TypeError('Tensor type %r is not supported' % x.dtype) - - if key is None: - raise ValueError('A non-None key is required for _mean_and_var_per_key') - - if not reduce_instance_dims and isinstance( - x, (tf.SparseTensor, tf.RaggedTensor)): - raise NotImplementedError( - 'Per-key elementwise reduction of Composite Tensors not supported ') + output_dtype = _FLOAT_OUTPUT_DTYPE_MAP.get(x.dtype) + if output_dtype is None: + raise TypeError("Tensor type %r is not supported" % x.dtype) + if not reduce_instance_dims and isinstance(x, tf.RaggedTensor): + raise NotImplementedError( + "Elementwise mean_and_var does not support RaggedTensors." + ) - with tf.compat.v1.name_scope('mean_and_var_per_key'): - x = tf.cast(x, output_dtype) + with tf.compat.v1.name_scope("mean_and_var"): + x = tf.cast(x, output_dtype) - key_vocab, key_counts, key_means, key_variances = ( - tf_utils.reduce_batch_count_mean_and_var_per_key( - x, key, reduce_instance_dims=reduce_instance_dims)) - output_shape = () if reduce_instance_dims else x.get_shape()[1:] + x_count, x_mean, x_variance = tf_utils.reduce_batch_count_mean_and_var( + x, reduce_instance_dims + ) - combine_inputs = _WeightedMeanAndVarAccumulator( - count=key_counts, - mean=key_means, - variance=key_variances, - weight=tf.zeros_like(key_means, tf.float32)) + combine_inputs = _WeightedMeanAndVarAccumulator( + count=x_count, + mean=x_mean, + variance=x_variance, + weight=tf.zeros([], tf.float32), + ) - combiner = WeightedMeanAndVarCombiner(output_dtype.as_numpy_dtype, - output_shape) + output_shape = () + if not reduce_instance_dims: + # We need to use tf.expand_dims to artificially add a batch dimension. + output_shape = _get_output_shape_from_input(tf.expand_dims(x_count, axis=0)) - if key_vocabulary_filename is not None: - key_vocabulary_filename = _maybe_get_per_key_vocab_filename( - key_vocabulary_filename) - return _apply_cacheable_combiner_per_key_large( - combiner, key_vocabulary_filename, key_vocab, *combine_inputs) + x_mean, x_var = _apply_cacheable_combiner( + WeightedMeanAndVarCombiner(output_dtype.as_numpy_dtype, output_shape), + *combine_inputs, + ) - key, key_mean, key_var = _apply_cacheable_combiner_per_key( - combiner, key_vocab, *combine_inputs) + return x_mean, x_var - return key, key_mean, key_var +@common.log_api_use(common.ANALYZER_COLLECTION) +def tukey_location( + x: common_types.TensorType, + reduce_instance_dims: Optional[bool] = True, + output_dtype: Optional[tf.DType] = None, + name: Optional[str] = None, +) -> tf.Tensor: + """Computes the location of the values of a `Tensor` over the whole dataset. -class _WeightedMeanAndVarAccumulator( - tfx_namedtuple.namedtuple('WeightedMeanAndVarAccumulator', - ['count', 'mean', 'variance', 'weight'])): - """Container for WeightedMeanAndVarCombiner intermediate values.""" + This computes the location of x, assuming a Tukey HH distribution, i.e. + (x - tukey_location) / tukey_scale is a Tukey HH distribution with parameters + tukey_h_params. See the following publication for the definition of the Tukey + HH distribution: - @classmethod - def make_nan_to_num(cls, - counts, - means, - variances, - weights, - compute_variance=False, - compute_weighted=True): - """Util function to replace NaN with 0 and inf with large finite numbers.""" - if compute_variance: - variances = np.nan_to_num(variances, copy=True) + Todd C. Headrick, and Mohan D. Pant. "Characterizing Tukey h and + hh-Distributions through L-Moments and the L-Correlation," ISRN Applied + Mathematics, vol. 2012, 2012. doi:10.5402/2012/980153 - if compute_weighted: - weights = np.nan_to_num(weights, copy=True) + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating + point (float{16|32|64}), or integral ([u]int{8|16|32|64}). + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + output_dtype: (Optional) If not None, casts the output tensor to this type. + name: (Optional) A name for this operation. - return cls( - np.array(counts), np.nan_to_num(means, copy=True), variances, weights) + Returns: + ------- + A `Tensor` containing the location. If `x` is floating point, the location + will have the same type as `x`. If `x` is integral, the output is cast to + float32. + + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + with tf.compat.v1.name_scope(name, "tukey_location"): + return _tukey_parameters(x, reduce_instance_dims, output_dtype)[0] -class WeightedMeanAndVarCombiner(analyzer_nodes.Combiner): - """Combines a PCollection of accumulators to compute mean and variance.""" +@common.log_api_use(common.ANALYZER_COLLECTION) +def tukey_scale( + x: common_types.TensorType, + reduce_instance_dims: Optional[bool] = True, + output_dtype: Optional[tf.DType] = None, + name: Optional[str] = None, +) -> tf.Tensor: + """Computes the scale of the values of a `Tensor` over the whole dataset. - accumulator_class = _WeightedMeanAndVarAccumulator + This computes the scale of x, assuming a Tukey HH distribution, i.e. + (x - tukey_location) / tukey_scale is a Tukey HH distribution with parameters + tukey_h_params. See the following publication for the definition of the Tukey + HH distribution: - def __init__(self, - output_numpy_dtype, - output_shape: Optional[Collection[Optional[int]]] = None, - compute_variance: bool = True, - compute_weighted: bool = False): - """Init method for WeightedMeanAndVarCombiner. + Todd C. Headrick, and Mohan D. Pant. "Characterizing Tukey h and + hh-Distributions through L-Moments and the L-Correlation," ISRN Applied + Mathematics, vol. 2012, 2012. doi:10.5402/2012/980153 - Args: - output_numpy_dtype: A numpy dtype that the outputs are cast to. - output_shape: The shape of the resulting Tensors. - compute_variance: A bool indicating whether or not a variance should be - calculated and returned. - compute_weighted: A bool indicating whether or not weights are provided - and all calculations should be weighted. - """ - self._output_numpy_dtype = output_numpy_dtype - self._output_shape = output_shape - self._compute_variance = compute_variance - self._compute_weighted = compute_weighted - - if self._compute_variance and self._compute_weighted: - raise ValueError( - 'WeightedMeanAndVarCombiner does not yet support weighted variance') - if self._output_shape is None: - raise ValueError('An output_shape must be provided.') - - def create_accumulator(self) -> _WeightedMeanAndVarAccumulator: - """Create an accumulator with all zero entries.""" - # TODO(b/131325061): Determine whether counts/weights should always be - # scalars or if we want to continue supporting multi-dimensional arrays. - initial_count, initial_weight = np.array(0), np.array(0.) - # If we know the exact shape, initialize accumulator values with zeros of - # the exact shape. For unknown dimensions, initialize with a 1D 0 array. - output_shape = [dim if dim is not None else 0 for dim in self._output_shape] - initial_mean, initial_var = np.zeros(output_shape), np.zeros(output_shape) - return _WeightedMeanAndVarAccumulator(initial_count, initial_mean, - initial_var, initial_weight) - - def add_input( - self, accumulator: _WeightedMeanAndVarAccumulator, - batch_values: _WeightedMeanAndVarAccumulator - ) -> _WeightedMeanAndVarAccumulator: - """Composes an accumulator from batch_values and calls merge_accumulators. Args: - accumulator: The `_WeightedMeanAndVarAccumulator` computed so far. - batch_values: A `_WeightedMeanAndVarAccumulator` for the current batch. + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating + point (float{16|32|64}), or integral ([u]int{8|16|32|64}). + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + output_dtype: (Optional) If not None, casts the output tensor to this type. + name: (Optional) A name for this operation. Returns: - A `_WeightedMeanAndVarAccumulator` which is accumulator and batch_values - combined. + ------- + A `Tensor` containing the scale. If `x` is floating point, the location + will have the same type as `x`. If `x` is integral, the output is cast to + float32. + + Raises: + ------ + TypeError: If the type of `x` is not supported. """ - new_accumulator = _WeightedMeanAndVarAccumulator(*batch_values) - return self._combine_mean_and_var_accumulators(accumulator, new_accumulator) + with tf.compat.v1.name_scope(name, "tukey_scale"): + return _tukey_parameters(x, reduce_instance_dims, output_dtype)[1] - def merge_accumulators( - self, accumulators: List[_WeightedMeanAndVarAccumulator] - ) -> _WeightedMeanAndVarAccumulator: - """Merges several `_WeightedMeanAndVarAccumulator`s to a single accumulator. - Args: - accumulators: A list of `_WeightedMeanAndVarAccumulator`s. +@common.log_api_use(common.ANALYZER_COLLECTION) +def tukey_h_params( + x: common_types.TensorType, + reduce_instance_dims: bool = True, + output_dtype: Optional[tf.DType] = None, + name: Optional[str] = None, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes the h parameters of the values of a `Tensor` over the dataset. - Returns: - The sole merged `_WeightedMeanAndVarAccumulator`. - """ - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result = self._combine_mean_and_var_accumulators(result, accumulator) - return result + This computes the parameters (hl, hr) of the samples, assuming a Tukey HH + distribution, i.e. (x - tukey_location) / tukey_scale is a Tukey HH + distribution with parameters hl (left parameter) and hr (right parameter). + See the following publication for the definition of the Tukey HH distribution: - def extract_output( - self, accumulator: _WeightedMeanAndVarAccumulator - ) -> Union[Tuple[float, float], _WeightedMeanAndVarAccumulator]: - """Converts an accumulator into the output accumulator or (mean, var) tuple. + Todd C. Headrick, and Mohan D. Pant. "Characterizing Tukey h and + hh-Distributions through L-Moments and the L-Correlation," ISRN Applied + Mathematics, vol. 2012, 2012. doi:10.5402/2012/980153 Args: - accumulator: the final `_WeightedMeanAndVarAccumulator` value. + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. Its type must be floating + point (float{16|32|64}), or integral ([u]int{8|16|32|64}). + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single scalar output. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + output_dtype: (Optional) If not None, casts the output tensor to this type. + name: (Optional) A name for this operation. Returns: - A _WeightedMeanAndVarAccumulator or a 2-tuple composed of (mean, var). + ------- + The tuple (hl, hr) containing two `Tensor` instances with the hl and hr + parameters. If `x` is floating point, each parameter will have the same type + as `x`. If `x` is integral, the output is cast to float32. + + Raises: + ------ + TypeError: If the type of `x` is not supported. """ + with tf.compat.v1.name_scope(name, "tukey_h_params"): + return _tukey_parameters(x, reduce_instance_dims, output_dtype)[2:] - if self._compute_variance and not self._compute_weighted: - return (self._output_numpy_dtype(accumulator.mean), - self._output_numpy_dtype(accumulator.variance)) - else: - return _WeightedMeanAndVarAccumulator( - np.int64(accumulator.count), - self._output_numpy_dtype(accumulator.mean), - self._output_numpy_dtype(accumulator.variance), - self._output_numpy_dtype(accumulator.weight)) - - def output_tensor_infos(self) -> List[analyzer_nodes.TensorInfo]: - # The output is (mean, var). - if self._compute_variance and not self._compute_weighted: - return [ - analyzer_nodes.TensorInfo( - tf.as_dtype(self._output_numpy_dtype), self._output_shape, None) - ] * 2 - else: - return [ - analyzer_nodes.TensorInfo( - tf.as_dtype(np.int64), self._output_shape, None), - analyzer_nodes.TensorInfo( - tf.as_dtype(self._output_numpy_dtype), self._output_shape, None), - analyzer_nodes.TensorInfo( - tf.as_dtype(self._output_numpy_dtype), self._output_shape, None), - analyzer_nodes.TensorInfo( - tf.as_dtype(self._output_numpy_dtype), self._output_shape, None) - ] - - def _combine_mean_and_var_accumulators( - self, a: _WeightedMeanAndVarAccumulator, - b: _WeightedMeanAndVarAccumulator) -> _WeightedMeanAndVarAccumulator: - """Combines two mean and var accumulators. + +def _tukey_parameters( + x: common_types.TensorType, + reduce_instance_dims: bool = True, + output_dtype: Optional[tf.DType] = None, +) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: + """Efficient computation of L-moments.""" + if output_dtype is None: + output_dtype = _FLOAT_OUTPUT_DTYPE_MAP.get(x.dtype) + if output_dtype is None: + raise TypeError("Tensor type %r is not supported" % x.dtype) + + with tf.compat.v1.name_scope("tukey_parameters"): + x = tf.cast(x, output_dtype) + + (count_l1, l1, count_l2, l2, count_l3, l3, count_l4, l4) = ( + tf_utils.reduce_batch_count_l_moments(x, reduce_instance_dims) + ) + + combine_inputs = _LMomentsAccumulator( + count_l1=count_l1, + count_l2=count_l2, + count_l3=count_l3, + count_l4=count_l4, + l1=l1, + l2=l2, + l3=l3, + l4=l4, + ) + + output_shape = () + if not reduce_instance_dims: + output_shape = _get_output_shape_from_input(x) + + x_loc, x_scale, hl_param, hr_param = _apply_cacheable_combiner( + _LMomentsCombiner(output_dtype.as_numpy_dtype, output_shape), + *combine_inputs, + ) + + return x_loc, x_scale, hl_param, hr_param + + +def _mean_and_var_per_key( + x: common_types.TensorType, + key: common_types.TensorType, + reduce_instance_dims: bool = True, + output_dtype: Optional[tf.DType] = None, + key_vocabulary_filename: Optional[str] = None, +) -> Union[Tuple[tf.Tensor, tf.Tensor, tf.Tensor], tf.Tensor, tf.saved_model.Asset]: + """`mean_and_var` by group, specified by key. Args: - a: A _WeightedMeanAndVarAccumulator. - b: A _WeightedMeanAndVarAccumulator. + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor`. + key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string. If + `x` is a `CompositeTensor`, `key` must exactly match `x` in everything + except values. + reduce_instance_dims: (Optional) By default collapses the batch and instance + dimensions to arrive at a single scalar output. The False case is not + currently supported for _mean_and_var_per_key. + output_dtype: (Optional) Desired output dtype, otherwise inferred. + key_vocabulary_filename: (Optional) The file name for the key-output mapping + file. If None and key are provided, this combiner assumes the keys fit in + memory and will not store the result in a file. If empty string, a file + name will be chosen based on the current scope. If not an empty string, + should be unique within a given preprocessing function. Returns: - A _WeightedMeanAndVarAccumulator computed as the combination of a and b. + ------- + Either: + (A) Three `Tensor`s. The first is the key vocab of type tf.string, and the + second two have same type as `x` (if key_vocabulary_filename is None). + (B) The filename where the key-value mapping is stored (if + key_vocabulary_filename is not None). + NaNs and infinite input values are ignored. """ - # NaNs get preserved through division by a.count + b.count. - a = _WeightedMeanAndVarAccumulator.make_nan_to_num( - *a, - compute_variance=self._compute_variance, - compute_weighted=self._compute_weighted) - b = _WeightedMeanAndVarAccumulator.make_nan_to_num( - *b, - compute_variance=self._compute_variance, - compute_weighted=self._compute_weighted) - - # a.count >= b.count following this logic. - if np.sum(a.count) < np.sum(b.count): - a, b = b, a - - if np.sum(a.count) == 0: - return b - - a_count, b_count = _pad_arrays_to_match(a.count, b.count) - a_mean, b_mean = _pad_arrays_to_match(a.mean, b.mean) - combined_total = a_count + b_count - if self._compute_weighted: - a_weight, b_weight = _pad_arrays_to_match(a.weight, b.weight) - # Mean and variance update formulas which are more numerically stable when - # a and b vary in magnitude. - combined_weights_mean = ( - a_weight + (b_count / combined_total) * (b_weight - a_weight)) - combined_mean = a_mean + (b_count * b_weight / - (combined_total * combined_weights_mean)) * ( - b_mean - a_mean) - else: - combined_weights_mean = np.ones(shape=combined_total.shape) - combined_mean = a_mean + (b_count / combined_total * (b_mean - a_mean)) - if self._compute_variance: - a_variance, b_variance = _pad_arrays_to_match(a.variance, b.variance) - # TODO(zoyahav): Add an option for weighted variance if needed. - assert not self._compute_weighted - combined_variance = ( - a_variance + (b_count / combined_total) * (b_variance - a_variance + - ((b_mean - combined_mean) * - (b_mean - a_mean)))) - - else: - combined_variance = np.zeros(combined_mean.shape) + if output_dtype is None: + output_dtype = _FLOAT_OUTPUT_DTYPE_MAP.get(x.dtype) + if output_dtype is None: + raise TypeError("Tensor type %r is not supported" % x.dtype) - return _WeightedMeanAndVarAccumulator(combined_total, combined_mean, - combined_variance, - combined_weights_mean) + if key is None: + raise ValueError("A non-None key is required for _mean_and_var_per_key") + if not reduce_instance_dims and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): + raise NotImplementedError( + "Per-key elementwise reduction of Composite Tensors not supported " + ) -# TODO(b/165020671): Optimize padding to save up to 15% computing resource. -def _pad_arrays_to_match(a, b): - """Pad the ndarray values to match dimensions as needed. - - If the dimensions of the ndarrays values differ, we pad the smaller of the - two arrays with zeros to be the same shape as the larger. In other words, - the missing accumulator indices are assumed to be zero, and combining - a = [1, 2, 3] with b = [1, 2] is equivalent t combining with b = [1, 2, 0]. - - Args: - a: NDarray to be matched in shaped with b - b: NDarray to be matched in shaped with a - - Returns: - a: a padded to same dimensions as b - b: b padded to same dimensions as a - """ - if a.shape == b.shape: - return a, b - padding_a, padding_b = [], [] - for a_dim, b_dim in zip(a.shape, b.shape): - a_pad = b_pad = (0, 0) - delta = a_dim - b_dim - if delta > 0: - b_pad = (0, abs(delta)) - elif delta < 0: - a_pad = (0, abs(delta)) - padding_a.append(a_pad) - padding_b.append(b_pad) - if padding_a: - a = np.pad(a, padding_a, mode='constant') - if padding_b: - b = np.pad(b, padding_b, mode='constant') - return a, b + with tf.compat.v1.name_scope("mean_and_var_per_key"): + x = tf.cast(x, output_dtype) + key_vocab, key_counts, key_means, key_variances = ( + tf_utils.reduce_batch_count_mean_and_var_per_key( + x, key, reduce_instance_dims=reduce_instance_dims + ) + ) + output_shape = () if reduce_instance_dims else x.get_shape()[1:] -class _LMomentsAccumulator( - tfx_namedtuple.namedtuple('LMomentsAccumulator', [ - 'count_l1', 'count_l2', 'count_l3', 'count_l4', 'l1', 'l2', 'l3', 'l4' - ])): - """Container for _LMomentsCombiner intermediate values.""" + combine_inputs = _WeightedMeanAndVarAccumulator( + count=key_counts, + mean=key_means, + variance=key_variances, + weight=tf.zeros_like(key_means, tf.float32), + ) - @classmethod - def make_nan_to_num(cls, count_l1, count_l2, count_l3, count_l4, - l1, l2, l3, l4): - return cls( - np.array(count_l1), np.array(count_l2), np.array(count_l3), - np.array(count_l4), np.nan_to_num(l1), np.nan_to_num(l2), - np.nan_to_num(l3), np.nan_to_num(l4)) + combiner = WeightedMeanAndVarCombiner(output_dtype.as_numpy_dtype, output_shape) - def __reduce__(self): - return self.__class__, tuple(self) + if key_vocabulary_filename is not None: + key_vocabulary_filename = _maybe_get_per_key_vocab_filename( + key_vocabulary_filename + ) + return _apply_cacheable_combiner_per_key_large( + combiner, key_vocabulary_filename, key_vocab, *combine_inputs + ) + key, key_mean, key_var = _apply_cacheable_combiner_per_key( + combiner, key_vocab, *combine_inputs + ) -class _LMomentsCombiner(analyzer_nodes.Combiner): - """Combines a PCollection of accumulators to compute L-moments.""" + return key, key_mean, key_var - accumulator_class = _LMomentsAccumulator - def __init__(self, output_numpy_dtype, output_shape): - """Init method for _LMomentsCombiner. +class _WeightedMeanAndVarAccumulator( + tfx_namedtuple.namedtuple( + "WeightedMeanAndVarAccumulator", ["count", "mean", "variance", "weight"] + ) +): + """Container for WeightedMeanAndVarCombiner intermediate values.""" + + @classmethod + def make_nan_to_num( + cls, + counts, + means, + variances, + weights, + compute_variance=False, + compute_weighted=True, + ): + """Util function to replace NaN with 0 and inf with large finite numbers.""" + if compute_variance: + variances = np.nan_to_num(variances, copy=True) + + if compute_weighted: + weights = np.nan_to_num(weights, copy=True) + + return cls( + np.array(counts), np.nan_to_num(means, copy=True), variances, weights + ) - Args: - output_numpy_dtype: A numpy dtype that the outputs are cast to. - output_shape: The shape of the resulting Tensors. - """ - self._output_numpy_dtype = output_numpy_dtype - self._output_shape = output_shape - def create_accumulator(self): - """Create an accumulator with all zero entries.""" +class WeightedMeanAndVarCombiner(analyzer_nodes.Combiner): + """Combines a PCollection of accumulators to compute mean and variance.""" + + accumulator_class = _WeightedMeanAndVarAccumulator + + def __init__( + self, + output_numpy_dtype, + output_shape: Optional[Collection[Optional[int]]] = None, + compute_variance: bool = True, + compute_weighted: bool = False, + ): + """Init method for WeightedMeanAndVarCombiner. + + Args: + ---- + output_numpy_dtype: A numpy dtype that the outputs are cast to. + output_shape: The shape of the resulting Tensors. + compute_variance: A bool indicating whether or not a variance should be + calculated and returned. + compute_weighted: A bool indicating whether or not weights are provided + and all calculations should be weighted. + """ + self._output_numpy_dtype = output_numpy_dtype + self._output_shape = output_shape + self._compute_variance = compute_variance + self._compute_weighted = compute_weighted + + if self._compute_variance and self._compute_weighted: + raise ValueError( + "WeightedMeanAndVarCombiner does not yet support weighted variance" + ) + if self._output_shape is None: + raise ValueError("An output_shape must be provided.") + + def create_accumulator(self) -> _WeightedMeanAndVarAccumulator: + """Create an accumulator with all zero entries.""" + # TODO(b/131325061): Determine whether counts/weights should always be + # scalars or if we want to continue supporting multi-dimensional arrays. + initial_count, initial_weight = np.array(0), np.array(0.0) + # If we know the exact shape, initialize accumulator values with zeros of + # the exact shape. For unknown dimensions, initialize with a 1D 0 array. + output_shape = [dim if dim is not None else 0 for dim in self._output_shape] + initial_mean, initial_var = np.zeros(output_shape), np.zeros(output_shape) + return _WeightedMeanAndVarAccumulator( + initial_count, initial_mean, initial_var, initial_weight + ) - # If we know the exact shape, initialize accumulator values with zeros of - # the exact shape. For unknown dimensions, initialize with a 1D 0 array - # (this accumulator will be discarded by _combine_accumulators). - output_shape = () if None in self._output_shape else self._output_shape - initial_moment = np.zeros(output_shape, dtype=self._output_numpy_dtype) - initial_count = np.zeros(output_shape, dtype=self._output_numpy_dtype) - return _LMomentsAccumulator( - initial_count, initial_count, initial_count, initial_count, - initial_moment, initial_moment, initial_moment, initial_moment) + def add_input( + self, + accumulator: _WeightedMeanAndVarAccumulator, + batch_values: _WeightedMeanAndVarAccumulator, + ) -> _WeightedMeanAndVarAccumulator: + """Composes an accumulator from batch_values and calls merge_accumulators. + + Args: + ---- + accumulator: The `_WeightedMeanAndVarAccumulator` computed so far. + batch_values: A `_WeightedMeanAndVarAccumulator` for the current batch. + + Returns: + ------- + A `_WeightedMeanAndVarAccumulator` which is accumulator and batch_values + combined. + """ + new_accumulator = _WeightedMeanAndVarAccumulator(*batch_values) + return self._combine_mean_and_var_accumulators(accumulator, new_accumulator) + + def merge_accumulators( + self, accumulators: List[_WeightedMeanAndVarAccumulator] + ) -> _WeightedMeanAndVarAccumulator: + """Merges several `_WeightedMeanAndVarAccumulator`s to a single accumulator. + + Args: + ---- + accumulators: A list of `_WeightedMeanAndVarAccumulator`s. + + Returns: + ------- + The sole merged `_WeightedMeanAndVarAccumulator`. + """ + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result = self._combine_mean_and_var_accumulators(result, accumulator) + return result + + def extract_output( + self, accumulator: _WeightedMeanAndVarAccumulator + ) -> Union[Tuple[float, float], _WeightedMeanAndVarAccumulator]: + """Converts an accumulator into the output accumulator or (mean, var) tuple. + + Args: + ---- + accumulator: the final `_WeightedMeanAndVarAccumulator` value. + + Returns: + ------- + A _WeightedMeanAndVarAccumulator or a 2-tuple composed of (mean, var). + """ + if self._compute_variance and not self._compute_weighted: + return ( + self._output_numpy_dtype(accumulator.mean), + self._output_numpy_dtype(accumulator.variance), + ) + else: + return _WeightedMeanAndVarAccumulator( + np.int64(accumulator.count), + self._output_numpy_dtype(accumulator.mean), + self._output_numpy_dtype(accumulator.variance), + self._output_numpy_dtype(accumulator.weight), + ) + + def output_tensor_infos(self) -> List[analyzer_nodes.TensorInfo]: + # The output is (mean, var). + if self._compute_variance and not self._compute_weighted: + return [ + analyzer_nodes.TensorInfo( + tf.as_dtype(self._output_numpy_dtype), self._output_shape, None + ) + ] * 2 + else: + return [ + analyzer_nodes.TensorInfo( + tf.as_dtype(np.int64), self._output_shape, None + ), + analyzer_nodes.TensorInfo( + tf.as_dtype(self._output_numpy_dtype), self._output_shape, None + ), + analyzer_nodes.TensorInfo( + tf.as_dtype(self._output_numpy_dtype), self._output_shape, None + ), + analyzer_nodes.TensorInfo( + tf.as_dtype(self._output_numpy_dtype), self._output_shape, None + ), + ] + + def _combine_mean_and_var_accumulators( + self, a: _WeightedMeanAndVarAccumulator, b: _WeightedMeanAndVarAccumulator + ) -> _WeightedMeanAndVarAccumulator: + """Combines two mean and var accumulators. + + Args: + ---- + a: A _WeightedMeanAndVarAccumulator. + b: A _WeightedMeanAndVarAccumulator. + + Returns: + ------- + A _WeightedMeanAndVarAccumulator computed as the combination of a and b. + """ + # NaNs get preserved through division by a.count + b.count. + a = _WeightedMeanAndVarAccumulator.make_nan_to_num( + *a, + compute_variance=self._compute_variance, + compute_weighted=self._compute_weighted, + ) + b = _WeightedMeanAndVarAccumulator.make_nan_to_num( + *b, + compute_variance=self._compute_variance, + compute_weighted=self._compute_weighted, + ) - def add_input(self, accumulator, batch_values): - """Composes an accumulator from batch_values and calls merge_accumulators. + # a.count >= b.count following this logic. + if np.sum(a.count) < np.sum(b.count): + a, b = b, a + + if np.sum(a.count) == 0: + return b + + a_count, b_count = _pad_arrays_to_match(a.count, b.count) + a_mean, b_mean = _pad_arrays_to_match(a.mean, b.mean) + combined_total = a_count + b_count + if self._compute_weighted: + a_weight, b_weight = _pad_arrays_to_match(a.weight, b.weight) + # Mean and variance update formulas which are more numerically stable when + # a and b vary in magnitude. + combined_weights_mean = a_weight + (b_count / combined_total) * ( + b_weight - a_weight + ) + combined_mean = a_mean + ( + b_count * b_weight / (combined_total * combined_weights_mean) + ) * (b_mean - a_mean) + else: + combined_weights_mean = np.ones(shape=combined_total.shape) + combined_mean = a_mean + (b_count / combined_total * (b_mean - a_mean)) + if self._compute_variance: + a_variance, b_variance = _pad_arrays_to_match(a.variance, b.variance) + # TODO(zoyahav): Add an option for weighted variance if needed. + assert not self._compute_weighted + combined_variance = a_variance + (b_count / combined_total) * ( + b_variance - a_variance + ((b_mean - combined_mean) * (b_mean - a_mean)) + ) + + else: + combined_variance = np.zeros(combined_mean.shape) + + return _WeightedMeanAndVarAccumulator( + combined_total, combined_mean, combined_variance, combined_weights_mean + ) - Args: - accumulator: The `_LMomentsAccumulator` computed so far. - batch_values: A `_LMomentsAccumulator` for the current batch. - Returns: - A `_LMomentsAccumulator` which is accumulator and batch_values combined. - """ - new_accumulator = _LMomentsAccumulator(*batch_values) - return self._combine_accumulators(accumulator, new_accumulator) +# TODO(b/165020671): Optimize padding to save up to 15% computing resource. +def _pad_arrays_to_match(a, b): + """Pad the ndarray values to match dimensions as needed. - def merge_accumulators(self, accumulators): - """Merges several `_LMomentsAccumulator`s to a single accumulator. + If the dimensions of the ndarrays values differ, we pad the smaller of the + two arrays with zeros to be the same shape as the larger. In other words, + the missing accumulator indices are assumed to be zero, and combining + a = [1, 2, 3] with b = [1, 2] is equivalent t combining with b = [1, 2, 0]. Args: - accumulators: A list of `_LMomentsAccumulator`s. + ---- + a: NDarray to be matched in shaped with b + b: NDarray to be matched in shaped with a Returns: - The sole merged `_LMomentsAccumulator`. + ------- + a: a padded to same dimensions as b + b: b padded to same dimensions as a """ - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result = self._combine_accumulators(result, accumulator) - return result + if a.shape == b.shape: + return a, b + padding_a, padding_b = [], [] + for a_dim, b_dim in zip(a.shape, b.shape): + a_pad = b_pad = (0, 0) + delta = a_dim - b_dim + if delta > 0: + b_pad = (0, abs(delta)) + elif delta < 0: + a_pad = (0, abs(delta)) + padding_a.append(a_pad) + padding_b.append(b_pad) + if padding_a: + a = np.pad(a, padding_a, mode="constant") + if padding_b: + b = np.pad(b, padding_b, mode="constant") + return a, b + - def extract_output(self, accumulator): - """Converts an accumulator into the output (loc, scale, hl, hr) tuple. +class _LMomentsAccumulator( + tfx_namedtuple.namedtuple( + "LMomentsAccumulator", + ["count_l1", "count_l2", "count_l3", "count_l4", "l1", "l2", "l3", "l4"], + ) +): + """Container for _LMomentsCombiner intermediate values.""" + + @classmethod + def make_nan_to_num(cls, count_l1, count_l2, count_l3, count_l4, l1, l2, l3, l4): + return cls( + np.array(count_l1), + np.array(count_l2), + np.array(count_l3), + np.array(count_l4), + np.nan_to_num(l1), + np.nan_to_num(l2), + np.nan_to_num(l3), + np.nan_to_num(l4), + ) - Estimates the parameters of a Tukey HH distribution, given estimates of the - first four L-moments. The parameters are: location, scale, hl, and hr. If - x is the input sample, then (x - location) / scale is distributed according - to the Tukey HH distribution with parameters hl (left parameter) and hr - (right parameter). + def __reduce__(self): + return self.__class__, tuple(self) - Args: - accumulator: the final `_LMomentsAccumulator` value. - Returns: - A 4-tuple composed of (location, scale, hl, hr). - """ +class _LMomentsCombiner(analyzer_nodes.Combiner): + """Combines a PCollection of accumulators to compute L-moments.""" + + accumulator_class = _LMomentsAccumulator + + def __init__(self, output_numpy_dtype, output_shape): + """Init method for _LMomentsCombiner. + + Args: + ---- + output_numpy_dtype: A numpy dtype that the outputs are cast to. + output_shape: The shape of the resulting Tensors. + """ + self._output_numpy_dtype = output_numpy_dtype + self._output_shape = output_shape + + def create_accumulator(self): + """Create an accumulator with all zero entries.""" + # If we know the exact shape, initialize accumulator values with zeros of + # the exact shape. For unknown dimensions, initialize with a 1D 0 array + # (this accumulator will be discarded by _combine_accumulators). + output_shape = () if None in self._output_shape else self._output_shape + initial_moment = np.zeros(output_shape, dtype=self._output_numpy_dtype) + initial_count = np.zeros(output_shape, dtype=self._output_numpy_dtype) + return _LMomentsAccumulator( + initial_count, + initial_count, + initial_count, + initial_count, + initial_moment, + initial_moment, + initial_moment, + initial_moment, + ) - # To compute kurtosis, we need positive scale and at least one quadruplet. - # If this is not the case, L-kewness and L-kurtosis are set to zero, which - # gives hl=0, hr=0 and samples are treated as in the Gaussian case. - - valid_scale = accumulator.l2 > 0.0 - valid_kurtosis = np.logical_and(valid_scale, accumulator.count_l4 > 0.0) - - l_skewness = np.true_divide(accumulator.l3, accumulator.l2, - where=valid_kurtosis, - out=np.zeros_like(accumulator.l3)) - - l_kurtosis = np.true_divide(accumulator.l4, accumulator.l2, - where=valid_kurtosis, - out=np.zeros_like(accumulator.l4)) - l_skewness_and_kurtosis = np.stack((l_skewness, l_kurtosis), axis=0) - h_params = np.apply_along_axis( - gaussianization.compute_tukey_hh_params, 0, l_skewness_and_kurtosis) - hh_l_mean, hh_l_scale = gaussianization.tukey_hh_l_mean_and_scale(h_params) - - scale = np.true_divide(accumulator.l2, hh_l_scale, - where=valid_scale, out=np.ones_like(accumulator.l2)) - loc = accumulator.l1 - scale * hh_l_mean - hl = h_params[0, ...] - hr = h_params[1, ...] - return [self._output_numpy_dtype(x) for x in [loc, scale, hl, hr]] - - def output_tensor_infos(self): - # The output is (loc, scale, hl, hr). - return [ - analyzer_nodes.TensorInfo( - tf.as_dtype(self._output_numpy_dtype), self._output_shape, None) - ] * 4 - - @property - def accumulator_coder(self): - # TODO(b/170510451): Re-enable caching for this Combiner. - return None - - def _combine_accumulators(self, a, b): - """Combines two accumulators. + def add_input(self, accumulator, batch_values): + """Composes an accumulator from batch_values and calls merge_accumulators. + + Args: + ---- + accumulator: The `_LMomentsAccumulator` computed so far. + batch_values: A `_LMomentsAccumulator` for the current batch. + + Returns: + ------- + A `_LMomentsAccumulator` which is accumulator and batch_values combined. + """ + new_accumulator = _LMomentsAccumulator(*batch_values) + return self._combine_accumulators(accumulator, new_accumulator) + + def merge_accumulators(self, accumulators): + """Merges several `_LMomentsAccumulator`s to a single accumulator. + + Args: + ---- + accumulators: A list of `_LMomentsAccumulator`s. + + Returns: + ------- + The sole merged `_LMomentsAccumulator`. + """ + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result = self._combine_accumulators(result, accumulator) + return result + + def extract_output(self, accumulator): + """Converts an accumulator into the output (loc, scale, hl, hr) tuple. + + Estimates the parameters of a Tukey HH distribution, given estimates of the + first four L-moments. The parameters are: location, scale, hl, and hr. If + x is the input sample, then (x - location) / scale is distributed according + to the Tukey HH distribution with parameters hl (left parameter) and hr + (right parameter). + + Args: + ---- + accumulator: the final `_LMomentsAccumulator` value. + + Returns: + ------- + A 4-tuple composed of (location, scale, hl, hr). + """ + # To compute kurtosis, we need positive scale and at least one quadruplet. + # If this is not the case, L-kewness and L-kurtosis are set to zero, which + # gives hl=0, hr=0 and samples are treated as in the Gaussian case. + + valid_scale = accumulator.l2 > 0.0 + valid_kurtosis = np.logical_and(valid_scale, accumulator.count_l4 > 0.0) + + l_skewness = np.true_divide( + accumulator.l3, + accumulator.l2, + where=valid_kurtosis, + out=np.zeros_like(accumulator.l3), + ) - Args: - a: A _LMomentsAccumulator. - b: A _LMomentsAccumulator. + l_kurtosis = np.true_divide( + accumulator.l4, + accumulator.l2, + where=valid_kurtosis, + out=np.zeros_like(accumulator.l4), + ) + l_skewness_and_kurtosis = np.stack((l_skewness, l_kurtosis), axis=0) + h_params = np.apply_along_axis( + gaussianization.compute_tukey_hh_params, 0, l_skewness_and_kurtosis + ) + hh_l_mean, hh_l_scale = gaussianization.tukey_hh_l_mean_and_scale(h_params) - Returns: - A _LMomentsAccumulator computed as the combination of a and b. - """ - # NaNs get preserved through division by a.count + b.count. - a = _LMomentsAccumulator.make_nan_to_num(*a) - b = _LMomentsAccumulator.make_nan_to_num(*b) - - # If one accumulator is empty return the other. - if np.sum(a.count_l1) < np.sum(b.count_l1): - a, b = b, a - if np.sum(b.count_l1) == 0: - return a - - a_count_l1, b_count_l1 = _pad_arrays_to_match(a.count_l1, b.count_l1) - a_l1, b_l1 = _pad_arrays_to_match(a.l1, b.l1) - a_count_l2, b_count_l2 = _pad_arrays_to_match(a.count_l2, b.count_l2) - a_l2, b_l2 = _pad_arrays_to_match(a.l2, b.l2) - a_count_l3, b_count_l3 = _pad_arrays_to_match(a.count_l3, b.count_l3) - a_l3, b_l3 = _pad_arrays_to_match(a.l3, b.l3) - a_count_l4, b_count_l4 = _pad_arrays_to_match(a.count_l4, b.count_l4) - a_l4, b_l4 = _pad_arrays_to_match(a.l4, b.l4) - - combined_count_l1 = a_count_l1 + b_count_l1 - combined_count_l2 = a_count_l2 + b_count_l2 - combined_count_l3 = a_count_l3 + b_count_l3 - combined_count_l4 = a_count_l4 + b_count_l4 - - combined_l1 = (a_l1 + np.true_divide( - b_count_l1, combined_count_l1, where=combined_count_l1 > 0, - out=np.zeros_like(a_l1)) * (b_l1 - a_l1)) - combined_l2 = (a_l2 + np.true_divide( - b_count_l2, combined_count_l2, where=combined_count_l2 > 0, - out=np.zeros_like(a_l2)) * (b_l2 - a_l2)) - combined_l3 = (a_l3 + np.true_divide( - b_count_l3, combined_count_l3, where=combined_count_l3 > 0, - out=np.zeros_like(a_l3)) * (b_l3 - a_l3)) - combined_l4 = (a_l4 + np.true_divide( - b_count_l4, combined_count_l4, where=combined_count_l4 > 0, - out=np.zeros_like(a_l4)) * (b_l4 - a_l4)) - - return _LMomentsAccumulator( - combined_count_l1, combined_count_l2, combined_count_l3, - combined_count_l4, combined_l1, combined_l2, combined_l3, combined_l4) + scale = np.true_divide( + accumulator.l2, + hh_l_scale, + where=valid_scale, + out=np.ones_like(accumulator.l2), + ) + loc = accumulator.l1 - scale * hh_l_mean + hl = h_params[0, ...] + hr = h_params[1, ...] + return [self._output_numpy_dtype(x) for x in [loc, scale, hl, hr]] + + def output_tensor_infos(self): + # The output is (loc, scale, hl, hr). + return [ + analyzer_nodes.TensorInfo( + tf.as_dtype(self._output_numpy_dtype), self._output_shape, None + ) + ] * 4 + + @property + def accumulator_coder(self): + # TODO(b/170510451): Re-enable caching for this Combiner. + return None + + def _combine_accumulators(self, a, b): + """Combines two accumulators. + + Args: + ---- + a: A _LMomentsAccumulator. + b: A _LMomentsAccumulator. + + Returns: + ------- + A _LMomentsAccumulator computed as the combination of a and b. + """ + # NaNs get preserved through division by a.count + b.count. + a = _LMomentsAccumulator.make_nan_to_num(*a) + b = _LMomentsAccumulator.make_nan_to_num(*b) + + # If one accumulator is empty return the other. + if np.sum(a.count_l1) < np.sum(b.count_l1): + a, b = b, a + if np.sum(b.count_l1) == 0: + return a + + a_count_l1, b_count_l1 = _pad_arrays_to_match(a.count_l1, b.count_l1) + a_l1, b_l1 = _pad_arrays_to_match(a.l1, b.l1) + a_count_l2, b_count_l2 = _pad_arrays_to_match(a.count_l2, b.count_l2) + a_l2, b_l2 = _pad_arrays_to_match(a.l2, b.l2) + a_count_l3, b_count_l3 = _pad_arrays_to_match(a.count_l3, b.count_l3) + a_l3, b_l3 = _pad_arrays_to_match(a.l3, b.l3) + a_count_l4, b_count_l4 = _pad_arrays_to_match(a.count_l4, b.count_l4) + a_l4, b_l4 = _pad_arrays_to_match(a.l4, b.l4) + + combined_count_l1 = a_count_l1 + b_count_l1 + combined_count_l2 = a_count_l2 + b_count_l2 + combined_count_l3 = a_count_l3 + b_count_l3 + combined_count_l4 = a_count_l4 + b_count_l4 + + combined_l1 = a_l1 + np.true_divide( + b_count_l1, + combined_count_l1, + where=combined_count_l1 > 0, + out=np.zeros_like(a_l1), + ) * (b_l1 - a_l1) + combined_l2 = a_l2 + np.true_divide( + b_count_l2, + combined_count_l2, + where=combined_count_l2 > 0, + out=np.zeros_like(a_l2), + ) * (b_l2 - a_l2) + combined_l3 = a_l3 + np.true_divide( + b_count_l3, + combined_count_l3, + where=combined_count_l3 > 0, + out=np.zeros_like(a_l3), + ) * (b_l3 - a_l3) + combined_l4 = a_l4 + np.true_divide( + b_count_l4, + combined_count_l4, + where=combined_count_l4 > 0, + out=np.zeros_like(a_l4), + ) * (b_l4 - a_l4) + + return _LMomentsAccumulator( + combined_count_l1, + combined_count_l2, + combined_count_l3, + combined_count_l4, + combined_l1, + combined_l2, + combined_l3, + combined_l4, + ) def sanitized_vocab_filename(filename=None, prefix=None): - """Generates a sanitized filename either from the given filename or the scope. + """Generates a sanitized filename either from the given filename or the scope. - If filename is specified, provide a sanitized version of the given filename. - Otherwise generate a filename from the current scope. Note that it is the - callers responsibility to ensure that filenames are unique across calls within - a given preprocessing function. + If filename is specified, provide a sanitized version of the given filename. + Otherwise generate a filename from the current scope. Note that it is the + callers responsibility to ensure that filenames are unique across calls within + a given preprocessing function. - Args: - filename: A filename with non-alpha characters replaced with underscores and - spaces to hyphens. - prefix: Prefix to use for the name of the vocab file, if filename - is not given. + Args: + ---- + filename: A filename with non-alpha characters replaced with underscores and + spaces to hyphens. + prefix: Prefix to use for the name of the vocab file, if filename + is not given. - Returns: - A valid filename. + Returns: + ------- + A valid filename. - Raises: - ValueError: If neither filename and prefix are specified, or if both - are specified. - """ - if filename is None and prefix is None: - raise ValueError('Both filename and prefix cannot be None.') + Raises: + ------ + ValueError: If neither filename and prefix are specified, or if both + are specified. + """ + if filename is None and prefix is None: + raise ValueError("Both filename and prefix cannot be None.") - if filename is not None and prefix is not None: - raise ValueError('Only one of filename or prefix can be specified.') + if filename is not None and prefix is not None: + raise ValueError("Only one of filename or prefix can be specified.") - if filename is None: - filename = prefix + tf.compat.v1.get_default_graph().get_name_scope() - # Replace non-alpha characters (excluding whitespaces) with '_'. - filename = re.sub(r'[^\w\s-]', '_', filename).strip() - # Replace whitespaces with '-'. - return re.sub(r'[-\s]+', '-', filename) + if filename is None: + filename = prefix + tf.compat.v1.get_default_graph().get_name_scope() + # Replace non-alpha characters (excluding whitespaces) with '_'. + filename = re.sub(r"[^\w\s-]", "_", filename).strip() + # Replace whitespaces with '-'. + return re.sub(r"[-\s]+", "-", filename) def _get_vocab_filename(vocab_filename, store_frequency): - """Returns a sanitized vocabulary filename with appropriate prefix applied. + """Returns a sanitized vocabulary filename with appropriate prefix applied. - Args: - vocab_filename: The file name for the vocabulary file. If none, the - "vocabulary" scope name in the context of this graph will be used as the - file name. - store_frequency: A bool that is true when the vocabulary for which this - generates a filename stores term frequency. False otherwise. + Args: + ---- + vocab_filename: The file name for the vocabulary file. If none, the + "vocabulary" scope name in the context of this graph will be used as the + file name. + store_frequency: A bool that is true when the vocabulary for which this + generates a filename stores term frequency. False otherwise. - Returns: - A valid filename. - """ - if vocab_filename is not None: - prefix = None - elif store_frequency: - prefix = VOCAB_FREQUENCY_FILENAME_PREFIX - else: - prefix = VOCAB_FILENAME_PREFIX + Returns: + ------- + A valid filename. + """ + if vocab_filename is not None: + prefix = None + elif store_frequency: + prefix = VOCAB_FREQUENCY_FILENAME_PREFIX + else: + prefix = VOCAB_FILENAME_PREFIX - # Make the file name path safe. - return sanitized_vocab_filename(vocab_filename, prefix=prefix) + # Make the file name path safe. + return sanitized_vocab_filename(vocab_filename, prefix=prefix) def _maybe_get_per_key_vocab_filename(key_vocabulary_filename): - if key_vocabulary_filename == '': # pylint: disable=g-explicit-bool-comparison - key_vocabulary_filename = _get_vocab_filename(vocab_filename=None, - store_frequency=False) - return key_vocabulary_filename + if key_vocabulary_filename == "": # pylint: disable=g-explicit-bool-comparison + key_vocabulary_filename = _get_vocab_filename( + vocab_filename=None, store_frequency=False + ) + return key_vocabulary_filename # TODO(b/116308354): frequency_threshold is misleading since this threshold can # be applied to mutual information rather than frequency. def _get_top_k_and_frequency_threshold(top_k, frequency_threshold): - """Validate `top_k` and `frequency_threshold` values and convert to number.""" - if top_k is not None: - top_k = int(top_k) - if top_k <= 0: - raise ValueError('top_k must be positive, but got: %r' % top_k) - - if frequency_threshold is not None: - frequency_threshold = float(frequency_threshold) - if frequency_threshold < 0: - raise ValueError( - 'frequency_threshold must be non-negative, but got: %r' % - frequency_threshold) - elif frequency_threshold <= 1: - # Note: this warning is misleading in the context where tokens are ranked - # based on mutual information rather than frequency. - tf.compat.v1.logging.warn( - 'frequency_threshold %d <= 1 is a no-op, use None instead.', - frequency_threshold) - return top_k, frequency_threshold + """Validate `top_k` and `frequency_threshold` values and convert to number.""" + if top_k is not None: + top_k = int(top_k) + if top_k <= 0: + raise ValueError("top_k must be positive, but got: %r" % top_k) + + if frequency_threshold is not None: + frequency_threshold = float(frequency_threshold) + if frequency_threshold < 0: + raise ValueError( + "frequency_threshold must be non-negative, but got: %r" + % frequency_threshold + ) + elif frequency_threshold <= 1: + # Note: this warning is misleading in the context where tokens are ranked + # based on mutual information rather than frequency. + tf.compat.v1.logging.warn( + "frequency_threshold %d <= 1 is a no-op, use None instead.", + frequency_threshold, + ) + return top_k, frequency_threshold class _VocabOrderingType: - """Class for all vocab ordering types.""" - # Orders vocabulary based on the simple frequency of the token - FREQUENCY = 1 - # Orders vocabulary based on the weighted frequency of the token - WEIGHTED_FREQUENCY = 2 - # Orders vocabulary based on the weighted mutual - # information of token with the label - WEIGHTED_MUTUAL_INFORMATION = 3 - # Experimental - WEIGHTED_LABELS = 4 - # Orders vocabulary based on the mutual information - # of token with the label and without weight. - MUTUAL_INFORMATION = 5 - - -def register_vocab(sanitized_filename: str, - vocabulary_size: Optional[tf.Tensor] = None, - vocabulary_key: Optional[str] = None, - file_format: common_types - .VocabularyFileFormatType = DEFAULT_VOCABULARY_FILE_FORMAT): - """Registers the specificed vocabulary within the asset map. - - Args: - sanitized_filename: The santized filename of the vocabulary. - vocabulary_size: The size of the vocabulary. - vocabulary_key: The key of the vocabulary to use. - file_format: The format of the vocabulary file (text or tfrecord_gzip). - """ - if vocabulary_key is None: - vocabulary_key = sanitized_filename - filename = ('{}.tfrecord.gz'.format(sanitized_filename) - if file_format == 'tfrecord_gzip' else sanitized_filename) - annotators.annotate_asset(vocabulary_key, filename) - if vocabulary_size is not None: - annotators.annotate_vocab_size(vocabulary_key, vocabulary_size) + """Class for all vocab ordering types.""" + + # Orders vocabulary based on the simple frequency of the token + FREQUENCY = 1 + # Orders vocabulary based on the weighted frequency of the token + WEIGHTED_FREQUENCY = 2 + # Orders vocabulary based on the weighted mutual + # information of token with the label + WEIGHTED_MUTUAL_INFORMATION = 3 + # Experimental + WEIGHTED_LABELS = 4 + # Orders vocabulary based on the mutual information + # of token with the label and without weight. + MUTUAL_INFORMATION = 5 + + +def register_vocab( + sanitized_filename: str, + vocabulary_size: Optional[tf.Tensor] = None, + vocabulary_key: Optional[str] = None, + file_format: common_types.VocabularyFileFormatType = DEFAULT_VOCABULARY_FILE_FORMAT, +): + """Registers the specificed vocabulary within the asset map. + + Args: + ---- + sanitized_filename: The santized filename of the vocabulary. + vocabulary_size: The size of the vocabulary. + vocabulary_key: The key of the vocabulary to use. + file_format: The format of the vocabulary file (text or tfrecord_gzip). + """ + if vocabulary_key is None: + vocabulary_key = sanitized_filename + filename = ( + f"{sanitized_filename}.tfrecord.gz" + if file_format == "tfrecord_gzip" + else sanitized_filename + ) + annotators.annotate_asset(vocabulary_key, filename) + if vocabulary_size is not None: + annotators.annotate_vocab_size(vocabulary_key, vocabulary_size) def get_empy_vocabulary_dummy_value( - dtype: Union[tf.dtypes.DType, str]) -> Tuple[int, bytes]: - """Returns a vocabulary entry to use in case of an empty vocabulary.""" - # TODO(b/62272023) remove this workaround if/when fixed on tensorflow. - # If the vocabulary is empty add a dummy value with count one so - # the tensorflow index operations don't fail to initialize with empty - # tensors downstream. - dummy_value = (b'49d0cd50-04bb-48c0-bc6f-5b575dce351a' - if tf.dtypes.as_dtype(dtype) == tf.string else b'-1') - return (1, dummy_value) + dtype: Union[tf.dtypes.DType, str], +) -> Tuple[int, bytes]: + """Returns a vocabulary entry to use in case of an empty vocabulary.""" + # TODO(b/62272023) remove this workaround if/when fixed on tensorflow. + # If the vocabulary is empty add a dummy value with count one so + # the tensorflow index operations don't fail to initialize with empty + # tensors downstream. + dummy_value = ( + b"49d0cd50-04bb-48c0-bc6f-5b575dce351a" + if tf.dtypes.as_dtype(dtype) == tf.string + else b"-1" + ) + return (1, dummy_value) # TODO(b/117796748): Add coverage key feature input as alternative to `key_fn`. @@ -1724,212 +1932,228 @@ def vocabulary( file_format: common_types.VocabularyFileFormatType = DEFAULT_VOCABULARY_FILE_FORMAT, name: Optional[str] = None, ) -> common_types.TemporaryAnalyzerOutputType: - r"""Computes the unique values of `x` over the whole dataset. - - Computes The unique values taken by `x`, which can be a `Tensor`, - `SparseTensor`, or `RaggedTensor` of any size. The unique values will be - aggregated over all dimensions of `x` and all instances. - - In case `file_format` is 'text' and one of the tokens contains the '\n' or - '\r' characters or is empty it will be discarded. - - If an integer `Tensor` is provided, its semantic type should be categorical - not a continuous/numeric, since computing a vocabulary over a continuous - feature is not appropriate. - - The unique values are sorted by decreasing frequency and then reverse - lexicographical order (e.g. [('a', 5), ('c', 3), ('b', 3)]). This is true even - if `x` is numerical dtype (e.g. [('3', 5), ('2', 3), ('111', 3)]). - - For large datasets it is highly recommended to either set frequency_threshold - or top_k to control the size of the output, and also the run time of this - operation. - - When labels are provided, we filter the vocabulary based on the relationship - between the token's presence in a record and the label for that record, using - (possibly adjusted) Mutual Information. Note: If labels are provided, the x - input must be a unique set of per record, as the semantics of the mutual - information calculation depend on a multi-hot representation of the input. - Having unique input tokens per row is advisable but not required for a - frequency-based vocabulary. - - WARNING: The following is experimental and is still being actively worked on. - - Supply `key_fn` if you would like to generate a vocabulary with coverage over - specific keys. - - A "coverage vocabulary" is the union of two vocabulary "arms". The "standard - arm" of the vocabulary is equivalent to the one generated by the same function - call with no coverage arguments. Adding coverage only appends additional - entries to the end of the standard vocabulary. - - The "coverage arm" of the vocabulary is determined by taking the - `coverage_top_k` most frequent unique terms per key. A term's key is obtained - by applying `key_fn` to the term. Use `coverage_frequency_threshold` to lower - bound the frequency of entries in the coverage arm of the vocabulary. - - Note this is currently implemented for the case where the key is contained - within each vocabulary entry (b/117796748). - - Args: - x: A categorical/discrete input `Tensor`, `SparseTensor`, or `RaggedTensor` - with dtype tf.string or tf.int[8|16|32|64]. The inputs should generally be - unique per row (i.e. a bag of words/ngrams representation). - top_k: Limit the generated vocabulary to the first `top_k` elements. If set - to None, the full vocabulary is generated. - frequency_threshold: Limit the generated vocabulary only to elements whose - absolute frequency is >= to the supplied threshold. If set to None, the - full vocabulary is generated. Absolute frequency means the number of - occurrences of the element in the dataset, as opposed to the proportion of - instances that contain that element. - vocab_filename: The file name for the vocabulary file. If None, a file name - will be chosen based on the current scope. If not None, should be unique - within a given preprocessing function. NOTE To make your pipelines - resilient to implementation details please set `vocab_filename` when you - are using the vocab_filename on a downstream component. - store_frequency: If True, frequency of the words is stored in the vocabulary - file. In the case labels are provided, the mutual information is stored in - the file instead. Each line in the file will be of the form 'frequency - word'. NOTE: if this is True then the computed vocabulary cannot be used - with `tft.apply_vocabulary` directly, since frequencies are added to the - beginning of each row of the vocabulary, which the mapper will not ignore. - reserved_tokens: (Optional) A list of tokens that should appear in the - vocabulary regardless of their appearance in the input. These tokens would - maintain their order, and have a reserved spot at the beginning of the - vocabulary. Note: this field has no affect on cache. - weights: (Optional) Weights `Tensor` for the vocabulary. It must have the - same shape as x. - labels: (Optional) Labels dense `Tensor` for the vocabulary. If provided, - the vocabulary is calculated based on mutual information with the label, - rather than frequency. The labels must have the same batch dimension as x. - If x is sparse, labels should be a 1D tensor reflecting row-wise labels. - If x is dense, labels can either be a 1D tensor of row-wise labels, or a - dense tensor of the identical shape as x (i.e. element-wise labels). - Labels should be a discrete integerized tensor (If the label is numeric, - it should first be bucketized; If the label is a string, an integer - vocabulary should first be applied). Note: `CompositeTensor` labels are - not yet supported (b/134931826). WARNING: When labels are provided, the - frequency_threshold argument functions as a mutual information threshold, - which is a float. TODO(b/116308354): Fix confusing naming. - use_adjusted_mutual_info: If true, and labels are provided, calculate - vocabulary using adjusted rather than raw mutual information. - min_diff_from_avg: MI (or AMI) of a feature x label will be adjusted to zero - whenever the difference between count and the expected (average) count is - lower than min_diff_from_average. This can be thought of as a regularizing - parameter that pushes small MI/AMI values to zero. If None, a default - parameter will be selected based on the size of the dataset (see - calculate_recommended_min_diff_from_avg). - coverage_top_k: (Optional), (Experimental) The minimum number of elements - per key to be included in the vocabulary. - coverage_frequency_threshold: (Optional), (Experimental) Limit the coverage - arm of the vocabulary only to elements whose absolute frequency is >= this - threshold for a given key. - key_fn: (Optional), (Experimental) A fn that takes in a single entry of `x` - and returns the corresponding key for coverage calculation. If this is - `None`, no coverage arm is added to the vocabulary. - fingerprint_shuffle: (Optional), (Experimental) Whether to sort the - vocabularies by fingerprint instead of counts. This is useful for load - balancing on the training parameter servers. Shuffle only happens while - writing the files, so all the filters above (top_k, frequency_threshold, - etc) will still take effect. - file_format: (Optional) A str. The format of the resulting vocabulary file. - Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires - tensorflow>=2.4. The default value is 'text'. - name: (Optional) A name for this operation. - - Returns: - The path name for the vocabulary file containing the unique values of `x`. - - Raises: - ValueError: If `top_k` or `frequency_threshold` is negative. - If `coverage_top_k` or `coverage_frequency_threshold` is negative. - If either `coverage_top_k` or `coverage_frequency_threshold` is specified - and `key_fn` is not. - If `key_fn` is specified and neither `coverage_top_k`, nor - """ - top_k, frequency_threshold = _get_top_k_and_frequency_threshold( - top_k, frequency_threshold) - - if (coverage_top_k or coverage_frequency_threshold) and not key_fn: - raise ValueError('You must specify `key_fn` if you specify `coverage_top_k' - ' or `coverage_frequency_threshold` in `vocabulary`.') - - if key_fn and not (coverage_top_k or coverage_frequency_threshold): - raise ValueError('You must specify `coverage_top_k` or ' - '`coverage_frequency_threshold` if you specify `key_fn` in' - ' `vocabulary`.') - - if file_format not in ALLOWED_VOCABULARY_FILE_FORMATS: - raise ValueError( - '"{}" is not an accepted file_format. It should be one of: {}'.format( - file_format, ALLOWED_VOCABULARY_FILE_FORMATS)) - - coverage_top_k, coverage_frequency_threshold = ( - _get_top_k_and_frequency_threshold( - coverage_top_k, coverage_frequency_threshold)) - - if x.dtype != tf.string and not x.dtype.is_integer: - raise ValueError('expected tf.string or integer but got %r' % x.dtype) - - if labels is not None and not labels.dtype.is_integer: - raise ValueError('expected integer labels but got %r' % labels.dtype) - - if (frequency_threshold is None and labels is None and key_fn is None and - not fingerprint_shuffle and top_k is not None and - top_k <= LARGE_VOCAB_TOP_K): - logging.info('If the number of unique tokens is smaller than the provided ' - 'top_k or approximation error is acceptable, consider using ' - 'tft.experimental.approximate_vocabulary for a potentially ' - 'more efficient implementation.') - - with tf.compat.v1.name_scope(name, 'vocabulary'): - vocabulary_key = vocab_filename - vocab_filename = _get_vocab_filename(vocab_filename, store_frequency) - informativeness_threshold = float('-inf') - coverage_informativeness_threshold = float('-inf') - if labels is not None: - if weights is not None: - vocab_ordering_type = _VocabOrderingType.WEIGHTED_MUTUAL_INFORMATION - else: - vocab_ordering_type = _VocabOrderingType.MUTUAL_INFORMATION - # Correct for the overloaded `frequency_threshold` API. - if frequency_threshold is not None: - informativeness_threshold = frequency_threshold - frequency_threshold = 0.0 - if coverage_frequency_threshold is not None: - coverage_informativeness_threshold = coverage_frequency_threshold - coverage_frequency_threshold = 0.0 - elif weights is not None: - vocab_ordering_type = _VocabOrderingType.WEIGHTED_FREQUENCY - else: - vocab_ordering_type = _VocabOrderingType.FREQUENCY - analyzer_inputs = _get_vocabulary_analyzer_inputs( - vocab_ordering_type=vocab_ordering_type, - x=x, - file_format=file_format, - labels=labels, - weights=weights) - return _vocabulary_analyzer_nodes( - analyzer_inputs=analyzer_inputs, - input_dtype=x.dtype.name, - vocab_ordering_type=vocab_ordering_type, - vocab_filename=vocab_filename, - top_k=top_k, - frequency_threshold=frequency_threshold or 0, - informativeness_threshold=informativeness_threshold, - use_adjusted_mutual_info=use_adjusted_mutual_info, - min_diff_from_avg=min_diff_from_avg, - fingerprint_shuffle=fingerprint_shuffle, - store_frequency=store_frequency, - key_fn=key_fn, - coverage_top_k=coverage_top_k, - coverage_frequency_threshold=coverage_frequency_threshold or 0, - coverage_informativeness_threshold=coverage_informativeness_threshold, - file_format=file_format, - vocabulary_key=vocabulary_key, - reserved_tokens=reserved_tokens, + r"""Computes the unique values of `x` over the whole dataset. + + Computes The unique values taken by `x`, which can be a `Tensor`, + `SparseTensor`, or `RaggedTensor` of any size. The unique values will be + aggregated over all dimensions of `x` and all instances. + + In case `file_format` is 'text' and one of the tokens contains the '\n' or + '\r' characters or is empty it will be discarded. + + If an integer `Tensor` is provided, its semantic type should be categorical + not a continuous/numeric, since computing a vocabulary over a continuous + feature is not appropriate. + + The unique values are sorted by decreasing frequency and then reverse + lexicographical order (e.g. [('a', 5), ('c', 3), ('b', 3)]). This is true even + if `x` is numerical dtype (e.g. [('3', 5), ('2', 3), ('111', 3)]). + + For large datasets it is highly recommended to either set frequency_threshold + or top_k to control the size of the output, and also the run time of this + operation. + + When labels are provided, we filter the vocabulary based on the relationship + between the token's presence in a record and the label for that record, using + (possibly adjusted) Mutual Information. Note: If labels are provided, the x + input must be a unique set of per record, as the semantics of the mutual + information calculation depend on a multi-hot representation of the input. + Having unique input tokens per row is advisable but not required for a + frequency-based vocabulary. + + WARNING: The following is experimental and is still being actively worked on. + + Supply `key_fn` if you would like to generate a vocabulary with coverage over + specific keys. + + A "coverage vocabulary" is the union of two vocabulary "arms". The "standard + arm" of the vocabulary is equivalent to the one generated by the same function + call with no coverage arguments. Adding coverage only appends additional + entries to the end of the standard vocabulary. + + The "coverage arm" of the vocabulary is determined by taking the + `coverage_top_k` most frequent unique terms per key. A term's key is obtained + by applying `key_fn` to the term. Use `coverage_frequency_threshold` to lower + bound the frequency of entries in the coverage arm of the vocabulary. + + Note this is currently implemented for the case where the key is contained + within each vocabulary entry (b/117796748). + + Args: + ---- + x: A categorical/discrete input `Tensor`, `SparseTensor`, or `RaggedTensor` + with dtype tf.string or tf.int[8|16|32|64]. The inputs should generally be + unique per row (i.e. a bag of words/ngrams representation). + top_k: Limit the generated vocabulary to the first `top_k` elements. If set + to None, the full vocabulary is generated. + frequency_threshold: Limit the generated vocabulary only to elements whose + absolute frequency is >= to the supplied threshold. If set to None, the + full vocabulary is generated. Absolute frequency means the number of + occurrences of the element in the dataset, as opposed to the proportion of + instances that contain that element. + vocab_filename: The file name for the vocabulary file. If None, a file name + will be chosen based on the current scope. If not None, should be unique + within a given preprocessing function. NOTE To make your pipelines + resilient to implementation details please set `vocab_filename` when you + are using the vocab_filename on a downstream component. + store_frequency: If True, frequency of the words is stored in the vocabulary + file. In the case labels are provided, the mutual information is stored in + the file instead. Each line in the file will be of the form 'frequency + word'. NOTE: if this is True then the computed vocabulary cannot be used + with `tft.apply_vocabulary` directly, since frequencies are added to the + beginning of each row of the vocabulary, which the mapper will not ignore. + reserved_tokens: (Optional) A list of tokens that should appear in the + vocabulary regardless of their appearance in the input. These tokens would + maintain their order, and have a reserved spot at the beginning of the + vocabulary. Note: this field has no affect on cache. + weights: (Optional) Weights `Tensor` for the vocabulary. It must have the + same shape as x. + labels: (Optional) Labels dense `Tensor` for the vocabulary. If provided, + the vocabulary is calculated based on mutual information with the label, + rather than frequency. The labels must have the same batch dimension as x. + If x is sparse, labels should be a 1D tensor reflecting row-wise labels. + If x is dense, labels can either be a 1D tensor of row-wise labels, or a + dense tensor of the identical shape as x (i.e. element-wise labels). + Labels should be a discrete integerized tensor (If the label is numeric, + it should first be bucketized; If the label is a string, an integer + vocabulary should first be applied). Note: `CompositeTensor` labels are + not yet supported (b/134931826). WARNING: When labels are provided, the + frequency_threshold argument functions as a mutual information threshold, + which is a float. TODO(b/116308354): Fix confusing naming. + use_adjusted_mutual_info: If true, and labels are provided, calculate + vocabulary using adjusted rather than raw mutual information. + min_diff_from_avg: MI (or AMI) of a feature x label will be adjusted to zero + whenever the difference between count and the expected (average) count is + lower than min_diff_from_average. This can be thought of as a regularizing + parameter that pushes small MI/AMI values to zero. If None, a default + parameter will be selected based on the size of the dataset (see + calculate_recommended_min_diff_from_avg). + coverage_top_k: (Optional), (Experimental) The minimum number of elements + per key to be included in the vocabulary. + coverage_frequency_threshold: (Optional), (Experimental) Limit the coverage + arm of the vocabulary only to elements whose absolute frequency is >= this + threshold for a given key. + key_fn: (Optional), (Experimental) A fn that takes in a single entry of `x` + and returns the corresponding key for coverage calculation. If this is + `None`, no coverage arm is added to the vocabulary. + fingerprint_shuffle: (Optional), (Experimental) Whether to sort the + vocabularies by fingerprint instead of counts. This is useful for load + balancing on the training parameter servers. Shuffle only happens while + writing the files, so all the filters above (top_k, frequency_threshold, + etc) will still take effect. + file_format: (Optional) A str. The format of the resulting vocabulary file. + Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires + tensorflow>=2.4. The default value is 'text'. + name: (Optional) A name for this operation. + + Returns: + ------- + The path name for the vocabulary file containing the unique values of `x`. + + Raises: + ------ + ValueError: If `top_k` or `frequency_threshold` is negative. + If `coverage_top_k` or `coverage_frequency_threshold` is negative. + If either `coverage_top_k` or `coverage_frequency_threshold` is specified + and `key_fn` is not. + If `key_fn` is specified and neither `coverage_top_k`, nor + """ + top_k, frequency_threshold = _get_top_k_and_frequency_threshold( + top_k, frequency_threshold + ) + + if (coverage_top_k or coverage_frequency_threshold) and not key_fn: + raise ValueError( + "You must specify `key_fn` if you specify `coverage_top_k" + " or `coverage_frequency_threshold` in `vocabulary`." + ) + + if key_fn and not (coverage_top_k or coverage_frequency_threshold): + raise ValueError( + "You must specify `coverage_top_k` or " + "`coverage_frequency_threshold` if you specify `key_fn` in" + " `vocabulary`." + ) + + if file_format not in ALLOWED_VOCABULARY_FILE_FORMATS: + raise ValueError( + f'"{file_format}" is not an accepted file_format. It should be one of: {ALLOWED_VOCABULARY_FILE_FORMATS}' + ) + + coverage_top_k, coverage_frequency_threshold = _get_top_k_and_frequency_threshold( + coverage_top_k, coverage_frequency_threshold ) + if x.dtype != tf.string and not x.dtype.is_integer: + raise ValueError("expected tf.string or integer but got %r" % x.dtype) + + if labels is not None and not labels.dtype.is_integer: + raise ValueError("expected integer labels but got %r" % labels.dtype) + + if ( + frequency_threshold is None + and labels is None + and key_fn is None + and not fingerprint_shuffle + and top_k is not None + and top_k <= LARGE_VOCAB_TOP_K + ): + logging.info( + "If the number of unique tokens is smaller than the provided " + "top_k or approximation error is acceptable, consider using " + "tft.experimental.approximate_vocabulary for a potentially " + "more efficient implementation." + ) + + with tf.compat.v1.name_scope(name, "vocabulary"): + vocabulary_key = vocab_filename + vocab_filename = _get_vocab_filename(vocab_filename, store_frequency) + informativeness_threshold = float("-inf") + coverage_informativeness_threshold = float("-inf") + if labels is not None: + if weights is not None: + vocab_ordering_type = _VocabOrderingType.WEIGHTED_MUTUAL_INFORMATION + else: + vocab_ordering_type = _VocabOrderingType.MUTUAL_INFORMATION + # Correct for the overloaded `frequency_threshold` API. + if frequency_threshold is not None: + informativeness_threshold = frequency_threshold + frequency_threshold = 0.0 + if coverage_frequency_threshold is not None: + coverage_informativeness_threshold = coverage_frequency_threshold + coverage_frequency_threshold = 0.0 + elif weights is not None: + vocab_ordering_type = _VocabOrderingType.WEIGHTED_FREQUENCY + else: + vocab_ordering_type = _VocabOrderingType.FREQUENCY + analyzer_inputs = _get_vocabulary_analyzer_inputs( + vocab_ordering_type=vocab_ordering_type, + x=x, + file_format=file_format, + labels=labels, + weights=weights, + ) + return _vocabulary_analyzer_nodes( + analyzer_inputs=analyzer_inputs, + input_dtype=x.dtype.name, + vocab_ordering_type=vocab_ordering_type, + vocab_filename=vocab_filename, + top_k=top_k, + frequency_threshold=frequency_threshold or 0, + informativeness_threshold=informativeness_threshold, + use_adjusted_mutual_info=use_adjusted_mutual_info, + min_diff_from_avg=min_diff_from_avg, + fingerprint_shuffle=fingerprint_shuffle, + store_frequency=store_frequency, + key_fn=key_fn, + coverage_top_k=coverage_top_k, + coverage_frequency_threshold=coverage_frequency_threshold or 0, + coverage_informativeness_threshold=coverage_informativeness_threshold, + file_format=file_format, + vocabulary_key=vocabulary_key, + reserved_tokens=reserved_tokens, + ) + def _get_vocabulary_analyzer_inputs( vocab_ordering_type: int, @@ -1938,55 +2162,64 @@ def _get_vocabulary_analyzer_inputs( labels: Optional[Union[tf.Tensor, tf.SparseTensor]], weights: Optional[tf.Tensor], ): - """Helper for constructing analyzer inputs from tensors. - - Args: - vocab_ordering_type: VocabOrderingType specifying how to select vocabulary. - x: Tensor to compute vocabulary over. - file_format: The format of the resulting vocabulary file. Accepted formats - are 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires tensorflow>=2.4. - labels: Optional tensor of integerized labels. - weights: Optional tensor of weights. - - Returns: - A list of batch-reduced tensors to feed to vocabulary analysis. - """ - filter_regex = get_vocab_newline_characters_regex(x.dtype, file_format) - if vocab_ordering_type == _VocabOrderingType.WEIGHTED_MUTUAL_INFORMATION: - if not isinstance(labels, tf.SparseTensor): - labels = tf.reshape(labels, [-1]) - reduced_batch = tf_utils.reduce_batch_weighted_cooccurrences( - x, labels, weights, filter_regex=filter_regex) - return [ - reduced_batch.unique_x, reduced_batch.summed_weights_per_x, - reduced_batch.summed_positive_per_x_and_y, reduced_batch.counts_per_x - ] - elif vocab_ordering_type == _VocabOrderingType.MUTUAL_INFORMATION: - if not isinstance(labels, tf.SparseTensor): - labels = tf.reshape(labels, [-1]) - reduced_batch = tf_utils.reduce_batch_weighted_cooccurrences( - x, labels, weights, filter_regex=filter_regex) - return [ - reduced_batch.unique_x, reduced_batch.summed_positive_per_x_and_y, - reduced_batch.counts_per_x - ] - elif vocab_ordering_type == _VocabOrderingType.WEIGHTED_FREQUENCY: - reduced_batch = tf_utils.reduce_batch_weighted_counts( - x, weights, filter_regex=filter_regex) - return [reduced_batch.unique_x, reduced_batch.summed_weights_per_x] - else: - reduced_batch = tf_utils.reduce_batch_weighted_counts( - x, filter_regex=filter_regex) - return [reduced_batch.unique_x] + """Helper for constructing analyzer inputs from tensors. + + Args: + ---- + vocab_ordering_type: VocabOrderingType specifying how to select vocabulary. + x: Tensor to compute vocabulary over. + file_format: The format of the resulting vocabulary file. Accepted formats + are 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires tensorflow>=2.4. + labels: Optional tensor of integerized labels. + weights: Optional tensor of weights. + + Returns: + ------- + A list of batch-reduced tensors to feed to vocabulary analysis. + """ + filter_regex = get_vocab_newline_characters_regex(x.dtype, file_format) + if vocab_ordering_type == _VocabOrderingType.WEIGHTED_MUTUAL_INFORMATION: + if not isinstance(labels, tf.SparseTensor): + labels = tf.reshape(labels, [-1]) + reduced_batch = tf_utils.reduce_batch_weighted_cooccurrences( + x, labels, weights, filter_regex=filter_regex + ) + return [ + reduced_batch.unique_x, + reduced_batch.summed_weights_per_x, + reduced_batch.summed_positive_per_x_and_y, + reduced_batch.counts_per_x, + ] + elif vocab_ordering_type == _VocabOrderingType.MUTUAL_INFORMATION: + if not isinstance(labels, tf.SparseTensor): + labels = tf.reshape(labels, [-1]) + reduced_batch = tf_utils.reduce_batch_weighted_cooccurrences( + x, labels, weights, filter_regex=filter_regex + ) + return [ + reduced_batch.unique_x, + reduced_batch.summed_positive_per_x_and_y, + reduced_batch.counts_per_x, + ] + elif vocab_ordering_type == _VocabOrderingType.WEIGHTED_FREQUENCY: + reduced_batch = tf_utils.reduce_batch_weighted_counts( + x, weights, filter_regex=filter_regex + ) + return [reduced_batch.unique_x, reduced_batch.summed_weights_per_x] + else: + reduced_batch = tf_utils.reduce_batch_weighted_counts( + x, filter_regex=filter_regex + ) + return [reduced_batch.unique_x] def get_vocab_newline_characters_regex( - input_dtype: tf.dtypes.DType, - file_format: common_types.VocabularyFileFormatType) -> Optional[str]: - if input_dtype == tf.string and file_format == 'text': - return _EMPTY_STRING_OR_NEWLINE_CHARS_REGEX - else: - return None + input_dtype: tf.dtypes.DType, file_format: common_types.VocabularyFileFormatType +) -> Optional[str]: + if input_dtype == tf.string and file_format == "text": + return _EMPTY_STRING_OR_NEWLINE_CHARS_REGEX + else: + return None def _vocabulary_analyzer_nodes( @@ -1996,7 +2229,7 @@ def _vocabulary_analyzer_nodes( vocab_filename: str, top_k: Optional[int] = None, frequency_threshold: int = 0, - informativeness_threshold: float = float('-inf'), + informativeness_threshold: float = float("-inf"), use_adjusted_mutual_info: bool = False, min_diff_from_avg: Optional[int] = None, fingerprint_shuffle: bool = False, @@ -2004,293 +2237,317 @@ def _vocabulary_analyzer_nodes( key_fn: Optional[Callable[[Any], Any]] = None, coverage_top_k: Optional[int] = None, coverage_frequency_threshold: float = 0.0, - coverage_informativeness_threshold: float = float('-inf'), + coverage_informativeness_threshold: float = float("-inf"), file_format: common_types.VocabularyFileFormatType = DEFAULT_VOCABULARY_FILE_FORMAT, vocabulary_key: Optional[str] = None, reserved_tokens: Optional[Union[Sequence[str], tf.Tensor]] = None, ) -> common_types.TemporaryAnalyzerOutputType: - """Internal helper for analyzing vocab. See `vocabulary` doc string.""" - - input_values_node = analyzer_nodes.get_input_tensors_value_nodes( - analyzer_inputs) - - accumulate_output_value_node = nodes.apply_operation( - analyzer_nodes.VocabularyAccumulate, - input_values_node, - vocab_ordering_type=vocab_ordering_type, - input_dtype=input_dtype) - - merge_output_value_node = nodes.apply_operation( - analyzer_nodes.VocabularyMerge, - accumulate_output_value_node, - use_adjusted_mutual_info=use_adjusted_mutual_info, - min_diff_from_avg=min_diff_from_avg, - vocab_ordering_type=vocab_ordering_type) - - filtered_value_node = nodes.apply_operation( - analyzer_nodes.VocabularyPrune, - merge_output_value_node, - coverage_top_k=coverage_top_k, - coverage_frequency_threshold=coverage_frequency_threshold, - coverage_informativeness_threshold=coverage_informativeness_threshold, - key_fn=key_fn, - top_k=top_k, - frequency_threshold=frequency_threshold, - informativeness_threshold=informativeness_threshold, - input_dtype=input_dtype) - - reserved_tokens_size = 0 - extra_apply_order_and_write_op_args = [] - if reserved_tokens is not None: - reserved_tokens_size = tf_utils.register_vocabulary_reserved_tokens( - vocab_filename, reserved_tokens + """Internal helper for analyzing vocab. See `vocabulary` doc string.""" + input_values_node = analyzer_nodes.get_input_tensors_value_nodes(analyzer_inputs) + + accumulate_output_value_node = nodes.apply_operation( + analyzer_nodes.VocabularyAccumulate, + input_values_node, + vocab_ordering_type=vocab_ordering_type, + input_dtype=input_dtype, + ) + + merge_output_value_node = nodes.apply_operation( + analyzer_nodes.VocabularyMerge, + accumulate_output_value_node, + use_adjusted_mutual_info=use_adjusted_mutual_info, + min_diff_from_avg=min_diff_from_avg, + vocab_ordering_type=vocab_ordering_type, + ) + + filtered_value_node = nodes.apply_operation( + analyzer_nodes.VocabularyPrune, + merge_output_value_node, + coverage_top_k=coverage_top_k, + coverage_frequency_threshold=coverage_frequency_threshold, + coverage_informativeness_threshold=coverage_informativeness_threshold, + key_fn=key_fn, + top_k=top_k, + frequency_threshold=frequency_threshold, + informativeness_threshold=informativeness_threshold, + input_dtype=input_dtype, ) - extra_apply_order_and_write_op_args.append( - nodes.apply_operation( - analyzer_nodes.ExtractVocabularyReservedTokens, name=vocab_filename + + reserved_tokens_size = 0 + extra_apply_order_and_write_op_args = [] + if reserved_tokens is not None: + reserved_tokens_size = tf_utils.register_vocabulary_reserved_tokens( + vocab_filename, reserved_tokens ) + extra_apply_order_and_write_op_args.append( + nodes.apply_operation( + analyzer_nodes.ExtractVocabularyReservedTokens, name=vocab_filename + ) + ) + + vocab_filename_node = nodes.apply_operation( + analyzer_nodes.VocabularyOrderAndWrite, + filtered_value_node, + *extra_apply_order_and_write_op_args, + vocab_filename=vocab_filename, + store_frequency=store_frequency, + fingerprint_shuffle=fingerprint_shuffle, + input_dtype=input_dtype, + file_format=file_format, + # LINT.IfChange(input_is_sorted) + input_is_sorted=( + top_k is not None and key_fn is None and not fingerprint_shuffle + ), + # LINT.ThenChange(beam/analyzer_impls.py:top_k_impl) + ) + + scope = tf.compat.v1.get_default_graph().get_name_scope() + unfiltered_vocab_size_node = nodes.apply_operation( + analyzer_nodes.VocabularyCount, + merge_output_value_node, + label=f"VocabularyCountUnfiltered[{scope}]", + ) + unfiltered_vocab_size = analyzer_nodes.bind_future_as_tensor( + unfiltered_vocab_size_node, + analyzer_nodes.TensorInfo(tf.int64, [], None), + name=f"{vocab_filename}_unpruned_vocab_size", + ) + filtered_vocab_size_node = nodes.apply_operation( + analyzer_nodes.VocabularyCount, + filtered_value_node, + label=f"VocabularyCountFiltered[{scope}]", + ) + filtered_vocab_size = analyzer_nodes.bind_future_as_tensor( + filtered_vocab_size_node, + analyzer_nodes.TensorInfo(tf.int64, [], None), + name=f"{vocab_filename}_pruned_vocab_size", ) - vocab_filename_node = nodes.apply_operation( - analyzer_nodes.VocabularyOrderAndWrite, - filtered_value_node, - *extra_apply_order_and_write_op_args, - vocab_filename=vocab_filename, - store_frequency=store_frequency, - fingerprint_shuffle=fingerprint_shuffle, - input_dtype=input_dtype, - file_format=file_format, - # LINT.IfChange(input_is_sorted) - input_is_sorted=( - top_k is not None and key_fn is None and not fingerprint_shuffle - ), - # LINT.ThenChange(beam/analyzer_impls.py:top_k_impl) - ) - - scope = tf.compat.v1.get_default_graph().get_name_scope() - unfiltered_vocab_size_node = nodes.apply_operation( - analyzer_nodes.VocabularyCount, - merge_output_value_node, - label=f'VocabularyCountUnfiltered[{scope}]') - unfiltered_vocab_size = analyzer_nodes.bind_future_as_tensor( - unfiltered_vocab_size_node, - analyzer_nodes.TensorInfo(tf.int64, [], None), - name=f'{vocab_filename}_unpruned_vocab_size') - filtered_vocab_size_node = nodes.apply_operation( - analyzer_nodes.VocabularyCount, - filtered_value_node, - label=f'VocabularyCountFiltered[{scope}]') - filtered_vocab_size = analyzer_nodes.bind_future_as_tensor( - filtered_vocab_size_node, - analyzer_nodes.TensorInfo(tf.int64, [], None), - name=f'{vocab_filename}_pruned_vocab_size') - - unfiltered_vocab_size += reserved_tokens_size - filtered_vocab_size += reserved_tokens_size - _maybe_annotate_vocab_metadata(vocab_filename, unfiltered_vocab_size, - filtered_vocab_size) - - register_vocab( - vocab_filename, - vocabulary_size=filtered_vocab_size, - vocabulary_key=vocabulary_key, - file_format=file_format) - return analyzer_nodes.wrap_as_tensor(vocab_filename_node) + unfiltered_vocab_size += reserved_tokens_size + filtered_vocab_size += reserved_tokens_size + _maybe_annotate_vocab_metadata( + vocab_filename, unfiltered_vocab_size, filtered_vocab_size + ) + + register_vocab( + vocab_filename, + vocabulary_size=filtered_vocab_size, + vocabulary_key=vocabulary_key, + file_format=file_format, + ) + return analyzer_nodes.wrap_as_tensor(vocab_filename_node) def calculate_recommended_min_diff_from_avg(dataset_size: int) -> int: - """Calculates a recommended min_diff_from_avg argument to tft.vocabulary. - - Computes a default min_diff_from_average parameter based on the size of the - dataset. The MI (or AMI) of a token x label will be pushed to zero whenever - the difference between the observed and the expected (average) cooccurrence - with the label is < min_diff_from_average. This can be thought of as a - regularization parameter for mutual information based vocabularies. - - Args: - dataset_size: The number of recods in the dataset. The bigger the dataset, - the higher the min_diff_from_average will be. - - Returns: - An integer that is recomended to use as the min_diff_from_avg parameter of - `vocabulary`. - """ - # The minimum and maximum min_diff_from_avg parameter to use. - min_value, max_value = 2, 25 - # Heuristics for a "small" and "large" dataset. The selected parameter will - # be between min_value and max_value depending on where the dataset_size falls - # relative to these values. - small_dataset_size, large_dataset_size = 10000, 1000000 - return int( - builtin_min( - max_value, - builtin_max(min_value, (dataset_size - small_dataset_size) / - (large_dataset_size - small_dataset_size) * - (max_value - min_value) + min_value))) + """Calculates a recommended min_diff_from_avg argument to tft.vocabulary. + + Computes a default min_diff_from_average parameter based on the size of the + dataset. The MI (or AMI) of a token x label will be pushed to zero whenever + the difference between the observed and the expected (average) cooccurrence + with the label is < min_diff_from_average. This can be thought of as a + regularization parameter for mutual information based vocabularies. + + Args: + ---- + dataset_size: The number of recods in the dataset. The bigger the dataset, + the higher the min_diff_from_average will be. + + Returns: + ------- + An integer that is recomended to use as the min_diff_from_avg parameter of + `vocabulary`. + """ + # The minimum and maximum min_diff_from_avg parameter to use. + min_value, max_value = 2, 25 + # Heuristics for a "small" and "large" dataset. The selected parameter will + # be between min_value and max_value depending on where the dataset_size falls + # relative to these values. + small_dataset_size, large_dataset_size = 10000, 1000000 + return int( + builtin_min( + max_value, + builtin_max( + min_value, + (dataset_size - small_dataset_size) + / (large_dataset_size - small_dataset_size) + * (max_value - min_value) + + min_value, + ), + ) + ) # Code related to this class is performance sensitive, so (micro-)benchmarks # should be run when it is updated. class QuantilesCombiner(analyzer_nodes.Combiner): - """Computes quantiles on the PCollection. - - This implementation is based on go/squawd. - For additional details on the algorithm, such as streaming and summary, - see also http://web.cs.ucla.edu/~weiwang/paper/SSDBM07_2.pdf - """ - - def __init__(self, - num_quantiles, - epsilon, - bucket_numpy_dtype, - has_weights=False, - output_shape=None, - include_max_and_min=False, - feature_shape=None): - self._num_quantiles = num_quantiles - self._epsilon = epsilon - # Expected upper bound on the total number of input elements per feature. - # Theoretical error bound is guaranteed to be <= epsilon as long as the - # number of input elements is <= max_num_values. - self._max_num_values = 1 << 32 - self._bucket_numpy_dtype = bucket_numpy_dtype - self._has_weights = has_weights - self._include_max_and_min = include_max_and_min - num_outputs = (num_quantiles + - 1) if include_max_and_min else (num_quantiles - 1) - if feature_shape is None: - feature_shape = [] - elif isinstance(feature_shape, int): - feature_shape = [feature_shape] - if output_shape is None: - self._output_shape = list(feature_shape) + [num_outputs] - else: - self._output_shape = output_shape - self._num_features = np.prod(feature_shape, dtype=np.int64).item() - - def create_accumulator(self): - return sketches.QuantilesSketch(self._epsilon, self._max_num_values, - self._num_features) - - def add_input(self, accumulator, next_input): - # Flattened input array will be split on inputs for each feature. - # C-contiguous order of flattened array is required. - flat_values = pa.array(np.ravel(next_input[0])) - if self._has_weights: - flat_weights = pa.array(np.ravel(next_input[1])) - accumulator.AddValues(flat_values, flat_weights) - else: - accumulator.AddValues(flat_values) - return accumulator - - def merge_accumulators(self, accumulators): - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.Merge(accumulator) - return result - - def compact(self, accumulator): - accumulator.Compact() - return accumulator - - def extract_output(self, accumulator): - result = accumulator.GetQuantiles(self._num_quantiles).to_pylist() - if not result: - return [np.zeros(self._output_shape, self._bucket_numpy_dtype)] - result = np.array(result, self._bucket_numpy_dtype) - # Trim elementwise results if max and min should be excluded. - if not self._include_max_and_min: - result = result[:, 1:-1] - return [np.reshape(result, self._output_shape)] - - def output_tensor_infos(self): - return [ - analyzer_nodes.TensorInfo( - tf.as_dtype(self._bucket_numpy_dtype), self._output_shape, None) - ] - - @property - def accumulator_coder(self): - return _QuantilesSketchCacheCoder() + """Computes quantiles on the PCollection. + + This implementation is based on go/squawd. + For additional details on the algorithm, such as streaming and summary, + see also http://web.cs.ucla.edu/~weiwang/paper/SSDBM07_2.pdf + """ + + def __init__( + self, + num_quantiles, + epsilon, + bucket_numpy_dtype, + has_weights=False, + output_shape=None, + include_max_and_min=False, + feature_shape=None, + ): + self._num_quantiles = num_quantiles + self._epsilon = epsilon + # Expected upper bound on the total number of input elements per feature. + # Theoretical error bound is guaranteed to be <= epsilon as long as the + # number of input elements is <= max_num_values. + self._max_num_values = 1 << 32 + self._bucket_numpy_dtype = bucket_numpy_dtype + self._has_weights = has_weights + self._include_max_and_min = include_max_and_min + num_outputs = ( + (num_quantiles + 1) if include_max_and_min else (num_quantiles - 1) + ) + if feature_shape is None: + feature_shape = [] + elif isinstance(feature_shape, int): + feature_shape = [feature_shape] + if output_shape is None: + self._output_shape = list(feature_shape) + [num_outputs] + else: + self._output_shape = output_shape + self._num_features = np.prod(feature_shape, dtype=np.int64).item() + + def create_accumulator(self): + return sketches.QuantilesSketch( + self._epsilon, self._max_num_values, self._num_features + ) + + def add_input(self, accumulator, next_input): + # Flattened input array will be split on inputs for each feature. + # C-contiguous order of flattened array is required. + flat_values = pa.array(np.ravel(next_input[0])) + if self._has_weights: + flat_weights = pa.array(np.ravel(next_input[1])) + accumulator.AddValues(flat_values, flat_weights) + else: + accumulator.AddValues(flat_values) + return accumulator + + def merge_accumulators(self, accumulators): + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.Merge(accumulator) + return result + + def compact(self, accumulator): + accumulator.Compact() + return accumulator + + def extract_output(self, accumulator): + result = accumulator.GetQuantiles(self._num_quantiles).to_pylist() + if not result: + return [np.zeros(self._output_shape, self._bucket_numpy_dtype)] + result = np.array(result, self._bucket_numpy_dtype) + # Trim elementwise results if max and min should be excluded. + if not self._include_max_and_min: + result = result[:, 1:-1] + return [np.reshape(result, self._output_shape)] + + def output_tensor_infos(self): + return [ + analyzer_nodes.TensorInfo( + tf.as_dtype(self._bucket_numpy_dtype), self._output_shape, None + ) + ] + + @property + def accumulator_coder(self): + return _QuantilesSketchCacheCoder() class _QuantilesSketchCacheCoder(analyzer_nodes.CacheCoder): - """Cache coder for the quantiles accumulator.""" + """Cache coder for the quantiles accumulator.""" - def encode_cache(self, accumulator): - return pickle.dumps(accumulator) + def encode_cache(self, accumulator): + return pickle.dumps(accumulator) - def decode_cache(self, encoded_accumulator): - return pickle.loads(encoded_accumulator) + def decode_cache(self, encoded_accumulator): + return pickle.loads(encoded_accumulator) @common.log_api_use(common.ANALYZER_COLLECTION) -def quantiles(x: tf.Tensor, - num_buckets: int, - epsilon: float, - weights: Optional[tf.Tensor] = None, - reduce_instance_dims: bool = True, - name: Optional[str] = None) -> tf.Tensor: - """Computes the quantile boundaries of a `Tensor` over the whole dataset. - - Quantile boundaries are computed using approximate quantiles, - and error tolerance is specified using `epsilon`. The boundaries divide the - input tensor into approximately equal `num_buckets` parts. - See go/squawd for details, and how to control the error due to approximation. - NaN input values and values with NaN weights are ignored. - - Args: - x: An input `Tensor`. - num_buckets: Values in the `x` are divided into approximately equal-sized - buckets, where the number of buckets is `num_buckets`. The number of - returned quantiles is `num_buckets` - 1. - epsilon: Error tolerance, typically a small fraction close to zero (e.g. - 0.01). Higher values of epsilon increase the quantile approximation, and - hence result in more unequal buckets, but could improve performance, - and resource consumption. Some measured results on memory consumption: - For epsilon = 0.001, the amount of memory for each buffer to hold the - summary for 1 trillion input values is ~25000 bytes. If epsilon is - relaxed to 0.01, the buffer size drops to ~2000 bytes for the same input - size. The buffer size also determines the amount of work in the - different stages of the beam pipeline, in general, larger epsilon - results in fewer and smaller stages, and less time. For more performance - trade-offs see also http://web.cs.ucla.edu/~weiwang/paper/SSDBM07_2.pdf - weights: (Optional) Weights tensor for the quantiles. Tensor must have the - same batch size as x. - reduce_instance_dims: By default collapses the batch and instance dimensions - to arrive at a single output vector. If False, only collapses the batch - dimension and outputs a vector of the same shape as the input. - name: (Optional) A name for this operation. - - Returns: - The bucket boundaries represented as a list, with num_bucket-1 elements, - unless reduce_instance_dims is False, which results in a Tensor of - shape x.shape + [num_bucket-1]. - See code below for discussion on the type of bucket boundaries. - """ - # Quantile ops convert input values to double under the hood. Keep bucket - # boundaries as float for all numeric types. - bucket_dtype = tf.float32 - with tf.compat.v1.name_scope(name, 'quantiles'): - if weights is None: - analyzer_inputs = [x] - has_weights = False - else: - analyzer_inputs = [x, weights] - has_weights = True - feature_shape = [] if reduce_instance_dims else x.get_shape().as_list()[1:] - output_shape = (feature_shape if feature_shape else [1]) + [num_buckets - 1] - combiner = QuantilesCombiner( - num_buckets, - epsilon, - bucket_dtype.as_numpy_dtype, - has_weights=has_weights, - output_shape=output_shape, - feature_shape=feature_shape) - (quantile_boundaries,) = _apply_cacheable_combiner(combiner, - *analyzer_inputs) - return quantile_boundaries +def quantiles( + x: tf.Tensor, + num_buckets: int, + epsilon: float, + weights: Optional[tf.Tensor] = None, + reduce_instance_dims: bool = True, + name: Optional[str] = None, +) -> tf.Tensor: + """Computes the quantile boundaries of a `Tensor` over the whole dataset. + + Quantile boundaries are computed using approximate quantiles, + and error tolerance is specified using `epsilon`. The boundaries divide the + input tensor into approximately equal `num_buckets` parts. + See go/squawd for details, and how to control the error due to approximation. + NaN input values and values with NaN weights are ignored. + + Args: + ---- + x: An input `Tensor`. + num_buckets: Values in the `x` are divided into approximately equal-sized + buckets, where the number of buckets is `num_buckets`. The number of + returned quantiles is `num_buckets` - 1. + epsilon: Error tolerance, typically a small fraction close to zero (e.g. + 0.01). Higher values of epsilon increase the quantile approximation, and + hence result in more unequal buckets, but could improve performance, + and resource consumption. Some measured results on memory consumption: + For epsilon = 0.001, the amount of memory for each buffer to hold the + summary for 1 trillion input values is ~25000 bytes. If epsilon is + relaxed to 0.01, the buffer size drops to ~2000 bytes for the same input + size. The buffer size also determines the amount of work in the + different stages of the beam pipeline, in general, larger epsilon + results in fewer and smaller stages, and less time. For more performance + trade-offs see also http://web.cs.ucla.edu/~weiwang/paper/SSDBM07_2.pdf + weights: (Optional) Weights tensor for the quantiles. Tensor must have the + same batch size as x. + reduce_instance_dims: By default collapses the batch and instance dimensions + to arrive at a single output vector. If False, only collapses the batch + dimension and outputs a vector of the same shape as the input. + name: (Optional) A name for this operation. + + Returns: + ------- + The bucket boundaries represented as a list, with num_bucket-1 elements, + unless reduce_instance_dims is False, which results in a Tensor of + shape x.shape + [num_bucket-1]. + See code below for discussion on the type of bucket boundaries. + """ + # Quantile ops convert input values to double under the hood. Keep bucket + # boundaries as float for all numeric types. + bucket_dtype = tf.float32 + with tf.compat.v1.name_scope(name, "quantiles"): + if weights is None: + analyzer_inputs = [x] + has_weights = False + else: + analyzer_inputs = [x, weights] + has_weights = True + feature_shape = [] if reduce_instance_dims else x.get_shape().as_list()[1:] + output_shape = (feature_shape if feature_shape else [1]) + [num_buckets - 1] + combiner = QuantilesCombiner( + num_buckets, + epsilon, + bucket_dtype.as_numpy_dtype, + has_weights=has_weights, + output_shape=output_shape, + feature_shape=feature_shape, + ) + (quantile_boundaries,) = _apply_cacheable_combiner(combiner, *analyzer_inputs) + return quantile_boundaries def _quantiles_per_key( @@ -2299,403 +2556,438 @@ def _quantiles_per_key( num_buckets: int, epsilon: float, weights: Optional[tf.Tensor] = None, - name: Optional[str] = None + name: Optional[str] = None, ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, int]: - """Like quantiles but per-key. - - For private use in tf.Transform implementation only. - - Args: - x: An input `Tensor`. - key: An input `Tensor` with rank 1 and size same as the fist dimension of - `x`. All values of `x` will be aggregated according to the corresponding - value of `key`. - num_buckets: See `quantiles`. - epsilon: See `quantiles`. - weights: See `quantiles`. - name: (Optional) A name for this operation. - - Returns: - A 4-tuple of (boundaries, scale, shift, num_buckets). - The returned boundaries is a 1-d Tensor of size: - ((num_buckets - 2) * num_keys) + 1 - - And the returned scale and shift 1-d Tensors can be used to transform a - value before applying bucketization and shift the resulting bucket. - So the transformation of each input x before computing its bucket should be: - F(x, key) = x * scale_factor_per_key[key] + shift_per_key[key] - - For example, if there are 2 keys, and the following boundaries are computed - for them: [[0, 1, 2], [0, 1, 2]], this will return: - boundaries: [0, 0.5, 1, 1.5, 2] - scale_factor_per_key: [0.5, 0.5] - shift_per_key: [0, 1] - num_buckets: 4 - - Raises: - ValueError: If key has wrong dtype. - """ - if key.dtype != tf.string: - raise ValueError('key must have type tf.string') - # Quantile ops convert input values to double under the hood. Keep bucket - # boundaries as float for all numeric types. - bucket_dtype = tf.float32 - with tf.compat.v1.name_scope(name, 'quantiles_by_key'): - combiner = QuantilesCombiner( - num_buckets, - epsilon, - bucket_dtype.as_numpy_dtype, - has_weights=weights is not None, - output_shape=(num_buckets - 1,)) - - input_values_node = analyzer_nodes.get_input_tensors_value_nodes(( - key, x) if weights is None else (key, x, weights)) - - accumulate_outputs_value_nodes = nodes.apply_multi_output_operation( - analyzer_nodes.CacheableCombinePerKeyAccumulate, - input_values_node, - combiner=combiner) - - merge_output_value_node = nodes.apply_operation( - analyzer_nodes.CacheableCombinePerKeyMerge, - *accumulate_outputs_value_nodes, - combiner=combiner) + """Like quantiles but per-key. - key_value_node, bucket_boundaries = nodes.apply_multi_output_operation( - analyzer_nodes.CacheableCombinePerKeyFormatKeys, - merge_output_value_node, - combiner=combiner) + For private use in tf.Transform implementation only. - boundaries, scale_factor, shift, num_buckets_node = ( - nodes.apply_multi_output_operation( - analyzer_nodes.ScaleAndFlattenPerKeyBucketBouandaries, - bucket_boundaries, - output_tensor_dtype=bucket_dtype)) + Args: + ---- + x: An input `Tensor`. + key: An input `Tensor` with rank 1 and size same as the fist dimension of + `x`. All values of `x` will be aggregated according to the corresponding + value of `key`. + num_buckets: See `quantiles`. + epsilon: See `quantiles`. + weights: See `quantiles`. + name: (Optional) A name for this operation. - return tuple( - map(analyzer_nodes.wrap_as_tensor, - [key_value_node, boundaries, scale_factor, shift, num_buckets_node - ])) + Returns: + ------- + A 4-tuple of (boundaries, scale, shift, num_buckets). + The returned boundaries is a 1-d Tensor of size: + ((num_buckets - 2) * num_keys) + 1 + + And the returned scale and shift 1-d Tensors can be used to transform a + value before applying bucketization and shift the resulting bucket. + So the transformation of each input x before computing its bucket should be: + F(x, key) = x * scale_factor_per_key[key] + shift_per_key[key] + + For example, if there are 2 keys, and the following boundaries are computed + for them: [[0, 1, 2], [0, 1, 2]], this will return: + boundaries: [0, 0.5, 1, 1.5, 2] + scale_factor_per_key: [0.5, 0.5] + shift_per_key: [0, 1] + num_buckets: 4 + + Raises: + ------ + ValueError: If key has wrong dtype. + """ + if key.dtype != tf.string: + raise ValueError("key must have type tf.string") + # Quantile ops convert input values to double under the hood. Keep bucket + # boundaries as float for all numeric types. + bucket_dtype = tf.float32 + with tf.compat.v1.name_scope(name, "quantiles_by_key"): + combiner = QuantilesCombiner( + num_buckets, + epsilon, + bucket_dtype.as_numpy_dtype, + has_weights=weights is not None, + output_shape=(num_buckets - 1,), + ) + input_values_node = analyzer_nodes.get_input_tensors_value_nodes( + (key, x) if weights is None else (key, x, weights) + ) -class CovarianceCombiner(analyzer_nodes.Combiner): - """Combines the PCollection to compute the biased covariance matrix.""" - - def __init__(self, output_shape, numpy_dtype=np.float64): - """Store the dtype and shape for np arrays/matrices for precision.""" - self._output_shape = output_shape - self._numpy_dtype = numpy_dtype - - def create_accumulator(self): - """Create an accumulator with all zero entries.""" - return [ - np.zeros((self._output_shape[0], self._output_shape[0]), - self._numpy_dtype), - np.zeros((self._output_shape[0],), self._numpy_dtype), - np.zeros((), self._numpy_dtype) - ] - - def add_input(self, accumulator, batch_values): - """Compute sum of input cross-terms, sum of inputs, and count. - - The cross terms for a numeric 1d array x are given by the set: - {z_ij = x_i * x_j for all indices i and j}. This is stored as a 2d array. - Since next_input is an array of 1d numeric arrays (i.e. a 2d array), - matmul(transpose(next_input), next_input) will automatically sum up - the cross terms of each 1d array in next_input. + accumulate_outputs_value_nodes = nodes.apply_multi_output_operation( + analyzer_nodes.CacheableCombinePerKeyAccumulate, + input_values_node, + combiner=combiner, + ) - Args: - accumulator: running sum of cross terms, input vectors, and count - batch_values: entries from the pipeline, which must be single element list - containing a 2d array - representing multiple 1d arrays + merge_output_value_node = nodes.apply_operation( + analyzer_nodes.CacheableCombinePerKeyMerge, + *accumulate_outputs_value_nodes, + combiner=combiner, + ) - Returns: - An accumulator with next_input considered in its running list of - sum_product, sum_vectors, and count of input rows. - """ - # Expect a single input representing the batch for the input tensor. - batch_value, = batch_values + key_value_node, bucket_boundaries = nodes.apply_multi_output_operation( + analyzer_nodes.CacheableCombinePerKeyFormatKeys, + merge_output_value_node, + combiner=combiner, + ) - assert len(np.shape(batch_value)) == 2 + boundaries, scale_factor, shift, num_buckets_node = ( + nodes.apply_multi_output_operation( + analyzer_nodes.ScaleAndFlattenPerKeyBucketBouandaries, + bucket_boundaries, + output_tensor_dtype=bucket_dtype, + ) + ) - batch_cross_terms = np.matmul( - np.transpose(batch_value), - batch_value - ).astype(self._numpy_dtype) + return tuple( + map( + analyzer_nodes.wrap_as_tensor, + [key_value_node, boundaries, scale_factor, shift, num_buckets_node], + ) + ) - batch_sum = np.array(np.sum(batch_value, axis=0), self._numpy_dtype) - batch_count = np.shape(batch_value)[0] - sum_product, sum_vectors, count = accumulator +class CovarianceCombiner(analyzer_nodes.Combiner): + """Combines the PCollection to compute the biased covariance matrix.""" + + def __init__(self, output_shape, numpy_dtype=np.float64): + """Store the dtype and shape for np arrays/matrices for precision.""" + self._output_shape = output_shape + self._numpy_dtype = numpy_dtype + + def create_accumulator(self): + """Create an accumulator with all zero entries.""" + return [ + np.zeros((self._output_shape[0], self._output_shape[0]), self._numpy_dtype), + np.zeros((self._output_shape[0],), self._numpy_dtype), + np.zeros((), self._numpy_dtype), + ] + + def add_input(self, accumulator, batch_values): + """Compute sum of input cross-terms, sum of inputs, and count. + + The cross terms for a numeric 1d array x are given by the set: + {z_ij = x_i * x_j for all indices i and j}. This is stored as a 2d array. + Since next_input is an array of 1d numeric arrays (i.e. a 2d array), + matmul(transpose(next_input), next_input) will automatically sum up + the cross terms of each 1d array in next_input. + + Args: + ---- + accumulator: running sum of cross terms, input vectors, and count + batch_values: entries from the pipeline, which must be single element list + containing a 2d array + representing multiple 1d arrays + + Returns: + ------- + An accumulator with next_input considered in its running list of + sum_product, sum_vectors, and count of input rows. + """ + # Expect a single input representing the batch for the input tensor. + (batch_value,) = batch_values + + assert len(np.shape(batch_value)) == 2 + + batch_cross_terms = np.matmul(np.transpose(batch_value), batch_value).astype( + self._numpy_dtype + ) - return [ - sum_product + batch_cross_terms, sum_vectors + batch_sum, - count + batch_count - ] + batch_sum = np.array(np.sum(batch_value, axis=0), self._numpy_dtype) + batch_count = np.shape(batch_value)[0] + + sum_product, sum_vectors, count = accumulator + + return [ + sum_product + batch_cross_terms, + sum_vectors + batch_sum, + count + batch_count, + ] + + def merge_accumulators(self, accumulators): + """Sums values in each accumulator entry.""" + products, vectors, counts = zip(*accumulators) + return [ + np.sum(products, axis=0), + np.sum(vectors, axis=0), + np.sum(counts, axis=0), + ] + + def extract_output(self, accumulator): + """Run covariance logic on sum_product, sum of input vectors, and count. + + The formula used to compute the covariance is cov(x) = E(xx^T) - uu^T, + where x is the original input to the combiner, and u = mean(x). + E(xx^T) is computed by dividing sum of cross terms (index 0) by count + (index 2). u is computed by taking the sum of rows (index 1) and dividing by + the count (index 2). + + Args: + ---- + accumulator: final accumulator as a list of the sum of cross-terms matrix, + sum of input vectors, and count. + + Returns: + ------- + A list containing a single 2d ndarray, the covariance matrix. + """ + sum_product, sum_vectors, count = accumulator + if count == 0: + return [np.zeros(self._output_shape, self._numpy_dtype)] + expected_cross_terms = sum_product / count + expected_terms = sum_vectors / count + return [ + np.ndarray.astype( # TODO(b/64987151): # pytype: disable=attribute-error + expected_cross_terms - np.outer(expected_terms, expected_terms), + self._numpy_dtype, + ) + ] + + def output_tensor_infos(self): + return [ + analyzer_nodes.TensorInfo( + tf.as_dtype(self._numpy_dtype), self._output_shape, None + ) + ] + + @property + def accumulator_coder(self): + # Needed since NumPy 1.24 no longer automatically infers dtype=object when + # ragged sequences are passed to np.array(). + return analyzer_nodes.JsonNumpyCacheCoder(np_dtype=object) - def merge_accumulators(self, accumulators): - """Sums values in each accumulator entry.""" - products, vectors, counts = zip(*accumulators) - return [ - np.sum(products, axis=0), - np.sum(vectors, axis=0), - np.sum(counts, axis=0) - ] - def extract_output(self, accumulator): - """Run covariance logic on sum_product, sum of input vectors, and count. +@common.log_api_use(common.ANALYZER_COLLECTION) +def covariance(x: tf.Tensor, dtype: tf.DType, name: Optional[str] = None) -> tf.Tensor: + """Computes the covariance matrix over the whole dataset. - The formula used to compute the covariance is cov(x) = E(xx^T) - uu^T, - where x is the original input to the combiner, and u = mean(x). - E(xx^T) is computed by dividing sum of cross terms (index 0) by count - (index 2). u is computed by taking the sum of rows (index 1) and dividing by - the count (index 2). + The covariance matrix M is defined as follows: + Let x[:j] be a tensor of the jth element of all input vectors in x, and let + u_j = mean(x[:j]). The entry M[i,j] = E[(x[:i] - u_i)(x[:j] - u_j)]. + Notice that the diagonal entries correspond to variances of individual + elements in the vector, i.e. M[i,i] corresponds to the variance of x[:i]. Args: - accumulator: final accumulator as a list of the sum of cross-terms matrix, - sum of input vectors, and count. + ---- + x: A rank-2 `Tensor`, 0th dim are rows, 1st dim are indices in each input + vector. + dtype: Tensorflow dtype of entries in the returned matrix. + name: (Optional) A name for this operation. + + Raises: + ------ + ValueError: if input is not a rank-2 Tensor. Returns: - A list containing a single 2d ndarray, the covariance matrix. + ------- + A rank-2 (matrix) covariance `Tensor` """ + if not isinstance(x, tf.Tensor): + raise TypeError("Expected a Tensor, but got %r" % x) - sum_product, sum_vectors, count = accumulator - if count == 0: - return [np.zeros(self._output_shape, self._numpy_dtype)] - expected_cross_terms = sum_product / count - expected_terms = sum_vectors / count - return [ - np.ndarray.astype( # TODO(b/64987151): # pytype: disable=attribute-error - expected_cross_terms - np.outer(expected_terms, expected_terms), - self._numpy_dtype) - ] - - def output_tensor_infos(self): - return [ - analyzer_nodes.TensorInfo( - tf.as_dtype(self._numpy_dtype), self._output_shape, None) - ] - - @property - def accumulator_coder(self): - # Needed since NumPy 1.24 no longer automatically infers dtype=object when - # ragged sequences are passed to np.array(). - return analyzer_nodes.JsonNumpyCacheCoder(np_dtype=object) + with tf.compat.v1.name_scope(name, "covariance"): + x.shape.assert_has_rank(2) + input_dim = x.shape.as_list()[1] + shape = (input_dim, input_dim) -@common.log_api_use(common.ANALYZER_COLLECTION) -def covariance(x: tf.Tensor, - dtype: tf.DType, - name: Optional[str] = None) -> tf.Tensor: - """Computes the covariance matrix over the whole dataset. + (result,) = _apply_cacheable_combiner( + CovarianceCombiner(shape, dtype.as_numpy_dtype), x + ) + return result - The covariance matrix M is defined as follows: - Let x[:j] be a tensor of the jth element of all input vectors in x, and let - u_j = mean(x[:j]). The entry M[i,j] = E[(x[:i] - u_i)(x[:j] - u_j)]. - Notice that the diagonal entries correspond to variances of individual - elements in the vector, i.e. M[i,i] corresponds to the variance of x[:i]. - Args: - x: A rank-2 `Tensor`, 0th dim are rows, 1st dim are indices in each input - vector. - dtype: Tensorflow dtype of entries in the returned matrix. - name: (Optional) A name for this operation. +class PCACombiner(CovarianceCombiner): + """Compute PCA of accumulated data using the biased covariance matrix.""" + + def __init__(self, output_shape, output_dim=None, numpy_dtype=np.float64): + """Store pca output dimension, shape and dtype for precision.""" + super().__init__(output_shape, numpy_dtype=numpy_dtype) + self._output_dim = output_dim + + def extract_output(self, accumulator): + """Compute PCA of the accumulated data using the biased covariance matrix. + + Following the covariance computation in CovarianceCombiner, this method runs + eigenvalue decomposition on the covariance matrix, sorts eigenvalues in + decreasing order, and returns the first output_dim corresponding + eigenvectors (principal components) as a matrix. + + Args: + ---- + accumulator: final accumulator as a list of the sum of cross-terms matrix, + sum of input vectors, and count. + + Returns: + ------- + A list containing a matrix of shape (input_dim, output_dim). + """ + sum_product, sum_vectors, count = accumulator + if count == 0: + # In this case all eigenvalues==0 and we output (possibly truncated) basis + # vectors. Note that if _output_dim is None, then M is set to N in np.eye. + return [ + np.eye( + N=self._output_shape[0], M=self._output_dim, dtype=self._numpy_dtype + ) + ] + expected_cross_terms = sum_product / count + expected_terms = sum_vectors / count + cov = np.ndarray.astype( # TODO(b/64987151): # pytype: disable=attribute-error + expected_cross_terms - np.outer(expected_terms, expected_terms), + self._numpy_dtype, + ) + vals, vecs = np.linalg.eigh(cov) + sorted_vecs = vecs[:, np.argsort(vals)[::-1]] + if self._output_dim is None: + return [sorted_vecs] + else: + return [sorted_vecs[:, : self._output_dim]] - Raises: - ValueError: if input is not a rank-2 Tensor. - Returns: - A rank-2 (matrix) covariance `Tensor` - """ +@common.log_api_use(common.ANALYZER_COLLECTION) +def pca( + x: tf.Tensor, output_dim: int, dtype: tf.DType, name: Optional[str] = None +) -> tf.Tensor: + """Computes PCA on the dataset using biased covariance. + + The PCA analyzer computes output_dim orthonormal vectors that capture + directions/axes corresponding to the highest variances in the input vectors of + `x`. The output vectors are returned as a rank-2 tensor with shape + `(input_dim, output_dim)`, where the 0th dimension are the components of each + output vector, and the 1st dimension are the output vectors representing + orthogonal directions in the input space, sorted in order of decreasing + variances. + + The output rank-2 tensor (matrix) serves a useful transform purpose. Formally, + the matrix can be used downstream in the transform step by multiplying it to + the input tensor `x`. This transform reduces the dimension of input vectors to + output_dim in a way that retains the maximal variance. + + NOTE: To properly use PCA, input vector components should be converted to + similar units of measurement such that the vectors represent a Euclidean + space. If no such conversion is available (e.g. one element represents time, + another element distance), the canonical approach is to first apply a + transformation to the input data to normalize numerical variances, i.e. + `tft.scale_to_z_score()`. Normalization allows PCA to choose output axes that + help decorrelate input axes. + + Below are a couple intuitive examples of PCA. + + Consider a simple 2-dimensional example: + + Input x is a series of vectors `[e, e]` where `e` is Gaussian with mean 0, + variance 1. The two components are perfectly correlated, and the resulting + covariance matrix is + + ``` + [[1 1], + [1 1]]. + ``` + + Applying PCA with `output_dim = 1` would discover the first principal + component `[1 / sqrt(2), 1 / sqrt(2)]`. When multipled to the original + example, each vector `[e, e]` would be mapped to a scalar `sqrt(2) * e`. The + second principal component would be `[-1 / sqrt(2), 1 / sqrt(2)]` and would + map `[e, e]` to 0, which indicates that the second component captures no + variance at all. This agrees with our intuition since we know that the two + axes in the input are perfectly correlated and can be fully explained by a + single scalar `e`. + + Consider a 3-dimensional example: + + Input `x` is a series of vectors `[a, a, b]`, where `a` is a zero-mean, unit + variance Gaussian and `b` is a zero-mean, variance 4 Gaussian and is + independent of `a`. The first principal component of the unnormalized vector + would be `[0, 0, 1]` since `b` has a much larger variance than any linear + combination of the first two components. This would map `[a, a, b]` onto `b`, + asserting that the axis with highest energy is the third component. While this + may be the desired output if `a` and `b` correspond to the same units, it is + not statistically desireable when the units are irreconciliable. In such a + case, one should first normalize each component to unit variance first, i.e. + `b := b / 2`. The first principal component of a normalized vector would yield + `[1 / sqrt(2), 1 / sqrt(2), 0]`, and would map `[a, a, b]` to `sqrt(2) * a`. + The second component would be `[0, 0, 1]` and map `[a, a, b]` to `b`. As can + be seen, the benefit of normalization is that PCA would capture highly + correlated components first and collapse them into a lower dimension. - if not isinstance(x, tf.Tensor): - raise TypeError('Expected a Tensor, but got %r' % x) + Args: + ---- + x: A rank-2 `Tensor`, 0th dim are rows, 1st dim are indices in row vectors. + output_dim: The PCA output dimension (number of eigenvectors to return). + dtype: Tensorflow dtype of entries in the returned matrix. + name: (Optional) A name for this operation. - with tf.compat.v1.name_scope(name, 'covariance'): - x.shape.assert_has_rank(2) + Raises: + ------ + ValueError: if input is not a rank-2 Tensor. - input_dim = x.shape.as_list()[1] - shape = (input_dim, input_dim) + Returns: + ------- + A 2D `Tensor` (matrix) M of shape (input_dim, output_dim). + """ + if not isinstance(x, tf.Tensor): + raise TypeError("Expected a Tensor, but got %r" % x) - (result,) = _apply_cacheable_combiner( - CovarianceCombiner(shape, dtype.as_numpy_dtype), x) - return result + with tf.compat.v1.name_scope(name, "pca"): + x.shape.assert_has_rank(2) + input_dim = x.shape.as_list()[1] + shape = (input_dim, output_dim) -class PCACombiner(CovarianceCombiner): - """Compute PCA of accumulated data using the biased covariance matrix.""" + (result,) = _apply_cacheable_combiner( + PCACombiner(shape, output_dim, dtype.as_numpy_dtype), x + ) + return result - def __init__(self, output_shape, output_dim=None, numpy_dtype=np.float64): - """Store pca output dimension, shape and dtype for precision.""" - super().__init__(output_shape, numpy_dtype=numpy_dtype) - self._output_dim = output_dim - def extract_output(self, accumulator): - """Compute PCA of the accumulated data using the biased covariance matrix. +def _maybe_annotate_vocab_metadata( + vocab_filename: str, + unfiltered_vocabulary_size: tf.Tensor, + filtered_vocabulary_size: tf.Tensor, +): + """Annotates a bucketized tensor with the boundaries that were applied. - Following the covariance computation in CovarianceCombiner, this method runs - eigenvalue decomposition on the covariance matrix, sorts eigenvalues in - decreasing order, and returns the first output_dim corresponding - eigenvectors (principal components) as a matrix. + Creates a deferred annotation for the specified tensor. Args: - accumulator: final accumulator as a list of the sum of cross-terms matrix, - sum of input vectors, and count. - - Returns: - A list containing a matrix of shape (input_dim, output_dim). + ---- + vocab_filename: The name of the vocabulary. + unfiltered_vocabulary_size: A tf.int64 tensor containing the unfiltered + vocab size. + filtered_vocabulary_size: A tf.int64 tensor containing the filtered vocab + size. """ - sum_product, sum_vectors, count = accumulator - if count == 0: - # In this case all eigenvalues==0 and we output (possibly truncated) basis - # vectors. Note that if _output_dim is None, then M is set to N in np.eye. - return [np.eye(N=self._output_shape[0], M=self._output_dim, - dtype=self._numpy_dtype)] - expected_cross_terms = sum_product / count - expected_terms = sum_vectors / count - cov = np.ndarray.astype( # TODO(b/64987151): # pytype: disable=attribute-error - expected_cross_terms - np.outer(expected_terms, expected_terms), - self._numpy_dtype) - vals, vecs = np.linalg.eigh(cov) - sorted_vecs = vecs[:, np.argsort(vals)[::-1]] - if self._output_dim is None: - return [sorted_vecs] - else: - return [sorted_vecs[:, :self._output_dim]] + if not common.IS_ANNOTATIONS_PB_AVAILABLE: + return + from tensorflow_transform import ( + annotations_pb2, # pylint: disable=g-import-not-at-top + ) -@common.log_api_use(common.ANALYZER_COLLECTION) -def pca(x: tf.Tensor, - output_dim: int, - dtype: tf.DType, - name: Optional[str] = None) -> tf.Tensor: - """Computes PCA on the dataset using biased covariance. - - The PCA analyzer computes output_dim orthonormal vectors that capture - directions/axes corresponding to the highest variances in the input vectors of - `x`. The output vectors are returned as a rank-2 tensor with shape - `(input_dim, output_dim)`, where the 0th dimension are the components of each - output vector, and the 1st dimension are the output vectors representing - orthogonal directions in the input space, sorted in order of decreasing - variances. - - The output rank-2 tensor (matrix) serves a useful transform purpose. Formally, - the matrix can be used downstream in the transform step by multiplying it to - the input tensor `x`. This transform reduces the dimension of input vectors to - output_dim in a way that retains the maximal variance. - - NOTE: To properly use PCA, input vector components should be converted to - similar units of measurement such that the vectors represent a Euclidean - space. If no such conversion is available (e.g. one element represents time, - another element distance), the canonical approach is to first apply a - transformation to the input data to normalize numerical variances, i.e. - `tft.scale_to_z_score()`. Normalization allows PCA to choose output axes that - help decorrelate input axes. - - Below are a couple intuitive examples of PCA. - - Consider a simple 2-dimensional example: - - Input x is a series of vectors `[e, e]` where `e` is Gaussian with mean 0, - variance 1. The two components are perfectly correlated, and the resulting - covariance matrix is - - ``` - [[1 1], - [1 1]]. - ``` - - Applying PCA with `output_dim = 1` would discover the first principal - component `[1 / sqrt(2), 1 / sqrt(2)]`. When multipled to the original - example, each vector `[e, e]` would be mapped to a scalar `sqrt(2) * e`. The - second principal component would be `[-1 / sqrt(2), 1 / sqrt(2)]` and would - map `[e, e]` to 0, which indicates that the second component captures no - variance at all. This agrees with our intuition since we know that the two - axes in the input are perfectly correlated and can be fully explained by a - single scalar `e`. - - Consider a 3-dimensional example: - - Input `x` is a series of vectors `[a, a, b]`, where `a` is a zero-mean, unit - variance Gaussian and `b` is a zero-mean, variance 4 Gaussian and is - independent of `a`. The first principal component of the unnormalized vector - would be `[0, 0, 1]` since `b` has a much larger variance than any linear - combination of the first two components. This would map `[a, a, b]` onto `b`, - asserting that the axis with highest energy is the third component. While this - may be the desired output if `a` and `b` correspond to the same units, it is - not statistically desireable when the units are irreconciliable. In such a - case, one should first normalize each component to unit variance first, i.e. - `b := b / 2`. The first principal component of a normalized vector would yield - `[1 / sqrt(2), 1 / sqrt(2), 0]`, and would map `[a, a, b]` to `sqrt(2) * a`. - The second component would be `[0, 0, 1]` and map `[a, a, b]` to `b`. As can - be seen, the benefit of normalization is that PCA would capture highly - correlated components first and collapse them into a lower dimension. - - Args: - x: A rank-2 `Tensor`, 0th dim are rows, 1st dim are indices in row vectors. - output_dim: The PCA output dimension (number of eigenvectors to return). - dtype: Tensorflow dtype of entries in the returned matrix. - name: (Optional) A name for this operation. - - Raises: - ValueError: if input is not a rank-2 Tensor. - - Returns: - A 2D `Tensor` (matrix) M of shape (input_dim, output_dim). - """ - - if not isinstance(x, tf.Tensor): - raise TypeError('Expected a Tensor, but got %r' % x) - - with tf.compat.v1.name_scope(name, 'pca'): - x.shape.assert_has_rank(2) - - input_dim = x.shape.as_list()[1] - shape = (input_dim, output_dim) - - (result,) = _apply_cacheable_combiner( - PCACombiner(shape, output_dim, dtype.as_numpy_dtype), x) - return result - - -def _maybe_annotate_vocab_metadata(vocab_filename: str, - unfiltered_vocabulary_size: tf.Tensor, - filtered_vocabulary_size: tf.Tensor): - """Annotates a bucketized tensor with the boundaries that were applied. - - Creates a deferred annotation for the specified tensor. - - Args: - vocab_filename: The name of the vocabulary. - unfiltered_vocabulary_size: A tf.int64 tensor containing the unfiltered - vocab size. - filtered_vocabulary_size: A tf.int64 tensor containing the filtered vocab - size. - """ - if not common.IS_ANNOTATIONS_PB_AVAILABLE: - return - - from tensorflow_transform import annotations_pb2 # pylint: disable=g-import-not-at-top - message_type = annotations_pb2.VocabularyMetadata.DESCRIPTOR.full_name - unfiltered_vocabulary_size = tf.expand_dims(unfiltered_vocabulary_size, 0) - filtered_vocabulary_size = tf.expand_dims(filtered_vocabulary_size, 0) - file_name = tf.convert_to_tensor([vocab_filename]) - descriptor_source = descriptor_pb2.FileDescriptorSet() - annotations_pb2.VocabularyMetadata.DESCRIPTOR.file.CopyToProto( - descriptor_source.file.add()) - descriptor_source_str = b'bytes://' + descriptor_source.SerializeToString() - message_proto = tf_utils._encode_proto( # pylint: disable=protected-access - { - 'unfiltered_vocabulary_size': unfiltered_vocabulary_size, - 'filtered_vocabulary_size': filtered_vocabulary_size, - 'file_name': file_name, - }, message_type, descriptor_source=descriptor_source_str) - assert message_proto.shape == [1] - message_proto = message_proto[0] - - # Note: we annotate globally here (tied to a vocabulary by filename) rather - # than attaching to a tensor, because this annotation is tied to an analysis - # output not a final tensor produced by a mapper. - type_url = os.path.join(common.ANNOTATION_PREFIX_URL, message_type) - schema_inference.annotate(type_url, message_proto) + message_type = annotations_pb2.VocabularyMetadata.DESCRIPTOR.full_name + unfiltered_vocabulary_size = tf.expand_dims(unfiltered_vocabulary_size, 0) + filtered_vocabulary_size = tf.expand_dims(filtered_vocabulary_size, 0) + file_name = tf.convert_to_tensor([vocab_filename]) + descriptor_source = descriptor_pb2.FileDescriptorSet() + annotations_pb2.VocabularyMetadata.DESCRIPTOR.file.CopyToProto( + descriptor_source.file.add() + ) + descriptor_source_str = b"bytes://" + descriptor_source.SerializeToString() + message_proto = tf_utils._encode_proto( # pylint: disable=protected-access + { + "unfiltered_vocabulary_size": unfiltered_vocabulary_size, + "filtered_vocabulary_size": filtered_vocabulary_size, + "file_name": file_name, + }, + message_type, + descriptor_source=descriptor_source_str, + ) + assert message_proto.shape == [1] + message_proto = message_proto[0] + + # Note: we annotate globally here (tied to a vocabulary by filename) rather + # than attaching to a tensor, because this annotation is tied to an analysis + # output not a final tensor produced by a mapper. + type_url = os.path.join(common.ANNOTATION_PREFIX_URL, message_type) + schema_inference.annotate(type_url, message_proto) diff --git a/tensorflow_transform/analyzers_test.py b/tensorflow_transform/analyzers_test.py index 38a8a44..c470a95 100644 --- a/tensorflow_transform/analyzers_test.py +++ b/tensorflow_transform/analyzers_test.py @@ -18,18 +18,18 @@ import numpy as np import tensorflow as tf -from tensorflow_transform import analyzers -from tensorflow_transform import test_case +from tensorflow_transform import analyzers, test_case _NP_TYPES = (np.float32, np.float64, np.int32, np.int64) _SUM_TEST = dict( - testcase_name='Sum', + testcase_name="Sum", combiner=analyzers.NumPyCombiner( fn=np.sum, default_accumulator_value=0, output_dtypes=[np.int64], - output_shapes=[None]), + output_shapes=[None], + ), batches=[ (np.array([1, 2, 3, 4, 5, 6]),), (np.array([1, 2, 3, 4, 5, 6]),), @@ -38,12 +38,13 @@ ) _SUM_SCALAR_TEST = dict( - testcase_name='SumScalar', + testcase_name="SumScalar", combiner=analyzers.NumPyCombiner( fn=np.sum, default_accumulator_value=0, output_dtypes=[np.int64], - output_shapes=[None]), + output_shapes=[None], + ), batches=[ (np.array(1),), (np.array(2),), @@ -52,12 +53,13 @@ ) _SUM_OF_SIZE_ZERO_TENSORS_TEST = dict( - testcase_name='SumOfSizeZeroTensors', + testcase_name="SumOfSizeZeroTensors", combiner=analyzers.NumPyCombiner( fn=np.sum, default_accumulator_value=0, output_dtypes=[np.int64], - output_shapes=[None]), + output_shapes=[None], + ), batches=[ (np.array([]),), (np.array([]),), @@ -66,9 +68,8 @@ ) _COVARIANCE_SIZE_ZERO_TENSORS_TEST = dict( - testcase_name='CovarianceSizeZeroTensors', - combiner=analyzers.CovarianceCombiner(output_shape=(0, 0), - numpy_dtype=np.float64), + testcase_name="CovarianceSizeZeroTensors", + combiner=analyzers.CovarianceCombiner(output_shape=(0, 0), numpy_dtype=np.float64), batches=[ (np.empty((1, 0)),), (np.empty((2, 0)),), @@ -77,23 +78,19 @@ ) _COVARIANCE_WITH_DEGENERATE_COVARIANCE_MATRIX_TEST = dict( - testcase_name='CovarianceWithDegenerateCovarianceMatrix', - combiner=analyzers.CovarianceCombiner(output_shape=(3, 3), - numpy_dtype=np.float64), + testcase_name="CovarianceWithDegenerateCovarianceMatrix", + combiner=analyzers.CovarianceCombiner(output_shape=(3, 3), numpy_dtype=np.float64), batches=[ (np.array([[0, 0, 1]]),), (np.array([[4, 0, 1], [2, -1, 1]]),), (np.array([[2, 1, 1]]),), ], - expected_outputs=[ - np.array([[2, 0, 0], [0, 0.5, 0], [0, 0, 0]], dtype=np.float64) - ], + expected_outputs=[np.array([[2, 0, 0], [0, 0.5, 0], [0, 0, 0]], dtype=np.float64)], ) _COVARIANCE_WITH_LARGE_NUMBERS_TEST = dict( - testcase_name='CovarianceWithLargeNumbers', - combiner=analyzers.CovarianceCombiner(output_shape=(2, 2), - numpy_dtype=np.float64), + testcase_name="CovarianceWithLargeNumbers", + combiner=analyzers.CovarianceCombiner(output_shape=(2, 2), numpy_dtype=np.float64), batches=[ (np.array([[2e15, 0], [1e15, 0]]),), (np.array([[-2e15, 0], [-1e15, 0]]),), @@ -102,35 +99,33 @@ ) _PCA_WITH_DEGENERATE_COVARIANCE_MATRIX_TEST = dict( - testcase_name='PCAWithDegenerateCovarianceMatrix', - combiner=analyzers.PCACombiner(output_shape=(3, 3), - numpy_dtype=np.float64), + testcase_name="PCAWithDegenerateCovarianceMatrix", + combiner=analyzers.PCACombiner(output_shape=(3, 3), numpy_dtype=np.float64), batches=[ (np.array([[0, 0, 1]]),), (np.array([[4, 0, 1], [2, -1, 1]]),), (np.array([[2, 1, 1]]),), ], - expected_outputs=[ - np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float64) - ], + expected_outputs=[np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float64)], ) def _make_mean_and_var_accumulator_from_instance(instance, axis=None): - return analyzers._WeightedMeanAndVarAccumulator( - count=np.sum(np.ones_like(instance), axis=axis), - mean=np.mean(instance, axis=axis), - weight=np.sum(np.ones_like(instance), axis=axis), - variance=np.var(instance, axis=axis)) + return analyzers._WeightedMeanAndVarAccumulator( + count=np.sum(np.ones_like(instance), axis=axis), + mean=np.mean(instance, axis=axis), + weight=np.sum(np.ones_like(instance), axis=axis), + variance=np.var(instance, axis=axis), + ) + _MEAN_AND_VAR_TEST = dict( - testcase_name='WeightedMeanAndVar', + testcase_name="WeightedMeanAndVar", combiner=analyzers.WeightedMeanAndVarCombiner(np.float32, output_shape=()), batches=[ _make_mean_and_var_accumulator_from_instance([[1, 2, 3, 4, 5, 6, 7]]), # Count is 5*0xFFFF=327675 for this accumulator. - _make_mean_and_var_accumulator_from_instance([[8, 9, 10, 11, 12]] * - 0xFFFF), + _make_mean_and_var_accumulator_from_instance([[8, 9, 10, 11, 12]] * 0xFFFF), _make_mean_and_var_accumulator_from_instance([[100, 200, 3000]]), ], expected_outputs=[ @@ -140,27 +135,26 @@ def _make_mean_and_var_accumulator_from_instance(instance, axis=None): ) _MEAN_AND_VAR_SIMPLE_TEST = dict( - testcase_name='WeightedMeanAndVarSimple', + testcase_name="WeightedMeanAndVarSimple", combiner=analyzers.WeightedMeanAndVarCombiner( - np.float32, - output_shape=(), - compute_variance=False, - compute_weighted=False), + np.float32, output_shape=(), compute_variance=False, compute_weighted=False + ), batches=[ _make_mean_and_var_accumulator_from_instance([[1, 2, 3, 4, 5, 6, 7]]), # Count is 5*0xFFFF=327675 for this accumulator. - _make_mean_and_var_accumulator_from_instance([[8, 9, 10, 11, 12]] * - 0xFFFF), + _make_mean_and_var_accumulator_from_instance([[8, 9, 10, 11, 12]] * 0xFFFF), _make_mean_and_var_accumulator_from_instance([[100, 200, 3000]]), ], expected_outputs=analyzers._WeightedMeanAndVarAccumulator( count=np.array(327685), mean=np.float32(10.00985092390558), weight=np.float32(1.0), - variance=np.float32(0.0))) + variance=np.float32(0.0), + ), +) _MEAN_AND_VAR_BIG_TEST = dict( - testcase_name='WeightedMeanAndVarBig', + testcase_name="WeightedMeanAndVarBig", combiner=analyzers.WeightedMeanAndVarCombiner(np.float32, output_shape=()), batches=[ _make_mean_and_var_accumulator_from_instance([[1, 2, 3, 4, 5, 6, 7]]), @@ -168,34 +162,32 @@ def _make_mean_and_var_accumulator_from_instance(instance, axis=None): _make_mean_and_var_accumulator_from_instance([[100, 200]]), ], expected_outputs=[ - np.float32(2.50e+14), - np.float32(3.541666666665e+29), + np.float32(2.50e14), + np.float32(3.541666666665e29), ], ) _MEAN_AND_VAR_VECTORS_TEST = dict( - testcase_name='WeightedMeanAndVarForVectors', - combiner=analyzers.WeightedMeanAndVarCombiner( - np.float32, output_shape=(None,)), + testcase_name="WeightedMeanAndVarForVectors", + combiner=analyzers.WeightedMeanAndVarCombiner(np.float32, output_shape=(None,)), batches=[ - _make_mean_and_var_accumulator_from_instance([[1, 2, 3, 4, 5, 6]], - axis=0), - _make_mean_and_var_accumulator_from_instance([[7, 8, 9, 10, 11, 12]], - axis=0), + _make_mean_and_var_accumulator_from_instance([[1, 2, 3, 4, 5, 6]], axis=0), + _make_mean_and_var_accumulator_from_instance([[7, 8, 9, 10, 11, 12]], axis=0), _make_mean_and_var_accumulator_from_instance( - [[100, 200, 3000, 17, 27, 53]], axis=0), + [[100, 200, 3000, 17, 27, 53]], axis=0 + ), ], expected_outputs=[ - np.float32([36., 70., 1004., 10.33333333, 14.33333333, 23.66666667]), - np.float32( - [2054., 8456., 1992014., 28.22222222, 86.22222222, 436.22222222]), + np.float32([36.0, 70.0, 1004.0, 10.33333333, 14.33333333, 23.66666667]), + np.float32([2054.0, 8456.0, 1992014.0, 28.22222222, 86.22222222, 436.22222222]), ], ) _MEAN_AND_VAR_ND_TEST = dict( - testcase_name='WeightedMeanAndVarForNDVectors', + testcase_name="WeightedMeanAndVarForNDVectors", combiner=analyzers.WeightedMeanAndVarCombiner( - np.float32, output_shape=(None, None)), + np.float32, output_shape=(None, None) + ), batches=[ _make_mean_and_var_accumulator_from_instance([[[1], [1], [2]]], axis=2), _make_mean_and_var_accumulator_from_instance([[[1], [2], [2]]], axis=2), @@ -203,224 +195,240 @@ def _make_mean_and_var_accumulator_from_instance(instance, axis=None): ], expected_outputs=[ np.float32([[1.333333333, 1.666666666, 2]]), - np.float32([[.222222222, .222222222, 0]]), + np.float32([[0.222222222, 0.222222222, 0]]), ], ) -_L_MOMENTS_TESTS = [dict( - testcase_name='LMoments_one_batch', - combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), - batches=[ - # Accumulator for the sequence: - # np.concatenate((np.power(2.0, np.arange(0, 10, 0.01)), - # -np.power(1.9, np.arange(0, 10, 0.01))) - analyzers._LMomentsAccumulator( - count_l1=np.float32(2000.), - count_l2=np.float32(1999000.), - count_l3=np.float32(1.331334e+09), - count_l4=np.float32(6.6466854e+11), - l1=np.float32(26.00855), - l2=np.float32(103.25489), - l3=np.float32(17.549286), - l4=np.float32(47.41136)) - ], - expected_outputs=[ - np.float32(5.769684), - np.float32(81.381424), - np.float32(0.39079103), - np.float32(0.55846965) - ], -), dict( - testcase_name='LMoments_small_batch', - combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), - batches=[ - # Accumulator for the sequence: [1., 1., 2., 2.]. - analyzers._LMomentsAccumulator( - count_l1=np.float32(4.), - count_l2=np.float32(6.), - count_l3=np.float32(4.), - count_l4=np.float32(1.), - l1=np.float32(1.5), - l2=np.float32(0.33333334), - l3=np.float32(0.), - l4=np.float32(-0.5)) - ], - expected_outputs=[ - np.float32(1.5), - np.float32(np.sqrt(np.pi) / 3.0), - np.float32(0.0), - np.float32(0.0) - ], -), dict( - testcase_name='LMoments_one_sample', - combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), - batches=[ - # Accumulator for the sequence: [1.]. - analyzers._LMomentsAccumulator( - count_l1=np.float32(1.), - count_l2=np.float32(0.), - count_l3=np.float32(-0.), - count_l4=np.float32(0.), - l1=np.float32(1.), - l2=np.float32(0.), - l3=np.float32(-0.), - l4=np.float32(0.)) - ], - expected_outputs=[ - np.float32(1.0), - np.float32(1.0), - np.float32(0.0), - np.float32(0.0) - ], -), dict( - testcase_name='LMoments_two_samples', - combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), - batches=[ - # Accumulator for the sequence: [1., 1.]. - analyzers._LMomentsAccumulator( - count_l1=np.float32(2.), - count_l2=np.float32(1.), - count_l3=np.float32(0.), - count_l4=np.float32(-0.), - l1=np.float32(1.), - l2=np.float32(0.), - l3=np.float32(0.), - l4=np.float32(0.)) - ], - expected_outputs=[ - np.float32(1.0), - np.float32(1.0), - np.float32(0.0), - np.float32(0.0) - ], -), dict( - testcase_name='LMoments_multiple_batches', - combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), - batches=[ - # Accumulator for the sequence: - # np.concatenate((np.power(2.0, np.arange(0, 10, 0.02)), - # -np.power(1.9, np.arange(0, 10, 0.02))) - analyzers._LMomentsAccumulator( - count_l1=np.float32(1000.), - count_l2=np.float32(499500.), - count_l3=np.float32(1.66167e+08), - count_l4=np.float32(4.1417126e+10), - l1=np.float32(25.90623), - l2=np.float32(102.958664), - l3=np.float32(17.50719), - l4=np.float32(47.393063)), - # Accumulator for the sequence: - # np.concatenate((np.power(2.0, np.arange(0.01, 10, 0.02)), - # -np.power(1.9, np.arange(0.01, 10, 0.02))) - analyzers._LMomentsAccumulator( - count_l1=np.float32(1000.), - count_l2=np.float32(499500.), - count_l3=np.float32(1.66167e+08), - count_l4=np.float32(4.1417126e+10), - l1=np.float32(26.110888), - l2=np.float32(103.65407), - l3=np.float32(17.64386), - l4=np.float32(47.71353)), - ], - expected_outputs=[ - np.float32(5.751478), - np.float32(81.16352), - np.float32(0.3923474), - np.float32(0.55972165) - ], -)] +_L_MOMENTS_TESTS = [ + dict( + testcase_name="LMoments_one_batch", + combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), + batches=[ + # Accumulator for the sequence: + # np.concatenate((np.power(2.0, np.arange(0, 10, 0.01)), + # -np.power(1.9, np.arange(0, 10, 0.01))) + analyzers._LMomentsAccumulator( + count_l1=np.float32(2000.0), + count_l2=np.float32(1999000.0), + count_l3=np.float32(1.331334e09), + count_l4=np.float32(6.6466854e11), + l1=np.float32(26.00855), + l2=np.float32(103.25489), + l3=np.float32(17.549286), + l4=np.float32(47.41136), + ) + ], + expected_outputs=[ + np.float32(5.769684), + np.float32(81.381424), + np.float32(0.39079103), + np.float32(0.55846965), + ], + ), + dict( + testcase_name="LMoments_small_batch", + combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), + batches=[ + # Accumulator for the sequence: [1., 1., 2., 2.]. + analyzers._LMomentsAccumulator( + count_l1=np.float32(4.0), + count_l2=np.float32(6.0), + count_l3=np.float32(4.0), + count_l4=np.float32(1.0), + l1=np.float32(1.5), + l2=np.float32(0.33333334), + l3=np.float32(0.0), + l4=np.float32(-0.5), + ) + ], + expected_outputs=[ + np.float32(1.5), + np.float32(np.sqrt(np.pi) / 3.0), + np.float32(0.0), + np.float32(0.0), + ], + ), + dict( + testcase_name="LMoments_one_sample", + combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), + batches=[ + # Accumulator for the sequence: [1.]. + analyzers._LMomentsAccumulator( + count_l1=np.float32(1.0), + count_l2=np.float32(0.0), + count_l3=np.float32(-0.0), + count_l4=np.float32(0.0), + l1=np.float32(1.0), + l2=np.float32(0.0), + l3=np.float32(-0.0), + l4=np.float32(0.0), + ) + ], + expected_outputs=[ + np.float32(1.0), + np.float32(1.0), + np.float32(0.0), + np.float32(0.0), + ], + ), + dict( + testcase_name="LMoments_two_samples", + combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), + batches=[ + # Accumulator for the sequence: [1., 1.]. + analyzers._LMomentsAccumulator( + count_l1=np.float32(2.0), + count_l2=np.float32(1.0), + count_l3=np.float32(0.0), + count_l4=np.float32(-0.0), + l1=np.float32(1.0), + l2=np.float32(0.0), + l3=np.float32(0.0), + l4=np.float32(0.0), + ) + ], + expected_outputs=[ + np.float32(1.0), + np.float32(1.0), + np.float32(0.0), + np.float32(0.0), + ], + ), + dict( + testcase_name="LMoments_multiple_batches", + combiner=analyzers._LMomentsCombiner(np.float32, output_shape=()), + batches=[ + # Accumulator for the sequence: + # np.concatenate((np.power(2.0, np.arange(0, 10, 0.02)), + # -np.power(1.9, np.arange(0, 10, 0.02))) + analyzers._LMomentsAccumulator( + count_l1=np.float32(1000.0), + count_l2=np.float32(499500.0), + count_l3=np.float32(1.66167e08), + count_l4=np.float32(4.1417126e10), + l1=np.float32(25.90623), + l2=np.float32(102.958664), + l3=np.float32(17.50719), + l4=np.float32(47.393063), + ), + # Accumulator for the sequence: + # np.concatenate((np.power(2.0, np.arange(0.01, 10, 0.02)), + # -np.power(1.9, np.arange(0.01, 10, 0.02))) + analyzers._LMomentsAccumulator( + count_l1=np.float32(1000.0), + count_l2=np.float32(499500.0), + count_l3=np.float32(1.66167e08), + count_l4=np.float32(4.1417126e10), + l1=np.float32(26.110888), + l2=np.float32(103.65407), + l3=np.float32(17.64386), + l4=np.float32(47.71353), + ), + ], + expected_outputs=[ + np.float32(5.751478), + np.float32(81.16352), + np.float32(0.3923474), + np.float32(0.55972165), + ], + ), +] -_L_MOMENTS_ND_TESTS = [dict( - testcase_name='LMomentsOneBatchForNDVectors', - combiner=analyzers._LMomentsCombiner(np.float32, output_shape=(None, None)), - batches=[ - # Accumulator for the sequence: - # np.concatenate(( - # np.concatenate(( - # np.power(2.0, np.arange(0, 10, 0.01)), - # -np.power(1.9, np.arange(0, 10, 0.01)))).reshape( - # [-1, 1, 1]), - # np.concatenate(( - # np.power(1.9, np.arange(0, 10, 0.01)), - # -np.power(2.0, np.arange(0, 10, 0.01)))).reshape( - # [-1, 1, 1])), axis=2), - # axis=0), - analyzers._LMomentsAccumulator( - count_l1=np.array([[2000., 2000.]], dtype=np.float32), - count_l2=np.array([[1999000., 1999000.]], dtype=np.float32), - count_l3=np.array([[1.331334e+09, 1.331334e+09]], dtype=np.float32), - count_l4=np.array( - [[6.6466854e+11, 6.6466854e+11]], dtype=np.float32), - l1=np.array([[26.00855, -26.008562]], dtype=np.float32), - l2=np.array([[103.25489, 103.25489]], dtype=np.float32), - l3=np.array([[17.549286, -17.549274]], dtype=np.float32), - l4=np.array([[47.41136, 47.41136]], dtype=np.float32)) - ], - expected_outputs=[ - np.array([[5.7696896, -5.7697697]], dtype=np.float32), - np.array([[81.38142, 81.381386]], dtype=np.float32), - np.array([[0.39079103, 0.55846965]], dtype=np.float32), - np.array([[0.55846965, 0.39079177]], dtype=np.float32) - ], -), dict( - testcase_name='LMomentsMultipleBatchesForNDVectors', - combiner=analyzers._LMomentsCombiner(np.float32, output_shape=(None, None)), - batches=[ - # Accumulator for the sequence: - # np.concatenate(( - # np.concatenate(( - # np.power(2.0, np.arange(0, 10, 0.02)), - # -np.power(1.9, np.arange(0., 10, 0.02)))).reshape( - # [-1, 1, 1]), - # np.concatenate(( - # np.power(1.9, np.arange(0, 10, 0.02)), - # -np.power(2.0, np.arange(0., 10, 0.02)))).reshape( - # [-1, 1, 1])), axis=2), - # axis=0) - analyzers._LMomentsAccumulator( - count_l1=np.array([[1000., 1000.]], dtype=np.float32), - count_l2=np.array([[499500., 499500.]], dtype=np.float32), - count_l3=np.array([[1.66167e+08, 1.66167e+08]], dtype=np.float32), - count_l4=np.array( - [[4.1417126e+10, 4.1417126e+10]], dtype=np.float32), - l1=np.array([[25.90623, -25.90623]], dtype=np.float32), - l2=np.array([[102.958664, 102.958664]], dtype=np.float32), - l3=np.array([[17.50719, -17.507195]], dtype=np.float32), - l4=np.array([[47.393063, 47.393066]], dtype=np.float32)), - # Accumulator for the sequence: - # np.concatenate(( - # np.concatenate(( - # np.power(2.0, np.arange(0.01, 10, 0.02)), - # -np.power(1.9, np.arange(0.01, 10, 0.02)))).reshape( - # [-1, 1, 1]), - # np.concatenate(( - # np.power(1.9, np.arange(0.01, 10, 0.02)), - # -np.power(2.0, np.arange(0.01, 10, 0.02)))).reshape( - # [-1, 1, 1])), axis=2), - # axis=0) - analyzers._LMomentsAccumulator( - count_l1=np.array([[1000., 1000.]], dtype=np.float32), - count_l2=np.array([[499500., 499500.]], dtype=np.float32), - count_l3=np.array([[1.66167e+08, 1.66167e+08]], dtype=np.float32), - count_l4=np.array( - [[4.1417126e+10, 4.1417126e+10]], dtype=np.float32), - l1=np.array([[26.110888, -26.110888]], dtype=np.float32), - l2=np.array([[103.65407, 103.654076]], dtype=np.float32), - l3=np.array([[17.64386, -17.643852]], dtype=np.float32), - l4=np.array([[47.71353, 47.71353]], dtype=np.float32)) - ], - expected_outputs=[ - np.array([[5.751478, -5.751478]], dtype=np.float32), - np.array([[81.16352, 81.16352]], dtype=np.float32), - np.array([[0.3923474, 0.55972165]], dtype=np.float32), - np.array([[0.55972165, 0.3923474]], dtype=np.float32) - ], -)] +_L_MOMENTS_ND_TESTS = [ + dict( + testcase_name="LMomentsOneBatchForNDVectors", + combiner=analyzers._LMomentsCombiner(np.float32, output_shape=(None, None)), + batches=[ + # Accumulator for the sequence: + # np.concatenate(( + # np.concatenate(( + # np.power(2.0, np.arange(0, 10, 0.01)), + # -np.power(1.9, np.arange(0, 10, 0.01)))).reshape( + # [-1, 1, 1]), + # np.concatenate(( + # np.power(1.9, np.arange(0, 10, 0.01)), + # -np.power(2.0, np.arange(0, 10, 0.01)))).reshape( + # [-1, 1, 1])), axis=2), + # axis=0), + analyzers._LMomentsAccumulator( + count_l1=np.array([[2000.0, 2000.0]], dtype=np.float32), + count_l2=np.array([[1999000.0, 1999000.0]], dtype=np.float32), + count_l3=np.array([[1.331334e09, 1.331334e09]], dtype=np.float32), + count_l4=np.array([[6.6466854e11, 6.6466854e11]], dtype=np.float32), + l1=np.array([[26.00855, -26.008562]], dtype=np.float32), + l2=np.array([[103.25489, 103.25489]], dtype=np.float32), + l3=np.array([[17.549286, -17.549274]], dtype=np.float32), + l4=np.array([[47.41136, 47.41136]], dtype=np.float32), + ) + ], + expected_outputs=[ + np.array([[5.7696896, -5.7697697]], dtype=np.float32), + np.array([[81.38142, 81.381386]], dtype=np.float32), + np.array([[0.39079103, 0.55846965]], dtype=np.float32), + np.array([[0.55846965, 0.39079177]], dtype=np.float32), + ], + ), + dict( + testcase_name="LMomentsMultipleBatchesForNDVectors", + combiner=analyzers._LMomentsCombiner(np.float32, output_shape=(None, None)), + batches=[ + # Accumulator for the sequence: + # np.concatenate(( + # np.concatenate(( + # np.power(2.0, np.arange(0, 10, 0.02)), + # -np.power(1.9, np.arange(0., 10, 0.02)))).reshape( + # [-1, 1, 1]), + # np.concatenate(( + # np.power(1.9, np.arange(0, 10, 0.02)), + # -np.power(2.0, np.arange(0., 10, 0.02)))).reshape( + # [-1, 1, 1])), axis=2), + # axis=0) + analyzers._LMomentsAccumulator( + count_l1=np.array([[1000.0, 1000.0]], dtype=np.float32), + count_l2=np.array([[499500.0, 499500.0]], dtype=np.float32), + count_l3=np.array([[1.66167e08, 1.66167e08]], dtype=np.float32), + count_l4=np.array([[4.1417126e10, 4.1417126e10]], dtype=np.float32), + l1=np.array([[25.90623, -25.90623]], dtype=np.float32), + l2=np.array([[102.958664, 102.958664]], dtype=np.float32), + l3=np.array([[17.50719, -17.507195]], dtype=np.float32), + l4=np.array([[47.393063, 47.393066]], dtype=np.float32), + ), + # Accumulator for the sequence: + # np.concatenate(( + # np.concatenate(( + # np.power(2.0, np.arange(0.01, 10, 0.02)), + # -np.power(1.9, np.arange(0.01, 10, 0.02)))).reshape( + # [-1, 1, 1]), + # np.concatenate(( + # np.power(1.9, np.arange(0.01, 10, 0.02)), + # -np.power(2.0, np.arange(0.01, 10, 0.02)))).reshape( + # [-1, 1, 1])), axis=2), + # axis=0) + analyzers._LMomentsAccumulator( + count_l1=np.array([[1000.0, 1000.0]], dtype=np.float32), + count_l2=np.array([[499500.0, 499500.0]], dtype=np.float32), + count_l3=np.array([[1.66167e08, 1.66167e08]], dtype=np.float32), + count_l4=np.array([[4.1417126e10, 4.1417126e10]], dtype=np.float32), + l1=np.array([[26.110888, -26.110888]], dtype=np.float32), + l2=np.array([[103.65407, 103.654076]], dtype=np.float32), + l3=np.array([[17.64386, -17.643852]], dtype=np.float32), + l4=np.array([[47.71353, 47.71353]], dtype=np.float32), + ), + ], + expected_outputs=[ + np.array([[5.751478, -5.751478]], dtype=np.float32), + np.array([[81.16352, 81.16352]], dtype=np.float32), + np.array([[0.3923474, 0.55972165]], dtype=np.float32), + np.array([[0.55972165, 0.3923474]], dtype=np.float32), + ], + ), +] _QUANTILES_NO_ELEMENTS_TEST = dict( - testcase_name='ComputeQuantilesNoElements', + testcase_name="ComputeQuantilesNoElements", combiner=analyzers.QuantilesCombiner( - num_quantiles=5, epsilon=0.00001, bucket_numpy_dtype=np.float32), + num_quantiles=5, epsilon=0.00001, bucket_numpy_dtype=np.float32 + ), batches=[ (np.empty((0, 1), dtype=np.float32),), ], @@ -428,9 +436,10 @@ def _make_mean_and_var_accumulator_from_instance(instance, axis=None): ) _QUANTILES_EXACT_NO_ELEMENTS_TEST = dict( - testcase_name='ComputeExactQuantilesNoElements', + testcase_name="ComputeExactQuantilesNoElements", combiner=analyzers.QuantilesCombiner( - num_quantiles=5, epsilon=0.00001, bucket_numpy_dtype=np.float32), + num_quantiles=5, epsilon=0.00001, bucket_numpy_dtype=np.float32 + ), batches=[ (np.empty((0, 1), dtype=np.float32),), ], @@ -438,12 +447,13 @@ def _make_mean_and_var_accumulator_from_instance(instance, axis=None): ) _QUANTILES_NO_TRIM_TEST = dict( - testcase_name='NoTrimQuantilesTest', + testcase_name="NoTrimQuantilesTest", combiner=analyzers.QuantilesCombiner( num_quantiles=4, epsilon=0.00001, bucket_numpy_dtype=np.float32, - include_max_and_min=True), + include_max_and_min=True, + ), batches=[ (np.array([1, 1]),), ], @@ -453,9 +463,10 @@ def _make_mean_and_var_accumulator_from_instance(instance, axis=None): # pylint: disable=g-complex-comprehension _QUANTILES_SINGLE_BATCH_TESTS = [ dict( - testcase_name='ComputeQuantilesSingleBatch-{}'.format(np_type), + testcase_name=f"ComputeQuantilesSingleBatch-{np_type}", combiner=analyzers.QuantilesCombiner( - num_quantiles=5, epsilon=0.00001, bucket_numpy_dtype=np.float32), + num_quantiles=5, epsilon=0.00001, bucket_numpy_dtype=np.float32 + ), batches=[ (np.linspace(1, 100, 100, dtype=np_type),), (np.linspace(101, 200, 100, dtype=np_type),), @@ -463,166 +474,187 @@ def _make_mean_and_var_accumulator_from_instance(instance, axis=None): (np.empty((0, 3)),), ], expected_outputs=[np.array([61, 121, 181, 241], dtype=np.float32)], - ) for np_type in _NP_TYPES + ) + for np_type in _NP_TYPES ] _QUANTILES_ELEMENTWISE_TESTS = [ dict( - testcase_name='ComputeQuantilesElementwise-{}'.format(np_type), + testcase_name=f"ComputeQuantilesElementwise-{np_type}", combiner=analyzers.QuantilesCombiner( num_quantiles=5, epsilon=0.00001, bucket_numpy_dtype=np.float32, - feature_shape=[3]), + feature_shape=[3], + ), batches=[ - (np.vstack([np.linspace(1, 100, 100, dtype=np_type), + ( + np.vstack( + [ + np.linspace(1, 100, 100, dtype=np_type), np.linspace(101, 200, 100, dtype=np_type), - np.linspace(201, 300, 100, dtype=np_type)]).T,), + np.linspace(201, 300, 100, dtype=np_type), + ] + ).T, + ), (np.empty((0, 3)),), ], - expected_outputs=[np.array([[21, 41, 61, 81], - [121, 141, 161, 181], - [221, 241, 261, 281]], dtype=np.float32)], - ) for np_type in _NP_TYPES + expected_outputs=[ + np.array( + [[21, 41, 61, 81], [121, 141, 161, 181], [221, 241, 261, 281]], + dtype=np.float32, + ) + ], + ) + for np_type in _NP_TYPES ] _QUANTILES_MULTIPLE_BATCH_TESTS = [ dict( - testcase_name='ComputeQuantilesMultipleBatch-{}'.format(np_type), + testcase_name=f"ComputeQuantilesMultipleBatch-{np_type}", combiner=analyzers.QuantilesCombiner( - num_quantiles=3, epsilon=0.00001, bucket_numpy_dtype=np.float32), + num_quantiles=3, epsilon=0.00001, bucket_numpy_dtype=np.float32 + ), batches=[ (np.linspace(1, 100, 100, np_type),), ], expected_outputs=[np.array([34, 67], dtype=np.float32)], - ) for np_type in _NP_TYPES + ) + for np_type in _NP_TYPES ] _EXACT_NUM_QUANTILES_TESTS = [ dict( - testcase_name='ComputeExactNumQuantiles-{}'.format(np_type), + testcase_name=f"ComputeExactNumQuantiles-{np_type}", combiner=analyzers.QuantilesCombiner( - num_quantiles=4, epsilon=0.00001, bucket_numpy_dtype=np.float32), + num_quantiles=4, epsilon=0.00001, bucket_numpy_dtype=np.float32 + ), batches=[ (np.array([1, 1]),), ], expected_outputs=[np.array([1, 1, 1], dtype=np.float32)], - ) for np_type in _NP_TYPES + ) + for np_type in _NP_TYPES ] # pylint: enable=g-complex-comprehension class AnalyzersTest(test_case.TransformTestCase): - - @test_case.named_parameters( - *[ - _SUM_TEST, - _SUM_SCALAR_TEST, - _SUM_OF_SIZE_ZERO_TENSORS_TEST, - _COVARIANCE_SIZE_ZERO_TENSORS_TEST, - _COVARIANCE_WITH_DEGENERATE_COVARIANCE_MATRIX_TEST, - _COVARIANCE_WITH_LARGE_NUMBERS_TEST, - _PCA_WITH_DEGENERATE_COVARIANCE_MATRIX_TEST, - _MEAN_AND_VAR_TEST, - _MEAN_AND_VAR_SIMPLE_TEST, - _MEAN_AND_VAR_BIG_TEST, - _MEAN_AND_VAR_VECTORS_TEST, - _MEAN_AND_VAR_ND_TEST, - _QUANTILES_NO_ELEMENTS_TEST, - _QUANTILES_NO_TRIM_TEST, - _QUANTILES_EXACT_NO_ELEMENTS_TEST, - ] + _L_MOMENTS_TESTS + _L_MOMENTS_ND_TESTS + - _QUANTILES_SINGLE_BATCH_TESTS + _QUANTILES_MULTIPLE_BATCH_TESTS + - _QUANTILES_ELEMENTWISE_TESTS + _EXACT_NUM_QUANTILES_TESTS) - def testCombiner(self, combiner, batches, expected_outputs): - """Tests the provided combiner. - - Args: - combiner: An object implementing the Combiner interface. - batches: A list of batches, each is a tuples of ndarrays. each ndarray - represents the values of an input tensor of the analyzer over a single - batch. - expected_outputs: The expected outputs from extract_output. - - Exercises create_accumulator, add_input, merge_accumulators, - and extract_output. - """ - # Test serialization faithfully reproduces the object. If tests - # mysteriously break, it could be because __reduce__ is missing something. - combiner = pickle.loads(pickle.dumps(combiner)) - - # Note `accumulators` is a generator, not list. We do this to ensure that - # add_input is not relying on its input being a list. - accumulators = ( - combiner.add_input(combiner.create_accumulator(), batch) - for batch in batches) - - final_accumulator = combiner.merge_accumulators(accumulators) - outputs = combiner.extract_output(final_accumulator) - tensor_infos = combiner.output_tensor_infos() - self.assertEqual(len(outputs), len(expected_outputs)) - self.assertEqual(len(outputs), len(tensor_infos)) - for output, expected_output, tensor_info in zip( - outputs, expected_outputs, tensor_infos): - self.assertEqual(output.dtype, expected_output.dtype) - self.assertEqual(tensor_info.dtype, tf.as_dtype(expected_output.dtype)) - - self.assertAllClose(output, expected_output, rtol=1e-4, atol=1e-4) - - @test_case.named_parameters( - { - 'testcase_name': '1d', - 'a': np.array([1]), - 'b': np.array([1, 1]), - 'expected_a': np.array([1, 0]), - 'expected_b': np.array([1, 1]), - }, - { - 'testcase_name': '2d_1different', - 'a': np.array([[1], [1]]), - 'b': np.array([[1], [1], [2]]), - 'expected_a': np.array([[1], [1], [0]]), - 'expected_b': np.array([[1], [1], [2]]), - }, - { - 'testcase_name': '2d_2different', - 'a': np.array([[1, 3], [1, 3]]), - 'b': np.array([[1], [1], [2]]), - 'expected_a': np.array([[1, 3], [1, 3], [0, 0]]), - 'expected_b': np.array([[1, 0], [1, 0], [2, 0]]), - }, - { - 'testcase_name': '3d_1different', - 'a': np.array([[[1], [1]], [[1], [1]]]), - 'b': np.array([[[1], [1]]]), - 'expected_a': np.array([[[1], [1]], [[1], [1]]]), - 'expected_b': np.array([[[1], [1]], [[0], [0]]]), - }, - { - 'testcase_name': '3d_2different', - 'a': np.array([[[1], [1]], [[1], [1]]]), - 'b': np.array([[[1, 1], [1, 1]]]), - 'expected_a': np.array([[[1, 0], [1, 0]], [[1, 0], [1, 0]]]), - 'expected_b': np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]]]), - }, - ) - def test_pad_arrays_to_match(self, a, b, expected_a, expected_b): - a2, b2 = analyzers._pad_arrays_to_match(a, b) - self.assertAllClose(a2, expected_a) - self.assertAllClose(b2, expected_b) - - def testMinDiffFromAvg(self): - # Small dataset gets the minimum of 2 - self.assertEqual( - analyzers.calculate_recommended_min_diff_from_avg(10000), 2) - self.assertEqual( - analyzers.calculate_recommended_min_diff_from_avg(100000), 4) - self.assertEqual( - analyzers.calculate_recommended_min_diff_from_avg(500000), 13) - # Large dataset gets the maximum of 25 - self.assertEqual( - analyzers.calculate_recommended_min_diff_from_avg(100000000), 25) - - -if __name__ == '__main__': - test_case.main() + @test_case.named_parameters( + *[ + _SUM_TEST, + _SUM_SCALAR_TEST, + _SUM_OF_SIZE_ZERO_TENSORS_TEST, + _COVARIANCE_SIZE_ZERO_TENSORS_TEST, + _COVARIANCE_WITH_DEGENERATE_COVARIANCE_MATRIX_TEST, + _COVARIANCE_WITH_LARGE_NUMBERS_TEST, + _PCA_WITH_DEGENERATE_COVARIANCE_MATRIX_TEST, + _MEAN_AND_VAR_TEST, + _MEAN_AND_VAR_SIMPLE_TEST, + _MEAN_AND_VAR_BIG_TEST, + _MEAN_AND_VAR_VECTORS_TEST, + _MEAN_AND_VAR_ND_TEST, + _QUANTILES_NO_ELEMENTS_TEST, + _QUANTILES_NO_TRIM_TEST, + _QUANTILES_EXACT_NO_ELEMENTS_TEST, + ] + + _L_MOMENTS_TESTS + + _L_MOMENTS_ND_TESTS + + _QUANTILES_SINGLE_BATCH_TESTS + + _QUANTILES_MULTIPLE_BATCH_TESTS + + _QUANTILES_ELEMENTWISE_TESTS + + _EXACT_NUM_QUANTILES_TESTS + ) + def testCombiner(self, combiner, batches, expected_outputs): + """Tests the provided combiner. + + Args: + ---- + combiner: An object implementing the Combiner interface. + batches: A list of batches, each is a tuples of ndarrays. each ndarray + represents the values of an input tensor of the analyzer over a single + batch. + expected_outputs: The expected outputs from extract_output. + + Exercises create_accumulator, add_input, merge_accumulators, + and extract_output. + """ + # Test serialization faithfully reproduces the object. If tests + # mysteriously break, it could be because __reduce__ is missing something. + combiner = pickle.loads(pickle.dumps(combiner)) + + # Note `accumulators` is a generator, not list. We do this to ensure that + # add_input is not relying on its input being a list. + accumulators = ( + combiner.add_input(combiner.create_accumulator(), batch) + for batch in batches + ) + + final_accumulator = combiner.merge_accumulators(accumulators) + outputs = combiner.extract_output(final_accumulator) + tensor_infos = combiner.output_tensor_infos() + self.assertEqual(len(outputs), len(expected_outputs)) + self.assertEqual(len(outputs), len(tensor_infos)) + for output, expected_output, tensor_info in zip( + outputs, expected_outputs, tensor_infos + ): + self.assertEqual(output.dtype, expected_output.dtype) + self.assertEqual(tensor_info.dtype, tf.as_dtype(expected_output.dtype)) + + self.assertAllClose(output, expected_output, rtol=1e-4, atol=1e-4) + + @test_case.named_parameters( + { + "testcase_name": "1d", + "a": np.array([1]), + "b": np.array([1, 1]), + "expected_a": np.array([1, 0]), + "expected_b": np.array([1, 1]), + }, + { + "testcase_name": "2d_1different", + "a": np.array([[1], [1]]), + "b": np.array([[1], [1], [2]]), + "expected_a": np.array([[1], [1], [0]]), + "expected_b": np.array([[1], [1], [2]]), + }, + { + "testcase_name": "2d_2different", + "a": np.array([[1, 3], [1, 3]]), + "b": np.array([[1], [1], [2]]), + "expected_a": np.array([[1, 3], [1, 3], [0, 0]]), + "expected_b": np.array([[1, 0], [1, 0], [2, 0]]), + }, + { + "testcase_name": "3d_1different", + "a": np.array([[[1], [1]], [[1], [1]]]), + "b": np.array([[[1], [1]]]), + "expected_a": np.array([[[1], [1]], [[1], [1]]]), + "expected_b": np.array([[[1], [1]], [[0], [0]]]), + }, + { + "testcase_name": "3d_2different", + "a": np.array([[[1], [1]], [[1], [1]]]), + "b": np.array([[[1, 1], [1, 1]]]), + "expected_a": np.array([[[1, 0], [1, 0]], [[1, 0], [1, 0]]]), + "expected_b": np.array([[[1, 1], [1, 1]], [[0, 0], [0, 0]]]), + }, + ) + def test_pad_arrays_to_match(self, a, b, expected_a, expected_b): + a2, b2 = analyzers._pad_arrays_to_match(a, b) + self.assertAllClose(a2, expected_a) + self.assertAllClose(b2, expected_b) + + def testMinDiffFromAvg(self): + # Small dataset gets the minimum of 2 + self.assertEqual(analyzers.calculate_recommended_min_diff_from_avg(10000), 2) + self.assertEqual(analyzers.calculate_recommended_min_diff_from_avg(100000), 4) + self.assertEqual(analyzers.calculate_recommended_min_diff_from_avg(500000), 13) + # Large dataset gets the maximum of 25 + self.assertEqual( + analyzers.calculate_recommended_min_diff_from_avg(100000000), 25 + ) + + +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/annotators.py b/tensorflow_transform/annotators.py index 5da478b..553c9c8 100644 --- a/tensorflow_transform/annotators.py +++ b/tensorflow_transform/annotators.py @@ -22,213 +22,232 @@ from typing import Callable, List, Optional import tensorflow as tf +from tensorflow.python.trackable import ( + base, # pylint: disable=g-direct-tensorflow-import +) + from tensorflow_transform.graph_context import TFGraphContext from tensorflow_transform.keras_lib import tf_keras -from tensorflow.python.trackable import base # pylint: disable=g-direct-tensorflow-import -__all__ = ['annotate_asset', 'make_and_track_object'] +__all__ = ["annotate_asset", "make_and_track_object"] -_ASSET_KEY_COLLECTION = 'tft_asset_key_collection' -_ASSET_FILENAME_COLLECTION = 'tft_asset_filename_collection' +_ASSET_KEY_COLLECTION = "tft_asset_key_collection" +_ASSET_FILENAME_COLLECTION = "tft_asset_filename_collection" # Thread-Hostile _OBJECT_TRACKER = None -VOCABULARY_SIZE_BY_NAME_COLLECTION = 'tft_vocabulary_size_by_name_collection' +VOCABULARY_SIZE_BY_NAME_COLLECTION = "tft_vocabulary_size_by_name_collection" class ObjectTracker: - """A class that tracks a list of trackable objects.""" - - __slots__ = ['_trackable_objects'] - - def __init__(self): - self._trackable_objects = [] - - @property - def trackable_objects(self) -> List[base.Trackable]: - return self._trackable_objects - - def add_trackable_object(self, trackable_object: base.Trackable, - name: Optional[str]): - """Add `trackable_object` to list of objects tracked.""" - if name is None: - self._trackable_objects.append(trackable_object) - else: - module = TFGraphContext.get_module_to_export() - # The `preprocessing_fn` should always be invoked within a TFGraphContext. - # If not, module will be None. - if module is None: - raise RuntimeError( - f'No module found to track {name} with. Check that the ' - '`preprocessing_fn` is invoked within a `TFGraphContext` with a ' - 'valid `TFGraphContext.module_to_export`.') - if hasattr(module, name): - raise ValueError( - f'An object with name {name} is already being tracked. Check that a ' - 'unique name was passed.') - setattr(module, name, trackable_object) + """A class that tracks a list of trackable objects.""" + + __slots__ = ["_trackable_objects"] + + def __init__(self): + self._trackable_objects = [] + + @property + def trackable_objects(self) -> List[base.Trackable]: + return self._trackable_objects + + def add_trackable_object( + self, trackable_object: base.Trackable, name: Optional[str] + ): + """Add `trackable_object` to list of objects tracked.""" + if name is None: + self._trackable_objects.append(trackable_object) + else: + module = TFGraphContext.get_module_to_export() + # The `preprocessing_fn` should always be invoked within a TFGraphContext. + # If not, module will be None. + if module is None: + raise RuntimeError( + f"No module found to track {name} with. Check that the " + "`preprocessing_fn` is invoked within a `TFGraphContext` with a " + "valid `TFGraphContext.module_to_export`." + ) + if hasattr(module, name): + raise ValueError( + f"An object with name {name} is already being tracked. Check that a " + "unique name was passed." + ) + setattr(module, name, trackable_object) # Thread-Hostile @contextlib.contextmanager def object_tracker_scope(object_tracker: ObjectTracker): - """A context to manage trackable objects. + """A context to manage trackable objects. - Collects all trackable objects annotated using `track_object` within the body - of its scope. + Collects all trackable objects annotated using `track_object` within the body + of its scope. - Args: - object_tracker: The passed in ObjectTracker object + Args: + ---- + object_tracker: The passed in ObjectTracker object - Yields: - A scope in which the object_tracker is active. - """ - global _OBJECT_TRACKER - # Multiple nested object_tracker_scope calls are not expected. - assert _OBJECT_TRACKER is None - _OBJECT_TRACKER = object_tracker - try: - yield - finally: - _OBJECT_TRACKER = None + Yields: + ------ + A scope in which the object_tracker is active. + """ + global _OBJECT_TRACKER + # Multiple nested object_tracker_scope calls are not expected. + assert _OBJECT_TRACKER is None + _OBJECT_TRACKER = object_tracker + try: + yield + finally: + _OBJECT_TRACKER = None def _get_object(name: str) -> Optional[base.Trackable]: - """If an object is being tracked using `name` return it, else None.""" - module = TFGraphContext.get_module_to_export() - # The `preprocessing_fn` should always be invoked within a TFGraphContext. If - # not, module will be None. - if module is None: - raise RuntimeError( - f'No module found to track {name} with. Check that the `preprocessing_fn` is' - ' invoked within a `TFGraphContext` with a valid ' - '`TFGraphContext.module_to_export`.') - return getattr(module, name, None) + """If an object is being tracked using `name` return it, else None.""" + module = TFGraphContext.get_module_to_export() + # The `preprocessing_fn` should always be invoked within a TFGraphContext. If + # not, module will be None. + if module is None: + raise RuntimeError( + f"No module found to track {name} with. Check that the `preprocessing_fn` is" + " invoked within a `TFGraphContext` with a valid " + "`TFGraphContext.module_to_export`." + ) + return getattr(module, name, None) # Thread-Hostile def track_object(trackable: base.Trackable, name: Optional[str]): - """Add `trackable` to the object trackers active in this scope.""" - global _OBJECT_TRACKER - # The transform tf.function should always be traced - # (call to get_concrete_function) within an object_tracker_scope. - assert _OBJECT_TRACKER is not None - _OBJECT_TRACKER.add_trackable_object(trackable, name) + """Add `trackable` to the object trackers active in this scope.""" + global _OBJECT_TRACKER + # The transform tf.function should always be traced + # (call to get_concrete_function) within an object_tracker_scope. + assert _OBJECT_TRACKER is not None + _OBJECT_TRACKER.add_trackable_object(trackable, name) # Thread-Hostile -def make_and_track_object(trackable_factory_callable: Callable[[], - base.Trackable], - name: Optional[str] = None) -> base.Trackable: - # pyformat: disable - """Keeps track of the object created by invoking `trackable_factory_callable`. - - This API is only for use when Transform APIs are run with TF2 behaviors - enabled and `tft_beam.Context.force_tf_compat_v1` is set to False. - - Use this API to track TF Trackable objects created in the `preprocessing_fn` - such as tf.hub modules, tf.data.Dataset etc. This ensures they are serialized - correctly when exporting to SavedModel. - - Args: - trackable_factory_callable: A callable that creates and returns a Trackable - object. - name: (Optional) Provide a unique name to track this object with. If the - Trackable object created is a Keras Layer or Model this is needed for - proper tracking. - - Example: - - >>> def preprocessing_fn(inputs): - ... dataset = tft.make_and_track_object( - ... lambda: tf.data.Dataset.from_tensor_slices([1, 2, 3])) - ... with tf.init_scope(): - ... dataset_list = list(dataset.as_numpy_iterator()) - ... return {'x_0': dataset_list[0] + inputs['x']} - >>> raw_data = [dict(x=1), dict(x=2), dict(x=3)] - >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp(), - ... force_tf_compat_v1=False): - ... transformed_dataset, transform_fn = ( - ... (raw_data, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - >>> transformed_data, transformed_metadata = transformed_dataset - >>> transformed_data - [{'x_0': 2}, {'x_0': 3}, {'x_0': 4}] - - Returns: - The object returned when trackable_factory_callable is invoked. The object - creation is lifted out to the eager context using `tf.init_scope`. - """ - # pyformat: enable - if not tf.inside_function(): - raise ValueError('This API should only be invoked inside the user defined ' - '`preprocessing_fn` with TF2 behaviors enabled and ' - '`force_tf_compat_v1=False`. ') - result = _get_object(name) if name is not None else None - if result is None: - with tf.init_scope(): - result = trackable_factory_callable() - if name is None and isinstance(result, tf_keras.layers.Layer): +def make_and_track_object( + trackable_factory_callable: Callable[[], base.Trackable], name: Optional[str] = None +) -> base.Trackable: + # pyformat: disable + """Keeps track of the object created by invoking `trackable_factory_callable`. + + This API is only for use when Transform APIs are run with TF2 behaviors + enabled and `tft_beam.Context.force_tf_compat_v1` is set to False. + + Use this API to track TF Trackable objects created in the `preprocessing_fn` + such as tf.hub modules, tf.data.Dataset etc. This ensures they are serialized + correctly when exporting to SavedModel. + + Args: + ---- + trackable_factory_callable: A callable that creates and returns a Trackable + object. + name: (Optional) Provide a unique name to track this object with. If the + Trackable object created is a Keras Layer or Model this is needed for + proper tracking. + + Example: + ------- + >>> def preprocessing_fn(inputs): + ... dataset = tft.make_and_track_object( + ... lambda: tf.data.Dataset.from_tensor_slices([1, 2, 3])) + ... with tf.init_scope(): + ... dataset_list = list(dataset.as_numpy_iterator()) + ... return {'x_0': dataset_list[0] + inputs['x']} + >>> raw_data = [dict(x=1), dict(x=2), dict(x=3)] + >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp(), + ... force_tf_compat_v1=False): + ... transformed_dataset, transform_fn = ( + ... (raw_data, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + >>> transformed_data, transformed_metadata = transformed_dataset + >>> transformed_data + [{'x_0': 2}, {'x_0': 3}, {'x_0': 4}] + + Returns: + ------- + The object returned when trackable_factory_callable is invoked. The object + creation is lifted out to the eager context using `tf.init_scope`. + """ + # pyformat: enable + if not tf.inside_function(): raise ValueError( - 'Please pass a unique `name` to this API to ensure Keras objects ' - 'are tracked correctly.') - track_object(result, name) - return result + "This API should only be invoked inside the user defined " + "`preprocessing_fn` with TF2 behaviors enabled and " + "`force_tf_compat_v1=False`. " + ) + result = _get_object(name) if name is not None else None + if result is None: + with tf.init_scope(): + result = trackable_factory_callable() + if name is None and isinstance(result, tf_keras.layers.Layer): + raise ValueError( + "Please pass a unique `name` to this API to ensure Keras objects " + "are tracked correctly." + ) + track_object(result, name) + return result def get_asset_annotations(graph: tf.Graph): - """Obtains the asset annotations in the specified graph. - - Args: - graph: A `tf.Graph` object. - - Returns: - A dict that maps asset_keys to asset_filenames. Note that if multiple - entries for the same key exist, later ones will override earlier ones. - """ - asset_key_collection = graph.get_collection(_ASSET_KEY_COLLECTION) - asset_filename_collection = graph.get_collection(_ASSET_FILENAME_COLLECTION) - assert len(asset_key_collection) == len( - asset_filename_collection - ), 'Length of asset key and filename collections must match.' - # Remove scope. - annotations = { - os.path.basename(key): os.path.basename(filename) - for key, filename in zip(asset_key_collection, asset_filename_collection) - } - return annotations + """Obtains the asset annotations in the specified graph. + + Args: + ---- + graph: A `tf.Graph` object. + + Returns: + ------- + A dict that maps asset_keys to asset_filenames. Note that if multiple + entries for the same key exist, later ones will override earlier ones. + """ + asset_key_collection = graph.get_collection(_ASSET_KEY_COLLECTION) + asset_filename_collection = graph.get_collection(_ASSET_FILENAME_COLLECTION) + assert len(asset_key_collection) == len( + asset_filename_collection + ), "Length of asset key and filename collections must match." + # Remove scope. + annotations = { + os.path.basename(key): os.path.basename(filename) + for key, filename in zip(asset_key_collection, asset_filename_collection) + } + return annotations def clear_asset_annotations(graph: tf.Graph): - """Clears the asset annotations. + """Clears the asset annotations. - Args: - graph: A `tf.Graph` object. - """ - graph.clear_collection(_ASSET_KEY_COLLECTION) - graph.clear_collection(_ASSET_FILENAME_COLLECTION) + Args: + ---- + graph: A `tf.Graph` object. + """ + graph.clear_collection(_ASSET_KEY_COLLECTION) + graph.clear_collection(_ASSET_FILENAME_COLLECTION) def annotate_asset(asset_key: str, asset_filename: str): - """Creates mapping between user-defined keys and SavedModel assets. + """Creates mapping between user-defined keys and SavedModel assets. - This mapping is made available in `BeamDatasetMetadata` and is also used to - resolve vocabularies in `tft.TFTransformOutput`. + This mapping is made available in `BeamDatasetMetadata` and is also used to + resolve vocabularies in `tft.TFTransformOutput`. - Note: multiple mappings for the same key will overwrite the previous one. + Note: multiple mappings for the same key will overwrite the previous one. - Args: - asset_key: The key to associate with the asset. - asset_filename: The filename as it appears within the assets/ subdirectory. - Must be sanitized and complete (e.g. include the tfrecord.gz for suffix - appropriate files). - """ - tf.compat.v1.add_to_collection(_ASSET_KEY_COLLECTION, asset_key) - tf.compat.v1.add_to_collection(_ASSET_FILENAME_COLLECTION, asset_filename) + Args: + ---- + asset_key: The key to associate with the asset. + asset_filename: The filename as it appears within the assets/ subdirectory. + Must be sanitized and complete (e.g. include the tfrecord.gz for suffix + appropriate files). + """ + tf.compat.v1.add_to_collection(_ASSET_KEY_COLLECTION, asset_key) + tf.compat.v1.add_to_collection(_ASSET_FILENAME_COLLECTION, asset_filename) def annotate_vocab_size(vocab_filename: str, vocab_size: tf.Tensor): - """Adds annotation to retrieve the vocabulary size for `vocab_filename`.""" - tf.compat.v1.add_to_collection(VOCABULARY_SIZE_BY_NAME_COLLECTION, - (vocab_filename, vocab_size)) + """Adds annotation to retrieve the vocabulary size for `vocab_filename`.""" + tf.compat.v1.add_to_collection( + VOCABULARY_SIZE_BY_NAME_COLLECTION, (vocab_filename, vocab_size) + ) diff --git a/tensorflow_transform/annotators_test.py b/tensorflow_transform/annotators_test.py index 7058331..5d01efc 100644 --- a/tensorflow_transform/annotators_test.py +++ b/tensorflow_transform/annotators_test.py @@ -14,56 +14,55 @@ """Tests for tensorflow_transform.annotators.""" import tensorflow as tf -from tensorflow_transform import annotators -from tensorflow_transform import test_case +from tensorflow_transform import annotators, test_case -class AnnotatorsTest(test_case.TransformTestCase): - @test_case.named_parameters( - dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True), - dict(testcase_name='tf2', use_tf_compat_v1=False)) - def test_annotate_asset(self, use_tf_compat_v1): - if not use_tf_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') +class AnnotatorsTest(test_case.TransformTestCase): + @test_case.named_parameters( + dict(testcase_name="tf_compat_v1", use_tf_compat_v1=True), + dict(testcase_name="tf2", use_tf_compat_v1=False), + ) + def test_annotate_asset(self, use_tf_compat_v1): + if not use_tf_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") - def foo(): - annotators.annotate_asset('scope/my_key', 'scope/my_value') - annotators.annotate_asset('my_key2', 'should_be_replaced') - annotators.annotate_asset('my_key2', 'my_value2') + def foo(): + annotators.annotate_asset("scope/my_key", "scope/my_value") + annotators.annotate_asset("my_key2", "should_be_replaced") + annotators.annotate_asset("my_key2", "my_value2") - if use_tf_compat_v1: - with tf.Graph().as_default() as graph: - foo() - else: - graph = tf.function(foo).get_concrete_function().graph + if use_tf_compat_v1: + with tf.Graph().as_default() as graph: + foo() + else: + graph = tf.function(foo).get_concrete_function().graph - self.assertDictEqual( - annotators.get_asset_annotations(graph), { - 'my_key': 'my_value', - 'my_key2': 'my_value2' - }) + self.assertDictEqual( + annotators.get_asset_annotations(graph), + {"my_key": "my_value", "my_key2": "my_value2"}, + ) - annotators.clear_asset_annotations(graph) - self.assertDictEqual(annotators.get_asset_annotations(graph), {}) + annotators.clear_asset_annotations(graph) + self.assertDictEqual(annotators.get_asset_annotations(graph), {}) - def test_object_tracker(self): - test_case.skip_if_not_tf2('Tensorflow 2.x required') + def test_object_tracker(self): + test_case.skip_if_not_tf2("Tensorflow 2.x required") - trackable_object = tf.__internal__.tracking.Trackable() + trackable_object = tf.__internal__.tracking.Trackable() - @tf.function - def preprocessing_fn(): - _ = annotators.make_and_track_object(lambda: trackable_object) - return 1 + @tf.function + def preprocessing_fn(): + _ = annotators.make_and_track_object(lambda: trackable_object) + return 1 - object_tracker = annotators.ObjectTracker() - with annotators.object_tracker_scope(object_tracker): - _ = preprocessing_fn() + object_tracker = annotators.ObjectTracker() + with annotators.object_tracker_scope(object_tracker): + _ = preprocessing_fn() - self.assertLen(object_tracker.trackable_objects, 1) - self.assertEqual(trackable_object, object_tracker.trackable_objects[0]) + self.assertLen(object_tracker.trackable_objects, 1) + self.assertEqual(trackable_object, object_tracker.trackable_objects[0]) -if __name__ == '__main__': - test_case.main() +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/beam/__init__.py b/tensorflow_transform/beam/__init__.py index 0164261..04744cd 100644 --- a/tensorflow_transform/beam/__init__.py +++ b/tensorflow_transform/beam/__init__.py @@ -17,15 +17,15 @@ # The doc-generator's `explicit_package_contents_filter` requires that # sub-modules you want documented are explicitly imported. # Also: analyzer_impls registers implementation of analyzers. -from tensorflow_transform.beam import analyzer_cache -from tensorflow_transform.beam import analyzer_impls -from tensorflow_transform.beam import experimental +from tensorflow_transform.beam import analyzer_cache, analyzer_impls, experimental from tensorflow_transform.beam.context import Context -from tensorflow_transform.beam.impl import AnalyzeAndTransformDataset -from tensorflow_transform.beam.impl import AnalyzeDataset -from tensorflow_transform.beam.impl import AnalyzeDatasetWithCache -from tensorflow_transform.beam.impl import EncodeTransformedDataset -from tensorflow_transform.beam.impl import TransformDataset +from tensorflow_transform.beam.impl import ( + AnalyzeAndTransformDataset, + AnalyzeDataset, + AnalyzeDatasetWithCache, + EncodeTransformedDataset, + TransformDataset, +) from tensorflow_transform.beam.tft_beam_io import * # pylint: enable=wildcard-import @@ -34,6 +34,6 @@ # `tensorflow_io` package. Hence, this import is needed wherever we touch the # filesystem. try: - import tensorflow_io as _ # pytype: disable=import-error # pylint: disable=g-import-not-at-top + import tensorflow_io as _ # pytype: disable=import-error # pylint: disable=g-import-not-at-top except ModuleNotFoundError: - pass + pass diff --git a/tensorflow_transform/beam/analysis_graph_builder.py b/tensorflow_transform/beam/analysis_graph_builder.py index e90bd28..ee2fccd 100644 --- a/tensorflow_transform/beam/analysis_graph_builder.py +++ b/tensorflow_transform/beam/analysis_graph_builder.py @@ -20,17 +20,17 @@ from typing import OrderedDict as OrderedDictType import tensorflow as tf -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import common_types -from tensorflow_transform import graph_tools -from tensorflow_transform import impl_helper -from tensorflow_transform import nodes -from tensorflow_transform import tf2_utils -from tensorflow_transform import tf_utils -from tensorflow_transform.beam import analyzer_cache -from tensorflow_transform.beam import beam_nodes -from tensorflow_transform.beam import combiner_packing_util +from tensorflow_transform import ( + analyzer_nodes, + common_types, + graph_tools, + impl_helper, + nodes, + tf2_utils, + tf_utils, +) +from tensorflow_transform.beam import analyzer_cache, beam_nodes, combiner_packing_util # Used for debugging only. This will point to the most recent graph built. _ANALYSIS_GRAPH = None @@ -42,382 +42,427 @@ def _tensor_name(tensor): - """Get a name of a tensor without trailing ":0" when relevant.""" - # tensor.name is unicode in Python 3 and bytes in Python 2 so convert to - # bytes here. - name = str(tensor.name) - return name[:-2] if name.endswith(':0') else name + """Get a name of a tensor without trailing ":0" when relevant.""" + # tensor.name is unicode in Python 3 and bytes in Python 2 so convert to + # bytes here. + name = str(tensor.name) + return name[:-2] if name.endswith(":0") else name class _ReadyVisitor(nodes.Visitor): - """Visitor to determine if a node is ready to run.""" + """Visitor to determine if a node is ready to run.""" - def __init__(self, graph_analyzer): - self._graph_analyzer = graph_analyzer - self._visited_operation_def_labels = set() + def __init__(self, graph_analyzer): + self._graph_analyzer = graph_analyzer + self._visited_operation_def_labels = set() - def _validate_operation_label_uniqueness(self, operation_def): - assert operation_def.label not in self._visited_operation_def_labels, ( - f'An operation with label {operation_def.label} ' - 'already exists in the operations graph.') - self._visited_operation_def_labels.add(operation_def.label) + def _validate_operation_label_uniqueness(self, operation_def): + assert operation_def.label not in self._visited_operation_def_labels, ( + f"An operation with label {operation_def.label} " + "already exists in the operations graph." + ) + self._visited_operation_def_labels.add(operation_def.label) - def visit(self, operation_def, input_values): - self._validate_operation_label_uniqueness(operation_def) + def visit(self, operation_def, input_values): + self._validate_operation_label_uniqueness(operation_def) - if isinstance(operation_def, analyzer_nodes.TensorSource): - is_ready = all(self._graph_analyzer.ready_to_run(tensor) - for tensor in operation_def.tensors) - else: - is_ready = all(input_values) - return (is_ready,) * operation_def.num_outputs + if isinstance(operation_def, analyzer_nodes.TensorSource): + is_ready = all( + self._graph_analyzer.ready_to_run(tensor) + for tensor in operation_def.tensors + ) + else: + is_ready = all(input_values) + return (is_ready,) * operation_def.num_outputs - def validate_value(self, value): - assert isinstance(value, bool) + def validate_value(self, value): + assert isinstance(value, bool) class _TranslateVisitor(nodes.Visitor): - """Visitor that translates the operation graph. - - The original graph is defined by the user in the preprocessing_fn. The - translated graph represents a Beam pipeline. - """ - - def __init__(self): - self.phase = None - self.extracted_values_dict = None - self.intermediate_output_signature = {} - - def visit(self, operation_def, input_values): - if isinstance(operation_def, analyzer_nodes.TensorSource): - tensors = operation_def.tensors - label = operation_def.label - # Add tensor to signature so it gets produced by the SavedModel. - for tensor in tensors: - self.intermediate_output_signature[_tensor_name(tensor)] = tensor - keys = tuple(map(_tensor_name, tensors)) - output = nodes.apply_operation( - beam_nodes.ExtractFromDict, self.extracted_values_dict, - keys=keys, label=label) - return (output,) - else: - return nodes.OperationNode(operation_def, input_values).outputs - - def validate_value(self, value): - assert isinstance(value, nodes.ValueNode) + """Visitor that translates the operation graph. + + The original graph is defined by the user in the preprocessing_fn. The + translated graph represents a Beam pipeline. + """ + + def __init__(self): + self.phase = None + self.extracted_values_dict = None + self.intermediate_output_signature = {} + + def visit(self, operation_def, input_values): + if isinstance(operation_def, analyzer_nodes.TensorSource): + tensors = operation_def.tensors + label = operation_def.label + # Add tensor to signature so it gets produced by the SavedModel. + for tensor in tensors: + self.intermediate_output_signature[_tensor_name(tensor)] = tensor + keys = tuple(map(_tensor_name, tensors)) + output = nodes.apply_operation( + beam_nodes.ExtractFromDict, + self.extracted_values_dict, + keys=keys, + label=label, + ) + return (output,) + else: + return nodes.OperationNode(operation_def, input_values).outputs + + def validate_value(self, value): + assert isinstance(value, nodes.ValueNode) @dataclasses.dataclass(frozen=True) class _OptimizationView: - """A container for operation outputs during _OptimizeVisitor traversal. + """A container for operation outputs during _OptimizeVisitor traversal. - This is used in order to maintain both a flattened view, and a fine grained - view that can be used for caching. + This is used in order to maintain both a flattened view, and a fine grained + view that can be used for caching. - `prefer_fine_grained_view` is a hint that means that if True, the - `fine_grained_view` should be used. It should be set to true if the upstream - view has cacheing operations that haven't been flattened yet. - """ - prefer_fine_grained_view: bool - flattened_view: nodes.ValueNode - fine_grained_view: Optional[OrderedDictType[str, nodes.ValueNode]] - hashed_path: Optional[bytes] + `prefer_fine_grained_view` is a hint that means that if True, the + `fine_grained_view` should be used. It should be set to true if the upstream + view has cacheing operations that haven't been flattened yet. + """ + + prefer_fine_grained_view: bool + flattened_view: nodes.ValueNode + fine_grained_view: Optional[OrderedDictType[str, nodes.ValueNode]] + hashed_path: Optional[bytes] - def __post_init__(self): - if self.prefer_fine_grained_view and not self.fine_grained_view: - raise ValueError( - 'Cannot prefer fine_grained_view when one is not provided') + def __post_init__(self): + if self.prefer_fine_grained_view and not self.fine_grained_view: + raise ValueError("Cannot prefer fine_grained_view when one is not provided") class _OptimizeVisitor(nodes.Visitor): - """Visitor optimizes the operation graph (see nodes.py). - - This operates on the translated graph which is emitted by the - `_TranslateVisitor`, and performs optimizations. - - Namely, when enabled, this enables reading and writing from/to analyzer - accumulator cache to avoid recomputing them over already seen datasets. - This type of optimization requires also creating a partitioned view of the - input data, according to the `is_partitionable` annotation. - """ - - def __init__( - self, - dataset_keys: Collection[analyzer_cache.DatasetKey], - cache_dict: Optional[analyzer_cache.BeamAnalysisCache], - tensor_keys_to_paths: Mapping[str, str], - cache_output_nodes: _IntermediateCacheType, - num_phases: int, - ): - """Init method for _OptimizeVisitor. + """Visitor optimizes the operation graph (see nodes.py). - Args: - dataset_keys: An iterable of strings which are keys for a partitioned - dataset. - cache_dict: A dictionary of input cache that can be used in place of a - cacheable accumulate operation. A dictionary from dataset_keys to - dictionaries of cache keys to PCollections. This can be None if there is - no cache. - tensor_keys_to_paths: A dictionary from a tensor key to a unique TF graph - path hash. - cache_output_nodes: A dictionary from (dataset_key, cache_key) to encoded - cache ValueNode. This is the output cache for this graph. - num_phases: The number of phases of analysis. + This operates on the translated graph which is emitted by the + `_TranslateVisitor`, and performs optimizations. + + Namely, when enabled, this enables reading and writing from/to analyzer + accumulator cache to avoid recomputing them over already seen datasets. + This type of optimization requires also creating a partitioned view of the + input data, according to the `is_partitionable` annotation. """ - self._sorted_dataset_keys = sorted(dataset_keys) - self._cache_dict = cache_dict - self._tensor_keys_to_paths = tensor_keys_to_paths - self._dataset_has_cache_misses = collections.defaultdict(bool) - self._num_encode_cache_nodes = 0 - self._num_decode_cache_nodes = 0 - self.cache_output_nodes = cache_output_nodes - self._num_phases = num_phases - - def _validate_operation_def(self, operation_def): - if operation_def.cache_coder is not None: - if not operation_def.is_partitionable: - raise ValueError( - 'Non partitionable OperationDefs cannot be cacheable: ' - f'{operation_def.label}' + + def __init__( + self, + dataset_keys: Collection[analyzer_cache.DatasetKey], + cache_dict: Optional[analyzer_cache.BeamAnalysisCache], + tensor_keys_to_paths: Mapping[str, str], + cache_output_nodes: _IntermediateCacheType, + num_phases: int, + ): + """Init method for _OptimizeVisitor. + + Args: + ---- + dataset_keys: An iterable of strings which are keys for a partitioned + dataset. + cache_dict: A dictionary of input cache that can be used in place of a + cacheable accumulate operation. A dictionary from dataset_keys to + dictionaries of cache keys to PCollections. This can be None if there is + no cache. + tensor_keys_to_paths: A dictionary from a tensor key to a unique TF graph + path hash. + cache_output_nodes: A dictionary from (dataset_key, cache_key) to encoded + cache ValueNode. This is the output cache for this graph. + num_phases: The number of phases of analysis. + """ + self._sorted_dataset_keys = sorted(dataset_keys) + self._cache_dict = cache_dict + self._tensor_keys_to_paths = tensor_keys_to_paths + self._dataset_has_cache_misses = collections.defaultdict(bool) + self._num_encode_cache_nodes = 0 + self._num_decode_cache_nodes = 0 + self.cache_output_nodes = cache_output_nodes + self._num_phases = num_phases + + def _validate_operation_def(self, operation_def): + if operation_def.cache_coder is not None: + if not operation_def.is_partitionable: + raise ValueError( + "Non partitionable OperationDefs cannot be cacheable: " + f"{operation_def.label}" + ) + if operation_def.is_partitionable or operation_def.cache_coder is not None: + if operation_def.num_outputs != 1: + raise ValueError( + "Cacheable OperationDefs must have exactly 1 output: " + f"{operation_def.label}" + ) + + def get_detached_sideeffect_leafs(self): + """Returns a list of sideeffect leaf nodes after the visit is done.""" + # If this is a multi-phase analysis, then all datasets have to be read + # anyway, and so we'll not instrument full cache coverage for this case. + if self._num_phases > 1: + return [] + dataset_keys_with_decoded_cache = [] + for dataset_key in self._sorted_dataset_keys: + # Default to True here, if the dataset_key is not in the cache misses map + # then treat it like it does have cache misses because it has not been + # visited in the optimization traversal. + if self._dataset_has_cache_misses.get(dataset_key, True): + continue + # Default to None if the dataset_key isn't present in the cache dict, it + # means that there is not cache present for this dataset, so we should not + # instrument cache for it. + cache_dict = self._cache_dict or {} + dataset_cache_entries = cache_dict.get(dataset_key, None) + if dataset_cache_entries is not None and dataset_cache_entries.metadata: + dataset_keys_with_decoded_cache.append(dataset_key) + if ( + dataset_keys_with_decoded_cache + or self._num_encode_cache_nodes + or self._num_decode_cache_nodes + ): + return [ + nodes.apply_operation( + analyzer_nodes.InstrumentDatasetCache, + input_cache_dataset_keys=dataset_keys_with_decoded_cache, + num_encode_cache=self._num_encode_cache_nodes, + num_decode_cache=self._num_decode_cache_nodes, + label="InstrumentDatasetCache", + ) + ] + return [] + + def _make_next_hashed_path(self, parent_hashed_paths, operation_def): + # Making a copy of parent_hashed_paths. + paths_to_hash = list(parent_hashed_paths) + paths_to_hash.append(tf.compat.as_bytes(operation_def.__class__.__name__)) + + if isinstance(operation_def, beam_nodes.ExtractFromDict): + for key in operation_def.keys: + path = self._tensor_keys_to_paths[key] + paths_to_hash.append(path) + else: + for attr in sorted( + [x for x in dir(operation_def) if x not in operation_def._fields] + ): + if attr.startswith("_") or callable(getattr(operation_def, attr)): + continue + paths_to_hash.append( + tf.compat.as_bytes(str((attr, getattr(operation_def, attr)))) + ) + for field in operation_def._fields: + paths_to_hash.append( + tf.compat.as_bytes(str((field, operation_def.get_field_str(field)))) + ) + + hash_container = hashlib.sha1() + for path in paths_to_hash: + if path is None: + return None + hash_container.update(path) + return hash_container.digest() + + def visit(self, operation_def, input_values): + self._validate_operation_def(operation_def) + + if ( + isinstance(operation_def, beam_nodes.ApplySavedModel) + and operation_def.phase == 0 + ): + return self._visit_apply_savedmodel_operation(operation_def, input_values) + + # When self._cache_dict is None this means that we shouldn't do any cacheing + # for this pipeline, and so there's no need to create any fine grained + # views. + if self._cache_dict is not None and operation_def.is_partitionable: + return self._visit_partitionable_operation(operation_def, input_values) + + if input_values and any( + v.fine_grained_view and v.prefer_fine_grained_view for v in input_values + ): + # We can 'flatten' the cached outputs of the parent operation since this + # operation doesn't support partitioning. + disaggregated_input_values = [] + for view in input_values: + disaggregated_input_values.extend(view.fine_grained_view.values()) + + # Each cache item should be a single ValueNode. + assert all( + isinstance(value, nodes.ValueNode) + for value in disaggregated_input_values + ) + + next_inputs = nodes.apply_multi_output_operation( + beam_nodes.Flatten, + *disaggregated_input_values, + label=f"FlattenCache[{operation_def.label}]", + ) + else: + # Parent operation output is not cacheable, therefore we can just use + # a flattened view. + next_inputs = tuple(v.flattened_view for v in input_values) + + flattened_view = nodes.OperationNode(operation_def, next_inputs).outputs + + return tuple( + _OptimizationView( # pylint: disable=g-complex-comprehension + prefer_fine_grained_view=False, + flattened_view=flat, + fine_grained_view=None, + hashed_path=None, + ) + for flat in flattened_view + ) + + def _visit_partitionable_operation(self, operation_def, upstream_views): + # This is a hint for whether or not the `fine_grained_view` should be used + # downstream. It should be set to true if either the upstream view has + # cacheing operations that haven't been flattened yet, or the current + # operation is cacheable. + all_fine_grained_views_available = all( + v.fine_grained_view for v in upstream_views ) - if operation_def.is_partitionable or operation_def.cache_coder is not None: - if operation_def.num_outputs != 1: - raise ValueError( - 'Cacheable OperationDefs must have exactly 1 output: ' - f'{operation_def.label}' + prefer_fine_grained_view = ( + any(v.prefer_fine_grained_view for v in upstream_views) + or all_fine_grained_views_available + and operation_def.cache_coder is not None ) - def get_detached_sideeffect_leafs(self): - """Returns a list of sideeffect leaf nodes after the visit is done.""" - # If this is a multi-phase analysis, then all datasets have to be read - # anyway, and so we'll not instrument full cache coverage for this case. - if self._num_phases > 1: - return [] - dataset_keys_with_decoded_cache = [] - for dataset_key in self._sorted_dataset_keys: - # Default to True here, if the dataset_key is not in the cache misses map - # then treat it like it does have cache misses because it has not been - # visited in the optimization traversal. - if self._dataset_has_cache_misses.get(dataset_key, True): - continue - # Default to None if the dataset_key isn't present in the cache dict, it - # means that there is not cache present for this dataset, so we should not - # instrument cache for it. - cache_dict = self._cache_dict or {} - dataset_cache_entries = cache_dict.get(dataset_key, None) - if dataset_cache_entries is not None and dataset_cache_entries.metadata: - dataset_keys_with_decoded_cache.append(dataset_key) - if (dataset_keys_with_decoded_cache or self._num_encode_cache_nodes or - self._num_decode_cache_nodes): - return [ - nodes.apply_operation( - analyzer_nodes.InstrumentDatasetCache, - input_cache_dataset_keys=dataset_keys_with_decoded_cache, - num_encode_cache=self._num_encode_cache_nodes, - num_decode_cache=self._num_decode_cache_nodes, - label='InstrumentDatasetCache') - ] - return [] - - def _make_next_hashed_path(self, parent_hashed_paths, operation_def): - # Making a copy of parent_hashed_paths. - paths_to_hash = list(parent_hashed_paths) - paths_to_hash.append(tf.compat.as_bytes(operation_def.__class__.__name__)) - - if isinstance(operation_def, beam_nodes.ExtractFromDict): - for key in operation_def.keys: - path = self._tensor_keys_to_paths[key] - paths_to_hash.append(path) - else: - for attr in sorted( - [x for x in dir(operation_def) if x not in operation_def._fields]): - if attr.startswith('_') or callable(getattr(operation_def, attr)): - continue - paths_to_hash.append( - tf.compat.as_bytes(str((attr, getattr(operation_def, attr))))) - for field in operation_def._fields: - paths_to_hash.append( - tf.compat.as_bytes( - str((field, operation_def.get_field_str(field))))) - - hash_container = hashlib.sha1() - for path in paths_to_hash: - if path is None: - return None - hash_container.update(path) - return hash_container.digest() - - def visit(self, operation_def, input_values): - self._validate_operation_def(operation_def) - - if (isinstance(operation_def, beam_nodes.ApplySavedModel) and - operation_def.phase == 0): - return self._visit_apply_savedmodel_operation(operation_def, input_values) - - # When self._cache_dict is None this means that we shouldn't do any cacheing - # for this pipeline, and so there's no need to create any fine grained - # views. - if self._cache_dict is not None and operation_def.is_partitionable: - return self._visit_partitionable_operation(operation_def, input_values) - - if input_values and any(v.fine_grained_view and v.prefer_fine_grained_view - for v in input_values): - # We can 'flatten' the cached outputs of the parent operation since this - # operation doesn't support partitioning. - disaggregated_input_values = [] - for view in input_values: - disaggregated_input_values.extend(view.fine_grained_view.values()) - - # Each cache item should be a single ValueNode. - assert all( - isinstance(value, nodes.ValueNode) - for value in disaggregated_input_values - ) - - next_inputs = nodes.apply_multi_output_operation( - beam_nodes.Flatten, - *disaggregated_input_values, - label=f'FlattenCache[{operation_def.label}]') - else: - # Parent operation output is not cacheable, therefore we can just use - # a flattened view. - next_inputs = tuple(v.flattened_view for v in input_values) - - flattened_view = nodes.OperationNode(operation_def, next_inputs).outputs - - return tuple( - _OptimizationView( # pylint: disable=g-complex-comprehension - prefer_fine_grained_view=False, - flattened_view=flat, - fine_grained_view=None, - hashed_path=None) for flat in flattened_view) - - def _visit_partitionable_operation(self, operation_def, upstream_views): - - # This is a hint for whether or not the `fine_grained_view` should be used - # downstream. It should be set to true if either the upstream view has - # cacheing operations that haven't been flattened yet, or the current - # operation is cacheable. - all_fine_grained_views_available = all( - v.fine_grained_view for v in upstream_views) - prefer_fine_grained_view = ( - any(v.prefer_fine_grained_view for v in upstream_views) or - all_fine_grained_views_available and - operation_def.cache_coder is not None) - - next_hashed_path = self._make_next_hashed_path( - [v.hashed_path for v in upstream_views], operation_def) - if all_fine_grained_views_available: - fine_grained_views = (self._apply_operation_on_fine_grained_view( - operation_def, tuple(v.fine_grained_view for v in upstream_views), - next_hashed_path),) - else: - fine_grained_views = (None,) * operation_def.num_outputs - - flattened_views = nodes.OperationNode( - operation_def, tuple(v.flattened_view for v in upstream_views)).outputs - - assert len(fine_grained_views) == len(flattened_views) - return tuple( - _OptimizationView( # pylint: disable=g-complex-comprehension - prefer_fine_grained_view=prefer_fine_grained_view, - flattened_view=flat, - fine_grained_view=fine, - hashed_path=next_hashed_path) - for flat, fine in zip(flattened_views, fine_grained_views)) - - def _apply_operation_on_fine_grained_view(self, operation_def, - fine_grained_views, - next_hashed_path): - """Applies a shardable operation on a fine grained view. - - This also updates `cache_output_nodes` when necessary. + next_hashed_path = self._make_next_hashed_path( + [v.hashed_path for v in upstream_views], operation_def + ) + if all_fine_grained_views_available: + fine_grained_views = ( + self._apply_operation_on_fine_grained_view( + operation_def, + tuple(v.fine_grained_view for v in upstream_views), + next_hashed_path, + ), + ) + else: + fine_grained_views = (None,) * operation_def.num_outputs + + flattened_views = nodes.OperationNode( + operation_def, tuple(v.flattened_view for v in upstream_views) + ).outputs + + assert len(fine_grained_views) == len(flattened_views) + return tuple( + _OptimizationView( # pylint: disable=g-complex-comprehension + prefer_fine_grained_view=prefer_fine_grained_view, + flattened_view=flat, + fine_grained_view=fine, + hashed_path=next_hashed_path, + ) + for flat, fine in zip(flattened_views, fine_grained_views) + ) - Args: - operation_def: A shardable `OperationDef`. - fine_grained_views: A tuple of `_OptimizationView.fine_grained_view`s. - next_hashed_path: The hashed path for the currently processed - operation_def. + def _apply_operation_on_fine_grained_view( + self, operation_def, fine_grained_views, next_hashed_path + ): + """Applies a shardable operation on a fine grained view. - Returns: - The resulting list of `_OptimizationView.fine_grained_view`s. - """ - result_fine_grained_view = collections.OrderedDict() - - cache_entry_key = analyzer_cache.make_cache_entry_key( - tf.compat.as_bytes(operation_def.label) + b'-' + next_hashed_path) - - for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys): - # We use an index for the label in order to make beam labels more stable. - infix = f'AnalysisIndex{dataset_idx}' - if ( - operation_def.cache_coder - and self._cache_dict - and self._cache_dict.get(dataset_key, {}).get(cache_entry_key) - is not None - ): - self._dataset_has_cache_misses[dataset_key] |= False - decode_cache = analyzer_nodes.DecodeCache( - dataset_key, - cache_entry_key, - coder=operation_def.cache_coder, - label=f'DecodeCache[{operation_def.label}][{infix}]') - (op_output,) = nodes.OperationNode(decode_cache, tuple()).outputs - self._num_decode_cache_nodes += 1 - else: - value_nodes = tuple(v[dataset_key] for v in fine_grained_views) - (op_output,) = nodes.OperationNode( - operation_def._replace(label=f'{operation_def.label}[{infix}]'), - value_nodes).outputs - if operation_def.cache_coder and dataset_key.is_cached: - self._dataset_has_cache_misses[dataset_key] = True - encode_cache = nodes.apply_operation( - analyzer_nodes.EncodeCache, - op_output, - coder=operation_def.cache_coder, - label=f'EncodeCache[{operation_def.label}][{infix}]') - self.cache_output_nodes[(dataset_key, cache_entry_key)] = encode_cache - self._num_encode_cache_nodes += 1 - result_fine_grained_view[dataset_key] = op_output - - return result_fine_grained_view - - def _visit_apply_savedmodel_operation(self, operation_def, upstream_views): - if any(v.fine_grained_view for v in upstream_views): - raise ValueError( - 'Was not expecting a fine_grained_view input for ApplySavedModel') - (saved_model_path_upstream_view, input_upstream_view) = upstream_views - - fine_grained_view = collections.OrderedDict() - for (dataset_idx, dataset_key) in enumerate(self._sorted_dataset_keys): - infix = f'AnalysisIndex{dataset_idx}' - input_node = nodes.apply_operation( - beam_nodes.ExtractInputForSavedModel, - dataset_key=dataset_key, - label=f'ExtractInputForSavedModel[{infix}]') - # We use an index for the label in order to make beam labels more stable. - (fine_grained_view[dataset_key],) = ( - nodes.OperationNode( - operation_def._replace(label=f'{operation_def.label}[{infix}]'), - (saved_model_path_upstream_view.flattened_view, - input_node)).outputs) - - (flattened_view,) = nodes.OperationNode( - operation_def, (saved_model_path_upstream_view.flattened_view, - input_upstream_view.flattened_view)).outputs - - return (_OptimizationView( - prefer_fine_grained_view=False, - flattened_view=flattened_view, - fine_grained_view=fine_grained_view, - hashed_path=b'APPLY_SAVEDMODEL'),) - - def validate_value(self, value): - assert isinstance(value, _OptimizationView), value - if value.fine_grained_view: - assert set(value.fine_grained_view.keys()) == set( - self._sorted_dataset_keys - ), (f'{value.fine_grained_view.keys()} != {self._sorted_dataset_keys}') + This also updates `cache_output_nodes` when necessary. + + Args: + ---- + operation_def: A shardable `OperationDef`. + fine_grained_views: A tuple of `_OptimizationView.fine_grained_view`s. + next_hashed_path: The hashed path for the currently processed + operation_def. + + Returns: + ------- + The resulting list of `_OptimizationView.fine_grained_view`s. + """ + result_fine_grained_view = collections.OrderedDict() + + cache_entry_key = analyzer_cache.make_cache_entry_key( + tf.compat.as_bytes(operation_def.label) + b"-" + next_hashed_path + ) + + for dataset_idx, dataset_key in enumerate(self._sorted_dataset_keys): + # We use an index for the label in order to make beam labels more stable. + infix = f"AnalysisIndex{dataset_idx}" + if ( + operation_def.cache_coder + and self._cache_dict + and self._cache_dict.get(dataset_key, {}).get(cache_entry_key) + is not None + ): + self._dataset_has_cache_misses[dataset_key] |= False + decode_cache = analyzer_nodes.DecodeCache( + dataset_key, + cache_entry_key, + coder=operation_def.cache_coder, + label=f"DecodeCache[{operation_def.label}][{infix}]", + ) + (op_output,) = nodes.OperationNode(decode_cache, tuple()).outputs + self._num_decode_cache_nodes += 1 + else: + value_nodes = tuple(v[dataset_key] for v in fine_grained_views) + (op_output,) = nodes.OperationNode( + operation_def._replace(label=f"{operation_def.label}[{infix}]"), + value_nodes, + ).outputs + if operation_def.cache_coder and dataset_key.is_cached: + self._dataset_has_cache_misses[dataset_key] = True + encode_cache = nodes.apply_operation( + analyzer_nodes.EncodeCache, + op_output, + coder=operation_def.cache_coder, + label=f"EncodeCache[{operation_def.label}][{infix}]", + ) + self.cache_output_nodes[(dataset_key, cache_entry_key)] = ( + encode_cache + ) + self._num_encode_cache_nodes += 1 + result_fine_grained_view[dataset_key] = op_output + + return result_fine_grained_view + + def _visit_apply_savedmodel_operation(self, operation_def, upstream_views): + if any(v.fine_grained_view for v in upstream_views): + raise ValueError( + "Was not expecting a fine_grained_view input for ApplySavedModel" + ) + (saved_model_path_upstream_view, input_upstream_view) = upstream_views + + fine_grained_view = collections.OrderedDict() + for dataset_idx, dataset_key in enumerate(self._sorted_dataset_keys): + infix = f"AnalysisIndex{dataset_idx}" + input_node = nodes.apply_operation( + beam_nodes.ExtractInputForSavedModel, + dataset_key=dataset_key, + label=f"ExtractInputForSavedModel[{infix}]", + ) + # We use an index for the label in order to make beam labels more stable. + (fine_grained_view[dataset_key],) = nodes.OperationNode( + operation_def._replace(label=f"{operation_def.label}[{infix}]"), + (saved_model_path_upstream_view.flattened_view, input_node), + ).outputs + + (flattened_view,) = nodes.OperationNode( + operation_def, + ( + saved_model_path_upstream_view.flattened_view, + input_upstream_view.flattened_view, + ), + ).outputs + + return ( + _OptimizationView( + prefer_fine_grained_view=False, + flattened_view=flattened_view, + fine_grained_view=fine_grained_view, + hashed_path=b"APPLY_SAVEDMODEL", + ), + ) + + def validate_value(self, value): + assert isinstance(value, _OptimizationView), value + if value.fine_grained_view: + assert set(value.fine_grained_view.keys()) == set( + self._sorted_dataset_keys + ), f"{value.fine_grained_view.keys()} != {self._sorted_dataset_keys}" def _perform_cache_optimization( @@ -431,134 +476,144 @@ def _perform_cache_optimization( Optional[_IntermediateCacheType], Collection[nodes.ValueNode], ]: - """Performs cache optimization on the given graph.""" - cache_output_nodes = {} - optimize_visitor = _OptimizeVisitor(dataset_keys or {}, cache_dict, - tensor_keys_to_paths, cache_output_nodes, - num_phases) - optimize_traverser = nodes.Traverser(optimize_visitor) - optimized = optimize_traverser.visit_value_node( - saved_model_future).flattened_view + """Performs cache optimization on the given graph.""" + cache_output_nodes = {} + optimize_visitor = _OptimizeVisitor( + dataset_keys or {}, + cache_dict, + tensor_keys_to_paths, + cache_output_nodes, + num_phases, + ) + optimize_traverser = nodes.Traverser(optimize_visitor) + optimized = optimize_traverser.visit_value_node(saved_model_future).flattened_view + + if cache_dict is None: + assert not cache_output_nodes + cache_output_nodes = None + + return ( + optimized, + cache_output_nodes, + optimize_visitor.get_detached_sideeffect_leafs(), + ) - if cache_dict is None: - assert not cache_output_nodes - cache_output_nodes = None - return (optimized, cache_output_nodes, - optimize_visitor.get_detached_sideeffect_leafs()) +class _InspectVisitor(nodes.Visitor): + """A visitor that inspects the graph and looks for dataset keys in use.""" + + def __init__(self, required_dataset_keys_output): + self._required_dataset_keys = required_dataset_keys_output + + def visit(self, operation_def, input_values): + if isinstance(operation_def, beam_nodes.ExtractInputForSavedModel): + self._required_dataset_keys.add(operation_def.dataset_key) + return nodes.OperationNode(operation_def, input_values).outputs + + def validate_value(self, value): + assert isinstance(value, nodes.ValueNode) + + +def _build_analysis_graph_for_inspection( + preprocessing_fn, specs, dataset_keys, input_cache, force_tf_compat_v1 +): + """Builds the analysis graph for inspection.""" + if not force_tf_compat_v1: + assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs + graph, structured_inputs, structured_outputs = ( + impl_helper.trace_preprocessing_function( + preprocessing_fn, + specs, + use_tf_compat_v1=tf2_utils.use_tf_compat_v1(force_tf_compat_v1), + ) + ) + transform_fn_future, cache_dict, _ = build( + graph, + structured_inputs, + structured_outputs, + dataset_keys=dataset_keys, + cache_dict=input_cache, + ) + return transform_fn_future, cache_dict -class _InspectVisitor(nodes.Visitor): - """A visitor that inspects the graph and looks for dataset keys in use.""" - - def __init__(self, required_dataset_keys_output): - self._required_dataset_keys = required_dataset_keys_output - - def visit(self, operation_def, input_values): - if isinstance(operation_def, beam_nodes.ExtractInputForSavedModel): - self._required_dataset_keys.add(operation_def.dataset_key) - return nodes.OperationNode(operation_def, input_values).outputs - - def validate_value(self, value): - assert isinstance(value, nodes.ValueNode) - - -def _build_analysis_graph_for_inspection(preprocessing_fn, specs, dataset_keys, - input_cache, force_tf_compat_v1): - """Builds the analysis graph for inspection.""" - if not force_tf_compat_v1: - assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs - graph, structured_inputs, structured_outputs = ( - impl_helper.trace_preprocessing_function( - preprocessing_fn, - specs, - use_tf_compat_v1=tf2_utils.use_tf_compat_v1(force_tf_compat_v1))) - - transform_fn_future, cache_dict, _ = build( - graph, - structured_inputs, - structured_outputs, - dataset_keys=dataset_keys, - cache_dict=input_cache) - return transform_fn_future, cache_dict - - -def get_analysis_dataset_keys(preprocessing_fn, - specs, - dataset_keys, - input_cache, - force_tf_compat_v1): - """Computes the dataset keys that are required in order to perform analysis. - - Args: - preprocessing_fn: A tf.transform preprocessing_fn. - specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is - True, this can also be feature specifications. - dataset_keys: A set of strings which are dataset keys, they uniquely - identify these datasets across analysis runs. - input_cache: A cache dictionary. - force_tf_compat_v1: If `True`, use Tensorflow in compat.v1 mode. - - Returns: - A set of dataset keys that are required for analysis. - """ - transform_fn_future, _ = _build_analysis_graph_for_inspection( - preprocessing_fn, specs, dataset_keys, input_cache, force_tf_compat_v1) - - result = set() - inspect_visitor = _InspectVisitor(result) - inspect_traverser = nodes.Traverser(inspect_visitor) - _ = inspect_traverser.visit_value_node(transform_fn_future) - - # If None is present this means that a flattened version of the entire dataset - # is required, therefore this will be returning all of the given dataset_keys. - if any(k.is_flattened_dataset_key() for k in result): - result = dataset_keys - return result - - -def get_analysis_cache_entry_keys(preprocessing_fn, - specs, - dataset_keys, - force_tf_compat_v1): - """Computes the cache entry keys that would be useful for analysis. - - Args: - preprocessing_fn: A tf.transform preprocessing_fn. - specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is - True, this can also be feature specifications. - dataset_keys: A set of strings which are dataset keys, they uniquely - identify these datasets across analysis runs. - force_tf_compat_v1: If `True`, use Tensorflow in compat.v1 mode. - - Returns: - A set of cache entry keys which would be useful for analysis. - """ - _, cache_dict = _build_analysis_graph_for_inspection(preprocessing_fn, specs, - dataset_keys, {}, - force_tf_compat_v1) - result = set() - if cache_dict: - for dataset_cache in cache_dict.values(): - result.update(dataset_cache.keys()) - return result - - -AnalysisCache = Mapping[ - analyzer_cache.DatasetKey, Mapping[str, nodes.ValueNode] -] + +def get_analysis_dataset_keys( + preprocessing_fn, specs, dataset_keys, input_cache, force_tf_compat_v1 +): + """Computes the dataset keys that are required in order to perform analysis. + + Args: + ---- + preprocessing_fn: A tf.transform preprocessing_fn. + specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is + True, this can also be feature specifications. + dataset_keys: A set of strings which are dataset keys, they uniquely + identify these datasets across analysis runs. + input_cache: A cache dictionary. + force_tf_compat_v1: If `True`, use Tensorflow in compat.v1 mode. + + Returns: + ------- + A set of dataset keys that are required for analysis. + """ + transform_fn_future, _ = _build_analysis_graph_for_inspection( + preprocessing_fn, specs, dataset_keys, input_cache, force_tf_compat_v1 + ) + + result = set() + inspect_visitor = _InspectVisitor(result) + inspect_traverser = nodes.Traverser(inspect_visitor) + _ = inspect_traverser.visit_value_node(transform_fn_future) + + # If None is present this means that a flattened version of the entire dataset + # is required, therefore this will be returning all of the given dataset_keys. + if any(k.is_flattened_dataset_key() for k in result): + result = dataset_keys + return result + + +def get_analysis_cache_entry_keys( + preprocessing_fn, specs, dataset_keys, force_tf_compat_v1 +): + """Computes the cache entry keys that would be useful for analysis. + + Args: + ---- + preprocessing_fn: A tf.transform preprocessing_fn. + specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is + True, this can also be feature specifications. + dataset_keys: A set of strings which are dataset keys, they uniquely + identify these datasets across analysis runs. + force_tf_compat_v1: If `True`, use Tensorflow in compat.v1 mode. + + Returns: + ------- + A set of cache entry keys which would be useful for analysis. + """ + _, cache_dict = _build_analysis_graph_for_inspection( + preprocessing_fn, specs, dataset_keys, {}, force_tf_compat_v1 + ) + result = set() + if cache_dict: + for dataset_cache in cache_dict.values(): + result.update(dataset_cache.keys()) + return result + + +AnalysisCache = Mapping[analyzer_cache.DatasetKey, Mapping[str, nodes.ValueNode]] def _format_output_cache( cache_value_nodes: _IntermediateCacheType, ) -> Optional[AnalysisCache]: - """Triggers dataset cache encoding and composes analysis cache output.""" - if cache_value_nodes is None: - return None - cache_dict = collections.defaultdict(dict) - for (dataset_key, cache_key), value_node in cache_value_nodes.items(): - cache_dict[dataset_key][cache_key] = value_node - return cache_dict + """Triggers dataset cache encoding and composes analysis cache output.""" + if cache_value_nodes is None: + return None + cache_dict = collections.defaultdict(dict) + for (dataset_key, cache_key), value_node in cache_value_nodes.items(): + cache_dict[dataset_key][cache_key] = value_node + return cache_dict def build( @@ -567,168 +622,184 @@ def build( output_signature: Mapping[str, common_types.TensorType], dataset_keys: Optional[Collection[analyzer_cache.DatasetKey]] = None, cache_dict: Optional[analyzer_cache.BeamAnalysisCache] = None, -) -> Tuple[ - nodes.ValueNode, Optional[AnalysisCache], Collection[nodes.ValueNode] -]: - """Returns a list of `Phase`s describing how to execute the pipeline. - - The default graph is assumed to contain some `Analyzer`s which must be - executed by doing a full pass over the dataset, and passing the inputs for - that analyzer into some implementation, then taking the results and replacing - the `Analyzer`s outputs with constants in the graph containing these results. - - The execution plan is described by a list of `Phase`s. Each phase contains - a list of `Analyzer`s, which are the `Analyzer`s which are ready to run in - that phase, together with a list of ops, which are the table initializers that - are ready to run in that phase. - - An `Analyzer` or op is ready to run when all its dependencies in the graph - have been computed. Thus if the graph is constructed by - - def preprocessing_fn(input) - x = inputs['x'] - scaled_0 = x - tft.min(x) - scaled_0_1 = scaled_0 / tft.max(scaled_0) - - Then the first phase will contain the analyzer corresponding to the call to - `min`, because `x` is an input and so is ready to compute in the first phase, - while the second phase will contain the analyzer corresponding to the call to - `max` since `scaled_1` depends on the result of the call to `tft.min` which - is computed in the first phase. - - More generally, we define a level for each op and each `Analyzer` by walking - the graph, assigning to each operation the max level of its inputs, to each - `Tensor` the level of its operation, unless it's the output of an `Analyzer` - in which case we assign the level of its `Analyzer` plus one. - - Args: - graph: A `tf.Graph`. - input_signature: A dict whose keys are strings and values are `Tensor`s, - `SparseTensor`s, or `RaggedTensor`s. - output_signature: A dict whose keys are strings and values are `Tensor`s, - `SparseTensor`s, or `RaggedTensor`s. - dataset_keys: (Optional) A set of `DatasetKeys`, which uniquely identify - these datasets across analysis runs. - cache_dict: (Optional): A cache dictionary. - - Returns: - A tuple of: - * A SavedModel future node. - * A dictionary of output cache `ValueNode`s. - * Side affect leaf nodes. - - Raises: - ValueError: if the graph cannot be analyzed. - """ - tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) - graph.clear_collection(analyzer_nodes.TENSOR_REPLACEMENTS) - phase = 0 - tensor_bindings = [] - sink_tensors_ready = { - tf_utils.hashable_tensor_or_op(tensor_sink.tensor): - False for tensor_sink in tensor_sinks - } - translate_visitor = _TranslateVisitor() - translate_traverser = nodes.Traverser(translate_visitor) - - analyzers_input_signature = {} - graph_analyzer = None - - extracted_input_node = nodes.apply_operation( - beam_nodes.ExtractInputForSavedModel, - dataset_key=analyzer_cache._make_flattened_dataset_key(), # pylint: disable=protected-access - label='ExtractInputForSavedModel[FlattenedDataset]') - - while not all(sink_tensors_ready.values()): - infix = f'Phase{phase}' - # Determine which table init ops are ready to run in this phase - # Determine which keys of pending_tensor_replacements are ready to run - # in this phase, based in whether their dependencies are ready. - graph_analyzer = graph_tools.InitializableGraphAnalyzer( - graph, input_signature, list(sink_tensors_ready.items()), - graph_tools.describe_path_as_analyzer_cache_hash) - ready_traverser = nodes.Traverser(_ReadyVisitor(graph_analyzer)) - - # Now create and apply a SavedModel with all tensors in tensor_bindings - # bound, which outputs all the tensors in the required tensor tuples. - intermediate_output_signature = collections.OrderedDict() +) -> Tuple[nodes.ValueNode, Optional[AnalysisCache], Collection[nodes.ValueNode]]: + """Returns a list of `Phase`s describing how to execute the pipeline. + + The default graph is assumed to contain some `Analyzer`s which must be + executed by doing a full pass over the dataset, and passing the inputs for + that analyzer into some implementation, then taking the results and replacing + the `Analyzer`s outputs with constants in the graph containing these results. + + The execution plan is described by a list of `Phase`s. Each phase contains + a list of `Analyzer`s, which are the `Analyzer`s which are ready to run in + that phase, together with a list of ops, which are the table initializers that + are ready to run in that phase. + + An `Analyzer` or op is ready to run when all its dependencies in the graph + have been computed. Thus if the graph is constructed by + + def preprocessing_fn(input) + x = inputs['x'] + scaled_0 = x - tft.min(x) + scaled_0_1 = scaled_0 / tft.max(scaled_0) + + Then the first phase will contain the analyzer corresponding to the call to + `min`, because `x` is an input and so is ready to compute in the first phase, + while the second phase will contain the analyzer corresponding to the call to + `max` since `scaled_1` depends on the result of the call to `tft.min` which + is computed in the first phase. + + More generally, we define a level for each op and each `Analyzer` by walking + the graph, assigning to each operation the max level of its inputs, to each + `Tensor` the level of its operation, unless it's the output of an `Analyzer` + in which case we assign the level of its `Analyzer` plus one. + + Args: + ---- + graph: A `tf.Graph`. + input_signature: A dict whose keys are strings and values are `Tensor`s, + `SparseTensor`s, or `RaggedTensor`s. + output_signature: A dict whose keys are strings and values are `Tensor`s, + `SparseTensor`s, or `RaggedTensor`s. + dataset_keys: (Optional) A set of `DatasetKeys`, which uniquely identify + these datasets across analysis runs. + cache_dict: (Optional): A cache dictionary. + + Returns: + ------- + A tuple of: + * A SavedModel future node. + * A dictionary of output cache `ValueNode`s. + * Side affect leaf nodes. + + Raises: + ------ + ValueError: if the graph cannot be analyzed. + """ + tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) + graph.clear_collection(analyzer_nodes.TENSOR_REPLACEMENTS) + phase = 0 + tensor_bindings = [] + sink_tensors_ready = { + tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False + for tensor_sink in tensor_sinks + } + translate_visitor = _TranslateVisitor() + translate_traverser = nodes.Traverser(translate_visitor) + + analyzers_input_signature = {} + graph_analyzer = None + + extracted_input_node = nodes.apply_operation( + beam_nodes.ExtractInputForSavedModel, + dataset_key=analyzer_cache._make_flattened_dataset_key(), # pylint: disable=protected-access + label="ExtractInputForSavedModel[FlattenedDataset]", + ) + + while not all(sink_tensors_ready.values()): + infix = f"Phase{phase}" + # Determine which table init ops are ready to run in this phase + # Determine which keys of pending_tensor_replacements are ready to run + # in this phase, based in whether their dependencies are ready. + graph_analyzer = graph_tools.InitializableGraphAnalyzer( + graph, + input_signature, + list(sink_tensors_ready.items()), + graph_tools.describe_path_as_analyzer_cache_hash, + ) + ready_traverser = nodes.Traverser(_ReadyVisitor(graph_analyzer)) + + # Now create and apply a SavedModel with all tensors in tensor_bindings + # bound, which outputs all the tensors in the required tensor tuples. + intermediate_output_signature = collections.OrderedDict() + saved_model_future = nodes.apply_operation( + beam_nodes.CreateSavedModel, + *tensor_bindings, + table_initializers=tuple(graph_analyzer.ready_table_initializers), + output_signature=intermediate_output_signature, + label=f"CreateSavedModelForAnalyzerInputs[{infix}]", + ) + + extracted_values_dict = nodes.apply_operation( + beam_nodes.ApplySavedModel, + saved_model_future, + extracted_input_node, + phase=phase, + label=f"ApplySavedModel[{infix}]", + ) + + translate_visitor.phase = phase + translate_visitor.intermediate_output_signature = intermediate_output_signature + translate_visitor.extracted_values_dict = extracted_values_dict + for tensor, value_node, is_asset_filepath in tensor_sinks: + hashable_tensor = tf_utils.hashable_tensor_or_op(tensor) + # Don't compute a binding/sink/replacement that's already been computed + if sink_tensors_ready[hashable_tensor]: + continue + + if not ready_traverser.visit_value_node(value_node): + continue + + translated_value_node = translate_traverser.visit_value_node(value_node) + + name = _tensor_name(tensor) + tensor_bindings.append( + nodes.apply_operation( + beam_nodes.CreateTensorBinding, + translated_value_node, + tensor_name=str(tensor.name), + dtype_enum=tensor.dtype.as_datatype_enum, + is_asset_filepath=is_asset_filepath, + label=analyzer_nodes.sanitize_label(f"CreateTensorBinding[{name}]"), + ) + ) + sink_tensors_ready[hashable_tensor] = True + + analyzers_input_signature.update(intermediate_output_signature) + phase += 1 + + # We need to make sure that the representation of this output_signature is + # deterministic. + output_signature = collections.OrderedDict( + sorted(output_signature.items(), key=lambda t: t[0]) + ) + + # TODO(KesterTong): check all table initializers are ready, check all output + # tensors are ready. saved_model_future = nodes.apply_operation( beam_nodes.CreateSavedModel, *tensor_bindings, - table_initializers=tuple(graph_analyzer.ready_table_initializers), - output_signature=intermediate_output_signature, - label=f'CreateSavedModelForAnalyzerInputs[{infix}]') - - extracted_values_dict = nodes.apply_operation( - beam_nodes.ApplySavedModel, - saved_model_future, - extracted_input_node, - phase=phase, - label=f'ApplySavedModel[{infix}]') - - translate_visitor.phase = phase - translate_visitor.intermediate_output_signature = ( - intermediate_output_signature) - translate_visitor.extracted_values_dict = extracted_values_dict - for tensor, value_node, is_asset_filepath in tensor_sinks: - hashable_tensor = tf_utils.hashable_tensor_or_op(tensor) - # Don't compute a binding/sink/replacement that's already been computed - if sink_tensors_ready[hashable_tensor]: - continue - - if not ready_traverser.visit_value_node(value_node): - continue - - translated_value_node = translate_traverser.visit_value_node(value_node) - - name = _tensor_name(tensor) - tensor_bindings.append( - nodes.apply_operation( - beam_nodes.CreateTensorBinding, - translated_value_node, - tensor_name=str(tensor.name), - dtype_enum=tensor.dtype.as_datatype_enum, - is_asset_filepath=is_asset_filepath, - label=analyzer_nodes.sanitize_label( - f'CreateTensorBinding[{name}]'))) - sink_tensors_ready[hashable_tensor] = True - - analyzers_input_signature.update(intermediate_output_signature) - phase += 1 - - # We need to make sure that the representation of this output_signature is - # deterministic. - output_signature = collections.OrderedDict( - sorted(output_signature.items(), key=lambda t: t[0])) - - # TODO(KesterTong): check all table initializers are ready, check all output - # tensors are ready. - saved_model_future = nodes.apply_operation( - beam_nodes.CreateSavedModel, - *tensor_bindings, - table_initializers=tuple( - graph.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)), - output_signature=output_signature, - label='CreateSavedModel') - - tensor_keys_to_paths = { - tensor_key: - graph_analyzer.get_unique_path(analyzers_input_signature[tensor_key]) # pytype: disable=attribute-error - for tensor_key in analyzers_input_signature - } - (optimized_saved_model_future, output_cache_value_nodes, - detached_sideeffect_leafs) = _perform_cache_optimization( - saved_model_future, dataset_keys, tensor_keys_to_paths, cache_dict, - phase) - - (optimized_saved_model_future, output_cache_value_nodes) = ( - combiner_packing_util.perform_combiner_packing_optimization( - optimized_saved_model_future, output_cache_value_nodes, phase)) - - global _ANALYSIS_GRAPH - _ANALYSIS_GRAPH = optimized_saved_model_future - return ( - optimized_saved_model_future, - _format_output_cache(output_cache_value_nodes), - detached_sideeffect_leafs, - ) + table_initializers=tuple( + graph.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS) + ), + output_signature=output_signature, + label="CreateSavedModel", + ) + + tensor_keys_to_paths = { + tensor_key: graph_analyzer.get_unique_path( + analyzers_input_signature[tensor_key] + ) # pytype: disable=attribute-error + for tensor_key in analyzers_input_signature + } + ( + optimized_saved_model_future, + output_cache_value_nodes, + detached_sideeffect_leafs, + ) = _perform_cache_optimization( + saved_model_future, dataset_keys, tensor_keys_to_paths, cache_dict, phase + ) + + (optimized_saved_model_future, output_cache_value_nodes) = ( + combiner_packing_util.perform_combiner_packing_optimization( + optimized_saved_model_future, output_cache_value_nodes, phase + ) + ) + + global _ANALYSIS_GRAPH + _ANALYSIS_GRAPH = optimized_saved_model_future + return ( + optimized_saved_model_future, + _format_output_cache(output_cache_value_nodes), + detached_sideeffect_leafs, + ) diff --git a/tensorflow_transform/beam/analysis_graph_builder_test.py b/tensorflow_transform/beam/analysis_graph_builder_test.py index b1aca33..7f3e831 100644 --- a/tensorflow_transform/beam/analysis_graph_builder_test.py +++ b/tensorflow_transform/beam/analysis_graph_builder_test.py @@ -16,30 +16,27 @@ import os import tensorflow as tf -import tensorflow_transform as tft -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import impl_helper -from tensorflow_transform import nodes -from tensorflow_transform import tf2_utils -from tensorflow_transform.beam import analysis_graph_builder -from tensorflow_transform.beam import analyzer_cache -from tensorflow_transform.beam import tft_unit + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple +import tensorflow_transform as tft +from tensorflow_transform import analyzer_nodes, impl_helper, nodes, tf2_utils +from tensorflow_transform.beam import analysis_graph_builder, analyzer_cache, tft_unit + mock = tf.compat.v1.test.mock def _preprocessing_fn_with_no_analyzers(inputs): - x = inputs['x'] - x_plus_1 = x + 1 - return {'x_plus_1': x_plus_1} + x = inputs["x"] + x_plus_1 = x + 1 + return {"x_plus_1": x_plus_1} _NO_ANALYZERS_CASE = dict( - testcase_name='with_no_analyzers', - feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}, + testcase_name="with_no_analyzers", + feature_spec={"x": tf.io.FixedLenFeature([], tf.float32)}, preprocessing_fn=_preprocessing_fn_with_no_analyzers, expected_dot_graph_str=r"""digraph G { directed=True; @@ -52,24 +49,24 @@ def _preprocessing_fn_with_no_analyzers(inputs): node [shape=Mrecord]; CreateSavedModel [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x_plus_1', \"Tensor\\>\")])|label: CreateSavedModel}"]; } -""") +""", +) def _preprocessing_fn_with_one_analyzer(inputs): + @tf.function + def _plus_one(x): + return x + 1 - @tf.function - def _plus_one(x): - return x + 1 - - x = _plus_one(inputs['x']) - x_mean = tft.mean(x, name='x') - x_centered = x - x_mean - return {'x_centered': x_centered} + x = _plus_one(inputs["x"]) + x_mean = tft.mean(x, name="x") + x_centered = x - x_mean + return {"x_centered": x_centered} _ONE_ANALYZER_CASE = dict( - testcase_name='with_one_analyzer', - feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}, + testcase_name="with_one_analyzer", + feature_spec={"x": tf.io.FixedLenFeature([], tf.float32)}, preprocessing_fn=_preprocessing_fn_with_one_analyzer, expected_dot_graph_str=r"""digraph G { directed=True; @@ -120,26 +117,28 @@ def _plus_one(x): "CreateTensorBinding[x#mean_and_var#temporary_analyzer_output#PlaceholderWithDefault]" -> CreateSavedModel; "CreateTensorBinding[x#mean_and_var#temporary_analyzer_output_1#PlaceholderWithDefault]" -> CreateSavedModel; } -""") +""", +) def _preprocessing_fn_with_table(inputs): - x = inputs['x'] - x_vocab = tft.vocabulary(x, name='x') - initializer = tf.lookup.TextFileInitializer( - x_vocab, - key_dtype=tf.string, - key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - value_dtype=tf.int64, - value_index=tf.lookup.TextFileIndex.LINE_NUMBER) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - x_integerized = table.lookup(x) - return {'x_integerized': x_integerized} + x = inputs["x"] + x_vocab = tft.vocabulary(x, name="x") + initializer = tf.lookup.TextFileInitializer( + x_vocab, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + ) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + x_integerized = table.lookup(x) + return {"x_integerized": x_integerized} _WITH_TABLE_CASE = dict( - testcase_name='with_table', - feature_spec={'x': tf.io.FixedLenFeature([], tf.string)}, + testcase_name="with_table", + feature_spec={"x": tf.io.FixedLenFeature([], tf.string)}, preprocessing_fn=_preprocessing_fn_with_table, expected_dot_graph_str=r"""digraph G { directed=True; @@ -208,21 +207,22 @@ def _preprocessing_fn_with_table(inputs): "CreateTensorBinding[x#temporary_analyzer_output_1#vocab_x_pruned_vocab_size]" -> CreateSavedModel; "CreateTensorBinding[x#temporary_analyzer_output_2#Const]" -> CreateSavedModel; } -""") +""", +) def _preprocessing_fn_with_two_phases(inputs): - x = inputs['x'] - x_mean = tft.mean(x, name='x') - x_square_deviations = tf.square(x - x_mean) - x_var = tft.mean(x_square_deviations, name='x_square_deviations') - x_normalized = (x - x_mean) / tf.sqrt(x_var) - return {'x_normalized': x_normalized} + x = inputs["x"] + x_mean = tft.mean(x, name="x") + x_square_deviations = tf.square(x - x_mean) + x_var = tft.mean(x_square_deviations, name="x_square_deviations") + x_normalized = (x - x_mean) / tf.sqrt(x_var) + return {"x_normalized": x_normalized} _TWO_PHASES_CASE = dict( - testcase_name='with_two_phases', - feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}, + testcase_name="with_two_phases", + feature_spec={"x": tf.io.FixedLenFeature([], tf.float32)}, preprocessing_fn=_preprocessing_fn_with_two_phases, expected_dot_graph_str=r"""digraph G { directed=True; @@ -313,38 +313,40 @@ def _preprocessing_fn_with_two_phases(inputs): "CreateTensorBinding[x_square_deviations#mean_and_var#temporary_analyzer_output#PlaceholderWithDefault]" -> CreateSavedModel; "CreateTensorBinding[x_square_deviations#mean_and_var#temporary_analyzer_output_1#PlaceholderWithDefault]" -> CreateSavedModel; } -""") +""", +) def _preprocessing_fn_with_chained_ptransforms(inputs): - - class FakeChainable( - tfx_namedtuple.namedtuple('FakeChainable', ['label']), - nodes.OperationDef): - - def __new__(cls): - scope = tf.compat.v1.get_default_graph().get_name_scope() - label = '{}[{}]'.format(cls.__name__, scope) - return super(FakeChainable, cls).__new__(cls, label=label) - - with tf.compat.v1.name_scope('x'): - input_values_node = nodes.apply_operation( - analyzer_nodes.TensorSource, tensors=[inputs['x']]) - with tf.compat.v1.name_scope('ptransform1'): - intermediate_value_node = nodes.apply_operation(FakeChainable, - input_values_node) - with tf.compat.v1.name_scope('ptransform2'): - output_value_node = nodes.apply_operation(FakeChainable, - intermediate_value_node) - x_chained = analyzer_nodes.bind_future_as_tensor( - output_value_node, analyzer_nodes.TensorInfo(tf.float32, (17, 27), - None)) - return {'x_chained': x_chained} + class FakeChainable( + tfx_namedtuple.namedtuple("FakeChainable", ["label"]), nodes.OperationDef + ): + def __new__(cls): + scope = tf.compat.v1.get_default_graph().get_name_scope() + label = f"{cls.__name__}[{scope}]" + return super(FakeChainable, cls).__new__(cls, label=label) + + with tf.compat.v1.name_scope("x"): + input_values_node = nodes.apply_operation( + analyzer_nodes.TensorSource, tensors=[inputs["x"]] + ) + with tf.compat.v1.name_scope("ptransform1"): + intermediate_value_node = nodes.apply_operation( + FakeChainable, input_values_node + ) + with tf.compat.v1.name_scope("ptransform2"): + output_value_node = nodes.apply_operation( + FakeChainable, intermediate_value_node + ) + x_chained = analyzer_nodes.bind_future_as_tensor( + output_value_node, analyzer_nodes.TensorInfo(tf.float32, (17, 27), None) + ) + return {"x_chained": x_chained} _CHAINED_PTRANSFORMS_CASE = dict( - testcase_name='with_chained_ptransforms', - feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}, + testcase_name="with_chained_ptransforms", + feature_spec={"x": tf.io.FixedLenFeature([], tf.int64)}, preprocessing_fn=_preprocessing_fn_with_chained_ptransforms, expected_dot_graph_str=r"""digraph G { directed=True; @@ -385,7 +387,8 @@ def __new__(cls): CreateSavedModel [label="{CreateSavedModel|table_initializers: 0|output_signature: OrderedDict([('x_chained', \"Tensor\\>\")])|label: CreateSavedModel}"]; "CreateTensorBinding[x#temporary_analyzer_output#PlaceholderWithDefault]" -> CreateSavedModel; } -""") +""", +) _ANALYZE_TEST_CASES = [ _NO_ANALYZERS_CASE, @@ -397,187 +400,214 @@ def __new__(cls): class AnalysisGraphBuilderTest(tft_unit.TransformTestCase): - - @tft_unit.named_parameters( - *tft_unit.cross_named_parameters( - _ANALYZE_TEST_CASES, - [ - dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True), - dict(testcase_name='tf2', use_tf_compat_v1=False), - ], - ) - ) - def test_build(self, feature_spec, preprocessing_fn, expected_dot_graph_str, - expected_dot_graph_str_tf2, use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required') - specs = ( - feature_spec if use_tf_compat_v1 else - impl_helper.get_type_specs_from_feature_specs(feature_spec)) - graph, structured_inputs, structured_outputs = ( - impl_helper.trace_preprocessing_function( - preprocessing_fn, - specs, - use_tf_compat_v1=use_tf_compat_v1, - base_temp_dir=os.path.join(self.get_temp_dir(), - self._testMethodName))) - (transform_fn_future, unused_cache, - unused_sideeffects) = analysis_graph_builder.build(graph, - structured_inputs, - structured_outputs) - - dot_string = nodes.get_dot_graph([transform_fn_future]).to_string() - self.WriteRenderedDotFile(dot_string) - self.assertMultiLineEqual( - msg='Result dot graph is:\n{}'.format(dot_string), - first=dot_string, - second=(expected_dot_graph_str - if use_tf_compat_v1 else expected_dot_graph_str_tf2)) - - @tft_unit.named_parameters( - *tft_unit.cross_named_parameters( - [ - dict( - testcase_name='one_dataset_cached_single_phase', - preprocessing_fn=_preprocessing_fn_with_one_analyzer, - full_dataset_keys=['a', 'b'], - cached_dataset_keys=['a'], - expected_dataset_keys=['b'], - ), - dict( - testcase_name='all_datasets_cached_single_phase', - preprocessing_fn=_preprocessing_fn_with_one_analyzer, - full_dataset_keys=['a', 'b'], - cached_dataset_keys=['a', 'b'], - expected_dataset_keys=[], - ), - dict( - testcase_name='mixed_single_phase', - preprocessing_fn=lambda d: dict( # pylint: disable=g-long-lambda - list( - _preprocessing_fn_with_chained_ptransforms(d).items() - ) - + list(_preprocessing_fn_with_one_analyzer(d).items()) - ), - full_dataset_keys=['a', 'b'], - cached_dataset_keys=['a', 'b'], - expected_dataset_keys=['a', 'b'], - ), - dict( - testcase_name='multi_phase', - preprocessing_fn=_preprocessing_fn_with_two_phases, - full_dataset_keys=['a', 'b'], - cached_dataset_keys=['a', 'b'], - expected_dataset_keys=['a', 'b'], - ), - ], - [ - dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True), - dict(testcase_name='tf2', use_tf_compat_v1=False), - ], - ) - ) - def test_get_analysis_dataset_keys(self, preprocessing_fn, full_dataset_keys, - cached_dataset_keys, expected_dataset_keys, - use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required') - full_dataset_keys = list( - map(analyzer_cache.DatasetKey, full_dataset_keys)) - cached_dataset_keys = map(analyzer_cache.DatasetKey, cached_dataset_keys) - expected_dataset_keys = map( - analyzer_cache.DatasetKey, expected_dataset_keys) - # We force all dataset keys with entries in the cache dict will have a cache - # hit. - mocked_cache_entry_key = b'M' - input_cache = { - key: analyzer_cache.DatasetCache({mocked_cache_entry_key: 'C'}, None) - for key in cached_dataset_keys - } - feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)} - specs = ( - feature_spec if use_tf_compat_v1 else - impl_helper.get_type_specs_from_feature_specs(feature_spec)) - with mock.patch( - 'tensorflow_transform.beam.analysis_graph_builder.' - 'analyzer_cache.make_cache_entry_key', - return_value=mocked_cache_entry_key): - dataset_keys = ( - analysis_graph_builder.get_analysis_dataset_keys( - preprocessing_fn, - specs, - full_dataset_keys, - input_cache, - force_tf_compat_v1=use_tf_compat_v1)) - self.DebugPublishLatestsRenderedTFTGraph() - self.assertCountEqual(expected_dataset_keys, dataset_keys) - - @tft_unit.named_parameters( - dict(testcase_name='tf_compat_v1', use_tf_compat_v1=True), - dict(testcase_name='tf2', use_tf_compat_v1=False), - ) - def test_get_analysis_cache_entry_keys(self, use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required') - full_dataset_keys = map(analyzer_cache.DatasetKey, ['a', 'b']) - def preprocessing_fn(inputs): - return {'x': tft.scale_to_0_1(inputs['x'])} - mocked_cache_entry_key = 'A' - def mocked_make_cache_entry_key(_): - return mocked_cache_entry_key - feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)} - specs = ( - feature_spec if use_tf_compat_v1 else - impl_helper.get_type_specs_from_feature_specs(feature_spec)) - with mock.patch( - 'tensorflow_transform.beam.analysis_graph_builder.' - 'analyzer_cache.make_cache_entry_key', - side_effect=mocked_make_cache_entry_key): - cache_entry_keys = ( - analysis_graph_builder.get_analysis_cache_entry_keys( - preprocessing_fn, - specs, - full_dataset_keys, - force_tf_compat_v1=use_tf_compat_v1)) - self.DebugPublishLatestsRenderedTFTGraph() - self.assertCountEqual(cache_entry_keys, [mocked_cache_entry_key]) - - def test_duplicate_label_error(self): - - def _preprocessing_fn(inputs): - - class _Analyzer( - tfx_namedtuple.namedtuple('_Analyzer', ['label']), - nodes.OperationDef): - pass - - input_values_node = nodes.apply_operation( - analyzer_nodes.TensorSource, tensors=[inputs['x']]) - intermediate_value_node = nodes.apply_operation( - _Analyzer, input_values_node, label='SameLabel') - output_value_node = nodes.apply_operation( - _Analyzer, intermediate_value_node, label='SameLabel') - x_chained = analyzer_nodes.bind_future_as_tensor( - output_value_node, - analyzer_nodes.TensorInfo(tf.float32, (17, 27), None)) - return {'x_chained': x_chained} - - feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)} - use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(False) - specs = ( - feature_spec if use_tf_compat_v1 else - impl_helper.get_type_specs_from_feature_specs(feature_spec)) - graph, structured_inputs, structured_outputs = ( - impl_helper.trace_preprocessing_function( - _preprocessing_fn, - specs, - use_tf_compat_v1=use_tf_compat_v1, - base_temp_dir=os.path.join(self.get_temp_dir(), - self._testMethodName))) - with self.assertRaisesRegex(AssertionError, 'SameLabel'): - _ = analysis_graph_builder.build(graph, structured_inputs, - structured_outputs) - - -if __name__ == '__main__': - tft_unit.main() + @tft_unit.named_parameters( + *tft_unit.cross_named_parameters( + _ANALYZE_TEST_CASES, + [ + dict(testcase_name="tf_compat_v1", use_tf_compat_v1=True), + dict(testcase_name="tf2", use_tf_compat_v1=False), + ], + ) + ) + def test_build( + self, + feature_spec, + preprocessing_fn, + expected_dot_graph_str, + expected_dot_graph_str_tf2, + use_tf_compat_v1, + ): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required") + specs = ( + feature_spec + if use_tf_compat_v1 + else impl_helper.get_type_specs_from_feature_specs(feature_spec) + ) + graph, structured_inputs, structured_outputs = ( + impl_helper.trace_preprocessing_function( + preprocessing_fn, + specs, + use_tf_compat_v1=use_tf_compat_v1, + base_temp_dir=os.path.join(self.get_temp_dir(), self._testMethodName), + ) + ) + (transform_fn_future, unused_cache, unused_sideeffects) = ( + analysis_graph_builder.build(graph, structured_inputs, structured_outputs) + ) + + dot_string = nodes.get_dot_graph([transform_fn_future]).to_string() + self.WriteRenderedDotFile(dot_string) + self.assertMultiLineEqual( + msg=f"Result dot graph is:\n{dot_string}", + first=dot_string, + second=( + expected_dot_graph_str + if use_tf_compat_v1 + else expected_dot_graph_str_tf2 + ), + ) + + @tft_unit.named_parameters( + *tft_unit.cross_named_parameters( + [ + dict( + testcase_name="one_dataset_cached_single_phase", + preprocessing_fn=_preprocessing_fn_with_one_analyzer, + full_dataset_keys=["a", "b"], + cached_dataset_keys=["a"], + expected_dataset_keys=["b"], + ), + dict( + testcase_name="all_datasets_cached_single_phase", + preprocessing_fn=_preprocessing_fn_with_one_analyzer, + full_dataset_keys=["a", "b"], + cached_dataset_keys=["a", "b"], + expected_dataset_keys=[], + ), + dict( + testcase_name="mixed_single_phase", + preprocessing_fn=lambda d: dict( # pylint: disable=g-long-lambda + list(_preprocessing_fn_with_chained_ptransforms(d).items()) + + list(_preprocessing_fn_with_one_analyzer(d).items()) + ), + full_dataset_keys=["a", "b"], + cached_dataset_keys=["a", "b"], + expected_dataset_keys=["a", "b"], + ), + dict( + testcase_name="multi_phase", + preprocessing_fn=_preprocessing_fn_with_two_phases, + full_dataset_keys=["a", "b"], + cached_dataset_keys=["a", "b"], + expected_dataset_keys=["a", "b"], + ), + ], + [ + dict(testcase_name="tf_compat_v1", use_tf_compat_v1=True), + dict(testcase_name="tf2", use_tf_compat_v1=False), + ], + ) + ) + def test_get_analysis_dataset_keys( + self, + preprocessing_fn, + full_dataset_keys, + cached_dataset_keys, + expected_dataset_keys, + use_tf_compat_v1, + ): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required") + full_dataset_keys = list(map(analyzer_cache.DatasetKey, full_dataset_keys)) + cached_dataset_keys = map(analyzer_cache.DatasetKey, cached_dataset_keys) + expected_dataset_keys = map(analyzer_cache.DatasetKey, expected_dataset_keys) + # We force all dataset keys with entries in the cache dict will have a cache + # hit. + mocked_cache_entry_key = b"M" + input_cache = { + key: analyzer_cache.DatasetCache({mocked_cache_entry_key: "C"}, None) + for key in cached_dataset_keys + } + feature_spec = {"x": tf.io.FixedLenFeature([], tf.float32)} + specs = ( + feature_spec + if use_tf_compat_v1 + else impl_helper.get_type_specs_from_feature_specs(feature_spec) + ) + with mock.patch( + "tensorflow_transform.beam.analysis_graph_builder." + "analyzer_cache.make_cache_entry_key", + return_value=mocked_cache_entry_key, + ): + dataset_keys = analysis_graph_builder.get_analysis_dataset_keys( + preprocessing_fn, + specs, + full_dataset_keys, + input_cache, + force_tf_compat_v1=use_tf_compat_v1, + ) + self.DebugPublishLatestsRenderedTFTGraph() + self.assertCountEqual(expected_dataset_keys, dataset_keys) + + @tft_unit.named_parameters( + dict(testcase_name="tf_compat_v1", use_tf_compat_v1=True), + dict(testcase_name="tf2", use_tf_compat_v1=False), + ) + def test_get_analysis_cache_entry_keys(self, use_tf_compat_v1): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required") + full_dataset_keys = map(analyzer_cache.DatasetKey, ["a", "b"]) + + def preprocessing_fn(inputs): + return {"x": tft.scale_to_0_1(inputs["x"])} + + mocked_cache_entry_key = "A" + + def mocked_make_cache_entry_key(_): + return mocked_cache_entry_key + + feature_spec = {"x": tf.io.FixedLenFeature([], tf.float32)} + specs = ( + feature_spec + if use_tf_compat_v1 + else impl_helper.get_type_specs_from_feature_specs(feature_spec) + ) + with mock.patch( + "tensorflow_transform.beam.analysis_graph_builder." + "analyzer_cache.make_cache_entry_key", + side_effect=mocked_make_cache_entry_key, + ): + cache_entry_keys = analysis_graph_builder.get_analysis_cache_entry_keys( + preprocessing_fn, + specs, + full_dataset_keys, + force_tf_compat_v1=use_tf_compat_v1, + ) + self.DebugPublishLatestsRenderedTFTGraph() + self.assertCountEqual(cache_entry_keys, [mocked_cache_entry_key]) + + def test_duplicate_label_error(self): + def _preprocessing_fn(inputs): + class _Analyzer( + tfx_namedtuple.namedtuple("_Analyzer", ["label"]), nodes.OperationDef + ): + pass + + input_values_node = nodes.apply_operation( + analyzer_nodes.TensorSource, tensors=[inputs["x"]] + ) + intermediate_value_node = nodes.apply_operation( + _Analyzer, input_values_node, label="SameLabel" + ) + output_value_node = nodes.apply_operation( + _Analyzer, intermediate_value_node, label="SameLabel" + ) + x_chained = analyzer_nodes.bind_future_as_tensor( + output_value_node, analyzer_nodes.TensorInfo(tf.float32, (17, 27), None) + ) + return {"x_chained": x_chained} + + feature_spec = {"x": tf.io.FixedLenFeature([], tf.float32)} + use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(False) + specs = ( + feature_spec + if use_tf_compat_v1 + else impl_helper.get_type_specs_from_feature_specs(feature_spec) + ) + graph, structured_inputs, structured_outputs = ( + impl_helper.trace_preprocessing_function( + _preprocessing_fn, + specs, + use_tf_compat_v1=use_tf_compat_v1, + base_temp_dir=os.path.join(self.get_temp_dir(), self._testMethodName), + ) + ) + with self.assertRaisesRegex(AssertionError, "SameLabel"): + _ = analysis_graph_builder.build( + graph, structured_inputs, structured_outputs + ) + + +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/analyzer_cache.py b/tensorflow_transform/beam/analyzer_cache.py index 6e31192..232726f 100644 --- a/tensorflow_transform/beam/analyzer_cache.py +++ b/tensorflow_transform/beam/analyzer_cache.py @@ -17,10 +17,11 @@ import pickle import re import sys -from typing import Iterable, Mapping, List, Optional, Union, Tuple +from typing import Iterable, List, Mapping, Optional, Tuple, Union import apache_beam as beam import tensorflow as tf + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple @@ -28,388 +29,413 @@ # This should be advanced whenever a non-backwards compatible change is made # that affects analyzer cache. For example, changing accumulator format. _CACHE_VERSION_NUMBER = 1 -_PYTHON_VERSION = f'{sys.version_info.major}.{sys.version_info.minor}' -_CACHE_VERSION = tf.compat.as_bytes( - f'__v{_CACHE_VERSION_NUMBER}__{_PYTHON_VERSION}_') +_PYTHON_VERSION = f"{sys.version_info.major}.{sys.version_info.minor}" +_CACHE_VERSION = tf.compat.as_bytes(f"__v{_CACHE_VERSION_NUMBER}__{_PYTHON_VERSION}_") -_METADATA_FILE_NAME = 'METADATA' +_METADATA_FILE_NAME = "METADATA" _CACHE_COMPONENT_CHARACTER_REPLACEMENTS = ( - ('/', '-'), - ('\\', '-'), - ('*', 'STAR'), - ('@', 'AT'), - ('[', '--'), - (']', '--'), - (':', 'P'), - ('=', 'E'), + ("/", "-"), + ("\\", "-"), + ("*", "STAR"), + ("@", "AT"), + ("[", "--"), + ("]", "--"), + (":", "P"), + ("=", "E"), ) def _make_valid_cache_component(name: str) -> str: - result = name - for unsupported_char, replacement in _CACHE_COMPONENT_CHARACTER_REPLACEMENTS: - result = result.replace(unsupported_char, replacement) - return result + result = name + for unsupported_char, replacement in _CACHE_COMPONENT_CHARACTER_REPLACEMENTS: + result = result.replace(unsupported_char, replacement) + return result + +class DatasetKey(tfx_namedtuple.namedtuple("DatasetKey", ["key", "is_cached"])): + """A key for a dataset used for analysis.""" -class DatasetKey(tfx_namedtuple.namedtuple('DatasetKey', ['key', 'is_cached'])): - """A key for a dataset used for analysis.""" - __slots__ = () - _FLATTENED_DATASET_KEY = object() + __slots__ = () + _FLATTENED_DATASET_KEY = object() - def non_cacheable(self) -> 'DatasetKey': - """Creates a non cacheable dataset key, for which no cache will be produced.""" - return self._replace(key=f'uncached_{self.key}', is_cached=False) + def non_cacheable(self) -> "DatasetKey": + """Creates a non cacheable dataset key, for which no cache will be produced.""" + return self._replace(key=f"uncached_{self.key}", is_cached=False) - def __new__( - cls, dataset_key: Union[str, object], is_cached: bool = True - ) -> 'DatasetKey': - if dataset_key is not DatasetKey._FLATTENED_DATASET_KEY: - if not isinstance(dataset_key, str): - raise ValueError( - f'User provided dataset_key must be a str. Got: {dataset_key}') - dataset_key = _make_valid_cache_component(dataset_key) - return super().__new__(cls, key=dataset_key, is_cached=is_cached) + def __new__( + cls, dataset_key: Union[str, object], is_cached: bool = True + ) -> "DatasetKey": + if dataset_key is not DatasetKey._FLATTENED_DATASET_KEY: + if not isinstance(dataset_key, str): + raise ValueError( + f"User provided dataset_key must be a str. Got: {dataset_key}" + ) + dataset_key = _make_valid_cache_component(dataset_key) + return super().__new__(cls, key=dataset_key, is_cached=is_cached) - def __str__(self): - if self.is_flattened_dataset_key(): - return str(DatasetKey('FlattenedDataset')) - else: - return super().__str__() + def __str__(self): + if self.is_flattened_dataset_key(): + return str(DatasetKey("FlattenedDataset")) + else: + return super().__str__() - def is_flattened_dataset_key(self) -> bool: - return self.key == self._FLATTENED_DATASET_KEY + def is_flattened_dataset_key(self) -> bool: + return self.key == self._FLATTENED_DATASET_KEY def _make_flattened_dataset_key() -> DatasetKey: - return DatasetKey(DatasetKey._FLATTENED_DATASET_KEY, is_cached=False) # pylint: disable=protected-access + return DatasetKey(DatasetKey._FLATTENED_DATASET_KEY, is_cached=False) # pylint: disable=protected-access def _get_dataset_cache_path(base_dir: str, dataset_key: DatasetKey) -> str: - return os.path.join(base_dir, dataset_key.key) + return os.path.join(base_dir, dataset_key.key) class DatasetCacheMetadata( - tfx_namedtuple.TypedNamedTuple('DatasetCacheMetadata', - [('dataset_size', int)])): - """Metadata about a cached dataset.""" + tfx_namedtuple.TypedNamedTuple("DatasetCacheMetadata", [("dataset_size", int)]) +): + """Metadata about a cached dataset.""" - __slots__ = () + __slots__ = () - def encode(self) -> bytes: - return pickle.dumps(self._asdict(), protocol=0) + def encode(self) -> bytes: + return pickle.dumps(self._asdict(), protocol=0) - @classmethod - def decode(cls, value: bytes) -> 'DatasetCacheMetadata': - return cls(**pickle.loads(value)) + @classmethod + def decode(cls, value: bytes) -> "DatasetCacheMetadata": + return cls(**pickle.loads(value)) class DatasetCache( tfx_namedtuple.TypedNamedTuple( - 'DatasetCache', - [('cache_dict', Mapping[str, beam.PCollection[bytes]]), - ('metadata', Optional[Union[beam.PCollection[DatasetCacheMetadata], - DatasetCacheMetadata]])])): - """Complete cache for a dataset as well as metadata.""" - __slots__ = () + "DatasetCache", + [ + ("cache_dict", Mapping[str, beam.PCollection[bytes]]), + ( + "metadata", + Optional[ + Union[beam.PCollection[DatasetCacheMetadata], DatasetCacheMetadata] + ], + ), + ], + ) +): + """Complete cache for a dataset as well as metadata.""" - def get(self, key): - return self.cache_dict.get(key) + __slots__ = () - def values(self): - return self.cache_dict.values() + def get(self, key): + return self.cache_dict.get(key) - def keys(self): - return self.cache_dict.keys() + def values(self): + return self.cache_dict.values() - def items(self): - return self.cache_dict.items() + def keys(self): + return self.cache_dict.keys() + + def items(self): + return self.cache_dict.items() BeamAnalysisCache = Mapping[DatasetKey, DatasetCache] class _ManifestFile: - """A manifest file wrapper used to read and write tft cache manifest files.""" - - # TODO(b/37788560): Use artifacts instead. - _MANIFEST_FILE_NAME = 'MANIFEST' - - def __init__(self, base_path: str): - self._base_path = base_path - self._manifest_path = os.path.join(base_path, self._MANIFEST_FILE_NAME) - self._file = None - - def _open(self): - assert self._file is None - if not tf.io.gfile.isdir(self._base_path): - tf.io.gfile.makedirs(self._base_path) - self._file = tf.io.gfile.GFile(self._manifest_path, 'wb+') - - def _close(self): - if self._file: - self._file.close() - self._file = None - - def _delete(self): - self._close() - tf.io.gfile.remove(self._manifest_path) - - def __enter__(self): - self._open() - return self - - def __exit__(self, *exn_info): - self._close() - - def _get_manifest_contents(self, manifest_file_handle) -> Mapping[str, int]: - """Reads, decodes and returns the manifest contents.""" - manifest_file_handle.seek(0) - try: - result = pickle.loads(manifest_file_handle.read()) - assert isinstance(result, dict) - return result - except Exception as e: # pylint: disable=broad-except - # Any exception at this point would be an indication that the cache is - # likely invalidated. Returning an empty dict allows the pipeline to - # "gracefully" recover (by proceeding without cache) as opposed to - # entering a crash-loop it can't recover from. - tf.compat.v1.logging.error('Can\'t load cache manifest contents: %s', - str(e)) - return {} - - def read(self): - if not tf.io.gfile.exists(self._manifest_path): - return {} - - if self._file is not None: - return self._get_manifest_contents(self._file) - else: - with tf.io.gfile.GFile(self._manifest_path, 'rb') as f: - return self._get_manifest_contents(f) - - def write(self, manifest: Mapping[str, int]): - """Writes the manifest to the file.""" - try: - # First attempt to delete the manifest if it exists in case it can't be - # edited in-place. - self._delete() - except tf.errors.NotFoundError: - pass - self._open() - # Manifests are small, so writing in a semi-human readable form (protocol=0) - # is preferred over the efficiency gains of higher protocols. - assert self._file is not None - self._file.write(pickle.dumps(manifest, protocol=0)) + """A manifest file wrapper used to read and write tft cache manifest files.""" + + # TODO(b/37788560): Use artifacts instead. + _MANIFEST_FILE_NAME = "MANIFEST" + + def __init__(self, base_path: str): + self._base_path = base_path + self._manifest_path = os.path.join(base_path, self._MANIFEST_FILE_NAME) + self._file = None + + def _open(self): + assert self._file is None + if not tf.io.gfile.isdir(self._base_path): + tf.io.gfile.makedirs(self._base_path) + self._file = tf.io.gfile.GFile(self._manifest_path, "wb+") + + def _close(self): + if self._file: + self._file.close() + self._file = None + + def _delete(self): + self._close() + tf.io.gfile.remove(self._manifest_path) + + def __enter__(self): + self._open() + return self + + def __exit__(self, *exn_info): + self._close() + + def _get_manifest_contents(self, manifest_file_handle) -> Mapping[str, int]: + """Reads, decodes and returns the manifest contents.""" + manifest_file_handle.seek(0) + try: + result = pickle.loads(manifest_file_handle.read()) + assert isinstance(result, dict) + return result + except Exception as e: # pylint: disable=broad-except + # Any exception at this point would be an indication that the cache is + # likely invalidated. Returning an empty dict allows the pipeline to + # "gracefully" recover (by proceeding without cache) as opposed to + # entering a crash-loop it can't recover from. + tf.compat.v1.logging.error("Can't load cache manifest contents: %s", str(e)) + return {} + + def read(self): + if not tf.io.gfile.exists(self._manifest_path): + return {} + + if self._file is not None: + return self._get_manifest_contents(self._file) + else: + with tf.io.gfile.GFile(self._manifest_path, "rb") as f: + return self._get_manifest_contents(f) + + def write(self, manifest: Mapping[str, int]): + """Writes the manifest to the file.""" + try: + # First attempt to delete the manifest if it exists in case it can't be + # edited in-place. + self._delete() + except tf.errors.NotFoundError: + pass + self._open() + # Manifests are small, so writing in a semi-human readable form (protocol=0) + # is preferred over the efficiency gains of higher protocols. + assert self._file is not None + self._file.write(pickle.dumps(manifest, protocol=0)) class _WriteToTFRecordGzip(beam.io.WriteToTFRecord): - - def __init__(self, file_path_prefix): - super().__init__(file_path_prefix, file_name_suffix='.gz') + def __init__(self, file_path_prefix): + super().__init__(file_path_prefix, file_name_suffix=".gz") class _WriteMetadata(beam.PTransform): + def __init__(self, dataset_key_dir: str): + self._path = os.path.join(dataset_key_dir, _METADATA_FILE_NAME) - def __init__(self, dataset_key_dir: str): - self._path = os.path.join(dataset_key_dir, _METADATA_FILE_NAME) - - def expand( - self, - metadata: beam.PCollection[DatasetCacheMetadata]) -> beam.pvalue.PDone: - return (metadata - | 'EncodeCacheMetadata' >> beam.Map(lambda x: x.encode()) - | 'WriteCacheMetadata' >> beam.io.WriteToTFRecord(self._path)) + def expand( + self, metadata: beam.PCollection[DatasetCacheMetadata] + ) -> beam.pvalue.PDone: + return ( + metadata + | "EncodeCacheMetadata" >> beam.Map(lambda x: x.encode()) + | "WriteCacheMetadata" >> beam.io.WriteToTFRecord(self._path) + ) class _ReadMetadata(beam.PTransform): + def __init__(self, dataset_key_dir: str): + self._cache_metadata_path = os.path.join( + dataset_key_dir, f"{_METADATA_FILE_NAME}-*-of-*" + ) - def __init__(self, dataset_key_dir: str): - self._cache_metadata_path = os.path.join(dataset_key_dir, - f'{_METADATA_FILE_NAME}-*-of-*') - - def expand(self, - pipeline: beam.Pipeline) -> beam.PCollection[DatasetCacheMetadata]: - if tf.io.gfile.glob(self._cache_metadata_path): - return (pipeline - | 'ReadMetadata' >> beam.io.ReadFromTFRecord( - self._cache_metadata_path, validate=False) - | 'Decode' >> beam.Map(DatasetCacheMetadata.decode)) + def expand(self, pipeline: beam.Pipeline) -> beam.PCollection[DatasetCacheMetadata]: + if tf.io.gfile.glob(self._cache_metadata_path): + return ( + pipeline + | "ReadMetadata" + >> beam.io.ReadFromTFRecord(self._cache_metadata_path, validate=False) + | "Decode" >> beam.Map(DatasetCacheMetadata.decode) + ) class WriteAnalysisCacheToFS(beam.PTransform): - """Writes a cache object that can be read by ReadAnalysisCacheFromFS. - - Given a cache collection, this writes it to the configured directory. - If the configured directory already contains cache, this will merge the new - cache with the old. - NOTE: This merging of cache is determined at beam graph construction time, - so the cache must already exist there when constructing this. - """ - - def __init__(self, - pipeline: beam.Pipeline, - cache_base_dir: str, - dataset_keys: Optional[Iterable[DatasetKey]] = None, - sink: Optional[object] = None): - """Init method. - - Args: - pipeline: A beam Pipeline. - cache_base_dir: A str, the path that the cache should be stored in. - dataset_keys: (Optional) An iterable of strings. - sink: (Optional) A PTransform class that takes a path in its constructor, - and is used to write the cache. If not provided this uses a GZipped - TFRecord sink. + """Writes a cache object that can be read by ReadAnalysisCacheFromFS. + + Given a cache collection, this writes it to the configured directory. + If the configured directory already contains cache, this will merge the new + cache with the old. + NOTE: This merging of cache is determined at beam graph construction time, + so the cache must already exist there when constructing this. """ - self.pipeline = pipeline - self._cache_base_dir = cache_base_dir - if dataset_keys is None: - self._sorted_dataset_keys = None - else: - self._sorted_dataset_keys = sorted(dataset_keys) - self._sink = sink - if self._sink is None: - # TODO(b/37788560): Possibly use Riegeli as a default file format once - # possible. - self._sink = _WriteToTFRecordGzip - - def _extract_input_pvalues( - self, dataset_cache_dict: BeamAnalysisCache - ) -> Tuple[BeamAnalysisCache, List[beam.pvalue.PValue]]: - pvalues = [] - for value in dataset_cache_dict.values(): - if value.metadata: - pvalues.append(value.metadata) - return dataset_cache_dict, pvalues - - def _write_cache(self, manifest_file, dataset_key_index, dataset_key_dir, - cache): - manifest = manifest_file.read() - start_cache_idx = max(manifest.values()) + 1 if manifest else 0 - - dataset_identifier = f'AnalysisIndex{dataset_key_index}' - cache_is_written = [] - for cache_key_idx, (cache_entry_key, - cache_pcoll) in enumerate(cache.cache_dict.items(), - start_cache_idx): - cache_identifier = f'CacheKeyIndex{cache_key_idx}' - path = os.path.join(dataset_key_dir, str(cache_key_idx)) - manifest[cache_entry_key] = cache_key_idx - cache_is_written.append( - cache_pcoll - | f'Write[{dataset_identifier}][{cache_identifier}]' >> self._sink( - path)) - if cache.metadata is not None: - cache_is_written.append(cache.metadata - | f'WriteMetadata[{dataset_identifier}]' >> - _WriteMetadata(dataset_key_dir)) - - manifest_file.write(manifest) - return cache_is_written - - # TODO(b/269419184): Add typehints when possible: - # expand(self, dataset_cache_dict: BeamAnalysisCache) -> List[beam.pvalue.PDone] # pylint: disable=line-too-long - def expand(self, dataset_cache_dict): - if self._sorted_dataset_keys is None: - sorted_dataset_keys_list = sorted(dataset_cache_dict.keys()) - else: - sorted_dataset_keys_list = self._sorted_dataset_keys - missing_keys = set(dataset_cache_dict.keys()).difference( - set(sorted_dataset_keys_list)) - if missing_keys: - raise ValueError( - 'The dataset keys in the cache dictionary must be a subset of the ' - 'keys in dataset_keys. Missing {}.'.format(missing_keys)) - if not all(isinstance(d, DatasetKey) for d in sorted_dataset_keys_list): - raise ValueError('Expected dataset_keys to be of type DatasetKey') - - cache_is_written = [] - for dataset_key, cache in dataset_cache_dict.items(): - dataset_key_idx = sorted_dataset_keys_list.index(dataset_key) - dataset_key_dir = _get_dataset_cache_path(self._cache_base_dir, - dataset_key) - with _ManifestFile(dataset_key_dir) as manifest_file: - cache_is_written.extend( - self._write_cache(manifest_file, dataset_key_idx, dataset_key_dir, - cache)) - - return cache_is_written + + def __init__( + self, + pipeline: beam.Pipeline, + cache_base_dir: str, + dataset_keys: Optional[Iterable[DatasetKey]] = None, + sink: Optional[object] = None, + ): + """Init method. + + Args: + ---- + pipeline: A beam Pipeline. + cache_base_dir: A str, the path that the cache should be stored in. + dataset_keys: (Optional) An iterable of strings. + sink: (Optional) A PTransform class that takes a path in its constructor, + and is used to write the cache. If not provided this uses a GZipped + TFRecord sink. + """ + self.pipeline = pipeline + self._cache_base_dir = cache_base_dir + if dataset_keys is None: + self._sorted_dataset_keys = None + else: + self._sorted_dataset_keys = sorted(dataset_keys) + self._sink = sink + if self._sink is None: + # TODO(b/37788560): Possibly use Riegeli as a default file format once + # possible. + self._sink = _WriteToTFRecordGzip + + def _extract_input_pvalues( + self, dataset_cache_dict: BeamAnalysisCache + ) -> Tuple[BeamAnalysisCache, List[beam.pvalue.PValue]]: + pvalues = [] + for value in dataset_cache_dict.values(): + if value.metadata: + pvalues.append(value.metadata) + return dataset_cache_dict, pvalues + + def _write_cache(self, manifest_file, dataset_key_index, dataset_key_dir, cache): + manifest = manifest_file.read() + start_cache_idx = max(manifest.values()) + 1 if manifest else 0 + + dataset_identifier = f"AnalysisIndex{dataset_key_index}" + cache_is_written = [] + for cache_key_idx, (cache_entry_key, cache_pcoll) in enumerate( + cache.cache_dict.items(), start_cache_idx + ): + cache_identifier = f"CacheKeyIndex{cache_key_idx}" + path = os.path.join(dataset_key_dir, str(cache_key_idx)) + manifest[cache_entry_key] = cache_key_idx + cache_is_written.append( + cache_pcoll + | f"Write[{dataset_identifier}][{cache_identifier}]" >> self._sink(path) + ) + if cache.metadata is not None: + cache_is_written.append( + cache.metadata + | f"WriteMetadata[{dataset_identifier}]" + >> _WriteMetadata(dataset_key_dir) + ) + + manifest_file.write(manifest) + return cache_is_written + + # TODO(b/269419184): Add typehints when possible: + # expand(self, dataset_cache_dict: BeamAnalysisCache) -> List[beam.pvalue.PDone] # pylint: disable=line-too-long + def expand(self, dataset_cache_dict): + if self._sorted_dataset_keys is None: + sorted_dataset_keys_list = sorted(dataset_cache_dict.keys()) + else: + sorted_dataset_keys_list = self._sorted_dataset_keys + missing_keys = set(dataset_cache_dict.keys()).difference( + set(sorted_dataset_keys_list) + ) + if missing_keys: + raise ValueError( + "The dataset keys in the cache dictionary must be a subset of the " + f"keys in dataset_keys. Missing {missing_keys}." + ) + if not all(isinstance(d, DatasetKey) for d in sorted_dataset_keys_list): + raise ValueError("Expected dataset_keys to be of type DatasetKey") + + cache_is_written = [] + for dataset_key, cache in dataset_cache_dict.items(): + dataset_key_idx = sorted_dataset_keys_list.index(dataset_key) + dataset_key_dir = _get_dataset_cache_path(self._cache_base_dir, dataset_key) + with _ManifestFile(dataset_key_dir) as manifest_file: + cache_is_written.extend( + self._write_cache( + manifest_file, dataset_key_idx, dataset_key_dir, cache + ) + ) + + return cache_is_written class ReadAnalysisCacheFromFS(beam.PTransform): - """Reads cache from the FS written by WriteAnalysisCacheToFS.""" - - def __init__(self, - cache_base_dir: str, - dataset_keys: Iterable[DatasetKey], - cache_entry_keys: Optional[Iterable[bytes]] = None, - source: Optional[object] = None): - """Init method. - - Args: - cache_base_dir: A string, the path that the cache should be stored in. - dataset_keys: An iterable of `DatasetKey`s. - cache_entry_keys: (Optional) An iterable of cache entry key strings. If - provided, only cache entries that exist in `cache_entry_keys` will be - read. - source: (Optional) A PTransform class that takes a path argument in its - constructor, and is used to read the cache. - """ - self._cache_base_dir = cache_base_dir - if not all(isinstance(d, DatasetKey) for d in dataset_keys): - raise ValueError('Expected dataset_keys to be of type DatasetKey') - self._sorted_dataset_keys = sorted(dataset_keys) - self._filtered_cache_entry_keys = (None if cache_entry_keys is None else - set(cache_entry_keys)) - # TODO(b/37788560): Possibly use Riegeli as a default file format once - # possible. - self._source = source if source is not None else beam.io.ReadFromTFRecord - - def _should_read_cache_entry_key(self, key: str) -> bool: - return (self._filtered_cache_entry_keys is None or - key in self._filtered_cache_entry_keys) - - # TODO(b/269419184): Add typehints when possible: - # expand(self, pipeline: beam.Pipeline) -> BeamAnalysisCache - def expand(self, pipeline: beam.Pipeline): - result = {} - - for dataset_key_idx, dataset_key in enumerate(self._sorted_dataset_keys): - - dataset_cache_path = _get_dataset_cache_path(self._cache_base_dir, - dataset_key) - manifest_file = _ManifestFile(dataset_cache_path) - manifest = manifest_file.read() - if not manifest: - continue - dataset_id = f'AnalysisIndex{dataset_key_idx}' - cache_dict = {} - for key, cache_key_idx in manifest.items(): - if self._should_read_cache_entry_key(key): - cache_dict[key] = ( - pipeline - | f'Read[{dataset_id}]][CacheKeyIndex{cache_key_idx}]' >> - self._source( - f'{os.path.join(dataset_cache_path, str(cache_key_idx))}-*-of-*' - )) - metadata = pipeline | f'ReadMetadata[{dataset_id}]' >> _ReadMetadata( - dataset_cache_path) - result[dataset_key] = DatasetCache(cache_dict, metadata) - return result + """Reads cache from the FS written by WriteAnalysisCacheToFS.""" + + def __init__( + self, + cache_base_dir: str, + dataset_keys: Iterable[DatasetKey], + cache_entry_keys: Optional[Iterable[bytes]] = None, + source: Optional[object] = None, + ): + """Init method. + + Args: + ---- + cache_base_dir: A string, the path that the cache should be stored in. + dataset_keys: An iterable of `DatasetKey`s. + cache_entry_keys: (Optional) An iterable of cache entry key strings. If + provided, only cache entries that exist in `cache_entry_keys` will be + read. + source: (Optional) A PTransform class that takes a path argument in its + constructor, and is used to read the cache. + """ + self._cache_base_dir = cache_base_dir + if not all(isinstance(d, DatasetKey) for d in dataset_keys): + raise ValueError("Expected dataset_keys to be of type DatasetKey") + self._sorted_dataset_keys = sorted(dataset_keys) + self._filtered_cache_entry_keys = ( + None if cache_entry_keys is None else set(cache_entry_keys) + ) + # TODO(b/37788560): Possibly use Riegeli as a default file format once + # possible. + self._source = source if source is not None else beam.io.ReadFromTFRecord + + def _should_read_cache_entry_key(self, key: str) -> bool: + return ( + self._filtered_cache_entry_keys is None + or key in self._filtered_cache_entry_keys + ) + + # TODO(b/269419184): Add typehints when possible: + # expand(self, pipeline: beam.Pipeline) -> BeamAnalysisCache + def expand(self, pipeline: beam.Pipeline): + result = {} + + for dataset_key_idx, dataset_key in enumerate(self._sorted_dataset_keys): + dataset_cache_path = _get_dataset_cache_path( + self._cache_base_dir, dataset_key + ) + manifest_file = _ManifestFile(dataset_cache_path) + manifest = manifest_file.read() + if not manifest: + continue + dataset_id = f"AnalysisIndex{dataset_key_idx}" + cache_dict = {} + for key, cache_key_idx in manifest.items(): + if self._should_read_cache_entry_key(key): + cache_dict[key] = ( + pipeline + | f"Read[{dataset_id}]][CacheKeyIndex{cache_key_idx}]" + >> self._source( + f"{os.path.join(dataset_cache_path, str(cache_key_idx))}-*-of-*" + ) + ) + metadata = pipeline | f"ReadMetadata[{dataset_id}]" >> _ReadMetadata( + dataset_cache_path + ) + result[dataset_key] = DatasetCache(cache_dict, metadata) + return result def validate_dataset_keys(dataset_keys: Iterable[DatasetKey]): - regex = re.compile(r'^[a-zA-Z0-9\.\-_]+$') - for dataset_key in dataset_keys: - if not isinstance(dataset_key, DatasetKey): - raise ValueError('Dataset key {} must be of type DatasetKey') - if not regex.match(dataset_key.key): - raise ValueError( - 'Dataset key {!r} does not match allowed pattern: {!r}'.format( - dataset_key.key, regex.pattern)) + regex = re.compile(r"^[a-zA-Z0-9\.\-_]+$") + for dataset_key in dataset_keys: + if not isinstance(dataset_key, DatasetKey): + raise ValueError("Dataset key {} must be of type DatasetKey") + if not regex.match(dataset_key.key): + raise ValueError( + f"Dataset key {dataset_key.key!r} does not match allowed pattern: {regex.pattern!r}" + ) def make_cache_entry_key(cache_key: str) -> str: - return _CACHE_VERSION + tf.compat.as_bytes(cache_key) + return _CACHE_VERSION + tf.compat.as_bytes(cache_key) diff --git a/tensorflow_transform/beam/analyzer_cache_test.py b/tensorflow_transform/beam/analyzer_cache_test.py index b23c5cb..1cc1e38 100644 --- a/tensorflow_transform/beam/analyzer_cache_test.py +++ b/tensorflow_transform/beam/analyzer_cache_test.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2018 Google Inc. All Rights Reserved. # @@ -18,298 +17,341 @@ import os import apache_beam as beam -from apache_beam.testing import util as beam_test_util import numpy as np - import tensorflow as tf +from apache_beam.testing import util as beam_test_util -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import analyzers +from tensorflow_transform import analyzer_nodes, analyzers, test_case from tensorflow_transform.beam import analyzer_cache -from tensorflow_transform import test_case mock = tf.compat.v1.test.mock def _get_quantiles_accumulator(): - - qcombiner = analyzers.QuantilesCombiner( - num_quantiles=2, - epsilon=0.01, - bucket_numpy_dtype=np.float32, - has_weights=False, - output_shape=None, - include_max_and_min=False, - feature_shape=[1]) - accumulator = qcombiner.create_accumulator() - return qcombiner.add_input(accumulator, [np.array([1.0, 2.0, 3.0])]) + qcombiner = analyzers.QuantilesCombiner( + num_quantiles=2, + epsilon=0.01, + bucket_numpy_dtype=np.float32, + has_weights=False, + output_shape=None, + include_max_and_min=False, + feature_shape=[1], + ) + accumulator = qcombiner.create_accumulator() + return qcombiner.add_input(accumulator, [np.array([1.0, 2.0, 3.0])]) class AnalyzerCacheTest(test_case.TransformTestCase): - - def test_validate_dataset_keys(self): - analyzer_cache.validate_dataset_keys({ - analyzer_cache.DatasetKey(k) - for k in ('foo', 'Foo', 'A1', 'A_1', 'A.1', 'A-1', 'foo@1', 'foo*', - 'foo[]', 'foo/goo') - }) - - for key in {analyzer_cache.DatasetKey(k) for k in ('^foo^', 'foo 1')}: - with self.assertRaisesRegex( - ValueError, 'Dataset key .* does not match allowed pattern:' - ): - analyzer_cache.validate_dataset_keys({key}) - - @test_case.named_parameters( - dict( - testcase_name='JsonNumpyCacheCoder', - coder=analyzer_nodes.JsonNumpyCacheCoder(), - value=[1, 2.5, 3, '4']), - dict( - testcase_name='JsonNumpyCacheCoderNpArray', - coder=analyzer_nodes.JsonNumpyCacheCoder(), - value=np.array([1, 2.5, 3, '4'])), - dict( - testcase_name='JsonNumpyCacheCoderNestedNpTypes', - coder=analyzer_nodes.JsonNumpyCacheCoder(), - value=[np.int64(1), np.float32(2.5), 3, '4']), - dict( - testcase_name='_VocabularyAccumulatorCoderIntAccumulator', - coder=analyzer_nodes._VocabularyAccumulatorCoder(), - value=[b'A', 17]), - dict( - testcase_name='_VocabularyAccumulatorCoderIntAccumulatorNonUtf8', - coder=analyzer_nodes._VocabularyAccumulatorCoder(), - value=[b'\x8a', 29]), - dict( - testcase_name='_WeightedMeanAndVarAccumulatorPerKey', - coder=analyzer_nodes._VocabularyAccumulatorCoder(), - value=[ - b'A', - analyzers._WeightedMeanAndVarAccumulator( - count=np.array(5), - mean=np.array([.4, .9, 1.5]), - variance=np.array([.1, .4, .5]), - weight=np.array(0.), - ) - ]), - dict( - testcase_name='_WeightedMeanAndVarAccumulatorKeepDims', - coder=analyzer_nodes.JsonNumpyCacheCoder(), - # TODO(b/268341036): Remove this complication once np 1.24 issue is - # fixed. - value=analyzer_nodes.JsonNumpyCacheCoder(object).decode_cache( - analyzer_nodes.JsonNumpyCacheCoder().encode_cache( - analyzers._WeightedMeanAndVarAccumulator( - count=np.array(0), - mean=np.array([], dtype=np.float64), - variance=np.array([], dtype=np.float64), - weight=np.array(0.0)))) - ), - dict( - testcase_name='_QuantilesAccumulatorCoderClassAccumulator', - coder=analyzers._QuantilesSketchCacheCoder(), - value=_get_quantiles_accumulator()), - dict( - testcase_name='_CombinerPerKeyAccumulatorCoder', - coder=analyzer_nodes._CombinerPerKeyAccumulatorCoder( - analyzer_nodes.JsonNumpyCacheCoder()), - value=[b'\x8a', [np.int64(1), np.float32(2.5), 3, '4']]), - ) - def test_coders_round_trip(self, coder, value): - encoded = coder.encode_cache(value) - if isinstance(coder, analyzers._QuantilesSketchCacheCoder): - # Quantiles accumulator becomes a different object after pickle round trip - # and doesn't have a deep __eq__ defined. That's why we compare the output - # of accumulator before and after pickling. - np.testing.assert_equal( - coder.decode_cache(encoded).GetQuantiles(10).to_pylist(), - value.GetQuantiles(10).to_pylist()) - else: - np.testing.assert_equal(coder.decode_cache(encoded), value) - - def test_cache_helpers_round_trip(self): - base_test_dir = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - - dataset_key_0_metadata = analyzer_cache.DatasetCacheMetadata(42) - dataset_key_1_metadata = analyzer_cache.DatasetCacheMetadata(17) - dataset_key_0 = analyzer_cache.DatasetKey('dataset_key_0') - dataset_key_1 = analyzer_cache.DatasetKey('dataset_key_1') - dataset_keys = (dataset_key_0, dataset_key_1) - - with beam.Pipeline() as p: - cache_pcoll_dict = { - dataset_key_0: - analyzer_cache.DatasetCache( - { - b'\x8a': p | 'CreateA' >> beam.Create([b'[1, 2, 3]']), - b'\x8b': p | 'CreateB' >> beam.Create([b'[5]']), - b'\x8b1': p | 'CreateB1' >> beam.Create([b'[6]']), - }, p | 'CreateM0' >> beam.Create([dataset_key_0_metadata])), - dataset_key_1: - analyzer_cache.DatasetCache( - { - b'\x8c': p | 'CreateC' >> beam.Create([b'[9, 5, 2, 1]']), - }, p | 'CreateM1' >> beam.Create([dataset_key_1_metadata])), - } - - _ = cache_pcoll_dict | analyzer_cache.WriteAnalysisCacheToFS( - p, base_test_dir, dataset_keys) - - with beam.Pipeline() as p: - read_cache = p | analyzer_cache.ReadAnalysisCacheFromFS( - base_test_dir, list(cache_pcoll_dict.keys()), - [b'\x8a', b'\x8b', b'\x8c']) - - beam_test_util.assert_that( - read_cache[dataset_key_0].cache_dict[b'\x8a'], - beam_test_util.equal_to([b'[1, 2, 3]']), - label='AssertA') - beam_test_util.assert_that( - read_cache[dataset_key_0].cache_dict[b'\x8b'], - beam_test_util.equal_to([b'[5]']), - label='AssertB') - beam_test_util.assert_that( - read_cache[dataset_key_0].metadata, - beam_test_util.equal_to([dataset_key_0_metadata]), - label='Assert0Size') - beam_test_util.assert_that( - read_cache[dataset_key_1].cache_dict[b'\x8c'], - beam_test_util.equal_to([b'[9, 5, 2, 1]']), - label='AssertC') - beam_test_util.assert_that( - read_cache[dataset_key_1].metadata, - beam_test_util.equal_to([dataset_key_1_metadata]), - label='Assert1Size') - - def test_cache_write_empty(self): - base_test_dir = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - - with beam.Pipeline() as p: - _ = {} | analyzer_cache.WriteAnalysisCacheToFS( - p, base_test_dir, (analyzer_cache.DatasetKey('dataset_key_0'),)) - self.assertFalse(os.path.isdir(base_test_dir)) - - def test_cache_merge(self): - base_test_dir = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - - dataset_key_0 = analyzer_cache.DatasetKey('dataset_key_0') - dataset_key_1 = analyzer_cache.DatasetKey('dataset_key_1') - dataset_keys = (dataset_key_0, dataset_key_1) - cache_keys = list('abcd') - - def read_manifests(): - return [ - analyzer_cache._ManifestFile( - analyzer_cache._get_dataset_cache_path(base_test_dir, key)).read() - for key in dataset_keys - ] - - with beam.Pipeline() as p: - cache_pcoll_dict = { - dataset_key_0: - analyzer_cache.DatasetCache( - { - 'a': p | 'CreateA' >> beam.Create([b'a']), - 'b': p | 'CreateB' >> beam.Create([b'b']), - }, None), - dataset_key_1: - analyzer_cache.DatasetCache( - { - 'c': p | 'CreateC' >> beam.Create([b'c']), - 'd': p | 'CreateD' >> beam.Create([b'd']), - }, None), - } - _ = cache_pcoll_dict | analyzer_cache.WriteAnalysisCacheToFS( - p, base_test_dir, dataset_keys) - - first_manifests = read_manifests() - - with beam.Pipeline() as p: - cache_pcoll_dict = { - dataset_key_0: - analyzer_cache.DatasetCache( - { - 'c': p | 'CreateC' >> beam.Create([b'c']), - 'd': p | 'CreateD' >> beam.Create([b'd']), - }, None), - dataset_key_1: - analyzer_cache.DatasetCache( - { - 'a': p | 'CreateA' >> beam.Create([b'a']), - 'b': p | 'CreateB' >> beam.Create([b'b']), - }, None), - } - _ = cache_pcoll_dict | analyzer_cache.WriteAnalysisCacheToFS( - p, base_test_dir, dataset_keys) - - second_manifests = read_manifests() - self.assertEqual(len(first_manifests), len(second_manifests)) - for manifest_a, manifest_b in zip(first_manifests, second_manifests): - for key_value_pair in manifest_a.items(): - self.assertIn(key_value_pair, manifest_b.items()) - - self.assertEqual(2, len(manifest_a)) - self.assertCountEqual(range(len(manifest_a)), manifest_a.values()) - - self.assertEqual(4, len(manifest_b)) - self.assertCountEqual(range(len(manifest_b)), manifest_b.values()) - self.assertCountEqual(cache_keys, manifest_b.keys()) - - def test_cache_helpers_with_alternative_io(self): - - class LocalSink(beam.PTransform): - - def __init__(self, path): - self._path = path - - def expand(self, pcoll): - - def write_to_file(value): - tf.io.gfile.makedirs(self._path) - with open(os.path.join(self._path, 'cache'), 'wb') as f: - f.write(value) - - return pcoll | beam.Map(write_to_file) - - dataset_key = analyzer_cache.DatasetKey('a') - test_cache_dict = { - dataset_key: - analyzer_cache.DatasetCache({'b': [bytes([17, 19, 27, 31])]}, None) - } - dataset_keys = list(test_cache_dict.keys()) - - class LocalSource(beam.PTransform): - - def __init__(self, path): - del path - - def expand(self, pbegin): - return pbegin | beam.Create( - [test_cache_dict[k].cache_dict['b'] for k in dataset_keys]) - - cache_dir = self.get_temp_dir() - with beam.Pipeline() as p: - _ = test_cache_dict | analyzer_cache.WriteAnalysisCacheToFS( - p, cache_dir, dataset_keys, sink=LocalSink) - - read_cache = p | analyzer_cache.ReadAnalysisCacheFromFS( - cache_dir, dataset_keys, source=LocalSource) - - self.assertCountEqual(read_cache, [dataset_key]) - self.assertCountEqual(read_cache[dataset_key].cache_dict.keys(), ['b']) - - for key in dataset_keys: - beam_test_util.assert_that( - read_cache[key].cache_dict['b'], - beam_test_util.equal_to([test_cache_dict[key].cache_dict['b']])) - - -if __name__ == '__main__': - test_case.main() + def test_validate_dataset_keys(self): + analyzer_cache.validate_dataset_keys( + { + analyzer_cache.DatasetKey(k) + for k in ( + "foo", + "Foo", + "A1", + "A_1", + "A.1", + "A-1", + "foo@1", + "foo*", + "foo[]", + "foo/goo", + ) + } + ) + + for key in {analyzer_cache.DatasetKey(k) for k in ("^foo^", "foo 1")}: + with self.assertRaisesRegex( + ValueError, "Dataset key .* does not match allowed pattern:" + ): + analyzer_cache.validate_dataset_keys({key}) + + @test_case.named_parameters( + dict( + testcase_name="JsonNumpyCacheCoder", + coder=analyzer_nodes.JsonNumpyCacheCoder(), + value=[1, 2.5, 3, "4"], + ), + dict( + testcase_name="JsonNumpyCacheCoderNpArray", + coder=analyzer_nodes.JsonNumpyCacheCoder(), + value=np.array([1, 2.5, 3, "4"]), + ), + dict( + testcase_name="JsonNumpyCacheCoderNestedNpTypes", + coder=analyzer_nodes.JsonNumpyCacheCoder(), + value=[np.int64(1), np.float32(2.5), 3, "4"], + ), + dict( + testcase_name="_VocabularyAccumulatorCoderIntAccumulator", + coder=analyzer_nodes._VocabularyAccumulatorCoder(), + value=[b"A", 17], + ), + dict( + testcase_name="_VocabularyAccumulatorCoderIntAccumulatorNonUtf8", + coder=analyzer_nodes._VocabularyAccumulatorCoder(), + value=[b"\x8a", 29], + ), + dict( + testcase_name="_WeightedMeanAndVarAccumulatorPerKey", + coder=analyzer_nodes._VocabularyAccumulatorCoder(), + value=[ + b"A", + analyzers._WeightedMeanAndVarAccumulator( + count=np.array(5), + mean=np.array([0.4, 0.9, 1.5]), + variance=np.array([0.1, 0.4, 0.5]), + weight=np.array(0.0), + ), + ], + ), + dict( + testcase_name="_WeightedMeanAndVarAccumulatorKeepDims", + coder=analyzer_nodes.JsonNumpyCacheCoder(), + # TODO(b/268341036): Remove this complication once np 1.24 issue is + # fixed. + value=analyzer_nodes.JsonNumpyCacheCoder(object).decode_cache( + analyzer_nodes.JsonNumpyCacheCoder().encode_cache( + analyzers._WeightedMeanAndVarAccumulator( + count=np.array(0), + mean=np.array([], dtype=np.float64), + variance=np.array([], dtype=np.float64), + weight=np.array(0.0), + ) + ) + ), + ), + dict( + testcase_name="_QuantilesAccumulatorCoderClassAccumulator", + coder=analyzers._QuantilesSketchCacheCoder(), + value=_get_quantiles_accumulator(), + ), + dict( + testcase_name="_CombinerPerKeyAccumulatorCoder", + coder=analyzer_nodes._CombinerPerKeyAccumulatorCoder( + analyzer_nodes.JsonNumpyCacheCoder() + ), + value=[b"\x8a", [np.int64(1), np.float32(2.5), 3, "4"]], + ), + ) + def test_coders_round_trip(self, coder, value): + encoded = coder.encode_cache(value) + if isinstance(coder, analyzers._QuantilesSketchCacheCoder): + # Quantiles accumulator becomes a different object after pickle round trip + # and doesn't have a deep __eq__ defined. That's why we compare the output + # of accumulator before and after pickling. + np.testing.assert_equal( + coder.decode_cache(encoded).GetQuantiles(10).to_pylist(), + value.GetQuantiles(10).to_pylist(), + ) + else: + np.testing.assert_equal(coder.decode_cache(encoded), value) + + def test_cache_helpers_round_trip(self): + base_test_dir = os.path.join( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", self.get_temp_dir()), + self._testMethodName, + ) + + dataset_key_0_metadata = analyzer_cache.DatasetCacheMetadata(42) + dataset_key_1_metadata = analyzer_cache.DatasetCacheMetadata(17) + dataset_key_0 = analyzer_cache.DatasetKey("dataset_key_0") + dataset_key_1 = analyzer_cache.DatasetKey("dataset_key_1") + dataset_keys = (dataset_key_0, dataset_key_1) + + with beam.Pipeline() as p: + cache_pcoll_dict = { + dataset_key_0: analyzer_cache.DatasetCache( + { + b"\x8a": p | "CreateA" >> beam.Create([b"[1, 2, 3]"]), + b"\x8b": p | "CreateB" >> beam.Create([b"[5]"]), + b"\x8b1": p | "CreateB1" >> beam.Create([b"[6]"]), + }, + p | "CreateM0" >> beam.Create([dataset_key_0_metadata]), + ), + dataset_key_1: analyzer_cache.DatasetCache( + { + b"\x8c": p | "CreateC" >> beam.Create([b"[9, 5, 2, 1]"]), + }, + p | "CreateM1" >> beam.Create([dataset_key_1_metadata]), + ), + } + + _ = cache_pcoll_dict | analyzer_cache.WriteAnalysisCacheToFS( + p, base_test_dir, dataset_keys + ) + + with beam.Pipeline() as p: + read_cache = p | analyzer_cache.ReadAnalysisCacheFromFS( + base_test_dir, + list(cache_pcoll_dict.keys()), + [b"\x8a", b"\x8b", b"\x8c"], + ) + + beam_test_util.assert_that( + read_cache[dataset_key_0].cache_dict[b"\x8a"], + beam_test_util.equal_to([b"[1, 2, 3]"]), + label="AssertA", + ) + beam_test_util.assert_that( + read_cache[dataset_key_0].cache_dict[b"\x8b"], + beam_test_util.equal_to([b"[5]"]), + label="AssertB", + ) + beam_test_util.assert_that( + read_cache[dataset_key_0].metadata, + beam_test_util.equal_to([dataset_key_0_metadata]), + label="Assert0Size", + ) + beam_test_util.assert_that( + read_cache[dataset_key_1].cache_dict[b"\x8c"], + beam_test_util.equal_to([b"[9, 5, 2, 1]"]), + label="AssertC", + ) + beam_test_util.assert_that( + read_cache[dataset_key_1].metadata, + beam_test_util.equal_to([dataset_key_1_metadata]), + label="Assert1Size", + ) + + def test_cache_write_empty(self): + base_test_dir = os.path.join( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", self.get_temp_dir()), + self._testMethodName, + ) + + with beam.Pipeline() as p: + _ = {} | analyzer_cache.WriteAnalysisCacheToFS( + p, base_test_dir, (analyzer_cache.DatasetKey("dataset_key_0"),) + ) + self.assertFalse(os.path.isdir(base_test_dir)) + + def test_cache_merge(self): + base_test_dir = os.path.join( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", self.get_temp_dir()), + self._testMethodName, + ) + + dataset_key_0 = analyzer_cache.DatasetKey("dataset_key_0") + dataset_key_1 = analyzer_cache.DatasetKey("dataset_key_1") + dataset_keys = (dataset_key_0, dataset_key_1) + cache_keys = list("abcd") + + def read_manifests(): + return [ + analyzer_cache._ManifestFile( + analyzer_cache._get_dataset_cache_path(base_test_dir, key) + ).read() + for key in dataset_keys + ] + + with beam.Pipeline() as p: + cache_pcoll_dict = { + dataset_key_0: analyzer_cache.DatasetCache( + { + "a": p | "CreateA" >> beam.Create([b"a"]), + "b": p | "CreateB" >> beam.Create([b"b"]), + }, + None, + ), + dataset_key_1: analyzer_cache.DatasetCache( + { + "c": p | "CreateC" >> beam.Create([b"c"]), + "d": p | "CreateD" >> beam.Create([b"d"]), + }, + None, + ), + } + _ = cache_pcoll_dict | analyzer_cache.WriteAnalysisCacheToFS( + p, base_test_dir, dataset_keys + ) + + first_manifests = read_manifests() + + with beam.Pipeline() as p: + cache_pcoll_dict = { + dataset_key_0: analyzer_cache.DatasetCache( + { + "c": p | "CreateC" >> beam.Create([b"c"]), + "d": p | "CreateD" >> beam.Create([b"d"]), + }, + None, + ), + dataset_key_1: analyzer_cache.DatasetCache( + { + "a": p | "CreateA" >> beam.Create([b"a"]), + "b": p | "CreateB" >> beam.Create([b"b"]), + }, + None, + ), + } + _ = cache_pcoll_dict | analyzer_cache.WriteAnalysisCacheToFS( + p, base_test_dir, dataset_keys + ) + + second_manifests = read_manifests() + self.assertEqual(len(first_manifests), len(second_manifests)) + for manifest_a, manifest_b in zip(first_manifests, second_manifests): + for key_value_pair in manifest_a.items(): + self.assertIn(key_value_pair, manifest_b.items()) + + self.assertEqual(2, len(manifest_a)) + self.assertCountEqual(range(len(manifest_a)), manifest_a.values()) + + self.assertEqual(4, len(manifest_b)) + self.assertCountEqual(range(len(manifest_b)), manifest_b.values()) + self.assertCountEqual(cache_keys, manifest_b.keys()) + + def test_cache_helpers_with_alternative_io(self): + class LocalSink(beam.PTransform): + def __init__(self, path): + self._path = path + + def expand(self, pcoll): + def write_to_file(value): + tf.io.gfile.makedirs(self._path) + with open(os.path.join(self._path, "cache"), "wb") as f: + f.write(value) + + return pcoll | beam.Map(write_to_file) + + dataset_key = analyzer_cache.DatasetKey("a") + test_cache_dict = { + dataset_key: analyzer_cache.DatasetCache( + {"b": [bytes([17, 19, 27, 31])]}, None + ) + } + dataset_keys = list(test_cache_dict.keys()) + + class LocalSource(beam.PTransform): + def __init__(self, path): + del path + + def expand(self, pbegin): + return pbegin | beam.Create( + [test_cache_dict[k].cache_dict["b"] for k in dataset_keys] + ) + + cache_dir = self.get_temp_dir() + with beam.Pipeline() as p: + _ = test_cache_dict | analyzer_cache.WriteAnalysisCacheToFS( + p, cache_dir, dataset_keys, sink=LocalSink + ) + + read_cache = p | analyzer_cache.ReadAnalysisCacheFromFS( + cache_dir, dataset_keys, source=LocalSource + ) + + self.assertCountEqual(read_cache, [dataset_key]) + self.assertCountEqual(read_cache[dataset_key].cache_dict.keys(), ["b"]) + + for key in dataset_keys: + beam_test_util.assert_that( + read_cache[key].cache_dict["b"], + beam_test_util.equal_to([test_cache_dict[key].cache_dict["b"]]), + ) + + +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/beam/analyzer_impls.py b/tensorflow_transform/beam/analyzer_impls.py index 476b779..0de97c2 100644 --- a/tensorflow_transform/beam/analyzer_impls.py +++ b/tensorflow_transform/beam/analyzer_impls.py @@ -20,28 +20,30 @@ import os import typing -from absl import logging import apache_beam as beam - -from apache_beam.transforms.ptransform import ptransform_fn -from apache_beam.typehints import Any -from apache_beam.typehints import Dict -from apache_beam.typehints import Iterable -from apache_beam.typehints import KV -from apache_beam.typehints import List -from apache_beam.typehints import Tuple -from apache_beam.typehints import TypeVariable -from apache_beam.typehints import Union - import numpy as np import tensorflow as tf -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import analyzers -from tensorflow_transform import common_types -from tensorflow_transform import info_theory -from tensorflow_transform import tf_utils -from tensorflow_transform.beam import common -from tensorflow_transform.beam import experimental +from absl import logging +from apache_beam.transforms.ptransform import ptransform_fn +from apache_beam.typehints import ( + KV, + Any, + Dict, + Iterable, + List, + Tuple, + TypeVariable, + Union, +) + +from tensorflow_transform import ( + analyzer_nodes, + analyzers, + common_types, + info_theory, + tf_utils, +) +from tensorflow_transform.beam import common, experimental # TODO(b/199789764): Enable beam type checks (and remove this) after the # violations are fixed. @@ -49,8 +51,9 @@ _VocabOrderingType = analyzers._VocabOrderingType # pylint: disable=protected-access _VocabTokenType = Union[bytes, int] -_VocabAccumulatedIndicatorType = Union[int, Tuple[ - float, float], analyzers.WeightedMeanAndVarCombiner.accumulator_class] +_VocabAccumulatedIndicatorType = Union[ + int, Tuple[float, float], analyzers.WeightedMeanAndVarCombiner.accumulator_class +] _VocabIndicatorType = Union[float, int, bytes, Tuple[float, float]] # TODO(b/140645408, b/31727404, b/160207487): Remove this manual fanout when @@ -65,85 +68,89 @@ @ptransform_fn @beam.typehints.with_input_types(KV[_VocabIndicatorType, _VocabTokenType]) -@beam.typehints.with_output_types(Iterable[KV[_VocabIndicatorType, - _VocabTokenType]]) +@beam.typehints.with_output_types(Iterable[KV[_VocabIndicatorType, _VocabTokenType]]) def _BatchAndPreSort(counts, sort_kwargs): # pylint: disable=invalid-name - """Batches vocabulary and pre-sorts the batches.""" - # This PTransform aims to partially parallelize vocabulary sorting in - # `VocabularyOrderAndWrite`. Pre-sorting of batches in a parallel mode with - # `beam.Map` allows to significantly speed up final single-threaded sorting of - # a union in `_OrderElementsFn`. This is because `list.sort()` uses an - # adaptive merge sort that identifies pre-existing order in the union. See - # https://en.wikipedia.org/wiki/Timsort for more details. - return (counts - | 'BatchVocabulary' >> beam.BatchElements( - min_batch_size=_PRESORT_BATCH_SIZE, - max_batch_size=_PRESORT_BATCH_SIZE) - | 'SortBatches' >> beam.Map(lambda b: sorted(b, **sort_kwargs))) # pylint: disable=unnecessary-lambda + """Batches vocabulary and pre-sorts the batches.""" + # This PTransform aims to partially parallelize vocabulary sorting in + # `VocabularyOrderAndWrite`. Pre-sorting of batches in a parallel mode with + # `beam.Map` allows to significantly speed up final single-threaded sorting of + # a union in `_OrderElementsFn`. This is because `list.sort()` uses an + # adaptive merge sort that identifies pre-existing order in the union. See + # https://en.wikipedia.org/wiki/Timsort for more details. + return ( + counts + | "BatchVocabulary" + >> beam.BatchElements( + min_batch_size=_PRESORT_BATCH_SIZE, max_batch_size=_PRESORT_BATCH_SIZE + ) + | "SortBatches" >> beam.Map(lambda b: sorted(b, **sort_kwargs)) + ) # pylint: disable=unnecessary-lambda def maybe_add_empty_vocabulary_dummy( - counts: List[KV[_VocabIndicatorType, - _VocabTokenType]], dtype: Union[tf.dtypes.DType, str] + counts: List[KV[_VocabIndicatorType, _VocabTokenType]], + dtype: Union[tf.dtypes.DType, str], ) -> List[KV[_VocabIndicatorType, _VocabTokenType]]: - """Returns a list with a dummy token if counts list is empty.""" - if not counts: - return [analyzers.get_empy_vocabulary_dummy_value(dtype)] - else: - return counts + """Returns a list with a dummy token if counts list is empty.""" + if not counts: + return [analyzers.get_empy_vocabulary_dummy_value(dtype)] + else: + return counts -def _count_and_token_to_bytes(count: _VocabIndicatorType, - token: _VocabTokenType) -> bytes: - # Converts `token` (bytes) to unicode first as otherwise the result will - # look like b"1 b'real_string'" in PY3. We convert everything to bytes - # afterwards to get b'1 real_string'. - return tf.compat.as_bytes('{} {}'.format(count, tf.compat.as_str_any(token))) +def _count_and_token_to_bytes( + count: _VocabIndicatorType, token: _VocabTokenType +) -> bytes: + # Converts `token` (bytes) to unicode first as otherwise the result will + # look like b"1 b'real_string'" in PY3. We convert everything to bytes + # afterwards to get b'1 real_string'. + return tf.compat.as_bytes(f"{count} {tf.compat.as_str_any(token)}") class _OrderElementsFn(beam.DoFn): - """Sort the vocabulary by either descending frequency count or hash order.""" - - def __init__(self, store_frequency, sort_kwargs, input_dtype): - self._store_frequency = store_frequency - self._sort_kwargs = sort_kwargs - self._input_dtype = input_dtype - - # Metrics. - self._vocab_size = beam.metrics.Metrics.gauge( - common.METRICS_NAMESPACE, 'vocabulary_size') - - def process( - self, - unused_element, - batched_counts_iter, - reserved_tokens: typing.Optional[typing.List[np.ndarray]] = None, - ) -> typing.Iterable[bytes]: - counts = [] - if reserved_tokens: - (reserved_tokens,) = reserved_tokens - reserved_tokens_set = set(reserved_tokens) - for batch in batched_counts_iter: - # Filtering input tokens that already have a reserved spot. - counts.extend( - filter(lambda ct: ct[1] not in reserved_tokens_set, batch) - ) - else: - for c in batched_counts_iter: - counts.extend(c) + """Sort the vocabulary by either descending frequency count or hash order.""" - counts.sort(**self._sort_kwargs) - # Prepending reserved tokens after computed tokens have already been sorted. - if reserved_tokens is not None: - counts[:0] = [(-1, t) for t in reserved_tokens] - counts = maybe_add_empty_vocabulary_dummy(counts, self._input_dtype) - self._vocab_size.set(len(counts)) + def __init__(self, store_frequency, sort_kwargs, input_dtype): + self._store_frequency = store_frequency + self._sort_kwargs = sort_kwargs + self._input_dtype = input_dtype - for count, entry in counts: - if self._store_frequency: - yield _count_and_token_to_bytes(count, entry) - else: - yield entry + # Metrics. + self._vocab_size = beam.metrics.Metrics.gauge( + common.METRICS_NAMESPACE, "vocabulary_size" + ) + + def process( + self, + unused_element, + batched_counts_iter, + reserved_tokens: typing.Optional[typing.List[np.ndarray]] = None, + ) -> typing.Iterable[bytes]: + counts = [] + if reserved_tokens: + (reserved_tokens,) = reserved_tokens + reserved_tokens_set = set(reserved_tokens) + for batch in batched_counts_iter: + # Filtering input tokens that already have a reserved spot. + counts.extend( + filter(lambda ct: ct[1] not in reserved_tokens_set, batch) + ) + else: + for c in batched_counts_iter: + counts.extend(c) + + counts.sort(**self._sort_kwargs) + # Prepending reserved tokens after computed tokens have already been sorted. + if reserved_tokens is not None: + counts[:0] = [(-1, t) for t in reserved_tokens] + counts = maybe_add_empty_vocabulary_dummy(counts, self._input_dtype) + self._vocab_size.set(len(counts)) + + for count, entry in counts: + if self._store_frequency: + yield _count_and_token_to_bytes(count, entry) + else: + yield entry @ptransform_fn @@ -154,286 +161,321 @@ def _ApplyThresholdsAndTopK( # pylint: disable=invalid-name frequency_threshold, top_k, input_dtype, - info_threshold=float('-inf'), - key_fn=None): - """Applies `frequency_threshold` and `top_k` to (count, value) pairs.""" - # TODO(b/117796748): Filter frequency per-key when key feature input enabled. - # Filter is cheaper than TopK computation and the two commute, so filter - # first. - if frequency_threshold > 0 or info_threshold > float('-inf'): - - def filter_by_thresholds(values): - """Returns True if values are greater than specified thresholds.""" - values, _ = values - # The values can be a single number (the frequency) or a tuple of the - # informativeness and the frequency. - if isinstance(values, tuple): - informativeness, freq = values - else: - informativeness = float('inf') - freq = values - if freq >= frequency_threshold and informativeness >= info_threshold: - return True - return False - - counts |= ('FilterByThresholds(%s)' % frequency_threshold >> - beam.Filter(filter_by_thresholds)) - # If a tuple of multiple metrics, flatten to only the first. This is needed - # for the case the accumulator has tracked informativeness and frequency. - def flatten_to_single_metric(values): - value, term = values - value = value[0] if isinstance(value, tuple) else value - return value, term - - counts |= 'FlattenToSingleMetric' >> beam.Map(flatten_to_single_metric) - - if input_dtype != tf.string.name: - counts |= 'EncodeNumericalTerms' >> beam.MapTuple( - lambda k, v: (k, tf.compat.as_bytes(tf.compat.as_str_any(v)))) - - if top_k is not None: - # TODO(katsiapis): Perhaps enhance Beam's Top to accept an N that can - # signify "unlimited" and then we can simplify a lot of our code (though - # that might come at a performance penalty). - if key_fn: - def map_key_to_count_and_term(kv, key_fn): - """Parses key from term with `key_fn` and maps it to count and term.""" - count, term = kv - # TODO(b/184196242): Ideally we wouldn't be producing numpy.float64 - # counts in the first place, as opposed to casting to float here. See - # also b/79751861. - count = float(count) if isinstance(count, np.float64) else count - key = key_fn(term) - return key, (count, term) - - counts = ( - counts - | 'MapKeyToCountAndTerm' >> beam.Map( - lambda x: map_key_to_count_and_term(x, key_fn)) - | 'CoverageTop(%s)' % top_k >> beam.combiners.Top.LargestPerKey(top_k) - | 'FlattenCoverageTerms' >> beam.FlatMap(lambda kv: kv[1])) - else: - # LINT.IfChange(top_k_impl) - # Stages that follow this block rely on the sorted order of `Top.Of`'s - # output and fusion with the `FlattenList`. If changing this part of - # implementation, either make sure that these hold true or adjust the - # appropriate arg of `VocabularyOrderAndWrite` node. - counts = ( - counts - | 'Top(%s)' % top_k >> beam.combiners.Top.Of(top_k) - | 'MaybeAddDummy' >> beam.Map( - maybe_add_empty_vocabulary_dummy, dtype=input_dtype) - | 'FlattenList' >> beam.FlatMap(lambda lst: lst)) - # LINT.ThenChange(../analyzers.py:input_is_sorted) - - return counts + info_threshold=float("-inf"), + key_fn=None, +): + """Applies `frequency_threshold` and `top_k` to (count, value) pairs.""" + # TODO(b/117796748): Filter frequency per-key when key feature input enabled. + # Filter is cheaper than TopK computation and the two commute, so filter + # first. + if frequency_threshold > 0 or info_threshold > float("-inf"): + + def filter_by_thresholds(values): + """Returns True if values are greater than specified thresholds.""" + values, _ = values + # The values can be a single number (the frequency) or a tuple of the + # informativeness and the frequency. + if isinstance(values, tuple): + informativeness, freq = values + else: + informativeness = float("inf") + freq = values + if freq >= frequency_threshold and informativeness >= info_threshold: + return True + return False + + counts |= "FilterByThresholds(%s)" % frequency_threshold >> beam.Filter( + filter_by_thresholds + ) + + # If a tuple of multiple metrics, flatten to only the first. This is needed + # for the case the accumulator has tracked informativeness and frequency. + def flatten_to_single_metric(values): + value, term = values + value = value[0] if isinstance(value, tuple) else value + return value, term + + counts |= "FlattenToSingleMetric" >> beam.Map(flatten_to_single_metric) + + if input_dtype != tf.string.name: + counts |= "EncodeNumericalTerms" >> beam.MapTuple( + lambda k, v: (k, tf.compat.as_bytes(tf.compat.as_str_any(v))) + ) + + if top_k is not None: + # TODO(katsiapis): Perhaps enhance Beam's Top to accept an N that can + # signify "unlimited" and then we can simplify a lot of our code (though + # that might come at a performance penalty). + if key_fn: + + def map_key_to_count_and_term(kv, key_fn): + """Parses key from term with `key_fn` and maps it to count and term.""" + count, term = kv + # TODO(b/184196242): Ideally we wouldn't be producing numpy.float64 + # counts in the first place, as opposed to casting to float here. See + # also b/79751861. + count = float(count) if isinstance(count, np.float64) else count + key = key_fn(term) + return key, (count, term) + + counts = ( + counts + | "MapKeyToCountAndTerm" + >> beam.Map(lambda x: map_key_to_count_and_term(x, key_fn)) + | "CoverageTop(%s)" % top_k >> beam.combiners.Top.LargestPerKey(top_k) + | "FlattenCoverageTerms" >> beam.FlatMap(lambda kv: kv[1]) + ) + else: + # LINT.IfChange(top_k_impl) + # Stages that follow this block rely on the sorted order of `Top.Of`'s + # output and fusion with the `FlattenList`. If changing this part of + # implementation, either make sure that these hold true or adjust the + # appropriate arg of `VocabularyOrderAndWrite` node. + counts = ( + counts + | "Top(%s)" % top_k >> beam.combiners.Top.Of(top_k) + | "MaybeAddDummy" + >> beam.Map(maybe_add_empty_vocabulary_dummy, dtype=input_dtype) + | "FlattenList" >> beam.FlatMap(lambda lst: lst) + ) + # LINT.ThenChange(../analyzers.py:input_is_sorted) + + return counts # Experimental def sum_labeled_weights( - accs: List[Tuple[float, List[float]]]) -> Tuple[float, List[float]]: - """Sums up a collection of labeled-weight tables. - - Args: - accs: a list of (w, lw) tuples, where w is the total weight (floating point) - and lw is a list of weights for each label. - - Returns: - component-wise sum of the inputs in the same format. - """ - total_weight, labeled_weights = 0., [] - for acc in accs: - total_weight = total_weight + acc[0] - accumulator_labeled_weights = acc[1] - if len(accumulator_labeled_weights) > len(labeled_weights): - labeled_weights.extend( - [0.] * (len(accumulator_labeled_weights) - len(labeled_weights))) - for i in range(len(accumulator_labeled_weights)): - labeled_weights[i] = labeled_weights[i] + accumulator_labeled_weights[i] - return (total_weight, labeled_weights) + accs: List[Tuple[float, List[float]]], +) -> Tuple[float, List[float]]: + """Sums up a collection of labeled-weight tables. + + Args: + ---- + accs: a list of (w, lw) tuples, where w is the total weight (floating point) + and lw is a list of weights for each label. + + Returns: + ------- + component-wise sum of the inputs in the same format. + """ + total_weight, labeled_weights = 0.0, [] + for acc in accs: + total_weight = total_weight + acc[0] + accumulator_labeled_weights = acc[1] + if len(accumulator_labeled_weights) > len(labeled_weights): + labeled_weights.extend( + [0.0] * (len(accumulator_labeled_weights) - len(labeled_weights)) + ) + for i in range(len(accumulator_labeled_weights)): + labeled_weights[i] = labeled_weights[i] + accumulator_labeled_weights[i] + return (total_weight, labeled_weights) @common.register_ptransform(analyzer_nodes.VocabularyAccumulate) @beam.typehints.with_input_types(Tuple[np.ndarray, ...]) -@beam.typehints.with_output_types(KV[_VocabTokenType, - _VocabAccumulatedIndicatorType]) +@beam.typehints.with_output_types(KV[_VocabTokenType, _VocabAccumulatedIndicatorType]) class _VocabularyAccumulateImpl(beam.PTransform): - """Accumulates the unique elements in a PCollection of batches.""" - - def __init__(self, operation, extra_args): - self._vocab_ordering_type = operation.vocab_ordering_type - - def expand(self, inputs): - pcoll, = inputs - - # Create a PCollection of (count, element) pairs, then iterates over - # this to create a single element PCollection containing this list of - # pairs in sorted order by decreasing counts (and by values for equal - # counts). - - # TODO(b/112916494): Unify the graph in both cases once possible. - if (self._vocab_ordering_type == - _VocabOrderingType.WEIGHTED_MUTUAL_INFORMATION): - flatten_map_fn = functools.partial( - _flatten_to_key_and_means_accumulator_list, compute_weighted=True) - combine_transform = _MutualInformationTransformAccumulate( # pylint: disable=no-value-for-parameter - compute_weighted=True) - elif self._vocab_ordering_type == _VocabOrderingType.MUTUAL_INFORMATION: - flatten_map_fn = functools.partial( - _flatten_to_key_and_means_accumulator_list, compute_weighted=False) - combine_transform = _MutualInformationTransformAccumulate( # pylint: disable=no-value-for-parameter - compute_weighted=False) - elif self._vocab_ordering_type == _VocabOrderingType.WEIGHTED_FREQUENCY: - flatten_map_fn = _flatten_value_and_weights_to_list_of_tuples - combine_transform = beam.CombinePerKey(sum) - elif self._vocab_ordering_type == _VocabOrderingType.WEIGHTED_LABELS: - flatten_map_fn = _flatten_value_and_labeled_weights_to_list_of_tuples - # TODO(b/199789764) This returns a type outside of - # _VocabAccumulatedIndicatorType union - - # it is only supported by the specific contrib transform (vocab_map). - # This probably should be moved to vocab_map itself in order to not make - # every transform to support this type. - combine_transform = beam.CombinePerKey(sum_labeled_weights) - else: - flatten_map_fn = _flatten_value_to_list - combine_transform = beam.combiners.Count.PerElement() + """Accumulates the unique elements in a PCollection of batches.""" - if not _ENABLE_BEAM_TYPE_CHECKS: - combine_transform = combine_transform.with_input_types(Any) - combine_transform = combine_transform.with_output_types(Any) + def __init__(self, operation, extra_args): + self._vocab_ordering_type = operation.vocab_ordering_type - result = ( - pcoll - | 'FlattenTokensAndMaybeWeightsLabels' >> beam.FlatMap(flatten_map_fn) - | 'CountPerToken' >> combine_transform) + def expand(self, inputs): + (pcoll,) = inputs - return result + # Create a PCollection of (count, element) pairs, then iterates over + # this to create a single element PCollection containing this list of + # pairs in sorted order by decreasing counts (and by values for equal + # counts). + + # TODO(b/112916494): Unify the graph in both cases once possible. + if self._vocab_ordering_type == _VocabOrderingType.WEIGHTED_MUTUAL_INFORMATION: + flatten_map_fn = functools.partial( + _flatten_to_key_and_means_accumulator_list, compute_weighted=True + ) + combine_transform = _MutualInformationTransformAccumulate( # pylint: disable=no-value-for-parameter + compute_weighted=True + ) + elif self._vocab_ordering_type == _VocabOrderingType.MUTUAL_INFORMATION: + flatten_map_fn = functools.partial( + _flatten_to_key_and_means_accumulator_list, compute_weighted=False + ) + combine_transform = _MutualInformationTransformAccumulate( # pylint: disable=no-value-for-parameter + compute_weighted=False + ) + elif self._vocab_ordering_type == _VocabOrderingType.WEIGHTED_FREQUENCY: + flatten_map_fn = _flatten_value_and_weights_to_list_of_tuples + combine_transform = beam.CombinePerKey(sum) + elif self._vocab_ordering_type == _VocabOrderingType.WEIGHTED_LABELS: + flatten_map_fn = _flatten_value_and_labeled_weights_to_list_of_tuples + # TODO(b/199789764) This returns a type outside of + # _VocabAccumulatedIndicatorType union - + # it is only supported by the specific contrib transform (vocab_map). + # This probably should be moved to vocab_map itself in order to not make + # every transform to support this type. + combine_transform = beam.CombinePerKey(sum_labeled_weights) + else: + flatten_map_fn = _flatten_value_to_list + combine_transform = beam.combiners.Count.PerElement() + + if not _ENABLE_BEAM_TYPE_CHECKS: + combine_transform = combine_transform.with_input_types(Any) + combine_transform = combine_transform.with_output_types(Any) + + result = ( + pcoll + | "FlattenTokensAndMaybeWeightsLabels" >> beam.FlatMap(flatten_map_fn) + | "CountPerToken" >> combine_transform + ) + + return result @common.register_ptransform(analyzer_nodes.VocabularyCount) @beam.typehints.with_input_types(KV[_VocabIndicatorType, _VocabTokenType]) @beam.typehints.with_output_types(np.int64) class _VocabularyCountImpl(beam.PTransform): - """Counts the total number of tokens in the vocabulary.""" + """Counts the total number of tokens in the vocabulary.""" - def __init__(self, operation, extra_args): - super().__init__() + def __init__(self, operation, extra_args): + super().__init__() - def _format_count(self, count): - # Count should be at least one because empty vocabularies get populated with - # a single dummy value when written. - # TODO(b/62272023) remove this workaround if/when fixed on tensorflow. - return np.int64(np.maximum(count, 1)) + def _format_count(self, count): + # Count should be at least one because empty vocabularies get populated with + # a single dummy value when written. + # TODO(b/62272023) remove this workaround if/when fixed on tensorflow. + return np.int64(np.maximum(count, 1)) - def expand(self, inputs): - pcoll, = inputs + def expand(self, inputs): + (pcoll,) = inputs - return (pcoll - | 'TotalVocabSize' >> beam.combiners.Count.Globally() - | 'FormatCount' >> beam.Map(self._format_count)) + return ( + pcoll + | "TotalVocabSize" >> beam.combiners.Count.Globally() + | "FormatCount" >> beam.Map(self._format_count) + ) @common.register_ptransform(analyzer_nodes.VocabularyMerge) -@beam.typehints.with_input_types(KV[_VocabTokenType, - _VocabAccumulatedIndicatorType]) +@beam.typehints.with_input_types(KV[_VocabTokenType, _VocabAccumulatedIndicatorType]) @beam.typehints.with_output_types(KV[_VocabIndicatorType, _VocabTokenType]) class _VocabularyMergeImpl(beam.PTransform): - """Merges vocabulary accumulators of (token, num) pairs.""" - - def __init__(self, operation, extra_args): - self._vocab_ordering_type = operation.vocab_ordering_type - self._use_adjusted_mutual_info = operation.use_adjusted_mutual_info - self._min_diff_from_avg = operation.min_diff_from_avg - - def expand(self, inputs): - if (self._vocab_ordering_type == - _VocabOrderingType.WEIGHTED_MUTUAL_INFORMATION): - combine_transform = _MutualInformationTransformMerge( # pylint: disable=no-value-for-parameter - self._use_adjusted_mutual_info, - self._min_diff_from_avg, - compute_weighted=True) - elif self._vocab_ordering_type == _VocabOrderingType.MUTUAL_INFORMATION: - combine_transform = _MutualInformationTransformMerge( # pylint: disable=no-value-for-parameter - self._use_adjusted_mutual_info, - self._min_diff_from_avg, - compute_weighted=False) - elif self._vocab_ordering_type == _VocabOrderingType.WEIGHTED_LABELS: - # TODO(b/199789764) This returns a type outside of - # _VocabAccumulatedIndicatorType union - - # it is only supported by the specific contrib transform (vocab_map). - # This probably should be moved to vocab_map itself in order to not make - # every transform to support this type. - combine_transform = beam.CombinePerKey(sum_labeled_weights) - else: - combine_transform = beam.CombinePerKey(sum) - - pcoll, = inputs - if not _ENABLE_BEAM_TYPE_CHECKS: - combine_transform = combine_transform.with_input_types(Any) - combine_transform = combine_transform.with_output_types(Any) - - return (pcoll - | 'MergeCountPerToken' >> combine_transform - | 'SwapTokensAndCounts' >> beam.KvSwap()) + """Merges vocabulary accumulators of (token, num) pairs.""" + + def __init__(self, operation, extra_args): + self._vocab_ordering_type = operation.vocab_ordering_type + self._use_adjusted_mutual_info = operation.use_adjusted_mutual_info + self._min_diff_from_avg = operation.min_diff_from_avg + + def expand(self, inputs): + if self._vocab_ordering_type == _VocabOrderingType.WEIGHTED_MUTUAL_INFORMATION: + combine_transform = _MutualInformationTransformMerge( # pylint: disable=no-value-for-parameter + self._use_adjusted_mutual_info, + self._min_diff_from_avg, + compute_weighted=True, + ) + elif self._vocab_ordering_type == _VocabOrderingType.MUTUAL_INFORMATION: + combine_transform = _MutualInformationTransformMerge( # pylint: disable=no-value-for-parameter + self._use_adjusted_mutual_info, + self._min_diff_from_avg, + compute_weighted=False, + ) + elif self._vocab_ordering_type == _VocabOrderingType.WEIGHTED_LABELS: + # TODO(b/199789764) This returns a type outside of + # _VocabAccumulatedIndicatorType union - + # it is only supported by the specific contrib transform (vocab_map). + # This probably should be moved to vocab_map itself in order to not make + # every transform to support this type. + combine_transform = beam.CombinePerKey(sum_labeled_weights) + else: + combine_transform = beam.CombinePerKey(sum) + + (pcoll,) = inputs + if not _ENABLE_BEAM_TYPE_CHECKS: + combine_transform = combine_transform.with_input_types(Any) + combine_transform = combine_transform.with_output_types(Any) + + return ( + pcoll + | "MergeCountPerToken" >> combine_transform + | "SwapTokensAndCounts" >> beam.KvSwap() + ) @common.register_ptransform(analyzer_nodes.VocabularyPrune) @beam.typehints.with_input_types(KV[_VocabIndicatorType, _VocabTokenType]) @beam.typehints.with_output_types(KV[_VocabIndicatorType, _VocabTokenType]) class _VocabularyPruneImpl(beam.PTransform): - """Order, filters and writes the computed vocabulary file.""" - - def __init__(self, operation, extra_args): - self._top_k = operation.top_k - self._frequency_threshold = operation.frequency_threshold - self._informativeness_threshold = operation.informativeness_threshold - self._coverage_top_k = operation.coverage_top_k - self._coverage_frequency_threshold = operation.coverage_frequency_threshold - self._coverage_informativeness_threshold = ( - operation.coverage_informativeness_threshold) - self._key_fn = operation.key_fn - self._input_dtype = operation.input_dtype - - def expand(self, inputs): - if self._top_k is not None and self._top_k < 0: - raise ValueError('top_k for VocabularyImpl should be >= 0 or None, got ' - '{}.'.format(self._top_k)) - if self._frequency_threshold is not None and self._frequency_threshold < 0: - raise ValueError( - 'frequency_threshold for VocabularyImpl should be >= 0 or None, ' - 'got {}.'.format(self._frequency_threshold)) - if self._coverage_top_k is not None and self._coverage_top_k < 0: - raise ValueError('coverage_top_k for VocabularyImpl should be >= 0 or ' - 'None, got {}.'.format(self._coverage_top_k)) - if (self._coverage_frequency_threshold is not None and - self._coverage_frequency_threshold < 0): - raise ValueError( - 'coverage_frequency_threshold for VocabularyImpl should be >= 0 or ' - 'None, got {}.'.format(self._coverage_frequency_threshold)) - pcoll, = inputs - - result = ( - pcoll - | 'ApplyThresholdsAndTopK' >> ( + """Order, filters and writes the computed vocabulary file.""" + + def __init__(self, operation, extra_args): + self._top_k = operation.top_k + self._frequency_threshold = operation.frequency_threshold + self._informativeness_threshold = operation.informativeness_threshold + self._coverage_top_k = operation.coverage_top_k + self._coverage_frequency_threshold = operation.coverage_frequency_threshold + self._coverage_informativeness_threshold = ( + operation.coverage_informativeness_threshold + ) + self._key_fn = operation.key_fn + self._input_dtype = operation.input_dtype + + def expand(self, inputs): + if self._top_k is not None and self._top_k < 0: + raise ValueError( + "top_k for VocabularyImpl should be >= 0 or None, got " + f"{self._top_k}." + ) + if self._frequency_threshold is not None and self._frequency_threshold < 0: + raise ValueError( + "frequency_threshold for VocabularyImpl should be >= 0 or None, " + f"got {self._frequency_threshold}." + ) + if self._coverage_top_k is not None and self._coverage_top_k < 0: + raise ValueError( + "coverage_top_k for VocabularyImpl should be >= 0 or " + f"None, got {self._coverage_top_k}." + ) + if ( + self._coverage_frequency_threshold is not None + and self._coverage_frequency_threshold < 0 + ): + raise ValueError( + "coverage_frequency_threshold for VocabularyImpl should be >= 0 or " + f"None, got {self._coverage_frequency_threshold}." + ) + (pcoll,) = inputs + + result = pcoll | "ApplyThresholdsAndTopK" >> ( _ApplyThresholdsAndTopK( # pylint: disable=no-value-for-parameter - self._frequency_threshold, self._top_k, self._input_dtype, - self._informativeness_threshold, None))) + self._frequency_threshold, + self._top_k, + self._input_dtype, + self._informativeness_threshold, + None, + ) + ) - if self._key_fn: - # Note: current APIs do not allow for specifying a coverage - # informativeness threshold. - coverage_counts = ( - pcoll | 'ApplyCoverageThresholdAndTopK' >> ( - _ApplyThresholdsAndTopK( # pylint: disable=no-value-for-parameter - self._coverage_frequency_threshold, self._coverage_top_k, - self._input_dtype, self._coverage_informativeness_threshold, - self._key_fn))) + if self._key_fn: + # Note: current APIs do not allow for specifying a coverage + # informativeness threshold. + coverage_counts = pcoll | "ApplyCoverageThresholdAndTopK" >> ( + _ApplyThresholdsAndTopK( # pylint: disable=no-value-for-parameter + self._coverage_frequency_threshold, + self._coverage_top_k, + self._input_dtype, + self._coverage_informativeness_threshold, + self._key_fn, + ) + ) - result = ((result, coverage_counts) - | 'MergeStandardAndCoverageArms' >> beam.Flatten() - | 'RemoveDuplicates' >> beam.Distinct()) + result = ( + (result, coverage_counts) + | "MergeStandardAndCoverageArms" >> beam.Flatten() + | "RemoveDuplicates" >> beam.Distinct() + ) - return result + return result @common.register_ptransform(analyzer_nodes.VocabularyOrderAndWrite) @@ -442,1065 +484,1134 @@ def expand(self, inputs): ) @beam.typehints.with_output_types(np.ndarray) class _VocabularyOrderAndWriteImpl(beam.PTransform): - """Writes the computed vocabulary file.""" - - def __init__(self, operation, extra_args): - self._base_temp_dir = extra_args.base_temp_dir - self._store_frequency = operation.store_frequency - self._vocab_filename = operation.vocab_filename - self._fingerprint_shuffle = operation.fingerprint_shuffle - self._input_dtype = operation.input_dtype - self._file_format: common_types.VocabularyFileFormatType = ( - operation.file_format) - self._input_is_sorted = operation.input_is_sorted - - def expand(self, inputs): - reserved_tokens = None - counts = inputs[0] - if len(inputs) > 1: - reserved_tokens = inputs[1] - assert len(inputs) < 3 - vocabulary_file = os.path.join(self._base_temp_dir, self._vocab_filename) - - def fingerprint_sort_fn(kv): - # hashlib.sha1 expects bytes - return hashlib.sha1(kv[1]).digest() - - # TODO(b/62379925) For now force a single file. We can write a sharded - # file instead. - # TODO(b/190580668) Here we are relying on fusion (an implementation - # detail) for the ordering to be maintained when the results are written - # to disk. This includes fusion of `_OrderElementsFn` and writing PTransform - # when `_input_is_sorted` is false and fusion of the last stage in - # `_ApplyThresholdsAndTopK` and writing PTransform when `_input_is_sorted` - # is true. - # Perform the write within the body of `OrderElements` maybe - # `OrderElementsAndWrite`. This would mean using TF IO instead of Beam - # IO so it's perhaps not great. - # Alternatively, we could verify the proper ordering after vocabulary is - # written during `TransformDataset` stage. - if self._file_format == 'text': - write_ptransform = 'WriteToText' >> beam.io.WriteToText( - vocabulary_file, shard_name_template='') - elif self._file_format == 'tfrecord_gzip': - # Setting the suffix as .gz ensures that the vocabulary will be written - # with GZIP compression. - vocabulary_file = '{}.tfrecord.gz'.format(vocabulary_file) - write_ptransform = 'WriteToTFRecord' >> beam.io.WriteToTFRecord( - vocabulary_file, shard_name_template='') - - # TODO(b/282952880): Refactor and allow input_is_sorted and reserved_tokens - # inputs to rely on their sorting, for improved performance. - if self._input_is_sorted and not reserved_tokens: - assert not self._fingerprint_shuffle - if self._store_frequency: - formatted_vocabulary = ( - counts | 'ToBytes' >> beam.MapTuple(_count_and_token_to_bytes)) - else: - formatted_vocabulary = counts | 'ExtractTokens' >> beam.Values() - else: - if self._fingerprint_shuffle: - sort_kwargs = dict(key=fingerprint_sort_fn) - else: - sort_kwargs = dict(reverse=True) # Largest first. - batched_counts = counts | 'BatchAndPreSort' >> _BatchAndPreSort( # pylint: disable=no-value-for-parameter - sort_kwargs=sort_kwargs) - - kwargs = dict(batched_counts_iter=beam.pvalue.AsIter(batched_counts)) - if reserved_tokens: - kwargs.update(reserved_tokens=beam.pvalue.AsSingleton(reserved_tokens)) - formatted_vocabulary = ( - batched_counts.pipeline - | 'Prepare' >> beam.Create([None]) - | 'OrderElements' - >> beam.ParDo( - _OrderElementsFn( - self._store_frequency, sort_kwargs, self._input_dtype - ), - **kwargs - ) - ) - vocab_is_written = formatted_vocabulary | write_ptransform - - # Return the vocabulary path. - wait_for_vocabulary_transform = ( - counts.pipeline - | 'CreatePath' >> beam.Create([np.array(vocabulary_file)]) - # Ensure that the analysis returns only after the file is written. - | 'WaitForVocabularyFile' >> beam.Map( - lambda x, y: x, y=beam.pvalue.AsIter(vocab_is_written))) - return (wait_for_vocabulary_transform,) - - -def _flatten_value_to_list( - batch_values: Tuple[np.ndarray, ...]) -> Iterable[Any]: - """Converts an N-D dense or sparse batch to a 1-D list.""" - batch_value, = batch_values - - # TODO(b/36603294): Perhaps obviate the tolist(). It is currently used so - # that we go to native Python types for more efficient followup - # processing. - return batch_value.tolist() + """Writes the computed vocabulary file.""" + + def __init__(self, operation, extra_args): + self._base_temp_dir = extra_args.base_temp_dir + self._store_frequency = operation.store_frequency + self._vocab_filename = operation.vocab_filename + self._fingerprint_shuffle = operation.fingerprint_shuffle + self._input_dtype = operation.input_dtype + self._file_format: common_types.VocabularyFileFormatType = operation.file_format + self._input_is_sorted = operation.input_is_sorted + + def expand(self, inputs): + reserved_tokens = None + counts = inputs[0] + if len(inputs) > 1: + reserved_tokens = inputs[1] + assert len(inputs) < 3 + vocabulary_file = os.path.join(self._base_temp_dir, self._vocab_filename) + + def fingerprint_sort_fn(kv): + # hashlib.sha1 expects bytes + return hashlib.sha1(kv[1]).digest() + + # TODO(b/62379925) For now force a single file. We can write a sharded + # file instead. + # TODO(b/190580668) Here we are relying on fusion (an implementation + # detail) for the ordering to be maintained when the results are written + # to disk. This includes fusion of `_OrderElementsFn` and writing PTransform + # when `_input_is_sorted` is false and fusion of the last stage in + # `_ApplyThresholdsAndTopK` and writing PTransform when `_input_is_sorted` + # is true. + # Perform the write within the body of `OrderElements` maybe + # `OrderElementsAndWrite`. This would mean using TF IO instead of Beam + # IO so it's perhaps not great. + # Alternatively, we could verify the proper ordering after vocabulary is + # written during `TransformDataset` stage. + if self._file_format == "text": + write_ptransform = "WriteToText" >> beam.io.WriteToText( + vocabulary_file, shard_name_template="" + ) + elif self._file_format == "tfrecord_gzip": + # Setting the suffix as .gz ensures that the vocabulary will be written + # with GZIP compression. + vocabulary_file = f"{vocabulary_file}.tfrecord.gz" + write_ptransform = "WriteToTFRecord" >> beam.io.WriteToTFRecord( + vocabulary_file, shard_name_template="" + ) + + # TODO(b/282952880): Refactor and allow input_is_sorted and reserved_tokens + # inputs to rely on their sorting, for improved performance. + if self._input_is_sorted and not reserved_tokens: + assert not self._fingerprint_shuffle + if self._store_frequency: + formatted_vocabulary = counts | "ToBytes" >> beam.MapTuple( + _count_and_token_to_bytes + ) + else: + formatted_vocabulary = counts | "ExtractTokens" >> beam.Values() + else: + if self._fingerprint_shuffle: + sort_kwargs = dict(key=fingerprint_sort_fn) + else: + sort_kwargs = dict(reverse=True) # Largest first. + batched_counts = counts | "BatchAndPreSort" >> _BatchAndPreSort( # pylint: disable=no-value-for-parameter + sort_kwargs=sort_kwargs + ) + + kwargs = dict(batched_counts_iter=beam.pvalue.AsIter(batched_counts)) + if reserved_tokens: + kwargs.update(reserved_tokens=beam.pvalue.AsSingleton(reserved_tokens)) + formatted_vocabulary = ( + batched_counts.pipeline + | "Prepare" >> beam.Create([None]) + | "OrderElements" + >> beam.ParDo( + _OrderElementsFn( + self._store_frequency, sort_kwargs, self._input_dtype + ), + **kwargs, + ) + ) + vocab_is_written = formatted_vocabulary | write_ptransform + + # Return the vocabulary path. + wait_for_vocabulary_transform = ( + counts.pipeline + | "CreatePath" >> beam.Create([np.array(vocabulary_file)]) + # Ensure that the analysis returns only after the file is written. + | "WaitForVocabularyFile" + >> beam.Map(lambda x, y: x, y=beam.pvalue.AsIter(vocab_is_written)) + ) + return (wait_for_vocabulary_transform,) + + +def _flatten_value_to_list(batch_values: Tuple[np.ndarray, ...]) -> Iterable[Any]: + """Converts an N-D dense or sparse batch to a 1-D list.""" + (batch_value,) = batch_values + + # TODO(b/36603294): Perhaps obviate the tolist(). It is currently used so + # that we go to native Python types for more efficient followup + # processing. + return batch_value.tolist() def _flatten_value_and_weights_to_list_of_tuples( - batch_values: Tuple[np.ndarray, ...]) -> Iterable[Any]: - """Converts a batch of vocabulary and weights to a list of KV tuples.""" - batch_value, weights = batch_values + batch_values: Tuple[np.ndarray, ...], +) -> Iterable[Any]: + """Converts a batch of vocabulary and weights to a list of KV tuples.""" + batch_value, weights = batch_values - # TODO(b/36603294): Perhaps obviate the tolist(). It is currently used so - # that we go to native Python types for more efficient followup - # processing. - batch_value = batch_value.tolist() - weights = weights.tolist() - return zip(batch_value, weights) + # TODO(b/36603294): Perhaps obviate the tolist(). It is currently used so + # that we go to native Python types for more efficient followup + # processing. + batch_value = batch_value.tolist() + weights = weights.tolist() + return zip(batch_value, weights) # Experimental def _flatten_value_and_labeled_weights_to_list_of_tuples( - batch_values: Tuple[np.ndarray, ...]) -> Iterable[Any]: - """Converts a batch of vocabulary and labeled weights to a list of KV tuples. - - Args: - batch_values: A row in the batch consists of a value, a (total) weight, and - a list of weights for each label. - """ - batch_value, weights, labeled_weights = batch_values - - # TODO(b/36603294): Perhaps obviate the tolist(). It is currently used so - # that we go to native Python types for more efficient followup - # processing. - batch_value = batch_value.tolist() - weights = weights.tolist() - labeled_weights = labeled_weights.tolist() - return zip(batch_value, zip(weights, labeled_weights)) - - -def _make_count_and_weights_means_accumulator(sum_positive, weights_sum_total, - count): - """Create a WeightedMeanAndVarCombiner according to the parameter values.""" - # TODO(b/165003832): We're going back and forth from lists to numpy here. - # Can we remove this overhead? - if weights_sum_total is None: - mean = np.array(sum_positive) / count - weight = None - else: - mean = np.array(sum_positive) / weights_sum_total - weight = weights_sum_total / count - - return analyzers.WeightedMeanAndVarCombiner.accumulator_class( - count=np.array(count), - mean=mean, - variance=np.array(0.), # Variance is not used for vocabularies. - weight=weight) - - -def _flatten_to_key_and_means_accumulator_list(batch_values, - compute_weighted=True): - """Converts a batch of keys, weights, and counts to a list of KV pairs.""" - if compute_weighted: - keys, total_weights, positive_label_weights, counts = batch_values - total_weights = total_weights.tolist() - else: - keys, positive_label_weights, counts = batch_values - total_weights = [] - - # TODO(b/36603294): Perhaps obviate the tolist(). It is currently used so - # that we go to native Python types for more efficient followup - # processing. - keys = keys.tolist() - positive_label_weights = positive_label_weights.tolist() - counts = counts.tolist() - - return zip(keys, [ - _make_count_and_weights_means_accumulator(*batch) for batch in - itertools.zip_longest(positive_label_weights, total_weights, counts) - ]) + batch_values: Tuple[np.ndarray, ...], +) -> Iterable[Any]: + """Converts a batch of vocabulary and labeled weights to a list of KV tuples. + + Args: + ---- + batch_values: A row in the batch consists of a value, a (total) weight, and + a list of weights for each label. + """ + batch_value, weights, labeled_weights = batch_values + + # TODO(b/36603294): Perhaps obviate the tolist(). It is currently used so + # that we go to native Python types for more efficient followup + # processing. + batch_value = batch_value.tolist() + weights = weights.tolist() + labeled_weights = labeled_weights.tolist() + return zip(batch_value, zip(weights, labeled_weights)) + + +def _make_count_and_weights_means_accumulator(sum_positive, weights_sum_total, count): + """Create a WeightedMeanAndVarCombiner according to the parameter values.""" + # TODO(b/165003832): We're going back and forth from lists to numpy here. + # Can we remove this overhead? + if weights_sum_total is None: + mean = np.array(sum_positive) / count + weight = None + else: + mean = np.array(sum_positive) / weights_sum_total + weight = weights_sum_total / count + + return analyzers.WeightedMeanAndVarCombiner.accumulator_class( + count=np.array(count), + mean=mean, + variance=np.array(0.0), # Variance is not used for vocabularies. + weight=weight, + ) + + +def _flatten_to_key_and_means_accumulator_list(batch_values, compute_weighted=True): + """Converts a batch of keys, weights, and counts to a list of KV pairs.""" + if compute_weighted: + keys, total_weights, positive_label_weights, counts = batch_values + total_weights = total_weights.tolist() + else: + keys, positive_label_weights, counts = batch_values + total_weights = [] + + # TODO(b/36603294): Perhaps obviate the tolist(). It is currently used so + # that we go to native Python types for more efficient followup + # processing. + keys = keys.tolist() + positive_label_weights = positive_label_weights.tolist() + counts = counts.tolist() + + return zip( + keys, + [ + _make_count_and_weights_means_accumulator(*batch) + for batch in itertools.zip_longest( + positive_label_weights, total_weights, counts + ) + ], + ) def _clip_probability(p, epsilon=1e-6): - return np.clip(p, epsilon, 1 - epsilon) - - -def _calculate_mutual_information_for_feature_value(feature_and_accumulator, - global_accumulator, - use_adjusted_mutual_info, - min_diff_from_avg): - """Calculates the (possibly adjusted) mutual information of a feature value. - - Used as a measure of relatedness between a single feature value and a label. - - Mutual information is calculated as: - H(x, y) = (sum(weights) * - [(P(y|x)*log2(P(y|x)/P(y))) + (P(~y|x)*log2(P(~y|x)/P(~y)))]) - where x is feature and y is label. We use sum(weights) instead of P(x), as - this makes the mutual information more interpretable. - If we don't divide by sum(weights), it can be thought of as an adjusted - weighted count. - - If use_adjusted_mutual_info is True, we use Adjusted Mutual Information (AMI) - which accounts for relatedness due to chance. AMI is generally calculated as: - AMI(x, y) = MI(x, y) - EMI(x, y) / (max(H(x), H(y)) - EMI(x, y)) - where x is the feature and y is label. Here, we leave off the normalization - and only subtract expected mutual information (EMI) from mutual information. - The calculation is based on the following paper: - - Vinh, N. X.; Epps, J.; Bailey, J. (2009). "Information theoretic measures for - clusterings comparison". Proceedings of the 26th Annual International Confere - nce on Machine Learning - ICML '09. p. 1. - doi:10.1145/1553374.1553511. ISBN 9781605585161. - - Short summary can be found in the Wikipedia link: - https://en.wikipedia.org/wiki/Adjusted_mutual_information - - Args: - feature_and_accumulator: A tuple of the form: - (feature, WeightedMeanAndVarCombiner.accumulator_class) where: `feature` - is the single token in the vocabulary for which (possibly adjusted) - mutual information with the label is being computed. `mean` is the - weighted mean positive for each label value given x. `count` is the - count of weights for a feature. `weight` is the mean of the weights for - a feature. - global_accumulator: A WeightedMeanAndVarCombiner.accumulator_class where: - `mean` is the weighted mean positive for each label value for all - features. `count` is the count for all features. `weight` is the mean of - the weights for all features. - use_adjusted_mutual_info: If set to True, use adjusted mutual information. - min_diff_from_avg: A regularization parameter that pushes low MI/AMI towards - zero. The Mutual information of a feature x label pair will be adjusted to - zero whenever the absolute difference the weight and the expected - (average) weight is lower than min_diff_from_average. - - Returns: - A tuple of: - The feature value - The mutual information with the label. If use_adjusted_mutual_info, this - is the mutual information - the expected mutual information, otherwise - it is the raw mutual information. - The expected mutual information (EMI) if use_adjusted_mutual_info is - True, otherwise NaN. - The total weighted sum for the feature value. - """ - # Compute the frequency of each label value. - global_label_counts = ( - global_accumulator.mean * global_accumulator.weight * - global_accumulator.count) - feature_value, current_accumulator = feature_and_accumulator - total_label_counts = sum(global_label_counts) - n = global_accumulator.count * global_accumulator.weight - # TODO(b/168469757): Consider raising here once b/168469757 is resolved. - if round(total_label_counts) != round(n): - logging.warn( - 'Weighted label sum (%s) != total weighted count (%s), label means=%s', - total_label_counts, n, global_accumulator.mean) - if n == 0: - return (feature_value, (float('NaN'), float('NaN'), 0)) - - mutual_information = 0 - expected_mutual_information = 0 if use_adjusted_mutual_info else None - x_i = (current_accumulator.count * current_accumulator.weight) - # If x_i == n, the feature is a constant and thus has no information. - if round(x_i) == round(n): - return feature_value, (0, 0, x_i) - if round(x_i) > round(n): - raise ValueError( - 'Frequency of token {} higher than number of records {} > {}'.format( - feature_value, x_i, n) + - ' This likely means you have provided tft.vocabulary with input that' - ' has repeated tokens per row, rather than a set representation.') - for label_ix in range(len(global_label_counts)): - y_i = global_label_counts[label_ix] - if y_i == 0: - continue - local_mean = 0 - if label_ix < len(current_accumulator.mean): - local_mean = current_accumulator.mean[label_ix] - n_i = ( - _clip_probability(local_mean) * current_accumulator.weight * - current_accumulator.count) - diff_from_avg = (x_i * y_i / n) - n_i - if abs(diff_from_avg) < min_diff_from_avg: - continue - mutual_information += ( - info_theory.calculate_partial_mutual_information(n_i, x_i, y_i, n)) - if use_adjusted_mutual_info: - expected_mutual_information += ( - info_theory.calculate_partial_expected_mutual_information( - n, x_i, y_i)) + return np.clip(p, epsilon, 1 - epsilon) + + +def _calculate_mutual_information_for_feature_value( + feature_and_accumulator, + global_accumulator, + use_adjusted_mutual_info, + min_diff_from_avg, +): + """Calculates the (possibly adjusted) mutual information of a feature value. + + Used as a measure of relatedness between a single feature value and a label. + + Mutual information is calculated as: + H(x, y) = (sum(weights) * + [(P(y|x)*log2(P(y|x)/P(y))) + (P(~y|x)*log2(P(~y|x)/P(~y)))]) + where x is feature and y is label. We use sum(weights) instead of P(x), as + this makes the mutual information more interpretable. + If we don't divide by sum(weights), it can be thought of as an adjusted + weighted count. + + If use_adjusted_mutual_info is True, we use Adjusted Mutual Information (AMI) + which accounts for relatedness due to chance. AMI is generally calculated as: + AMI(x, y) = MI(x, y) - EMI(x, y) / (max(H(x), H(y)) - EMI(x, y)) + where x is the feature and y is label. Here, we leave off the normalization + and only subtract expected mutual information (EMI) from mutual information. + The calculation is based on the following paper: + + Vinh, N. X.; Epps, J.; Bailey, J. (2009). "Information theoretic measures for + clusterings comparison". Proceedings of the 26th Annual International Confere + nce on Machine Learning - ICML '09. p. 1. + doi:10.1145/1553374.1553511. ISBN 9781605585161. + + Short summary can be found in the Wikipedia link: + https://en.wikipedia.org/wiki/Adjusted_mutual_information - if use_adjusted_mutual_info: - # TODO(b/127366670): Consider implementing the normalization step as per - # AMI(x, y) = MI(x, y) - EMI(x, y) / (max(H(x), H(y)) - EMI(x, y)) - return (feature_value, (mutual_information - expected_mutual_information, - expected_mutual_information, x_i)) - else: - return (feature_value, (mutual_information, float('NaN'), x_i)) + Args: + ---- + feature_and_accumulator: A tuple of the form: + (feature, WeightedMeanAndVarCombiner.accumulator_class) where: `feature` + is the single token in the vocabulary for which (possibly adjusted) + mutual information with the label is being computed. `mean` is the + weighted mean positive for each label value given x. `count` is the + count of weights for a feature. `weight` is the mean of the weights for + a feature. + global_accumulator: A WeightedMeanAndVarCombiner.accumulator_class where: + `mean` is the weighted mean positive for each label value for all + features. `count` is the count for all features. `weight` is the mean of + the weights for all features. + use_adjusted_mutual_info: If set to True, use adjusted mutual information. + min_diff_from_avg: A regularization parameter that pushes low MI/AMI towards + zero. The Mutual information of a feature x label pair will be adjusted to + zero whenever the absolute difference the weight and the expected + (average) weight is lower than min_diff_from_average. + + Returns: + ------- + A tuple of: + The feature value + The mutual information with the label. If use_adjusted_mutual_info, this + is the mutual information - the expected mutual information, otherwise + it is the raw mutual information. + The expected mutual information (EMI) if use_adjusted_mutual_info is + True, otherwise NaN. + The total weighted sum for the feature value. + """ + # Compute the frequency of each label value. + global_label_counts = ( + global_accumulator.mean * global_accumulator.weight * global_accumulator.count + ) + feature_value, current_accumulator = feature_and_accumulator + total_label_counts = sum(global_label_counts) + n = global_accumulator.count * global_accumulator.weight + # TODO(b/168469757): Consider raising here once b/168469757 is resolved. + if round(total_label_counts) != round(n): + logging.warn( + "Weighted label sum (%s) != total weighted count (%s), label means=%s", + total_label_counts, + n, + global_accumulator.mean, + ) + if n == 0: + return (feature_value, (float("NaN"), float("NaN"), 0)) + + mutual_information = 0 + expected_mutual_information = 0 if use_adjusted_mutual_info else None + x_i = current_accumulator.count * current_accumulator.weight + # If x_i == n, the feature is a constant and thus has no information. + if round(x_i) == round(n): + return feature_value, (0, 0, x_i) + if round(x_i) > round(n): + raise ValueError( + f"Frequency of token {feature_value} higher than number of records {x_i} > {n}" + + " This likely means you have provided tft.vocabulary with input that" + " has repeated tokens per row, rather than a set representation." + ) + for label_ix in range(len(global_label_counts)): + y_i = global_label_counts[label_ix] + if y_i == 0: + continue + local_mean = 0 + if label_ix < len(current_accumulator.mean): + local_mean = current_accumulator.mean[label_ix] + n_i = ( + _clip_probability(local_mean) + * current_accumulator.weight + * current_accumulator.count + ) + diff_from_avg = (x_i * y_i / n) - n_i + if abs(diff_from_avg) < min_diff_from_avg: + continue + mutual_information += info_theory.calculate_partial_mutual_information( + n_i, x_i, y_i, n + ) + if use_adjusted_mutual_info: + expected_mutual_information += ( + info_theory.calculate_partial_expected_mutual_information(n, x_i, y_i) + ) + + if use_adjusted_mutual_info: + # TODO(b/127366670): Consider implementing the normalization step as per + # AMI(x, y) = MI(x, y) - EMI(x, y) / (max(H(x), H(y)) - EMI(x, y)) + return ( + feature_value, + ( + mutual_information - expected_mutual_information, + expected_mutual_information, + x_i, + ), + ) + else: + return (feature_value, (mutual_information, float("NaN"), x_i)) @ptransform_fn @beam.typehints.with_input_types( - KV[_VocabTokenType, analyzers.WeightedMeanAndVarCombiner.accumulator_class]) -@beam.typehints.with_output_types(KV[_VocabTokenType, - _VocabAccumulatedIndicatorType]) + KV[_VocabTokenType, analyzers.WeightedMeanAndVarCombiner.accumulator_class] +) +@beam.typehints.with_output_types(KV[_VocabTokenType, _VocabAccumulatedIndicatorType]) def _MutualInformationTransformAccumulate(pcol, compute_weighted=True): # pylint: disable=invalid-name - """Accumulates information needed for mutual information computation.""" - return (pcol | 'VocabCountPerLabelPerTokenAccumulate' >> beam.CombinePerKey( - _WeightedMeanCombineFn( - output_shape=(None,), compute_weighted=compute_weighted))) + """Accumulates information needed for mutual information computation.""" + return pcol | "VocabCountPerLabelPerTokenAccumulate" >> beam.CombinePerKey( + _WeightedMeanCombineFn(output_shape=(None,), compute_weighted=compute_weighted) + ) def _extract_sentinels(kv): - """Separate out label sentinel accumulators from vocab accumulators. + """Separate out label sentinel accumulators from vocab accumulators. - To keep track of the frequencies of label values, we store global label - frequencies associated with a special sentinel value. These are accumulated - just like other vocabulary tokens, but must be separated out before computing - mutual information. + To keep track of the frequencies of label values, we store global label + frequencies associated with a special sentinel value. These are accumulated + just like other vocabulary tokens, but must be separated out before computing + mutual information. - Args: - kv: tuple of key, accumulator + Args: + ---- + kv: tuple of key, accumulator - Yields: - A Beam TaggedOutout separating the sentinel and regular tokens. - """ - token, _ = kv - if (token == tf_utils.GLOBAL_Y_COUNT_SENTINEL_STRING or - token == tf_utils.GLOBAL_Y_COUNT_SENTINEL_INT): - # Throw away the sentinel token, since it's not needed. - yield beam.pvalue.TaggedOutput('global', kv[1]) - else: - yield beam.pvalue.TaggedOutput('feature', kv) + Yields: + ------ + A Beam TaggedOutout separating the sentinel and regular tokens. + """ + token, _ = kv + if ( + token == tf_utils.GLOBAL_Y_COUNT_SENTINEL_STRING + or token == tf_utils.GLOBAL_Y_COUNT_SENTINEL_INT + ): + # Throw away the sentinel token, since it's not needed. + yield beam.pvalue.TaggedOutput("global", kv[1]) + else: + yield beam.pvalue.TaggedOutput("feature", kv) @ptransform_fn -@beam.typehints.with_input_types(KV[_VocabTokenType, - _VocabAccumulatedIndicatorType]) +@beam.typehints.with_input_types(KV[_VocabTokenType, _VocabAccumulatedIndicatorType]) @beam.typehints.with_output_types(KV[_VocabTokenType, Tuple[float, float]]) def _MutualInformationTransformMerge( # pylint: disable=invalid-name - pcol, use_adjusted_mutual_info, min_diff_from_avg, compute_weighted): - """Computes mutual information for each key using the given accumulators.""" - feature_accumulator_pcol = ( - pcol | 'VocabCountPerLabelPerTokenMerge' >> beam.CombinePerKey( - _WeightedMeanCombineFn( - output_shape=(None,), compute_weighted=compute_weighted))) - - accumulators_by_feature, global_accumulator = ( - feature_accumulator_pcol - | 'ExtractSentinels' >> beam.FlatMap(_extract_sentinels).with_outputs( - 'feature', 'global')) - if min_diff_from_avg is None: - min_diff_from_avg = ( - global_accumulator | 'AutoMinDiffFromAvg' >> - beam.Map(lambda acc: analyzers.calculate_recommended_min_diff_from_avg( # pylint: disable=g-long-lambda - acc.count * acc.weight))) - min_diff_from_avg = beam.pvalue.AsSingleton(min_diff_from_avg) - - def _extract_merged_values(term, results): - """Returns the key and tuple of (mutual information, frequency).""" - # Ignore the second value, which is the Expected Mutual Info. - (mi, _, frequency) = results - return term, (mi, frequency) - - return (accumulators_by_feature - | 'CalculateMutualInformationPerToken' >> beam.Map( - _calculate_mutual_information_for_feature_value, - beam.pvalue.AsSingleton(global_accumulator), - use_adjusted_mutual_info=use_adjusted_mutual_info, - min_diff_from_avg=min_diff_from_avg) - | beam.MapTuple(_extract_merged_values)) + pcol, use_adjusted_mutual_info, min_diff_from_avg, compute_weighted +): + """Computes mutual information for each key using the given accumulators.""" + feature_accumulator_pcol = ( + pcol + | "VocabCountPerLabelPerTokenMerge" + >> beam.CombinePerKey( + _WeightedMeanCombineFn( + output_shape=(None,), compute_weighted=compute_weighted + ) + ) + ) + + accumulators_by_feature, global_accumulator = ( + feature_accumulator_pcol + | "ExtractSentinels" + >> beam.FlatMap(_extract_sentinels).with_outputs("feature", "global") + ) + if min_diff_from_avg is None: + min_diff_from_avg = global_accumulator | "AutoMinDiffFromAvg" >> beam.Map( + lambda acc: analyzers.calculate_recommended_min_diff_from_avg( # pylint: disable=g-long-lambda + acc.count * acc.weight + ) + ) + min_diff_from_avg = beam.pvalue.AsSingleton(min_diff_from_avg) + def _extract_merged_values(term, results): + """Returns the key and tuple of (mutual information, frequency).""" + # Ignore the second value, which is the Expected Mutual Info. + (mi, _, frequency) = results + return term, (mi, frequency) -class _WeightedMeanCombineFn(beam.CombineFn): - """_WeightedMeanCombineFn calculates total count and weighted means.""" + return ( + accumulators_by_feature + | "CalculateMutualInformationPerToken" + >> beam.Map( + _calculate_mutual_information_for_feature_value, + beam.pvalue.AsSingleton(global_accumulator), + use_adjusted_mutual_info=use_adjusted_mutual_info, + min_diff_from_avg=min_diff_from_avg, + ) + | beam.MapTuple(_extract_merged_values) + ) - def __init__(self, output_shape, compute_weighted=True): - self._combiner = analyzers.WeightedMeanAndVarCombiner( - np.float32, - output_shape=output_shape, - compute_variance=False, - compute_weighted=compute_weighted) - def create_accumulator(self): - """Create an accumulator with all zero entries.""" - return self._combiner.create_accumulator() +class _WeightedMeanCombineFn(beam.CombineFn): + """_WeightedMeanCombineFn calculates total count and weighted means.""" + + def __init__(self, output_shape, compute_weighted=True): + self._combiner = analyzers.WeightedMeanAndVarCombiner( + np.float32, + output_shape=output_shape, + compute_variance=False, + compute_weighted=compute_weighted, + ) - def add_input(self, accumulator, batch_values): - """Composes an accumulator from batch_values and calls merge_accumulators. + def create_accumulator(self): + """Create an accumulator with all zero entries.""" + return self._combiner.create_accumulator() + + def add_input(self, accumulator, batch_values): + """Composes an accumulator from batch_values and calls merge_accumulators. + + Args: + ---- + accumulator: The `WeightedMeanAndVarCombiner.accumulator_class` computed + so far. + batch_values: A `WeightedMeanAndVarCombiner.accumulator_class` for the + current batch. + + Returns: + ------- + A `WeightedMeanAndVarCombiner.accumulator_class` which is accumulator and + batch_values + combined. + """ + return self._combiner.add_input(accumulator, batch_values) + + def merge_accumulators(self, accumulators): + """Merges several `WeightedMeanAndVarCombiner.accumulator_class`s. + + Args: + ---- + accumulators: A list of `WeightedMeanAndVarCombiner.accumulator_class`s + and/or Nones. + + Returns: + ------- + The sole merged `WeightedMeanAndVarCombiner.accumulator_class`. + """ + return self._combiner.merge_accumulators(accumulators) + + def extract_output(self, accumulator): + """Returns the accumulator as the output. + + Args: + ---- + accumulator: the final `WeightedMeanAndVarCombiner.accumulator_class` + value. + + Returns: + ------- + The accumulator which could be None. + """ + return self._combiner.extract_output(accumulator) - Args: - accumulator: The `WeightedMeanAndVarCombiner.accumulator_class` computed - so far. - batch_values: A `WeightedMeanAndVarCombiner.accumulator_class` for the - current batch. - Returns: - A `WeightedMeanAndVarCombiner.accumulator_class` which is accumulator and - batch_values - combined. - """ - return self._combiner.add_input(accumulator, batch_values) +class _CombinerWrapper(beam.CombineFn): + """Class to wrap a analyzer_nodes.Combiner as a beam.CombineFn.""" + + def __init__(self, combiner, is_combining_accumulators): + """Init method for _CombinerWrapper. + + Args: + ---- + combiner: A `analyzer_nodes.Combiner` object used to combine. + is_combining_accumulators: A bool which indicates whether this is + combining single or batched inputs, or already accumulated objects. In + the former case, output of the CombineFn is an accumulator, whereas + in the latter case, output is extracted from the combined accumulators + using the combiner's extract_output. + """ + self._combiner = combiner + self._is_combining_accumulators = is_combining_accumulators + + def create_accumulator(self): + return self._combiner.create_accumulator() + + def add_input(self, accumulator, next_input): + if self._is_combining_accumulators: + # First accumulator can be None. + accumulators = [] + if accumulator is not None: + accumulators.append(accumulator) + if next_input is not None: + accumulators.append(next_input) + return self.merge_accumulators(accumulators) + return self._combiner.add_input(accumulator, next_input) + + def merge_accumulators(self, accumulators): + return self._combiner.merge_accumulators(accumulators) + + def compact(self, accumulator): + return self._combiner.compact(accumulator) + + def extract_output(self, accumulator): + if self._is_combining_accumulators: + return self._combiner.extract_output(accumulator) + return accumulator - def merge_accumulators(self, accumulators): - """Merges several `WeightedMeanAndVarCombiner.accumulator_class`s. - Args: - accumulators: A list of `WeightedMeanAndVarCombiner.accumulator_class`s - and/or Nones. +@beam.typehints.with_input_types(Union[Dict[str, Any], Tuple[str, Any]]) +@beam.typehints.with_output_types(Dict[str, Any]) +class _PackedCombinerWrapper(beam.combiners.TupleCombineFn): + """Class to wrap a analyzer_nodes.Combiner as a beam.CombineFn. - Returns: - The sole merged `WeightedMeanAndVarCombiner.accumulator_class`. + PackedCombineWrapper is used for combining input batches as well as + accumulators. When combining input batches, the input is a PCollection of + Dicts from feature keys to numpy arrays. When combining accumulators, the + input is a PCollection of tuples (key, accumulator), where the key represents + the individual combine label that is being packed. """ - return self._combiner.merge_accumulators(accumulators) - def extract_output(self, accumulator): - """Returns the accumulator as the output. + def __init__(self, combiner_ops, is_combining_accumulators): + """Init method for _PackedCombinerWrapper. + + Args: + ---- + combiner_ops: A List `analysis_graph_builder._CombinerOpWrapper` objects. + is_combining_accumulators: A bool which indicates whether this is + combining single or batched inputs, or already accumulated objects. + """ + super().__init__( + *[ + _CombinerWrapper(c.combiner, is_combining_accumulators) + for c in combiner_ops + ] + ) + self._is_combining_accumulators = is_combining_accumulators + if self._is_combining_accumulators: + # When combining accumulators, we expect to have only a single key which + # represents the label of the individual combine. + for op in combiner_ops: + assert len(op.keys) == 1 + self._combiner_label_to_index = { + op.keys[0]: index for index, op in enumerate(combiner_ops) + } + else: + self._combiner_keys = [c.keys for c in combiner_ops] + self._combiner_labels = [c.label for c in combiner_ops] + + def add_input(self, accumulator, element): + if self._is_combining_accumulators: + key, value = element + index = self._combiner_label_to_index[key] + accumulator[index] = self._combiners[index].add_input( + accumulator[index], value + ) + return accumulator + else: + return super().add_input( + accumulator, + [tuple(element[key] for key in keys) for keys in self._combiner_keys], + ) - Args: - accumulator: the final `WeightedMeanAndVarCombiner.accumulator_class` - value. + def extract_output(self, accumulator): + outputs = super().extract_output(accumulator) + return { + combiner_label: output + for combiner_label, output in zip(self._combiner_labels, outputs) + } - Returns: - The accumulator which could be None. - """ - return self._combiner.extract_output(accumulator) +def _split_inputs_by_key(batch_values): + """Takes inputs where first input is a key, and returns (key, value) pairs. -class _CombinerWrapper(beam.CombineFn): - """Class to wrap a analyzer_nodes.Combiner as a beam.CombineFn.""" + Takes inputs of the form (key, arg0, ..., arg{N-1}) where `key` is a vector + and arg0, ..., arg{N-1} have dimension >1 with size in the first dimension + matching `key`. - def __init__(self, - combiner, - is_combining_accumulators): - """Init method for _CombinerWrapper. + It yields pairs of the form - Args: - combiner: A `analyzer_nodes.Combiner` object used to combine. - is_combining_accumulators: A bool which indicates whether this is - combining single or batched inputs, or already accumulated objects. In - the former case, output of the CombineFn is an accumulator, whereas - in the latter case, output is extracted from the combined accumulators - using the combiner's extract_output. - """ - self._combiner = combiner - self._is_combining_accumulators = is_combining_accumulators + (key[i], [arg0[i], ..., arg{N-1}[i]]) - def create_accumulator(self): - return self._combiner.create_accumulator() + for 0 < i < len(key). - def add_input(self, accumulator, next_input): - if self._is_combining_accumulators: - # First accumulator can be None. - accumulators = [] - if accumulator is not None: - accumulators.append(accumulator) - if next_input is not None: - accumulators.append(next_input) - return self.merge_accumulators(accumulators) - return self._combiner.add_input(accumulator, next_input) + Args: + ---- + batch_values: A list of ndarrays representing the input from a batch. - def merge_accumulators(self, accumulators): - return self._combiner.merge_accumulators(accumulators) + Yields: + ------ + (key, args) pairs where args is a list of ndarrays. - def compact(self, accumulator): - return self._combiner.compact(accumulator) + Raises: + ------ + ValueError: if inputs do not have correct sizes. + """ + # TODO(b/77873002): Raise these errors in the graph where more informative + # errors can be generated. Keep these as a fallback for user-defined + # `Combiner`s. + keys = batch_values[0] + if keys.ndim != 1: + raise ValueError( + f"keys for CombinePerKey should have rank 1, got shape {keys.shape}" + ) + for arg_index, arg_values in enumerate(batch_values[1:]): + if arg_values.ndim < 1: + raise ValueError( + f"Argument {arg_index} for CombinePerKey should have rank >=1, " + f"got shape {arg_values.shape}" + ) + if arg_values.shape[0] != keys.shape[0]: + raise ValueError( + f"Argument {arg_index} had shape {arg_values.shape} whose first dimension was not equal to the " + f"size of the keys vector ({keys.shape[0]})" + ) - def extract_output(self, accumulator): - if self._is_combining_accumulators: - return self._combiner.extract_output(accumulator) - return accumulator + for instance_index, key in enumerate(keys): + instance_args = [arg_values[instance_index] for arg_values in batch_values[1:]] + yield (key, instance_args) -@beam.typehints.with_input_types(Union[Dict[str, Any], Tuple[str, Any]]) -@beam.typehints.with_output_types(Dict[str, Any]) -class _PackedCombinerWrapper(beam.combiners.TupleCombineFn): - """Class to wrap a analyzer_nodes.Combiner as a beam.CombineFn. +def _merge_outputs_by_key(keys_and_outputs, outputs_dtype): + """Merge outputs of analyzers per key into a single output. - PackedCombineWrapper is used for combining input batches as well as - accumulators. When combining input batches, the input is a PCollection of - Dicts from feature keys to numpy arrays. When combining accumulators, the - input is a PCollection of tuples (key, accumulator), where the key represents - the individual combine label that is being packed. - """ + Takes a list of elements of the form (key, [output0, ..., output{N-1}]) and + returns a list of ndarrays of the form [keys, outputs0, ..., outputs[{N-1}]] + where keys is formed by stacking the values of `key` from the list and + similarly outputs{k} is formed by stacking the individual elements of + output{k} from the list. - def __init__(self, - combiner_ops, - is_combining_accumulators): - """Init method for _PackedCombinerWrapper. + For each k, output{k} must be an ndarray whose size is the same for each + element of the list. Args: - combiner_ops: A List `analysis_graph_builder._CombinerOpWrapper` objects. - is_combining_accumulators: A bool which indicates whether this is - combining single or batched inputs, or already accumulated objects. + ---- + keys_and_outputs: A list of elements of the form + (key, [output0, ..., output{N-1}]) + outputs_dtype: A list of tf.DType. Each element corresponds to an output. + + Yields: + ------ + The `TaggedOutput`s: keys, outputs0, ..., outputs[{N-1}] + + Raises: + ------ + ValueError: If the number is outputs doesn't match num_outputs. """ - super().__init__(*[ - _CombinerWrapper(c.combiner, is_combining_accumulators) - for c in combiner_ops - ]) - self._is_combining_accumulators = is_combining_accumulators - if self._is_combining_accumulators: - # When combining accumulators, we expect to have only a single key which - # represents the label of the individual combine. - for op in combiner_ops: - assert len(op.keys) == 1 - self._combiner_label_to_index = { - op.keys[0]: index for index, op in enumerate(combiner_ops)} - else: - self._combiner_keys = [c.keys for c in combiner_ops] - self._combiner_labels = [c.label for c in combiner_ops] - - def add_input(self, accumulator, element): - if self._is_combining_accumulators: - key, value = element - index = self._combiner_label_to_index[key] - accumulator[index] = self._combiners[index].add_input( - accumulator[index], value) - return accumulator + num_outputs = len(outputs_dtype) + + # Sort a copy of keys_and_outputs by keys. + sorted_keys_and_outputs = sorted(keys_and_outputs, key=lambda x: x[0]) + + # Convert from a list of pairs of the form (key, outputs_for_key) to a list of + # keys and a list of outputs (where the outer dimension is the number of + # outputs not the number of keys). + key = [] + outputs = [] + for k, o in sorted_keys_and_outputs: + key.append(k) + outputs.append(o) + if not outputs: + outputs = [[]] * num_outputs else: - return super().add_input( - accumulator, - [tuple(element[key] for key in keys) for keys in self._combiner_keys]) - - def extract_output(self, accumulator): - outputs = super().extract_output(accumulator) - return { - combiner_label: output - for combiner_label, output in zip(self._combiner_labels, outputs) - } - - -def _split_inputs_by_key(batch_values): - """Takes inputs where first input is a key, and returns (key, value) pairs. - - Takes inputs of the form (key, arg0, ..., arg{N-1}) where `key` is a vector - and arg0, ..., arg{N-1} have dimension >1 with size in the first dimension - matching `key`. - - It yields pairs of the form - - (key[i], [arg0[i], ..., arg{N-1}[i]]) - - for 0 < i < len(key). - - Args: - batch_values: A list of ndarrays representing the input from a batch. - - Yields: - (key, args) pairs where args is a list of ndarrays. - - Raises: - ValueError: if inputs do not have correct sizes. - """ - # TODO(b/77873002): Raise these errors in the graph where more informative - # errors can be generated. Keep these as a fallback for user-defined - # `Combiner`s. - keys = batch_values[0] - if keys.ndim != 1: - raise ValueError( - 'keys for CombinePerKey should have rank 1, got shape {}'.format( - keys.shape)) - for arg_index, arg_values in enumerate(batch_values[1:]): - if arg_values.ndim < 1: - raise ValueError( - 'Argument {} for CombinePerKey should have rank >=1, ' - 'got shape {}'.format(arg_index, arg_values.shape)) - if arg_values.shape[0] != keys.shape[0]: - raise ValueError( - 'Argument {} had shape {} whose first dimension was not equal to the ' - 'size of the keys vector ({})'.format( - arg_index, arg_values.shape, keys.shape[0])) - - for instance_index, key in enumerate(keys): - instance_args = [arg_values[instance_index] - for arg_values in batch_values[1:]] - yield (key, instance_args) - - -def _merge_outputs_by_key(keys_and_outputs, outputs_dtype): - """Merge outputs of analyzers per key into a single output. - - Takes a list of elements of the form (key, [output0, ..., output{N-1}]) and - returns a list of ndarrays of the form [keys, outputs0, ..., outputs[{N-1}]] - where keys is formed by stacking the values of `key` from the list and - similarly outputs{k} is formed by stacking the individual elements of - output{k} from the list. - - For each k, output{k} must be an ndarray whose size is the same for each - element of the list. - - Args: - keys_and_outputs: A list of elements of the form - (key, [output0, ..., output{N-1}]) - outputs_dtype: A list of tf.DType. Each element corresponds to an output. - - Yields: - The `TaggedOutput`s: keys, outputs0, ..., outputs[{N-1}] - - Raises: - ValueError: If the number is outputs doesn't match num_outputs. - """ - num_outputs = len(outputs_dtype) - - # Sort a copy of keys_and_outputs by keys. - sorted_keys_and_outputs = sorted(keys_and_outputs, key=lambda x: x[0]) - - # Convert from a list of pairs of the form (key, outputs_for_key) to a list of - # keys and a list of outputs (where the outer dimension is the number of - # outputs not the number of keys). - key = [] - outputs = [] - for k, o in sorted_keys_and_outputs: - key.append(k) - outputs.append(o) - if not outputs: - outputs = [[]] * num_outputs - else: - outputs = list(zip(*outputs)) - yield beam.pvalue.TaggedOutput('key', - np.array(key, dtype=tf.string.as_numpy_dtype)) - if len(outputs) != num_outputs: - raise ValueError( - 'Analyzer has {} outputs but its implementation produced {} ' - 'values'.format(num_outputs, len(outputs))) - for i, (output, dtype) in enumerate(zip(outputs, outputs_dtype)): - yield beam.pvalue.TaggedOutput(str(i), np.array(output, - dtype=dtype.as_numpy_dtype)) + outputs = list(zip(*outputs)) + yield beam.pvalue.TaggedOutput("key", np.array(key, dtype=tf.string.as_numpy_dtype)) + if len(outputs) != num_outputs: + raise ValueError( + f"Analyzer has {num_outputs} outputs but its implementation produced {len(outputs)} " + "values" + ) + for i, (output, dtype) in enumerate(zip(outputs, outputs_dtype)): + yield beam.pvalue.TaggedOutput( + str(i), np.array(output, dtype=dtype.as_numpy_dtype) + ) def _make_strictly_increasing_boundaries_rows( - boundary_matrix: np.ndarray) -> np.ndarray: - """Converts a 2-d array of increasing rows to strictly increasing rows. + boundary_matrix: np.ndarray, +) -> np.ndarray: + """Converts a 2-d array of increasing rows to strictly increasing rows. - Args: - boundary_matrix: A 2-d np.array where each row is increasing. + Args: + ---- + boundary_matrix: A 2-d np.array where each row is increasing. - Returns: - A 2-d np.array of the same size as `boundary_matrix` where each row is - strictly increasing. - """ - epsilon = (1e-6 * - np.expand_dims(boundary_matrix[:, -1] - boundary_matrix[:, 0], 1)) + Returns: + ------- + A 2-d np.array of the same size as `boundary_matrix` where each row is + strictly increasing. + """ + epsilon = 1e-6 * np.expand_dims(boundary_matrix[:, -1] - boundary_matrix[:, 0], 1) - # Make sure every value in epsilon is positive. - epsilon[epsilon <= 0] = 1e-6 + # Make sure every value in epsilon is positive. + epsilon[epsilon <= 0] = 1e-6 - deltas = np.diff(boundary_matrix, axis=1) - corrected_deltas = np.maximum(deltas, epsilon) + deltas = np.diff(boundary_matrix, axis=1) + corrected_deltas = np.maximum(deltas, epsilon) - # Reconstruct the matrix with corrected deltas without the 1st column. - corrected_boundaries = ( - np.cumsum(corrected_deltas, axis=1) + - np.expand_dims(boundary_matrix[:, 0], 1)) + # Reconstruct the matrix with corrected deltas without the 1st column. + corrected_boundaries = np.cumsum(corrected_deltas, axis=1) + np.expand_dims( + boundary_matrix[:, 0], 1 + ) - # Reinsert the 1st column. - return np.insert(corrected_boundaries, 0, boundary_matrix[:, 0], axis=1) + # Reinsert the 1st column. + return np.insert(corrected_boundaries, 0, boundary_matrix[:, 0], axis=1) def _join_boundary_rows( - boundary_matrix: np.ndarray + boundary_matrix: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Joins boundaries per key, by scaling and shifting them. + """Joins boundaries per key, by scaling and shifting them. - This returns a new list of boundaries which is composed from the given 2-d - array. For each row we compute a scale factor, and a shift value which are - used to compute the transformed boundaries, and should be used to transform - a value before its bucket is computed. + This returns a new list of boundaries which is composed from the given 2-d + array. For each row we compute a scale factor, and a shift value which are + used to compute the transformed boundaries, and should be used to transform + a value before its bucket is computed. - Neighboring key bucket boundaries have their adjacent boundaries merged into - one. + Neighboring key bucket boundaries have their adjacent boundaries merged into + one. - Args: - boundary_matrix: A 2-d np.array where each row is a list of boundaries for a - certain key. + Args: + ---- + boundary_matrix: A 2-d np.array where each row is a list of boundaries for a + certain key. - Returns: - A 4-tuple of (boundaries, scale, shift, num_buckets). - The returned boundaries is a 1-d np.array of size: - ((num_buckets - 2) * num_keys) + 1 - """ - boundary_matrix = _make_strictly_increasing_boundaries_rows(boundary_matrix) + Returns: + ------- + A 4-tuple of (boundaries, scale, shift, num_buckets). + The returned boundaries is a 1-d np.array of size: + ((num_buckets - 2) * num_keys) + 1 + """ + boundary_matrix = _make_strictly_increasing_boundaries_rows(boundary_matrix) - num_buckets = np.array(boundary_matrix.shape[1] + 1, dtype=np.int64) + num_buckets = np.array(boundary_matrix.shape[1] + 1, dtype=np.int64) - # Min boundary for each row. - min_boundary = np.min(boundary_matrix, axis=1) + # Min boundary for each row. + min_boundary = np.min(boundary_matrix, axis=1) - # Max boundary for each row. - max_boundary = np.max(boundary_matrix, axis=1) + # Max boundary for each row. + max_boundary = np.max(boundary_matrix, axis=1) - boundary_difference = max_boundary - min_boundary - scale = np.divide( - 1.0, - boundary_difference, - out=np.ones_like(boundary_difference), - where=boundary_difference != 0) + boundary_difference = max_boundary - min_boundary + scale = np.divide( + 1.0, + boundary_difference, + out=np.ones_like(boundary_difference), + where=boundary_difference != 0, + ) - # Shifts what would shift values so that when applied to min[key_id] we - # get: min[key_id] * scale[key_id] + shift[key_id] = key_id - # Therefore shift is defined as: - # shift[key_id] = key_id - min[key_id] * scale[key_id] - shift = np.arange(scale.size, dtype=np.float32) - min_boundary * scale + # Shifts what would shift values so that when applied to min[key_id] we + # get: min[key_id] * scale[key_id] + shift[key_id] = key_id + # Therefore shift is defined as: + # shift[key_id] = key_id - min[key_id] * scale[key_id] + shift = np.arange(scale.size, dtype=np.float32) - min_boundary * scale - scaled_buckets = ( - boundary_matrix[:, 1:] * np.expand_dims(scale, axis=1) + - np.expand_dims(shift, axis=1)) - boundaries = np.insert(scaled_buckets.flatten(), 0, 0.) + scaled_buckets = boundary_matrix[:, 1:] * np.expand_dims( + scale, axis=1 + ) + np.expand_dims(shift, axis=1) + boundaries = np.insert(scaled_buckets.flatten(), 0, 0.0) - return boundaries, scale, shift, num_buckets + return boundaries, scale, shift, num_buckets -@common.register_ptransform( - analyzer_nodes.ScaleAndFlattenPerKeyBucketBouandaries) +@common.register_ptransform(analyzer_nodes.ScaleAndFlattenPerKeyBucketBouandaries) class _ScaleAndFlattenPerKeyBucketBouandariesImpl(beam.PTransform): - """Combines boundaries per-key to a single list of boundaries.""" - - _OUTPUT_TAGS = ('boundaries', 'scale_factor_per_key', 'shift_per_key', - 'num_buckets') - - def __init__(self, operation, extra_args): - self._dtype = operation.output_tensor_dtype - self._name = operation.label - - def _transform_boundaries(self, boundary_matrix): - results = _join_boundary_rows(boundary_matrix) - assert len(self._OUTPUT_TAGS) == len(results) - return [ - beam.pvalue.TaggedOutput(tag, value) - for tag, value in zip(self._OUTPUT_TAGS, results) - ] - - def expand(self, inputs): - pcoll, = inputs - output_dict = pcoll | beam.FlatMap( - self._transform_boundaries).with_outputs(*self._OUTPUT_TAGS) - return tuple(output_dict[key] for key in self._OUTPUT_TAGS) + """Combines boundaries per-key to a single list of boundaries.""" + + _OUTPUT_TAGS = ( + "boundaries", + "scale_factor_per_key", + "shift_per_key", + "num_buckets", + ) + + def __init__(self, operation, extra_args): + self._dtype = operation.output_tensor_dtype + self._name = operation.label + + def _transform_boundaries(self, boundary_matrix): + results = _join_boundary_rows(boundary_matrix) + assert len(self._OUTPUT_TAGS) == len(results) + return [ + beam.pvalue.TaggedOutput(tag, value) + for tag, value in zip(self._OUTPUT_TAGS, results) + ] + + def expand(self, inputs): + (pcoll,) = inputs + output_dict = pcoll | beam.FlatMap(self._transform_boundaries).with_outputs( + *self._OUTPUT_TAGS + ) + return tuple(output_dict[key] for key in self._OUTPUT_TAGS) @common.register_ptransform(analyzer_nodes.PackedCombineAccumulate) @beam.typehints.with_input_types(Dict[str, Any]) @beam.typehints.with_output_types(Dict[str, Any]) class _InitialAccumulatePackedCombineImpl(beam.PTransform): - """Implement an packed analyzer accumulate based on a Combine.""" - - def __init__(self, operation, extra_args): - self._combiners = operation.combiners - - def expand(self, inputs): - pcoll, = inputs - # We specify a fanout so that the packed combiner doesn't exhibit stragglers - # during the 'reduce' phase when we have a lot of combine analyzers packed. - fanout = int(math.ceil(math.sqrt(len(self._combiners)))) - fanout = max(_DEFAULT_COMBINE_GLOBALLY_FANOUT, fanout) - return ( - pcoll - | 'InitialPackedCombineGlobally' >> beam.CombineGlobally( - _PackedCombinerWrapper( - self._combiners, - is_combining_accumulators=False - ) - ).with_fanout(fanout) - | 'Count' >> - common.IncrementCounter('num_packed_accumulate_combiners')) + """Implement an packed analyzer accumulate based on a Combine.""" + + def __init__(self, operation, extra_args): + self._combiners = operation.combiners + + def expand(self, inputs): + (pcoll,) = inputs + # We specify a fanout so that the packed combiner doesn't exhibit stragglers + # during the 'reduce' phase when we have a lot of combine analyzers packed. + fanout = int(math.ceil(math.sqrt(len(self._combiners)))) + fanout = max(_DEFAULT_COMBINE_GLOBALLY_FANOUT, fanout) + return ( + pcoll + | "InitialPackedCombineGlobally" + >> beam.CombineGlobally( + _PackedCombinerWrapper(self._combiners, is_combining_accumulators=False) + ).with_fanout(fanout) + | "Count" >> common.IncrementCounter("num_packed_accumulate_combiners") + ) @common.register_ptransform(analyzer_nodes.PackedCombineMerge) @beam.typehints.with_input_types(Tuple[str, Any]) @beam.typehints.with_output_types(Dict[str, Any]) class _MergeAccumulatorsPackedCombineImpl(beam.PTransform): - """Implement an packed analyzer merge based on a Combine.""" + """Implement an packed analyzer merge based on a Combine.""" - def __init__(self, operation, extra_args): - self._combiners = operation.combiners + def __init__(self, operation, extra_args): + self._combiners = operation.combiners - def expand(self, inputs): - pcoll, = inputs + def expand(self, inputs): + (pcoll,) = inputs - return ( - pcoll - | 'MergePackedCombinesGlobally' >> beam.CombineGlobally( - _PackedCombinerWrapper( - self._combiners, - is_combining_accumulators=True)) - | 'Count' >> - common.IncrementCounter('num_packed_merge_combiners')) + return ( + pcoll + | "MergePackedCombinesGlobally" + >> beam.CombineGlobally( + _PackedCombinerWrapper(self._combiners, is_combining_accumulators=True) + ) + | "Count" >> common.IncrementCounter("num_packed_merge_combiners") + ) # TODO(zoyahav): Share logic with _InitialAccumulatePackedCombineImpl. @common.register_ptransform(analyzer_nodes.CacheableCombineAccumulate) @beam.typehints.with_input_types(Tuple[np.ndarray, ...]) class _InitialAccumulateCombineImpl(beam.PTransform): - """Implement an analyzer based on a Combine.""" + """Implement an analyzer based on a Combine.""" - def __init__(self, operation, extra_args): - self._combiner = operation.combiner - self._num_outputs = operation.num_outputs - self._name = operation.label + def __init__(self, operation, extra_args): + self._combiner = operation.combiner + self._num_outputs = operation.num_outputs + self._name = operation.label - def expand(self, inputs): - pcoll, = inputs + def expand(self, inputs): + (pcoll,) = inputs - return (pcoll - | 'InitialCombineGlobally' >> beam.CombineGlobally( - _CombinerWrapper( - self._combiner, - is_combining_accumulators=False)).with_fanout( - _DEFAULT_COMBINE_GLOBALLY_FANOUT)) + return pcoll | "InitialCombineGlobally" >> beam.CombineGlobally( + _CombinerWrapper(self._combiner, is_combining_accumulators=False) + ).with_fanout(_DEFAULT_COMBINE_GLOBALLY_FANOUT) @common.register_ptransform(analyzer_nodes.CacheableCombineMerge) class _MergeAccumulatorsCombineImpl(beam.PTransform): - """Implement an analyzer based on a Combine.""" + """Implement an analyzer based on a Combine.""" - def __init__(self, operation, extra_args): - self._combiner = operation.combiner - self._name = operation.label + def __init__(self, operation, extra_args): + self._combiner = operation.combiner + self._name = operation.label - def expand(self, inputs): - pcoll, = inputs + def expand(self, inputs): + (pcoll,) = inputs - return ( - pcoll - | 'MergeCombinesGlobally' >> beam.CombineGlobally( - _CombinerWrapper( - self._combiner, - is_combining_accumulators=True))) + return pcoll | "MergeCombinesGlobally" >> beam.CombineGlobally( + _CombinerWrapper(self._combiner, is_combining_accumulators=True) + ) @common.register_ptransform(analyzer_nodes.CacheableCombinePerKeyAccumulate) class _InitialAccumulateCombinePerKeyImpl(beam.PTransform): - """Implement an analyzer based on a CombinePerKey.""" - - def __init__(self, operation, extra_args): - self._combiner = operation.combiner - - def expand(self, inputs): - pcoll, = inputs - return (pcoll - | 'SplitByKey' >> beam.FlatMap(_split_inputs_by_key) - | 'CombinePerKey' >> beam.CombinePerKey( - _CombinerWrapper( - self._combiner, - is_combining_accumulators=False))) + """Implement an analyzer based on a CombinePerKey.""" + + def __init__(self, operation, extra_args): + self._combiner = operation.combiner + + def expand(self, inputs): + (pcoll,) = inputs + return ( + pcoll + | "SplitByKey" >> beam.FlatMap(_split_inputs_by_key) + | "CombinePerKey" + >> beam.CombinePerKey( + _CombinerWrapper(self._combiner, is_combining_accumulators=False) + ) + ) @common.register_ptransform(analyzer_nodes.CacheableCombinePerKeyMerge) class _MergeAccumulatorsCombinePerKeyImpl(beam.PTransform): - """Implement an analyzer based on a CombinePerKey.""" + """Implement an analyzer based on a CombinePerKey.""" - def __init__(self, operation, extra_args): - self._combiner = operation.combiner + def __init__(self, operation, extra_args): + self._combiner = operation.combiner - def expand(self, inputs): - pcoll, = inputs - return ( - pcoll - | 'MergeCombinePerKey' >> beam.CombinePerKey( - _CombinerWrapper( - self._combiner, - is_combining_accumulators=True))) + def expand(self, inputs): + (pcoll,) = inputs + return pcoll | "MergeCombinePerKey" >> beam.CombinePerKey( + _CombinerWrapper(self._combiner, is_combining_accumulators=True) + ) @common.register_ptransform(analyzer_nodes.CacheableCombinePerKeyFormatKeys) class _CombinePerKeyFormatKeysImpl(beam.PTransform): - """An analyzer that formats output for the non-stored per-key case.""" - - def __init__(self, operation, extra_args): - self._combiner = operation.combiner - - def expand(self, inputs): - pcoll, = inputs - output_keys = ( - ['key' - ] + [str(i) for i in range(len(self._combiner.output_tensor_infos()))]) - outputs_tuple = ( - pcoll - | 'ToList' >> beam.combiners.ToList() - | 'MergeByKey' >> beam.FlatMap(_merge_outputs_by_key, [ - info.dtype for info in self._combiner.output_tensor_infos() - ]).with_outputs(*output_keys)) - return tuple(outputs_tuple[key] for key in output_keys) + """An analyzer that formats output for the non-stored per-key case.""" + + def __init__(self, operation, extra_args): + self._combiner = operation.combiner + + def expand(self, inputs): + (pcoll,) = inputs + output_keys = ["key"] + [ + str(i) for i in range(len(self._combiner.output_tensor_infos())) + ] + outputs_tuple = ( + pcoll + | "ToList" >> beam.combiners.ToList() + | "MergeByKey" + >> beam.FlatMap( + _merge_outputs_by_key, + [info.dtype for info in self._combiner.output_tensor_infos()], + ).with_outputs(*output_keys) + ) + return tuple(outputs_tuple[key] for key in output_keys) @common.register_ptransform(analyzer_nodes.CacheableCombinePerKeyFormatLarge) class _CombinePerKeyFormatLargeImpl(beam.PTransform): - """An analyzer that formats output before writing to file for per-key case.""" + """An analyzer that formats output before writing to file for per-key case.""" - def __init__(self, operation, extra_args): - super().__init__() + def __init__(self, operation, extra_args): + super().__init__() - def expand(self, inputs): - to_str = tf.compat.as_str_any - pcoll, = inputs - return ( - pcoll - | 'EncodeValueAndSwapWithKey' >> beam.MapTuple( - lambda k, v: (to_str(','.join(map(to_str, v))), k))) + def expand(self, inputs): + to_str = tf.compat.as_str_any + (pcoll,) = inputs + return pcoll | "EncodeValueAndSwapWithKey" >> beam.MapTuple( + lambda k, v: (to_str(",".join(map(to_str, v))), k) + ) @common.register_ptransform(analyzer_nodes.PTransform) class _PTransformImpl(beam.PTransform): - """Implements a registered PTransform node by passing through the inputs.""" + """Implements a registered PTransform node by passing through the inputs.""" - def __init__(self, operation, extra_args): - self._ptransform = operation.ptransform - if isinstance(self._ptransform, experimental.PTransformAnalyzer): - self._ptransform.base_temp_dir = common.get_unique_temp_path( - extra_args.base_temp_dir) + def __init__(self, operation, extra_args): + self._ptransform = operation.ptransform + if isinstance(self._ptransform, experimental.PTransformAnalyzer): + self._ptransform.base_temp_dir = common.get_unique_temp_path( + extra_args.base_temp_dir + ) - def expand(self, inputs): - pcoll, = inputs - return pcoll | self._ptransform + def expand(self, inputs): + (pcoll,) = inputs + return pcoll | self._ptransform @common.register_ptransform(analyzer_nodes.EncodeCache) @beam.typehints.with_input_types(Any) @beam.typehints.with_output_types(bytes) class _EncodeCacheImpl(beam.PTransform): - """A PTransform that encodes cache entries.""" + """A PTransform that encodes cache entries.""" - def __init__(self, operation, extra_args): - self._coder = operation.coder + def __init__(self, operation, extra_args): + self._coder = operation.coder - def expand(self, inputs): - pcoll, = inputs + def expand(self, inputs): + (pcoll,) = inputs - return pcoll | 'Encode' >> beam.Map(self._coder.encode_cache) + return pcoll | "Encode" >> beam.Map(self._coder.encode_cache) @common.register_ptransform(analyzer_nodes.InstrumentDatasetCache) @beam.typehints.with_input_types(beam.pvalue.PBegin) @beam.typehints.with_output_types(None) class _InstrumentDatasetCacheImpl(beam.PTransform): - """Instruments pipeline analysis cache usage.""" - - def __init__(self, operation, extra_args): - self.pipeline = extra_args.pipeline - self._metadata_pcolls = tuple(extra_args.cache_pcoll_dict[k].metadata - for k in operation.input_cache_dataset_keys) - self._num_encode_cache = operation.num_encode_cache - self._num_decode_cache = operation.num_decode_cache - - def _make_and_increment_counter(self, value, name): - beam.metrics.Metrics.counter(common.METRICS_NAMESPACE, name).inc(value) - - def expand(self, pbegin): - if self._num_encode_cache > 0: - _ = ( - pbegin - | 'CreateSoleCacheEncodeInstrument' >> beam.Create( - [self._num_encode_cache]) - | 'InstrumentCacheEncode' >> beam.Map( - self._make_and_increment_counter, 'cache_entries_encoded')) - if self._num_decode_cache > 0: - _ = ( - self.pipeline - | 'CreateSoleCacheDecodeInstrument' >> beam.Create( - [self._num_decode_cache]) - | 'InstrumentCacheDecode' >> beam.Map( - self._make_and_increment_counter, 'cache_entries_decoded')) - if self._metadata_pcolls: - # Instruments datasets not read due to cache hit. - _ = ( - self._metadata_pcolls | beam.Flatten(pipeline=self.pipeline) - | 'ExtractCachedInputBytes' >> - beam.Map(lambda m: m.dataset_size if m else 0) - | 'SumCachedInputBytes' >> beam.CombineGlobally(sum) - | 'InstrumentCachedInputBytes' >> beam.Map( - self._make_and_increment_counter, - 'analysis_input_bytes_from_cache')) - return pbegin | 'CreateSoleEmptyOutput' >> beam.Create([]) + """Instruments pipeline analysis cache usage.""" + + def __init__(self, operation, extra_args): + self.pipeline = extra_args.pipeline + self._metadata_pcolls = tuple( + extra_args.cache_pcoll_dict[k].metadata + for k in operation.input_cache_dataset_keys + ) + self._num_encode_cache = operation.num_encode_cache + self._num_decode_cache = operation.num_decode_cache + + def _make_and_increment_counter(self, value, name): + beam.metrics.Metrics.counter(common.METRICS_NAMESPACE, name).inc(value) + + def expand(self, pbegin): + if self._num_encode_cache > 0: + _ = ( + pbegin + | "CreateSoleCacheEncodeInstrument" + >> beam.Create([self._num_encode_cache]) + | "InstrumentCacheEncode" + >> beam.Map(self._make_and_increment_counter, "cache_entries_encoded") + ) + if self._num_decode_cache > 0: + _ = ( + self.pipeline + | "CreateSoleCacheDecodeInstrument" + >> beam.Create([self._num_decode_cache]) + | "InstrumentCacheDecode" + >> beam.Map(self._make_and_increment_counter, "cache_entries_decoded") + ) + if self._metadata_pcolls: + # Instruments datasets not read due to cache hit. + _ = ( + self._metadata_pcolls + | beam.Flatten(pipeline=self.pipeline) + | "ExtractCachedInputBytes" + >> beam.Map(lambda m: m.dataset_size if m else 0) + | "SumCachedInputBytes" >> beam.CombineGlobally(sum) + | "InstrumentCachedInputBytes" + >> beam.Map( + self._make_and_increment_counter, "analysis_input_bytes_from_cache" + ) + ) + return pbegin | "CreateSoleEmptyOutput" >> beam.Create([]) @common.register_ptransform(analyzer_nodes.DecodeCache) @beam.typehints.with_input_types(beam.pvalue.PBegin) @beam.typehints.with_output_types(Any) class _DecodeCacheImpl(beam.PTransform): - """A PTransform method that extracts and decodes a cache object.""" + """A PTransform method that extracts and decodes a cache object.""" - def __init__(self, operation, extra_args): - self._cache_pcoll = ( - extra_args.cache_pcoll_dict[operation.dataset_key].get( - operation.cache_key)) - self._coder = operation.coder + def __init__(self, operation, extra_args): + self._cache_pcoll = extra_args.cache_pcoll_dict[operation.dataset_key].get( + operation.cache_key + ) + self._coder = operation.coder - def expand(self, pbegin): - del pbegin # unused + def expand(self, pbegin): + del pbegin # unused - return self._cache_pcoll | 'Decode' >> beam.Map(self._coder.decode_cache) + return self._cache_pcoll | "Decode" >> beam.Map(self._coder.decode_cache) @common.register_ptransform(analyzer_nodes.AddKey) @beam.typehints.with_input_types(Any) @beam.typehints.with_output_types(Tuple[str, Any]) class _AddKeyImpl(beam.PTransform): - """Implements AddKey.""" + """Implements AddKey.""" - def __init__(self, operation, extra_args): - del extra_args # unused - self._key = operation.key + def __init__(self, operation, extra_args): + del extra_args # unused + self._key = operation.key - def expand(self, inputs): - pcoll, = inputs - return pcoll | 'AddKey' >> beam.Map(lambda value: (self._key, value)) + def expand(self, inputs): + (pcoll,) = inputs + return pcoll | "AddKey" >> beam.Map(lambda value: (self._key, value)) -_FlattenListsItemType = TypeVariable('_FlattenListsItemType') +_FlattenListsItemType = TypeVariable("_FlattenListsItemType") @common.register_ptransform(analyzer_nodes.FlattenLists) @beam.typehints.with_input_types(List[_FlattenListsItemType]) @beam.typehints.with_output_types(_FlattenListsItemType) class _FlattenListsImpl(beam.PTransform): - """PTransform to flatten a PCollection of lists.""" + """PTransform to flatten a PCollection of lists.""" - def __init__(self, operation, extra_args): - del operation, extra_args # unused + def __init__(self, operation, extra_args): + del operation, extra_args # unused - def expand(self, inputs): - pcoll, = inputs - return pcoll | 'FlattenLists' >> beam.FlatMap(lambda x: x) + def expand(self, inputs): + (pcoll,) = inputs + return pcoll | "FlattenLists" >> beam.FlatMap(lambda x: x) @common.register_ptransform(analyzer_nodes.ExtractCombineMergeOutputs) @common.register_ptransform(analyzer_nodes.ExtractPackedCombineMergeOutputs) class _ExtractOutputImpl(beam.PTransform): - """Implements ExtractOutputs.""" + """Implements ExtractOutputs.""" - def __init__(self, operation, extra_args): - del extra_args # unused - self._num_outputs = operation.num_outputs + def __init__(self, operation, extra_args): + del extra_args # unused + self._num_outputs = operation.num_outputs - def expand(self, inputs): - pcoll, = inputs - def extract_outputs(outputs, num_outputs): - if len(outputs) != num_outputs: - raise ValueError( - 'Analyzer has {} outputs but its implementation produced {} ' - 'values'.format(num_outputs, len(outputs))) - for i, output in enumerate(outputs): - yield beam.pvalue.TaggedOutput(str(i), output) + def expand(self, inputs): + (pcoll,) = inputs + + def extract_outputs(outputs, num_outputs): + if len(outputs) != num_outputs: + raise ValueError( + f"Analyzer has {num_outputs} outputs but its implementation produced {len(outputs)} " + "values" + ) + for i, output in enumerate(outputs): + yield beam.pvalue.TaggedOutput(str(i), output) - output_keys = [str(i) for i in range(self._num_outputs)] - outputs_tuple = ( - pcoll | - 'ExtractOutputs' >> beam.FlatMap( - extract_outputs, self._num_outputs).with_outputs(*output_keys)) - return tuple(outputs_tuple[key] for key in output_keys) + output_keys = [str(i) for i in range(self._num_outputs)] + outputs_tuple = pcoll | "ExtractOutputs" >> beam.FlatMap( + extract_outputs, self._num_outputs + ).with_outputs(*output_keys) + return tuple(outputs_tuple[key] for key in output_keys) @common.register_ptransform(analyzer_nodes.ExtractVocabularyReservedTokens) class _ExtractVocabularyReservedTokensImpl(beam.PTransform): - """Extracts vocabulary reserved tokens values from the working graph.""" - - def __init__( - self, - operation: analyzer_nodes.ExtractVocabularyReservedTokens, - extra_args: common.ConstructBeamPipelineVisitor.ExtraArgs, - ): - self._name = operation.name - self._graph = extra_args.graph - - def expand( - self, pbegin: beam.pvalue.PBegin - ) -> beam.pvalue.PCollection[typing.List[np.ndarray]]: - tokens = tf_utils.fetch_vocabulary_reserved_tokens(self._graph, self._name) - return pbegin | 'CreateReservedToekensSinglton' >> beam.Create([[tokens]]) + """Extracts vocabulary reserved tokens values from the working graph.""" + + def __init__( + self, + operation: analyzer_nodes.ExtractVocabularyReservedTokens, + extra_args: common.ConstructBeamPipelineVisitor.ExtraArgs, + ): + self._name = operation.name + self._graph = extra_args.graph + + def expand( + self, pbegin: beam.pvalue.PBegin + ) -> beam.pvalue.PCollection[typing.List[np.ndarray]]: + tokens = tf_utils.fetch_vocabulary_reserved_tokens(self._graph, self._name) + return pbegin | "CreateReservedToekensSinglton" >> beam.Create([[tokens]]) diff --git a/tensorflow_transform/beam/analyzer_impls_test.py b/tensorflow_transform/beam/analyzer_impls_test.py index 671977c..02d13de 100644 --- a/tensorflow_transform/beam/analyzer_impls_test.py +++ b/tensorflow_transform/beam/analyzer_impls_test.py @@ -14,127 +14,154 @@ """Tests for tensorflow_transform.beam.analyzer_impls.""" import apache_beam as beam - import numpy as np import tensorflow as tf -from tensorflow_transform.beam import analyzer_impls -from tensorflow_transform.beam import tft_unit +from tensorflow_transform.beam import analyzer_impls, tft_unit -class AnalyzerImplsTest(tft_unit.TransformTestCase): - def testSplitInputsByKey(self): - inputs = [ - np.array(['my_key', 'my_other_key']), - np.array([[1, 2], [3, 4]]), - np.array([5, 6]) - ] - split_inputs = list(analyzer_impls._split_inputs_by_key(inputs)) - self.assertEqual(len(split_inputs), 2) +class AnalyzerImplsTest(tft_unit.TransformTestCase): + def testSplitInputsByKey(self): + inputs = [ + np.array(["my_key", "my_other_key"]), + np.array([[1, 2], [3, 4]]), + np.array([5, 6]), + ] + split_inputs = list(analyzer_impls._split_inputs_by_key(inputs)) + self.assertEqual(len(split_inputs), 2) - self.assertEqual(len(split_inputs[0]), 2) - self.assertEqual(split_inputs[0][0], 'my_key') - self.assertEqual(len(split_inputs[0][1]), 2) - self.assertAllEqual(split_inputs[0][1][0], np.array([1, 2])) - self.assertAllEqual(split_inputs[0][1][1], np.array(5)) + self.assertEqual(len(split_inputs[0]), 2) + self.assertEqual(split_inputs[0][0], "my_key") + self.assertEqual(len(split_inputs[0][1]), 2) + self.assertAllEqual(split_inputs[0][1][0], np.array([1, 2])) + self.assertAllEqual(split_inputs[0][1][1], np.array(5)) - self.assertEqual(len(split_inputs[1]), 2) - self.assertEqual(split_inputs[1][0], 'my_other_key') - self.assertEqual(len(split_inputs[1][1]), 2) - self.assertAllEqual(split_inputs[1][1][0], np.array([3, 4])) - self.assertAllEqual(split_inputs[1][1][1], np.array(6)) + self.assertEqual(len(split_inputs[1]), 2) + self.assertEqual(split_inputs[1][0], "my_other_key") + self.assertEqual(len(split_inputs[1][1]), 2) + self.assertAllEqual(split_inputs[1][1][0], np.array([3, 4])) + self.assertAllEqual(split_inputs[1][1][1], np.array(6)) - def testMergeOutputsByKey(self): - outputs = [ - ('my_key', [np.array(20), np.array([21, 22])]), - ('my_other_key', [np.array(23), np.array([24, 25])]) - ] - outputs_pcoll = [outputs] - merged_outputs_pcolls = tuple(outputs_pcoll | beam.FlatMap( - analyzer_impls._merge_outputs_by_key, - outputs_dtype=[tf.int64, tf.int64]).with_outputs('key', '0', '1')) - self.assertAllEqual(merged_outputs_pcolls[0][0], - np.array(['my_key', 'my_other_key'])) - self.assertAllEqual(merged_outputs_pcolls[1][0], - np.array([20, 23])) - self.assertAllEqual(merged_outputs_pcolls[2][0], - np.array([[21, 22], [24, 25]])) + def testMergeOutputsByKey(self): + outputs = [ + ("my_key", [np.array(20), np.array([21, 22])]), + ("my_other_key", [np.array(23), np.array([24, 25])]), + ] + outputs_pcoll = [outputs] + merged_outputs_pcolls = tuple( + outputs_pcoll + | beam.FlatMap( + analyzer_impls._merge_outputs_by_key, outputs_dtype=[tf.int64, tf.int64] + ).with_outputs("key", "0", "1") + ) + self.assertAllEqual( + merged_outputs_pcolls[0][0], np.array(["my_key", "my_other_key"]) + ) + self.assertAllEqual(merged_outputs_pcolls[1][0], np.array([20, 23])) + self.assertAllEqual(merged_outputs_pcolls[2][0], np.array([[21, 22], [24, 25]])) - def testMergeOutputsByKeyEmptyInput(self): - outputs = [] - outputs_pcoll = [outputs] - merged_outputs_pcolls = tuple(outputs_pcoll | beam.FlatMap( - analyzer_impls._merge_outputs_by_key, - outputs_dtype=[tf.float32, tf.float32]).with_outputs('key', '0', '1')) - self.assertAllEqual(merged_outputs_pcolls[0][0], - np.array([])) - self.assertAllEqual(merged_outputs_pcolls[1][0], np.array([])) - self.assertAllEqual(merged_outputs_pcolls[2][0], np.array([])) + def testMergeOutputsByKeyEmptyInput(self): + outputs = [] + outputs_pcoll = [outputs] + merged_outputs_pcolls = tuple( + outputs_pcoll + | beam.FlatMap( + analyzer_impls._merge_outputs_by_key, + outputs_dtype=[tf.float32, tf.float32], + ).with_outputs("key", "0", "1") + ) + self.assertAllEqual(merged_outputs_pcolls[0][0], np.array([])) + self.assertAllEqual(merged_outputs_pcolls[1][0], np.array([])) + self.assertAllEqual(merged_outputs_pcolls[2][0], np.array([])) - @tft_unit.named_parameters( - dict( - testcase_name='Increasing', - input_boundaries=np.array([[1, 1.00000001], [1, 2]]), - expected_boundaries=np.array([[1, 1.00000001], [1, 2]])), - dict( - testcase_name='Repeating', - input_boundaries=np.array([[1, 1, 1], [4, 4, 4]]), - expected_boundaries=np.array([[1, 1.000001, 1.000002], - [4, 4.000001, 4.000002]])), - dict( - testcase_name='NonIncreasing', - input_boundaries=np.array([[3, 5.1, 5.1], [4.01, 4.01, 4.2]]), - expected_boundaries=np.array([[3, 5.1, 5.1000021], - [4.01, 4.01000019, 4.20000019]]), - atol=1e-6), - ) - def testMakeStrictlyIncreasingBoundariesRows(self, - input_boundaries, - expected_boundaries, - atol=None): - result = analyzer_impls._make_strictly_increasing_boundaries_rows( - input_boundaries) - if atol is None: - self.assertAllEqual(result, expected_boundaries) - else: - self.assertAllClose(result, expected_boundaries, atol=atol) + @tft_unit.named_parameters( + dict( + testcase_name="Increasing", + input_boundaries=np.array([[1, 1.00000001], [1, 2]]), + expected_boundaries=np.array([[1, 1.00000001], [1, 2]]), + ), + dict( + testcase_name="Repeating", + input_boundaries=np.array([[1, 1, 1], [4, 4, 4]]), + expected_boundaries=np.array( + [[1, 1.000001, 1.000002], [4, 4.000001, 4.000002]] + ), + ), + dict( + testcase_name="NonIncreasing", + input_boundaries=np.array([[3, 5.1, 5.1], [4.01, 4.01, 4.2]]), + expected_boundaries=np.array( + [[3, 5.1, 5.1000021], [4.01, 4.01000019, 4.20000019]] + ), + atol=1e-6, + ), + ) + def testMakeStrictlyIncreasingBoundariesRows( + self, input_boundaries, expected_boundaries, atol=None + ): + result = analyzer_impls._make_strictly_increasing_boundaries_rows( + input_boundaries + ) + if atol is None: + self.assertAllEqual(result, expected_boundaries) + else: + self.assertAllClose(result, expected_boundaries, atol=atol) - @tft_unit.named_parameters( - dict( - testcase_name='Simple', - input_boundaries=np.array([[0, 1, 2], [0, 1, 2]]), - expected_boundaries=np.array([0, 0.5, 1, 1.5, 2]), - expected_scales=np.array([0.5, 0.5]), - expected_shifts=np.array([0, 1]), - expected_num_buckets=np.array(4)), - dict( - testcase_name='Complex', - input_boundaries=np.array([[0, 1, 2, 3], [3, 3, 3, 3], [2, 4, 6, 8]]), - expected_boundaries=np.array([ - 0, 0.33333333, 0.66666667, 1, 1.33333333, 1.66666667, 2, - 2.33333333, 2.66666667, 3 - ]), - expected_scales=np.array([0.333333333, 333333.333, 0.166666667]), - expected_shifts=np.array([0, -999999, 1.66666667]), - expected_num_buckets=np.array(5)), - dict( - testcase_name='SingleBoundary', - input_boundaries=np.array([[1], [2]]), - expected_boundaries=np.array([0]), - expected_scales=np.array([1., 1.]), - expected_shifts=np.array([-1, -1]), - expected_num_buckets=np.array(2)), - ) - def testJoinBoundarieRows(self, input_boundaries, expected_boundaries, - expected_scales, expected_shifts, - expected_num_buckets): - boundaries, scales, shifts, num_buckets = ( - analyzer_impls._join_boundary_rows(input_boundaries)) - self.assertAllClose(boundaries, expected_boundaries) - self.assertAllClose(scales, expected_scales) - self.assertAllClose(shifts, expected_shifts) - self.assertAllEqual(num_buckets, expected_num_buckets) + @tft_unit.named_parameters( + dict( + testcase_name="Simple", + input_boundaries=np.array([[0, 1, 2], [0, 1, 2]]), + expected_boundaries=np.array([0, 0.5, 1, 1.5, 2]), + expected_scales=np.array([0.5, 0.5]), + expected_shifts=np.array([0, 1]), + expected_num_buckets=np.array(4), + ), + dict( + testcase_name="Complex", + input_boundaries=np.array([[0, 1, 2, 3], [3, 3, 3, 3], [2, 4, 6, 8]]), + expected_boundaries=np.array( + [ + 0, + 0.33333333, + 0.66666667, + 1, + 1.33333333, + 1.66666667, + 2, + 2.33333333, + 2.66666667, + 3, + ] + ), + expected_scales=np.array([0.333333333, 333333.333, 0.166666667]), + expected_shifts=np.array([0, -999999, 1.66666667]), + expected_num_buckets=np.array(5), + ), + dict( + testcase_name="SingleBoundary", + input_boundaries=np.array([[1], [2]]), + expected_boundaries=np.array([0]), + expected_scales=np.array([1.0, 1.0]), + expected_shifts=np.array([-1, -1]), + expected_num_buckets=np.array(2), + ), + ) + def testJoinBoundarieRows( + self, + input_boundaries, + expected_boundaries, + expected_scales, + expected_shifts, + expected_num_buckets, + ): + boundaries, scales, shifts, num_buckets = analyzer_impls._join_boundary_rows( + input_boundaries + ) + self.assertAllClose(boundaries, expected_boundaries) + self.assertAllClose(scales, expected_scales) + self.assertAllClose(shifts, expected_shifts) + self.assertAllEqual(num_buckets, expected_num_buckets) -if __name__ == '__main__': - tft_unit.main() +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/annotators_test.py b/tensorflow_transform/beam/annotators_test.py index 5b016d3..e7bb254 100644 --- a/tensorflow_transform/beam/annotators_test.py +++ b/tensorflow_transform/beam/annotators_test.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2023 Google Inc. All Rights Reserved. # @@ -16,40 +15,41 @@ """Tests for tft annotators.""" import tensorflow as tf -import tensorflow_transform as tft -from tensorflow_transform.beam import tft_unit from google.protobuf import text_format from tensorflow_metadata.proto.v0 import schema_pb2 +import tensorflow_transform as tft +from tensorflow_transform.beam import tft_unit _TF_VERSION_NAMED_PARAMETERS = [ - dict(testcase_name='CompatV1', use_tf_compat_v1=True), - dict(testcase_name='V2', use_tf_compat_v1=False), + dict(testcase_name="CompatV1", use_tf_compat_v1=True), + dict(testcase_name="V2", use_tf_compat_v1=False), ] class AnnotatorsTest(tft_unit.TransformTestCase): + @tft_unit.named_parameters(*_TF_VERSION_NAMED_PARAMETERS) + def test_annotate_sparse_outputs(self, use_tf_compat_v1): + def preprocessing_fn(inputs): + outputs = inputs.copy() + x = tf.sparse.expand_dims(inputs["x"], -1) + outputs["x"] = x + tft.experimental.annotate_sparse_output_shape(x, tf.constant([1, 1])) + tft.experimental.annotate_sparse_output_shape(outputs["y"], [17]) + tft.experimental.annotate_true_sparse_output(outputs["z"]) + return outputs - @tft_unit.named_parameters(*_TF_VERSION_NAMED_PARAMETERS) - def test_annotate_sparse_outputs(self, use_tf_compat_v1): - def preprocessing_fn(inputs): - outputs = inputs.copy() - x = tf.sparse.expand_dims(inputs['x'], -1) - outputs['x'] = x - tft.experimental.annotate_sparse_output_shape(x, tf.constant([1, 1])) - tft.experimental.annotate_sparse_output_shape(outputs['y'], [17]) - tft.experimental.annotate_true_sparse_output(outputs['z']) - return outputs - - input_data_dicts = [dict(x=[1], y=[2], z=[3], t=[4]) for x in range(10)] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.VarLenFeature(tf.int64), - 'y': tf.io.VarLenFeature(tf.int64), - 'z': tf.io.VarLenFeature(tf.int64), - 't': tf.io.VarLenFeature(tf.int64), - }) - schema = text_format.Parse( - """ + input_data_dicts = [dict(x=[1], y=[2], z=[3], t=[4]) for x in range(10)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.VarLenFeature(tf.int64), + "y": tf.io.VarLenFeature(tf.int64), + "z": tf.io.VarLenFeature(tf.int64), + "t": tf.io.VarLenFeature(tf.int64), + } + ) + schema = text_format.Parse( + """ feature { name: "t" type: INT @@ -128,35 +128,35 @@ def preprocessing_fn(inputs): } } """, - schema_pb2.Schema(), - ) - if not tft_unit.is_external_environment(): - schema.generate_legacy_feature_spec = False - self.assertAnalyzeAndTransformResults( - input_data_dicts, - input_metadata, - preprocessing_fn, - expected_metadata=tft.DatasetMetadata(schema), - force_tf_compat_v1=use_tf_compat_v1, - output_record_batches=True, - ) + schema_pb2.Schema(), + ) + if not tft_unit.is_external_environment(): + schema.generate_legacy_feature_spec = False + self.assertAnalyzeAndTransformResults( + input_data_dicts, + input_metadata, + preprocessing_fn, + expected_metadata=tft.DatasetMetadata(schema), + force_tf_compat_v1=use_tf_compat_v1, + output_record_batches=True, + ) - @tft_unit.named_parameters(*_TF_VERSION_NAMED_PARAMETERS) - def test_conflicting_sparse_outputs_annotations(self, use_tf_compat_v1): - def preprocessing_fn(inputs): - tft.experimental.annotate_sparse_output_shape(inputs['x'], [3]) - tft.experimental.annotate_sparse_output_shape(inputs['x'], [17]) - tft.experimental.annotate_true_sparse_output(inputs['x']) - return inputs + @tft_unit.named_parameters(*_TF_VERSION_NAMED_PARAMETERS) + def test_conflicting_sparse_outputs_annotations(self, use_tf_compat_v1): + def preprocessing_fn(inputs): + tft.experimental.annotate_sparse_output_shape(inputs["x"], [3]) + tft.experimental.annotate_sparse_output_shape(inputs["x"], [17]) + tft.experimental.annotate_true_sparse_output(inputs["x"]) + return inputs - input_data_dicts = [dict(x=[1]) for x in range(10)] - input_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'x': tf.io.VarLenFeature(tf.int64), - } - ) - schema = text_format.Parse( - """ + input_data_dicts = [dict(x=[1]) for x in range(10)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.VarLenFeature(tf.int64), + } + ) + schema = text_format.Parse( + """ feature { name: "x$sparse_indices_0" type: INT @@ -180,84 +180,88 @@ def preprocessing_fn(inputs): } } """, - schema_pb2.Schema(), - ) - if not tft_unit.is_external_environment(): - schema.generate_legacy_feature_spec = False - self.assertAnalyzeAndTransformResults( - input_data_dicts, - input_metadata, - preprocessing_fn, - expected_metadata=tft.DatasetMetadata(schema), - force_tf_compat_v1=use_tf_compat_v1, - output_record_batches=True, - ) + schema_pb2.Schema(), + ) + if not tft_unit.is_external_environment(): + schema.generate_legacy_feature_spec = False + self.assertAnalyzeAndTransformResults( + input_data_dicts, + input_metadata, + preprocessing_fn, + expected_metadata=tft.DatasetMetadata(schema), + force_tf_compat_v1=use_tf_compat_v1, + output_record_batches=True, + ) - @tft_unit.named_parameters(*_TF_VERSION_NAMED_PARAMETERS) - def test_invalid_sparse_outputs_annotations(self, use_tf_compat_v1): - def preprocessing_fn(inputs): - tft.experimental.annotate_sparse_output_shape(inputs['x'], [3, 42]) - return inputs + @tft_unit.named_parameters(*_TF_VERSION_NAMED_PARAMETERS) + def test_invalid_sparse_outputs_annotations(self, use_tf_compat_v1): + def preprocessing_fn(inputs): + tft.experimental.annotate_sparse_output_shape(inputs["x"], [3, 42]) + return inputs - input_data_dicts = [dict(x=[1]) for x in range(10)] - input_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'x': tf.io.VarLenFeature(tf.int64), - } - ) - with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises - ValueError, - r'Annotated shape \[3, 42\] was expected to have rank 1', - ): - self.assertAnalyzeAndTransformResults( - input_data_dicts, - input_metadata, - preprocessing_fn, - force_tf_compat_v1=use_tf_compat_v1, - ) + input_data_dicts = [dict(x=[1]) for x in range(10)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.VarLenFeature(tf.int64), + } + ) + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + ValueError, + r"Annotated shape \[3, 42\] was expected to have rank 1", + ): + self.assertAnalyzeAndTransformResults( + input_data_dicts, + input_metadata, + preprocessing_fn, + force_tf_compat_v1=use_tf_compat_v1, + ) - @tft_unit.named_parameters( - dict( - testcase_name='sanity', - values=['hello', 'world', 'world'], - expected_size=2, - ), - dict( - testcase_name='single_token', - values=['hello', 'hello', 'hello'], - expected_size=1, - ), - dict( - testcase_name='empty', - values=['', '', ''], - expected_size=1, - ), - ) - def test_get_vocabulary_size_by_name(self, values, expected_size): - vocab_filename = 'vocab' + @tft_unit.named_parameters( + dict( + testcase_name="sanity", + values=["hello", "world", "world"], + expected_size=2, + ), + dict( + testcase_name="single_token", + values=["hello", "hello", "hello"], + expected_size=1, + ), + dict( + testcase_name="empty", + values=["", "", ""], + expected_size=1, + ), + ) + def test_get_vocabulary_size_by_name(self, values, expected_size): + vocab_filename = "vocab" - def preprocessing_fn(inputs): - tft.vocabulary(inputs['s'], vocab_filename=vocab_filename) - size = tf.zeros_like( - inputs['s'], dtype=tf.int64 - ) + tft.experimental.get_vocabulary_size_by_name(vocab_filename) - return {'size': size} + def preprocessing_fn(inputs): + tft.vocabulary(inputs["s"], vocab_filename=vocab_filename) + size = tf.zeros_like( + inputs["s"], dtype=tf.int64 + ) + tft.experimental.get_vocabulary_size_by_name(vocab_filename) + return {"size": size} - input_data_dicts = [dict(s=v) for v in values] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 's': tf.io.FixedLenFeature([], tf.string), - }) - expected_data = [{ - 'size': expected_size, - }] * len(values) - self.assertAnalyzeAndTransformResults( - input_data_dicts, - input_metadata, - preprocessing_fn, - force_tf_compat_v1=False, - expected_data=expected_data, - ) + input_data_dicts = [dict(s=v) for v in values] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "s": tf.io.FixedLenFeature([], tf.string), + } + ) + expected_data = [ + { + "size": expected_size, + } + ] * len(values) + self.assertAnalyzeAndTransformResults( + input_data_dicts, + input_metadata, + preprocessing_fn, + force_tf_compat_v1=False, + expected_data=expected_data, + ) -if __name__ == '__main__': - tft_unit.main() +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/beam_nodes.py b/tensorflow_transform/beam/beam_nodes.py index c3cc0ef..cc89026 100644 --- a/tensorflow_transform/beam/beam_nodes.py +++ b/tensorflow_transform/beam/beam_nodes.py @@ -41,154 +41,168 @@ """ import tensorflow as tf -from tensorflow_transform import nodes + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple +from tensorflow_transform import nodes + class CreateTensorBinding( tfx_namedtuple.namedtuple( - 'CreateTensorBinding', - ['tensor_name', 'dtype_enum', 'is_asset_filepath', 'label']), - nodes.OperationDef): - """An operation that represents creating a tensor binding from a value. - - This `OperationDef` represents a `beam.PTransform` that applies a ParDo - (where the input PCollection is assumed to contain a single element), which - combines the single element with the a tensor name and `is_asset_filepath` - to create a tensor binding. - - Attributes: - tensor_name: The name of the tensor that the given value should replace as a - constant tensor. - dtype_enum: The Dtype of the tensor as a TF `types_pb2.DataType`. - is_asset_filepath: If true, then the replaced value will be added to the - ASSET_FILEPATHS collection if exporting a TF1 Graph. - label: A unique label for this operation. - """ - __slots__ = () + "CreateTensorBinding", + ["tensor_name", "dtype_enum", "is_asset_filepath", "label"], + ), + nodes.OperationDef, +): + """An operation that represents creating a tensor binding from a value. + + This `OperationDef` represents a `beam.PTransform` that applies a ParDo + (where the input PCollection is assumed to contain a single element), which + combines the single element with the a tensor name and `is_asset_filepath` + to create a tensor binding. + + Attributes + ---------- + tensor_name: The name of the tensor that the given value should replace as a + constant tensor. + dtype_enum: The Dtype of the tensor as a TF `types_pb2.DataType`. + is_asset_filepath: If true, then the replaced value will be added to the + ASSET_FILEPATHS collection if exporting a TF1 Graph. + label: A unique label for this operation. + """ + + __slots__ = () class CreateSavedModel( tfx_namedtuple.namedtuple( - 'CreateSavedModel', - ['table_initializers', 'output_signature', 'label']), - nodes.OperationDef): - """An operation that represents creating a SavedModel with bound values. - - This operation represents creating a SavedModel. Its output is a - PCollection containing a single element which is the directory containing the - `SavedModel`. The inputs are a PCollection of tensor bindings. A tensor - binding is the specification of a tensor and a value that it should be - replaced with in the graph. - - This allows us to create a `SavedModel` in a deferred manner, which depends on - deferred values (the tensor bindings) which were not known when the Beam graph - was constructed. - - - Attributes: - table_initializers: A list of table initializer ops that should be run as - part of this SavedModel. - output_signature: The output signature of this `SavedModel`, as a dictionary - whose keys are feature names and values are `Tensor`s or - `SparseTensor`s. - label: A unique label for this operation. - """ - __slots__ = () - - def _get_tensor_type_name(self, tensor): - if isinstance(tensor, tf.Tensor): - return 'Tensor' - elif isinstance(tensor, tf.SparseTensor): - return 'SparseTensor' - elif isinstance(tensor, tf.RaggedTensor): - return 'RaggedTensor' - raise ValueError('Got a {}, expected a Tensor or SparseTensor'.format( - type(tensor))) - - def get_field_str(self, field_name): - # Overriding the str representation of table initializers since it may be - # different for various versions of TF. - if field_name == 'table_initializers': - return '{}'.format(len(self.table_initializers)) - elif field_name == 'output_signature': - copied = self.output_signature.copy() - for key in copied: - value = self.output_signature[key] - copied[key] = '{}'.format( - self._get_tensor_type_name(value), value.shape.as_list(), - value.dtype) - return str(copied) - return super().get_field_str(field_name) + "CreateSavedModel", ["table_initializers", "output_signature", "label"] + ), + nodes.OperationDef, +): + """An operation that represents creating a SavedModel with bound values. + + This operation represents creating a SavedModel. Its output is a + PCollection containing a single element which is the directory containing the + `SavedModel`. The inputs are a PCollection of tensor bindings. A tensor + binding is the specification of a tensor and a value that it should be + replaced with in the graph. + + This allows us to create a `SavedModel` in a deferred manner, which depends on + deferred values (the tensor bindings) which were not known when the Beam graph + was constructed. + + + Attributes + ---------- + table_initializers: A list of table initializer ops that should be run as + part of this SavedModel. + output_signature: The output signature of this `SavedModel`, as a dictionary + whose keys are feature names and values are `Tensor`s or + `SparseTensor`s. + label: A unique label for this operation. + """ + + __slots__ = () + + def _get_tensor_type_name(self, tensor): + if isinstance(tensor, tf.Tensor): + return "Tensor" + elif isinstance(tensor, tf.SparseTensor): + return "SparseTensor" + elif isinstance(tensor, tf.RaggedTensor): + return "RaggedTensor" + raise ValueError(f"Got a {type(tensor)}, expected a Tensor or SparseTensor") + + def get_field_str(self, field_name): + # Overriding the str representation of table initializers since it may be + # different for various versions of TF. + if field_name == "table_initializers": + return f"{len(self.table_initializers)}" + elif field_name == "output_signature": + copied = self.output_signature.copy() + for key in copied: + value = self.output_signature[key] + copied[key] = ( + f"{self._get_tensor_type_name(value)}" + ) + return str(copied) + return super().get_field_str(field_name) class ExtractInputForSavedModel( - tfx_namedtuple.namedtuple('ExtractInputForSavedModel', - ['dataset_key', 'label']), nodes.OperationDef): - """An operation that forwards the requested dataset in PCollection form. + tfx_namedtuple.namedtuple("ExtractInputForSavedModel", ["dataset_key", "label"]), + nodes.OperationDef, +): + """An operation that forwards the requested dataset in PCollection form. - The resulting PCollection is either the dataset corresponding to - `dataset_key`, or a flattened PCollection if `dataset_key` is not specified. + The resulting PCollection is either the dataset corresponding to + `dataset_key`, or a flattened PCollection if `dataset_key` is not specified. - Attributes: - dataset_key: (Optional) dataset key str. - label: A unique label for this operation. - """ - __slots__ = () + Attributes + ---------- + dataset_key: (Optional) dataset key str. + label: A unique label for this operation. + """ + + __slots__ = () class ApplySavedModel( - tfx_namedtuple.namedtuple('ApplySavedModel', ['phase', 'label']), - nodes.OperationDef): - """An operation that represents applying a SavedModel as a `beam.ParDo`. + tfx_namedtuple.namedtuple("ApplySavedModel", ["phase", "label"]), nodes.OperationDef +): + """An operation that represents applying a SavedModel as a `beam.ParDo`. + + This operation represents applying a `SavedModel`, which is the input to this + operation, to the input values. The inputs values are not an input to this + operation, but are provided to the implementation by + `tensorflow_transform.beam.common.ConstructBeamPipelineVisitor.ExtraArgs`. - This operation represents applying a `SavedModel`, which is the input to this - operation, to the input values. The inputs values are not an input to this - operation, but are provided to the implementation by - `tensorflow_transform.beam.common.ConstructBeamPipelineVisitor.ExtraArgs`. + The input should be a PCollection containing a single element which is the + directory containing the SavedModel to be run. - The input should be a PCollection containing a single element which is the - directory containing the SavedModel to be run. + Attributes + ---------- + phase: An integer which is the phase that this operation is run as part of. + label: A unique label for this operation. + """ - Attributes: - phase: An integer which is the phase that this operation is run as part of. - label: A unique label for this operation. - """ - __slots__ = () + __slots__ = () - @property - def is_partitionable(self): - return True + @property + def is_partitionable(self): + return True class ExtractFromDict( - tfx_namedtuple.namedtuple('ExtractFromDict', ['keys', 'label']), - nodes.OperationDef): - """An operation that represents extracting values from a dictionary. - - This operation represents a `beam.ParDo` that is applied to a PCollection - whose elements are assumed to be a dictionary of values. For each element of - the PCollection, this corresponding element of the output PCollection is a - tuple of values, one for each key. - - Attributes: - keys: The keys whose values should be extracted from each element of the - input PCollection. keys should either be a tuple or a string. - label: A unique label for this operation. - """ - __slots__ = () - - @property - def is_partitionable(self): - return True - - -class Flatten( - tfx_namedtuple.namedtuple('Flatten', ['label']), nodes.OperationDef): - __slots__ = () - - @property - def is_partitionable(self): - return True + tfx_namedtuple.namedtuple("ExtractFromDict", ["keys", "label"]), nodes.OperationDef +): + """An operation that represents extracting values from a dictionary. + + This operation represents a `beam.ParDo` that is applied to a PCollection + whose elements are assumed to be a dictionary of values. For each element of + the PCollection, this corresponding element of the output PCollection is a + tuple of values, one for each key. + + Attributes + ---------- + keys: The keys whose values should be extracted from each element of the + input PCollection. keys should either be a tuple or a string. + label: A unique label for this operation. + """ + + __slots__ = () + + @property + def is_partitionable(self): + return True + + +class Flatten(tfx_namedtuple.namedtuple("Flatten", ["label"]), nodes.OperationDef): + __slots__ = () + + @property + def is_partitionable(self): + return True diff --git a/tensorflow_transform/beam/bucketize_integration_test.py b/tensorflow_transform/beam/bucketize_integration_test.py index 03493f0..f30f7cd 100644 --- a/tensorflow_transform/beam/bucketize_integration_test.py +++ b/tensorflow_transform/beam/bucketize_integration_test.py @@ -17,878 +17,1004 @@ import random import numpy as np - import tensorflow as tf +from tensorflow_metadata.proto.v0 import schema_pb2 + import tensorflow_transform as tft from tensorflow_transform import analyzers from tensorflow_transform.beam import impl as beam_impl from tensorflow_transform.beam import tft_unit -from tensorflow_metadata.proto.v0 import schema_pb2 # pylint: disable=g-complex-comprehension def _construct_test_bucketization_tight_sequence_parameters(): - # (test_inputs, expected_boundaries, dtype, num_buckets, num_expected_buckets) - args = ( - ([1, 2, 3, 4], np.array([[3]], np.float32), tf.int32, 2, 2), - ([1, 2, 3, 4], np.array([[2, 3]], np.float32), tf.int32, 3, 3), - ([1, 2, 3, 4], np.array([[2, 3, 4]], np.float32), tf.int32, 4, 4), - ([1, 2, 3, 4], np.array([[1, 2, 3, 4]], np.float32), tf.int32, 5, 5), - ([1, 2, 3, 4], np.array([[1, 2, 3, 3, 4]], np.float32), tf.int32, 6, 6), - ([1, 2, 3, 4], np.array([[1, 1, 2, 2, 3, 3, 3, 4, 4]], - np.float32), tf.int32, 10, 10), - ) - return args + # (test_inputs, expected_boundaries, dtype, num_buckets, num_expected_buckets) + args = ( + ([1, 2, 3, 4], np.array([[3]], np.float32), tf.int32, 2, 2), + ([1, 2, 3, 4], np.array([[2, 3]], np.float32), tf.int32, 3, 3), + ([1, 2, 3, 4], np.array([[2, 3, 4]], np.float32), tf.int32, 4, 4), + ([1, 2, 3, 4], np.array([[1, 2, 3, 4]], np.float32), tf.int32, 5, 5), + ([1, 2, 3, 4], np.array([[1, 2, 3, 3, 4]], np.float32), tf.int32, 6, 6), + ( + [1, 2, 3, 4], + np.array([[1, 1, 2, 2, 3, 3, 3, 4, 4]], np.float32), + tf.int32, + 10, + 10, + ), + ) + return args def _construct_test_bucketization_parameters(): - args_without_dtype = ( - (range(1, 10), [4, 7], False, None, False, False), - (range(1, 100), [25, 50, 75], False, None, False, False), - - # The following is similar to range(1, 100) test above, except that - # only odd numbers are in the input; so boundaries differ (26 -> 27 and - # 76 -> 77). - (range(1, 100, 2), [24, 50, 75], False, None, False, False), - - # Test some inversely sorted inputs, and with different strides, and - # boundaries/buckets. - (range(9, 0, -1), [4, 7], False, None, False, False), - (range(19, 0, -1), [10], False, None, False, False), - (range(99, 0, -1), [50], False, None, False, False), - (range(99, 0, -1), [34, 67], False, None, False, False), - (range(99, 0, -2), [33, 67], False, None, False, False), - (range(99, 0, -1), range(10, 100, 10), False, None, False, False), - - # These tests do a random shuffle of the inputs, which must not affect the - # boundaries (or the computed buckets). - (range(99, 0, -1), range(10, 100, 10), True, None, False, False), - (range(1, 100), range(10, 100, 10), True, None, False, False), - - # The following test is with multiple batches (3 batches with default - # batch of 1000). - (range(1, 3000), [1500], False, None, False, False), - (range(1, 3000), [1000, 2000], False, None, False, False), - - # Test with specific error for bucket boundaries. This is same as the test - # above with 3 batches and a single boundary, but with a stricter error - # tolerance (0.001) than the default error (0.01). - (range(1, 3000), [1500], False, 0.001, False, False), - - # Tests for tft.apply_buckets. - (range(1, 100), [25, 50, 75], False, 0.00001, True, False), - (range(1, 100), [25, 50, 75], False, 0.00001, True, True), - ) - dtypes = (tf.int32, tf.int64, tf.float32, tf.float64, tf.double) - - args_with_dtype = [ - # Tests for handling np.nan input values. - (list(range(1, 100)) + [np.nan] * 10, [25, 50, 75], False, 0.01, - True, True, tf.float32), - (list(range(1, 100)) + [np.nan] * 10, [25, 50, 75], False, 0.01, - False, True, tf.float32), - (list(range(1, 100)) + [np.nan] * 10, [25, 50, 75], False, 0.01, - False, False, tf.float32), - ] - return ([x + (dtype,) for x in args_without_dtype for dtype in dtypes] + - args_with_dtype) + args_without_dtype = ( + (range(1, 10), [4, 7], False, None, False, False), + (range(1, 100), [25, 50, 75], False, None, False, False), + # The following is similar to range(1, 100) test above, except that + # only odd numbers are in the input; so boundaries differ (26 -> 27 and + # 76 -> 77). + (range(1, 100, 2), [24, 50, 75], False, None, False, False), + # Test some inversely sorted inputs, and with different strides, and + # boundaries/buckets. + (range(9, 0, -1), [4, 7], False, None, False, False), + (range(19, 0, -1), [10], False, None, False, False), + (range(99, 0, -1), [50], False, None, False, False), + (range(99, 0, -1), [34, 67], False, None, False, False), + (range(99, 0, -2), [33, 67], False, None, False, False), + (range(99, 0, -1), range(10, 100, 10), False, None, False, False), + # These tests do a random shuffle of the inputs, which must not affect the + # boundaries (or the computed buckets). + (range(99, 0, -1), range(10, 100, 10), True, None, False, False), + (range(1, 100), range(10, 100, 10), True, None, False, False), + # The following test is with multiple batches (3 batches with default + # batch of 1000). + (range(1, 3000), [1500], False, None, False, False), + (range(1, 3000), [1000, 2000], False, None, False, False), + # Test with specific error for bucket boundaries. This is same as the test + # above with 3 batches and a single boundary, but with a stricter error + # tolerance (0.001) than the default error (0.01). + (range(1, 3000), [1500], False, 0.001, False, False), + # Tests for tft.apply_buckets. + (range(1, 100), [25, 50, 75], False, 0.00001, True, False), + (range(1, 100), [25, 50, 75], False, 0.00001, True, True), + ) + dtypes = (tf.int32, tf.int64, tf.float32, tf.float64, tf.double) + + args_with_dtype = [ + # Tests for handling np.nan input values. + ( + list(range(1, 100)) + [np.nan] * 10, + [25, 50, 75], + False, + 0.01, + True, + True, + tf.float32, + ), + ( + list(range(1, 100)) + [np.nan] * 10, + [25, 50, 75], + False, + 0.01, + False, + True, + tf.float32, + ), + ( + list(range(1, 100)) + [np.nan] * 10, + [25, 50, 75], + False, + 0.01, + False, + False, + tf.float32, + ), + ] + return [ + x + (dtype,) for x in args_without_dtype for dtype in dtypes + ] + args_with_dtype + # Per-key buckets for: # input_data = [1, 2, ..., 100], # key = ['a' if val <50 else 'b'], -_SIMPLE_PER_KEY_BUCKETS = {'a': [17, 33], 'b': [66, 83]} +_SIMPLE_PER_KEY_BUCKETS = {"a": [17, 33], "b": [66, 83]} # Same as above, but with weights = [0 if val in range(25, 75) else 1] _WEIGHTED_PER_KEY_0_RANGE = range(25, 75) -_WEIGHTED_PER_KEY_BUCKETS = {'a': [9, 17], 'b': [83, 91]} +_WEIGHTED_PER_KEY_BUCKETS = {"a": [9, 17], "b": [83, 91]} def _compute_simple_per_key_bucket(val, key, weighted=False): - if weighted: - return np.digitize(val, _WEIGHTED_PER_KEY_BUCKETS[key]) - else: - return np.digitize(val, _SIMPLE_PER_KEY_BUCKETS[key]) + if weighted: + return np.digitize(val, _WEIGHTED_PER_KEY_BUCKETS[key]) + else: + return np.digitize(val, _SIMPLE_PER_KEY_BUCKETS[key]) _BUCKETIZE_COMPOSITE_INPUT_TEST_CASES = [ dict( - testcase_name='sparse', - input_data=[{ - 'val': [x], - 'idx0': [x % 4], - 'idx1': [x % 5] - } for x in range(1, 10)], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.SparseFeature(['idx0', 'idx1'], 'val', tf.float32, - [4, 5]), - }), - expected_data=[{ - 'x_bucketized$sparse_values': [(x - 1) // 3], - 'x_bucketized$sparse_indices_0': [x % 4], - 'x_bucketized$sparse_indices_1': [x % 5] - } for x in range(1, 10)]), + testcase_name="sparse", + input_data=[ + {"val": [x], "idx0": [x % 4], "idx1": [x % 5]} for x in range(1, 10) + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.SparseFeature(["idx0", "idx1"], "val", tf.float32, [4, 5]), + } + ), + expected_data=[ + { + "x_bucketized$sparse_values": [(x - 1) // 3], + "x_bucketized$sparse_indices_0": [x % 4], + "x_bucketized$sparse_indices_1": [x % 5], + } + for x in range(1, 10) + ], + ), dict( - testcase_name='ragged', - input_data=[{ - 'val': [x, 10 - x], - 'row_lengths': [0, x % 3, 2 - x % 3], - } for x in range(1, 10)], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.RaggedFeature( + testcase_name="ragged", + input_data=[ + { + "val": [x, 10 - x], + "row_lengths": [0, x % 3, 2 - x % 3], + } + for x in range(1, 10) + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.RaggedFeature( tf.int64, - value_key='val', + value_key="val", partitions=[ - tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error - ]), - }), - expected_data=[{ - 'x_bucketized$ragged_values': [(x - 1) // 3, (9 - x) // 3], - 'x_bucketized$row_lengths_1': [0, x % 3, 2 - x % 3], - } for x in range(1, 10)]), + tf.io.RaggedFeature.RowLengths( + "row_lengths" + ) # pytype: disable=attribute-error + ], + ), + } + ), + expected_data=[ + { + "x_bucketized$ragged_values": [(x - 1) // 3, (9 - x) // 3], + "x_bucketized$row_lengths_1": [0, x % 3, 2 - x % 3], + } + for x in range(1, 10) + ], + ), ] _BUCKETIZE_PER_KEY_TEST_CASES = [ dict( - testcase_name='dense', - input_data=[{ - 'x': x, - 'key': 'a' if x < 50 else 'b' - } for x in range(1, 100)], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string) - }), - expected_data=[{ - 'x_bucketized': - _compute_simple_per_key_bucket(x, 'a' if x < 50 else 'b') - } for x in range(1, 100)], + testcase_name="dense", + input_data=[{"x": x, "key": "a" if x < 50 else "b"} for x in range(1, 100)], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ), + expected_data=[ + {"x_bucketized": _compute_simple_per_key_bucket(x, "a" if x < 50 else "b")} + for x in range(1, 100) + ], expected_metadata=tft.DatasetMetadata.from_feature_spec( { - 'x_bucketized': tf.io.FixedLenFeature([], tf.int64), - }, { - 'x_bucketized': - schema_pb2.IntDomain(min=0, max=2, is_categorical=True), - })), + "x_bucketized": tf.io.FixedLenFeature([], tf.int64), + }, + { + "x_bucketized": schema_pb2.IntDomain(min=0, max=2, is_categorical=True), + }, + ), + ), dict( - testcase_name='sparse', - input_data=[{ - 'x': [x], - 'idx0': [0], - 'idx1': [0], - 'key': ['a'] if x < 50 else ['b'] - } for x in range(1, 100)], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.SparseFeature(['idx0', 'idx1'], 'x', tf.float32, (2, 2)), - 'key': tf.io.VarLenFeature(tf.string) - }), - expected_data=[{ - 'x_bucketized$sparse_values': [ - _compute_simple_per_key_bucket(x, 'a' if x < 50 else 'b') - ], - 'x_bucketized$sparse_indices_0': [0], - 'x_bucketized$sparse_indices_1': [0], - } for x in range(1, 100)], + testcase_name="sparse", + input_data=[ + {"x": [x], "idx0": [0], "idx1": [0], "key": ["a"] if x < 50 else ["b"]} + for x in range(1, 100) + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.SparseFeature(["idx0", "idx1"], "x", tf.float32, (2, 2)), + "key": tf.io.VarLenFeature(tf.string), + } + ), + expected_data=[ + { + "x_bucketized$sparse_values": [ + _compute_simple_per_key_bucket(x, "a" if x < 50 else "b") + ], + "x_bucketized$sparse_indices_0": [0], + "x_bucketized$sparse_indices_1": [0], + } + for x in range(1, 100) + ], expected_metadata=tft.DatasetMetadata.from_feature_spec( { - 'x_bucketized': - tf.io.SparseFeature([ - 'x_bucketized$sparse_indices_0', - 'x_bucketized$sparse_indices_1' - ], - 'x_bucketized$sparse_values', - tf.int64, (None, None), - already_sorted=True), - }, { - 'x_bucketized': - schema_pb2.IntDomain(min=0, max=2, is_categorical=True), - })), + "x_bucketized": tf.io.SparseFeature( + ["x_bucketized$sparse_indices_0", "x_bucketized$sparse_indices_1"], + "x_bucketized$sparse_values", + tf.int64, + (None, None), + already_sorted=True, + ), + }, + { + "x_bucketized": schema_pb2.IntDomain(min=0, max=2, is_categorical=True), + }, + ), + ), dict( - testcase_name='dense_weighted', - input_data=[{ - 'x': x, - 'key': 'a' if x < 50 else 'b', - 'weights': 0 if x in _WEIGHTED_PER_KEY_0_RANGE else 1, - } for x in range(1, 100)], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string), - 'weights': tf.io.FixedLenFeature([], tf.float32), - }), - expected_data=[{ - 'x_bucketized': - _compute_simple_per_key_bucket( - x, 'a' if x < 50 else 'b', weighted=True) - } for x in range(1, 100)], + testcase_name="dense_weighted", + input_data=[ + { + "x": x, + "key": "a" if x < 50 else "b", + "weights": 0 if x in _WEIGHTED_PER_KEY_0_RANGE else 1, + } + for x in range(1, 100) + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + "weights": tf.io.FixedLenFeature([], tf.float32), + } + ), + expected_data=[ + { + "x_bucketized": _compute_simple_per_key_bucket( + x, "a" if x < 50 else "b", weighted=True + ) + } + for x in range(1, 100) + ], expected_metadata=tft.DatasetMetadata.from_feature_spec( { - 'x_bucketized': tf.io.FixedLenFeature([], tf.int64), - }, { - 'x_bucketized': - schema_pb2.IntDomain(min=0, max=2, is_categorical=True), - })), + "x_bucketized": tf.io.FixedLenFeature([], tf.int64), + }, + { + "x_bucketized": schema_pb2.IntDomain(min=0, max=2, is_categorical=True), + }, + ), + ), dict( - testcase_name='ragged', - input_data=[{ - 'val': [x, x], - 'row_lengths': [x % 3, 2 - (x % 3)], - 'key_val': ['a', 'a'] if x < 50 else ['b', 'b'], - 'key_row_lengths': [x % 3, 2 - (x % 3)], - } for x in range(1, 100)], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.RaggedFeature( + testcase_name="ragged", + input_data=[ + { + "val": [x, x], + "row_lengths": [x % 3, 2 - (x % 3)], + "key_val": ["a", "a"] if x < 50 else ["b", "b"], + "key_row_lengths": [x % 3, 2 - (x % 3)], + } + for x in range(1, 100) + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.RaggedFeature( tf.int64, - value_key='val', + value_key="val", partitions=[ - tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error - ]), - 'key': - tf.io.RaggedFeature( + tf.io.RaggedFeature.RowLengths( + "row_lengths" + ) # pytype: disable=attribute-error + ], + ), + "key": tf.io.RaggedFeature( tf.string, - value_key='key_val', + value_key="key_val", partitions=[ - tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error - ]), - }), - expected_data=[{ - 'x_bucketized$ragged_values': [ - _compute_simple_per_key_bucket(x, 'a' if x < 50 else 'b'), - ] * 2, - 'x_bucketized$row_lengths_1': [x % 3, 2 - (x % 3)], - } for x in range(1, 100)], + tf.io.RaggedFeature.RowLengths( + "key_row_lengths" + ) # pytype: disable=attribute-error + ], + ), + } + ), + expected_data=[ + { + "x_bucketized$ragged_values": [ + _compute_simple_per_key_bucket(x, "a" if x < 50 else "b"), + ] + * 2, + "x_bucketized$row_lengths_1": [x % 3, 2 - (x % 3)], + } + for x in range(1, 100) + ], expected_metadata=tft.DatasetMetadata.from_feature_spec( { - 'x_bucketized': - tf.io.RaggedFeature( - tf.int64, - value_key='x_bucketized$ragged_values', - partitions=[ - tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error - 'x_bucketized$row_lengths_1') - ]), + "x_bucketized": tf.io.RaggedFeature( + tf.int64, + value_key="x_bucketized$ragged_values", + partitions=[ + tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error + "x_bucketized$row_lengths_1" + ) + ], + ), }, { - 'x_bucketized': - schema_pb2.IntDomain(min=0, max=2, is_categorical=True), - })), + "x_bucketized": schema_pb2.IntDomain(min=0, max=2, is_categorical=True), + }, + ), + ), dict( - testcase_name='ragged_weighted', - input_data=[{ - 'val': [x, x], - 'row_lengths': [2 - (x % 3), x % 3], - 'key_val': ['a', 'a'] if x < 50 else ['b', 'b'], - 'key_row_lengths': [ - 2 - (x % 3), - x % 3, - ], - 'weights_val': - ([0, 0] if x in _WEIGHTED_PER_KEY_0_RANGE else [1, 1]), - 'weights_row_lengths': [ - 2 - (x % 3), - x % 3, - ], - } for x in range(1, 100)], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.RaggedFeature( + testcase_name="ragged_weighted", + input_data=[ + { + "val": [x, x], + "row_lengths": [2 - (x % 3), x % 3], + "key_val": ["a", "a"] if x < 50 else ["b", "b"], + "key_row_lengths": [ + 2 - (x % 3), + x % 3, + ], + "weights_val": ([0, 0] if x in _WEIGHTED_PER_KEY_0_RANGE else [1, 1]), + "weights_row_lengths": [ + 2 - (x % 3), + x % 3, + ], + } + for x in range(1, 100) + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.RaggedFeature( tf.int64, - value_key='val', + value_key="val", partitions=[ - tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error - ]), - 'key': - tf.io.RaggedFeature( + tf.io.RaggedFeature.RowLengths( + "row_lengths" + ) # pytype: disable=attribute-error + ], + ), + "key": tf.io.RaggedFeature( tf.string, - value_key='key_val', + value_key="key_val", partitions=[ - tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error - ]), - 'weights': - tf.io.RaggedFeature( + tf.io.RaggedFeature.RowLengths( + "key_row_lengths" + ) # pytype: disable=attribute-error + ], + ), + "weights": tf.io.RaggedFeature( tf.int64, - value_key='weights_val', + value_key="weights_val", partitions=[ - tf.io.RaggedFeature.RowLengths('weights_row_lengths') # pytype: disable=attribute-error - ]), - }), - expected_data=[{ - 'x_bucketized$ragged_values': [ - _compute_simple_per_key_bucket( - x, 'a' if x < 50 else 'b', weighted=True), - ] * 2, - 'x_bucketized$row_lengths_1': [2 - (x % 3), x % 3], - } for x in range(1, 100)], + tf.io.RaggedFeature.RowLengths( + "weights_row_lengths" + ) # pytype: disable=attribute-error + ], + ), + } + ), + expected_data=[ + { + "x_bucketized$ragged_values": [ + _compute_simple_per_key_bucket( + x, "a" if x < 50 else "b", weighted=True + ), + ] + * 2, + "x_bucketized$row_lengths_1": [2 - (x % 3), x % 3], + } + for x in range(1, 100) + ], expected_metadata=tft.DatasetMetadata.from_feature_spec( { - 'x_bucketized': - tf.io.RaggedFeature( - tf.int64, - value_key='x_bucketized$ragged_values', - partitions=[ - tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error - 'x_bucketized$row_lengths_1') - ]), + "x_bucketized": tf.io.RaggedFeature( + tf.int64, + value_key="x_bucketized$ragged_values", + partitions=[ + tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error + "x_bucketized$row_lengths_1" + ) + ], + ), }, { - 'x_bucketized': - schema_pb2.IntDomain(min=0, max=2, is_categorical=True), - })), + "x_bucketized": schema_pb2.IntDomain(min=0, max=2, is_categorical=True), + }, + ), + ), ] class BucketizeIntegrationTest(tft_unit.TransformTestCase): - - def setUp(self): - self._context = beam_impl.Context(use_deep_copy_optimization=True) - self._context.__enter__() - super().setUp() - - def tearDown(self): - self._context.__exit__() - super().tearDown() - - @tft_unit.parameters( - # Test for all integral types, each type is in a separate testcase to - # increase parallelism of test shards (and reduce test time from ~250 - # seconds to ~80 seconds) - *_construct_test_bucketization_parameters()) - def testBucketization(self, test_inputs, expected_boundaries, do_shuffle, - epsilon, should_apply, is_manual_boundaries, - input_dtype): - test_inputs = list(test_inputs) - - # Shuffle the input to add randomness to input generated with - # simple range(). - if do_shuffle: - random.shuffle(test_inputs) - - def preprocessing_fn(inputs): - x = tf.cast(inputs['x'], input_dtype) - num_buckets = len(expected_boundaries) + 1 - if should_apply: - if is_manual_boundaries: - bucket_boundaries = [expected_boundaries] - else: - bucket_boundaries = tft.quantiles(inputs['x'], num_buckets, epsilon) - result = tft.apply_buckets(x, bucket_boundaries) - else: - result = tft.bucketize(x, num_buckets=num_buckets, epsilon=epsilon) - return {'q_b': result} - - input_data = [{'x': [x]} for x in test_inputs] - - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.FixedLenFeature([1], - tft_unit.canonical_numeric_dtype(input_dtype)) - }) - - # Sort the input based on value, index is used to create expected_data. - indexed_input = enumerate(test_inputs) - # We put all np.nans in the end of the list so that they get assigned to the - # last bucket. - sorted_list = sorted( - indexed_input, key=lambda p: np.inf if np.isnan(p[1]) else p[1]) - - # Expected data has the same size as input, one bucket per input value. - expected_data = [None] * len(test_inputs) - bucket = 0 - for (index, x) in sorted_list: - # Increment the bucket number when crossing the boundary - if (bucket < len(expected_boundaries) and - x >= expected_boundaries[bucket]): - bucket += 1 - expected_data[index] = {'q_b': [bucket]} - - expected_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'q_b': tf.io.FixedLenFeature([1], tf.int64), - }, { - 'q_b': - schema_pb2.IntDomain( - min=0, max=len(expected_boundaries), is_categorical=True), - }) - - @contextlib.contextmanager - def no_assert(): - yield None - - assertion = no_assert() - if input_dtype == tf.float16: - assertion = self.assertRaisesRegex( - TypeError, '.*DataType float16 not in list of allowed values.*' - ) - - with assertion: - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - desired_batch_size=1000) - - @tft_unit.parameters( - # Test for all integral types, each type is in a separate testcase to - # increase parallelism of test shards (and reduce test time from ~250 - # seconds to ~80 seconds) - *_construct_test_bucketization_parameters()) - def testBucketizationElementwise(self, test_inputs, expected_boundaries, - do_shuffle, epsilon, should_apply, - is_manual_boundaries, input_dtype): - test_inputs = list(test_inputs) - - # Shuffle the input to add randomness to input generated with - # simple range(). - if do_shuffle: - random.shuffle(test_inputs) - - def preprocessing_fn(inputs): - x = tf.cast(inputs['x'], input_dtype) - - num_buckets = len(expected_boundaries) + 1 - if should_apply: - if is_manual_boundaries: - bucket_boundaries = [ - expected_boundaries, [2 * b for b in expected_boundaries] - ] - else: - bucket_boundaries = tft.quantiles( - x, num_buckets, epsilon, reduce_instance_dims=False) - bucket_boundaries = tf.unstack(bucket_boundaries, axis=0) - - result = [] - for i, boundaries in enumerate(bucket_boundaries): - boundaries = tf.cast(boundaries, tf.float32) - result.append( - tft.apply_buckets(x[:, i], tf.expand_dims(boundaries, axis=0))) - result = tf.stack(result, axis=1) - - else: - result = tft.bucketize( - x, num_buckets=num_buckets, epsilon=epsilon, elementwise=True) - return {'q_b': result} - - input_data = [{'x': [x, 2 * x]} for x in test_inputs] - - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.FixedLenFeature([2], - tft_unit.canonical_numeric_dtype(input_dtype)) - }) - - # Sort the input based on value, index is used to create expected_data. - sorted_list = sorted(enumerate(test_inputs), key=lambda p: p[1]) - - # Expected data has the same size as input, one bucket per input value. - expected_data = [[None, None]] * len(test_inputs) - bucket = 0 - - for (index, x) in sorted_list: - # Increment the bucket number when crossing the boundary - if (bucket < len(expected_boundaries) and - x >= expected_boundaries[bucket]): - bucket += 1 - expected_data[index] = {'q_b': [bucket, bucket]} - - expected_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'q_b': tf.io.FixedLenFeature([2], tf.int64), - }, None) - - @contextlib.contextmanager - def no_assert(): - yield None - - assertion = no_assert() - if input_dtype == tf.float16: - assertion = self.assertRaisesRegex( - TypeError, '.*DataType float16 not in list of allowed values.*' - ) - - with assertion: - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - desired_batch_size=1000) - - @tft_unit.named_parameters(*_BUCKETIZE_COMPOSITE_INPUT_TEST_CASES) - def testBucketizeCompositeInput(self, input_data, input_metadata, - expected_data): - - def preprocessing_fn(inputs): - return { - 'x_bucketized': - tft.bucketize(inputs['x'], num_buckets=3, epsilon=0.00001) - } - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data) - - # Test for all numerical types, each type is in a separate testcase to - # increase parallelism of test shards and reduce test time. - @tft_unit.parameters( - (tf.int32, False), - (tf.int64, False), - (tf.float32, False), - (tf.float32, True), - (tf.float64, False), - (tf.float64, True), - (tf.double, False), - (tf.double, True), - # TODO(b/64836936): Enable test after bucket inconsistency is fixed. - # (tf.float16, False) - ) - def testQuantileBucketsWithWeights(self, input_dtype, with_nans): - - def analyzer_fn(inputs): - return { - 'q_b': - tft.quantiles( - tf.cast(inputs['x'], input_dtype), - num_buckets=3, - epsilon=0.00001, - weights=inputs['weights']) - } - - input_data = [{'x': [x], 'weights': [x / 100.]} for x in range(1, 3000)] - if with_nans: - input_data += [{ - 'x': [np.nan], - 'weights': [100000] - }, { - 'x': [100000], - 'weights': [np.nan] - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.FixedLenFeature( - [1], tft_unit.canonical_numeric_dtype(input_dtype)), - 'weights': - tf.io.FixedLenFeature([1], tf.float32) - }) - # The expected data has 2 boundaries that divides the data into 3 buckets. - expected_outputs = {'q_b': np.array([[1732, 2449]], np.float32)} - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=1000) - - # Test for all numerical types, each type is in a separate testcase to - # increase parallelism of test shards and reduce test time. - @tft_unit.parameters( - (tf.int32,), - (tf.int64,), - (tf.float32,), - (tf.float64,), - (tf.double,), - # TODO(b/64836936): Enable test after bucket inconsistency is fixed. - # (tf.float16,) - ) - def testElementwiseQuantileBucketsWithWeights(self, input_dtype): - - def analyzer_fn(inputs): - return { - 'q_b': - tft.quantiles( - tf.cast(inputs['x'], input_dtype), - num_buckets=3, - epsilon=0.00001, - weights=inputs['weights'], - reduce_instance_dims=False) - } - - input_data = [{ - 'x': [[x, 2 * x], [2 * x, x]], - 'weights': [x / 100.] - } for x in range(1, 3000)] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.FixedLenFeature( - [2, 2], tft_unit.canonical_numeric_dtype(input_dtype)), - 'weights': - tf.io.FixedLenFeature([1], tf.float32) - }) - # The expected data has 2 boundaries that divides the data into 3 buckets. - expected_outputs = { - 'q_b': - np.array( - [[[1732, 2449], [3464, 4898]], [[3464, 4898], [1732, 2449]]], - np.float32) - } - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=1000) - - # Test for all numerical types, each type is in a separate testcase to - # increase parallelism of test shards and reduce test time. - @tft_unit.parameters( - (tf.int32,), - (tf.int64,), - (tf.float32,), - (tf.float64,), - (tf.double,), - # TODO(b/64836936): Enable test after bucket inconsistency is fixed. - # (tf.float16,) - ) - def testQuantileBuckets(self, input_dtype): - - def analyzer_fn(inputs): - return { - 'q_b': - tft.quantiles( - tf.cast(inputs['x'], input_dtype), - num_buckets=3, - epsilon=0.00001) - } - - # NOTE: We force 3 batches: data has 3000 elements and we request a batch - # size of 1000. - input_data = [{'x': [x]} for x in range(1, 3000)] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.FixedLenFeature([1], - tft_unit.canonical_numeric_dtype(input_dtype)) - }) - # The expected data has 2 boundaries that divides the data into 3 buckets. - expected_outputs = {'q_b': np.array([[1000, 2000]], np.float32)} - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=1000) - - def testQuantilesPerKey(self): - - def analyzer_fn(inputs): - key_vocab, q_b, scale_factor_per_key, shift_per_key, num_buckets = ( - analyzers._quantiles_per_key( - inputs['x'], inputs['key'], num_buckets=3, epsilon=0.00001)) - return { - 'key_vocab': key_vocab, - 'q_b': q_b, - 'scale_factor_per_key': scale_factor_per_key, - 'shift_per_key': shift_per_key, - 'num_buckets': num_buckets, - } - - # NOTE: We force 10 batches: data has 100 elements and we request a batch - # size of 10. - input_data = [{ - 'x': [x], - 'key': 'a' if x < 50 else 'b' - } for x in range(1, 100)] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([1], tf.int64), - 'key': tf.io.FixedLenFeature([], tf.string) - }) - # The expected data has 2 boundaries that divides the data into 3 buckets. - expected_outputs = { - 'key_vocab': np.array([b'a', b'b'], object), - 'q_b': np.array([0., 1., 2.], np.float32), - 'scale_factor_per_key': np.array([0.0625, 0.05882353], np.float32), - 'shift_per_key': np.array([-1.0625, -2.88235283], np.float32), - 'num_buckets': np.array(3, np.int64), - } - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=10) - - @tft_unit.named_parameters(*_BUCKETIZE_PER_KEY_TEST_CASES) - def testBucketizePerKey(self, input_data, input_metadata, expected_data, - expected_metadata): - - def preprocessing_fn(inputs): - weights = inputs.get('weights', None) - x_bucketized = tft.bucketize_per_key( - inputs['x'], - inputs['key'], - num_buckets=3, - epsilon=0.00001, - weights=weights) - return {'x_bucketized': x_bucketized} - - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - def testBucketizePerKeyWithInfrequentKeys(self): - - def preprocessing_fn(inputs): - x_bucketized = tft.bucketize_per_key( - inputs['x'], inputs['key'], num_buckets=4, epsilon=0.00001) - return {'x': inputs['x'], 'x_bucketized': x_bucketized} - - input_data = [ - {'x': [], 'key': []}, - {'x': [5, 6], 'key': ['a', 'a']}, - {'x': [7], 'key': ['a']}, - {'x': [12], 'key': ['b']}, - {'x': [13], 'key': ['b']}, - {'x': [15], 'key': ['c']}, - {'x': [2], 'key': ['d']}, - {'x': [4], 'key': ['d']}, - {'x': [6], 'key': ['d']}, - {'x': [8], 'key': ['d']}, - {'x': [2], 'key': ['e']}, - {'x': [4], 'key': ['e']}, - {'x': [6], 'key': ['e']}, - {'x': [8], 'key': ['e']}, - {'x': [10], 'key': ['e']}, - {'x': [11], 'key': ['e']}, - {'x': [12], 'key': ['e']}, - {'x': [13], 'key': ['e']} - ] # pyformat: disable - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.VarLenFeature(tf.float32), - 'key': tf.io.VarLenFeature(tf.string) - }) - expected_data = [ - {'x': [], 'x_bucketized': []}, - {'x': [5, 6], 'x_bucketized': [1, 2]}, - {'x': [7], 'x_bucketized': [3]}, - {'x': [12], 'x_bucketized': [1]}, - {'x': [13], 'x_bucketized': [3]}, - {'x': [15], 'x_bucketized': [1]}, - {'x': [2], 'x_bucketized': [0]}, - {'x': [4], 'x_bucketized': [1]}, - {'x': [6], 'x_bucketized': [2]}, - {'x': [8], 'x_bucketized': [3]}, - {'x': [2], 'x_bucketized': [0]}, - {'x': [4], 'x_bucketized': [0]}, - {'x': [6], 'x_bucketized': [1]}, - {'x': [8], 'x_bucketized': [1]}, - {'x': [10], 'x_bucketized': [2]}, - {'x': [11], 'x_bucketized': [2]}, - {'x': [12], 'x_bucketized': [3]}, - {'x': [13], 'x_bucketized': [2]} - ] # pyformat: disable - expected_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'x': tf.io.VarLenFeature(tf.float32), - 'x_bucketized': tf.io.VarLenFeature(tf.int64), - }, { - 'x_bucketized': - schema_pb2.IntDomain(min=0, max=3, is_categorical=True), - }) - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - desired_batch_size=10) - - def _assert_quantile_boundaries(self, - test_inputs, - expected_boundaries, - input_dtype, - num_buckets=None, - num_expected_buckets=None): - - if not num_buckets: - num_buckets = len(expected_boundaries) + 1 - if not num_expected_buckets: - num_expected_buckets = num_buckets - - def analyzer_fn(inputs): - x = tf.cast(inputs['x'], input_dtype) - return {'q_b': tft.quantiles(x, num_buckets, epsilon=0.0001)} - - input_data = [{'x': [x]} for x in test_inputs] - - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.FixedLenFeature([1], - tft_unit.canonical_numeric_dtype(input_dtype)) - }) - - expected_data = {'q_b': expected_boundaries} - - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_data, - desired_batch_size=1000) - - @tft_unit.parameters( - *_construct_test_bucketization_tight_sequence_parameters()) - def testBucketizationForTightSequence(self, test_inputs, expected_boundaries, - dtype, num_buckets, - num_expected_buckets): - self._assert_quantile_boundaries( + def setUp(self): + self._context = beam_impl.Context(use_deep_copy_optimization=True) + self._context.__enter__() + super().setUp() + + def tearDown(self): + self._context.__exit__() + super().tearDown() + + @tft_unit.parameters( + # Test for all integral types, each type is in a separate testcase to + # increase parallelism of test shards (and reduce test time from ~250 + # seconds to ~80 seconds) + *_construct_test_bucketization_parameters() + ) + def testBucketization( + self, test_inputs, expected_boundaries, - dtype, - num_buckets=num_buckets, - num_expected_buckets=num_expected_buckets) - - def testBucketizationEqualDistributionInSequence(self): - # Input pattern is of the form [1, 1, 1, ..., 2, 2, 2, ..., 3, 3, 3, ...] - inputs = [] - for i in range(1, 101): - inputs += [i] * 100 - # Expect 100 equally spaced buckets. - expected_buckets = np.expand_dims( - np.arange(1, 101, dtype=np.float32), axis=0) - self._assert_quantile_boundaries( - inputs, expected_buckets, tf.int32, num_buckets=101) - - def testBucketizationEqualDistributionInterleaved(self): - # Input pattern is of the form [1, 2, 3, ..., 1, 2, 3, ..., 1, 2, 3, ...] - sequence = range(1, 101) - inputs = [] - for _ in range(1, 101): - inputs += sequence - # Expect 100 equally spaced buckets. - expected_buckets = np.expand_dims( - np.arange(1, 101, dtype=np.float32), axis=0) - self._assert_quantile_boundaries( - inputs, expected_buckets, tf.int32, num_buckets=101) - - def testBucketizationSpecificDistribution(self): - # Distribution of input values. - # This distribution is taken from one of the user pipelines. - dist = ( - # Format: ((, ), num-values) - ((0.51, 0.67), 4013), - ((0.67, 0.84), 2321), - ((0.84, 1.01), 7145), - ((1.01, 1.17), 64524), - ((1.17, 1.34), 42886), - ((1.34, 1.51), 154809), - ((1.51, 1.67), 382678), - ((1.67, 1.84), 582744), - ((1.84, 2.01), 252221), - ((2.01, 2.17), 7299)) - - inputs = [] - for (mn, mx), num in dist: - step = (mx - mn) / 100 - for ix in range(num // 100): - inputs += [mn + (ix * step)] - - expected_boundaries = np.array([[2.3084, 3.5638, 5.0972, 7.07]], - dtype=np.float32) - - self._assert_quantile_boundaries( - inputs, expected_boundaries, tf.float32, num_buckets=5) - - -if __name__ == '__main__': - tft_unit.main() + do_shuffle, + epsilon, + should_apply, + is_manual_boundaries, + input_dtype, + ): + test_inputs = list(test_inputs) + + # Shuffle the input to add randomness to input generated with + # simple range(). + if do_shuffle: + random.shuffle(test_inputs) + + def preprocessing_fn(inputs): + x = tf.cast(inputs["x"], input_dtype) + num_buckets = len(expected_boundaries) + 1 + if should_apply: + if is_manual_boundaries: + bucket_boundaries = [expected_boundaries] + else: + bucket_boundaries = tft.quantiles(inputs["x"], num_buckets, epsilon) + result = tft.apply_buckets(x, bucket_boundaries) + else: + result = tft.bucketize(x, num_buckets=num_buckets, epsilon=epsilon) + return {"q_b": result} + + input_data = [{"x": [x]} for x in test_inputs] + + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature( + [1], tft_unit.canonical_numeric_dtype(input_dtype) + ) + } + ) + + # Sort the input based on value, index is used to create expected_data. + indexed_input = enumerate(test_inputs) + # We put all np.nans in the end of the list so that they get assigned to the + # last bucket. + sorted_list = sorted( + indexed_input, key=lambda p: np.inf if np.isnan(p[1]) else p[1] + ) + + # Expected data has the same size as input, one bucket per input value. + expected_data = [None] * len(test_inputs) + bucket = 0 + for index, x in sorted_list: + # Increment the bucket number when crossing the boundary + if bucket < len(expected_boundaries) and x >= expected_boundaries[bucket]: + bucket += 1 + expected_data[index] = {"q_b": [bucket]} + + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "q_b": tf.io.FixedLenFeature([1], tf.int64), + }, + { + "q_b": schema_pb2.IntDomain( + min=0, max=len(expected_boundaries), is_categorical=True + ), + }, + ) + + @contextlib.contextmanager + def no_assert(): + yield None + + assertion = no_assert() + if input_dtype == tf.float16: + assertion = self.assertRaisesRegex( + TypeError, ".*DataType float16 not in list of allowed values.*" + ) + + with assertion: + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + desired_batch_size=1000, + ) + + @tft_unit.parameters( + # Test for all integral types, each type is in a separate testcase to + # increase parallelism of test shards (and reduce test time from ~250 + # seconds to ~80 seconds) + *_construct_test_bucketization_parameters() + ) + def testBucketizationElementwise( + self, + test_inputs, + expected_boundaries, + do_shuffle, + epsilon, + should_apply, + is_manual_boundaries, + input_dtype, + ): + test_inputs = list(test_inputs) + + # Shuffle the input to add randomness to input generated with + # simple range(). + if do_shuffle: + random.shuffle(test_inputs) + + def preprocessing_fn(inputs): + x = tf.cast(inputs["x"], input_dtype) + + num_buckets = len(expected_boundaries) + 1 + if should_apply: + if is_manual_boundaries: + bucket_boundaries = [ + expected_boundaries, + [2 * b for b in expected_boundaries], + ] + else: + bucket_boundaries = tft.quantiles( + x, num_buckets, epsilon, reduce_instance_dims=False + ) + bucket_boundaries = tf.unstack(bucket_boundaries, axis=0) + + result = [] + for i, boundaries in enumerate(bucket_boundaries): + boundaries = tf.cast(boundaries, tf.float32) + result.append( + tft.apply_buckets(x[:, i], tf.expand_dims(boundaries, axis=0)) + ) + result = tf.stack(result, axis=1) + + else: + result = tft.bucketize( + x, num_buckets=num_buckets, epsilon=epsilon, elementwise=True + ) + return {"q_b": result} + + input_data = [{"x": [x, 2 * x]} for x in test_inputs] + + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature( + [2], tft_unit.canonical_numeric_dtype(input_dtype) + ) + } + ) + + # Sort the input based on value, index is used to create expected_data. + sorted_list = sorted(enumerate(test_inputs), key=lambda p: p[1]) + + # Expected data has the same size as input, one bucket per input value. + expected_data = [[None, None]] * len(test_inputs) + bucket = 0 + + for index, x in sorted_list: + # Increment the bucket number when crossing the boundary + if bucket < len(expected_boundaries) and x >= expected_boundaries[bucket]: + bucket += 1 + expected_data[index] = {"q_b": [bucket, bucket]} + + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "q_b": tf.io.FixedLenFeature([2], tf.int64), + }, + None, + ) + + @contextlib.contextmanager + def no_assert(): + yield None + + assertion = no_assert() + if input_dtype == tf.float16: + assertion = self.assertRaisesRegex( + TypeError, ".*DataType float16 not in list of allowed values.*" + ) + + with assertion: + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + desired_batch_size=1000, + ) + + @tft_unit.named_parameters(*_BUCKETIZE_COMPOSITE_INPUT_TEST_CASES) + def testBucketizeCompositeInput(self, input_data, input_metadata, expected_data): + def preprocessing_fn(inputs): + return { + "x_bucketized": tft.bucketize( + inputs["x"], num_buckets=3, epsilon=0.00001 + ) + } + + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn, expected_data + ) + + # Test for all numerical types, each type is in a separate testcase to + # increase parallelism of test shards and reduce test time. + @tft_unit.parameters( + (tf.int32, False), + (tf.int64, False), + (tf.float32, False), + (tf.float32, True), + (tf.float64, False), + (tf.float64, True), + (tf.double, False), + (tf.double, True), + # TODO(b/64836936): Enable test after bucket inconsistency is fixed. + # (tf.float16, False) + ) + def testQuantileBucketsWithWeights(self, input_dtype, with_nans): + def analyzer_fn(inputs): + return { + "q_b": tft.quantiles( + tf.cast(inputs["x"], input_dtype), + num_buckets=3, + epsilon=0.00001, + weights=inputs["weights"], + ) + } + + input_data = [{"x": [x], "weights": [x / 100.0]} for x in range(1, 3000)] + if with_nans: + input_data += [ + {"x": [np.nan], "weights": [100000]}, + {"x": [100000], "weights": [np.nan]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature( + [1], tft_unit.canonical_numeric_dtype(input_dtype) + ), + "weights": tf.io.FixedLenFeature([1], tf.float32), + } + ) + # The expected data has 2 boundaries that divides the data into 3 buckets. + expected_outputs = {"q_b": np.array([[1732, 2449]], np.float32)} + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=1000, + ) + + # Test for all numerical types, each type is in a separate testcase to + # increase parallelism of test shards and reduce test time. + @tft_unit.parameters( + (tf.int32,), + (tf.int64,), + (tf.float32,), + (tf.float64,), + (tf.double,), + # TODO(b/64836936): Enable test after bucket inconsistency is fixed. + # (tf.float16,) + ) + def testElementwiseQuantileBucketsWithWeights(self, input_dtype): + def analyzer_fn(inputs): + return { + "q_b": tft.quantiles( + tf.cast(inputs["x"], input_dtype), + num_buckets=3, + epsilon=0.00001, + weights=inputs["weights"], + reduce_instance_dims=False, + ) + } + + input_data = [ + {"x": [[x, 2 * x], [2 * x, x]], "weights": [x / 100.0]} + for x in range(1, 3000) + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature( + [2, 2], tft_unit.canonical_numeric_dtype(input_dtype) + ), + "weights": tf.io.FixedLenFeature([1], tf.float32), + } + ) + # The expected data has 2 boundaries that divides the data into 3 buckets. + expected_outputs = { + "q_b": np.array( + [[[1732, 2449], [3464, 4898]], [[3464, 4898], [1732, 2449]]], np.float32 + ) + } + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=1000, + ) + + # Test for all numerical types, each type is in a separate testcase to + # increase parallelism of test shards and reduce test time. + @tft_unit.parameters( + (tf.int32,), + (tf.int64,), + (tf.float32,), + (tf.float64,), + (tf.double,), + # TODO(b/64836936): Enable test after bucket inconsistency is fixed. + # (tf.float16,) + ) + def testQuantileBuckets(self, input_dtype): + def analyzer_fn(inputs): + return { + "q_b": tft.quantiles( + tf.cast(inputs["x"], input_dtype), num_buckets=3, epsilon=0.00001 + ) + } + + # NOTE: We force 3 batches: data has 3000 elements and we request a batch + # size of 1000. + input_data = [{"x": [x]} for x in range(1, 3000)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature( + [1], tft_unit.canonical_numeric_dtype(input_dtype) + ) + } + ) + # The expected data has 2 boundaries that divides the data into 3 buckets. + expected_outputs = {"q_b": np.array([[1000, 2000]], np.float32)} + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=1000, + ) + + def testQuantilesPerKey(self): + def analyzer_fn(inputs): + key_vocab, q_b, scale_factor_per_key, shift_per_key, num_buckets = ( + analyzers._quantiles_per_key( + inputs["x"], inputs["key"], num_buckets=3, epsilon=0.00001 + ) + ) + return { + "key_vocab": key_vocab, + "q_b": q_b, + "scale_factor_per_key": scale_factor_per_key, + "shift_per_key": shift_per_key, + "num_buckets": num_buckets, + } + + # NOTE: We force 10 batches: data has 100 elements and we request a batch + # size of 10. + input_data = [{"x": [x], "key": "a" if x < 50 else "b"} for x in range(1, 100)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([1], tf.int64), + "key": tf.io.FixedLenFeature([], tf.string), + } + ) + # The expected data has 2 boundaries that divides the data into 3 buckets. + expected_outputs = { + "key_vocab": np.array([b"a", b"b"], object), + "q_b": np.array([0.0, 1.0, 2.0], np.float32), + "scale_factor_per_key": np.array([0.0625, 0.05882353], np.float32), + "shift_per_key": np.array([-1.0625, -2.88235283], np.float32), + "num_buckets": np.array(3, np.int64), + } + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=10, + ) + + @tft_unit.named_parameters(*_BUCKETIZE_PER_KEY_TEST_CASES) + def testBucketizePerKey( + self, input_data, input_metadata, expected_data, expected_metadata + ): + def preprocessing_fn(inputs): + weights = inputs.get("weights", None) + x_bucketized = tft.bucketize_per_key( + inputs["x"], + inputs["key"], + num_buckets=3, + epsilon=0.00001, + weights=weights, + ) + return {"x_bucketized": x_bucketized} + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testBucketizePerKeyWithInfrequentKeys(self): + def preprocessing_fn(inputs): + x_bucketized = tft.bucketize_per_key( + inputs["x"], inputs["key"], num_buckets=4, epsilon=0.00001 + ) + return {"x": inputs["x"], "x_bucketized": x_bucketized} + + input_data = [ + {"x": [], "key": []}, + {"x": [5, 6], "key": ["a", "a"]}, + {"x": [7], "key": ["a"]}, + {"x": [12], "key": ["b"]}, + {"x": [13], "key": ["b"]}, + {"x": [15], "key": ["c"]}, + {"x": [2], "key": ["d"]}, + {"x": [4], "key": ["d"]}, + {"x": [6], "key": ["d"]}, + {"x": [8], "key": ["d"]}, + {"x": [2], "key": ["e"]}, + {"x": [4], "key": ["e"]}, + {"x": [6], "key": ["e"]}, + {"x": [8], "key": ["e"]}, + {"x": [10], "key": ["e"]}, + {"x": [11], "key": ["e"]}, + {"x": [12], "key": ["e"]}, + {"x": [13], "key": ["e"]}, + ] # pyformat: disable + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.VarLenFeature(tf.float32), + "key": tf.io.VarLenFeature(tf.string), + } + ) + expected_data = [ + {"x": [], "x_bucketized": []}, + {"x": [5, 6], "x_bucketized": [1, 2]}, + {"x": [7], "x_bucketized": [3]}, + {"x": [12], "x_bucketized": [1]}, + {"x": [13], "x_bucketized": [3]}, + {"x": [15], "x_bucketized": [1]}, + {"x": [2], "x_bucketized": [0]}, + {"x": [4], "x_bucketized": [1]}, + {"x": [6], "x_bucketized": [2]}, + {"x": [8], "x_bucketized": [3]}, + {"x": [2], "x_bucketized": [0]}, + {"x": [4], "x_bucketized": [0]}, + {"x": [6], "x_bucketized": [1]}, + {"x": [8], "x_bucketized": [1]}, + {"x": [10], "x_bucketized": [2]}, + {"x": [11], "x_bucketized": [2]}, + {"x": [12], "x_bucketized": [3]}, + {"x": [13], "x_bucketized": [2]}, + ] # pyformat: disable + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.VarLenFeature(tf.float32), + "x_bucketized": tf.io.VarLenFeature(tf.int64), + }, + { + "x_bucketized": schema_pb2.IntDomain(min=0, max=3, is_categorical=True), + }, + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + desired_batch_size=10, + ) + + def _assert_quantile_boundaries( + self, + test_inputs, + expected_boundaries, + input_dtype, + num_buckets=None, + num_expected_buckets=None, + ): + if not num_buckets: + num_buckets = len(expected_boundaries) + 1 + if not num_expected_buckets: + num_expected_buckets = num_buckets + + def analyzer_fn(inputs): + x = tf.cast(inputs["x"], input_dtype) + return {"q_b": tft.quantiles(x, num_buckets, epsilon=0.0001)} + + input_data = [{"x": [x]} for x in test_inputs] + + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature( + [1], tft_unit.canonical_numeric_dtype(input_dtype) + ) + } + ) + + expected_data = {"q_b": expected_boundaries} + + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_data, + desired_batch_size=1000, + ) + + @tft_unit.parameters(*_construct_test_bucketization_tight_sequence_parameters()) + def testBucketizationForTightSequence( + self, test_inputs, expected_boundaries, dtype, num_buckets, num_expected_buckets + ): + self._assert_quantile_boundaries( + test_inputs, + expected_boundaries, + dtype, + num_buckets=num_buckets, + num_expected_buckets=num_expected_buckets, + ) + + def testBucketizationEqualDistributionInSequence(self): + # Input pattern is of the form [1, 1, 1, ..., 2, 2, 2, ..., 3, 3, 3, ...] + inputs = [] + for i in range(1, 101): + inputs += [i] * 100 + # Expect 100 equally spaced buckets. + expected_buckets = np.expand_dims(np.arange(1, 101, dtype=np.float32), axis=0) + self._assert_quantile_boundaries( + inputs, expected_buckets, tf.int32, num_buckets=101 + ) + + def testBucketizationEqualDistributionInterleaved(self): + # Input pattern is of the form [1, 2, 3, ..., 1, 2, 3, ..., 1, 2, 3, ...] + sequence = range(1, 101) + inputs = [] + for _ in range(1, 101): + inputs += sequence + # Expect 100 equally spaced buckets. + expected_buckets = np.expand_dims(np.arange(1, 101, dtype=np.float32), axis=0) + self._assert_quantile_boundaries( + inputs, expected_buckets, tf.int32, num_buckets=101 + ) + + def testBucketizationSpecificDistribution(self): + # Distribution of input values. + # This distribution is taken from one of the user pipelines. + dist = ( + # Format: ((, ), num-values) + ((0.51, 0.67), 4013), + ((0.67, 0.84), 2321), + ((0.84, 1.01), 7145), + ((1.01, 1.17), 64524), + ((1.17, 1.34), 42886), + ((1.34, 1.51), 154809), + ((1.51, 1.67), 382678), + ((1.67, 1.84), 582744), + ((1.84, 2.01), 252221), + ((2.01, 2.17), 7299), + ) + + inputs = [] + for (mn, mx), num in dist: + step = (mx - mn) / 100 + for ix in range(num // 100): + inputs += [mn + (ix * step)] + + expected_boundaries = np.array( + [[2.3084, 3.5638, 5.0972, 7.07]], dtype=np.float32 + ) + + self._assert_quantile_boundaries( + inputs, expected_boundaries, tf.float32, num_buckets=5 + ) + + +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/cached_impl_test.py b/tensorflow_transform/beam/cached_impl_test.py index fdaf1da..c03d9ca 100644 --- a/tensorflow_transform/beam/cached_impl_test.py +++ b/tensorflow_transform/beam/cached_impl_test.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2018 Google Inc. All Rights Reserved. # @@ -18,27 +17,24 @@ import functools import os import struct -from typing import Callable, Mapping, List import uuid +from typing import Callable, List, Mapping + import apache_beam as beam -from apache_beam.testing import util as beam_test_util import numpy as np - import tensorflow as tf -import tensorflow_transform as tft -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import common_types -from tensorflow_transform import impl_helper -from tensorflow_transform import nodes -import tensorflow_transform.beam as tft_beam -from tensorflow_transform.beam import analysis_graph_builder -from tensorflow_transform.beam import analyzer_cache -from tensorflow_transform.beam import tft_unit -from tensorflow_transform.tf_metadata import dataset_metadata +from apache_beam.testing import util as beam_test_util + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple +import tensorflow_transform as tft +import tensorflow_transform.beam as tft_beam +from tensorflow_transform import analyzer_nodes, common_types, impl_helper, nodes +from tensorflow_transform.beam import analysis_graph_builder, analyzer_cache, tft_unit +from tensorflow_transform.tf_metadata import dataset_metadata + mock = tf.compat.v1.test.mock _SINGLE_PHASE_NUM_SAVED_MODELS = 2 @@ -46,38 +42,43 @@ def _make_cache_key(cache_identifier): - return analyzer_cache._CACHE_VERSION + cache_identifier + b'-HASH' + return analyzer_cache._CACHE_VERSION + cache_identifier + b"-HASH" def _encode_vocabulary_accumulator(token_bytes, value_bytes): - return struct.pack('qq{}s{}s'.format(len(token_bytes), len(value_bytes)), - len(token_bytes), len(value_bytes), token_bytes, - value_bytes) + return struct.pack( + f"qq{len(token_bytes)}s{len(value_bytes)}s", + len(token_bytes), + len(value_bytes), + token_bytes, + value_bytes, + ) def _preprocessing_fn_for_common_optimize_traversal(inputs): - _ = tft.vocabulary(inputs['s']) - x = inputs['x'] - x_mean = tft.mean(x, name='x') - x_square_deviations = tf.square(x - x_mean) + _ = tft.vocabulary(inputs["s"]) + x = inputs["x"] + x_mean = tft.mean(x, name="x") + x_square_deviations = tf.square(x - x_mean) - # 2nd analysis phase defined here. - x_var = tft.mean(x_square_deviations, name='x_square_deviations') - x_normalized = (x - x_mean) / tf.sqrt(x_var) - return {'x_normalized': x_normalized} + # 2nd analysis phase defined here. + x_var = tft.mean(x_square_deviations, name="x_square_deviations") + x_normalized = (x - x_mean) / tf.sqrt(x_var) + return {"x_normalized": x_normalized} _OPTIMIZE_TRAVERSAL_COMMON_CASE = dict( - testcase_name='common', + testcase_name="common", feature_spec={ - 'x': tf.io.FixedLenFeature([], tf.float32), - 's': tf.io.FixedLenFeature([], tf.string) + "x": tf.io.FixedLenFeature([], tf.float32), + "s": tf.io.FixedLenFeature([], tf.string), }, preprocessing_fn=_preprocessing_fn_for_common_optimize_traversal, - dataset_input_cache_dicts=[{ - _make_cache_key(b'CacheableCombineAccumulate[x#mean_and_var]'): - 'cache hit', - }], + dataset_input_cache_dicts=[ + { + _make_cache_key(b"CacheableCombineAccumulate[x#mean_and_var]"): "cache hit", + } + ], expected_dot_graph_str=r"""digraph G { directed=True; node [shape=Mrecord]; @@ -170,21 +171,23 @@ def _preprocessing_fn_for_common_optimize_traversal(inputs): "EncodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]" [label="{EncodeCache|coder: \<_VocabularyAccumulatorCoder\>|label: EncodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]|partitionable: True}"]; "VocabularyAccumulate[vocabulary][AnalysisIndex1]" -> "EncodeCache[VocabularyAccumulate[vocabulary]][AnalysisIndex1]"; } -""") +""", +) _OPTIMIZE_TRAVERSAL_MULTI_PHASE_FULL_CACHE_HIT_CASE = dict( - testcase_name='multi_phase_full_cache_coverage', + testcase_name="multi_phase_full_cache_coverage", feature_spec={ - 'x': tf.io.FixedLenFeature([], tf.float32), - 's': tf.io.FixedLenFeature([], tf.string) + "x": tf.io.FixedLenFeature([], tf.float32), + "s": tf.io.FixedLenFeature([], tf.string), }, preprocessing_fn=_preprocessing_fn_for_common_optimize_traversal, - dataset_input_cache_dicts=[{ - _make_cache_key(b'CacheableCombineAccumulate[x#mean_and_var]'): - 'cache hit', - _make_cache_key(b'VocabularyAccumulate[vocabulary]'): - 'cache hit', - }] * 2, + dataset_input_cache_dicts=[ + { + _make_cache_key(b"CacheableCombineAccumulate[x#mean_and_var]"): "cache hit", + _make_cache_key(b"VocabularyAccumulate[vocabulary]"): "cache hit", + } + ] + * 2, expected_dot_graph_str=r"""digraph G { directed=True; node [shape=Mrecord]; @@ -253,106 +256,110 @@ def _preprocessing_fn_for_common_optimize_traversal(inputs): "CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder]" -> CreateSavedModel; "CreateTensorBinding[x_square_deviations#mean_and_var#Placeholder_1]" -> CreateSavedModel; } -""") +""", +) _TF_VERSION_NAMED_PARAMETERS = [ - dict(testcase_name='CompatV1', use_tf_compat_v1=True), - dict(testcase_name='V2', use_tf_compat_v1=False), + dict(testcase_name="CompatV1", use_tf_compat_v1=True), + dict(testcase_name="V2", use_tf_compat_v1=False), ] def _preprocessing_fn_for_generalized_chained_ptransforms(inputs): - - class FakeChainablePartitionable( - tfx_namedtuple.namedtuple('FakeChainablePartitionable', ['label']), - nodes.OperationDef): - - def __new__(cls): - scope = tf.compat.v1.get_default_graph().get_name_scope() - label = '{}[{}]'.format(cls.__name__, scope) - return super(FakeChainablePartitionable, cls).__new__(cls, label=label) - - @property - def num_outputs(self): - return 1 - - @property - def is_partitionable(self): - return True - - class FakeChainableCacheable( - tfx_namedtuple.namedtuple('FakeChainableCacheable', ['label']), - nodes.OperationDef): - - def __new__(cls): - scope = tf.compat.v1.get_default_graph().get_name_scope() - label = '{}[{}]'.format(cls.__name__, scope) - return super(FakeChainableCacheable, cls).__new__(cls, label=label) - - @property - def num_outputs(self): - return 1 - - @property - def is_partitionable(self): - return True - - @property - def cache_coder(self): - return 'Not-a-coder-but-thats-ok!' - - class FakeChainable( - tfx_namedtuple.namedtuple('FakeChainable', ['label']), - nodes.OperationDef): - - def __new__(cls): - scope = tf.compat.v1.get_default_graph().get_name_scope() - label = '{}[{}]'.format(cls.__name__, scope) - return super(FakeChainable, cls).__new__(cls, label=label) - - @property - def num_outputs(self): - return 1 - - @property - def is_partitionable(self): - return False - - with tf.compat.v1.name_scope('x'): - input_values_node = nodes.apply_operation( - analyzer_nodes.TensorSource, tensors=[inputs['x']]) - with tf.compat.v1.name_scope('partitionable1'): - partitionable_outputs = nodes.apply_multi_output_operation( - FakeChainablePartitionable, input_values_node) - with tf.compat.v1.name_scope('cacheable1'): - intermediate_cached_value_node = nodes.apply_multi_output_operation( - FakeChainableCacheable, *partitionable_outputs) - with tf.compat.v1.name_scope('partitionable2'): - partitionable_outputs = nodes.apply_multi_output_operation( - FakeChainablePartitionable, *intermediate_cached_value_node) - with tf.compat.v1.name_scope('cacheable2'): - cached_value_node = nodes.apply_multi_output_operation( - FakeChainableCacheable, *partitionable_outputs) - with tf.compat.v1.name_scope('partitionable3'): - output_value_node = nodes.apply_multi_output_operation( - FakeChainablePartitionable, *cached_value_node) - with tf.compat.v1.name_scope('merge'): - output_value_node = nodes.apply_operation(FakeChainable, - *output_value_node) - with tf.compat.v1.name_scope('not-cacheable'): - non_cached_output = nodes.apply_operation(FakeChainable, - input_values_node) - x_chained = analyzer_nodes.bind_future_as_tensor( - output_value_node, analyzer_nodes.TensorInfo(tf.float32, (17, 27), - None)) - x_plain = analyzer_nodes.bind_future_as_tensor( - non_cached_output, analyzer_nodes.TensorInfo(tf.int64, (7, 13), None)) - return {'x_chained': x_chained, 'x_plain': x_plain} + class FakeChainablePartitionable( + tfx_namedtuple.namedtuple("FakeChainablePartitionable", ["label"]), + nodes.OperationDef, + ): + def __new__(cls): + scope = tf.compat.v1.get_default_graph().get_name_scope() + label = f"{cls.__name__}[{scope}]" + return super(FakeChainablePartitionable, cls).__new__(cls, label=label) + + @property + def num_outputs(self): + return 1 + + @property + def is_partitionable(self): + return True + + class FakeChainableCacheable( + tfx_namedtuple.namedtuple("FakeChainableCacheable", ["label"]), + nodes.OperationDef, + ): + def __new__(cls): + scope = tf.compat.v1.get_default_graph().get_name_scope() + label = f"{cls.__name__}[{scope}]" + return super(FakeChainableCacheable, cls).__new__(cls, label=label) + + @property + def num_outputs(self): + return 1 + + @property + def is_partitionable(self): + return True + + @property + def cache_coder(self): + return "Not-a-coder-but-thats-ok!" + + class FakeChainable( + tfx_namedtuple.namedtuple("FakeChainable", ["label"]), nodes.OperationDef + ): + def __new__(cls): + scope = tf.compat.v1.get_default_graph().get_name_scope() + label = f"{cls.__name__}[{scope}]" + return super(FakeChainable, cls).__new__(cls, label=label) + + @property + def num_outputs(self): + return 1 + + @property + def is_partitionable(self): + return False + + with tf.compat.v1.name_scope("x"): + input_values_node = nodes.apply_operation( + analyzer_nodes.TensorSource, tensors=[inputs["x"]] + ) + with tf.compat.v1.name_scope("partitionable1"): + partitionable_outputs = nodes.apply_multi_output_operation( + FakeChainablePartitionable, input_values_node + ) + with tf.compat.v1.name_scope("cacheable1"): + intermediate_cached_value_node = nodes.apply_multi_output_operation( + FakeChainableCacheable, *partitionable_outputs + ) + with tf.compat.v1.name_scope("partitionable2"): + partitionable_outputs = nodes.apply_multi_output_operation( + FakeChainablePartitionable, *intermediate_cached_value_node + ) + with tf.compat.v1.name_scope("cacheable2"): + cached_value_node = nodes.apply_multi_output_operation( + FakeChainableCacheable, *partitionable_outputs + ) + with tf.compat.v1.name_scope("partitionable3"): + output_value_node = nodes.apply_multi_output_operation( + FakeChainablePartitionable, *cached_value_node + ) + with tf.compat.v1.name_scope("merge"): + output_value_node = nodes.apply_operation(FakeChainable, *output_value_node) + with tf.compat.v1.name_scope("not-cacheable"): + non_cached_output = nodes.apply_operation(FakeChainable, input_values_node) + x_chained = analyzer_nodes.bind_future_as_tensor( + output_value_node, analyzer_nodes.TensorInfo(tf.float32, (17, 27), None) + ) + x_plain = analyzer_nodes.bind_future_as_tensor( + non_cached_output, analyzer_nodes.TensorInfo(tf.int64, (7, 13), None) + ) + return {"x_chained": x_chained, "x_plain": x_plain} _OPTIMIZE_TRAVERSAL_GENERALIZED_CHAINED_PTRANSFORMS_CASE = dict( - testcase_name='generalized_chained_ptransforms', - feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}, + testcase_name="generalized_chained_ptransforms", + feature_spec={"x": tf.io.FixedLenFeature([], tf.float32)}, preprocessing_fn=_preprocessing_fn_for_generalized_chained_ptransforms, dataset_input_cache_dicts=None, expected_dot_graph_str=r"""digraph G { @@ -421,7 +428,8 @@ def is_partitionable(self): "FakeChainableCacheable[x/cacheable2][AnalysisIndex1]" -> "EncodeCache[FakeChainableCacheable[x/cacheable2]][AnalysisIndex1]"; InstrumentDatasetCache [label="{InstrumentDatasetCache|input_cache_dataset_keys: []|num_encode_cache: 4|num_decode_cache: 0|label: InstrumentDatasetCache|partitionable: True}"]; } -""") +""", +) _OPTIMIZE_TRAVERSAL_TEST_CASES = [ _OPTIMIZE_TRAVERSAL_COMMON_CASE, @@ -431,1323 +439,1350 @@ def is_partitionable(self): def mock_out_cache_hash(test_fn): + def _make_next_hashed_path_for_test(*unused_args): + return b"HASH" - def _make_next_hashed_path_for_test(*unused_args): - return b'HASH' + def _run_test(*args, **kwargs): + with mock.patch.object( + analysis_graph_builder._OptimizeVisitor, + "_make_next_hashed_path", + _make_next_hashed_path_for_test, + ): + return test_fn(*args, **kwargs) - def _run_test(*args, **kwargs): - with mock.patch.object(analysis_graph_builder._OptimizeVisitor, - '_make_next_hashed_path', - _make_next_hashed_path_for_test): - return test_fn(*args, **kwargs) - - return _run_test + return _run_test _RunPipelineResult = tfx_namedtuple.namedtuple( # pylint: disable=invalid-name - '_RunPipelineResult', ['cache_output', 'metrics', 'transform_output'] + "_RunPipelineResult", ["cache_output", "metrics", "transform_output"] ) class CachedImplTest(tft_unit.TransformTestCase): + def setUp(self): + super().setUp() + self.base_test_dir = os.path.join( + os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", self.get_temp_dir()), + self._testMethodName, + ) + self._cache_dir = os.path.join(self.base_test_dir, "cache") + self._running_index = 0 + + self._context = tft_beam.Context(temp_dir=self.get_temp_dir()) + self._context.__enter__() + + def tearDown(self): + self._context.__exit__() + super().tearDown() + + def _get_running_index(self): + self._running_index += 1 + return self._running_index + + def _publish_rendered_dot_graph_file_from_leaf_nodes(self, leaf_nodes): + dot_string = nodes.get_dot_graph(leaf_nodes).to_string() + tf.io.gfile.makedirs(self.base_test_dir) + output_file = os.path.join( + self.base_test_dir, + f"rendered_graph_{self._get_running_index()}.svg", + ) + self.WriteRenderedDotFile(dot_string, output_file=output_file) + return dot_string - def setUp(self): - super().setUp() - self.base_test_dir = os.path.join( - os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()), - self._testMethodName) - self._cache_dir = os.path.join(self.base_test_dir, 'cache') - self._running_index = 0 - - self._context = tft_beam.Context(temp_dir=self.get_temp_dir()) - self._context.__enter__() - - def tearDown(self): - self._context.__exit__() - super().tearDown() - - def _get_running_index(self): - self._running_index += 1 - return self._running_index - - def _publish_rendered_dot_graph_file_from_leaf_nodes(self, leaf_nodes): - dot_string = nodes.get_dot_graph(leaf_nodes).to_string() - tf.io.gfile.makedirs(self.base_test_dir) - output_file = os.path.join( - self.base_test_dir, - 'rendered_graph_{}.svg'.format(self._get_running_index()), - ) - self.WriteRenderedDotFile(dot_string, output_file=output_file) - return dot_string - - def _publish_rendered_dot_graph_file(self, - preprocessing_fn, - feature_spec, - dataset_keys, - pcoll_cache_dict, - use_tf_compat_v1=True): - specs = feature_spec - base_temp_dir = None - if not use_tf_compat_v1: - specs = impl_helper.get_type_specs_from_feature_specs(specs) - base_temp_dir = self.base_test_dir - graph, structured_inputs, structured_outputs = ( - impl_helper.trace_preprocessing_function( - preprocessing_fn, - specs, - use_tf_compat_v1=use_tf_compat_v1, - base_temp_dir=base_temp_dir)) - (transform_fn_future, cache_output_dict, - sideeffects) = analysis_graph_builder.build(graph, structured_inputs, - structured_outputs, - dataset_keys, pcoll_cache_dict) - def sort_value_node_values(cache_dict): - result = [] - if cache_dict is None: - return result - for dataset_cache in cache_dict.values(): - result.extend(dataset_cache.values()) - return sorted(result, key=str) - - return self._publish_rendered_dot_graph_file_from_leaf_nodes( - [transform_fn_future] - + sort_value_node_values(cache_output_dict) - + list(sideeffects) - ) - - def _run_pipeline( - self, - feature_spec, - input_data_dict, - preprocessing_fn, - cache_dict=None, - should_read_cache=False, - datasets_to_transform=None, - expected_transform_data=None, - expected_cache=None, - transform_fn_output_dir='', - use_tf_compat_v1=True, - ) -> _RunPipelineResult: - """Runs an analysis pipeline with cache. - - Args: - feature_spec: A feature_spec for the input data. - input_data_dict: Dict[str, List[Dict[str, primitive]]] the input data used - for analysis. - preprocessing_fn: The preprocessing_fn used for analysis. - cache_dict: Dict[str, Dict[str, List[bytes]]], input cache dict. If - provided, should_read_cache must be False. - should_read_cache: A bool indicating if the pipeline should read cache. If - True, cache_dict must be False. - datasets_to_transform: List[str], list of dataset keys to transform. - expected_transform_data: List[Dict[str, primitive]], the expected - transformed data, should be the same for each dataset. - expected_cache: Dict[str, Dict[str, bytes]], expected encoded cache. - transform_fn_output_dir: A directory where the output transform_fn should - be written to, if None provided it will not be written. - use_tf_compat_v1: If True, TFT's public APIs (e.g. AnalyzeDataset) will - use Tensorflow in compat.v1 mode. Defaults to `True`. - - Returns: - A _RunPipelineResult. - """ - input_metadata = dataset_metadata.DatasetMetadata.from_feature_spec( - feature_spec) - with self._TestPipeline() as p: - with tft_beam.Context(force_tf_compat_v1=use_tf_compat_v1): - - # Wraps each value in input_data_dict as a PCollection. - input_data_pcoll_dict = {} - for a, b in input_data_dict.items(): - pcoll = p | a.key >> beam.Create(b) - input_data_pcoll_dict[a] = pcoll - - pcoll_cache_dict = {} - - # If provided with a cache dictionary this wraps cache entries in - # PCollections. - if cache_dict is not None: - assert not should_read_cache - for dataset in cache_dict: - cache_entry = {} - for idx, (k, v) in enumerate(cache_dict[dataset].items()): - cache_entry[k] = ( - p | f'CreateCache[{dataset}][{idx}]' >> beam.Create(v)) - metadata = ( - p | f'CreateCacheMetadata[{dataset}]' >> beam.Create( - [cache_dict[dataset].metadata])) - pcoll_cache_dict[dataset] = analyzer_cache.DatasetCache( - cache_entry, metadata) - - # If requested, reads cache from the test cache directory. - if should_read_cache: - assert cache_dict is None - pcoll_cache_dict = p | analyzer_cache.ReadAnalysisCacheFromFS( - self._cache_dir, list(input_data_dict.keys())) - - self._publish_rendered_dot_graph_file( - preprocessing_fn, - feature_spec, - set(input_data_dict.keys()), - pcoll_cache_dict, - use_tf_compat_v1=use_tf_compat_v1) - - transform_fn, cache_output = ( - (input_data_pcoll_dict, pcoll_cache_dict, input_metadata) - | 'Analyze' >> tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)) - _ = ( - cache_output - | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS( - p, self._cache_dir)) - - # Transforms the requested datasets. - if datasets_to_transform is None: - transformed_dataset = None - else: - flattened_transform_data = ( - [input_data_pcoll_dict[d] for d in datasets_to_transform] - | 'FlattenTransformData' >> beam.Flatten()) - transformed_dataset = (( - (flattened_transform_data, input_metadata), transform_fn) - | 'Transform' >> tft_beam.TransformDataset()) - - # Validate the transformed data is as expected. This requires providing - # datasets_to_transform. - if expected_transform_data is not None: - assert transformed_dataset is not None - transformed_data, unused_transformed_metadata = transformed_dataset - beam_test_util.assert_that( - transformed_data, - beam_test_util.equal_to(expected_transform_data)) - - if expected_cache is not None: - for dataset in expected_cache: - cache_dict = cache_output[dataset].cache_dict - self.assertCountEqual(cache_dict.keys(), - expected_cache[dataset].keys()) - beam_test_util.assert_that( - cache_output[dataset].metadata, - beam_test_util.is_not_empty(), - label='AssertCacheMetadata[{}]'.format(dataset)) - for idx, (key, value) in enumerate(expected_cache[dataset].items()): - beam_test_util.assert_that( - cache_dict[key], - beam_test_util.equal_to(value), - label='AssertCache[{}][{}]'.format(dataset, idx)) - - # Write transform_fn if provided with an output directory. - tft_output = None - if transform_fn_output_dir is not None: - if not transform_fn_output_dir: - transform_fn_output_dir = os.path.join( - self.base_test_dir, uuid.uuid4().hex + def _publish_rendered_dot_graph_file( + self, + preprocessing_fn, + feature_spec, + dataset_keys, + pcoll_cache_dict, + use_tf_compat_v1=True, + ): + specs = feature_spec + base_temp_dir = None + if not use_tf_compat_v1: + specs = impl_helper.get_type_specs_from_feature_specs(specs) + base_temp_dir = self.base_test_dir + graph, structured_inputs, structured_outputs = ( + impl_helper.trace_preprocessing_function( + preprocessing_fn, + specs, + use_tf_compat_v1=use_tf_compat_v1, + base_temp_dir=base_temp_dir, ) - _ = transform_fn | tft_beam.WriteTransformFn(transform_fn_output_dir) - tft_output = tft.TFTransformOutput(transform_fn_output_dir) - - return _RunPipelineResult(cache_output, p.metrics, tft_output) - - @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) - @mock_out_cache_hash - def test_single_phase_mixed_analyzer_run_once(self, use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - span_0_key = analyzer_cache.DatasetKey('span-0') - span_1_key = analyzer_cache.DatasetKey('span-1') - - def preprocessing_fn(inputs): - - _ = tft.bucketize(inputs['x'], 2, name='bucketize') - - return { - 'integerized_s': - tft.compute_and_apply_vocabulary(inputs['s']), - 'x_min': - tft.min(inputs['x'], name='x') + tf.zeros_like(inputs['x']), - 'x_mean': - tft.mean(inputs['x'], name='x') + tf.zeros_like(inputs['x']), - 'y_min': - tft.min(inputs['y'], name='y') + tf.zeros_like(inputs['y']), - 'y_mean': - tft.mean(inputs['y'], name='y') + tf.zeros_like(inputs['y']), - } - - # Run AnalyzeAndTransform on some input data and compare with expected - # output. - input_data = [{'x': 12, 'y': 1, 's': 'd'}, {'x': 10, 'y': 1, 's': 'c'}] - feature_spec = { - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32), - 's': tf.io.FixedLenFeature([], tf.string), - } - input_data_dict = { - span_0_key: [{ - 'x': -2, - 'y': 1, - 's': 'b', - }, { - 'x': 4, - 'y': -4, - 's': 'b', - }], - span_1_key: input_data, - } - - span_0_size = 42 - cache_dict = { - span_0_key: - analyzer_cache.DatasetCache( - { - _make_cache_key( - b'CacheableCombineAccumulate[x_1#mean_and_var]'): - [b'[2.0, 1.0, 9.0, 0.0]'], - _make_cache_key(b'CacheableCombineAccumulate[x#x]'): - [b'[2.0, 4.0]'], - _make_cache_key( - b'CacheableCombineAccumulate[y_1#mean_and_var]'): - [b'[2.0, -1.5, 6.25, 0.0]'], - _make_cache_key(b'CacheableCombineAccumulate[y#y]'): - [b'[4.0, 1.0]'], - }, analyzer_cache.DatasetCacheMetadata(span_0_size)), - span_1_key: - analyzer_cache.DatasetCache({}, None), - } - - expected_transformed = [ - { - 'x_mean': 6.0, - 'x_min': -2.0, - 'y_mean': -0.25, - 'y_min': -4.0, - 'integerized_s': 1, - }, - { - 'x_mean': 6.0, - 'x_min': -2.0, - 'y_mean': -0.25, - 'y_min': -4.0, - 'integerized_s': 2, - }, - ] + ) + (transform_fn_future, cache_output_dict, sideeffects) = ( + analysis_graph_builder.build( + graph, + structured_inputs, + structured_outputs, + dataset_keys, + pcoll_cache_dict, + ) + ) - run_result = self._run_pipeline( + def sort_value_node_values(cache_dict): + result = [] + if cache_dict is None: + return result + for dataset_cache in cache_dict.values(): + result.extend(dataset_cache.values()) + return sorted(result, key=str) + + return self._publish_rendered_dot_graph_file_from_leaf_nodes( + [transform_fn_future] + + sort_value_node_values(cache_output_dict) + + list(sideeffects) + ) + + def _run_pipeline( + self, feature_spec, input_data_dict, preprocessing_fn, - cache_dict=cache_dict, - datasets_to_transform=[span_1_key], - expected_transform_data=expected_transformed, - transform_fn_output_dir=os.path.join(self.base_test_dir, - 'transform_fn'), - use_tf_compat_v1=use_tf_compat_v1) - - # The output cache should not have entries for the cache that is present - # in the input cache. - self.assertEqual( - len(run_result.cache_output[span_0_key].cache_dict), - len(run_result.cache_output[span_1_key].cache_dict) - 4) - - metrics = run_result.metrics - # 4 from analyzing 2 spans, and 2 from transform. - self.assertMetricsCounterEqual(metrics, 'num_instances', 6) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 4) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 8) - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _SINGLE_PHASE_NUM_SAVED_MODELS - ) - self.assertMetricsCounterEqual( - metrics, 'num_packed_accumulate_combiners', 1 - ) - self.assertMetricsCounterEqual(metrics, 'num_packed_merge_combiners', 1) - # All datasets were processed even though some of the analyzers were covered - # by cache. - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + cache_dict=None, + should_read_cache=False, + datasets_to_transform=None, + expected_transform_data=None, + expected_cache=None, + transform_fn_output_dir="", + use_tf_compat_v1=True, + ) -> _RunPipelineResult: + """Runs an analysis pipeline with cache. + + Args: + ---- + feature_spec: A feature_spec for the input data. + input_data_dict: Dict[str, List[Dict[str, primitive]]] the input data used + for analysis. + preprocessing_fn: The preprocessing_fn used for analysis. + cache_dict: Dict[str, Dict[str, List[bytes]]], input cache dict. If + provided, should_read_cache must be False. + should_read_cache: A bool indicating if the pipeline should read cache. If + True, cache_dict must be False. + datasets_to_transform: List[str], list of dataset keys to transform. + expected_transform_data: List[Dict[str, primitive]], the expected + transformed data, should be the same for each dataset. + expected_cache: Dict[str, Dict[str, bytes]], expected encoded cache. + transform_fn_output_dir: A directory where the output transform_fn should + be written to, if None provided it will not be written. + use_tf_compat_v1: If True, TFT's public APIs (e.g. AnalyzeDataset) will + use Tensorflow in compat.v1 mode. Defaults to `True`. + + Returns: + ------- + A _RunPipelineResult. + """ + input_metadata = dataset_metadata.DatasetMetadata.from_feature_spec( + feature_spec + ) + with self._TestPipeline() as p: + with tft_beam.Context(force_tf_compat_v1=use_tf_compat_v1): + # Wraps each value in input_data_dict as a PCollection. + input_data_pcoll_dict = {} + for a, b in input_data_dict.items(): + pcoll = p | a.key >> beam.Create(b) + input_data_pcoll_dict[a] = pcoll + + pcoll_cache_dict = {} + + # If provided with a cache dictionary this wraps cache entries in + # PCollections. + if cache_dict is not None: + assert not should_read_cache + for dataset in cache_dict: + cache_entry = {} + for idx, (k, v) in enumerate(cache_dict[dataset].items()): + cache_entry[k] = ( + p | f"CreateCache[{dataset}][{idx}]" >> beam.Create(v) + ) + metadata = p | f"CreateCacheMetadata[{dataset}]" >> beam.Create( + [cache_dict[dataset].metadata] + ) + pcoll_cache_dict[dataset] = analyzer_cache.DatasetCache( + cache_entry, metadata + ) + + # If requested, reads cache from the test cache directory. + if should_read_cache: + assert cache_dict is None + pcoll_cache_dict = p | analyzer_cache.ReadAnalysisCacheFromFS( + self._cache_dir, list(input_data_dict.keys()) + ) + + self._publish_rendered_dot_graph_file( + preprocessing_fn, + feature_spec, + set(input_data_dict.keys()), + pcoll_cache_dict, + use_tf_compat_v1=use_tf_compat_v1, + ) + + transform_fn, cache_output = ( + input_data_pcoll_dict, + pcoll_cache_dict, + input_metadata, + ) | "Analyze" >> tft_beam.AnalyzeDatasetWithCache(preprocessing_fn) + _ = ( + cache_output + | "WriteCache" + >> analyzer_cache.WriteAnalysisCacheToFS(p, self._cache_dir) + ) + + # Transforms the requested datasets. + if datasets_to_transform is None: + transformed_dataset = None + else: + flattened_transform_data = [ + input_data_pcoll_dict[d] for d in datasets_to_transform + ] | "FlattenTransformData" >> beam.Flatten() + transformed_dataset = ( + (flattened_transform_data, input_metadata), + transform_fn, + ) | "Transform" >> tft_beam.TransformDataset() + + # Validate the transformed data is as expected. This requires providing + # datasets_to_transform. + if expected_transform_data is not None: + assert transformed_dataset is not None + transformed_data, unused_transformed_metadata = transformed_dataset + beam_test_util.assert_that( + transformed_data, + beam_test_util.equal_to(expected_transform_data), + ) + + if expected_cache is not None: + for dataset in expected_cache: + cache_dict = cache_output[dataset].cache_dict + self.assertCountEqual( + cache_dict.keys(), expected_cache[dataset].keys() + ) + beam_test_util.assert_that( + cache_output[dataset].metadata, + beam_test_util.is_not_empty(), + label=f"AssertCacheMetadata[{dataset}]", + ) + for idx, (key, value) in enumerate( + expected_cache[dataset].items() + ): + beam_test_util.assert_that( + cache_dict[key], + beam_test_util.equal_to(value), + label=f"AssertCache[{dataset}][{idx}]", + ) + + # Write transform_fn if provided with an output directory. + tft_output = None + if transform_fn_output_dir is not None: + if not transform_fn_output_dir: + transform_fn_output_dir = os.path.join( + self.base_test_dir, uuid.uuid4().hex + ) + _ = transform_fn | tft_beam.WriteTransformFn( + transform_fn_output_dir + ) + tft_output = tft.TFTransformOutput(transform_fn_output_dir) + + return _RunPipelineResult(cache_output, p.metrics, tft_output) + + @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) + @mock_out_cache_hash + def test_single_phase_mixed_analyzer_run_once(self, use_tf_compat_v1): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + span_0_key = analyzer_cache.DatasetKey("span-0") + span_1_key = analyzer_cache.DatasetKey("span-1") + + def preprocessing_fn(inputs): + _ = tft.bucketize(inputs["x"], 2, name="bucketize") + + return { + "integerized_s": tft.compute_and_apply_vocabulary(inputs["s"]), + "x_min": tft.min(inputs["x"], name="x") + tf.zeros_like(inputs["x"]), + "x_mean": tft.mean(inputs["x"], name="x") + tf.zeros_like(inputs["x"]), + "y_min": tft.min(inputs["y"], name="y") + tf.zeros_like(inputs["y"]), + "y_mean": tft.mean(inputs["y"], name="y") + tf.zeros_like(inputs["y"]), + } + + # Run AnalyzeAndTransform on some input data and compare with expected + # output. + input_data = [{"x": 12, "y": 1, "s": "d"}, {"x": 10, "y": 1, "s": "c"}] + feature_spec = { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + "s": tf.io.FixedLenFeature([], tf.string), + } + input_data_dict = { + span_0_key: [ + { + "x": -2, + "y": 1, + "s": "b", + }, + { + "x": 4, + "y": -4, + "s": "b", + }, + ], + span_1_key: input_data, + } - @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) - def test_single_phase_run_twice(self, use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - span_0_key = analyzer_cache.DatasetKey('span-0') - span_1_key = analyzer_cache.DatasetKey('span-1') - span_2_key = analyzer_cache.DatasetKey('span-2') - - def preprocessing_fn(inputs): - - _ = tft.vocabulary(inputs['s'], vocab_filename='vocab1') - - _ = tft.bucketize(inputs['x'], 2, name='bucketize') - - y_cov = tft.covariance(tf.expand_dims(inputs['y'], axis=1), tf.float32) - return { - 'x_min': - tft.min(inputs['x'], name='x') + tf.zeros_like(inputs['x']), - 'x_mean': - tft.mean(inputs['x'], name='x') + tf.zeros_like(inputs['x']), - 'y_min': - tft.min(inputs['y'], name='y') + tf.zeros_like(inputs['y']), - 'y_mean': - tft.mean(inputs['y'], name='y') + tf.zeros_like(inputs['y']), - 'y_cov': - tf.math.reduce_sum(y_cov) + tf.zeros_like(inputs['y']), - 's_integerized': - tft.compute_and_apply_vocabulary( - inputs['s'], - labels=inputs['label'], - use_adjusted_mutual_info=True), - } - - feature_spec = { - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32), - 's': tf.io.FixedLenFeature([], tf.string), - 'label': tf.io.FixedLenFeature([], tf.int64), - } - input_data_dict = { - span_0_key: [], - span_1_key: [{ - 'x': -2, - 'y': 1, - 's': 'a', - 'label': 0, - }, { - 'x': 4, - 'y': -4, - 's': 'a', - 'label': 1, - }, { - 'x': 5, - 'y': 11, - 's': 'a', - 'label': 1, - }, { - 'x': 1, - 'y': -4, - 's': u'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'), - 'label': 1, - }], - span_2_key: [{ - 'x': 12, - 'y': 1, - 's': u'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'), - 'label': 0 - }, { - 'x': 10, - 'y': 1, - 's': 'c', - 'label': 1 - }], - } - expected_vocabulary_contents = np.array( - [b'a', u'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'), b'c'], dtype=object) - - expected_transformed_data = [ - { - 'x_mean': 5.0, - 'x_min': -2.0, - 'y_mean': 1.0, - 'y_min': -4.0, - 'y_cov': 25.0, - 's_integerized': 0, - }, - { - 'x_mean': 5.0, - 'x_min': -2.0, - 'y_mean': 1.0, - 'y_min': -4.0, - 'y_cov': 25.0, - 's_integerized': 2, - }, - ] + span_0_size = 42 + cache_dict = { + span_0_key: analyzer_cache.DatasetCache( + { + _make_cache_key(b"CacheableCombineAccumulate[x_1#mean_and_var]"): [ + b"[2.0, 1.0, 9.0, 0.0]" + ], + _make_cache_key(b"CacheableCombineAccumulate[x#x]"): [ + b"[2.0, 4.0]" + ], + _make_cache_key(b"CacheableCombineAccumulate[y_1#mean_and_var]"): [ + b"[2.0, -1.5, 6.25, 0.0]" + ], + _make_cache_key(b"CacheableCombineAccumulate[y#y]"): [ + b"[4.0, 1.0]" + ], + }, + analyzer_cache.DatasetCacheMetadata(span_0_size), + ), + span_1_key: analyzer_cache.DatasetCache({}, None), + } - first_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - datasets_to_transform=[span_2_key], - expected_transform_data=expected_transformed_data, - use_tf_compat_v1=use_tf_compat_v1) + expected_transformed = [ + { + "x_mean": 6.0, + "x_min": -2.0, + "y_mean": -0.25, + "y_min": -4.0, + "integerized_s": 1, + }, + { + "x_mean": 6.0, + "x_min": -2.0, + "y_mean": -0.25, + "y_min": -4.0, + "integerized_s": 2, + }, + ] - for key in input_data_dict: - self.assertIn(key, first_run_result.cache_output) - self.assertEqual(8, len(first_run_result.cache_output[key].cache_dict)) + run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + cache_dict=cache_dict, + datasets_to_transform=[span_1_key], + expected_transform_data=expected_transformed, + transform_fn_output_dir=os.path.join(self.base_test_dir, "transform_fn"), + use_tf_compat_v1=use_tf_compat_v1, + ) - vocab1_path = first_run_result.transform_output.vocabulary_file_by_name( - 'vocab1' - ) - self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) - - metrics = first_run_result.metrics - # 6 from analyzing 3 spans, and 2 from transform. - self.assertMetricsCounterEqual(metrics, 'num_instances', 8) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 0) - # 8 entries for each of 3 spans. Note that default values for the empty span - # are also encoded. - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 24) - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _SINGLE_PHASE_NUM_SAVED_MODELS - ) - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + # The output cache should not have entries for the cache that is present + # in the input cache. + self.assertEqual( + len(run_result.cache_output[span_0_key].cache_dict), + len(run_result.cache_output[span_1_key].cache_dict) - 4, + ) - second_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - should_read_cache=True, - datasets_to_transform=[span_2_key], - expected_transform_data=expected_transformed_data, - use_tf_compat_v1=use_tf_compat_v1) + metrics = run_result.metrics + # 4 from analyzing 2 spans, and 2 from transform. + self.assertMetricsCounterEqual(metrics, "num_instances", 6) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 4) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 8) + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _SINGLE_PHASE_NUM_SAVED_MODELS + ) + self.assertMetricsCounterEqual(metrics, "num_packed_accumulate_combiners", 1) + self.assertMetricsCounterEqual(metrics, "num_packed_merge_combiners", 1) + # All datasets were processed even though some of the analyzers were covered + # by cache. + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) + + @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) + def test_single_phase_run_twice(self, use_tf_compat_v1): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + span_0_key = analyzer_cache.DatasetKey("span-0") + span_1_key = analyzer_cache.DatasetKey("span-1") + span_2_key = analyzer_cache.DatasetKey("span-2") + + def preprocessing_fn(inputs): + _ = tft.vocabulary(inputs["s"], vocab_filename="vocab1") + + _ = tft.bucketize(inputs["x"], 2, name="bucketize") + + y_cov = tft.covariance(tf.expand_dims(inputs["y"], axis=1), tf.float32) + return { + "x_min": tft.min(inputs["x"], name="x") + tf.zeros_like(inputs["x"]), + "x_mean": tft.mean(inputs["x"], name="x") + tf.zeros_like(inputs["x"]), + "y_min": tft.min(inputs["y"], name="y") + tf.zeros_like(inputs["y"]), + "y_mean": tft.mean(inputs["y"], name="y") + tf.zeros_like(inputs["y"]), + "y_cov": tf.math.reduce_sum(y_cov) + tf.zeros_like(inputs["y"]), + "s_integerized": tft.compute_and_apply_vocabulary( + inputs["s"], labels=inputs["label"], use_adjusted_mutual_info=True + ), + } + + feature_spec = { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + "s": tf.io.FixedLenFeature([], tf.string), + "label": tf.io.FixedLenFeature([], tf.int64), + } + input_data_dict = { + span_0_key: [], + span_1_key: [ + { + "x": -2, + "y": 1, + "s": "a", + "label": 0, + }, + { + "x": 4, + "y": -4, + "s": "a", + "label": 1, + }, + { + "x": 5, + "y": 11, + "s": "a", + "label": 1, + }, + { + "x": 1, + "y": -4, + "s": "ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴".encode(), + "label": 1, + }, + ], + span_2_key: [ + {"x": 12, "y": 1, "s": "ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴".encode(), "label": 0}, + {"x": 10, "y": 1, "s": "c", "label": 1}, + ], + } + expected_vocabulary_contents = np.array( + [b"a", "ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴".encode(), b"c"], dtype=object + ) - vocab1_path = second_run_result.transform_output.vocabulary_file_by_name( - 'vocab1' - ) - self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) + expected_transformed_data = [ + { + "x_mean": 5.0, + "x_min": -2.0, + "y_mean": 1.0, + "y_min": -4.0, + "y_cov": 25.0, + "s_integerized": 0, + }, + { + "x_mean": 5.0, + "x_min": -2.0, + "y_mean": 1.0, + "y_min": -4.0, + "y_cov": 25.0, + "s_integerized": 2, + }, + ] + + first_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + datasets_to_transform=[span_2_key], + expected_transform_data=expected_transformed_data, + use_tf_compat_v1=use_tf_compat_v1, + ) - self.assertFalse(second_run_result.cache_output) + for key in input_data_dict: + self.assertIn(key, first_run_result.cache_output) + self.assertEqual(8, len(first_run_result.cache_output[key].cache_dict)) - metrics = second_run_result.metrics - # Only 2 from transform. - self.assertMetricsCounterEqual(metrics, 'num_instances', 2) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 24) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 0) + vocab1_path = first_run_result.transform_output.vocabulary_file_by_name( + "vocab1" + ) + self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) + + metrics = first_run_result.metrics + # 6 from analyzing 3 spans, and 2 from transform. + self.assertMetricsCounterEqual(metrics, "num_instances", 8) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 0) + # 8 entries for each of 3 spans. Note that default values for the empty span + # are also encoded. + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 24) + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _SINGLE_PHASE_NUM_SAVED_MODELS + ) + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) - # The root CreateSavedModel is optimized away because the data doesn't get - # processed at all (only cache). - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _ZERO_PHASE_NUM_SAVED_MODELS - ) - # Cache coverage allowed us to avoid processing this many bytes of data. - self.assertMetricsCounterGreater( - metrics, 'analysis_input_bytes_from_cache', 413 - ) + second_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + should_read_cache=True, + datasets_to_transform=[span_2_key], + expected_transform_data=expected_transformed_data, + use_tf_compat_v1=use_tf_compat_v1, + ) + + vocab1_path = second_run_result.transform_output.vocabulary_file_by_name( + "vocab1" + ) + self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) + + self.assertFalse(second_run_result.cache_output) + + metrics = second_run_result.metrics + # Only 2 from transform. + self.assertMetricsCounterEqual(metrics, "num_instances", 2) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 24) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 0) - @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) - @mock_out_cache_hash - def test_caching_vocab_for_integer_categorical(self, use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - span_0_key = analyzer_cache.DatasetKey('span-0') - span_1_key = analyzer_cache.DatasetKey('span-1') - - def preprocessing_fn(inputs): - return { - 'x_vocab': - tft.compute_and_apply_vocabulary( - inputs['x'], frequency_threshold=2) - } - - feature_spec = {'x': tf.io.FixedLenFeature([], tf.int64)} - input_data_dict = { - span_0_key: [{'x': -2}, {'x': -4}, {'x': -1}, {'x': 4}], - span_1_key: [{'x': -2}, {'x': -1}, {'x': 6}, {'x': 7}], - } # pyformat: disable - expected_transformed_data = [ - {'x_vocab': 0}, {'x_vocab': 1}, {'x_vocab': -1}, {'x_vocab': -1} - ] # pyformat: disable - - dataset_size = 17 - cache_dict = { - span_0_key: - analyzer_cache.DatasetCache( + # The root CreateSavedModel is optimized away because the data doesn't get + # processed at all (only cache). + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _ZERO_PHASE_NUM_SAVED_MODELS + ) + # Cache coverage allowed us to avoid processing this many bytes of data. + self.assertMetricsCounterGreater( + metrics, "analysis_input_bytes_from_cache", 413 + ) + + @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) + @mock_out_cache_hash + def test_caching_vocab_for_integer_categorical(self, use_tf_compat_v1): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + span_0_key = analyzer_cache.DatasetKey("span-0") + span_1_key = analyzer_cache.DatasetKey("span-1") + + def preprocessing_fn(inputs): + return { + "x_vocab": tft.compute_and_apply_vocabulary( + inputs["x"], frequency_threshold=2 + ) + } + + feature_spec = {"x": tf.io.FixedLenFeature([], tf.int64)} + input_data_dict = { + span_0_key: [{"x": -2}, {"x": -4}, {"x": -1}, {"x": 4}], + span_1_key: [{"x": -2}, {"x": -1}, {"x": 6}, {"x": 7}], + } # pyformat: disable + expected_transformed_data = [ + {"x_vocab": 0}, + {"x_vocab": 1}, + {"x_vocab": -1}, + {"x_vocab": -1}, + ] # pyformat: disable + + dataset_size = 17 + cache_dict = { + span_0_key: analyzer_cache.DatasetCache( { _make_cache_key( - b'VocabularyAccumulate[compute_and_apply_vocabulary#vocabulary]' + b"VocabularyAccumulate[compute_and_apply_vocabulary#vocabulary]" ): [ - _encode_vocabulary_accumulator(b'-2', b'2'), - _encode_vocabulary_accumulator(b'-4', b'1'), - _encode_vocabulary_accumulator(b'-1', b'1'), - _encode_vocabulary_accumulator(b'4', b'1'), + _encode_vocabulary_accumulator(b"-2", b"2"), + _encode_vocabulary_accumulator(b"-4", b"1"), + _encode_vocabulary_accumulator(b"-1", b"1"), + _encode_vocabulary_accumulator(b"4", b"1"), ] }, - analyzer_cache.DatasetCacheMetadata(dataset_size=dataset_size)), - span_1_key: - analyzer_cache.DatasetCache({}, None), - } + analyzer_cache.DatasetCacheMetadata(dataset_size=dataset_size), + ), + span_1_key: analyzer_cache.DatasetCache({}, None), + } - run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - cache_dict=cache_dict, - datasets_to_transform=[span_1_key], - expected_transform_data=expected_transformed_data, - transform_fn_output_dir=None, - use_tf_compat_v1=use_tf_compat_v1, - ) + run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + cache_dict=cache_dict, + datasets_to_transform=[span_1_key], + expected_transform_data=expected_transformed_data, + transform_fn_output_dir=None, + use_tf_compat_v1=use_tf_compat_v1, + ) - self.assertNotIn(span_0_key, run_result.cache_output) + self.assertNotIn(span_0_key, run_result.cache_output) - metrics = run_result.metrics - # 4 from analysis since 1 span was completely cached, and 4 from transform. - self.assertMetricsCounterEqual(metrics, 'num_instances', 8) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 1) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 1) - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _SINGLE_PHASE_NUM_SAVED_MODELS - ) - # Cache coverage allowed us to avoid processing this many bytes of data. - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', dataset_size - ) + metrics = run_result.metrics + # 4 from analysis since 1 span was completely cached, and 4 from transform. + self.assertMetricsCounterEqual(metrics, "num_instances", 8) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 1) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 1) + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _SINGLE_PHASE_NUM_SAVED_MODELS + ) + # Cache coverage allowed us to avoid processing this many bytes of data. + self.assertMetricsCounterEqual( + metrics, "analysis_input_bytes_from_cache", dataset_size + ) - @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) - @mock_out_cache_hash - def test_non_frequency_vocabulary_merge(self, use_tf_compat_v1): - """This test compares vocabularies produced with and without cache.""" - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - mi_vocab_name = 'mutual_information_vocab' - adjusted_mi_vocab_name = 'adjusted_mutual_information_vocab' - weighted_frequency_vocab_name = 'weighted_frequency_vocab' - - def preprocessing_fn(inputs): - _ = tft.vocabulary( - inputs['s'], - labels=inputs['label'], - store_frequency=True, - vocab_filename=mi_vocab_name, - min_diff_from_avg=0.1, - use_adjusted_mutual_info=False, - name='with_mi') - - _ = tft.vocabulary( - inputs['s'], - labels=inputs['label'], - store_frequency=True, - vocab_filename=adjusted_mi_vocab_name, - min_diff_from_avg=1.0, - use_adjusted_mutual_info=True, - name='with_adjusted_mi') - - _ = tft.vocabulary( - inputs['s'], - weights=inputs['weight'], - store_frequency=True, - vocab_filename=weighted_frequency_vocab_name, - use_adjusted_mutual_info=False, - name='with_weight') - return inputs - - span_0_key = analyzer_cache.DatasetKey('span-0') - span_1_key = analyzer_cache.DatasetKey('span-1') - - input_data = [ - dict(s='a', weight=1, label=1), - dict(s='a', weight=0.5, label=1), - dict(s='b', weight=0.75, label=1), - dict(s='b', weight=1, label=0), - ] - feature_spec = { - 's': tf.io.FixedLenFeature([], tf.string), - 'label': tf.io.FixedLenFeature([], tf.int64), - 'weight': tf.io.FixedLenFeature([], tf.float32), - } - input_data_dict = { - span_0_key: input_data, - span_1_key: input_data, - } - transform_fn_with_cache_dir = os.path.join(self.base_test_dir, - 'transform_fn_with_cache') - - expected_accumulators = { - _make_cache_key(b'VocabularyAccumulate[with_mi]'): [ - _encode_vocabulary_accumulator(b'a', - b'[2, [0.0, 1.0], [0.0, 0.0], 1.0]'), - _encode_vocabulary_accumulator(b'b', - b'[2, [0.5, 0.5], [0.0, 0.0], 1.0]'), - _encode_vocabulary_accumulator( - b'global_y_count_sentinel', - b'[4, [0.25, 0.75], [0.0, 0.0], 1.0]'), - ], - _make_cache_key(b'VocabularyAccumulate[with_adjusted_mi]'): [ - _encode_vocabulary_accumulator(b'a', - b'[2, [0.0, 1.0], [0.0, 0.0], 1.0]'), - _encode_vocabulary_accumulator(b'b', - b'[2, [0.5, 0.5], [0.0, 0.0], 1.0]'), - _encode_vocabulary_accumulator( - b'global_y_count_sentinel', - b'[4, [0.25, 0.75], [0.0, 0.0], 1.0]'), - ], - _make_cache_key(b'VocabularyAccumulate[with_weight]'): [ - _encode_vocabulary_accumulator(b'a', b'1.5'), - _encode_vocabulary_accumulator(b'b', b'1.75') + @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) + @mock_out_cache_hash + def test_non_frequency_vocabulary_merge(self, use_tf_compat_v1): + """This test compares vocabularies produced with and without cache.""" + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + mi_vocab_name = "mutual_information_vocab" + adjusted_mi_vocab_name = "adjusted_mutual_information_vocab" + weighted_frequency_vocab_name = "weighted_frequency_vocab" + + def preprocessing_fn(inputs): + _ = tft.vocabulary( + inputs["s"], + labels=inputs["label"], + store_frequency=True, + vocab_filename=mi_vocab_name, + min_diff_from_avg=0.1, + use_adjusted_mutual_info=False, + name="with_mi", + ) + + _ = tft.vocabulary( + inputs["s"], + labels=inputs["label"], + store_frequency=True, + vocab_filename=adjusted_mi_vocab_name, + min_diff_from_avg=1.0, + use_adjusted_mutual_info=True, + name="with_adjusted_mi", + ) + + _ = tft.vocabulary( + inputs["s"], + weights=inputs["weight"], + store_frequency=True, + vocab_filename=weighted_frequency_vocab_name, + use_adjusted_mutual_info=False, + name="with_weight", + ) + return inputs + + span_0_key = analyzer_cache.DatasetKey("span-0") + span_1_key = analyzer_cache.DatasetKey("span-1") + + input_data = [ + dict(s="a", weight=1, label=1), + dict(s="a", weight=0.5, label=1), + dict(s="b", weight=0.75, label=1), + dict(s="b", weight=1, label=0), + ] + feature_spec = { + "s": tf.io.FixedLenFeature([], tf.string), + "label": tf.io.FixedLenFeature([], tf.int64), + "weight": tf.io.FixedLenFeature([], tf.float32), + } + input_data_dict = { + span_0_key: input_data, + span_1_key: input_data, + } + transform_fn_with_cache_dir = os.path.join( + self.base_test_dir, "transform_fn_with_cache" + ) + + expected_accumulators = { + _make_cache_key(b"VocabularyAccumulate[with_mi]"): [ + _encode_vocabulary_accumulator( + b"a", b"[2, [0.0, 1.0], [0.0, 0.0], 1.0]" + ), + _encode_vocabulary_accumulator( + b"b", b"[2, [0.5, 0.5], [0.0, 0.0], 1.0]" + ), + _encode_vocabulary_accumulator( + b"global_y_count_sentinel", b"[4, [0.25, 0.75], [0.0, 0.0], 1.0]" + ), + ], + _make_cache_key(b"VocabularyAccumulate[with_adjusted_mi]"): [ + _encode_vocabulary_accumulator( + b"a", b"[2, [0.0, 1.0], [0.0, 0.0], 1.0]" + ), + _encode_vocabulary_accumulator( + b"b", b"[2, [0.5, 0.5], [0.0, 0.0], 1.0]" + ), + _encode_vocabulary_accumulator( + b"global_y_count_sentinel", b"[4, [0.25, 0.75], [0.0, 0.0], 1.0]" + ), + ], + _make_cache_key(b"VocabularyAccumulate[with_weight]"): [ + _encode_vocabulary_accumulator(b"a", b"1.5"), + _encode_vocabulary_accumulator(b"b", b"1.75"), + ], + } + expected_cache = { + span: expected_accumulators for span in [span_0_key, span_1_key] + } + + run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + transform_fn_output_dir=transform_fn_with_cache_dir, + expected_cache=expected_cache, + use_tf_compat_v1=use_tf_compat_v1, + ) + + metrics = run_result.metrics + # 4 from analysis on each of the input spans. + self.assertMetricsCounterEqual(metrics, "num_instances", 8) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 0) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 6) + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _SINGLE_PHASE_NUM_SAVED_MODELS + ) + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) + + with self._TestPipeline() as p: + with tft_beam.Context(): + flat_data = p | "CreateInputData" >> beam.Create(input_data * 2) + + input_metadata = dataset_metadata.DatasetMetadata.from_feature_spec( + feature_spec + ) + transform_fn_no_cache = ( + flat_data, + input_metadata, + ) | tft_beam.AnalyzeDataset(preprocessing_fn) + + transform_fn_no_cache_dir = os.path.join( + self.base_test_dir, "transform_fn_no_cache" + ) + _ = transform_fn_no_cache | tft_beam.WriteTransformFn( + transform_fn_no_cache_dir + ) + + # 4 from analysis on each of the input spans. + self.assertMetricsCounterEqual(p.metrics, "num_instances", 8) + self.assertMetricsCounterEqual(p.metrics, "cache_entries_decoded", 0) + self.assertMetricsCounterEqual(p.metrics, "cache_entries_encoded", 0) + self.assertMetricsCounterEqual( + p.metrics, "saved_models_created", _SINGLE_PHASE_NUM_SAVED_MODELS + ) + self.assertMetricsCounterEqual(p.metrics, "analysis_input_bytes_from_cache", 0) + + tft_output_cache = tft.TFTransformOutput(transform_fn_with_cache_dir) + tft_output_no_cache = tft.TFTransformOutput(transform_fn_no_cache_dir) + + for vocab_filename in ( + mi_vocab_name, + adjusted_mi_vocab_name, + weighted_frequency_vocab_name, + ): + cache_path = tft_output_cache.vocabulary_file_by_name(vocab_filename) + no_cache_path = tft_output_no_cache.vocabulary_file_by_name(vocab_filename) + with tf.io.gfile.GFile(cache_path, "rb") as f1, tf.io.gfile.GFile( + no_cache_path, "rb" + ) as f2: + self.assertEqual( + f1.readlines(), + f2.readlines(), + f"vocab with cache != vocab without cache for: {vocab_filename}", + ) + + @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) + @mock_out_cache_hash + def test_cached_ptransform_analyzer(self, use_tf_compat_v1): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + + class _AnalyzerMakeAccumulators(beam.PTransform): + def expand(self, pcoll): + input_sum = ( + pcoll | beam.FlatMap(sum) | "ReduceSum" >> beam.CombineGlobally(sum) + ) + size = ( + pcoll + | beam.Map(np.size) + | "ReduceCount" >> beam.CombineGlobally(sum) + ) + + return ( + pcoll.pipeline + | beam.Create([None]) + | beam.Map( + lambda _, a, b: (a, b), # pyformat: disable + beam.pvalue.AsSingleton(input_sum), + beam.pvalue.AsSingleton(size), + ) + ) + + class _AnalyzerMergeAccumulators(beam.PTransform): + def expand(self, pcoll): + def merge(x): + zipped = list(zip(*x)) + assert len(zipped) == 2, zipped + return sum(zipped[0]), sum(zipped[1]) + + return pcoll | beam.CombineGlobally(merge) + + class _AnalyzerExtractOutput(beam.PTransform): + def expand(self, pcoll): + return pcoll | beam.Map(lambda p: p[0] / p[1]) + + analyzer = tft.experimental.CacheablePTransformAnalyzer( + make_accumulators_ptransform=_AnalyzerMakeAccumulators(), + merge_accumulators_ptransform=_AnalyzerMergeAccumulators(), + extract_output_ptransform=_AnalyzerExtractOutput(), + cache_coder=tft.experimental.SimpleJsonPTransformAnalyzerCacheCoder(), + ) + + def preprocessing_fn(inputs): + y = tft.experimental.ptransform_analyzer( + [inputs["x"]], analyzer, [tf.int64], [[]] + ) + return {"y": tf.zeros_like(inputs["x"]) + y} + + feature_spec = {"x": tf.io.FixedLenFeature([], tf.int64)} + span_0_key = analyzer_cache.DatasetKey("span-0") + input_data_dict = {span_0_key: [{"x": x} for x in range(7)]} + expected_cache_dict = { + span_0_key: { + _make_cache_key(b"PTransform[ptransform#local_merge_accumulators]"): [ + b"[21, 7]" + ], + }, + } + expected_transformed_data = [{"y": 3} for _ in range(7)] + first_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + datasets_to_transform=[span_0_key], + expected_transform_data=expected_transformed_data, + expected_cache=expected_cache_dict, + use_tf_compat_v1=use_tf_compat_v1, + ) + metrics = first_run_result.metrics + # Incremented for both analysis and transform (7 * 2). + self.assertMetricsCounterEqual(metrics, "num_instances", 14) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 0) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 1) + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) + + first_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + should_read_cache=True, + datasets_to_transform=[span_0_key], + expected_transform_data=expected_transformed_data, + expected_cache={}, + use_tf_compat_v1=use_tf_compat_v1, + ) + metrics = first_run_result.metrics + # This time analysis is skipped due to cache, only transform dataset counts. + self.assertMetricsCounterEqual(metrics, "num_instances", 7) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 1) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 0) + self.assertMetricsCounterGreater( + metrics, "analysis_input_bytes_from_cache", 100 + ) + + @tft_unit.named_parameters(*_OPTIMIZE_TRAVERSAL_TEST_CASES) + @mock_out_cache_hash + def test_optimize_traversal( + self, + feature_spec: Mapping[str, common_types.FeatureSpecType], + preprocessing_fn: Callable[ + [Mapping[str, common_types.TensorType]], + Mapping[str, common_types.TensorType], ], - } - expected_cache = { - span: expected_accumulators for span in [span_0_key, span_1_key] - } + dataset_input_cache_dicts: List[Mapping[str, str]], + expected_dot_graph_str: str, + ): + dataset_keys = [ + analyzer_cache.DatasetKey("span-0"), + analyzer_cache.DatasetKey("span-1"), + ] + if dataset_input_cache_dicts is not None: + cache = { + key: analyzer_cache.DatasetCache( + cache_dict, analyzer_cache.DatasetCacheMetadata(1) + ) + for key, cache_dict in zip(dataset_keys, dataset_input_cache_dicts) + } # pyformat: disable + else: + cache = {} + dot_string = self._publish_rendered_dot_graph_file( + preprocessing_fn, feature_spec, set(dataset_keys), cache + ) - run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - transform_fn_output_dir=transform_fn_with_cache_dir, - expected_cache=expected_cache, - use_tf_compat_v1=use_tf_compat_v1) - - metrics = run_result.metrics - # 4 from analysis on each of the input spans. - self.assertMetricsCounterEqual(metrics, 'num_instances', 8) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 0) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 6) - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _SINGLE_PHASE_NUM_SAVED_MODELS - ) - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + self.assertSameElements( + expected_dot_graph_str.split("\n"), + dot_string.split("\n"), + msg=f"Result dot graph is:\n{dot_string}", + ) - with self._TestPipeline() as p: - with tft_beam.Context(): - flat_data = p | 'CreateInputData' >> beam.Create(input_data * 2) + def test_no_data_needed(self): + span_0_key = analyzer_cache.DatasetKey("span-0") + span_1_key = analyzer_cache.DatasetKey("span-1") + + def preprocessing_fn(inputs): + return {k: tf.identity(v) for k, v in inputs.items()} input_metadata = dataset_metadata.DatasetMetadata.from_feature_spec( - feature_spec) - transform_fn_no_cache = ((flat_data, input_metadata) - | tft_beam.AnalyzeDataset(preprocessing_fn)) - - transform_fn_no_cache_dir = os.path.join(self.base_test_dir, - 'transform_fn_no_cache') - _ = transform_fn_no_cache | tft_beam.WriteTransformFn( - transform_fn_no_cache_dir) - - # 4 from analysis on each of the input spans. - self.assertMetricsCounterEqual(p.metrics, 'num_instances', 8) - self.assertMetricsCounterEqual(p.metrics, 'cache_entries_decoded', 0) - self.assertMetricsCounterEqual(p.metrics, 'cache_entries_encoded', 0) - self.assertMetricsCounterEqual(p.metrics, 'saved_models_created', - _SINGLE_PHASE_NUM_SAVED_MODELS) - self.assertMetricsCounterEqual(p.metrics, 'analysis_input_bytes_from_cache', - 0) - - tft_output_cache = tft.TFTransformOutput(transform_fn_with_cache_dir) - tft_output_no_cache = tft.TFTransformOutput(transform_fn_no_cache_dir) - - for vocab_filename in (mi_vocab_name, adjusted_mi_vocab_name, - weighted_frequency_vocab_name): - cache_path = tft_output_cache.vocabulary_file_by_name(vocab_filename) - no_cache_path = tft_output_no_cache.vocabulary_file_by_name( - vocab_filename) - with tf.io.gfile.GFile(cache_path, 'rb') as f1, tf.io.gfile.GFile( - no_cache_path, 'rb') as f2: - self.assertEqual( - f1.readlines(), f2.readlines(), - 'vocab with cache != vocab without cache for: {}'.format( - vocab_filename)) - - @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) - @mock_out_cache_hash - def test_cached_ptransform_analyzer(self, use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - - class _AnalyzerMakeAccumulators(beam.PTransform): - - def expand(self, pcoll): - input_sum = pcoll | beam.FlatMap( - sum) | 'ReduceSum' >> beam.CombineGlobally(sum) - size = pcoll | beam.Map( - np.size) | 'ReduceCount' >> beam.CombineGlobally(sum) - - return (pcoll.pipeline - | beam.Create([None]) - | beam.Map( - lambda _, a, b: (a, b), # pyformat: disable - beam.pvalue.AsSingleton(input_sum), - beam.pvalue.AsSingleton(size))) - - class _AnalyzerMergeAccumulators(beam.PTransform): - - def expand(self, pcoll): - - def merge(x): - zipped = list(zip(*x)) - assert len(zipped) == 2, zipped - return sum(zipped[0]), sum(zipped[1]) - - return pcoll | beam.CombineGlobally(merge) - - class _AnalyzerExtractOutput(beam.PTransform): - - def expand(self, pcoll): - - return pcoll | beam.Map(lambda p: p[0] / p[1]) - - analyzer = tft.experimental.CacheablePTransformAnalyzer( - make_accumulators_ptransform=_AnalyzerMakeAccumulators(), - merge_accumulators_ptransform=_AnalyzerMergeAccumulators(), - extract_output_ptransform=_AnalyzerExtractOutput(), - cache_coder=tft.experimental.SimpleJsonPTransformAnalyzerCacheCoder()) - - def preprocessing_fn(inputs): - y = tft.experimental.ptransform_analyzer([inputs['x']], analyzer, - [tf.int64], [[]]) - return {'y': tf.zeros_like(inputs['x']) + y} - - feature_spec = {'x': tf.io.FixedLenFeature([], tf.int64)} - span_0_key = analyzer_cache.DatasetKey('span-0') - input_data_dict = {span_0_key: [{'x': x} for x in range(7)]} - expected_cache_dict = { - span_0_key: { - _make_cache_key(b'PTransform[ptransform#local_merge_accumulators]'): - [b'[21, 7]'], - }, - } - expected_transformed_data = [{'y': 3} for _ in range(7)] - first_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - datasets_to_transform=[span_0_key], - expected_transform_data=expected_transformed_data, - expected_cache=expected_cache_dict, - use_tf_compat_v1=use_tf_compat_v1) - metrics = first_run_result.metrics - # Incremented for both analysis and transform (7 * 2). - self.assertMetricsCounterEqual(metrics, 'num_instances', 14) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 0) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 1) - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + { + "x": tf.io.FixedLenFeature([], tf.float32), + } + ) + input_data_dict = { + span_0_key: None, + span_1_key: None, + } + + with self._TestPipeline() as p: + cache_dict = { + span_0_key: {}, + span_1_key: {}, + } + + _, output_cache = ( + input_data_dict, + cache_dict, + input_metadata, + ) | "Analyze" >> tft_beam.AnalyzeDatasetWithCache( + preprocessing_fn, pipeline=p + ) + self.assertFalse(output_cache) + + @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) + def test_tf_function_works_with_cache(self, use_tf_compat_v1): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + + def preprocessing_fn(inputs, should_add_one): + @tf.function + def identity(x): + if should_add_one: + x = x + 1 + return x + + return { + "x_mean": tft.mean(identity(inputs["x"]), name="x") + + tf.zeros_like(inputs["x"]) + } + + feature_spec = {"x": tf.io.FixedLenFeature([], tf.float32)} + input_data_dict = {analyzer_cache.DatasetKey("span-0"): [dict(x=-2), dict(x=4)]} + run_result = self._run_pipeline( + feature_spec, + input_data_dict, + functools.partial(preprocessing_fn, should_add_one=False), + transform_fn_output_dir=None, + use_tf_compat_v1=use_tf_compat_v1, + ) + first_cache_output, metrics = run_result.cache_output, run_result.metrics - first_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - should_read_cache=True, - datasets_to_transform=[span_0_key], - expected_transform_data=expected_transformed_data, - expected_cache={}, - use_tf_compat_v1=use_tf_compat_v1) - metrics = first_run_result.metrics - # This time analysis is skipped due to cache, only transform dataset counts. - self.assertMetricsCounterEqual(metrics, 'num_instances', 7) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 1) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 0) - self.assertMetricsCounterGreater( - metrics, 'analysis_input_bytes_from_cache', 100 - ) + for key in input_data_dict: + self.assertIn(key, first_cache_output) + self.assertEqual(1, len(first_cache_output[key].cache_dict)) - @tft_unit.named_parameters(*_OPTIMIZE_TRAVERSAL_TEST_CASES) - @mock_out_cache_hash - def test_optimize_traversal( - self, feature_spec: Mapping[str, common_types.FeatureSpecType], - preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], - Mapping[str, common_types.TensorType]], - dataset_input_cache_dicts: List[Mapping[str, str]], - expected_dot_graph_str: str): - dataset_keys = [ - analyzer_cache.DatasetKey('span-0'), - analyzer_cache.DatasetKey('span-1') - ] - if dataset_input_cache_dicts is not None: - cache = { - key: analyzer_cache.DatasetCache( - cache_dict, analyzer_cache.DatasetCacheMetadata(1)) - for key, cache_dict in zip(dataset_keys, dataset_input_cache_dicts) - } # pyformat: disable - else: - cache = {} - dot_string = self._publish_rendered_dot_graph_file(preprocessing_fn, - feature_spec, - set(dataset_keys), cache) - - self.assertSameElements( - expected_dot_graph_str.split('\n'), - dot_string.split('\n'), - msg='Result dot graph is:\n{}'.format(dot_string)) - - def test_no_data_needed(self): - span_0_key = analyzer_cache.DatasetKey('span-0') - span_1_key = analyzer_cache.DatasetKey('span-1') - - def preprocessing_fn(inputs): - return {k: tf.identity(v) for k, v in inputs.items()} - - input_metadata = dataset_metadata.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - }) - input_data_dict = { - span_0_key: None, - span_1_key: None, - } - - with self._TestPipeline() as p: - cache_dict = { - span_0_key: {}, - span_1_key: {}, - } - - _, output_cache = ((input_data_dict, cache_dict, input_metadata) - | 'Analyze' >> tft_beam.AnalyzeDatasetWithCache( - preprocessing_fn, pipeline=p)) - self.assertFalse(output_cache) - - @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) - def test_tf_function_works_with_cache(self, use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - - def preprocessing_fn(inputs, should_add_one): - - @tf.function - def identity(x): - if should_add_one: - x = x + 1 - return x - - return { - 'x_mean': - tft.mean(identity(inputs['x']), name='x') + - tf.zeros_like(inputs['x']) - } - - feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)} - input_data_dict = { - analyzer_cache.DatasetKey('span-0'): [dict(x=-2), dict(x=4)] - } - run_result = self._run_pipeline( - feature_spec, - input_data_dict, - functools.partial(preprocessing_fn, should_add_one=False), - transform_fn_output_dir=None, - use_tf_compat_v1=use_tf_compat_v1, - ) - first_cache_output, metrics = run_result.cache_output, run_result.metrics + self.assertMetricsCounterEqual(metrics, "num_instances", 2) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 0) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 1) + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _SINGLE_PHASE_NUM_SAVED_MODELS + ) + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) - for key in input_data_dict: - self.assertIn(key, first_cache_output) - self.assertEqual(1, len(first_cache_output[key].cache_dict)) + # Cache is still valid since the contents of the tf.function are the same. + run_result = self._run_pipeline( + feature_spec, + input_data_dict, + functools.partial(preprocessing_fn, should_add_one=False), + should_read_cache=True, + transform_fn_output_dir=None, + use_tf_compat_v1=use_tf_compat_v1, + ) + second_cache_output, metrics = run_result.cache_output, run_result.metrics - self.assertMetricsCounterEqual(metrics, 'num_instances', 2) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 0) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 1) - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _SINGLE_PHASE_NUM_SAVED_MODELS - ) - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + self.assertFalse(second_cache_output) - # Cache is still valid since the contents of the tf.function are the same. - run_result = self._run_pipeline( - feature_spec, - input_data_dict, - functools.partial(preprocessing_fn, should_add_one=False), - should_read_cache=True, - transform_fn_output_dir=None, - use_tf_compat_v1=use_tf_compat_v1, - ) - second_cache_output, metrics = run_result.cache_output, run_result.metrics + self.assertMetricsCounterEqual(metrics, "num_instances", 0) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 1) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 0) + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _ZERO_PHASE_NUM_SAVED_MODELS + ) + self.assertMetricsCounterGreater(metrics, "analysis_input_bytes_from_cache", 23) - self.assertFalse(second_cache_output) + # Modifying the tf.function contents causes cache invalidation. + run_result = self._run_pipeline( + feature_spec, + input_data_dict, + functools.partial(preprocessing_fn, should_add_one=True), + should_read_cache=True, + transform_fn_output_dir=None, + use_tf_compat_v1=use_tf_compat_v1, + ) + third_output_cache, metrics = run_result.cache_output, run_result.metrics - self.assertMetricsCounterEqual(metrics, 'num_instances', 0) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 1) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 0) - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _ZERO_PHASE_NUM_SAVED_MODELS - ) - self.assertMetricsCounterGreater( - metrics, 'analysis_input_bytes_from_cache', 23 - ) + for key in input_data_dict: + self.assertIn(key, third_output_cache) + self.assertEqual(1, len(third_output_cache[key].cache_dict)) - # Modifying the tf.function contents causes cache invalidation. - run_result = self._run_pipeline( - feature_spec, - input_data_dict, - functools.partial(preprocessing_fn, should_add_one=True), - should_read_cache=True, - transform_fn_output_dir=None, - use_tf_compat_v1=use_tf_compat_v1, - ) - third_output_cache, metrics = run_result.cache_output, run_result.metrics + self.assertMetricsCounterEqual(metrics, "num_instances", 2) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 0) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 1) + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _SINGLE_PHASE_NUM_SAVED_MODELS + ) + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) - for key in input_data_dict: - self.assertIn(key, third_output_cache) - self.assertEqual(1, len(third_output_cache[key].cache_dict)) + def test_cache_with_missing_metadata(self): + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + span_0_key = analyzer_cache.DatasetKey("span-0") + span_1_key = analyzer_cache.DatasetKey("span-1") + span_2_key = analyzer_cache.DatasetKey("span-2") - self.assertMetricsCounterEqual(metrics, 'num_instances', 2) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 0) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 1) - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _SINGLE_PHASE_NUM_SAVED_MODELS - ) - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + def preprocessing_fn(inputs): + return { + "x_min": tft.min(inputs["x"], name="x") + tf.zeros_like(inputs["x"]) + } - def test_cache_with_missing_metadata(self): - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - span_0_key = analyzer_cache.DatasetKey('span-0') - span_1_key = analyzer_cache.DatasetKey('span-1') - span_2_key = analyzer_cache.DatasetKey('span-2') + feature_spec = {"x": tf.io.FixedLenFeature([], tf.float32)} + input_data_dict = { + span_0_key: [], + span_1_key: [{"x": idx} for idx in range(4)], + span_2_key: [{"x": idx} for idx in range(2)], + } # pyformat: disable - def preprocessing_fn(inputs): - return { - 'x_min': tft.min(inputs['x'], name='x') + tf.zeros_like(inputs['x']) - } + expected_transformed_data = [{"x_min": 0} for _ in range(2)] - feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)} - input_data_dict = { - span_0_key: [], - span_1_key: [{'x': idx} for idx in range(4)], - span_2_key: [{'x': idx} for idx in range(2)], - } # pyformat: disable + first_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + should_read_cache=True, + datasets_to_transform=[span_2_key], + expected_transform_data=expected_transformed_data, + use_tf_compat_v1=False, + ) - expected_transformed_data = [{'x_min': 0} for _ in range(2)] + # Deleting dataset cache metadata files. + for key in input_data_dict: + self.assertIn(key, first_run_result.cache_output) + dataset_cache_dir = os.path.join(self._cache_dir, key.key) + cache_metadata_files = tf.io.gfile.glob( + os.path.join( + dataset_cache_dir, f"{analyzer_cache._METADATA_FILE_NAME}*" + ) + ) + self.assertLen(cache_metadata_files, 1) + os.rename( + cache_metadata_files[0], + os.path.join( + dataset_cache_dir, f"deleted_{analyzer_cache._METADATA_FILE_NAME}" + ), + ) - first_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - should_read_cache=True, - datasets_to_transform=[span_2_key], - expected_transform_data=expected_transformed_data, - use_tf_compat_v1=False) - - # Deleting dataset cache metadata files. - for key in input_data_dict: - self.assertIn(key, first_run_result.cache_output) - dataset_cache_dir = os.path.join(self._cache_dir, key.key) - cache_metadata_files = tf.io.gfile.glob( - os.path.join(dataset_cache_dir, - f'{analyzer_cache._METADATA_FILE_NAME}*')) - self.assertLen(cache_metadata_files, 1) - os.rename( - cache_metadata_files[0], - os.path.join(dataset_cache_dir, - f'deleted_{analyzer_cache._METADATA_FILE_NAME}')) - - metrics = first_run_result.metrics - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 0) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 3) - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + metrics = first_run_result.metrics + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 0) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 3) + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) - second_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - should_read_cache=True, - datasets_to_transform=[span_2_key], - expected_transform_data=expected_transformed_data, - use_tf_compat_v1=False) - - self.assertFalse(second_run_result.cache_output) - - metrics = second_run_result.metrics - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 3) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 0) - # Because the metadata associated with cached datasets was deleted, we - # report 0 for this counter. - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + second_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + should_read_cache=True, + datasets_to_transform=[span_2_key], + expected_transform_data=expected_transformed_data, + use_tf_compat_v1=False, + ) - @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) - def test_non_cached_dataset_run_twice(self, use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - span_0_key = analyzer_cache.DatasetKey('span-0') - span_1_key = analyzer_cache.DatasetKey('span-1').non_cacheable() - - num_analyzers = 2 - - def preprocessing_fn(inputs): - return { - 'x_min': tft.min(inputs['x'], name='x') + tf.zeros_like(inputs['x']), - 's_integerized': tft.compute_and_apply_vocabulary( - inputs['s'], vocab_filename='vocab' - ), - } - - feature_spec = { - 'x': tf.io.FixedLenFeature([], tf.float32), - 's': tf.io.FixedLenFeature([], tf.string), - } - input_data_dict = { - span_0_key: [ - { - 'x': -2, - 's': 'a', - }, - { - 'x': 4, - 's': 'a', - }, - { - 'x': 5, - 's': 'a', - }, - { - 'x': 1, - 's': 'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'), - }, - ], - span_1_key: [ + self.assertFalse(second_run_result.cache_output) + + metrics = second_run_result.metrics + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 3) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 0) + # Because the metadata associated with cached datasets was deleted, we + # report 0 for this counter. + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) + + @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) + def test_non_cached_dataset_run_twice(self, use_tf_compat_v1): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + span_0_key = analyzer_cache.DatasetKey("span-0") + span_1_key = analyzer_cache.DatasetKey("span-1").non_cacheable() + + num_analyzers = 2 + + def preprocessing_fn(inputs): + return { + "x_min": tft.min(inputs["x"], name="x") + tf.zeros_like(inputs["x"]), + "s_integerized": tft.compute_and_apply_vocabulary( + inputs["s"], vocab_filename="vocab" + ), + } + + feature_spec = { + "x": tf.io.FixedLenFeature([], tf.float32), + "s": tf.io.FixedLenFeature([], tf.string), + } + input_data_dict = { + span_0_key: [ + { + "x": -2, + "s": "a", + }, + { + "x": 4, + "s": "a", + }, + { + "x": 5, + "s": "a", + }, + { + "x": 1, + "s": "ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴".encode(), + }, + ], + span_1_key: [ + { + "x": 5, + "s": "c", + }, + { + "x": 5, + "s": "2", + }, + ], + } + expected_vocabulary_contents = np.array( + [b"a", "ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴".encode(), b"c", b"2"], dtype=object + ) + + expected_transformed_data = [ { - 'x': 5, - 's': 'c', + "x_min": -2.0, + "s_integerized": 2, }, { - 'x': 5, - 's': '2', + "x_min": -2.0, + "s_integerized": 3, }, - ], - } - expected_vocabulary_contents = np.array( - [b'a', 'ȟᎥ𝒋ǩľḿꞑȯ𝘱𝑞𝗋𝘴'.encode('utf-8'), b'c', b'2'], dtype=object - ) - - expected_transformed_data = [ - { - 'x_min': -2.0, - 's_integerized': 2, - }, - { - 'x_min': -2.0, - 's_integerized': 3, - }, - ] + ] - first_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - datasets_to_transform=[span_1_key], - expected_transform_data=expected_transformed_data, - use_tf_compat_v1=use_tf_compat_v1, - ) - - for key in input_data_dict: - if key.is_cached: - self.assertIn(key, first_run_result.cache_output) - self.assertEqual( - num_analyzers, len(first_run_result.cache_output[key].cache_dict) + first_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + datasets_to_transform=[span_1_key], + expected_transform_data=expected_transformed_data, + use_tf_compat_v1=use_tf_compat_v1, ) - else: - self.assertNotIn(key, first_run_result.cache_output) - vocab1_path = first_run_result.transform_output.vocabulary_file_by_name( - 'vocab' - ) - self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) - - metrics = first_run_result.metrics - # 6 from analyzing 2 spans, and 2 from transform. - self.assertMetricsCounterEqual(metrics, 'num_instances', 8) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 0) - # 2 analyzers, computed for just one dataset. - self.assertMetricsCounterEqual( - metrics, 'cache_entries_encoded', num_analyzers - ) - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _SINGLE_PHASE_NUM_SAVED_MODELS - ) - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + for key in input_data_dict: + if key.is_cached: + self.assertIn(key, first_run_result.cache_output) + self.assertEqual( + num_analyzers, len(first_run_result.cache_output[key].cache_dict) + ) + else: + self.assertNotIn(key, first_run_result.cache_output) + + vocab1_path = first_run_result.transform_output.vocabulary_file_by_name("vocab") + self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) + + metrics = first_run_result.metrics + # 6 from analyzing 2 spans, and 2 from transform. + self.assertMetricsCounterEqual(metrics, "num_instances", 8) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 0) + # 2 analyzers, computed for just one dataset. + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", num_analyzers) + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _SINGLE_PHASE_NUM_SAVED_MODELS + ) + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) - second_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - should_read_cache=True, - datasets_to_transform=[span_1_key], - expected_transform_data=expected_transformed_data, - use_tf_compat_v1=use_tf_compat_v1, - ) + second_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + should_read_cache=True, + datasets_to_transform=[span_1_key], + expected_transform_data=expected_transformed_data, + use_tf_compat_v1=use_tf_compat_v1, + ) - vocab1_path = second_run_result.transform_output.vocabulary_file_by_name( - 'vocab' - ) - self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) + vocab1_path = second_run_result.transform_output.vocabulary_file_by_name( + "vocab" + ) + self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) - self.assertFalse(second_run_result.cache_output) + self.assertFalse(second_run_result.cache_output) - metrics = second_run_result.metrics - # Only 2 from transform, plus 2 from non-cache dataset. - self.assertMetricsCounterEqual(metrics, 'num_instances', 4) - self.assertMetricsCounterEqual( - metrics, 'cache_entries_decoded', num_analyzers - ) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 0) + metrics = second_run_result.metrics + # Only 2 from transform, plus 2 from non-cache dataset. + self.assertMetricsCounterEqual(metrics, "num_instances", 4) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", num_analyzers) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 0) - # The non cached dataset causes the same number of saved models created as - # in the first iteration. - self.assertMetricsCounterEqual( - metrics, 'saved_models_created', _SINGLE_PHASE_NUM_SAVED_MODELS - ) - # Cache coverage allowed us to avoid processing this many bytes of data. - self.assertMetricsCounterGreater( - metrics, 'analysis_input_bytes_from_cache', 100 - ) + # The non cached dataset causes the same number of saved models created as + # in the first iteration. + self.assertMetricsCounterEqual( + metrics, "saved_models_created", _SINGLE_PHASE_NUM_SAVED_MODELS + ) + # Cache coverage allowed us to avoid processing this many bytes of data. + self.assertMetricsCounterGreater( + metrics, "analysis_input_bytes_from_cache", 100 + ) - @tft_unit.parameters( - dict( - num_non_packed_analyzers=0, - num_packed_analyzers=0, - num_input_datasets=10, - expected_node_count_with_cache=1, - expected_node_count_without_cache=1, - ), - dict( - num_non_packed_analyzers=2, - num_packed_analyzers=0, - num_input_datasets=1, - expected_node_count_with_cache=29, - expected_node_count_without_cache=24, - ), - dict( - num_non_packed_analyzers=0, - num_packed_analyzers=2, - num_input_datasets=1, - expected_node_count_with_cache=24, - expected_node_count_without_cache=19, - ), - dict( - num_non_packed_analyzers=2, - num_packed_analyzers=0, - num_input_datasets=10, - expected_node_count_with_cache=101, - expected_node_count_without_cache=24, - ), - dict( - num_non_packed_analyzers=0, - num_packed_analyzers=2, - num_input_datasets=10, - expected_node_count_with_cache=87, - expected_node_count_without_cache=19, - ), - ) - def test_node_count( - self, - num_non_packed_analyzers, - num_packed_analyzers, - num_input_datasets, - expected_node_count_with_cache, - expected_node_count_without_cache, - ): - dataset_keys = [ - analyzer_cache.DatasetKey(str(x)) for x in range(num_input_datasets) - ] - specs = impl_helper.get_type_specs_from_feature_specs( - {'x': tf.io.FixedLenFeature([], tf.int64)} + @tft_unit.parameters( + dict( + num_non_packed_analyzers=0, + num_packed_analyzers=0, + num_input_datasets=10, + expected_node_count_with_cache=1, + expected_node_count_without_cache=1, + ), + dict( + num_non_packed_analyzers=2, + num_packed_analyzers=0, + num_input_datasets=1, + expected_node_count_with_cache=29, + expected_node_count_without_cache=24, + ), + dict( + num_non_packed_analyzers=0, + num_packed_analyzers=2, + num_input_datasets=1, + expected_node_count_with_cache=24, + expected_node_count_without_cache=19, + ), + dict( + num_non_packed_analyzers=2, + num_packed_analyzers=0, + num_input_datasets=10, + expected_node_count_with_cache=101, + expected_node_count_without_cache=24, + ), + dict( + num_non_packed_analyzers=0, + num_packed_analyzers=2, + num_input_datasets=10, + expected_node_count_with_cache=87, + expected_node_count_without_cache=19, + ), ) + def test_node_count( + self, + num_non_packed_analyzers, + num_packed_analyzers, + num_input_datasets, + expected_node_count_with_cache, + expected_node_count_without_cache, + ): + dataset_keys = [ + analyzer_cache.DatasetKey(str(x)) for x in range(num_input_datasets) + ] + specs = impl_helper.get_type_specs_from_feature_specs( + {"x": tf.io.FixedLenFeature([], tf.int64)} + ) - def preprocessing_fn(inputs): - for _ in range(num_packed_analyzers): - tft.mean(inputs['x']) - for _ in range(num_non_packed_analyzers): - tft.vocabulary(inputs['x']) - return inputs - - def get_graph_leaf_nodes(cache_enabled): - graph, structured_inputs, structured_outputs = ( - impl_helper.trace_preprocessing_function( - preprocessing_fn, - specs, - use_tf_compat_v1=False, - base_temp_dir=self.base_test_dir, - ) - ) - (transform_fn_future, cache_output_dict, sideeffects) = ( - analysis_graph_builder.build( - graph, - structured_inputs, - structured_outputs, - dataset_keys, - cache_dict={} if cache_enabled else None, - ) - ) - cache_output_nodes = [] - if cache_output_dict: - for cache in cache_output_dict.values(): - cache_output_nodes.extend(cache.values()) - return [transform_fn_future] + cache_output_nodes + list(sideeffects) - - with_cache_graph = get_graph_leaf_nodes(True) - without_cache_graph = get_graph_leaf_nodes(False) - node_count_with_cache = nodes.count_graph_nodes(with_cache_graph) - node_count_without_cache = nodes.count_graph_nodes(without_cache_graph) - self._publish_rendered_dot_graph_file_from_leaf_nodes(without_cache_graph) - self._publish_rendered_dot_graph_file_from_leaf_nodes(with_cache_graph) - self.assertEqual( - (expected_node_count_without_cache, expected_node_count_with_cache), - (node_count_without_cache, node_count_with_cache), - ) + def preprocessing_fn(inputs): + for _ in range(num_packed_analyzers): + tft.mean(inputs["x"]) + for _ in range(num_non_packed_analyzers): + tft.vocabulary(inputs["x"]) + return inputs + + def get_graph_leaf_nodes(cache_enabled): + graph, structured_inputs, structured_outputs = ( + impl_helper.trace_preprocessing_function( + preprocessing_fn, + specs, + use_tf_compat_v1=False, + base_temp_dir=self.base_test_dir, + ) + ) + (transform_fn_future, cache_output_dict, sideeffects) = ( + analysis_graph_builder.build( + graph, + structured_inputs, + structured_outputs, + dataset_keys, + cache_dict={} if cache_enabled else None, + ) + ) + cache_output_nodes = [] + if cache_output_dict: + for cache in cache_output_dict.values(): + cache_output_nodes.extend(cache.values()) + return [transform_fn_future] + cache_output_nodes + list(sideeffects) + + with_cache_graph = get_graph_leaf_nodes(True) + without_cache_graph = get_graph_leaf_nodes(False) + node_count_with_cache = nodes.count_graph_nodes(with_cache_graph) + node_count_without_cache = nodes.count_graph_nodes(without_cache_graph) + self._publish_rendered_dot_graph_file_from_leaf_nodes(without_cache_graph) + self._publish_rendered_dot_graph_file_from_leaf_nodes(with_cache_graph) + self.assertEqual( + (expected_node_count_without_cache, expected_node_count_with_cache), + (node_count_without_cache, node_count_with_cache), + ) - @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) - def test_vocab_reserved_tokens_not_cached(self, use_tf_compat_v1): - if not use_tf_compat_v1: - tft_unit.skip_if_not_tf2('Tensorflow 2.x required.') - - def preprocessing_fn(inputs): - return { - 's_integerized': tft.compute_and_apply_vocabulary( - inputs['s'], vocab_filename='vocab', reserved_tokens=['reserved'] - ), - } - - feature_spec = { - 's': tf.io.FixedLenFeature([], tf.string), - } - input_data_dict = { - analyzer_cache.DatasetKey('span-0'): [dict(s='a')] * 3 + [dict(s='b')] - } - expected_vocabulary_contents = np.array( - [b'reserved', b'a', b'b'], dtype=object - ) + @tft_unit.named_parameters(_TF_VERSION_NAMED_PARAMETERS) + def test_vocab_reserved_tokens_not_cached(self, use_tf_compat_v1): + if not use_tf_compat_v1: + tft_unit.skip_if_not_tf2("Tensorflow 2.x required.") + + def preprocessing_fn(inputs): + return { + "s_integerized": tft.compute_and_apply_vocabulary( + inputs["s"], vocab_filename="vocab", reserved_tokens=["reserved"] + ), + } + + feature_spec = { + "s": tf.io.FixedLenFeature([], tf.string), + } + input_data_dict = { + analyzer_cache.DatasetKey("span-0"): [dict(s="a")] * 3 + [dict(s="b")] + } + expected_vocabulary_contents = np.array([b"reserved", b"a", b"b"], dtype=object) + + first_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + use_tf_compat_v1=use_tf_compat_v1, + ) - first_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - use_tf_compat_v1=use_tf_compat_v1, - ) + vocab1_path = first_run_result.transform_output.vocabulary_file_by_name("vocab") + self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) - vocab1_path = first_run_result.transform_output.vocabulary_file_by_name( - 'vocab' - ) - self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) - - metrics = first_run_result.metrics - self.assertMetricsCounterEqual(metrics, 'num_instances', 4) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 0) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 1) - self.assertMetricsCounterEqual( - metrics, 'analysis_input_bytes_from_cache', 0 - ) + metrics = first_run_result.metrics + self.assertMetricsCounterEqual(metrics, "num_instances", 4) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 0) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 1) + self.assertMetricsCounterEqual(metrics, "analysis_input_bytes_from_cache", 0) - second_run_result = self._run_pipeline( - feature_spec, - input_data_dict, - preprocessing_fn, - should_read_cache=True, - use_tf_compat_v1=use_tf_compat_v1, - ) + second_run_result = self._run_pipeline( + feature_spec, + input_data_dict, + preprocessing_fn, + should_read_cache=True, + use_tf_compat_v1=use_tf_compat_v1, + ) - vocab1_path = second_run_result.transform_output.vocabulary_file_by_name( - 'vocab' - ) - self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) + vocab1_path = second_run_result.transform_output.vocabulary_file_by_name( + "vocab" + ) + self.AssertVocabularyContents(vocab1_path, expected_vocabulary_contents) - self.assertFalse(second_run_result.cache_output) + self.assertFalse(second_run_result.cache_output) - metrics = second_run_result.metrics - self.assertMetricsCounterEqual(metrics, 'num_instances', 0) - self.assertMetricsCounterEqual(metrics, 'cache_entries_decoded', 1) - self.assertMetricsCounterEqual(metrics, 'cache_entries_encoded', 0) - self.assertMetricsCounterGreater( - metrics, 'analysis_input_bytes_from_cache', 1 - ) + metrics = second_run_result.metrics + self.assertMetricsCounterEqual(metrics, "num_instances", 0) + self.assertMetricsCounterEqual(metrics, "cache_entries_decoded", 1) + self.assertMetricsCounterEqual(metrics, "cache_entries_encoded", 0) + self.assertMetricsCounterGreater(metrics, "analysis_input_bytes_from_cache", 1) -if __name__ == '__main__': - tft_unit.main() +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/combiner_packing_util.py b/tensorflow_transform/beam/combiner_packing_util.py index c8ad360..465f58d 100644 --- a/tensorflow_transform/beam/combiner_packing_util.py +++ b/tensorflow_transform/beam/combiner_packing_util.py @@ -39,84 +39,86 @@ import dataclasses from typing import Sequence -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import nodes +from tensorflow_transform import analyzer_nodes, nodes from tensorflow_transform.beam import beam_nodes @dataclasses.dataclass(frozen=True) class _CombinerOpWrapper: - combiner: analyzer_nodes.Combiner - keys: Sequence[str] - label: str + combiner: analyzer_nodes.Combiner + keys: Sequence[str] + label: str class _ValidationVisitor(nodes.Visitor): - """Visitor to determine if a node is ready to run.""" + """Visitor to determine if a node is ready to run.""" - def __init__(self): - self._visited_operation_def_labels = set() + def __init__(self): + self._visited_operation_def_labels = set() - def validate_operation_def(self, operation_def): - assert operation_def.label not in self._visited_operation_def_labels - self._visited_operation_def_labels.add(operation_def.label) + def validate_operation_def(self, operation_def): + assert operation_def.label not in self._visited_operation_def_labels + self._visited_operation_def_labels.add(operation_def.label) - def validate_value(self, value): - assert isinstance(value, nodes.ValueNode) + def validate_value(self, value): + assert isinstance(value, nodes.ValueNode) class _InspectAccumulateCombineVisitor(_ValidationVisitor): - """A visitor that inspects the graph and looks for combine nodes. - - As this visitor visits the TFT Beam Graph, we group together all the - packable combine nodes. Specifically, we look for the following path: - ExtractFromDict --> CacheableCombineAccumulate - The combines under the same grand parent can be packed together. - In this visitor, we group all the packable combines for each unique - grand parent node and save their reference in the `packable_combines` class - attribute. - """ - - def __init__(self): - super().__init__() - # Group all packable combines. We pack all the combines that have the same - # grand parent. - # {grand_parent_label: List of packable _CombinerOpWrapper's} - self.packable_combines = collections.defaultdict(list) - - def visit(self, operation_def, input_values): - self.validate_operation_def(operation_def) - self._maybe_add_packable_combine(operation_def, input_values) - return nodes.OperationNode(operation_def, input_values).outputs - - def _maybe_add_packable_combine(self, operation_def, input_values): - # We cannot pack the per-key combine analyzers as the key may be different - # for each analyzer. - if not isinstance(operation_def, analyzer_nodes.CacheableCombineAccumulate): - return - assert len(input_values) == 1 - - # Get the ExtractFromDict parent node of the current - # CacheableCombineAccumulate node. - parent = input_values[0].parent_operation - if not isinstance(parent.operation_def, beam_nodes.ExtractFromDict): - return - assert len(parent.inputs) == 1 - - # Get the parent of the current ExtractFromDict node. - grand_parent = parent.inputs[0].parent_operation - assert isinstance(grand_parent.operation_def, beam_nodes.ApplySavedModel) - - # This is a packable combine. - grand_parent_label = grand_parent.operation_def.label - self.packable_combines[grand_parent_label].append(_CombinerOpWrapper( - combiner=operation_def.combiner, - keys=parent.operation_def.keys, - label=operation_def.label)) + """A visitor that inspects the graph and looks for combine nodes. + + As this visitor visits the TFT Beam Graph, we group together all the + packable combine nodes. Specifically, we look for the following path: + ExtractFromDict --> CacheableCombineAccumulate + The combines under the same grand parent can be packed together. + In this visitor, we group all the packable combines for each unique + grand parent node and save their reference in the `packable_combines` class + attribute. + """ + + def __init__(self): + super().__init__() + # Group all packable combines. We pack all the combines that have the same + # grand parent. + # {grand_parent_label: List of packable _CombinerOpWrapper's} + self.packable_combines = collections.defaultdict(list) + + def visit(self, operation_def, input_values): + self.validate_operation_def(operation_def) + self._maybe_add_packable_combine(operation_def, input_values) + return nodes.OperationNode(operation_def, input_values).outputs + + def _maybe_add_packable_combine(self, operation_def, input_values): + # We cannot pack the per-key combine analyzers as the key may be different + # for each analyzer. + if not isinstance(operation_def, analyzer_nodes.CacheableCombineAccumulate): + return + assert len(input_values) == 1 + + # Get the ExtractFromDict parent node of the current + # CacheableCombineAccumulate node. + parent = input_values[0].parent_operation + if not isinstance(parent.operation_def, beam_nodes.ExtractFromDict): + return + assert len(parent.inputs) == 1 + + # Get the parent of the current ExtractFromDict node. + grand_parent = parent.inputs[0].parent_operation + assert isinstance(grand_parent.operation_def, beam_nodes.ApplySavedModel) + + # This is a packable combine. + grand_parent_label = grand_parent.operation_def.label + self.packable_combines[grand_parent_label].append( + _CombinerOpWrapper( + combiner=operation_def.combiner, + keys=parent.operation_def.keys, + label=operation_def.label, + ) + ) class _PackAccumulateCombineVisitor(_ValidationVisitor): - r"""A visitor that packs combine nodes in the graph. + r"""A visitor that packs combine nodes in the graph. This visitor takes the grouped combines and performs the packing of those combines. @@ -138,101 +140,111 @@ class _PackAccumulateCombineVisitor(_ValidationVisitor): to the individual combines. """ - def __init__(self, packable_combines): - super().__init__() - self._packable_combines = packable_combines - - self._combine_to_grand_parent = {} - for grand_parent_label, group in self._packable_combines.items(): - for combine_op in group: - self._combine_to_grand_parent[combine_op.label] = grand_parent_label - - # Cache the packed combine node. - # Grand parent node label -> Packed combine node - self._packed_combine_cache = {} - - def visit(self, operation_def, input_values): - self.validate_operation_def(operation_def) - # If we see a combine node which can be packed, create the packed combine - # node and cache it as we will use the same packed node for all the combines - # in the group. - if operation_def.label in self._combine_to_grand_parent: - return self._get_packed_combine(operation_def, input_values) - return nodes.OperationNode(operation_def, input_values).outputs - - def _get_packed_combine(self, operation_def, input_values): - grand_parent_label = self._combine_to_grand_parent[operation_def.label] - # If we are seeing a combine from a group for the first time, create the - # the packed combine node and cache it. - if grand_parent_label not in self._packed_combine_cache: - # Get the grand parent node of the CacheableCombineAccumulate node. - # We will make this node as the parent of the - # PackedCombineAccumulate node. - assert len(input_values) == 1 - parent_node = input_values[0] - assert isinstance(parent_node.parent_operation.operation_def, - beam_nodes.ExtractFromDict) - assert len(parent_node.parent_operation.inputs) == 1 - grand_parent_node = parent_node.parent_operation.inputs[0] - assert (grand_parent_node.parent_operation.operation_def.label == - grand_parent_label) - self._packed_combine_cache[grand_parent_label] = ( - nodes.apply_operation( - analyzer_nodes.PackedCombineAccumulate, - grand_parent_node, - combiners=self._packable_combines[grand_parent_label], - label='PackedCombineAccumulate[{}]'.format(grand_parent_label))) - # For the current combine, create the ExtractFromDict node which - # extracts the accumulator corresponding to this combine from the - # packed combine output. - result = nodes.apply_operation( - beam_nodes.ExtractFromDict, - self._packed_combine_cache[grand_parent_label], - keys=operation_def.label, label=operation_def.label) - return (result,) + def __init__(self, packable_combines): + super().__init__() + self._packable_combines = packable_combines + + self._combine_to_grand_parent = {} + for grand_parent_label, group in self._packable_combines.items(): + for combine_op in group: + self._combine_to_grand_parent[combine_op.label] = grand_parent_label + + # Cache the packed combine node. + # Grand parent node label -> Packed combine node + self._packed_combine_cache = {} + + def visit(self, operation_def, input_values): + self.validate_operation_def(operation_def) + # If we see a combine node which can be packed, create the packed combine + # node and cache it as we will use the same packed node for all the combines + # in the group. + if operation_def.label in self._combine_to_grand_parent: + return self._get_packed_combine(operation_def, input_values) + return nodes.OperationNode(operation_def, input_values).outputs + + def _get_packed_combine(self, operation_def, input_values): + grand_parent_label = self._combine_to_grand_parent[operation_def.label] + # If we are seeing a combine from a group for the first time, create the + # the packed combine node and cache it. + if grand_parent_label not in self._packed_combine_cache: + # Get the grand parent node of the CacheableCombineAccumulate node. + # We will make this node as the parent of the + # PackedCombineAccumulate node. + assert len(input_values) == 1 + parent_node = input_values[0] + assert isinstance( + parent_node.parent_operation.operation_def, beam_nodes.ExtractFromDict + ) + assert len(parent_node.parent_operation.inputs) == 1 + grand_parent_node = parent_node.parent_operation.inputs[0] + assert ( + grand_parent_node.parent_operation.operation_def.label + == grand_parent_label + ) + self._packed_combine_cache[grand_parent_label] = nodes.apply_operation( + analyzer_nodes.PackedCombineAccumulate, + grand_parent_node, + combiners=self._packable_combines[grand_parent_label], + label=f"PackedCombineAccumulate[{grand_parent_label}]", + ) + # For the current combine, create the ExtractFromDict node which + # extracts the accumulator corresponding to this combine from the + # packed combine output. + result = nodes.apply_operation( + beam_nodes.ExtractFromDict, + self._packed_combine_cache[grand_parent_label], + keys=operation_def.label, + label=operation_def.label, + ) + return (result,) + _COMBINE_PARENT_NODE_TYPES = ( - beam_nodes.ExtractFromDict, beam_nodes.Flatten, analyzer_nodes.DecodeCache) + beam_nodes.ExtractFromDict, + beam_nodes.Flatten, + analyzer_nodes.DecodeCache, +) class _InspectMergeCombineVisitor(_ValidationVisitor): - """A visitor that inspects the graph and looks for merge combine nodes.""" - - def __init__(self): - super().__init__() - # Gather all the packable merge combines. - # Dict {ExtractCombineMergeOutputs (child of CacheableCombineMerge) label: - # _CombinerOpWrapper} - self.packable_combine_extract_outputs = collections.OrderedDict() - - def visit(self, operation_def, input_values): - self.validate_operation_def(operation_def) - self._maybe_add_packable_combine(operation_def, input_values) - return nodes.OperationNode(operation_def, input_values).outputs - - def _maybe_add_packable_combine(self, operation_def, input_values): - if not isinstance(operation_def, analyzer_nodes.ExtractCombineMergeOutputs): - return - # Verify we have a CacheableCombineMerge parent. - parent = input_values[0].parent_operation - if not isinstance(parent.operation_def, - analyzer_nodes.CacheableCombineMerge): - return - assert len(parent.inputs) == 1 - grand_parent = parent.inputs[0].parent_operation - # We look for packable combines. Specifically, CacheableCombineMerge nodes - # whose parent is one of the type in _COMBINE_PARENT_NODE_TYPES. - if isinstance(grand_parent.operation_def, _COMBINE_PARENT_NODE_TYPES): - # This is a packable combine. - self.packable_combine_extract_outputs[operation_def.label] = ( - _CombinerOpWrapper( - combiner=parent.operation_def.combiner, - keys=(parent.operation_def.label,), - label=parent.operation_def.label)) + """A visitor that inspects the graph and looks for merge combine nodes.""" + + def __init__(self): + super().__init__() + # Gather all the packable merge combines. + # Dict {ExtractCombineMergeOutputs (child of CacheableCombineMerge) label: + # _CombinerOpWrapper} + self.packable_combine_extract_outputs = collections.OrderedDict() + + def visit(self, operation_def, input_values): + self.validate_operation_def(operation_def) + self._maybe_add_packable_combine(operation_def, input_values) + return nodes.OperationNode(operation_def, input_values).outputs + + def _maybe_add_packable_combine(self, operation_def, input_values): + if not isinstance(operation_def, analyzer_nodes.ExtractCombineMergeOutputs): + return + # Verify we have a CacheableCombineMerge parent. + parent = input_values[0].parent_operation + if not isinstance(parent.operation_def, analyzer_nodes.CacheableCombineMerge): + return + assert len(parent.inputs) == 1 + grand_parent = parent.inputs[0].parent_operation + # We look for packable combines. Specifically, CacheableCombineMerge nodes + # whose parent is one of the type in _COMBINE_PARENT_NODE_TYPES. + if isinstance(grand_parent.operation_def, _COMBINE_PARENT_NODE_TYPES): + # This is a packable combine. + self.packable_combine_extract_outputs[operation_def.label] = ( + _CombinerOpWrapper( + combiner=parent.operation_def.combiner, + keys=(parent.operation_def.label,), + label=parent.operation_def.label, + ) + ) class _PackMergeCombineVisitor(_ValidationVisitor): - r"""A visitor that inspects the graph and looks for combine nodes. + r"""A visitor that inspects the graph and looks for combine nodes. This visitor takes the grouped combines and performs the packing of those combines. @@ -266,75 +278,76 @@ class _PackMergeCombineVisitor(_ValidationVisitor): redundant packed merge and flatten nodes which needs to be removed. """ - def __init__(self, packable_combine_extract_outputs): - super().__init__() - self._packable_combine_extract_outputs = packable_combine_extract_outputs - # Gather all the input nodes that we need to flatten to be passed as input - # to the packed merge node. - self._flatten_inputs = [] - # Keep track of the label of the final packed merge combine node. - self.final_packed_merge_combine_label = None - - def visit(self, operation_def, input_values): - self.validate_operation_def(operation_def) - # We look for the ExtractOutputs node of packable combines - if operation_def.label in self._packable_combine_extract_outputs: - return self._add_flatten_placeholder(operation_def, input_values) - return nodes.OperationNode(operation_def, input_values).outputs - - def _add_flatten_placeholder(self, operation_def, input_values): - assert isinstance(operation_def, analyzer_nodes.ExtractCombineMergeOutputs) - parent = input_values[0].parent_operation - assert isinstance(parent.operation_def, - analyzer_nodes.CacheableCombineMerge) - packed_combine = self._get_packed_combine( - parent.operation_def, parent.inputs) - # For the current combine, create the ExtractFromDict node which - # extracts the accumulator corresponding to this combine from the - # packed combine output. - extract_dict_node = nodes.apply_operation( - beam_nodes.ExtractFromDict, - packed_combine, - keys=parent.operation_def.label, - label='ExtractFromDict[{}]'.format(parent.operation_def.label)) - # Create the new ExtractPackedCombineMergeOutputs node. - return nodes.apply_multi_output_operation( - analyzer_nodes.ExtractPackedCombineMergeOutputs, - extract_dict_node, - output_tensor_info_list=operation_def.output_tensor_infos, - label='ExtractPackedCombineMergeOutputs[{}]'.format( - parent.operation_def.label) - ) - - def _get_packed_combine(self, operation_def, input_values): - for value in input_values: - keyed_value = nodes.apply_operation( - analyzer_nodes.AddKey, - value, - key=operation_def.label, - label='AddKey[{}]'.format(operation_def.label)) - self._flatten_inputs.append(keyed_value) - # TODO(b/134414978): When we add support for multi-phase merge packing, - # add phase number to the flatten and packed combine labels. - flatten_label = 'FlattenInputForPackedCombineMerge[{}]'.format( - len(self._flatten_inputs)) - flatten_node = nodes.apply_operation( - beam_nodes.Flatten, *self._flatten_inputs, label=flatten_label) - packed_combine_label = 'PackedCombineMerge[{}]'.format( - len(self._flatten_inputs)) - packed_combine = nodes.apply_operation( - analyzer_nodes.PackedCombineMerge, - flatten_node, - combiners=list(self._packable_combine_extract_outputs.values()), - label=packed_combine_label) - self.final_packed_merge_combine_label = packed_combine_label - return packed_combine + def __init__(self, packable_combine_extract_outputs): + super().__init__() + self._packable_combine_extract_outputs = packable_combine_extract_outputs + # Gather all the input nodes that we need to flatten to be passed as input + # to the packed merge node. + self._flatten_inputs = [] + # Keep track of the label of the final packed merge combine node. + self.final_packed_merge_combine_label = None + + def visit(self, operation_def, input_values): + self.validate_operation_def(operation_def) + # We look for the ExtractOutputs node of packable combines + if operation_def.label in self._packable_combine_extract_outputs: + return self._add_flatten_placeholder(operation_def, input_values) + return nodes.OperationNode(operation_def, input_values).outputs + + def _add_flatten_placeholder(self, operation_def, input_values): + assert isinstance(operation_def, analyzer_nodes.ExtractCombineMergeOutputs) + parent = input_values[0].parent_operation + assert isinstance(parent.operation_def, analyzer_nodes.CacheableCombineMerge) + packed_combine = self._get_packed_combine(parent.operation_def, parent.inputs) + # For the current combine, create the ExtractFromDict node which + # extracts the accumulator corresponding to this combine from the + # packed combine output. + extract_dict_node = nodes.apply_operation( + beam_nodes.ExtractFromDict, + packed_combine, + keys=parent.operation_def.label, + label=f"ExtractFromDict[{parent.operation_def.label}]", + ) + # Create the new ExtractPackedCombineMergeOutputs node. + return nodes.apply_multi_output_operation( + analyzer_nodes.ExtractPackedCombineMergeOutputs, + extract_dict_node, + output_tensor_info_list=operation_def.output_tensor_infos, + label=f"ExtractPackedCombineMergeOutputs[{parent.operation_def.label}]", + ) + + def _get_packed_combine(self, operation_def, input_values): + for value in input_values: + keyed_value = nodes.apply_operation( + analyzer_nodes.AddKey, + value, + key=operation_def.label, + label=f"AddKey[{operation_def.label}]", + ) + self._flatten_inputs.append(keyed_value) + # TODO(b/134414978): When we add support for multi-phase merge packing, + # add phase number to the flatten and packed combine labels. + flatten_label = ( + f"FlattenInputForPackedCombineMerge[{len(self._flatten_inputs)}]" + ) + flatten_node = nodes.apply_operation( + beam_nodes.Flatten, *self._flatten_inputs, label=flatten_label + ) + packed_combine_label = f"PackedCombineMerge[{len(self._flatten_inputs)}]" + packed_combine = nodes.apply_operation( + analyzer_nodes.PackedCombineMerge, + flatten_node, + combiners=list(self._packable_combine_extract_outputs.values()), + label=packed_combine_label, + ) + self.final_packed_merge_combine_label = packed_combine_label + return packed_combine @dataclasses.dataclass(frozen=True) class _TensorBindingInfo: - intermediate_post_processing_op_defs: Sequence[nodes.OperationDef] - output_index: int + intermediate_post_processing_op_defs: Sequence[nodes.OperationDef] + output_index: int # Maximum search depth for packed post-processing nodes. @@ -342,242 +355,266 @@ class _TensorBindingInfo: class _RemoveRedundantPackedMergeCombineVisitor(_ValidationVisitor): - """A visitor that inspects the graph and removes redundant merge nodes. - - This visitor removes the redundant flatten and packed merge nodes added - by the _PackMergeCombineVisitor and reconstructs the descendants of the - removed nodes with the final flatten and packed merge node. - """ - - def __init__(self, final_packed_merge_combine_label): - super().__init__() - self._final_packed_merge_combine_label = final_packed_merge_combine_label - self._packed_post_processing_nodes_cache = {} - - def visit(self, operation_def, input_values): - self.validate_operation_def(operation_def) - if input_values and isinstance(operation_def, beam_nodes.CreateSavedModel): - # This will only be called once since this is a single phase analysis - # graph and in that case only the final CreateSavedModel node has inputs. - return self._remove_redundant_nodes(operation_def, input_values) - return nodes.OperationNode(operation_def, input_values).outputs - - def _remove_redundant_nodes(self, operation_def, input_values): - # Input values to be used as input to CreateSavedModel. - # Since some of the input values are generated from the redundant nodes, - # those needs to be reconstructed with the final packed merge node. - reconstructed_input_values = [] - - redundant_values, non_redundant_values = ( - self._get_redundant_and_non_redundant_input_values(input_values)) - - # Keep track of the final packed merge combine node. For those input nodes - # which are descendants of the redundant nodes, we would create a new node - # generated from the final packed merge combine node. - (final_packed_merge_combine, final_packed_merge_combine_tensor_bindings) = ( - self._get_final_packed_combine_and_tensor_bindings(redundant_values)) - reconstructed_input_values.extend( - final_packed_merge_combine_tensor_bindings) - - # Add the non-redundant nodes to the input values. - reconstructed_input_values.extend(non_redundant_values) - - # Keep track of the info needed to reconstruct the descendents of the - # redundant nodes. - to_be_created_tensor_bindings = ( - self._get_to_be_created_tensor_bindings_info(redundant_values)) - - reconstructed_input_values.extend(self._create_tensor_bindings( - to_be_created_tensor_bindings, final_packed_merge_combine)) - assert len(input_values) == len(reconstructed_input_values) - return nodes.OperationNode( - operation_def, tuple(reconstructed_input_values)).outputs - - def _is_packed_post_processing_node(self, - value_node: nodes.ValueNode) -> bool: - # ValueNode is considered a packed post-processing node iff - # PackedCombineMerge node is its ancestor. - if value_node in self._packed_post_processing_nodes_cache: - return self._packed_post_processing_nodes_cache[value_node] - - input_nodes = set() - search_depth = 0 - result = False - while (value_node.parent_operation.inputs and - search_depth < _MAX_PACKED_POST_PROCESSING_DEPTH): - # Post-processing nodes form a tree. Looking only at the first input. - input_nodes.add(value_node) - value_node = value_node.parent_operation.inputs[0] - if isinstance(value_node.parent_operation.operation_def, - analyzer_nodes.PackedCombineMerge): - result = True - break - search_depth += 1 - self._packed_post_processing_nodes_cache.update( - {node: result for node in input_nodes}) - return result - - def _get_redundant_and_non_redundant_input_values( - self, input_values): - redundant_values, non_redundant_values = [], [] - for value in input_values: - assert isinstance(value.parent_operation.operation_def, - beam_nodes.CreateTensorBinding) - # If it's from a packed combine node, this is a redundant value. - if self._is_packed_post_processing_node(value): - redundant_values.append(value) - else: - non_redundant_values.append(value) - return redundant_values, non_redundant_values - - def _get_final_packed_combine_and_tensor_bindings(self, input_values): - final_packed_merge_combine = None - final_packed_merge_combine_tensor_bindings = [] - for value in input_values: - # PackedCombineMerge is the first not post-processing node on backwards - # traversal. Post-processing nodes form a tree, it is enough to iterate - # through first inputs. - packed_combine = value.parent_operation.inputs[0] - while self._is_packed_post_processing_node(packed_combine): - packed_combine = packed_combine.parent_operation.inputs[0] - # If the input is generated from the final packed merge node, add it to - # the filtered inputs and keep track of the node for reconstruction of - # the other inputs. - packed_combine_op_def = packed_combine.parent_operation.operation_def - if (isinstance(packed_combine_op_def, analyzer_nodes.PackedCombineMerge) - and (packed_combine_op_def.label - == self._final_packed_merge_combine_label)): - final_packed_merge_combine = packed_combine - final_packed_merge_combine_tensor_bindings.append(value) - return (final_packed_merge_combine, - final_packed_merge_combine_tensor_bindings) - - def _get_to_be_created_tensor_bindings_info(self, input_values): - result = [] - for value in input_values: - intermidiate_post_processing_op_defs = [] - intermidiate_value = value - output_index = None - while self._is_packed_post_processing_node(intermidiate_value): - intermidiate_op_def = intermidiate_value.parent_operation.operation_def - intermidiate_post_processing_op_defs.append(intermidiate_op_def) - if isinstance(intermidiate_op_def, - analyzer_nodes.ExtractPackedCombineMergeOutputs): - assert output_index is None - output_index = intermidiate_value.value_index - intermidiate_value = intermidiate_value.parent_operation.inputs[0] - - # If the input is not generated from the final packed merge node, keep - # track of the node for reconstruction of the other inputs. - if (intermidiate_value.parent_operation.operation_def.label != - self._final_packed_merge_combine_label): - # Store the info needed to reconstruct the input node, including - # CreateTensorBinding node's input value index. - result.append( - _TensorBindingInfo(intermidiate_post_processing_op_defs, - output_index)) - return result - - def _create_tensor_bindings(self, to_be_created_tensor_bindings, - final_packed_merge_combine): - labels_to_new_nodes = {} - def _maybe_create_node(op_def, inputs): - if op_def.label in labels_to_new_nodes: - return labels_to_new_nodes[op_def.label] - new_node = nodes.OperationNode(op_def, inputs).outputs - labels_to_new_nodes[op_def.label] = new_node - return new_node - - result = [] - if to_be_created_tensor_bindings: - assert final_packed_merge_combine is not None - # Reconstruct the remaining inputs from the final packed merge node. - for tensor_binding_info in to_be_created_tensor_bindings: - intermediate_nodes = (final_packed_merge_combine,) - for op_def in reversed( - tensor_binding_info.intermediate_post_processing_op_defs): - intermediate_nodes = _maybe_create_node(op_def, intermediate_nodes) - if isinstance(op_def, - analyzer_nodes.ExtractPackedCombineMergeOutputs): - intermediate_nodes = ( - intermediate_nodes[tensor_binding_info.output_index],) - # The last node must be a single CreateTensorBinding. - assert len(intermediate_nodes) == 1, intermediate_nodes - assert isinstance(intermediate_nodes[0].parent_operation.operation_def, - beam_nodes.CreateTensorBinding), intermediate_nodes[0] - result.append(intermediate_nodes[0]) - return result + """A visitor that inspects the graph and removes redundant merge nodes. + + This visitor removes the redundant flatten and packed merge nodes added + by the _PackMergeCombineVisitor and reconstructs the descendants of the + removed nodes with the final flatten and packed merge node. + """ + + def __init__(self, final_packed_merge_combine_label): + super().__init__() + self._final_packed_merge_combine_label = final_packed_merge_combine_label + self._packed_post_processing_nodes_cache = {} + + def visit(self, operation_def, input_values): + self.validate_operation_def(operation_def) + if input_values and isinstance(operation_def, beam_nodes.CreateSavedModel): + # This will only be called once since this is a single phase analysis + # graph and in that case only the final CreateSavedModel node has inputs. + return self._remove_redundant_nodes(operation_def, input_values) + return nodes.OperationNode(operation_def, input_values).outputs + + def _remove_redundant_nodes(self, operation_def, input_values): + # Input values to be used as input to CreateSavedModel. + # Since some of the input values are generated from the redundant nodes, + # those needs to be reconstructed with the final packed merge node. + reconstructed_input_values = [] + + redundant_values, non_redundant_values = ( + self._get_redundant_and_non_redundant_input_values(input_values) + ) + + # Keep track of the final packed merge combine node. For those input nodes + # which are descendants of the redundant nodes, we would create a new node + # generated from the final packed merge combine node. + (final_packed_merge_combine, final_packed_merge_combine_tensor_bindings) = ( + self._get_final_packed_combine_and_tensor_bindings(redundant_values) + ) + reconstructed_input_values.extend(final_packed_merge_combine_tensor_bindings) + + # Add the non-redundant nodes to the input values. + reconstructed_input_values.extend(non_redundant_values) + + # Keep track of the info needed to reconstruct the descendents of the + # redundant nodes. + to_be_created_tensor_bindings = self._get_to_be_created_tensor_bindings_info( + redundant_values + ) + + reconstructed_input_values.extend( + self._create_tensor_bindings( + to_be_created_tensor_bindings, final_packed_merge_combine + ) + ) + assert len(input_values) == len(reconstructed_input_values) + return nodes.OperationNode( + operation_def, tuple(reconstructed_input_values) + ).outputs + + def _is_packed_post_processing_node(self, value_node: nodes.ValueNode) -> bool: + # ValueNode is considered a packed post-processing node iff + # PackedCombineMerge node is its ancestor. + if value_node in self._packed_post_processing_nodes_cache: + return self._packed_post_processing_nodes_cache[value_node] + + input_nodes = set() + search_depth = 0 + result = False + while ( + value_node.parent_operation.inputs + and search_depth < _MAX_PACKED_POST_PROCESSING_DEPTH + ): + # Post-processing nodes form a tree. Looking only at the first input. + input_nodes.add(value_node) + value_node = value_node.parent_operation.inputs[0] + if isinstance( + value_node.parent_operation.operation_def, + analyzer_nodes.PackedCombineMerge, + ): + result = True + break + search_depth += 1 + self._packed_post_processing_nodes_cache.update( + {node: result for node in input_nodes} + ) + return result + + def _get_redundant_and_non_redundant_input_values(self, input_values): + redundant_values, non_redundant_values = [], [] + for value in input_values: + assert isinstance( + value.parent_operation.operation_def, beam_nodes.CreateTensorBinding + ) + # If it's from a packed combine node, this is a redundant value. + if self._is_packed_post_processing_node(value): + redundant_values.append(value) + else: + non_redundant_values.append(value) + return redundant_values, non_redundant_values + + def _get_final_packed_combine_and_tensor_bindings(self, input_values): + final_packed_merge_combine = None + final_packed_merge_combine_tensor_bindings = [] + for value in input_values: + # PackedCombineMerge is the first not post-processing node on backwards + # traversal. Post-processing nodes form a tree, it is enough to iterate + # through first inputs. + packed_combine = value.parent_operation.inputs[0] + while self._is_packed_post_processing_node(packed_combine): + packed_combine = packed_combine.parent_operation.inputs[0] + # If the input is generated from the final packed merge node, add it to + # the filtered inputs and keep track of the node for reconstruction of + # the other inputs. + packed_combine_op_def = packed_combine.parent_operation.operation_def + if isinstance( + packed_combine_op_def, analyzer_nodes.PackedCombineMerge + ) and ( + packed_combine_op_def.label == self._final_packed_merge_combine_label + ): + final_packed_merge_combine = packed_combine + final_packed_merge_combine_tensor_bindings.append(value) + return (final_packed_merge_combine, final_packed_merge_combine_tensor_bindings) + + def _get_to_be_created_tensor_bindings_info(self, input_values): + result = [] + for value in input_values: + intermidiate_post_processing_op_defs = [] + intermidiate_value = value + output_index = None + while self._is_packed_post_processing_node(intermidiate_value): + intermidiate_op_def = intermidiate_value.parent_operation.operation_def + intermidiate_post_processing_op_defs.append(intermidiate_op_def) + if isinstance( + intermidiate_op_def, analyzer_nodes.ExtractPackedCombineMergeOutputs + ): + assert output_index is None + output_index = intermidiate_value.value_index + intermidiate_value = intermidiate_value.parent_operation.inputs[0] + + # If the input is not generated from the final packed merge node, keep + # track of the node for reconstruction of the other inputs. + if ( + intermidiate_value.parent_operation.operation_def.label + != self._final_packed_merge_combine_label + ): + # Store the info needed to reconstruct the input node, including + # CreateTensorBinding node's input value index. + result.append( + _TensorBindingInfo( + intermidiate_post_processing_op_defs, output_index + ) + ) + return result + + def _create_tensor_bindings( + self, to_be_created_tensor_bindings, final_packed_merge_combine + ): + labels_to_new_nodes = {} + + def _maybe_create_node(op_def, inputs): + if op_def.label in labels_to_new_nodes: + return labels_to_new_nodes[op_def.label] + new_node = nodes.OperationNode(op_def, inputs).outputs + labels_to_new_nodes[op_def.label] = new_node + return new_node + + result = [] + if to_be_created_tensor_bindings: + assert final_packed_merge_combine is not None + # Reconstruct the remaining inputs from the final packed merge node. + for tensor_binding_info in to_be_created_tensor_bindings: + intermediate_nodes = (final_packed_merge_combine,) + for op_def in reversed( + tensor_binding_info.intermediate_post_processing_op_defs + ): + intermediate_nodes = _maybe_create_node(op_def, intermediate_nodes) + if isinstance( + op_def, analyzer_nodes.ExtractPackedCombineMergeOutputs + ): + intermediate_nodes = ( + intermediate_nodes[tensor_binding_info.output_index], + ) + # The last node must be a single CreateTensorBinding. + assert len(intermediate_nodes) == 1, intermediate_nodes + assert isinstance( + intermediate_nodes[0].parent_operation.operation_def, + beam_nodes.CreateTensorBinding, + ), intermediate_nodes[0] + result.append(intermediate_nodes[0]) + return result def _update_cache_value_node_references(cache_value_nodes, traverser): - """Updates value node references in the cache.""" - if cache_value_nodes: - cache_value_nodes = { - key: traverser.visit_value_node(value_node) - for key, value_node in cache_value_nodes.items() + """Updates value node references in the cache.""" + if cache_value_nodes: + cache_value_nodes = { + key: traverser.visit_value_node(value_node) + for key, value_node in cache_value_nodes.items() + } + return cache_value_nodes + + +def perform_combiner_packing_optimization( + saved_model_future, cache_value_nodes, num_phases +): + """Optimizes the graph by packing possible combine nodes.""" + # Inspect the graph to identify all the packable combines. + inspect_acc_combine_visitor = _InspectAccumulateCombineVisitor() + inspect_acc_combine_traverser = nodes.Traverser(inspect_acc_combine_visitor) + _ = inspect_acc_combine_traverser.visit_value_node(saved_model_future) + + packable_combines = inspect_acc_combine_visitor.packable_combines + # Do not pack if we have only a single combine in the group. + packable_combines = { + label: group for label, group in packable_combines.items() if len(group) > 1 } - return cache_value_nodes - - -def perform_combiner_packing_optimization(saved_model_future, - cache_value_nodes, num_phases): - """Optimizes the graph by packing possible combine nodes.""" - # Inspect the graph to identify all the packable combines. - inspect_acc_combine_visitor = _InspectAccumulateCombineVisitor() - inspect_acc_combine_traverser = nodes.Traverser(inspect_acc_combine_visitor) - _ = inspect_acc_combine_traverser.visit_value_node(saved_model_future) - - packable_combines = inspect_acc_combine_visitor.packable_combines - # Do not pack if we have only a single combine in the group. - packable_combines = { - label: group for label, group in packable_combines.items() - if len(group) > 1 - } - - pack_acc_combine_visitor = _PackAccumulateCombineVisitor(packable_combines) - pack_acc_combine_traverser = nodes.Traverser(pack_acc_combine_visitor) - saved_model_future = pack_acc_combine_traverser.visit_value_node( - saved_model_future) - - # Replace cache nodes to point to the corresponding new nodes. - cache_value_nodes = _update_cache_value_node_references( - cache_value_nodes, pack_acc_combine_traverser) - - # TODO(b/134414978): Consider also packing the merges even when we have - # multiple phases. - if num_phases > 1: - return (saved_model_future, cache_value_nodes) - # Identify the merge combines that can be packed together. - inspect_merge_combine_visitor = _InspectMergeCombineVisitor() - inspect_merge_combine_traverser = nodes.Traverser( - inspect_merge_combine_visitor) - _ = inspect_merge_combine_traverser.visit_value_node(saved_model_future) + pack_acc_combine_visitor = _PackAccumulateCombineVisitor(packable_combines) + pack_acc_combine_traverser = nodes.Traverser(pack_acc_combine_visitor) + saved_model_future = pack_acc_combine_traverser.visit_value_node(saved_model_future) - # Only pack if we have more than one merge combines. - if len(inspect_merge_combine_visitor.packable_combine_extract_outputs) <= 1: - return (saved_model_future, cache_value_nodes) + # Replace cache nodes to point to the corresponding new nodes. + cache_value_nodes = _update_cache_value_node_references( + cache_value_nodes, pack_acc_combine_traverser + ) - # Add flatten and packed merge nodes. - pack_merge_combine_visitor = _PackMergeCombineVisitor( - packable_combine_extract_outputs= - inspect_merge_combine_visitor.packable_combine_extract_outputs) - pack_merge_combine_traverser = nodes.Traverser(pack_merge_combine_visitor) - saved_model_future = pack_merge_combine_traverser.visit_value_node( - saved_model_future) - # Replace cache nodes to point to the corresponding new nodes. - cache_value_nodes = _update_cache_value_node_references( - cache_value_nodes, pack_merge_combine_traverser) - - # Remove redundant flatten and packed merge nodes. - remove_redundant_visitor = _RemoveRedundantPackedMergeCombineVisitor( - final_packed_merge_combine_label= - pack_merge_combine_visitor.final_packed_merge_combine_label) - remove_redundant_traverser = nodes.Traverser(remove_redundant_visitor) - saved_model_future = remove_redundant_traverser.visit_value_node( - saved_model_future) - # Replace cache nodes to point to the corresponding new nodes. - cache_value_nodes = _update_cache_value_node_references( - cache_value_nodes, remove_redundant_traverser) - - return (saved_model_future, cache_value_nodes) + # TODO(b/134414978): Consider also packing the merges even when we have + # multiple phases. + if num_phases > 1: + return (saved_model_future, cache_value_nodes) + + # Identify the merge combines that can be packed together. + inspect_merge_combine_visitor = _InspectMergeCombineVisitor() + inspect_merge_combine_traverser = nodes.Traverser(inspect_merge_combine_visitor) + _ = inspect_merge_combine_traverser.visit_value_node(saved_model_future) + + # Only pack if we have more than one merge combines. + if len(inspect_merge_combine_visitor.packable_combine_extract_outputs) <= 1: + return (saved_model_future, cache_value_nodes) + + # Add flatten and packed merge nodes. + pack_merge_combine_visitor = _PackMergeCombineVisitor( + packable_combine_extract_outputs=inspect_merge_combine_visitor.packable_combine_extract_outputs + ) + pack_merge_combine_traverser = nodes.Traverser(pack_merge_combine_visitor) + saved_model_future = pack_merge_combine_traverser.visit_value_node( + saved_model_future + ) + # Replace cache nodes to point to the corresponding new nodes. + cache_value_nodes = _update_cache_value_node_references( + cache_value_nodes, pack_merge_combine_traverser + ) + + # Remove redundant flatten and packed merge nodes. + remove_redundant_visitor = _RemoveRedundantPackedMergeCombineVisitor( + final_packed_merge_combine_label=pack_merge_combine_visitor.final_packed_merge_combine_label + ) + remove_redundant_traverser = nodes.Traverser(remove_redundant_visitor) + saved_model_future = remove_redundant_traverser.visit_value_node(saved_model_future) + # Replace cache nodes to point to the corresponding new nodes. + cache_value_nodes = _update_cache_value_node_references( + cache_value_nodes, remove_redundant_traverser + ) + + return (saved_model_future, cache_value_nodes) diff --git a/tensorflow_transform/beam/combiner_packing_util_test.py b/tensorflow_transform/beam/combiner_packing_util_test.py index b1ac44a..7888be5 100644 --- a/tensorflow_transform/beam/combiner_packing_util_test.py +++ b/tensorflow_transform/beam/combiner_packing_util_test.py @@ -16,41 +16,43 @@ from unittest import mock import tensorflow as tf + import tensorflow_transform as tft -from tensorflow_transform import impl_helper -from tensorflow_transform import nodes -from tensorflow_transform.beam import analysis_graph_builder -from tensorflow_transform.beam import combiner_packing_util -from tensorflow_transform import test_case +from tensorflow_transform import impl_helper, nodes, test_case +from tensorflow_transform.beam import analysis_graph_builder, combiner_packing_util def _preprocessing_fn_with_packable_analyzer_single_phase(inputs): - x, y = inputs['x'], inputs['y'] - x_mean = tft.mean(x, name='x') - x_centered = x - x_mean - y_mean = tft.mean(y, name='y') - y_centered = y - y_mean - z = inputs['z'] - z_vocab = tft.vocabulary(z, name='z') - _ = tft.experimental.approximate_vocabulary(z, top_k=10, name='z_approx') - initializer = tf.lookup.TextFileInitializer( - z_vocab, - key_dtype=tf.string, - key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - value_dtype=tf.int64, - value_index=tf.lookup.TextFileIndex.LINE_NUMBER) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - z_integerized = table.lookup(z) - return {'x_centered': x_centered, 'y_centered': y_centered, - 'z_integerized': z_integerized} + x, y = inputs["x"], inputs["y"] + x_mean = tft.mean(x, name="x") + x_centered = x - x_mean + y_mean = tft.mean(y, name="y") + y_centered = y - y_mean + z = inputs["z"] + z_vocab = tft.vocabulary(z, name="z") + _ = tft.experimental.approximate_vocabulary(z, top_k=10, name="z_approx") + initializer = tf.lookup.TextFileInitializer( + z_vocab, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + ) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + z_integerized = table.lookup(z) + return { + "x_centered": x_centered, + "y_centered": y_centered, + "z_integerized": z_integerized, + } _PACKABLE_ANALYZER_SINGLE_PHASE_CASE = dict( - testcase_name='with_packable_analyzer_single_phase', + testcase_name="with_packable_analyzer_single_phase", feature_spec={ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32), - 'z': tf.io.FixedLenFeature([], tf.string) + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + "z": tf.io.FixedLenFeature([], tf.string), }, preprocessing_fn=_preprocessing_fn_with_packable_analyzer_single_phase, num_phases=1, @@ -215,27 +217,28 @@ def _preprocessing_fn_with_packable_analyzer_single_phase(inputs): "CreateTensorBinding[y#mean_and_var#Placeholder]" -> CreateSavedModel; "CreateTensorBinding[y#mean_and_var#Placeholder_1]" -> CreateSavedModel; } -""") +""", +) def _preprocessing_fn_with_packable_analyzer_two_phases(inputs): - x, y = inputs['x'], inputs['y'] - x_mean = tft.mean(x, name='x') - x_square_deviations = tf.square(x - x_mean) - x_var = tft.mean(x_square_deviations, name='x_square_deviations') - x_normalized = (x - x_mean) / tf.sqrt(x_var) - y_mean = tft.mean(y, name='y') - y_square_deviations = tf.square(y - y_mean) - y_var = tft.mean(y_square_deviations, name='y_square_deviations') - y_normalized = (y - y_mean) / tf.sqrt(y_var) - return {'x_normalized': x_normalized, 'y_normalized': y_normalized} + x, y = inputs["x"], inputs["y"] + x_mean = tft.mean(x, name="x") + x_square_deviations = tf.square(x - x_mean) + x_var = tft.mean(x_square_deviations, name="x_square_deviations") + x_normalized = (x - x_mean) / tf.sqrt(x_var) + y_mean = tft.mean(y, name="y") + y_square_deviations = tf.square(y - y_mean) + y_var = tft.mean(y_square_deviations, name="y_square_deviations") + y_normalized = (y - y_mean) / tf.sqrt(y_var) + return {"x_normalized": x_normalized, "y_normalized": y_normalized} _PACKABLE_ANALYZER_TWO_PHASES_CASE = dict( - testcase_name='with_packable_analyzer_two_phases', + testcase_name="with_packable_analyzer_two_phases", feature_spec={ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32) + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), }, preprocessing_fn=_preprocessing_fn_with_packable_analyzer_two_phases, num_phases=2, @@ -384,7 +387,8 @@ def _preprocessing_fn_with_packable_analyzer_two_phases(inputs): "CreateTensorBinding[y_square_deviations#mean_and_var#Placeholder]" -> CreateSavedModel; "CreateTensorBinding[y_square_deviations#mean_and_var#Placeholder_1]" -> CreateSavedModel; } -""") +""", +) _COMBINER_PACKING_TEST_CASES = [ _PACKABLE_ANALYZER_SINGLE_PHASE_CASE, @@ -393,45 +397,55 @@ def _preprocessing_fn_with_packable_analyzer_two_phases(inputs): class CombinerPackingUtilTest(test_case.TransformTestCase): + @test_case.named_parameters(*_COMBINER_PACKING_TEST_CASES) + def test_perform_combiner_packing_optimization( + self, + feature_spec, + preprocessing_fn, + num_phases, + expected_dot_graph_str_before_packing, + expected_dot_graph_str_after_packing, + ): + graph, structured_inputs, structured_outputs = ( + impl_helper.trace_preprocessing_function( + preprocessing_fn, feature_spec, use_tf_compat_v1=True + ) + ) - @test_case.named_parameters(*_COMBINER_PACKING_TEST_CASES) - def test_perform_combiner_packing_optimization( - self, feature_spec, preprocessing_fn, num_phases, - expected_dot_graph_str_before_packing, - expected_dot_graph_str_after_packing): - - graph, structured_inputs, structured_outputs = ( - impl_helper.trace_preprocessing_function( - preprocessing_fn, feature_spec, use_tf_compat_v1=True)) - - def _side_effect_fn(saved_model_future, cache_value_nodes, - unused_num_phases): - return (saved_model_future, cache_value_nodes) + def _side_effect_fn(saved_model_future, cache_value_nodes, unused_num_phases): + return (saved_model_future, cache_value_nodes) - with mock.patch.object( - combiner_packing_util, - 'perform_combiner_packing_optimization', - side_effect=_side_effect_fn): - (transform_fn_future_before, unused_cache, - _) = analysis_graph_builder.build(graph, structured_inputs, - structured_outputs) - transform_fn_future_after, unused_cache = ( - combiner_packing_util.perform_combiner_packing_optimization( - transform_fn_future_before, unused_cache, num_phases)) - dot_string_before = nodes.get_dot_graph( - [transform_fn_future_before]).to_string() - self.assertMultiLineEqual( - msg='Prior to optimization dot graph is:\n{}'.format(dot_string_before), - first=dot_string_before, - second=expected_dot_graph_str_before_packing) - dot_string_after = nodes.get_dot_graph( - [transform_fn_future_after]).to_string() - self.WriteRenderedDotFile(dot_string_after) - self.assertMultiLineEqual( - msg='After optimization dot graph is:\n{}'.format(dot_string_after), - first=dot_string_after, - second=expected_dot_graph_str_after_packing) + with mock.patch.object( + combiner_packing_util, + "perform_combiner_packing_optimization", + side_effect=_side_effect_fn, + ): + (transform_fn_future_before, unused_cache, _) = ( + analysis_graph_builder.build( + graph, structured_inputs, structured_outputs + ) + ) + transform_fn_future_after, unused_cache = ( + combiner_packing_util.perform_combiner_packing_optimization( + transform_fn_future_before, unused_cache, num_phases + ) + ) + dot_string_before = nodes.get_dot_graph( + [transform_fn_future_before] + ).to_string() + self.assertMultiLineEqual( + msg=f"Prior to optimization dot graph is:\n{dot_string_before}", + first=dot_string_before, + second=expected_dot_graph_str_before_packing, + ) + dot_string_after = nodes.get_dot_graph([transform_fn_future_after]).to_string() + self.WriteRenderedDotFile(dot_string_after) + self.assertMultiLineEqual( + msg=f"After optimization dot graph is:\n{dot_string_after}", + first=dot_string_after, + second=expected_dot_graph_str_after_packing, + ) -if __name__ == '__main__': - test_case.main() +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/beam/common.py b/tensorflow_transform/beam/common.py index f5c7b1e..7dad2e3 100644 --- a/tensorflow_transform/beam/common.py +++ b/tensorflow_transform/beam/common.py @@ -17,16 +17,16 @@ import dataclasses import enum import os -from typing import Any, Dict, Mapping, Optional import uuid +from typing import Any, Dict, Mapping, Optional import apache_beam as beam import tensorflow as tf -from tensorflow_transform import common_types -from tensorflow_transform import nodes from tfx_bsl.telemetry import util -METRICS_NAMESPACE = util.MakeTfxNamespace(['Transform']) +from tensorflow_transform import common_types, nodes + +METRICS_NAMESPACE = util.MakeTfxNamespace(["Transform"]) # Depending on the environment, (TF 1.x vs 2.x for e.g.,) we may want to @@ -34,195 +34,209 @@ # tags are used to identify the implementation to use under the current # environment. class EnvironmentTags(enum.Enum): - TF_COMPAT_V1 = 'tf_compat_v1' - TF_V2_ONLY = 'tf_v2_only' + TF_COMPAT_V1 = "tf_compat_v1" + TF_V2_ONLY = "tf_v2_only" _ALLOWED_PTRANSFORM_TAGS = [tag.value for tag in EnvironmentTags] def get_unique_temp_path(base_temp_dir): - """Return a path to a unique temp dir from given base temp dir. + """Return a path to a unique temp dir from given base temp dir. - Note this doesn't create the path that it returns. + Note this doesn't create the path that it returns. - Args: - base_temp_dir: A base directory + Args: + ---- + base_temp_dir: A base directory - Returns: - The path name of a subdirectory of base_temp_dir, where the subdirectory is - unique. - """ - return os.path.join(base_temp_dir, uuid.uuid4().hex) + Returns: + ------- + The path name of a subdirectory of base_temp_dir, where the subdirectory is + unique. + """ + return os.path.join(base_temp_dir, uuid.uuid4().hex) class _PtransformWrapper: - """A wrapper around registered implementations of beam nodes.""" - _GENERAL_ENVIRONMENT_TAG = object() + """A wrapper around registered implementations of beam nodes.""" - def __init__(self): - self._ptransform_by_tag = {} + _GENERAL_ENVIRONMENT_TAG = object() - def add_ptransform(self, ptransform_class, tags): - """Add `ptransform_class` for all `tags`.""" - # Many tags can refer to the same ptransform_class, but each - # ptransform_class should be registered only once. - tags = {self._GENERAL_ENVIRONMENT_TAG} if tags is None else tags - assert (tag not in self._ptransform_by_tag for tag in tags) - for tag in tags: - self._ptransform_by_tag[tag] = ptransform_class + def __init__(self): + self._ptransform_by_tag = {} - def get_ptransform(self, tag): - """Retrieves ptransform for `tag`. + def add_ptransform(self, ptransform_class, tags): + """Add `ptransform_class` for all `tags`.""" + # Many tags can refer to the same ptransform_class, but each + # ptransform_class should be registered only once. + tags = {self._GENERAL_ENVIRONMENT_TAG} if tags is None else tags + assert (tag not in self._ptransform_by_tag for tag in tags) + for tag in tags: + self._ptransform_by_tag[tag] = ptransform_class - Args: - tag: A string key (or None) to retrieve corresponding ptransform. + def get_ptransform(self, tag): + """Retrieves ptransform for `tag`. - Returns: - A tuple of a registered beam.PTransform implementation and the tag it was - registered with. + Args: + ---- + tag: A string key (or None) to retrieve corresponding ptransform. - Raises: - KeyError: If no registered PTransform implementation could be found. + Returns: + ------- + A tuple of a registered beam.PTransform implementation and the tag it was + registered with. - """ - if tag is None or tag not in self._ptransform_by_tag: - return self._ptransform_by_tag[self._GENERAL_ENVIRONMENT_TAG], None - return self._ptransform_by_tag[tag], tag.value + Raises: + ------ + KeyError: If no registered PTransform implementation could be found. + + """ + if tag is None or tag not in self._ptransform_by_tag: + return self._ptransform_by_tag[self._GENERAL_ENVIRONMENT_TAG], None + return self._ptransform_by_tag[tag], tag.value -_PTRANSFORM_BY_OPERATION_DEF_SUBCLASS = ( - collections.defaultdict(_PtransformWrapper)) +_PTRANSFORM_BY_OPERATION_DEF_SUBCLASS = collections.defaultdict(_PtransformWrapper) def register_ptransform(operation_def_subclass, tags=None): - """Decorator to register a PTransform as the implementation for an analyzer. + """Decorator to register a PTransform as the implementation for an analyzer. - This function is used to define implementations of the analyzers defined in - tensorflow_transform/analyzer_nodes.py and also the internal operations - defined in tensorflow_transform/beam/beam_nodes.py. The registered PTransform - will be invoked as follows: + This function is used to define implementations of the analyzers defined in + tensorflow_transform/analyzer_nodes.py and also the internal operations + defined in tensorflow_transform/beam/beam_nodes.py. The registered PTransform + will be invoked as follows: - outputs = inputs | operation.label >> MyPTransform(operation, extra_args) + outputs = inputs | operation.label >> MyPTransform(operation, extra_args) - where operation is a the instance of the subclass that was registered, - extra_args are global arguments available to each PTransform (see - ConstructBeamPipelineVisitor.extra_args) and `inputs` is a tuple of - PCollections corresponding to the inputs of the OperationNode being - implemented. The return value `outputs` should be a a tuple of PCollections - corresponding to the outputs of the OperationNode. If the OperationNode has - a single output then the return value can also be a PCollection instead of a - tuple. + where operation is a the instance of the subclass that was registered, + extra_args are global arguments available to each PTransform (see + ConstructBeamPipelineVisitor.extra_args) and `inputs` is a tuple of + PCollections corresponding to the inputs of the OperationNode being + implemented. The return value `outputs` should be a a tuple of PCollections + corresponding to the outputs of the OperationNode. If the OperationNode has + a single output then the return value can also be a PCollection instead of a + tuple. - In some cases the implementation cannot be a PTransform and so instead the - value being registered may also be a function. The registered function will - be invoked as follows: + In some cases the implementation cannot be a PTransform and so instead the + value being registered may also be a function. The registered function will + be invoked as follows: - outputs = my_function(inputs, operation, extra_args) + outputs = my_function(inputs, operation, extra_args) - where inputs, operation, extra_args and outputs are the same as for the - PTransform case. + where inputs, operation, extra_args and outputs are the same as for the + PTransform case. - Args: - operation_def_subclass: The class of attributes that is being registered. - Should be a subclass of `tensorflow_transform.nodes.OperationDef`. - tags: A set of string tags belonging to `EnvironmentTags`. If - provided, the PTransform will be registered against all of them. + Args: + ---- + operation_def_subclass: The class of attributes that is being registered. + Should be a subclass of `tensorflow_transform.nodes.OperationDef`. + tags: A set of string tags belonging to `EnvironmentTags`. If + provided, the PTransform will be registered against all of them. - Returns: - A class decorator that registers a PTransform or function as an - implementation of the OperationDef subclass. - """ + Returns: + ------- + A class decorator that registers a PTransform or function as an + implementation of the OperationDef subclass. + """ - def register(ptransform_class): - assert isinstance(ptransform_class, type) - assert issubclass(ptransform_class, beam.PTransform) - assert tags is None or (tag in _ALLOWED_PTRANSFORM_TAGS for tag in tags) - _PTRANSFORM_BY_OPERATION_DEF_SUBCLASS[ - operation_def_subclass].add_ptransform(ptransform_class, tags) - return ptransform_class + def register(ptransform_class): + assert isinstance(ptransform_class, type) + assert issubclass(ptransform_class, beam.PTransform) + assert tags is None or (tag in _ALLOWED_PTRANSFORM_TAGS for tag in tags) + _PTRANSFORM_BY_OPERATION_DEF_SUBCLASS[operation_def_subclass].add_ptransform( + ptransform_class, tags + ) + return ptransform_class - return register + return register class ConstructBeamPipelineVisitor(nodes.Visitor): - """Visitor that constructs the beam pipeline from the node graph.""" - - @dataclasses.dataclass(frozen=True) - class ExtraArgs: - """Context required in order to construct a TFT beam pipeline.""" - # Some typing below is set to Any to avoid having to add dependencies just - # for the type definitions. - base_temp_dir: str - pipeline: beam.Pipeline - flat_pcollection: Optional[beam.PCollection] - pcollection_dict: Dict[str, beam.PCollection] - tf_config: Any - graph: Any - input_signature: Mapping[str, common_types.TensorType] - input_specs: Mapping[str, Any] - input_tensor_adapter_config: Any - use_tf_compat_v1: bool - cache_pcoll_dict: Optional[Dict[str, beam.PCollection]] - preprocessing_fn: Any - analyzers_fingerprint: Mapping[str, Any] - save_options: tf.saved_model.SaveOptions - - def __init__(self, extra_args): - self._extra_args = extra_args - - def visit(self, operation, inputs): - try: - ptransform_wrapper = ( - _PTRANSFORM_BY_OPERATION_DEF_SUBCLASS[operation.__class__]) - environment_tag = ( - EnvironmentTags.TF_COMPAT_V1 - if self._extra_args.use_tf_compat_v1 else EnvironmentTags.TF_V2_ONLY) - ptransform, tag = ptransform_wrapper.get_ptransform(environment_tag) - except KeyError: - raise ValueError('No implementation for {} was registered'.format( - operation)) - - # TODO(zoyahav): Consider extracting a single PCollection before passing to - # ptransform if len(inputs) == 1. - if tag is None: - tagged_label = operation.label - else: - tagged_label = '{label}[{tag}]'.format(label=operation.label, tag=tag) - try: - outputs = ((inputs or beam.pvalue.PBegin(self._extra_args.pipeline)) - | tagged_label >> ptransform(operation, self._extra_args)) - except Exception as e: - raise RuntimeError('Failed to apply: {}'.format(tagged_label)) from e - - if isinstance(outputs, beam.pvalue.PCollection): - return (outputs,) - else: - return outputs - - def validate_value(self, value): - if not isinstance(value, beam.pvalue.PCollection): - raise TypeError('Expected a PCollection, got {} of type {}'.format( - value, type(value))) + """Visitor that constructs the beam pipeline from the node graph.""" + + @dataclasses.dataclass(frozen=True) + class ExtraArgs: + """Context required in order to construct a TFT beam pipeline.""" + + # Some typing below is set to Any to avoid having to add dependencies just + # for the type definitions. + base_temp_dir: str + pipeline: beam.Pipeline + flat_pcollection: Optional[beam.PCollection] + pcollection_dict: Dict[str, beam.PCollection] + tf_config: Any + graph: Any + input_signature: Mapping[str, common_types.TensorType] + input_specs: Mapping[str, Any] + input_tensor_adapter_config: Any + use_tf_compat_v1: bool + cache_pcoll_dict: Optional[Dict[str, beam.PCollection]] + preprocessing_fn: Any + analyzers_fingerprint: Mapping[str, Any] + save_options: tf.saved_model.SaveOptions + + def __init__(self, extra_args): + self._extra_args = extra_args + + def visit(self, operation, inputs): + try: + ptransform_wrapper = _PTRANSFORM_BY_OPERATION_DEF_SUBCLASS[ + operation.__class__ + ] + environment_tag = ( + EnvironmentTags.TF_COMPAT_V1 + if self._extra_args.use_tf_compat_v1 + else EnvironmentTags.TF_V2_ONLY + ) + ptransform, tag = ptransform_wrapper.get_ptransform(environment_tag) + except KeyError: + raise ValueError(f"No implementation for {operation} was registered") + + # TODO(zoyahav): Consider extracting a single PCollection before passing to + # ptransform if len(inputs) == 1. + if tag is None: + tagged_label = operation.label + else: + tagged_label = f"{operation.label}[{tag}]" + try: + outputs = ( + inputs or beam.pvalue.PBegin(self._extra_args.pipeline) + ) | tagged_label >> ptransform(operation, self._extra_args) + except Exception as e: + raise RuntimeError(f"Failed to apply: {tagged_label}") from e + + if isinstance(outputs, beam.pvalue.PCollection): + return (outputs,) + else: + return outputs + + def validate_value(self, value): + if not isinstance(value, beam.pvalue.PCollection): + raise TypeError( + f"Expected a PCollection, got {value} of type {type(value)}" + ) class IncrementCounter(beam.PTransform): - """A PTransform that increments a counter once per PCollection. - - The output PCollection is the same as the input PCollection. - """ + """A PTransform that increments a counter once per PCollection. - def __init__(self, counter_name): - self._counter_name = counter_name - - def _make_and_increment_counter(self, unused_element): - del unused_element - beam.metrics.Metrics.counter(METRICS_NAMESPACE, self._counter_name).inc() - return None + The output PCollection is the same as the input PCollection. + """ - def expand(self, pcoll): - _ = ( - pcoll.pipeline - | 'CreateSole' >> beam.Create([None]) - | 'Count' >> beam.Map(self._make_and_increment_counter)) - return pcoll + def __init__(self, counter_name): + self._counter_name = counter_name + + def _make_and_increment_counter(self, unused_element): + del unused_element + beam.metrics.Metrics.counter(METRICS_NAMESPACE, self._counter_name).inc() + return + + def expand(self, pcoll): + _ = ( + pcoll.pipeline + | "CreateSole" >> beam.Create([None]) + | "Count" >> beam.Map(self._make_and_increment_counter) + ) + return pcoll diff --git a/tensorflow_transform/beam/context.py b/tensorflow_transform/beam/context.py index 770fb14..3d7f9e5 100644 --- a/tensorflow_transform/beam/context.py +++ b/tensorflow_transform/beam/context.py @@ -19,174 +19,185 @@ from typing import Iterable, Optional import tensorflow as tf + from tensorflow_transform import tf2_utils class Context: - """Context manager for tensorflow-transform. - - All the attributes in this context are kept on a thread local state. - - Attributes: - temp_dir: (Optional) The temporary directory used within in this block. - desired_batch_size: (Optional) A batch size to batch elements by. If not - provided, a batch size will be computed automatically. - passthrough_keys: (Optional) A set of strings that are keys to - instances that should pass through the pipeline and be hidden from - the preprocessing_fn. This should only be used in cases where additional - information should be attached to instances in the pipeline which should - not be part of the transformation graph, instance keys is one such - example. - use_deep_copy_optimization: (Optional) If True, makes deep copies of - PCollections that are used in multiple TFT phases. - force_tf_compat_v1: (Optional) If True, TFT's public APIs - (e.g. AnalyzeDataset) will use Tensorflow in compat.v1 mode irrespective - of installed version of Tensorflow. Defaults to `False`. - save_options: (Optional) If set, the tf.saved_model.SaveOptions to save - the transform_fn with. Only applies for TF2. - - Note that the temp dir should be accessible to worker jobs, e.g. if running - with the Cloud Dataflow runner, the temp dir should be on GCS and should have - permissions that allow both launcher and workers to access it. - """ - - @dataclasses.dataclass(frozen=True) - class _State: - """A named tuple to store attributes of `Context`.""" - temp_dir: Optional[str] = None - desired_batch_size: Optional[int] = None - passthrough_keys: Optional[Iterable[str]] = None - use_deep_copy_optimization: Optional[bool] = None - force_tf_compat_v1: Optional[bool] = None - save_options: Optional[tf.saved_model.SaveOptions] = None + """Context manager for tensorflow-transform. + + All the attributes in this context are kept on a thread local state. + + Attributes + ---------- + temp_dir: (Optional) The temporary directory used within in this block. + desired_batch_size: (Optional) A batch size to batch elements by. If not + provided, a batch size will be computed automatically. + passthrough_keys: (Optional) A set of strings that are keys to + instances that should pass through the pipeline and be hidden from + the preprocessing_fn. This should only be used in cases where additional + information should be attached to instances in the pipeline which should + not be part of the transformation graph, instance keys is one such + example. + use_deep_copy_optimization: (Optional) If True, makes deep copies of + PCollections that are used in multiple TFT phases. + force_tf_compat_v1: (Optional) If True, TFT's public APIs + (e.g. AnalyzeDataset) will use Tensorflow in compat.v1 mode irrespective + of installed version of Tensorflow. Defaults to `False`. + save_options: (Optional) If set, the tf.saved_model.SaveOptions to save + the transform_fn with. Only applies for TF2. + + Note that the temp dir should be accessible to worker jobs, e.g. if running + with the Cloud Dataflow runner, the temp dir should be on GCS and should have + permissions that allow both launcher and workers to access it. + """ + + @dataclasses.dataclass(frozen=True) + class _State: + """A named tuple to store attributes of `Context`.""" + + temp_dir: Optional[str] = None + desired_batch_size: Optional[int] = None + passthrough_keys: Optional[Iterable[str]] = None + use_deep_copy_optimization: Optional[bool] = None + force_tf_compat_v1: Optional[bool] = None + save_options: Optional[tf.saved_model.SaveOptions] = None + + @classmethod + def make_empty(cls): + """Return `_State` object with all fields set to `None`.""" + return cls(*(None,) * len(dataclasses.fields(cls))) + + class _StateStack: + """Stack of states for this context manager (found in thread-local storage).""" + + def __init__(self): + self.frames = [] + + # TODO(b/36359436) Ensure tf.Transform code only uses consistent filesystem + # operations on Cloud. + _TEMP_SUBDIR = "tftransform_tmp" + + _thread_local = threading.local() + + def __init__( + self, + temp_dir: Optional[str] = None, + desired_batch_size: Optional[int] = None, + passthrough_keys: Optional[Iterable[str]] = None, + use_deep_copy_optimization: Optional[bool] = None, + force_tf_compat_v1: Optional[bool] = None, + save_options: Optional[tf.saved_model.SaveOptions] = None, + ): + state = getattr(self._thread_local, "state", None) + if not state: + self._thread_local.state = self._StateStack() + self._thread_local.state.frames.append( + self._State(*(None,) * len(dataclasses.fields(self._State))) + ) + + self._temp_dir = temp_dir + self._desired_batch_size = desired_batch_size + self._passthrough_keys = passthrough_keys + self._use_deep_copy_optimization = use_deep_copy_optimization + self._force_tf_compat_v1 = force_tf_compat_v1 + self._save_options = save_options + + def __enter__(self): + # Previous State's properties are inherited if not explicitly specified. + last_frame = self._get_topmost_state_frame() + self._thread_local.state.frames.append( + self._State( + temp_dir=self._temp_dir + if self._temp_dir is not None + else last_frame.temp_dir, + desired_batch_size=self._desired_batch_size + if self._desired_batch_size is not None + else last_frame.desired_batch_size, + passthrough_keys=self._passthrough_keys + if self._passthrough_keys is not None + else last_frame.passthrough_keys, + use_deep_copy_optimization=self._use_deep_copy_optimization + if self._use_deep_copy_optimization is not None + else last_frame.use_deep_copy_optimization, + force_tf_compat_v1=self._force_tf_compat_v1 + if self._force_tf_compat_v1 is not None + else last_frame.force_tf_compat_v1, + save_options=self._save_options or last_frame.save_options, + ) + ) + + def __exit__(self, *exn_info): + self._thread_local.state.frames.pop() @classmethod - def make_empty(cls): - """Return `_State` object with all fields set to `None`.""" - return cls(*(None,) * len(dataclasses.fields(cls))) + def _get_topmost_state_frame(cls) -> "Context._State": + if hasattr(cls._thread_local, "state") and cls._thread_local.state.frames: + return cls._thread_local.state.frames[-1] + return cls._State.make_empty() - class _StateStack: - """Stack of states for this context manager (found in thread-local storage). - """ + @classmethod + def create_base_temp_dir(cls) -> str: + """Generate a temporary location.""" + state = cls._get_topmost_state_frame() + if not state.temp_dir: + raise ValueError( + "A tf.Transform function that required a temp dir was called but no " + "temp dir was set. To set a temp dir use the impl.Context context " + "manager." + ) + base_temp_dir = os.path.join(state.temp_dir, cls._TEMP_SUBDIR) + + # TODO(b/35363519): Perhaps use Beam IO eventually? + tf.io.gfile.makedirs(base_temp_dir) + return base_temp_dir + + @classmethod + def get_desired_batch_size(cls) -> Optional[int]: + """Retrieves a user set fixed batch size, None if not set.""" + state = cls._get_topmost_state_frame() + if state.desired_batch_size is not None: + tf.compat.v1.logging.info( + "Using fixed batch size: %d", state.desired_batch_size + ) + return state.desired_batch_size + return None + + @classmethod + def get_passthrough_keys(cls) -> Iterable[str]: + """Retrieves a user set passthrough_keys, None if not set.""" + state = cls._get_topmost_state_frame() + if state.passthrough_keys is not None: + return state.passthrough_keys + return set() + + @classmethod + def get_use_deep_copy_optimization(cls) -> bool: + """Retrieves a user set use_deep_copy_optimization, None if not set.""" + state = cls._get_topmost_state_frame() + if state.use_deep_copy_optimization is not None: + return state.use_deep_copy_optimization + return False + + @classmethod + def _get_force_tf_compat_v1(cls) -> bool: + """Retrieves flag force_tf_compat_v1.""" + state = cls._get_topmost_state_frame() + if state.force_tf_compat_v1 is not None: + return state.force_tf_compat_v1 + return False + + @classmethod + def get_use_tf_compat_v1(cls) -> bool: + """Computes use_tf_compat_v1 from TF environment and force_tf_compat_v1.""" + force_tf_compat_v1 = cls._get_force_tf_compat_v1() + return tf2_utils.use_tf_compat_v1(force_tf_compat_v1) - def __init__(self): - self.frames = [] - - # TODO(b/36359436) Ensure tf.Transform code only uses consistent filesystem - # operations on Cloud. - _TEMP_SUBDIR = 'tftransform_tmp' - - _thread_local = threading.local() - - def __init__(self, - temp_dir: Optional[str] = None, - desired_batch_size: Optional[int] = None, - passthrough_keys: Optional[Iterable[str]] = None, - use_deep_copy_optimization: Optional[bool] = None, - force_tf_compat_v1: Optional[bool] = None, - save_options: Optional[tf.saved_model.SaveOptions] = None): - state = getattr(self._thread_local, 'state', None) - if not state: - self._thread_local.state = self._StateStack() - self._thread_local.state.frames.append( - self._State(*(None,) * len(dataclasses.fields(self._State)))) - - self._temp_dir = temp_dir - self._desired_batch_size = desired_batch_size - self._passthrough_keys = passthrough_keys - self._use_deep_copy_optimization = use_deep_copy_optimization - self._force_tf_compat_v1 = force_tf_compat_v1 - self._save_options = save_options - - def __enter__(self): - # Previous State's properties are inherited if not explicitly specified. - last_frame = self._get_topmost_state_frame() - self._thread_local.state.frames.append( - self._State( - temp_dir=self._temp_dir - if self._temp_dir is not None else last_frame.temp_dir, - desired_batch_size=self._desired_batch_size - if self._desired_batch_size is not None else - last_frame.desired_batch_size, - passthrough_keys=self._passthrough_keys if - self._passthrough_keys is not None else last_frame.passthrough_keys, - use_deep_copy_optimization=self._use_deep_copy_optimization - if self._use_deep_copy_optimization is not None else - last_frame.use_deep_copy_optimization, - force_tf_compat_v1=self._force_tf_compat_v1 - if self._force_tf_compat_v1 is not None else - last_frame.force_tf_compat_v1, - save_options=self._save_options or last_frame.save_options)) - - def __exit__(self, *exn_info): - self._thread_local.state.frames.pop() - - @classmethod - def _get_topmost_state_frame(cls) -> 'Context._State': - if hasattr(cls._thread_local, 'state') and cls._thread_local.state.frames: - return cls._thread_local.state.frames[-1] - return cls._State.make_empty() - - @classmethod - def create_base_temp_dir(cls) -> str: - """Generate a temporary location.""" - state = cls._get_topmost_state_frame() - if not state.temp_dir: - raise ValueError( - 'A tf.Transform function that required a temp dir was called but no ' - 'temp dir was set. To set a temp dir use the impl.Context context ' - 'manager.') - base_temp_dir = os.path.join(state.temp_dir, cls._TEMP_SUBDIR) - - # TODO(b/35363519): Perhaps use Beam IO eventually? - tf.io.gfile.makedirs(base_temp_dir) - return base_temp_dir - - @classmethod - def get_desired_batch_size(cls) -> Optional[int]: - """Retrieves a user set fixed batch size, None if not set.""" - state = cls._get_topmost_state_frame() - if state.desired_batch_size is not None: - tf.compat.v1.logging.info('Using fixed batch size: %d', - state.desired_batch_size) - return state.desired_batch_size - return None - - @classmethod - def get_passthrough_keys(cls) -> Iterable[str]: - """Retrieves a user set passthrough_keys, None if not set.""" - state = cls._get_topmost_state_frame() - if state.passthrough_keys is not None: - return state.passthrough_keys - return set() - - @classmethod - def get_use_deep_copy_optimization(cls) -> bool: - """Retrieves a user set use_deep_copy_optimization, None if not set.""" - state = cls._get_topmost_state_frame() - if state.use_deep_copy_optimization is not None: - return state.use_deep_copy_optimization - return False - - @classmethod - def _get_force_tf_compat_v1(cls) -> bool: - """Retrieves flag force_tf_compat_v1.""" - state = cls._get_topmost_state_frame() - if state.force_tf_compat_v1 is not None: - return state.force_tf_compat_v1 - return False - - @classmethod - def get_use_tf_compat_v1(cls) -> bool: - """Computes use_tf_compat_v1 from TF environment and force_tf_compat_v1.""" - force_tf_compat_v1 = cls._get_force_tf_compat_v1() - return tf2_utils.use_tf_compat_v1(force_tf_compat_v1) - - @classmethod - def get_save_options(cls) -> Optional[tf.saved_model.SaveOptions]: - """Retrieves a user set save_options, None if not set.""" - state = cls._get_topmost_state_frame() - if state.save_options is not None: - tf.compat.v1.logging.info('Using save_options: %s', state.save_options) - return state.save_options - return None + @classmethod + def get_save_options(cls) -> Optional[tf.saved_model.SaveOptions]: + """Retrieves a user set save_options, None if not set.""" + state = cls._get_topmost_state_frame() + if state.save_options is not None: + tf.compat.v1.logging.info("Using save_options: %s", state.save_options) + return state.save_options + return None diff --git a/tensorflow_transform/beam/context_test.py b/tensorflow_transform/beam/context_test.py index 8115af9..5be7dbe 100644 --- a/tensorflow_transform/beam/context_test.py +++ b/tensorflow_transform/beam/context_test.py @@ -20,25 +20,26 @@ class ContextTest(tft_unit.TransformTestCase): - - def testNestedContextCreateBaseTempDir(self): - - level_1_dir = self.get_temp_dir() - with tft_beam.Context(temp_dir=level_1_dir): - self.assertEqual( - os.path.join(level_1_dir, tft_beam.Context._TEMP_SUBDIR), - tft_beam.Context.create_base_temp_dir()) - level_2_dir = self.get_temp_dir() - with tft_beam.Context(temp_dir=level_2_dir): - self.assertEqual( - os.path.join(level_2_dir, tft_beam.Context._TEMP_SUBDIR), - tft_beam.Context.create_base_temp_dir()) - self.assertEqual( - os.path.join(level_1_dir, tft_beam.Context._TEMP_SUBDIR), - tft_beam.Context.create_base_temp_dir()) - with self.assertRaises(ValueError): - tft_beam.Context.create_base_temp_dir() - - -if __name__ == '__main__': - tft_unit.main() + def testNestedContextCreateBaseTempDir(self): + level_1_dir = self.get_temp_dir() + with tft_beam.Context(temp_dir=level_1_dir): + self.assertEqual( + os.path.join(level_1_dir, tft_beam.Context._TEMP_SUBDIR), + tft_beam.Context.create_base_temp_dir(), + ) + level_2_dir = self.get_temp_dir() + with tft_beam.Context(temp_dir=level_2_dir): + self.assertEqual( + os.path.join(level_2_dir, tft_beam.Context._TEMP_SUBDIR), + tft_beam.Context.create_base_temp_dir(), + ) + self.assertEqual( + os.path.join(level_1_dir, tft_beam.Context._TEMP_SUBDIR), + tft_beam.Context.create_base_temp_dir(), + ) + with self.assertRaises(ValueError): + tft_beam.Context.create_base_temp_dir() + + +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/deep_copy.py b/tensorflow_transform/beam/deep_copy.py index 6c67cee..1207502 100644 --- a/tensorflow_transform/beam/deep_copy.py +++ b/tensorflow_transform/beam/deep_copy.py @@ -33,212 +33,230 @@ from apache_beam.pvalue import PCollection from apache_beam.transforms import resources - -_MATERIALIZATION_BARRIER_TRANSFORMS = set([ - beam.GroupByKey, - # CombinePerKey is included here to allow combiner lifting to occur. - beam.CombinePerKey, -]) +_MATERIALIZATION_BARRIER_TRANSFORMS = set( + [ + beam.GroupByKey, + # CombinePerKey is included here to allow combiner lifting to occur. + beam.CombinePerKey, + ] +) def _is_at_materialization_boundary(pcollection): - """Determines whether a PCollection is at a materialization boundary.""" - # Ascend the hierarchy of composite PTransforms. In the Beam pipeline - # graph, each AppliedPTransform has its "composite parent" stored in its - # .parent field. Here, we check to see whether, at any of the composite - # levels, the transform is a materialization boundary. - current = pcollection.producer - while current: - if (current.transform and - current.transform.__class__ in _MATERIALIZATION_BARRIER_TRANSFORMS): - return True - if pcollection in current.parent.outputs.values(): - current = current.parent - else: - break - return False + """Determines whether a PCollection is at a materialization boundary.""" + # Ascend the hierarchy of composite PTransforms. In the Beam pipeline + # graph, each AppliedPTransform has its "composite parent" stored in its + # .parent field. Here, we check to see whether, at any of the composite + # levels, the transform is a materialization boundary. + current = pcollection.producer + while current: + if ( + current.transform + and current.transform.__class__ in _MATERIALIZATION_BARRIER_TRANSFORMS + ): + return True + if pcollection in current.parent.outputs.values(): + current = current.parent + else: + break + return False def _get_items_to_clone(pcollection): - """Get dependency-sorted list of PCollections and PTransforms to clone. - - This method returns a list of items, each of which is either a PCollection or - PTransform, that need to be cloned when creating a deep copy. This list is - sorted in dependency order, i.e. each PCollection or PTransform in the list - occurs before any of its downstream consumers. - - Args: - pcollection: PCollection to be deep-copied. - - Returns: - A dependency-sorted list of PCollections and PTransforms to clone. - - Raises: - ValueError: if the input PCollection is invalid. - """ - assert isinstance(pcollection, PCollection) - # List of items (either PCollection or PTransform, in reverse dependency - # order (i.e. here, consumers occur before producers). - reversed_to_clone = [] - # Queue of PCollections to be processed in traversal of pipeline graph. - to_process = queue.Queue() - # Set of items (PCollections and PTransforms) already seen during pipeline - # graph traversal. - seen = set() - - to_process.put(pcollection) - seen.add(pcollection) - while not to_process.empty(): - current_pcollection = to_process.get() - - # Stop if we have reached the beginning of the pipeline, or at a - # materialization boundary. - if (isinstance(current_pcollection, pvalue.PBegin) or - _is_at_materialization_boundary(current_pcollection)): - continue - - reversed_to_clone.append(current_pcollection) - applied_transform = current_pcollection.producer - if applied_transform is None: - raise ValueError( - 'PCollection node has invalid producer: %s' % current_pcollection) - - # Visit the input PCollection(s), and also add other outputs of that applied - # PTransform. - if applied_transform in seen: - continue - for output in applied_transform.outputs.values(): - assert isinstance(output, PCollection), output - if output not in seen: - reversed_to_clone.append(output) - seen.add(applied_transform) - reversed_to_clone.append(applied_transform) - for input_pcollection in applied_transform.inputs: - if input_pcollection not in seen: - to_process.put(input_pcollection) - seen.add(input_pcollection) - - return list(reversed(reversed_to_clone)) + """Get dependency-sorted list of PCollections and PTransforms to clone. + + This method returns a list of items, each of which is either a PCollection or + PTransform, that need to be cloned when creating a deep copy. This list is + sorted in dependency order, i.e. each PCollection or PTransform in the list + occurs before any of its downstream consumers. + + Args: + ---- + pcollection: PCollection to be deep-copied. + + Returns: + ------- + A dependency-sorted list of PCollections and PTransforms to clone. + + Raises: + ------ + ValueError: if the input PCollection is invalid. + """ + assert isinstance(pcollection, PCollection) + # List of items (either PCollection or PTransform, in reverse dependency + # order (i.e. here, consumers occur before producers). + reversed_to_clone = [] + # Queue of PCollections to be processed in traversal of pipeline graph. + to_process = queue.Queue() + # Set of items (PCollections and PTransforms) already seen during pipeline + # graph traversal. + seen = set() + + to_process.put(pcollection) + seen.add(pcollection) + while not to_process.empty(): + current_pcollection = to_process.get() + + # Stop if we have reached the beginning of the pipeline, or at a + # materialization boundary. + if isinstance( + current_pcollection, pvalue.PBegin + ) or _is_at_materialization_boundary(current_pcollection): + continue + + reversed_to_clone.append(current_pcollection) + applied_transform = current_pcollection.producer + if applied_transform is None: + raise ValueError( + "PCollection node has invalid producer: %s" % current_pcollection + ) + + # Visit the input PCollection(s), and also add other outputs of that applied + # PTransform. + if applied_transform in seen: + continue + for output in applied_transform.outputs.values(): + assert isinstance(output, PCollection), output + if output not in seen: + reversed_to_clone.append(output) + seen.add(applied_transform) + reversed_to_clone.append(applied_transform) + for input_pcollection in applied_transform.inputs: + if input_pcollection not in seen: + to_process.put(input_pcollection) + seen.add(input_pcollection) + + return list(reversed(reversed_to_clone)) def _clone_items(pipeline, to_clone): - """Clones dependency-sorted list of PCollections and PTransforms. - - Returns mappings of PCollection and PTransform replacements. - - Args: - pipeline: The beam.Pipeline. - to_clone: A dependency-sorted list of PCollections and PTransforms. - - Returns: - pcollection_replacements: a dict mapping original to cloned PCollections. - - Raises: - ValueError: if a clone is requested of an invalid object. - """ - pcollection_replacements = {} - ptransform_replacements = {} - for item in to_clone: - if isinstance(item, pvalue.PCollection): - assert item not in pcollection_replacements - copied = pvalue.PCollection(pipeline, tag=item.tag, - element_type=item.element_type, - windowing=item.windowing) - copied.producer = item.producer - # Update copied PCollection producer if its producer was copied as well. - if copied.producer in ptransform_replacements: - original_producer = copied.producer - copied.producer = ptransform_replacements[original_producer] - # Update producer outputs, - for tag, output in original_producer.outputs.items(): - if output == item: - copied.producer.outputs[tag] = copied - assert copied.producer.transform is not None - pcollection_replacements[item] = copied - elif isinstance(item, beam_pipeline.AppliedPTransform): - assert item.transform is not None - assert item not in ptransform_replacements - # The Beam pipeline graph keeps track of composite PTransforms by having - # AppliedPTransform.parts be a list of "children" AppliedPTransforms that - # are part of the "parent" AppliedPTransform. Any of these "composite - # wrapper" AppliedPTransforms does not actually produce output independent - # of the child non-composite transform. We therefore shouldn't ever clone - # AppliedPTransforms with non-empty parts, since such AppliedPTransforms - # are not reachable by tracing outputs in the pipeline graph. - assert not item.parts, ( - 'Reached invalid composite AppliedPTransform: %r.' % item) - - # TODO(b/217271822): Implement resource hint 'close to resources' for - # Beam/Dataflow, as when CSE makes it to Dataflow, 'close to resources' - # cannot be recognized. Once this is fixed, we can change the tag prefix - # to 'beam'. - # TODO(b/238243699): Obviate the need for setting 'close to resources' - # hints. - close_to_resources_available = resources.ResourceHint.is_registered( - 'close_to_resources') - - if close_to_resources_available: - # Assign close_to_resources resource hint to the orginal PTransforms. - # The reason of adding this annotation is to prevent root Reads that are - # generated from deep copy being merged due to common subexpression - # elimination (CSE). - item.resource_hints['beam:resources:close_to_resources:v1'] = ( - b'/fake/DeepCopy.Original[0]') - - # Assign new label. - count = 0 - copy_suffix = f'Copy{count}' - new_label = f'{item.full_label}.{copy_suffix}' - while new_label in pipeline.applied_labels: - count += 1 - copy_suffix = f'Copy{count}' - new_label = f'{item.full_label}.{copy_suffix}' - pipeline.applied_labels.add(new_label) - - # Update inputs. - new_inputs = { - tag: pcollection_replacements.get(old_input, old_input) - for tag, old_input in item.main_inputs.items() - } - - # Create the copy. Note that in the copy, copied.outputs will start out - # empty. Any outputs that are used will be repopulated in the PCollection - # copy branch above. - maybe_int = lambda s: int(s) if re.match(r'^\d+$', s) else s - semver = tuple(maybe_int(s) for s in beam.__version__.split('.')) - if semver >= (2, 63): - extra_args = { - 'environment_id': item.environment_id, - 'annotations': item.annotations, - } - else: - extra_args = {} - copied = beam_pipeline.AppliedPTransform( - item.parent, item.transform, new_label, new_inputs, **extra_args - ) - - # Add a 'close to resource' resource hint to the copied PTransforms. The - # PTransforms that are generated from each deep copy have the same unique - # 'close to resource' resource hint. This is to make sure that the - # PTransforms that are cloned from each deep copy can be fused together, - # but not across copies nor with the original. - if close_to_resources_available: - copied.resource_hints['beam:resources:close_to_resources:v1'] = ( - f'/fake/DeepCopy.{copy_suffix}[0]'.encode()) - - ptransform_replacements[item] = copied - - # Update composite transform parent to include this copy. - # TODO(b/111366378): Reconcile the composite PTransform nesting hierarchy, - # especially in the case where copied PTransforms should be copied in an - # "all-or-nothing" manner. This would allow the deep copy operation to be - # safe in the case runners replace well-known composite PTransforms in - # their entirety during execution. - copied.parent.parts.append(copied) - else: - raise ValueError('Invalid object to clone: %s' % item) - - return pcollection_replacements + """Clones dependency-sorted list of PCollections and PTransforms. + + Returns mappings of PCollection and PTransform replacements. + + Args: + ---- + pipeline: The beam.Pipeline. + to_clone: A dependency-sorted list of PCollections and PTransforms. + + Returns: + ------- + pcollection_replacements: a dict mapping original to cloned PCollections. + + Raises: + ------ + ValueError: if a clone is requested of an invalid object. + """ + pcollection_replacements = {} + ptransform_replacements = {} + for item in to_clone: + if isinstance(item, pvalue.PCollection): + assert item not in pcollection_replacements + copied = pvalue.PCollection( + pipeline, + tag=item.tag, + element_type=item.element_type, + windowing=item.windowing, + ) + copied.producer = item.producer + # Update copied PCollection producer if its producer was copied as well. + if copied.producer in ptransform_replacements: + original_producer = copied.producer + copied.producer = ptransform_replacements[original_producer] + # Update producer outputs, + for tag, output in original_producer.outputs.items(): + if output == item: + copied.producer.outputs[tag] = copied + assert copied.producer.transform is not None + pcollection_replacements[item] = copied + elif isinstance(item, beam_pipeline.AppliedPTransform): + assert item.transform is not None + assert item not in ptransform_replacements + # The Beam pipeline graph keeps track of composite PTransforms by having + # AppliedPTransform.parts be a list of "children" AppliedPTransforms that + # are part of the "parent" AppliedPTransform. Any of these "composite + # wrapper" AppliedPTransforms does not actually produce output independent + # of the child non-composite transform. We therefore shouldn't ever clone + # AppliedPTransforms with non-empty parts, since such AppliedPTransforms + # are not reachable by tracing outputs in the pipeline graph. + assert not item.parts, ( + "Reached invalid composite AppliedPTransform: %r." % item + ) + + # TODO(b/217271822): Implement resource hint 'close to resources' for + # Beam/Dataflow, as when CSE makes it to Dataflow, 'close to resources' + # cannot be recognized. Once this is fixed, we can change the tag prefix + # to 'beam'. + # TODO(b/238243699): Obviate the need for setting 'close to resources' + # hints. + close_to_resources_available = resources.ResourceHint.is_registered( + "close_to_resources" + ) + + if close_to_resources_available: + # Assign close_to_resources resource hint to the orginal PTransforms. + # The reason of adding this annotation is to prevent root Reads that are + # generated from deep copy being merged due to common subexpression + # elimination (CSE). + item.resource_hints["beam:resources:close_to_resources:v1"] = ( + b"/fake/DeepCopy.Original[0]" + ) + + # Assign new label. + count = 0 + copy_suffix = f"Copy{count}" + new_label = f"{item.full_label}.{copy_suffix}" + while new_label in pipeline.applied_labels: + count += 1 + copy_suffix = f"Copy{count}" + new_label = f"{item.full_label}.{copy_suffix}" + pipeline.applied_labels.add(new_label) + + # Update inputs. + new_inputs = { + tag: pcollection_replacements.get(old_input, old_input) + for tag, old_input in item.main_inputs.items() + } + + # Create the copy. Note that in the copy, copied.outputs will start out + # empty. Any outputs that are used will be repopulated in the PCollection + # copy branch above. + maybe_int = lambda s: int(s) if re.match(r"^\d+$", s) else s + semver = tuple(maybe_int(s) for s in beam.__version__.split(".")) + if semver >= (2, 63): + extra_args = { + "environment_id": item.environment_id, + "annotations": item.annotations, + } + else: + extra_args = {} + copied = beam_pipeline.AppliedPTransform( + item.parent, item.transform, new_label, new_inputs, **extra_args + ) + + # Add a 'close to resource' resource hint to the copied PTransforms. The + # PTransforms that are generated from each deep copy have the same unique + # 'close to resource' resource hint. This is to make sure that the + # PTransforms that are cloned from each deep copy can be fused together, + # but not across copies nor with the original. + if close_to_resources_available: + copied.resource_hints["beam:resources:close_to_resources:v1"] = ( + f"/fake/DeepCopy.{copy_suffix}[0]".encode() + ) + + ptransform_replacements[item] = copied + + # Update composite transform parent to include this copy. + # TODO(b/111366378): Reconcile the composite PTransform nesting hierarchy, + # especially in the case where copied PTransforms should be copied in an + # "all-or-nothing" manner. This would allow the deep copy operation to be + # safe in the case runners replace well-known composite PTransforms in + # their entirety during execution. + copied.parent.parts.append(copied) + else: + raise ValueError("Invalid object to clone: %s" % item) + + return pcollection_replacements # TODO(ccy): When this method is written as a PTransform, the resulting Beam @@ -246,19 +264,21 @@ def _clone_items(pipeline, to_clone): # runner cannot interpret. When this is fixed, we should express this method # as a DeepCopy PTransform. def deep_copy(pcollection): - """Create a deep copy of a PCollection up to materialization boundaries.""" - if not isinstance(pcollection, pvalue.PCollection): - raise ValueError('Input to deep_copy must be a PCollection.') - - # AppliedPTransform.update_input_refcounts() is a vestigial method that - # uses an incorrect heuristic; it will be removed in a future version of - # Beam, since its results aren't used anyway. Until then, we work around - # this (see https://issues.apache.org/jira/browse/BEAM-4593). - if getattr(beam_pipeline.AppliedPTransform, - 'update_input_refcounts', None) is not None: - beam_pipeline.AppliedPTransform.update_input_refcounts = lambda _: None - - to_clone = _get_items_to_clone(pcollection) - pcollection_replacements = _clone_items(pcollection.pipeline, to_clone) - - return pcollection_replacements[pcollection] + """Create a deep copy of a PCollection up to materialization boundaries.""" + if not isinstance(pcollection, pvalue.PCollection): + raise ValueError("Input to deep_copy must be a PCollection.") + + # AppliedPTransform.update_input_refcounts() is a vestigial method that + # uses an incorrect heuristic; it will be removed in a future version of + # Beam, since its results aren't used anyway. Until then, we work around + # this (see https://issues.apache.org/jira/browse/BEAM-4593). + if ( + getattr(beam_pipeline.AppliedPTransform, "update_input_refcounts", None) + is not None + ): + beam_pipeline.AppliedPTransform.update_input_refcounts = lambda _: None + + to_clone = _get_items_to_clone(pcollection) + pcollection_replacements = _clone_items(pcollection.pipeline, to_clone) + + return pcollection_replacements[pcollection] diff --git a/tensorflow_transform/beam/deep_copy_test.py b/tensorflow_transform/beam/deep_copy_test.py index 374463d..9f8be00 100644 --- a/tensorflow_transform/beam/deep_copy_test.py +++ b/tensorflow_transform/beam/deep_copy_test.py @@ -14,321 +14,353 @@ """Unit tests for tensorflow_transform.beam.deep_copy.""" import collections +import unittest import apache_beam as beam from apache_beam import pvalue from apache_beam.transforms import resources -from tensorflow_transform.beam import deep_copy -from tensorflow_transform.beam import test_helpers -import unittest + +from tensorflow_transform.beam import deep_copy, test_helpers # pylint: disable=g-long-lambda class DeepCopyTest(unittest.TestCase): - - @staticmethod - def _MakeBeamPipeline(): - return beam.Pipeline(**test_helpers.make_test_beam_pipeline_kwargs()) - - # _CountingIdentityFn and _InitializeCounts are declared as class-level - # methods to avoid Beam serialization issues, which would occur if an - # individual object instance were referenced in a lambda. In such a case, - # the object would be serialized and deserialized, so that mutations would - # not be propagated correctly for the subsequent verification step. - @staticmethod - def _CountingIdentityFn(label, x): - DeepCopyTest._counts[label] += 1 - return x - - @staticmethod - def _MakeAdd1CountingIdentityFn(label): - - def Add1CountingIdentityFn(x_y): - (x, y) = x_y - return DeepCopyTest._CountingIdentityFn(label, (x + 1, y)) - - return Add1CountingIdentityFn - - @staticmethod - def _InitializeCounts(): - DeepCopyTest._counts = collections.defaultdict(int) - - def setUp(self): - DeepCopyTest._InitializeCounts() - - def testBasicDeepCopy(self): - with DeepCopyTest._MakeBeamPipeline() as p: - grouped = (p - | beam.Create([(1, 'a'), (2, 'b'), (3, 'c')]) - | beam.Map( - lambda x: DeepCopyTest._CountingIdentityFn( - 'PreGroup', x)) - | beam.GroupByKey()) - modified = ( - grouped - | - 'Add1' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add1')) - | - 'Add2' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add2'))) - copied = deep_copy.deep_copy(modified) - - # pylint: disable=expression-not-assigned - modified | 'Add3' >> beam.Map( - DeepCopyTest._MakeAdd1CountingIdentityFn('Add3')) - # pylint: enable=expression-not-assigned - - # Check labels. - self.assertEqual(copied.producer.full_label, 'Add2.Copy0') - self.assertEqual(copied.producer.inputs[0].producer.full_label, - 'Add1.Copy0') - - # Check that deep copy was performed. - self.assertIsNot(copied.producer.inputs[0], modified.producer.inputs[0]) - - # Check that copy stops at materialization boundary. - self.assertIs(copied.producer.inputs[0].producer.inputs[0], - modified.producer.inputs[0].producer.inputs[0]) - - # Check counts of processed items. - self.assertEqual(DeepCopyTest._counts['PreGroup'], 3) - self.assertEqual(DeepCopyTest._counts['Add1'], 6) - self.assertEqual(DeepCopyTest._counts['Add2'], 6) - self.assertEqual(DeepCopyTest._counts['Add3'], 3) - - def testMultipleCopies(self): - with DeepCopyTest._MakeBeamPipeline() as p: - grouped = (p - | beam.Create([(1, 'a'), (2, 'b'), (3, 'c')]) - | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn( - 'PreGroup', x)) - | beam.GroupByKey()) - modified = ( - grouped - | - 'Add1' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add1')) - | - 'Add2' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add2'))) - - num_copies = 6 - - for i in range(num_copies): - copied = deep_copy.deep_copy(modified) - self.assertEqual(copied.producer.full_label, 'Add2.Copy%d' % i) - self.assertEqual(copied.producer.inputs[0].producer.full_label, - 'Add1.Copy%d' % i) - - self.assertEqual(DeepCopyTest._counts['PreGroup'], 3) - self.assertEqual(DeepCopyTest._counts['Add1'], 3 * (num_copies + 1)) - self.assertEqual(DeepCopyTest._counts['Add2'], 3 * (num_copies + 1)) - - def testFlatten(self): - with DeepCopyTest._MakeBeamPipeline() as p: - create_1 = p | 'Create1' >> beam.Create([(1, 'a'), (2, 'b')]) - create_2 = p | 'Create2' >> beam.Create([(3, 'c')]) - created = (create_1, create_2) | 'Flatten1' >> beam.Flatten() - grouped1 = (created - | 'PreGroup1' >> beam.Map( - lambda x: DeepCopyTest._CountingIdentityFn( - 'PreGroup1', x)) - | 'GBK1' >> beam.GroupByKey()) - grouped2 = (p - | beam.Create([(1, 'a'), (2, 'b'), (3, 'c')]) - | 'PreGroup2' >> beam.Map( - lambda x: DeepCopyTest._CountingIdentityFn( - 'PreGroup2', x)) - | 'GBK2' >> beam.GroupByKey()) - modified1 = ( - grouped1 - | - 'Add1' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add1'))) - modified2 = ( - grouped2 - | - 'Add2' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add2'))) - flattened = (modified1, modified2) | 'Flatten2' >> beam.Flatten() - modified3 = ( - flattened - | - 'Add3' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add3'))) - - copied = deep_copy.deep_copy(modified3) - - # Check that deep copy was performed. - self.assertIsNot(copied.producer.inputs[0], modified3.producer.inputs[0]) - self.assertIsNot(copied.producer.inputs[0].producer.inputs[0], - modified3.producer.inputs[0].producer.inputs[0]) - self.assertIsNot(copied.producer.inputs[0].producer.inputs[1], - modified3.producer.inputs[0].producer.inputs[1]) - - # Check that copy stops at materialization boundary. - self.assertIs( - copied.producer.inputs[0].producer.inputs[0].producer.inputs[0], - modified3.producer.inputs[0].producer.inputs[0].producer.inputs[0]) - self.assertIs( - copied.producer.inputs[0].producer.inputs[1].producer.inputs[0], - modified3.producer.inputs[0].producer.inputs[1].producer.inputs[0]) - - # Check counts of processed items. - self.assertEqual(DeepCopyTest._counts['PreGroup1'], 3) - self.assertEqual(DeepCopyTest._counts['PreGroup2'], 3) - self.assertEqual(DeepCopyTest._counts['Add1'], 6) - self.assertEqual(DeepCopyTest._counts['Add2'], 6) - self.assertEqual(DeepCopyTest._counts['Add3'], 12) - - def testEachPTransformCopiedOnce(self): - with DeepCopyTest._MakeBeamPipeline() as p: - created = p | 'Create1' >> beam.Create([(1, 'a'), (2, 'b')]) - modified1 = (created - | 'Transform1' >> beam.Map( - lambda x: DeepCopyTest._CountingIdentityFn( - 'Transform1', x))) - partition_fn = lambda element, partitions: element[0] % partitions - p1, p2 = (modified1 - | 'Partition' >> beam.Partition(partition_fn, 2)) - merged = (p1, p2) | 'Flatten1' >> beam.Flatten() - modified2 = (merged - | 'Transform2' >> beam.Map( - lambda x: DeepCopyTest._CountingIdentityFn( - 'Transform2', x))) - - copied = deep_copy.deep_copy(modified2) - - # Check that deep copy was performed. - self.assertIsNot(copied.producer.inputs[0], modified2.producer.inputs[0]) - self.assertIsNot(copied.producer.inputs[0].producer.inputs[0], - modified2.producer.inputs[0].producer.inputs[0]) - self.assertIsNot(copied.producer.inputs[0].producer.inputs[1], - modified2.producer.inputs[0].producer.inputs[1]) - - # Check counts of processed items. - self.assertEqual(DeepCopyTest._counts['Transform1'], 4) - self.assertEqual(DeepCopyTest._counts['Transform2'], 4) - - def testCombineGlobally(self): - with DeepCopyTest._MakeBeamPipeline() as p: - combined = (p - | beam.Create([1, 2, 3]) - | beam.Map( - lambda x: DeepCopyTest._CountingIdentityFn( - 'PreCombine', x)) - | beam.WindowInto(beam.window.FixedWindows(5, 0)) - | beam.CombineGlobally( - beam.transforms.combiners.MeanCombineFn() - ).without_defaults() - | beam.Map( - lambda x: DeepCopyTest._CountingIdentityFn( - 'PostCombine', x))) - copied = deep_copy.deep_copy(combined) - - # Check that deep copy was performed. - self.assertIsNot(combined, copied) - self.assertIsNot(combined.producer.inputs[0], copied.producer.inputs[0]) - self.assertEqual(combined.producer.inputs[0].producer.full_label, - 'CombineGlobally(MeanCombineFn)/UnKey') - self.assertEqual(copied.producer.inputs[0].producer.full_label, - 'CombineGlobally(MeanCombineFn)/UnKey.Copy0') - - # Check that deep copy stops at materialization boundary. - self.assertIs(combined.producer.inputs[0].producer.inputs[0], - copied.producer.inputs[0].producer.inputs[0]) - self.assertEqual( - str(combined.producer.inputs[0].producer.inputs[0]), - ('PCollection[CombineGlobally(MeanCombineFn)/CombinePerKey/Combine/' - 'ParDo(CombineValuesDoFn).None]')) - self.assertIs(combined.producer.inputs[0].producer.inputs[0].producer, - copied.producer.inputs[0].producer.inputs[0].producer) - self.assertEqual( - copied.producer.inputs[0].producer.inputs[0].producer.full_label, - ('CombineGlobally(MeanCombineFn)/CombinePerKey/Combine/' - 'ParDo(CombineValuesDoFn)')) - - # Check counts of processed items. - self.assertEqual(DeepCopyTest._counts['PreCombine'], 3) - self.assertEqual(DeepCopyTest._counts['PostCombine'], 2) - - def testSideInputNotCopied(self): - with DeepCopyTest._MakeBeamPipeline() as p: - side = (p - | 'CreateSide' >> beam.Create(['s1', 's2', 's3']) - | beam.Map( - lambda x: DeepCopyTest._CountingIdentityFn( - 'SideInput', x))) - main = (p - | 'CreateMain' >> beam.Create([1, 2, 3]) - | beam.Map( - lambda x: DeepCopyTest._CountingIdentityFn( - 'Main', x)) - | beam.Map(lambda e, s: (e, list(s)), - pvalue.AsList(side))) - copied = deep_copy.deep_copy(main) - - # Check that deep copy was performed. - self.assertIsNot(main, copied) - self.assertIsNot(main.producer, copied.producer) - - # Check that deep copy stops at the side input materialization boundary. - self.assertIs(main.producer.side_inputs[0], - copied.producer.side_inputs[0]) - self.assertIs(main.producer.side_inputs[0].pvalue, side) - - # Check counts of processed items. - self.assertEqual(DeepCopyTest._counts['SideInput'], 3) - self.assertEqual(DeepCopyTest._counts['Main'], 6) - - def testDeepCopyTags(self): - if not resources.ResourceHint.is_registered('tags'): - self.skipTest('Resource hint tags are not available.') - - with DeepCopyTest._MakeBeamPipeline() as p: - grouped = ( - p | beam.Create([(1, 'a'), (2, 'b'), (3, 'c')]) - | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn('PreGroup', x))) - - modified = ( - grouped - | - 'Add1' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add1')) - | - 'Add2' >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn('Add2'))) - - num_copies = 6 - - for i in range(num_copies): - copied = deep_copy.deep_copy(modified) - # Check labels. - self.assertEqual(copied.producer.full_label, 'Add2.Copy%d' % i) - self.assertEqual(copied.producer.inputs[0].producer.full_label, - 'Add1.Copy%d' % i) - - # Check resource hints. - self.assertEqual(modified.producer.resource_hints, { - 'beam:resources:close_to_resources:v1': - b'/fake/DeepCopy.Original[0]' - }) - self.assertEqual(modified.producer.inputs[0].producer.resource_hints, { - 'beam:resources:close_to_resources:v1': - b'/fake/DeepCopy.Original[0]' - }) - self.assertEqual(copied.producer.resource_hints, { - 'beam:resources:close_to_resources:v1': - b'/fake/DeepCopy.Copy%d[0]' % i - }) - self.assertEqual(copied.producer.inputs[0].producer.resource_hints, { - 'beam:resources:close_to_resources:v1': - b'/fake/DeepCopy.Copy%d[0]' % i - }) - - # pylint: disable=expression-not-assigned - modified | 'Add3' >> beam.Map( - DeepCopyTest._MakeAdd1CountingIdentityFn('Add3')) - # pylint: enable=expression-not-assigned - - # Check counts of processed items. Without the materialization boundary, - # e.g. GroupByKey, PreGroup is also copied. - self.assertEqual(DeepCopyTest._counts['PreGroup'], 3 * (num_copies + 1)) - self.assertEqual(DeepCopyTest._counts['Add1'], 3 * (num_copies + 1)) - self.assertEqual(DeepCopyTest._counts['Add2'], 3 * (num_copies + 1)) - self.assertEqual(DeepCopyTest._counts['Add3'], 3) - -if __name__ == '__main__': - unittest.main() + @staticmethod + def _MakeBeamPipeline(): + return beam.Pipeline(**test_helpers.make_test_beam_pipeline_kwargs()) + + # _CountingIdentityFn and _InitializeCounts are declared as class-level + # methods to avoid Beam serialization issues, which would occur if an + # individual object instance were referenced in a lambda. In such a case, + # the object would be serialized and deserialized, so that mutations would + # not be propagated correctly for the subsequent verification step. + @staticmethod + def _CountingIdentityFn(label, x): + DeepCopyTest._counts[label] += 1 + return x + + @staticmethod + def _MakeAdd1CountingIdentityFn(label): + def Add1CountingIdentityFn(x_y): + (x, y) = x_y + return DeepCopyTest._CountingIdentityFn(label, (x + 1, y)) + + return Add1CountingIdentityFn + + @staticmethod + def _InitializeCounts(): + DeepCopyTest._counts = collections.defaultdict(int) + + def setUp(self): + DeepCopyTest._InitializeCounts() + + def testBasicDeepCopy(self): + with DeepCopyTest._MakeBeamPipeline() as p: + grouped = ( + p + | beam.Create([(1, "a"), (2, "b"), (3, "c")]) + | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn("PreGroup", x)) + | beam.GroupByKey() + ) + modified = ( + grouped + | "Add1" >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn("Add1")) + | "Add2" >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn("Add2")) + ) + copied = deep_copy.deep_copy(modified) + + # pylint: disable=expression-not-assigned + modified | "Add3" >> beam.Map( + DeepCopyTest._MakeAdd1CountingIdentityFn("Add3") + ) + # pylint: enable=expression-not-assigned + + # Check labels. + self.assertEqual(copied.producer.full_label, "Add2.Copy0") + self.assertEqual( + copied.producer.inputs[0].producer.full_label, "Add1.Copy0" + ) + + # Check that deep copy was performed. + self.assertIsNot(copied.producer.inputs[0], modified.producer.inputs[0]) + + # Check that copy stops at materialization boundary. + self.assertIs( + copied.producer.inputs[0].producer.inputs[0], + modified.producer.inputs[0].producer.inputs[0], + ) + + # Check counts of processed items. + self.assertEqual(DeepCopyTest._counts["PreGroup"], 3) + self.assertEqual(DeepCopyTest._counts["Add1"], 6) + self.assertEqual(DeepCopyTest._counts["Add2"], 6) + self.assertEqual(DeepCopyTest._counts["Add3"], 3) + + def testMultipleCopies(self): + with DeepCopyTest._MakeBeamPipeline() as p: + grouped = ( + p + | beam.Create([(1, "a"), (2, "b"), (3, "c")]) + | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn("PreGroup", x)) + | beam.GroupByKey() + ) + modified = ( + grouped + | "Add1" >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn("Add1")) + | "Add2" >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn("Add2")) + ) + + num_copies = 6 + + for i in range(num_copies): + copied = deep_copy.deep_copy(modified) + self.assertEqual(copied.producer.full_label, "Add2.Copy%d" % i) + self.assertEqual( + copied.producer.inputs[0].producer.full_label, "Add1.Copy%d" % i + ) + + self.assertEqual(DeepCopyTest._counts["PreGroup"], 3) + self.assertEqual(DeepCopyTest._counts["Add1"], 3 * (num_copies + 1)) + self.assertEqual(DeepCopyTest._counts["Add2"], 3 * (num_copies + 1)) + + def testFlatten(self): + with DeepCopyTest._MakeBeamPipeline() as p: + create_1 = p | "Create1" >> beam.Create([(1, "a"), (2, "b")]) + create_2 = p | "Create2" >> beam.Create([(3, "c")]) + created = (create_1, create_2) | "Flatten1" >> beam.Flatten() + grouped1 = ( + created + | "PreGroup1" + >> beam.Map(lambda x: DeepCopyTest._CountingIdentityFn("PreGroup1", x)) + | "GBK1" >> beam.GroupByKey() + ) + grouped2 = ( + p + | beam.Create([(1, "a"), (2, "b"), (3, "c")]) + | "PreGroup2" + >> beam.Map(lambda x: DeepCopyTest._CountingIdentityFn("PreGroup2", x)) + | "GBK2" >> beam.GroupByKey() + ) + modified1 = grouped1 | "Add1" >> beam.Map( + DeepCopyTest._MakeAdd1CountingIdentityFn("Add1") + ) + modified2 = grouped2 | "Add2" >> beam.Map( + DeepCopyTest._MakeAdd1CountingIdentityFn("Add2") + ) + flattened = (modified1, modified2) | "Flatten2" >> beam.Flatten() + modified3 = flattened | "Add3" >> beam.Map( + DeepCopyTest._MakeAdd1CountingIdentityFn("Add3") + ) + + copied = deep_copy.deep_copy(modified3) + + # Check that deep copy was performed. + self.assertIsNot(copied.producer.inputs[0], modified3.producer.inputs[0]) + self.assertIsNot( + copied.producer.inputs[0].producer.inputs[0], + modified3.producer.inputs[0].producer.inputs[0], + ) + self.assertIsNot( + copied.producer.inputs[0].producer.inputs[1], + modified3.producer.inputs[0].producer.inputs[1], + ) + + # Check that copy stops at materialization boundary. + self.assertIs( + copied.producer.inputs[0].producer.inputs[0].producer.inputs[0], + modified3.producer.inputs[0].producer.inputs[0].producer.inputs[0], + ) + self.assertIs( + copied.producer.inputs[0].producer.inputs[1].producer.inputs[0], + modified3.producer.inputs[0].producer.inputs[1].producer.inputs[0], + ) + + # Check counts of processed items. + self.assertEqual(DeepCopyTest._counts["PreGroup1"], 3) + self.assertEqual(DeepCopyTest._counts["PreGroup2"], 3) + self.assertEqual(DeepCopyTest._counts["Add1"], 6) + self.assertEqual(DeepCopyTest._counts["Add2"], 6) + self.assertEqual(DeepCopyTest._counts["Add3"], 12) + + def testEachPTransformCopiedOnce(self): + with DeepCopyTest._MakeBeamPipeline() as p: + created = p | "Create1" >> beam.Create([(1, "a"), (2, "b")]) + modified1 = created | "Transform1" >> beam.Map( + lambda x: DeepCopyTest._CountingIdentityFn("Transform1", x) + ) + partition_fn = lambda element, partitions: element[0] % partitions + p1, p2 = modified1 | "Partition" >> beam.Partition(partition_fn, 2) + merged = (p1, p2) | "Flatten1" >> beam.Flatten() + modified2 = merged | "Transform2" >> beam.Map( + lambda x: DeepCopyTest._CountingIdentityFn("Transform2", x) + ) + + copied = deep_copy.deep_copy(modified2) + + # Check that deep copy was performed. + self.assertIsNot(copied.producer.inputs[0], modified2.producer.inputs[0]) + self.assertIsNot( + copied.producer.inputs[0].producer.inputs[0], + modified2.producer.inputs[0].producer.inputs[0], + ) + self.assertIsNot( + copied.producer.inputs[0].producer.inputs[1], + modified2.producer.inputs[0].producer.inputs[1], + ) + + # Check counts of processed items. + self.assertEqual(DeepCopyTest._counts["Transform1"], 4) + self.assertEqual(DeepCopyTest._counts["Transform2"], 4) + + def testCombineGlobally(self): + with DeepCopyTest._MakeBeamPipeline() as p: + combined = ( + p + | beam.Create([1, 2, 3]) + | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn("PreCombine", x)) + | beam.WindowInto(beam.window.FixedWindows(5, 0)) + | beam.CombineGlobally( + beam.transforms.combiners.MeanCombineFn() + ).without_defaults() + | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn("PostCombine", x)) + ) + copied = deep_copy.deep_copy(combined) + + # Check that deep copy was performed. + self.assertIsNot(combined, copied) + self.assertIsNot(combined.producer.inputs[0], copied.producer.inputs[0]) + self.assertEqual( + combined.producer.inputs[0].producer.full_label, + "CombineGlobally(MeanCombineFn)/UnKey", + ) + self.assertEqual( + copied.producer.inputs[0].producer.full_label, + "CombineGlobally(MeanCombineFn)/UnKey.Copy0", + ) + + # Check that deep copy stops at materialization boundary. + self.assertIs( + combined.producer.inputs[0].producer.inputs[0], + copied.producer.inputs[0].producer.inputs[0], + ) + self.assertEqual( + str(combined.producer.inputs[0].producer.inputs[0]), + ( + "PCollection[CombineGlobally(MeanCombineFn)/CombinePerKey/Combine/" + "ParDo(CombineValuesDoFn).None]" + ), + ) + self.assertIs( + combined.producer.inputs[0].producer.inputs[0].producer, + copied.producer.inputs[0].producer.inputs[0].producer, + ) + self.assertEqual( + copied.producer.inputs[0].producer.inputs[0].producer.full_label, + ( + "CombineGlobally(MeanCombineFn)/CombinePerKey/Combine/" + "ParDo(CombineValuesDoFn)" + ), + ) + + # Check counts of processed items. + self.assertEqual(DeepCopyTest._counts["PreCombine"], 3) + self.assertEqual(DeepCopyTest._counts["PostCombine"], 2) + + def testSideInputNotCopied(self): + with DeepCopyTest._MakeBeamPipeline() as p: + side = ( + p + | "CreateSide" >> beam.Create(["s1", "s2", "s3"]) + | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn("SideInput", x)) + ) + main = ( + p + | "CreateMain" >> beam.Create([1, 2, 3]) + | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn("Main", x)) + | beam.Map(lambda e, s: (e, list(s)), pvalue.AsList(side)) + ) + copied = deep_copy.deep_copy(main) + + # Check that deep copy was performed. + self.assertIsNot(main, copied) + self.assertIsNot(main.producer, copied.producer) + + # Check that deep copy stops at the side input materialization boundary. + self.assertIs(main.producer.side_inputs[0], copied.producer.side_inputs[0]) + self.assertIs(main.producer.side_inputs[0].pvalue, side) + + # Check counts of processed items. + self.assertEqual(DeepCopyTest._counts["SideInput"], 3) + self.assertEqual(DeepCopyTest._counts["Main"], 6) + + def testDeepCopyTags(self): + if not resources.ResourceHint.is_registered("tags"): + self.skipTest("Resource hint tags are not available.") + + with DeepCopyTest._MakeBeamPipeline() as p: + grouped = ( + p + | beam.Create([(1, "a"), (2, "b"), (3, "c")]) + | beam.Map(lambda x: DeepCopyTest._CountingIdentityFn("PreGroup", x)) + ) + + modified = ( + grouped + | "Add1" >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn("Add1")) + | "Add2" >> beam.Map(DeepCopyTest._MakeAdd1CountingIdentityFn("Add2")) + ) + + num_copies = 6 + + for i in range(num_copies): + copied = deep_copy.deep_copy(modified) + # Check labels. + self.assertEqual(copied.producer.full_label, "Add2.Copy%d" % i) + self.assertEqual( + copied.producer.inputs[0].producer.full_label, "Add1.Copy%d" % i + ) + + # Check resource hints. + self.assertEqual( + modified.producer.resource_hints, + { + "beam:resources:close_to_resources:v1": b"/fake/DeepCopy.Original[0]" + }, + ) + self.assertEqual( + modified.producer.inputs[0].producer.resource_hints, + { + "beam:resources:close_to_resources:v1": b"/fake/DeepCopy.Original[0]" + }, + ) + self.assertEqual( + copied.producer.resource_hints, + { + "beam:resources:close_to_resources:v1": b"/fake/DeepCopy.Copy%d[0]" + % i + }, + ) + self.assertEqual( + copied.producer.inputs[0].producer.resource_hints, + { + "beam:resources:close_to_resources:v1": b"/fake/DeepCopy.Copy%d[0]" + % i + }, + ) + + # pylint: disable=expression-not-assigned + modified | "Add3" >> beam.Map( + DeepCopyTest._MakeAdd1CountingIdentityFn("Add3") + ) + # pylint: enable=expression-not-assigned + + # Check counts of processed items. Without the materialization boundary, + # e.g. GroupByKey, PreGroup is also copied. + self.assertEqual(DeepCopyTest._counts["PreGroup"], 3 * (num_copies + 1)) + self.assertEqual(DeepCopyTest._counts["Add1"], 3 * (num_copies + 1)) + self.assertEqual(DeepCopyTest._counts["Add2"], 3 * (num_copies + 1)) + self.assertEqual(DeepCopyTest._counts["Add3"], 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow_transform/beam/experimental/analyzer_impls.py b/tensorflow_transform/beam/experimental/analyzer_impls.py index ad975f6..4ae536a 100644 --- a/tensorflow_transform/beam/experimental/analyzer_impls.py +++ b/tensorflow_transform/beam/experimental/analyzer_impls.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. """Beam implementations of experimental tf.Transform canonical analyzers.""" + import apache_beam as beam class PTransformAnalyzer(beam.PTransform): - """A PTransform analyzer's base class which provides a temp dir if needed.""" + """A PTransform analyzer's base class which provides a temp dir if needed.""" - def __init__(self): - self._base_temp_dir = None + def __init__(self): + self._base_temp_dir = None - @property - def base_temp_dir(self): - return self._base_temp_dir + @property + def base_temp_dir(self): + return self._base_temp_dir - @base_temp_dir.setter - def base_temp_dir(self, val): - self._base_temp_dir = val + @base_temp_dir.setter + def base_temp_dir(self, val): + self._base_temp_dir = val diff --git a/tensorflow_transform/beam/impl.py b/tensorflow_transform/beam/impl.py index 28b4c70..ef9eb20 100644 --- a/tensorflow_transform/beam/impl.py +++ b/tensorflow_transform/beam/impl.py @@ -43,54 +43,44 @@ import datetime import os -from absl import logging import apache_beam as beam -from apache_beam.runners.portability import fn_api_runner -from apache_beam.typehints import Any -from apache_beam.typehints import Dict -from apache_beam.typehints import Iterable -from apache_beam.typehints import List -from apache_beam.typehints import Optional -from apache_beam.typehints import Set -from apache_beam.typehints import Tuple -from apache_beam.typehints import Union -from apache_beam.utils import shared import numpy as np import pyarrow as pa import tensorflow as tf -from tensorflow_transform import annotators -from tensorflow_transform import common -from tensorflow_transform import common_types -from tensorflow_transform import graph_context -from tensorflow_transform import graph_tools -from tensorflow_transform import impl_helper -from tensorflow_transform import nodes -from tensorflow_transform import schema_inference -from tensorflow_transform.beam import analysis_graph_builder -from tensorflow_transform.beam import analyzer_cache -from tensorflow_transform.beam import beam_nodes -from tensorflow_transform.beam import common as beam_common -from tensorflow_transform.beam import context -from tensorflow_transform.beam import deep_copy -from tensorflow_transform.beam.tft_beam_io import beam_metadata_io -from tensorflow_transform.coders import example_proto_coder -from tensorflow_transform.saved import saved_transform_io -from tensorflow_transform.saved import saved_transform_io_v2 -from tensorflow_transform.tf_metadata import dataset_metadata -from tensorflow_transform.tf_metadata import metadata_io -from tensorflow_transform.tf_metadata import schema_utils +from absl import logging +from apache_beam.runners.portability import fn_api_runner +from apache_beam.typehints import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from apache_beam.utils import shared +from tensorflow_metadata.proto.v0 import schema_pb2 from tfx_bsl import beam as tfx_bsl_beam from tfx_bsl.coders import example_coder from tfx_bsl.telemetry import collection as telemetry from tfx_bsl.telemetry import util as telemetry_util -from tfx_bsl.tfxio import tensor_representation_util -from tfx_bsl.tfxio import tensor_to_arrow -from tfx_bsl.tfxio import tf_example_record -from tfx_bsl.tfxio.tensor_adapter import TensorAdapter -from tfx_bsl.tfxio.tensor_adapter import TensorAdapterConfig - -from tensorflow_metadata.proto.v0 import schema_pb2 - +from tfx_bsl.tfxio import tensor_representation_util, tensor_to_arrow, tf_example_record +from tfx_bsl.tfxio.tensor_adapter import TensorAdapter, TensorAdapterConfig + +from tensorflow_transform import ( + annotators, + common, + common_types, + graph_context, + graph_tools, + impl_helper, + nodes, + schema_inference, +) +from tensorflow_transform.beam import ( + analysis_graph_builder, + analyzer_cache, + beam_nodes, + context, + deep_copy, +) +from tensorflow_transform.beam import common as beam_common +from tensorflow_transform.beam.tft_beam_io import beam_metadata_io +from tensorflow_transform.coders import example_proto_coder +from tensorflow_transform.saved import saved_transform_io, saved_transform_io_v2 +from tensorflow_transform.tf_metadata import dataset_metadata, metadata_io, schema_utils tfx_bsl_beam.fix_code_type_pickling() @@ -98,7 +88,7 @@ Context = context.Context -_CREATE_SAVED_MODEL_COUNTER_NAME = 'saved_models_created' +_CREATE_SAVED_MODEL_COUNTER_NAME = "saved_models_created" # For some runners, we rely on Beam to manage concurrency, i.e. we expect it to # run one session per CPU--so we don't want to proliferate TF threads. @@ -112,13 +102,13 @@ # the former for now. use_per_session_threads=True, inter_op_parallelism_threads=2, - intra_op_parallelism_threads=2) + intra_op_parallelism_threads=2, +) _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE = { # TODO(katsiapis): Perhaps remove this entry once b/69922446 and b/30837990 # are resolved. beam.runners.DataflowRunner: _FIXED_PARALLELISM_TF_CONFIG, - beam.runners.DirectRunner: _FIXED_PARALLELISM_TF_CONFIG, fn_api_runner.FnApiRunner: _FIXED_PARALLELISM_TF_CONFIG, } @@ -144,24 +134,29 @@ # the mentioned bug above is resolved. # TODO(zoyahav): Make this a PTransform. def _clear_shared_state_after_barrier(pipeline, input_barrier): - """Clears any shared state from within a pipeline context. + """Clears any shared state from within a pipeline context. - This will only be cleared once input_barrier becomes available. + This will only be cleared once input_barrier becomes available. - Args: - pipeline: A `beam.Pipeline` object. - input_barrier: A `PCollection` which the pipeline should wait for. + Args: + ---- + pipeline: A `beam.Pipeline` object. + input_barrier: A `PCollection` which the pipeline should wait for. - Returns: - An empty `PCollection`. - """ - empty_pcoll = input_barrier | 'MakeCheapBarrier' >> beam.FlatMap( - lambda x: None) - return (pipeline - | 'PrepareToClearSharedKeepAlives' >> beam.Create([None]) - | 'WaitAndClearSharedKeepAlives' >> beam.Map( - lambda x, empty_side_input: shared.Shared().acquire(lambda: None), - beam.pvalue.AsIter(empty_pcoll))) + Returns: + ------- + An empty `PCollection`. + """ + empty_pcoll = input_barrier | "MakeCheapBarrier" >> beam.FlatMap(lambda x: None) + return ( + pipeline + | "PrepareToClearSharedKeepAlives" >> beam.Create([None]) + | "WaitAndClearSharedKeepAlives" + >> beam.Map( + lambda x, empty_side_input: shared.Shared().acquire(lambda: None), + beam.pvalue.AsIter(empty_pcoll), + ) + ) # TODO(b/36223892): Verify that these type hints work and make needed fixes. @@ -170,789 +165,878 @@ def _clear_shared_state_after_barrier(pipeline, input_barrier): _TransformFnPathType, ) @beam.typehints.with_output_types( - Dict[str, Union[np.ndarray, tf.compat.v1.SparseTensorValue]]) + Dict[str, Union[np.ndarray, tf.compat.v1.SparseTensorValue]] +) class _RunMetaGraphDoFn(beam.DoFn): - """Maps a PCollection of dicts to a PCollection of dicts via a TF graph. - - The TF graph may contain more inputs than the schema provided. In that case, - a subset of the inputs will be fed, which may cause an error if the excluded - inputs are required to produce the included outputs. - """ - - class _GraphStateCommon: - """A container for a shared graph state.""" - - def __init__(self, saved_model_dir, input_tensor_keys, output_tensor_keys, - callable_get_outputs): - self.saved_model_dir = saved_model_dir - self.inputs_tensor_keys = input_tensor_keys - self.outputs_tensor_keys = output_tensor_keys - self.callable_get_outputs = callable_get_outputs - - # Thread-safe. - class _GraphStateCompatV1(_GraphStateCommon): - """A container for a shared TF1 graph state.""" - - def __init__(self, saved_model_dir, input_tensor_names, exclude_outputs, - tf_config): - with tf.compat.v1.Graph().as_default() as graph: - self._session = tf.compat.v1.Session(graph=graph, config=tf_config) - with self._session.as_default(): - inputs, outputs = ( - saved_transform_io.partially_apply_saved_transform_internal( - saved_model_dir, {})) - self._session.run(tf.compat.v1.global_variables_initializer()) - self._session.run(tf.compat.v1.tables_initializer()) - graph.finalize() - - if set(input_tensor_names).difference(inputs.keys()): - raise ValueError( - 'Input tensor names contained tensors not in graph: %s' % - input_tensor_names) - if set(exclude_outputs).difference(outputs.keys()): - raise ValueError('Excluded outputs contained keys not in graph: %s' % - exclude_outputs) - non_excluded_output_keys = sorted( - set(outputs.keys()).difference(exclude_outputs)) - fetches = [outputs[key] for key in non_excluded_output_keys] - tensor_inputs = graph_tools.get_dependent_inputs(graph, inputs, fetches) - inputs_tensor_keys = sorted(tensor_inputs.keys()) - outputs_tensor_keys = non_excluded_output_keys - - tensor_inputs_list = [tensor_inputs[key] for key in inputs_tensor_keys] - callable_get_outputs = self._session.make_callable( - fetches, feed_list=tensor_inputs_list) - super().__init__(saved_model_dir, inputs_tensor_keys, - outputs_tensor_keys, callable_get_outputs) - - # Thread-safe. - class _GraphStateV2(_GraphStateCommon): - """A container for a shared TF2 graph state.""" - - def __init__(self, saved_model_dir, input_tensor_names, exclude_outputs): - saved_model_loader = saved_transform_io_v2.SavedModelLoader( - saved_model_dir) - callable_get_outputs = saved_model_loader.apply_transform_model - outputs_tensor_keys = set( - saved_model_loader.structured_outputs.keys()).difference( - exclude_outputs) - saved_model_loader.finalize(input_tensor_names, outputs_tensor_keys) - super().__init__(saved_model_dir, input_tensor_names, outputs_tensor_keys, - callable_get_outputs) - - # Initialized in process(). - _graph_state: _GraphStateCommon - # Initialized in setup(). - _tensor_adapter: TensorAdapter - # i-th element in this list contains the index of the column corresponding - # to self._passthrough_keys[i]. - _passthrough_column_indices: List[int] - - def __init__( - self, - tf_config, - shared_graph_state_handle, - passthrough_keys, - use_tf_compat_v1, - input_tensor_adapter_config, - exclude_outputs=None, - ): - """Initialize. + """Maps a PCollection of dicts to a PCollection of dicts via a TF graph. - Args: - tf_config: A tf.ConfigProto to use in sessions. None implies use - Tensorflow defaults. - shared_graph_state_handle: an instance of shared.Shared() that allows us - to load the graph once and share it across multiple threads in the - current process. - passthrough_keys: A set of strings that are keys to instances that should - pass through the pipeline and be hidden from the preprocessing_fn. - use_tf_compat_v1: Boolean to indicate whether TFT APIs should use TF in - compat.v1 mode. - input_tensor_adapter_config: Tensor Adapter config. - exclude_outputs: (Optional) A list of names of outputs to exclude. + The TF graph may contain more inputs than the schema provided. In that case, + a subset of the inputs will be fed, which may cause an error if the excluded + inputs are required to produce the included outputs. """ - super().__init__() - self._use_tf_compat_v1 = use_tf_compat_v1 - self._input_tensor_adapter_config = input_tensor_adapter_config - self._exclude_outputs = ( - exclude_outputs if exclude_outputs is not None else []) - self._tf_config = tf_config - passthrough_keys = set(passthrough_keys) - schema_keys = self._get_input_tensor_names() - if passthrough_keys - schema_keys != passthrough_keys: - raise ValueError( - 'passthrough_keys overlap with schema keys: {}, {}'.format( - passthrough_keys, schema_keys)) - self._passthrough_keys = sorted(passthrough_keys) - - # The shared graph state handle allows us to load the graph once and share - # it across multiple threads in the current process. - self._shared_graph_state_handle = shared_graph_state_handle - - # Metrics. - self._graph_load_seconds_distribution = beam.metrics.Metrics.distribution( - beam_common.METRICS_NAMESPACE, 'graph_load_seconds') - self._batch_size_distribution = beam.metrics.Metrics.distribution( - beam_common.METRICS_NAMESPACE, 'batch_size') - self._num_instances = beam.metrics.Metrics.counter( - beam_common.METRICS_NAMESPACE, 'num_instances') - - def _get_input_tensor_names(self): - return set(self._input_tensor_adapter_config.tensor_representations.keys()) - - def _update_metrics(self, batch): - self._batch_size_distribution.update(batch.num_rows) - self._num_instances.inc(batch.num_rows) - - def _make_feed_dict(self, batch): - # If self._use_tf_compat_v1 is True, do not produce eager tensors. - produce_eager_tensors = not self._use_tf_compat_v1 - return self._tensor_adapter.ToBatchTensors( - batch, produce_eager_tensors=produce_eager_tensors) - - def _get_passthrough_data_from_recordbatch( - self, batch: pa.RecordBatch - ) -> Dict[str, pa.Array]: - result = {} - for passthrough_key, column_index in zip( - self._passthrough_keys, self._passthrough_column_indices - ): - if column_index >= 0: - # The key is present in the input batch. - passthrough_data_column = batch.column(column_index) - # The passthrough column should be of (large_)list type with - # each sub-list being either null or of length 1. - assert (pa.types.is_list(passthrough_data_column.type) or - pa.types.is_large_list(passthrough_data_column.type)) - result[passthrough_key] = passthrough_data_column - return result - def _handle_batch(self, batch): - self._update_metrics(batch) - # No need to remove (and cannot remove) the passthrough columns here: - # 1) The TensorAdapter expects the RecordBatch to be of the same schema as - # statically determined by the TFXIO implementation the yields the - # TensorAdapter. - # 2) It's not possible to leak passthrough columns through TensorAdapter - # because they are not going to be converted to Tensors. - - feed_dict = self._make_feed_dict(batch) - try: - if self._use_tf_compat_v1: - # Use self._graph_state.inputs_tensor_keys and not the dictionary keys - # to maintain order of the feed list. - feed_list = [ - feed_dict[name] for name in self._graph_state.inputs_tensor_keys - ] - outputs_list = self._graph_state.callable_get_outputs(*feed_list) - assert len(self._graph_state.outputs_tensor_keys) == len(outputs_list) - result = { - key: value for key, value in zip( - self._graph_state.outputs_tensor_keys, outputs_list) - } - else: - result = self._graph_state.callable_get_outputs(feed_dict) - assert len(self._graph_state.outputs_tensor_keys) == len(result) - except Exception as e: - raise ValueError( - """An error occurred while trying to apply the transformation: "{}". - Batch instances: {}, - Fetching the values for the following Tensor keys: {}.""".format( - str(e), batch, self._graph_state.outputs_tensor_keys)) from e - - result.update(self._get_passthrough_data_from_recordbatch(batch)) + class _GraphStateCommon: + """A container for a shared graph state.""" + + def __init__( + self, + saved_model_dir, + input_tensor_keys, + output_tensor_keys, + callable_get_outputs, + ): + self.saved_model_dir = saved_model_dir + self.inputs_tensor_keys = input_tensor_keys + self.outputs_tensor_keys = output_tensor_keys + self.callable_get_outputs = callable_get_outputs + + # Thread-safe. + class _GraphStateCompatV1(_GraphStateCommon): + """A container for a shared TF1 graph state.""" + + def __init__( + self, saved_model_dir, input_tensor_names, exclude_outputs, tf_config + ): + with tf.compat.v1.Graph().as_default() as graph: + self._session = tf.compat.v1.Session(graph=graph, config=tf_config) + with self._session.as_default(): + inputs, outputs = ( + saved_transform_io.partially_apply_saved_transform_internal( + saved_model_dir, {} + ) + ) + self._session.run(tf.compat.v1.global_variables_initializer()) + self._session.run(tf.compat.v1.tables_initializer()) + graph.finalize() + + if set(input_tensor_names).difference(inputs.keys()): + raise ValueError( + "Input tensor names contained tensors not in graph: %s" + % input_tensor_names + ) + if set(exclude_outputs).difference(outputs.keys()): + raise ValueError( + "Excluded outputs contained keys not in graph: %s" + % exclude_outputs + ) + non_excluded_output_keys = sorted( + set(outputs.keys()).difference(exclude_outputs) + ) + fetches = [outputs[key] for key in non_excluded_output_keys] + tensor_inputs = graph_tools.get_dependent_inputs(graph, inputs, fetches) + inputs_tensor_keys = sorted(tensor_inputs.keys()) + outputs_tensor_keys = non_excluded_output_keys + + tensor_inputs_list = [tensor_inputs[key] for key in inputs_tensor_keys] + callable_get_outputs = self._session.make_callable( + fetches, feed_list=tensor_inputs_list + ) + super().__init__( + saved_model_dir, + inputs_tensor_keys, + outputs_tensor_keys, + callable_get_outputs, + ) + + # Thread-safe. + class _GraphStateV2(_GraphStateCommon): + """A container for a shared TF2 graph state.""" + + def __init__(self, saved_model_dir, input_tensor_names, exclude_outputs): + saved_model_loader = saved_transform_io_v2.SavedModelLoader(saved_model_dir) + callable_get_outputs = saved_model_loader.apply_transform_model + outputs_tensor_keys = set( + saved_model_loader.structured_outputs.keys() + ).difference(exclude_outputs) + saved_model_loader.finalize(input_tensor_names, outputs_tensor_keys) + super().__init__( + saved_model_dir, + input_tensor_names, + outputs_tensor_keys, + callable_get_outputs, + ) - return result + # Initialized in process(). + _graph_state: _GraphStateCommon + # Initialized in setup(). + _tensor_adapter: TensorAdapter + # i-th element in this list contains the index of the column corresponding + # to self._passthrough_keys[i]. + _passthrough_column_indices: List[int] + + def __init__( + self, + tf_config, + shared_graph_state_handle, + passthrough_keys, + use_tf_compat_v1, + input_tensor_adapter_config, + exclude_outputs=None, + ): + """Initialize. + + Args: + ---- + tf_config: A tf.ConfigProto to use in sessions. None implies use + Tensorflow defaults. + shared_graph_state_handle: an instance of shared.Shared() that allows us + to load the graph once and share it across multiple threads in the + current process. + passthrough_keys: A set of strings that are keys to instances that should + pass through the pipeline and be hidden from the preprocessing_fn. + use_tf_compat_v1: Boolean to indicate whether TFT APIs should use TF in + compat.v1 mode. + input_tensor_adapter_config: Tensor Adapter config. + exclude_outputs: (Optional) A list of names of outputs to exclude. + """ + super().__init__() + self._use_tf_compat_v1 = use_tf_compat_v1 + self._input_tensor_adapter_config = input_tensor_adapter_config + self._exclude_outputs = exclude_outputs if exclude_outputs is not None else [] + self._tf_config = tf_config + passthrough_keys = set(passthrough_keys) + schema_keys = self._get_input_tensor_names() + if passthrough_keys - schema_keys != passthrough_keys: + raise ValueError( + f"passthrough_keys overlap with schema keys: {passthrough_keys}, {schema_keys}" + ) + self._passthrough_keys = sorted(passthrough_keys) - def _make_graph_state(self, saved_model_dir): - start = datetime.datetime.now() - if self._use_tf_compat_v1: - result = self._GraphStateCompatV1(saved_model_dir, - self._get_input_tensor_names(), - self._exclude_outputs, self._tf_config) - else: - result = self._GraphStateV2(saved_model_dir, - self._get_input_tensor_names(), - self._exclude_outputs) - self._graph_load_seconds_distribution.update( - int((datetime.datetime.now() - start).total_seconds())) - return result + # The shared graph state handle allows us to load the graph once and share + # it across multiple threads in the current process. + self._shared_graph_state_handle = shared_graph_state_handle - def setup(self): - assert self._input_tensor_adapter_config is not None - self._tensor_adapter = TensorAdapter(self._input_tensor_adapter_config) - arrow_schema = self._input_tensor_adapter_config.arrow_schema - self._passthrough_column_indices = [ - arrow_schema.get_field_index(k) for k in self._passthrough_keys - ] + # Metrics. + self._graph_load_seconds_distribution = beam.metrics.Metrics.distribution( + beam_common.METRICS_NAMESPACE, "graph_load_seconds" + ) + self._batch_size_distribution = beam.metrics.Metrics.distribution( + beam_common.METRICS_NAMESPACE, "batch_size" + ) + self._num_instances = beam.metrics.Metrics.counter( + beam_common.METRICS_NAMESPACE, "num_instances" + ) - def process(self, batch, saved_model_dir): - """Runs the given graph to realize the outputs. + def _get_input_tensor_names(self): + return set(self._input_tensor_adapter_config.tensor_representations.keys()) - Runs the graph in a TF session for computing the output values of the - `Tensor`s, `SparseTensor`s, or `RaggedTensor`s, given an input row of data - (input `Tensor`s, `SparseTensor`s, or `RaggedTensor`s). + def _update_metrics(self, batch): + self._batch_size_distribution.update(batch.num_rows) + self._num_instances.inc(batch.num_rows) - Args: - batch: the batch of elements being processed by the DoFn - saved_model_dir: Directory containing saved model. + def _make_feed_dict(self, batch): + # If self._use_tf_compat_v1 is True, do not produce eager tensors. + produce_eager_tensors = not self._use_tf_compat_v1 + return self._tensor_adapter.ToBatchTensors( + batch, produce_eager_tensors=produce_eager_tensors + ) - Yields: - A representation of output features as a dict mapping keys (logical column - names) to values. - """ - if not hasattr(self, '_graph_state'): - # If available, acquire will return a cached _GraphStateCommon, since - # calling _make_graph_state is expensive. - self._graph_state = self._shared_graph_state_handle.acquire( - lambda: self._make_graph_state(saved_model_dir)) + def _get_passthrough_data_from_recordbatch( + self, batch: pa.RecordBatch + ) -> Dict[str, pa.Array]: + result = {} + for passthrough_key, column_index in zip( + self._passthrough_keys, self._passthrough_column_indices + ): + if column_index >= 0: + # The key is present in the input batch. + passthrough_data_column = batch.column(column_index) + # The passthrough column should be of (large_)list type with + # each sub-list being either null or of length 1. + assert pa.types.is_list( + passthrough_data_column.type + ) or pa.types.is_large_list(passthrough_data_column.type) + result[passthrough_key] = passthrough_data_column + return result + + def _handle_batch(self, batch): + self._update_metrics(batch) + # No need to remove (and cannot remove) the passthrough columns here: + # 1) The TensorAdapter expects the RecordBatch to be of the same schema as + # statically determined by the TFXIO implementation the yields the + # TensorAdapter. + # 2) It's not possible to leak passthrough columns through TensorAdapter + # because they are not going to be converted to Tensors. + + feed_dict = self._make_feed_dict(batch) + try: + if self._use_tf_compat_v1: + # Use self._graph_state.inputs_tensor_keys and not the dictionary keys + # to maintain order of the feed list. + feed_list = [ + feed_dict[name] for name in self._graph_state.inputs_tensor_keys + ] + outputs_list = self._graph_state.callable_get_outputs(*feed_list) + assert len(self._graph_state.outputs_tensor_keys) == len(outputs_list) + result = { + key: value + for key, value in zip( + self._graph_state.outputs_tensor_keys, outputs_list + ) + } + else: + result = self._graph_state.callable_get_outputs(feed_dict) + assert len(self._graph_state.outputs_tensor_keys) == len(result) + except Exception as e: + raise ValueError( + f"""An error occurred while trying to apply the transformation: "{str(e)}". + Batch instances: {batch}, + Fetching the values for the following Tensor keys: {self._graph_state.outputs_tensor_keys}.""" + ) from e + + result.update(self._get_passthrough_data_from_recordbatch(batch)) + + return result + + def _make_graph_state(self, saved_model_dir): + start = datetime.datetime.now() + if self._use_tf_compat_v1: + result = self._GraphStateCompatV1( + saved_model_dir, + self._get_input_tensor_names(), + self._exclude_outputs, + self._tf_config, + ) + else: + result = self._GraphStateV2( + saved_model_dir, self._get_input_tensor_names(), self._exclude_outputs + ) + self._graph_load_seconds_distribution.update( + int((datetime.datetime.now() - start).total_seconds()) + ) + return result + + def setup(self): + assert self._input_tensor_adapter_config is not None + self._tensor_adapter = TensorAdapter(self._input_tensor_adapter_config) + arrow_schema = self._input_tensor_adapter_config.arrow_schema + self._passthrough_column_indices = [ + arrow_schema.get_field_index(k) for k in self._passthrough_keys + ] + + def process(self, batch, saved_model_dir): + """Runs the given graph to realize the outputs. + + Runs the graph in a TF session for computing the output values of the + `Tensor`s, `SparseTensor`s, or `RaggedTensor`s, given an input row of data + (input `Tensor`s, `SparseTensor`s, or `RaggedTensor`s). + + Args: + ---- + batch: the batch of elements being processed by the DoFn + saved_model_dir: Directory containing saved model. + + Yields: + ------ + A representation of output features as a dict mapping keys (logical column + names) to values. + """ + if not hasattr(self, "_graph_state"): + # If available, acquire will return a cached _GraphStateCommon, since + # calling _make_graph_state is expensive. + self._graph_state = self._shared_graph_state_handle.acquire( + lambda: self._make_graph_state(saved_model_dir) + ) - # This should remain true throughout the lifetime of this DoFn, regardless - # of whether or not self._graph_state was cached. - assert self._graph_state.saved_model_dir == saved_model_dir + # This should remain true throughout the lifetime of this DoFn, regardless + # of whether or not self._graph_state was cached. + assert self._graph_state.saved_model_dir == saved_model_dir - yield self._handle_batch(batch) + yield self._handle_batch(batch) def _warn_about_tf_compat_v1(): - """Warns about using tf.compat.v1.""" - logging.warning( - 'Tensorflow Transform is running in tf.compat.v1 mode. This could be ' - 'either because TF2 was disabled or `Context.force_tf_compat_v1=True`. ' - 'Features such as tf.function may not work as intended.') + """Warns about using tf.compat.v1.""" + logging.warning( + "Tensorflow Transform is running in tf.compat.v1 mode. This could be " + "either because TF2 was disabled or `Context.force_tf_compat_v1=True`. " + "Features such as tf.function may not work as intended." + ) def _maybe_slice_large_record_batch( record_batch: pa.RecordBatch, ) -> Iterable[pa.RecordBatch]: - """Slices large batches into smaller chunks.""" - if record_batch.nbytes > _MAX_TRANSFORMED_BATCH_BYTES_SIZE: - if record_batch.num_rows < 2: - logging.warning( - 'Transformed data row may be too large: %d bytes. ' - 'Consider reshaping outputs to distribute elements over a larger ' - 'number of rows to allow automatic slicing.', - record_batch.nbytes, - ) - yield record_batch - return - # Note that slicing is a zero-copy operation, so the produced batches will - # still share memory with the original one up to the materialization - # boundary. - mid_point = record_batch.num_rows // 2 - yield from _maybe_slice_large_record_batch( - record_batch.slice(offset=0, length=mid_point) - ) - yield from _maybe_slice_large_record_batch( - record_batch.slice(offset=mid_point) - ) - else: - yield record_batch + """Slices large batches into smaller chunks.""" + if record_batch.nbytes > _MAX_TRANSFORMED_BATCH_BYTES_SIZE: + if record_batch.num_rows < 2: + logging.warning( + "Transformed data row may be too large: %d bytes. " + "Consider reshaping outputs to distribute elements over a larger " + "number of rows to allow automatic slicing.", + record_batch.nbytes, + ) + yield record_batch + return + # Note that slicing is a zero-copy operation, so the produced batches will + # still share memory with the original one up to the materialization + # boundary. + mid_point = record_batch.num_rows // 2 + yield from _maybe_slice_large_record_batch( + record_batch.slice(offset=0, length=mid_point) + ) + yield from _maybe_slice_large_record_batch(record_batch.slice(offset=mid_point)) + else: + yield record_batch def _convert_to_record_batch( batch_dict: Dict[str, Union[common_types.TensorValueType, pa.Array]], converter: tensor_to_arrow.TensorsToRecordBatchConverter, passthrough_keys: Set[str], - input_metadata: Union[ - TensorAdapterConfig, dataset_metadata.DatasetMetadata - ], + input_metadata: Union[TensorAdapterConfig, dataset_metadata.DatasetMetadata], validate_varlen_sparse_values: bool = False, ) -> Iterable[Tuple[pa.RecordBatch, Dict[str, pa.Array]]]: - """Convert batch of ndarrays to pyarrow.RecordBatches.""" - - # Making a copy of batch_dict because mutating PCollection elements is not - # allowed. - if passthrough_keys: - batch_dict = copy.copy(batch_dict) - passthrough_data = { - key: batch_dict.pop(key) for key in passthrough_keys if key in batch_dict - } - - if validate_varlen_sparse_values: - for name, representation in converter.tensor_representations().items(): - if representation.WhichOneof('kind') == 'varlen_sparse_tensor': - impl_helper.validate_varlen_sparse_value(name, batch_dict[name]) - - record_batch = converter.convert(batch_dict) - arrow_columns, arrow_schema = record_batch.columns, record_batch.schema - - batch_size = len(arrow_columns[0]) - # This dict will contain pass-through data with batch size of 1 if it doesn't - # match batch size of the transformed data. - unary_passthrough_features = {} - for key, data in passthrough_data.items(): - # Only raising a ValueError in case pass-through data has more than one - # distinct value. If it has one value and batch_size>1 then it will have to - # be handled by the user. - # TODO(b/38376110): Restrict to matching batch dimensions and clean this up - # once the internal feature key is deprecated. - if len(data) not in (batch_size, 1): - # The passthrough column should be of list type with each - # sub-list being either null or of length 1. - data_set = set( - None if elem is None else elem[0] for elem in data.to_pylist()) - if len(data_set) == 1: - elem = data_set.pop() - data = pa.array([None if elem is None else [elem]], type=data.type) - else: - raise ValueError( - 'Cannot pass-through data when input and output batch sizes ' - 'are different ({} vs. {})'.format(len(data), batch_size)) - if len(data) == batch_size: - arrow_schema = arrow_schema.append(input_metadata.arrow_schema.field(key)) - arrow_columns.append(data) - else: - unary_passthrough_features[key] = data - for reccord_batch in _maybe_slice_large_record_batch( - pa.RecordBatch.from_arrays(arrow_columns, schema=arrow_schema) - ): - yield reccord_batch, unary_passthrough_features + """Convert batch of ndarrays to pyarrow.RecordBatches.""" + # Making a copy of batch_dict because mutating PCollection elements is not + # allowed. + if passthrough_keys: + batch_dict = copy.copy(batch_dict) + passthrough_data = { + key: batch_dict.pop(key) for key in passthrough_keys if key in batch_dict + } + + if validate_varlen_sparse_values: + for name, representation in converter.tensor_representations().items(): + if representation.WhichOneof("kind") == "varlen_sparse_tensor": + impl_helper.validate_varlen_sparse_value(name, batch_dict[name]) + + record_batch = converter.convert(batch_dict) + arrow_columns, arrow_schema = record_batch.columns, record_batch.schema + + batch_size = len(arrow_columns[0]) + # This dict will contain pass-through data with batch size of 1 if it doesn't + # match batch size of the transformed data. + unary_passthrough_features = {} + for key, data in passthrough_data.items(): + # Only raising a ValueError in case pass-through data has more than one + # distinct value. If it has one value and batch_size>1 then it will have to + # be handled by the user. + # TODO(b/38376110): Restrict to matching batch dimensions and clean this up + # once the internal feature key is deprecated. + if len(data) not in (batch_size, 1): + # The passthrough column should be of list type with each + # sub-list being either null or of length 1. + data_set = set( + None if elem is None else elem[0] for elem in data.to_pylist() + ) + if len(data_set) == 1: + elem = data_set.pop() + data = pa.array([None if elem is None else [elem]], type=data.type) + else: + raise ValueError( + "Cannot pass-through data when input and output batch sizes " + f"are different ({len(data)} vs. {batch_size})" + ) + if len(data) == batch_size: + arrow_schema = arrow_schema.append(input_metadata.arrow_schema.field(key)) + arrow_columns.append(data) + else: + unary_passthrough_features[key] = data + for reccord_batch in _maybe_slice_large_record_batch( + pa.RecordBatch.from_arrays(arrow_columns, schema=arrow_schema) + ): + yield reccord_batch, unary_passthrough_features def _transformed_batch_to_instance_dicts( transformed_batch: Tuple[pa.RecordBatch, Dict[str, pa.Array]], schema: schema_pb2.Schema, ): - """Converts batch of transformed data to unbatched instance dicts.""" - record_batch, unary_passthrough_features = transformed_batch - result = impl_helper.record_batch_to_instance_dicts(record_batch, schema) - # Convert unary passthrough data to Python primitives. - for key, value in unary_passthrough_features.items(): - value = value.to_pylist() - value = None if value[0] is None else value[0] - for instance in result: - instance[key] = value - return result + """Converts batch of transformed data to unbatched instance dicts.""" + record_batch, unary_passthrough_features = transformed_batch + result = impl_helper.record_batch_to_instance_dicts(record_batch, schema) + # Convert unary passthrough data to Python primitives. + for key, value in unary_passthrough_features.items(): + value = value.to_pylist() + value = None if value[0] is None else value[0] + for instance in result: + instance[key] = value + return result @dataclasses.dataclass(frozen=True) class _TensorBinding: - value: Any - tensor_name: str - dtype_enum: int - is_asset_filepath: bool + value: Any + tensor_name: str + dtype_enum: int + is_asset_filepath: bool @beam_common.register_ptransform(beam_nodes.CreateTensorBinding) @beam.typehints.with_input_types(common_types.InstanceValueType) @beam.typehints.with_output_types(_TensorBinding) class _CreateTensorBindingsImpl(beam.PTransform): - """Maps a PCollection of data to a PCollection of `_TensorBinding`s.""" - - def __init__(self, operation, extra_args): - del extra_args - self._dtype_enum = operation.dtype_enum - self._tensor_name = operation.tensor_name - self._is_asset_file = operation.is_asset_filepath - - def expand(self, inputs): - pcoll, = inputs - return pcoll | 'ToTensorBinding' >> beam.Map( - _TensorBinding, self._tensor_name, self._dtype_enum, - self._is_asset_file) + """Maps a PCollection of data to a PCollection of `_TensorBinding`s.""" + + def __init__(self, operation, extra_args): + del extra_args + self._dtype_enum = operation.dtype_enum + self._tensor_name = operation.tensor_name + self._is_asset_file = operation.is_asset_filepath + + def expand(self, inputs): + (pcoll,) = inputs + return pcoll | "ToTensorBinding" >> beam.Map( + _TensorBinding, self._tensor_name, self._dtype_enum, self._is_asset_file + ) def _get_tensor_replacement_map(graph, *tensor_bindings): - """Get Tensor replacement map.""" - tensor_replacement_map = {} + """Get Tensor replacement map.""" + tensor_replacement_map = {} - for tensor_binding in tensor_bindings: - assert isinstance(tensor_binding, _TensorBinding), tensor_binding - replacement_tensor = tf.constant( - tensor_binding.value, tf.dtypes.as_dtype(tensor_binding.dtype_enum)) - if graph is not None and tensor_binding.is_asset_filepath: - graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS, - replacement_tensor) - tensor_replacement_map[tensor_binding.tensor_name] = replacement_tensor - return tensor_replacement_map + for tensor_binding in tensor_bindings: + assert isinstance(tensor_binding, _TensorBinding), tensor_binding + replacement_tensor = tf.constant( + tensor_binding.value, tf.dtypes.as_dtype(tensor_binding.dtype_enum) + ) + if graph is not None and tensor_binding.is_asset_filepath: + graph.add_to_collection( + tf.compat.v1.GraphKeys.ASSET_FILEPATHS, replacement_tensor + ) + tensor_replacement_map[tensor_binding.tensor_name] = replacement_tensor + return tensor_replacement_map -def _replace_tensors_with_constant_values(saved_model_dir, base_temp_dir, - *tensor_bindings): - """Replaces specified `Tensor`s with constant values. +def _replace_tensors_with_constant_values( + saved_model_dir, base_temp_dir, *tensor_bindings +): + """Replaces specified `Tensor`s with constant values. - Constants are accepted as Python values; these are automatically - wrapped in `tf.constant()`. + Constants are accepted as Python values; these are automatically + wrapped in `tf.constant()`. - This method creates its own temp dir, and is therefore idempotent - since any retry will use a different temp dir. + This method creates its own temp dir, and is therefore idempotent + since any retry will use a different temp dir. - Args: - saved_model_dir: A SavedModel directory providing a transform - graph. The MetaGraphDef and signature are selected from the - SavedModel using keys defined in `../constants.py` ('transform' - and 'transform_signature', respectively). - base_temp_dir: Base temp dir for storage of new model. - *tensor_bindings: An iterable of `_TensorBinding`s. + Args: + ---- + saved_model_dir: A SavedModel directory providing a transform + graph. The MetaGraphDef and signature are selected from the + SavedModel using keys defined in `../constants.py` ('transform' + and 'transform_signature', respectively). + base_temp_dir: Base temp dir for storage of new model. + *tensor_bindings: An iterable of `_TensorBinding`s. - Returns: - The directory name containing the updated SavedModel. + Returns: + ------- + The directory name containing the updated SavedModel. Raises: - RuntimeError: if there is no default graph available to which to - apply the transform. - """ - with tf.compat.v1.Graph().as_default() as graph: - tensor_replacement_map = ( - _get_tensor_replacement_map(graph, *tensor_bindings)) - - with tf.compat.v1.Session(graph=graph) as session: - temp_dir = beam_common.get_unique_temp_path(base_temp_dir) - input_tensors, output_tensors = ( - saved_transform_io.partially_apply_saved_transform_internal( - saved_model_dir, {}, tensor_replacement_map)) - session.run(tf.compat.v1.global_variables_initializer()) - saved_transform_io.write_saved_transform_from_session( - session, input_tensors, output_tensors, temp_dir) - return temp_dir + ------ + RuntimeError: if there is no default graph available to which to + apply the transform. + """ + with tf.compat.v1.Graph().as_default() as graph: + tensor_replacement_map = _get_tensor_replacement_map(graph, *tensor_bindings) + + with tf.compat.v1.Session(graph=graph) as session: + temp_dir = beam_common.get_unique_temp_path(base_temp_dir) + input_tensors, output_tensors = ( + saved_transform_io.partially_apply_saved_transform_internal( + saved_model_dir, {}, tensor_replacement_map + ) + ) + session.run(tf.compat.v1.global_variables_initializer()) + saved_transform_io.write_saved_transform_from_session( + session, input_tensors, output_tensors, temp_dir + ) + return temp_dir @beam_common.register_ptransform( - beam_nodes.CreateSavedModel, - tags={beam_common.EnvironmentTags.TF_COMPAT_V1}) + beam_nodes.CreateSavedModel, tags={beam_common.EnvironmentTags.TF_COMPAT_V1} +) @beam.typehints.with_input_types(_TensorBinding) @beam.typehints.with_output_types(_TransformFnPathType) class _CreateSavedModelImpl(beam.PTransform): - """Create a SavedModel from a TF Graph.""" - - def __init__(self, operation, extra_args): - self._base_temp_dir = extra_args.base_temp_dir - self._graph = extra_args.graph - self._input_signature = extra_args.input_signature - self._table_initializers = operation.table_initializers - self._output_signature = operation.output_signature - - def expand(self, inputs): - unbound_saved_model_dir = beam_common.get_unique_temp_path( - self._base_temp_dir) - with self._graph.as_default(): - with tf.compat.v1.Session(graph=self._graph) as session: - table_initializers_ref = tf.compat.v1.get_collection_ref( - tf.compat.v1.GraphKeys.TABLE_INITIALIZERS) - original_table_initializers = list(table_initializers_ref) - del table_initializers_ref[:] - table_initializers_ref.extend(self._table_initializers) - # Initialize all variables so they can be saved. - session.run(tf.compat.v1.global_variables_initializer()) - saved_transform_io.write_saved_transform_from_session( - session, self._input_signature, self._output_signature, - unbound_saved_model_dir) - del table_initializers_ref[:] - table_initializers_ref.extend(original_table_initializers) - return (inputs - | 'BindTensors' >> _BindTensors(self._base_temp_dir, - unbound_saved_model_dir) - | 'Count' >> - beam_common.IncrementCounter(_CREATE_SAVED_MODEL_COUNTER_NAME)) - - -def _create_v2_saved_model(tensor_replacement_map, base_temp_dir, - preprocessing_fn, input_signature, - baseline_analyzers_fingerprint, - output_keys_to_name_map, save_options): - """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function. - - The SavedModel written contains a method called `transform_fn` that - represents the traced `preprocessing_fn`. Additionally, if this is the final - SavedModel being written out, it will contain a method called `metadata_fn` - that provides deferred schema annotations. - - Args: - tensor_replacement_map: A map from placeholder tensor names to their - evaluated replacement tensors. - base_temp_dir: Base path to write SavedModel and temporary artifacts to. - preprocessing_fn: A user defined python function to be traced. - input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`. - baseline_analyzers_fingerprint: A mapping from analyzer name to a set of - paths that define its fingerprint. - output_keys_to_name_map: A map from output dictionary keys to the names of - the tensors that they represent. - save_options: The tf.saved_model.SaveOptions to save the model with. - - Returns: - Path to which SavedModel was written. - """ - saved_model_dir = beam_common.get_unique_temp_path(base_temp_dir) - impl_helper.trace_and_write_v2_saved_model(saved_model_dir, preprocessing_fn, - input_signature, base_temp_dir, - baseline_analyzers_fingerprint, - tensor_replacement_map, - output_keys_to_name_map, - save_options) - return saved_model_dir + """Create a SavedModel from a TF Graph.""" + + def __init__(self, operation, extra_args): + self._base_temp_dir = extra_args.base_temp_dir + self._graph = extra_args.graph + self._input_signature = extra_args.input_signature + self._table_initializers = operation.table_initializers + self._output_signature = operation.output_signature + + def expand(self, inputs): + unbound_saved_model_dir = beam_common.get_unique_temp_path(self._base_temp_dir) + with self._graph.as_default(): + with tf.compat.v1.Session(graph=self._graph) as session: + table_initializers_ref = tf.compat.v1.get_collection_ref( + tf.compat.v1.GraphKeys.TABLE_INITIALIZERS + ) + original_table_initializers = list(table_initializers_ref) + del table_initializers_ref[:] + table_initializers_ref.extend(self._table_initializers) + # Initialize all variables so they can be saved. + session.run(tf.compat.v1.global_variables_initializer()) + saved_transform_io.write_saved_transform_from_session( + session, + self._input_signature, + self._output_signature, + unbound_saved_model_dir, + ) + del table_initializers_ref[:] + table_initializers_ref.extend(original_table_initializers) + return ( + inputs + | "BindTensors" + >> _BindTensors(self._base_temp_dir, unbound_saved_model_dir) + | "Count" >> beam_common.IncrementCounter(_CREATE_SAVED_MODEL_COUNTER_NAME) + ) + + +def _create_v2_saved_model( + tensor_replacement_map, + base_temp_dir, + preprocessing_fn, + input_signature, + baseline_analyzers_fingerprint, + output_keys_to_name_map, + save_options, +): + """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function. + + The SavedModel written contains a method called `transform_fn` that + represents the traced `preprocessing_fn`. Additionally, if this is the final + SavedModel being written out, it will contain a method called `metadata_fn` + that provides deferred schema annotations. + + Args: + ---- + tensor_replacement_map: A map from placeholder tensor names to their + evaluated replacement tensors. + base_temp_dir: Base path to write SavedModel and temporary artifacts to. + preprocessing_fn: A user defined python function to be traced. + input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`. + baseline_analyzers_fingerprint: A mapping from analyzer name to a set of + paths that define its fingerprint. + output_keys_to_name_map: A map from output dictionary keys to the names of + the tensors that they represent. + save_options: The tf.saved_model.SaveOptions to save the model with. + + Returns: + ------- + Path to which SavedModel was written. + """ + saved_model_dir = beam_common.get_unique_temp_path(base_temp_dir) + impl_helper.trace_and_write_v2_saved_model( + saved_model_dir, + preprocessing_fn, + input_signature, + base_temp_dir, + baseline_analyzers_fingerprint, + tensor_replacement_map, + output_keys_to_name_map, + save_options, + ) + return saved_model_dir @beam_common.register_ptransform( - beam_nodes.CreateSavedModel, tags={beam_common.EnvironmentTags.TF_V2_ONLY}) + beam_nodes.CreateSavedModel, tags={beam_common.EnvironmentTags.TF_V2_ONLY} +) @beam.typehints.with_input_types(_TensorBinding) @beam.typehints.with_output_types(str) class _CreateSavedModelImplV2(beam.PTransform): - """Create a SavedModel from a TF Graph.""" - - def __init__(self, operation, extra_args): - self._base_temp_dir = extra_args.base_temp_dir - self._preprocessing_fn = extra_args.preprocessing_fn - self._input_signature = extra_args.input_specs - self._output_signature = operation.output_signature - self._analyzers_fingerprint = extra_args.analyzers_fingerprint - self._save_options = extra_args.save_options - - def _maybe_get_output_tensor_names_dict(self): - # output_signature will contain CompositeTensors only if this is the final - # SavedModel export. In this scenario, we do not need the output_signature - # anymore as we will output everything that the preprocessing_fn returns. - if all(isinstance(v, tf.Tensor) for v in self._output_signature.values()): - return {k: v.name for k, v in self._output_signature.items()} - else: - return {} - - def expand(self, inputs): - pipeline = (inputs[0] if isinstance(inputs, tuple) else inputs).pipeline - - input_pcoll = pipeline | 'CreateSole' >> beam.Create([None]) - if not isinstance(inputs, beam.pvalue.PBegin): - input_pcoll |= ('ReplaceWithConstants' >> beam.Map( - lambda _, *args: _get_tensor_replacement_map(None, *args), - *[beam.pvalue.AsSingleton(pcoll) for pcoll in inputs])) + """Create a SavedModel from a TF Graph.""" + + def __init__(self, operation, extra_args): + self._base_temp_dir = extra_args.base_temp_dir + self._preprocessing_fn = extra_args.preprocessing_fn + self._input_signature = extra_args.input_specs + self._output_signature = operation.output_signature + self._analyzers_fingerprint = extra_args.analyzers_fingerprint + self._save_options = extra_args.save_options + + def _maybe_get_output_tensor_names_dict(self): + # output_signature will contain CompositeTensors only if this is the final + # SavedModel export. In this scenario, we do not need the output_signature + # anymore as we will output everything that the preprocessing_fn returns. + if all(isinstance(v, tf.Tensor) for v in self._output_signature.values()): + return {k: v.name for k, v in self._output_signature.items()} + else: + return {} + + def expand(self, inputs): + pipeline = (inputs[0] if isinstance(inputs, tuple) else inputs).pipeline + + input_pcoll = pipeline | "CreateSole" >> beam.Create([None]) + if not isinstance(inputs, beam.pvalue.PBegin): + input_pcoll |= "ReplaceWithConstants" >> beam.Map( + lambda _, *args: _get_tensor_replacement_map(None, *args), + *[beam.pvalue.AsSingleton(pcoll) for pcoll in inputs], + ) - return ( - input_pcoll - | 'CreateSavedModel' >> beam.Map( - _create_v2_saved_model, self._base_temp_dir, self._preprocessing_fn, - self._input_signature, self._analyzers_fingerprint, - self._maybe_get_output_tensor_names_dict(), self._save_options) - | 'Count' >> - beam_common.IncrementCounter(_CREATE_SAVED_MODEL_COUNTER_NAME)) + return ( + input_pcoll + | "CreateSavedModel" + >> beam.Map( + _create_v2_saved_model, + self._base_temp_dir, + self._preprocessing_fn, + self._input_signature, + self._analyzers_fingerprint, + self._maybe_get_output_tensor_names_dict(), + self._save_options, + ) + | "Count" >> beam_common.IncrementCounter(_CREATE_SAVED_MODEL_COUNTER_NAME) + ) class _BindTensors(beam.PTransform): - """PTransform to bind tensor in a SavedModel.""" + """PTransform to bind tensor in a SavedModel.""" - def __init__(self, base_temp_dir, unbound_saved_model_dir): - self._base_temp_dir = base_temp_dir - self._unbound_saved_model_dir = unbound_saved_model_dir + def __init__(self, base_temp_dir, unbound_saved_model_dir): + self._base_temp_dir = base_temp_dir + self._unbound_saved_model_dir = unbound_saved_model_dir - def expand(self, inputs): - pipeline = (inputs[0] if isinstance(inputs, tuple) else inputs).pipeline - saved_model_dir_pcoll = pipeline | 'CreateSavedModel' >> beam.Create( - [self._unbound_saved_model_dir]) + def expand(self, inputs): + pipeline = (inputs[0] if isinstance(inputs, tuple) else inputs).pipeline + saved_model_dir_pcoll = pipeline | "CreateSavedModel" >> beam.Create( + [self._unbound_saved_model_dir] + ) - if isinstance(inputs, beam.pvalue.PBegin): - return saved_model_dir_pcoll + if isinstance(inputs, beam.pvalue.PBegin): + return saved_model_dir_pcoll - return saved_model_dir_pcoll | 'ReplaceWithConstants' >> beam.Map( - _replace_tensors_with_constant_values, self._base_temp_dir, - *[beam.pvalue.AsSingleton(pcoll) for pcoll in inputs]) + return saved_model_dir_pcoll | "ReplaceWithConstants" >> beam.Map( + _replace_tensors_with_constant_values, + self._base_temp_dir, + *[beam.pvalue.AsSingleton(pcoll) for pcoll in inputs], + ) @beam_common.register_ptransform(beam_nodes.ExtractInputForSavedModel) class _ExtractInputForSavedModelImpl(beam.PTransform): - """Returns a PCollection for analysis based on the specified dataset_key.""" - - def __init__(self, operation, extra_args): - self._dataset_key = operation.dataset_key - self._flat_pcollection = extra_args.flat_pcollection - self._pcollection_dict = extra_args.pcollection_dict - - def expand(self, pbegin): - # TODO(b/151921205): we have to do an identity map for unmodified - # PCollections below because otherwise we get an error from beam. - identity_map = 'Identity' >> beam.Map(lambda x: x) - if self._dataset_key.is_flattened_dataset_key(): - if self._flat_pcollection: - return self._flat_pcollection | identity_map - else: - return ( - list(self._pcollection_dict.values()) - | 'FlattenAnalysisInputs' >> beam.Flatten(pipeline=pbegin.pipeline)) - else: - return self._pcollection_dict[self._dataset_key] | identity_map + """Returns a PCollection for analysis based on the specified dataset_key.""" + + def __init__(self, operation, extra_args): + self._dataset_key = operation.dataset_key + self._flat_pcollection = extra_args.flat_pcollection + self._pcollection_dict = extra_args.pcollection_dict + + def expand(self, pbegin): + # TODO(b/151921205): we have to do an identity map for unmodified + # PCollections below because otherwise we get an error from beam. + identity_map = "Identity" >> beam.Map(lambda x: x) + if self._dataset_key.is_flattened_dataset_key(): + if self._flat_pcollection: + return self._flat_pcollection | identity_map + else: + return list( + self._pcollection_dict.values() + ) | "FlattenAnalysisInputs" >> beam.Flatten(pipeline=pbegin.pipeline) + else: + return self._pcollection_dict[self._dataset_key] | identity_map @beam_common.register_ptransform(beam_nodes.ApplySavedModel) class _ApplySavedModelImpl(beam.PTransform): - """PTransform to apply a SavedModel to data.""" - - def __init__(self, operation, extra_args): - self._use_tf_compat_v1 = extra_args.use_tf_compat_v1 - self._input_tensor_adapter_config = extra_args.input_tensor_adapter_config - self._tf_config = extra_args.tf_config - self._phase = operation.phase - - def expand(self, inputs): - saved_model_dir_pcol, input_values_pcol = inputs - - # We don't deep_copy pcollections used for the first phase, or when - # the user defined `Context` disables it. - if self._phase > 0 and Context.get_use_deep_copy_optimization(): - # Obviates unnecessary data materialization when the input data source is - # safe to read more than once. - logging.info('Deep copying inputs for phase: %d', self._phase) - input_values_pcol = deep_copy.deep_copy(input_values_pcol) - - def _convert_to_numpy(input_dict): - """Converts eager tensors to numpy arrays.""" - return { - k: np.asarray(v) if isinstance(v, tf.Tensor) else v - for k, v in input_dict.items() - } - - result = ( - input_values_pcol | 'ApplySavedModel' >> beam.ParDo( + """PTransform to apply a SavedModel to data.""" + + def __init__(self, operation, extra_args): + self._use_tf_compat_v1 = extra_args.use_tf_compat_v1 + self._input_tensor_adapter_config = extra_args.input_tensor_adapter_config + self._tf_config = extra_args.tf_config + self._phase = operation.phase + + def expand(self, inputs): + saved_model_dir_pcol, input_values_pcol = inputs + + # We don't deep_copy pcollections used for the first phase, or when + # the user defined `Context` disables it. + if self._phase > 0 and Context.get_use_deep_copy_optimization(): + # Obviates unnecessary data materialization when the input data source is + # safe to read more than once. + logging.info("Deep copying inputs for phase: %d", self._phase) + input_values_pcol = deep_copy.deep_copy(input_values_pcol) + + def _convert_to_numpy(input_dict): + """Converts eager tensors to numpy arrays.""" + return { + k: np.asarray(v) if isinstance(v, tf.Tensor) else v + for k, v in input_dict.items() + } + + result = input_values_pcol | "ApplySavedModel" >> beam.ParDo( _RunMetaGraphDoFn( self._tf_config, use_tf_compat_v1=self._use_tf_compat_v1, input_tensor_adapter_config=self._input_tensor_adapter_config, shared_graph_state_handle=shared.Shared(), - passthrough_keys=Context.get_passthrough_keys()), - saved_model_dir=beam.pvalue.AsSingleton(saved_model_dir_pcol))) - if not self._use_tf_compat_v1: - result |= 'ConvertToNumpy' >> beam.Map(_convert_to_numpy) - return result + passthrough_keys=Context.get_passthrough_keys(), + ), + saved_model_dir=beam.pvalue.AsSingleton(saved_model_dir_pcol), + ) + if not self._use_tf_compat_v1: + result |= "ConvertToNumpy" >> beam.Map(_convert_to_numpy) + return result @beam_common.register_ptransform(beam_nodes.ExtractFromDict) -@beam.typehints.with_input_types(Dict[str, - Union[np.ndarray, - tf.compat.v1.SparseTensorValue]]) +@beam.typehints.with_input_types( + Dict[str, Union[np.ndarray, tf.compat.v1.SparseTensorValue]] +) class _ExtractFromDictImpl(beam.PTransform): - """Implements ExtractFromDict by extracting the configured keys.""" + """Implements ExtractFromDict by extracting the configured keys.""" - def __init__(self, operation, extra_args): - del extra_args - self._keys = operation.keys + def __init__(self, operation, extra_args): + del extra_args + self._keys = operation.keys - def expand(self, inputs): - pcoll, = inputs + def expand(self, inputs): + (pcoll,) = inputs - def extract_keys(input_dict, keys): - return (tuple(input_dict[k] for k in keys) - if isinstance(keys, tuple) else input_dict[keys]) + def extract_keys(input_dict, keys): + return ( + tuple(input_dict[k] for k in keys) + if isinstance(keys, tuple) + else input_dict[keys] + ) - if isinstance(self._keys, tuple): - output_type = Tuple[(Any,) * len(self._keys)] - else: - output_type = Any - return pcoll | 'ExtractKeys' >> beam.Map( - extract_keys, keys=self._keys).with_output_types(output_type) + if isinstance(self._keys, tuple): + output_type = Tuple[(Any,) * len(self._keys)] + else: + output_type = Any + return pcoll | "ExtractKeys" >> beam.Map( + extract_keys, keys=self._keys + ).with_output_types(output_type) @beam_common.register_ptransform(beam_nodes.Flatten) class _Flatten(beam.PTransform): - """PTransform to flatten PCollections.""" + """PTransform to flatten PCollections.""" - def __init__(self, operation, extra_args): - del operation, extra_args # unused + def __init__(self, operation, extra_args): + del operation, extra_args # unused - def expand(self, inputs): - return inputs | beam.Flatten() + def expand(self, inputs): + return inputs | beam.Flatten() def _infer_metadata_from_saved_model( - saved_model_dir: str, - use_tf_compat_v1: bool) -> dataset_metadata.DatasetMetadata: - """Infers a DatasetMetadata for outputs of a SavedModel.""" - if use_tf_compat_v1: - return _infer_metadata_from_saved_model_v1(saved_model_dir) - else: - return _infer_metadata_from_saved_model_v2(saved_model_dir) + saved_model_dir: str, use_tf_compat_v1: bool +) -> dataset_metadata.DatasetMetadata: + """Infers a DatasetMetadata for outputs of a SavedModel.""" + if use_tf_compat_v1: + return _infer_metadata_from_saved_model_v1(saved_model_dir) + else: + return _infer_metadata_from_saved_model_v2(saved_model_dir) def _infer_metadata_from_saved_model_v1( - saved_model_dir: str) -> dataset_metadata.DatasetMetadata: - """Infers a DatasetMetadata for outputs of a TF1 SavedModel.""" - with tf.compat.v1.Graph().as_default() as graph: - with tf.compat.v1.Session(graph=graph) as session: - _, outputs = ( - saved_transform_io.partially_apply_saved_transform_internal( - saved_model_dir, {})) + saved_model_dir: str, +) -> dataset_metadata.DatasetMetadata: + """Infers a DatasetMetadata for outputs of a TF1 SavedModel.""" + with tf.compat.v1.Graph().as_default() as graph: + with tf.compat.v1.Session(graph=graph) as session: + _, outputs = saved_transform_io.partially_apply_saved_transform_internal( + saved_model_dir, {} + ) - session.run(tf.compat.v1.global_variables_initializer()) - session.run(tf.compat.v1.tables_initializer()) - return dataset_metadata.DatasetMetadata( - schema=schema_inference.infer_feature_schema(outputs, graph, session)) + session.run(tf.compat.v1.global_variables_initializer()) + session.run(tf.compat.v1.tables_initializer()) + return dataset_metadata.DatasetMetadata( + schema=schema_inference.infer_feature_schema(outputs, graph, session) + ) def _infer_metadata_from_saved_model_v2( - saved_model_dir: str) -> dataset_metadata.DatasetMetadata: - """Infers a DatasetMetadata for outputs of a TF2 SavedModel.""" - - metadata_path = os.path.join(saved_model_dir, impl_helper.METADATA_DIR_NAME) - return metadata_io.read_metadata(metadata_path) + saved_model_dir: str, +) -> dataset_metadata.DatasetMetadata: + """Infers a DatasetMetadata for outputs of a TF2 SavedModel.""" + metadata_path = os.path.join(saved_model_dir, impl_helper.METADATA_DIR_NAME) + return metadata_io.read_metadata(metadata_path) class _InstrumentAPI(beam.PTransform): - """PTransform that adds metrics for API usage.""" - - def __init__(self, tf_graph, force_tf_compat_v1, use_tf_compat_v1): - - def _get_counter_from_graph_collection(collection_name): - collection = tf_graph.get_collection(collection_name) - if len(collection) > 1: - raise ValueError( - "Expected TF graph collection '{}' to contain at most one element. " - 'Encountered {}.'.format(collection_name, len(collection))) - return collection[0] if collection else {} - - self._analyzer_use_counter = _get_counter_from_graph_collection( - common.ANALYZER_COLLECTION) - self._mapper_use_counter = _get_counter_from_graph_collection( - common.MAPPER_COLLECTION) - self._force_tf_compat_v1 = force_tf_compat_v1 - self._use_tf_compat_v1 = use_tf_compat_v1 - - def expand(self, pipeline): - - def _make_and_increment_counters(unused_element, analyzer_counter, - mapper_counter, force_tf_compat_v1, - use_tf_compat_v1): - del unused_element - beam.metrics.Metrics.counter(beam_common.METRICS_NAMESPACE, - 'requested_tf_compat_v1').inc( - int(force_tf_compat_v1)) - beam.metrics.Metrics.counter(beam_common.METRICS_NAMESPACE, - 'running_tf_compat_v1').inc( - int(use_tf_compat_v1)) - for counter_prefix, counter in (('tft_analyzer_{}', analyzer_counter), - ('tft_mapper_{}', mapper_counter)): - for name, count in counter.items(): - beam.metrics.Metrics.counter(beam_common.METRICS_NAMESPACE, - counter_prefix.format(name)).inc(count) - - _ = ( - pipeline - | 'CreateSoleAPIUse' >> beam.Create([None]) - | 'CountAPIUse' >> - beam.Map(_make_and_increment_counters, self._analyzer_use_counter, - self._mapper_use_counter, self._force_tf_compat_v1, - self._use_tf_compat_v1)) + """PTransform that adds metrics for API usage.""" + + def __init__(self, tf_graph, force_tf_compat_v1, use_tf_compat_v1): + def _get_counter_from_graph_collection(collection_name): + collection = tf_graph.get_collection(collection_name) + if len(collection) > 1: + raise ValueError( + f"Expected TF graph collection '{collection_name}' to contain at most one element. " + f"Encountered {len(collection)}." + ) + return collection[0] if collection else {} + + self._analyzer_use_counter = _get_counter_from_graph_collection( + common.ANALYZER_COLLECTION + ) + self._mapper_use_counter = _get_counter_from_graph_collection( + common.MAPPER_COLLECTION + ) + self._force_tf_compat_v1 = force_tf_compat_v1 + self._use_tf_compat_v1 = use_tf_compat_v1 + + def expand(self, pipeline): + def _make_and_increment_counters( + unused_element, + analyzer_counter, + mapper_counter, + force_tf_compat_v1, + use_tf_compat_v1, + ): + del unused_element + beam.metrics.Metrics.counter( + beam_common.METRICS_NAMESPACE, "requested_tf_compat_v1" + ).inc(int(force_tf_compat_v1)) + beam.metrics.Metrics.counter( + beam_common.METRICS_NAMESPACE, "running_tf_compat_v1" + ).inc(int(use_tf_compat_v1)) + for counter_prefix, counter in ( + ("tft_analyzer_{}", analyzer_counter), + ("tft_mapper_{}", mapper_counter), + ): + for name, count in counter.items(): + beam.metrics.Metrics.counter( + beam_common.METRICS_NAMESPACE, counter_prefix.format(name) + ).inc(count) + + _ = ( + pipeline + | "CreateSoleAPIUse" >> beam.Create([None]) + | "CountAPIUse" + >> beam.Map( + _make_and_increment_counters, + self._analyzer_use_counter, + self._mapper_use_counter, + self._force_tf_compat_v1, + self._use_tf_compat_v1, + ) + ) @beam.typehints.with_input_types(common_types.InstanceDictType) @beam.typehints.with_output_types(pa.RecordBatch) class _InstanceDictInputToTFXIOInput(beam.PTransform): - """PTransform that turns instance dicts into RecordBatches.""" - - def __init__(self, schema, desired_batch_size): - self._schema = schema - self._tfxio = tf_example_record.TFExampleBeamRecord( - physical_format='inmem', - telemetry_descriptors=['StandaloneTFTransform'], - schema=schema) - self._desired_batch_size = desired_batch_size + """PTransform that turns instance dicts into RecordBatches.""" + + def __init__(self, schema, desired_batch_size): + self._schema = schema + self._tfxio = tf_example_record.TFExampleBeamRecord( + physical_format="inmem", + telemetry_descriptors=["StandaloneTFTransform"], + schema=schema, + ) + self._desired_batch_size = desired_batch_size - def tensor_adapter_config(self): - return self._tfxio.TensorAdapterConfig() + def tensor_adapter_config(self): + return self._tfxio.TensorAdapterConfig() - def expand(self, instance_dict_pcoll): - return ( - instance_dict_pcoll - | 'EncodeInstanceDictsAsTfExample' >> beam.Map( - example_proto_coder.ExampleProtoCoder(self._schema).encode) - | 'TfExampleToRecordBatch' >> self._tfxio.BeamSource( - batch_size=self._desired_batch_size)) + def expand(self, instance_dict_pcoll): + return ( + instance_dict_pcoll + | "EncodeInstanceDictsAsTfExample" + >> beam.Map(example_proto_coder.ExampleProtoCoder(self._schema).encode) + | "TfExampleToRecordBatch" + >> self._tfxio.BeamSource(batch_size=self._desired_batch_size) + ) def _make_output_cache( @@ -962,363 +1046,399 @@ def _make_output_cache( analyzer_cache.DatasetKey, analyzer_cache.DatasetCacheMetadata ], ) -> Optional[analyzer_cache.BeamAnalysisCache]: - """Triggers dataset cache encoding and composes analysis cache output.""" - if cache_value_nodes is None: - return None - cache_dict = collections.defaultdict(dict) - for dataset_key, dataset_cache in cache_value_nodes.items(): - for cache_key, value_node in dataset_cache.items(): - cache_dict[dataset_key][cache_key] = traverser.visit_value_node( - value_node - ) - return { - dataset_key: analyzer_cache.DatasetCache(cache, - dataset_metrics[dataset_key]) - for dataset_key, cache in cache_dict.items() - } + """Triggers dataset cache encoding and composes analysis cache output.""" + if cache_value_nodes is None: + return None + cache_dict = collections.defaultdict(dict) + for dataset_key, dataset_cache in cache_value_nodes.items(): + for cache_key, value_node in dataset_cache.items(): + cache_dict[dataset_key][cache_key] = traverser.visit_value_node(value_node) + return { + dataset_key: analyzer_cache.DatasetCache(cache, dataset_metrics[dataset_key]) + for dataset_key, cache in cache_dict.items() + } class _AnalyzeDatasetCommon(beam.PTransform): - """Common implementation for AnalyzeDataset, with or without cache.""" + """Common implementation for AnalyzeDataset, with or without cache.""" + + def __init__(self, preprocessing_fn, pipeline=None): + """Init method. + + Args: + ---- + preprocessing_fn: A function that accepts and returns a dictionary from + strings to `Tensor`s, `SparseTensor`s, or `RaggedTensor`s. + pipeline: (Optional) a beam Pipeline. + """ + self._preprocessing_fn = preprocessing_fn + self.pipeline = pipeline + self._save_options = Context.get_save_options() + self._use_tf_compat_v1 = Context.get_use_tf_compat_v1() + if self._use_tf_compat_v1: + _warn_about_tf_compat_v1() + + def _extract_input_pvalues(self, dataset): + # This method returns all nested pvalues to inform beam of nested pvalues. + flat_data, data_dict, dataset_cache_dict, metadata = dataset + pvalues = [] + # flat_data should be None when performing analysis with cache. + if flat_data is not None: + pvalues.append(flat_data) + if data_dict: + for value in data_dict.values(): + # Dataset PCollections can be None if it's fully covered by cache and so + # there's no need in reading it. + if value is not None: + pvalues.append(value) + if dataset_cache_dict is not None: + for cache_dict in dataset_cache_dict.values(): + for cache_pcoll in cache_dict.values(): + pvalues.append(cache_pcoll) + if isinstance(metadata, beam_metadata_io.BeamDatasetMetadata): + pvalues.append(metadata.deferred_metadata) + assert ( + self.pipeline is not None or pvalues + ), "If there is no data, a pipeline must be provided" + return dataset, pvalues + + def expand(self, dataset): + """Analyze the dataset. + + Args: + ---- + dataset: A dataset. + + Returns: + ------- + A TransformFn containing the deferred transform function. + + Raises: + ------ + ValueError: If preprocessing_fn has no outputs. + """ + ( + flattened_pcoll, + input_values_pcoll_dict, + dataset_cache_dict, + input_metadata, + ) = dataset + input_values_pcoll_dict = input_values_pcoll_dict or dict() + + if isinstance(input_metadata, dataset_metadata.DatasetMetadata): + if Context.get_passthrough_keys(): + raise ValueError( + f"passthrough_keys is set to {Context.get_passthrough_keys()} but it is not supported" + "with instance dicts + DatasetMetadata input. Follow " + "the guide to switch to the TFXIO format." + ) + logging.warning( + "You are passing instance dicts and DatasetMetadata to TFT which " + "will not provide optimal performance. Consider following the TFT " + "guide to upgrade to the TFXIO format (Apache Arrow RecordBatch)." + ) + to_tfxio_ptransform = _InstanceDictInputToTFXIOInput( + input_metadata.schema, Context.get_desired_batch_size() + ) + input_tensor_adapter_config = to_tfxio_ptransform.tensor_adapter_config() + if flattened_pcoll is not None: + flattened_pcoll |= "InstanceDictToRecordBatch" >> to_tfxio_ptransform + for key in input_values_pcoll_dict.keys(): + if input_values_pcoll_dict[key] is not None: + input_values_pcoll_dict[key] |= ( + f"InstanceDictToRecordBatch[{key}]" >> to_tfxio_ptransform + ) + else: + input_tensor_adapter_config = input_metadata + assert input_tensor_adapter_config is not None + + specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs() + + if not specs: + raise ValueError("The input metadata is empty.") + + base_temp_dir = Context.create_base_temp_dir() + # TODO(b/149997088): Do not pass base_temp_dir here as this graph does not + # need to be serialized to SavedModel. + graph, structured_inputs, structured_outputs = ( + impl_helper.trace_preprocessing_function( + self._preprocessing_fn, specs, self._use_tf_compat_v1, base_temp_dir + ) + ) - def __init__(self, preprocessing_fn, pipeline=None): - """Init method. + # At this point we check that the preprocessing_fn has at least one + # output. This is because if we allowed the output of preprocessing_fn to + # be empty, we wouldn't be able to determine how many instances to + # "unbatch" the output into. + if not isinstance(structured_outputs, dict): + raise ValueError( + "A `preprocessing_fn` must return a " + "Dict[str, Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]]. " + f"Got: {structured_outputs}" + ) + if not structured_outputs: + raise ValueError("The preprocessing function returned an empty dict") + + if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES): + raise ValueError( + "The preprocessing function contained trainable variables " "{}".format( + graph.get_collection_ref(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) + ) + ) - Args: - preprocessing_fn: A function that accepts and returns a dictionary from - strings to `Tensor`s, `SparseTensor`s, or `RaggedTensor`s. - pipeline: (Optional) a beam Pipeline. - """ - self._preprocessing_fn = preprocessing_fn - self.pipeline = pipeline - self._save_options = Context.get_save_options() - self._use_tf_compat_v1 = Context.get_use_tf_compat_v1() - if self._use_tf_compat_v1: - _warn_about_tf_compat_v1() - - def _extract_input_pvalues(self, dataset): - # This method returns all nested pvalues to inform beam of nested pvalues. - flat_data, data_dict, dataset_cache_dict, metadata = dataset - pvalues = [] - # flat_data should be None when performing analysis with cache. - if flat_data is not None: - pvalues.append(flat_data) - if data_dict: - for value in data_dict.values(): - # Dataset PCollections can be None if it's fully covered by cache and so - # there's no need in reading it. - if value is not None: - pvalues.append(value) - if dataset_cache_dict is not None: - for cache_dict in dataset_cache_dict.values(): - for cache_pcoll in cache_dict.values(): - pvalues.append(cache_pcoll) - if isinstance(metadata, beam_metadata_io.BeamDatasetMetadata): - pvalues.append(metadata.deferred_metadata) - assert (self.pipeline is not None or - pvalues), 'If there is no data, a pipeline must be provided' - return dataset, pvalues - - def expand(self, dataset): - """Analyze the dataset. + pipeline = ( + self.pipeline + or ( + flattened_pcoll + or next(v for v in input_values_pcoll_dict.values() if v is not None) + ).pipeline + ) - Args: - dataset: A dataset. + # Add a stage that inspects graph collections for API use counts and logs + # them as a beam metric. + _ = pipeline | "InstrumentAPI" >> _InstrumentAPI( + graph, Context._get_force_tf_compat_v1(), self._use_tf_compat_v1 + ) # pylint: disable=protected-access + + dataset_metrics = {} + if flattened_pcoll is not None: + _ = ( + flattened_pcoll + | "InstrumentInputBytes[AnalysisFlattenedPColl]" + >> telemetry.TrackRecordBatchBytes( + beam_common.METRICS_NAMESPACE, "analysis_input_bytes" + ) + ) + else: + for idx, key in enumerate(sorted(input_values_pcoll_dict.keys())): + infix = f"AnalysisIndex{idx}" + input_value = input_values_pcoll_dict[key] + if input_value is not None: + dataset_metrics[key] = ( + input_value + | f"GetRecordBatchSize[{infix}]" + >> beam.Map(lambda rb: rb.nbytes) + | f"SumTotalBytes[{infix}]" >> beam.CombineGlobally(sum) + | f"ConstructMetadata[{infix}]" + >> beam.Map(analyzer_cache.DatasetCacheMetadata) + ) + _ = ( + input_value + | f"InstrumentInputBytes[{infix}]" + >> telemetry.TrackRecordBatchBytes( + beam_common.METRICS_NAMESPACE, "analysis_input_bytes" + ) + ) + + # Gather telemetry on types of input features. + _ = ( + pipeline + | "CreateAnalyzeInputTensorRepresentations" + >> beam.Create([input_tensor_adapter_config.tensor_representations]) + | "InstrumentAnalyzeInputTensors" + >> telemetry.TrackTensorRepresentations( + telemetry_util.AppendToNamespace( + beam_common.METRICS_NAMESPACE, ["analyze_input_tensors"] + ) + ) + ) - Returns: - A TransformFn containing the deferred transform function. + asset_map = annotators.get_asset_annotations(graph) + # TF.HUB can error when unapproved collections are present. So we explicitly + # clear out the collections in the graph. + annotators.clear_asset_annotations(graph) - Raises: - ValueError: If preprocessing_fn has no outputs. - """ - (flattened_pcoll, input_values_pcoll_dict, dataset_cache_dict, - input_metadata) = dataset - input_values_pcoll_dict = input_values_pcoll_dict or dict() - - if isinstance(input_metadata, dataset_metadata.DatasetMetadata): - if Context.get_passthrough_keys(): - raise ValueError('passthrough_keys is set to {} but it is not supported' - 'with instance dicts + DatasetMetadata input. Follow ' - 'the guide to switch to the TFXIO format.'.format( - Context.get_passthrough_keys())) - logging.warning( - 'You are passing instance dicts and DatasetMetadata to TFT which ' - 'will not provide optimal performance. Consider following the TFT ' - 'guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).') - to_tfxio_ptransform = _InstanceDictInputToTFXIOInput( - input_metadata.schema, Context.get_desired_batch_size()) - input_tensor_adapter_config = to_tfxio_ptransform.tensor_adapter_config() - if flattened_pcoll is not None: - flattened_pcoll |= 'InstanceDictToRecordBatch' >> to_tfxio_ptransform - for key in input_values_pcoll_dict.keys(): - if input_values_pcoll_dict[key] is not None: - input_values_pcoll_dict[key] |= ( - 'InstanceDictToRecordBatch[{}]'.format(key) >> - to_tfxio_ptransform) - else: - input_tensor_adapter_config = input_metadata - assert input_tensor_adapter_config is not None - - specs = TensorAdapter(input_tensor_adapter_config).OriginalTypeSpecs() - - if not specs: - raise ValueError('The input metadata is empty.') - - base_temp_dir = Context.create_base_temp_dir() - # TODO(b/149997088): Do not pass base_temp_dir here as this graph does not - # need to be serialized to SavedModel. - graph, structured_inputs, structured_outputs = ( - impl_helper.trace_preprocessing_function(self._preprocessing_fn, specs, - self._use_tf_compat_v1, - base_temp_dir)) - - # At this point we check that the preprocessing_fn has at least one - # output. This is because if we allowed the output of preprocessing_fn to - # be empty, we wouldn't be able to determine how many instances to - # "unbatch" the output into. - if not isinstance(structured_outputs, dict): - raise ValueError( - 'A `preprocessing_fn` must return a ' - 'Dict[str, Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor]]. ' - f'Got: {structured_outputs}') - if not structured_outputs: - raise ValueError('The preprocessing function returned an empty dict') - - if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES): - raise ValueError( - 'The preprocessing function contained trainable variables ' - '{}'.format( - graph.get_collection_ref( - tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES))) - - pipeline = self.pipeline or (flattened_pcoll or next( - v for v in input_values_pcoll_dict.values() if v is not None)).pipeline - - # Add a stage that inspects graph collections for API use counts and logs - # them as a beam metric. - _ = (pipeline | 'InstrumentAPI' >> _InstrumentAPI( - graph, Context._get_force_tf_compat_v1(), self._use_tf_compat_v1)) # pylint: disable=protected-access - - dataset_metrics = {} - if flattened_pcoll is not None: - _ = ( - flattened_pcoll - | 'InstrumentInputBytes[AnalysisFlattenedPColl]' >> - telemetry.TrackRecordBatchBytes(beam_common.METRICS_NAMESPACE, - 'analysis_input_bytes')) - else: - for idx, key in enumerate(sorted(input_values_pcoll_dict.keys())): - infix = f'AnalysisIndex{idx}' - input_value = input_values_pcoll_dict[key] - if input_value is not None: - dataset_metrics[key] = ( - input_value - | f'GetRecordBatchSize[{infix}]' >> beam.Map(lambda rb: rb.nbytes) - | f'SumTotalBytes[{infix}]' >> beam.CombineGlobally(sum) - | f'ConstructMetadata[{infix}]' >> beam.Map( - analyzer_cache.DatasetCacheMetadata)) - _ = ( - input_value - | f'InstrumentInputBytes[{infix}]' - >> telemetry.TrackRecordBatchBytes( - beam_common.METRICS_NAMESPACE, 'analysis_input_bytes')) - - # Gather telemetry on types of input features. - _ = ( - pipeline - | 'CreateAnalyzeInputTensorRepresentations' - >> beam.Create([input_tensor_adapter_config.tensor_representations]) - | 'InstrumentAnalyzeInputTensors' - >> telemetry.TrackTensorRepresentations( - telemetry_util.AppendToNamespace( - beam_common.METRICS_NAMESPACE, ['analyze_input_tensors'] + analyzers_fingerprint = ( + graph_tools.get_analyzers_fingerprint(graph, structured_inputs) + if not self._use_tf_compat_v1 + else None + ) + + tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get( + type(pipeline.runner) + ) + extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs( + base_temp_dir=base_temp_dir, + tf_config=tf_config, + pipeline=pipeline, + flat_pcollection=flattened_pcoll, + pcollection_dict=input_values_pcoll_dict, + graph=graph, + input_signature=structured_inputs, + input_specs=specs, + input_tensor_adapter_config=input_tensor_adapter_config, + use_tf_compat_v1=self._use_tf_compat_v1, + cache_pcoll_dict=dataset_cache_dict, + preprocessing_fn=self._preprocessing_fn, + analyzers_fingerprint=analyzers_fingerprint, + save_options=self._save_options, + ) + + (transform_fn_future, cache_value_nodes, detached_sideeffect_leafs) = ( + analysis_graph_builder.build( + graph, + structured_inputs, + structured_outputs, + input_values_pcoll_dict.keys(), + cache_dict=dataset_cache_dict, ) ) - ) + traverser = nodes.Traverser( + beam_common.ConstructBeamPipelineVisitor(extra_args) + ) + transform_fn_pcoll = traverser.visit_value_node(transform_fn_future) - asset_map = annotators.get_asset_annotations(graph) - # TF.HUB can error when unapproved collections are present. So we explicitly - # clear out the collections in the graph. - annotators.clear_asset_annotations(graph) - - analyzers_fingerprint = graph_tools.get_analyzers_fingerprint( - graph, structured_inputs) if not self._use_tf_compat_v1 else None - - tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get( - type(pipeline.runner)) - extra_args = beam_common.ConstructBeamPipelineVisitor.ExtraArgs( - base_temp_dir=base_temp_dir, - tf_config=tf_config, - pipeline=pipeline, - flat_pcollection=flattened_pcoll, - pcollection_dict=input_values_pcoll_dict, - graph=graph, - input_signature=structured_inputs, - input_specs=specs, - input_tensor_adapter_config=input_tensor_adapter_config, - use_tf_compat_v1=self._use_tf_compat_v1, - cache_pcoll_dict=dataset_cache_dict, - preprocessing_fn=self._preprocessing_fn, - analyzers_fingerprint=analyzers_fingerprint, - save_options=self._save_options) - - (transform_fn_future, cache_value_nodes, - detached_sideeffect_leafs) = analysis_graph_builder.build( - graph, - structured_inputs, - structured_outputs, - input_values_pcoll_dict.keys(), - cache_dict=dataset_cache_dict) - traverser = nodes.Traverser( - beam_common.ConstructBeamPipelineVisitor(extra_args)) - transform_fn_pcoll = traverser.visit_value_node(transform_fn_future) - - # Cause side-effect nodes to get executed. - for node in detached_sideeffect_leafs: - traverser.visit_value_node(node) - - output_cache_pcoll_dict = _make_output_cache(cache_value_nodes, traverser, - dataset_metrics) - - # Infer metadata. We take the inferred metadata and apply overrides that - # refer to values of tensors in the graph. The override tensors must - # be "constant" in that they don't depend on input data. The tensors can - # depend on analyzer outputs though. This allows us to set metadata that - # depends on analyzer outputs. _infer_metadata_from_saved_model will use the - # analyzer outputs stored in `transform_fn` to compute the metadata in a - # deferred manner, once the analyzer outputs are known. - if self._use_tf_compat_v1: - schema = schema_inference.infer_feature_schema(structured_outputs, graph) - else: - # Use metadata_fn here as func_graph outputs may be wrapped in an identity - # op and hence may not return the same tensors that were annotated. - tf_graph_context = graph_context.TFGraphContext( - module_to_export=tf.Module(), - temp_dir=base_temp_dir, - evaluated_replacements={}) - concrete_metadata_fn = schema_inference.get_traced_metadata_fn( - preprocessing_fn=self._preprocessing_fn, - structured_inputs=structured_inputs, - tf_graph_context=tf_graph_context, - evaluate_schema_overrides=False) - schema = schema_inference.infer_feature_schema_v2( - structured_outputs, - concrete_metadata_fn, - evaluate_schema_overrides=False) - deferred_metadata = ( - transform_fn_pcoll - | 'ComputeDeferredMetadata[compat_v1={}]'.format(self._use_tf_compat_v1) - >> beam.Map(_infer_metadata_from_saved_model, self._use_tf_compat_v1)) - - full_metadata = beam_metadata_io.BeamDatasetMetadata( - dataset_metadata.DatasetMetadata(schema=schema), deferred_metadata, - asset_map) - - _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll) - - return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict + # Cause side-effect nodes to get executed. + for node in detached_sideeffect_leafs: + traverser.visit_value_node(node) + + output_cache_pcoll_dict = _make_output_cache( + cache_value_nodes, traverser, dataset_metrics + ) + + # Infer metadata. We take the inferred metadata and apply overrides that + # refer to values of tensors in the graph. The override tensors must + # be "constant" in that they don't depend on input data. The tensors can + # depend on analyzer outputs though. This allows us to set metadata that + # depends on analyzer outputs. _infer_metadata_from_saved_model will use the + # analyzer outputs stored in `transform_fn` to compute the metadata in a + # deferred manner, once the analyzer outputs are known. + if self._use_tf_compat_v1: + schema = schema_inference.infer_feature_schema(structured_outputs, graph) + else: + # Use metadata_fn here as func_graph outputs may be wrapped in an identity + # op and hence may not return the same tensors that were annotated. + tf_graph_context = graph_context.TFGraphContext( + module_to_export=tf.Module(), + temp_dir=base_temp_dir, + evaluated_replacements={}, + ) + concrete_metadata_fn = schema_inference.get_traced_metadata_fn( + preprocessing_fn=self._preprocessing_fn, + structured_inputs=structured_inputs, + tf_graph_context=tf_graph_context, + evaluate_schema_overrides=False, + ) + schema = schema_inference.infer_feature_schema_v2( + structured_outputs, + concrete_metadata_fn, + evaluate_schema_overrides=False, + ) + deferred_metadata = ( + transform_fn_pcoll + | f"ComputeDeferredMetadata[compat_v1={self._use_tf_compat_v1}]" + >> beam.Map(_infer_metadata_from_saved_model, self._use_tf_compat_v1) + ) + + full_metadata = beam_metadata_io.BeamDatasetMetadata( + dataset_metadata.DatasetMetadata(schema=schema), + deferred_metadata, + asset_map, + ) + + _clear_shared_state_after_barrier(pipeline, transform_fn_pcoll) + + return (transform_fn_pcoll, full_metadata), output_cache_pcoll_dict class AnalyzeDatasetWithCache(_AnalyzeDatasetCommon): - r"""Takes a preprocessing_fn and computes the relevant statistics. - - WARNING: This is experimental. - - Operates similarly to AnalyzeDataset, by computing the required statistics - except this will not re-compute statistics when they are already cached, and - will write out cache for statistics that it does compute whenever possible. - - Example use: - - >>> span_0_key = tft_beam.analyzer_cache.DatasetKey('span-0') - >>> cache_dir = tempfile.mkdtemp() - >>> output_path = os.path.join(tempfile.mkdtemp(), 'result') - >>> def preprocessing_fn(inputs): - ... x = inputs['x'] - ... return {'x_mean': tft.mean(x, name='x') + tf.zeros_like(x)} - >>> feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)} - >>> input_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> input_data_dict_0 = {span_0_key: [{'x': x} for x in range(6)]} - >>> input_data_dict_1 = {span_0_key: [{'x': x} for x in range(6, 11)]} - >>> empty_input_cache = {} - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... with beam.Pipeline() as p: - ... # Iteration #0: - ... transform_fn, output_cache = ( - ... (input_data_dict_0, empty_input_cache, input_metadata) - ... | tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)) - ... output_cache | tft_beam.analyzer_cache.WriteAnalysisCacheToFS( - ... p, cache_dir) - ... - ... # Iteration #1: - ... input_cache = p | tft_beam.analyzer_cache.ReadAnalysisCacheFromFS( - ... cache_dir, [span_0_key]) - ... transform_fn, output_cache = ( - ... (input_data_dict_1, input_cache, input_metadata) - ... | tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)) - ... output_cache | tft_beam.analyzer_cache.WriteAnalysisCacheToFS( - ... p, cache_dir) - ... - ... # Applying the accumulated transformation: - ... transform_data = p | beam.Create(input_data_dict_0[span_0_key]) - ... transformed_dataset = ( - ... ((transform_data, input_metadata), transform_fn) - ... | tft_beam.TransformDataset()) - ... transformed_data, transformed_metadata = transformed_dataset - ... (transformed_data - ... | beam.combiners.Sample.FixedSizeGlobally(1) - ... | beam.io.WriteToText(output_path, shard_name_template='')) - >>> with open(output_path) as f: - ... f.read() - - "[{'x_mean': 5.0}]\n" - """ - - def _make_parent_dataset(self, dataset): - if len(dataset) > 3: - raise ValueError('This API no longer requires flattened_pcoll') - return (None,) + dataset - - def _extract_input_pvalues(self, dataset): - # This method returns all nested pvalues to inform beam of nested pvalues. - super_dataset = self._make_parent_dataset(dataset) - _, pvalues = super()._extract_input_pvalues(super_dataset) - return dataset, pvalues - - def expand(self, dataset): - input_values_pcoll_dict = dataset[1] or dict() - analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys()) - return super().expand(self._make_parent_dataset(dataset)) + r"""Takes a preprocessing_fn and computes the relevant statistics. + + WARNING: This is experimental. + + Operates similarly to AnalyzeDataset, by computing the required statistics + except this will not re-compute statistics when they are already cached, and + will write out cache for statistics that it does compute whenever possible. + + Example use: + + >>> span_0_key = tft_beam.analyzer_cache.DatasetKey('span-0') + >>> cache_dir = tempfile.mkdtemp() + >>> output_path = os.path.join(tempfile.mkdtemp(), 'result') + >>> def preprocessing_fn(inputs): + ... x = inputs['x'] + ... return {'x_mean': tft.mean(x, name='x') + tf.zeros_like(x)} + >>> feature_spec = {'x': tf.io.FixedLenFeature([], tf.float32)} + >>> input_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> input_data_dict_0 = {span_0_key: [{'x': x} for x in range(6)]} + >>> input_data_dict_1 = {span_0_key: [{'x': x} for x in range(6, 11)]} + >>> empty_input_cache = {} + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... with beam.Pipeline() as p: + ... # Iteration #0: + ... transform_fn, output_cache = ( + ... (input_data_dict_0, empty_input_cache, input_metadata) + ... | tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)) + ... output_cache | tft_beam.analyzer_cache.WriteAnalysisCacheToFS( + ... p, cache_dir) + ... + ... # Iteration #1: + ... input_cache = p | tft_beam.analyzer_cache.ReadAnalysisCacheFromFS( + ... cache_dir, [span_0_key]) + ... transform_fn, output_cache = ( + ... (input_data_dict_1, input_cache, input_metadata) + ... | tft_beam.AnalyzeDatasetWithCache(preprocessing_fn)) + ... output_cache | tft_beam.analyzer_cache.WriteAnalysisCacheToFS( + ... p, cache_dir) + ... + ... # Applying the accumulated transformation: + ... transform_data = p | beam.Create(input_data_dict_0[span_0_key]) + ... transformed_dataset = ( + ... ((transform_data, input_metadata), transform_fn) + ... | tft_beam.TransformDataset()) + ... transformed_data, transformed_metadata = transformed_dataset + ... (transformed_data + ... | beam.combiners.Sample.FixedSizeGlobally(1) + ... | beam.io.WriteToText(output_path, shard_name_template='')) + >>> with open(output_path) as f: + ... f.read() + + "[{'x_mean': 5.0}]\n" + """ + + def _make_parent_dataset(self, dataset): + if len(dataset) > 3: + raise ValueError("This API no longer requires flattened_pcoll") + return (None,) + dataset + + def _extract_input_pvalues(self, dataset): + # This method returns all nested pvalues to inform beam of nested pvalues. + super_dataset = self._make_parent_dataset(dataset) + _, pvalues = super()._extract_input_pvalues(super_dataset) + return dataset, pvalues + + def expand(self, dataset): + input_values_pcoll_dict = dataset[1] or dict() + analyzer_cache.validate_dataset_keys(input_values_pcoll_dict.keys()) + return super().expand(self._make_parent_dataset(dataset)) class AnalyzeDataset(_AnalyzeDatasetCommon): - """Takes a preprocessing_fn and computes the relevant statistics. - - AnalyzeDataset accepts a preprocessing_fn in its constructor. When its - `expand` method is called on a dataset, it computes all the relevant - statistics required to run the transformation described by the - preprocessing_fn, and returns a TransformFn representing the application of - the preprocessing_fn. - """ - - def _extract_input_pvalues(self, dataset): - # This method returns all nested pvalues to inform beam of nested pvalues. - data, metadata = dataset - pvalues = [data] - if isinstance(metadata, beam_metadata_io.BeamDatasetMetadata): - pvalues.append(metadata.deferred_metadata) - return dataset, pvalues - - def expand(self, dataset): - input_values, input_metadata = dataset - result, cache = super().expand((input_values, None, None, input_metadata)) - assert not cache - return result + """Takes a preprocessing_fn and computes the relevant statistics. + + AnalyzeDataset accepts a preprocessing_fn in its constructor. When its + `expand` method is called on a dataset, it computes all the relevant + statistics required to run the transformation described by the + preprocessing_fn, and returns a TransformFn representing the application of + the preprocessing_fn. + """ + def _extract_input_pvalues(self, dataset): + # This method returns all nested pvalues to inform beam of nested pvalues. + data, metadata = dataset + pvalues = [data] + if isinstance(metadata, beam_metadata_io.BeamDatasetMetadata): + pvalues.append(metadata.deferred_metadata) + return dataset, pvalues -@beam.typehints.with_input_types( - Union[common_types.InstanceDictType, pa.RecordBatch] -) + def expand(self, dataset): + input_values, input_metadata = dataset + result, cache = super().expand((input_values, None, None, input_metadata)) + assert not cache + return result + + +@beam.typehints.with_input_types(Union[common_types.InstanceDictType, pa.RecordBatch]) # This PTransfrom outputs multiple PCollections and the output typehint is # checked against each of them. That is why it needs to represent elements of # all PCollections at the same time. @@ -1336,105 +1456,114 @@ def expand(self, dataset): ] ) class AnalyzeAndTransformDataset(beam.PTransform): - """Combination of AnalyzeDataset and TransformDataset. - - ```python - transformed, transform_fn = AnalyzeAndTransformDataset( - preprocessing_fn).expand(dataset) - ``` - - should be equivalent to + """Combination of AnalyzeDataset and TransformDataset. - ```python - transform_fn = AnalyzeDataset(preprocessing_fn).expand(dataset) - transformed = TransformDataset().expand((dataset, transform_fn)) - ``` + ```python + transformed, transform_fn = AnalyzeAndTransformDataset( + preprocessing_fn).expand(dataset) + ``` - but may be more efficient since it avoids multiple passes over the data. - """ + should be equivalent to - def __init__(self, preprocessing_fn, output_record_batches=False): - """Init method. + ```python + transform_fn = AnalyzeDataset(preprocessing_fn).expand(dataset) + transformed = TransformDataset().expand((dataset, transform_fn)) + ``` - Args: - preprocessing_fn: A function that accepts and returns a dictionary from - strings to `Tensor`s, `SparseTensor`s, or `RaggedTensor`s. - output_record_batches: (Optional) A bool. If `True`, - `AnalyzeAndTransformDataset` outputs `pyarrow.RecordBatch`es; - otherwise, outputs instance dicts. + but may be more efficient since it avoids multiple passes over the data. """ - self._preprocessing_fn = preprocessing_fn - self._output_record_batches = output_record_batches - - def _extract_input_pvalues(self, dataset): - # This method returns all nested pvalues to inform beam of nested pvalues. - data, metadata = dataset - pvalues = [data] - if isinstance(metadata, beam_metadata_io.BeamDatasetMetadata): - pvalues.append(metadata.deferred_metadata) - return dataset, pvalues - - def expand(self, dataset): - """Transform the dataset by applying the preprocessing_fn. - - Args: - dataset: A dataset. - Returns: - A (Dataset, TransformFn) pair containing the preprocessed dataset and - the graph that maps the input to the output data. - """ - # Expand is currently implemented by composing AnalyzeDataset and - # TransformDataset. Future versions however could do somthing more optimal, - # e.g. caching the values of expensive computations done in AnalyzeDataset. - transform_fn = ( - dataset | 'AnalyzeDataset' >> AnalyzeDataset(self._preprocessing_fn)) + def __init__(self, preprocessing_fn, output_record_batches=False): + """Init method. + + Args: + ---- + preprocessing_fn: A function that accepts and returns a dictionary from + strings to `Tensor`s, `SparseTensor`s, or `RaggedTensor`s. + output_record_batches: (Optional) A bool. If `True`, + `AnalyzeAndTransformDataset` outputs `pyarrow.RecordBatch`es; + otherwise, outputs instance dicts. + """ + self._preprocessing_fn = preprocessing_fn + self._output_record_batches = output_record_batches + + def _extract_input_pvalues(self, dataset): + # This method returns all nested pvalues to inform beam of nested pvalues. + data, metadata = dataset + pvalues = [data] + if isinstance(metadata, beam_metadata_io.BeamDatasetMetadata): + pvalues.append(metadata.deferred_metadata) + return dataset, pvalues + + def expand(self, dataset): + """Transform the dataset by applying the preprocessing_fn. + + Args: + ---- + dataset: A dataset. + + Returns: + ------- + A (Dataset, TransformFn) pair containing the preprocessed dataset and + the graph that maps the input to the output data. + """ + # Expand is currently implemented by composing AnalyzeDataset and + # TransformDataset. Future versions however could do somthing more optimal, + # e.g. caching the values of expensive computations done in AnalyzeDataset. + transform_fn = dataset | "AnalyzeDataset" >> AnalyzeDataset( + self._preprocessing_fn + ) - if Context.get_use_deep_copy_optimization(): - data, metadata = dataset + if Context.get_use_deep_copy_optimization(): + data, metadata = dataset - # obviates unnecessary data materialization when the input data source is - # safe to read more than once. - logging.info( - 'Deep copying the dataset before applying transformation') - dataset = (deep_copy.deep_copy(data), metadata) + # obviates unnecessary data materialization when the input data source is + # safe to read more than once. + logging.info("Deep copying the dataset before applying transformation") + dataset = (deep_copy.deep_copy(data), metadata) - transformed_dataset = ( - (dataset, transform_fn) - | 'TransformDataset' >> - TransformDataset(output_record_batches=self._output_record_batches)) - return transformed_dataset, transform_fn + transformed_dataset = ( + dataset, + transform_fn, + ) | "TransformDataset" >> TransformDataset( + output_record_batches=self._output_record_batches + ) + return transformed_dataset, transform_fn def _remove_columns_from_metadata(metadata, excluded_columns): - """Remove columns from metadata without mutating original metadata.""" - generated = schema_utils.schema_as_feature_spec(metadata.schema) - new_feature_spec = { - name: spec - for name, spec in generated.feature_spec.items() - if name not in excluded_columns - } - new_domains = { - name: spec - for name, spec in generated.domains.items() - if name not in excluded_columns - } - return dataset_metadata.DatasetMetadata.from_feature_spec( - new_feature_spec, new_domains) + """Remove columns from metadata without mutating original metadata.""" + generated = schema_utils.schema_as_feature_spec(metadata.schema) + new_feature_spec = { + name: spec + for name, spec in generated.feature_spec.items() + if name not in excluded_columns + } + new_domains = { + name: spec + for name, spec in generated.domains.items() + if name not in excluded_columns + } + return dataset_metadata.DatasetMetadata.from_feature_spec( + new_feature_spec, new_domains + ) class _MaybeInferTensorRepresentationsDoFn(beam.DoFn): - """Tries to infer TensorRepresentations from a Schema.""" - - def process( - self, schema: schema_pb2.Schema - ) -> Iterable[Dict[str, schema_pb2.TensorRepresentation]]: - try: - yield (tensor_representation_util - .InferTensorRepresentationsFromMixedSchema(schema)) - except ValueError: - # Ignore any inference errors since the output is only used for metrics. - yield {} + """Tries to infer TensorRepresentations from a Schema.""" + + def process( + self, schema: schema_pb2.Schema + ) -> Iterable[Dict[str, schema_pb2.TensorRepresentation]]: + try: + yield ( + tensor_representation_util.InferTensorRepresentationsFromMixedSchema( + schema + ) + ) + except ValueError: + # Ignore any inference errors since the output is only used for metrics. + yield {} @beam.typehints.with_input_types( @@ -1459,226 +1588,251 @@ def process( ] ) class TransformDataset(beam.PTransform): - """Applies the transformation computed by transforming a Dataset. - - TransformDataset's `expand` method is called on a (dataset, transform_fn) - pair. It applies the transform_fn to each row of the input dataset and - returns the resulting dataset. - - args: - exclude_outputs: (Optional) Output features that should not be produced. - output_record_batches: (Optional) A bool. If `True`, `TransformDataset` - outputs `pyarrow.RecordBatch`es; otherwise, outputs instance dicts. - """ - - def __init__(self, exclude_outputs=None, output_record_batches=False): - self._exclude_outputs = exclude_outputs - self._output_record_batches = output_record_batches - self._use_tf_compat_v1 = Context.get_use_tf_compat_v1() - if self._use_tf_compat_v1: - _warn_about_tf_compat_v1() - - def _extract_input_pvalues(self, dataset_and_transform_fn): - # This method returns all nested pvalues to inform beam of nested pvalues. - (data, input_metadata), (transform_fn, output_metadata) = ( - dataset_and_transform_fn) - pvalues = [data, transform_fn] - if isinstance(input_metadata, beam_metadata_io.BeamDatasetMetadata): - pvalues.append(input_metadata.deferred_metadata) - if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata): - pvalues.append(output_metadata.deferred_metadata) - return dataset_and_transform_fn, pvalues - - def expand(self, dataset_and_transform_fn): - """Transforms the dataset using the transform_fn. + """Applies the transformation computed by transforming a Dataset. - Args: - dataset_and_transform_fn: A tuple of dataset and preprocessing - function. + TransformDataset's `expand` method is called on a (dataset, transform_fn) + pair. It applies the transform_fn to each row of the input dataset and + returns the resulting dataset. - Returns: - A dataset transformed according to the transform_fn. + Args: + ---- + exclude_outputs: (Optional) Output features that should not be produced. + output_record_batches: (Optional) A bool. If `True`, `TransformDataset` + outputs `pyarrow.RecordBatch`es; otherwise, outputs instance dicts. """ - (input_values, input_metadata), (transform_fn, output_metadata) = ( - dataset_and_transform_fn) - if isinstance(input_metadata, dataset_metadata.DatasetMetadata): - if Context.get_passthrough_keys(): - raise ValueError('passthrough_keys is set to {} but it is not ' - 'supported with instance dicts + DatasetMetadata ' - 'input. Follow the guide to switch to the TFXIO ' - 'format.'.format(Context.get_passthrough_keys())) - logging.warning( - 'You are passing instance dicts and DatasetMetadata to TFT which ' - 'will not provide optimal performance. Consider following the TFT ' - 'guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).') - to_tfxio_ptransform = _InstanceDictInputToTFXIOInput( - input_metadata.schema, Context.get_desired_batch_size()) - input_tensor_adapter_config = to_tfxio_ptransform.tensor_adapter_config() - input_values |= 'InstanceDictToRecordBatch' >> to_tfxio_ptransform - else: - input_tensor_adapter_config = input_metadata - - # If exclude_outputs is set, update the output metadata. - if self._exclude_outputs is not None: - if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata): - new_metadata = _remove_columns_from_metadata( - output_metadata.dataset_metadata, self._exclude_outputs) - new_deferred_metadata = ( - output_metadata.deferred_metadata - | 'RemoveColumns' - >> beam.Map(_remove_columns_from_metadata, self._exclude_outputs) + + def __init__(self, exclude_outputs=None, output_record_batches=False): + self._exclude_outputs = exclude_outputs + self._output_record_batches = output_record_batches + self._use_tf_compat_v1 = Context.get_use_tf_compat_v1() + if self._use_tf_compat_v1: + _warn_about_tf_compat_v1() + + def _extract_input_pvalues(self, dataset_and_transform_fn): + # This method returns all nested pvalues to inform beam of nested pvalues. + (data, input_metadata), (transform_fn, output_metadata) = ( + dataset_and_transform_fn + ) + pvalues = [data, transform_fn] + if isinstance(input_metadata, beam_metadata_io.BeamDatasetMetadata): + pvalues.append(input_metadata.deferred_metadata) + if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata): + pvalues.append(output_metadata.deferred_metadata) + return dataset_and_transform_fn, pvalues + + def expand(self, dataset_and_transform_fn): + """Transforms the dataset using the transform_fn. + + Args: + ---- + dataset_and_transform_fn: A tuple of dataset and preprocessing + function. + + Returns: + ------- + A dataset transformed according to the transform_fn. + """ + (input_values, input_metadata), (transform_fn, output_metadata) = ( + dataset_and_transform_fn + ) + if isinstance(input_metadata, dataset_metadata.DatasetMetadata): + if Context.get_passthrough_keys(): + raise ValueError( + f"passthrough_keys is set to {Context.get_passthrough_keys()} but it is not " + "supported with instance dicts + DatasetMetadata " + "input. Follow the guide to switch to the TFXIO " + "format." + ) + logging.warning( + "You are passing instance dicts and DatasetMetadata to TFT which " + "will not provide optimal performance. Consider following the TFT " + "guide to upgrade to the TFXIO format (Apache Arrow RecordBatch)." + ) + to_tfxio_ptransform = _InstanceDictInputToTFXIOInput( + input_metadata.schema, Context.get_desired_batch_size() + ) + input_tensor_adapter_config = to_tfxio_ptransform.tensor_adapter_config() + input_values |= "InstanceDictToRecordBatch" >> to_tfxio_ptransform + else: + input_tensor_adapter_config = input_metadata + + # If exclude_outputs is set, update the output metadata. + if self._exclude_outputs is not None: + if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata): + new_metadata = _remove_columns_from_metadata( + output_metadata.dataset_metadata, self._exclude_outputs + ) + new_deferred_metadata = ( + output_metadata.deferred_metadata + | "RemoveColumns" + >> beam.Map(_remove_columns_from_metadata, self._exclude_outputs) + ) + output_metadata = beam_metadata_io.BeamDatasetMetadata( + new_metadata, new_deferred_metadata, output_metadata.asset_map + ) + else: + output_metadata = _remove_columns_from_metadata( + output_metadata, self._exclude_outputs + ) + + if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata): + deferred_schema = ( + output_metadata.deferred_metadata + | "GetDeferredSchema" >> beam.Map(lambda m: m.schema) + ) + output_dataset_metadata = output_metadata.dataset_metadata + else: + deferred_schema = self.pipeline | "CreateDeferredSchema" >> beam.Create( + [output_metadata.schema] + ) + output_dataset_metadata = output_metadata + output_dataset_metadata._output_record_batches = self._output_record_batches # pylint: disable=protected-access + + # Increment input metrics. + _ = ( + input_values + | "InstrumentInputBytes[Transform]" + >> telemetry.TrackRecordBatchBytes( + beam_common.METRICS_NAMESPACE, "transform_input_bytes" + ) ) - output_metadata = beam_metadata_io.BeamDatasetMetadata( - new_metadata, new_deferred_metadata, output_metadata.asset_map) - else: - output_metadata = _remove_columns_from_metadata( - output_metadata, self._exclude_outputs) - - if isinstance(output_metadata, beam_metadata_io.BeamDatasetMetadata): - deferred_schema = ( - output_metadata.deferred_metadata - | 'GetDeferredSchema' >> beam.Map(lambda m: m.schema)) - output_dataset_metadata = output_metadata.dataset_metadata - else: - deferred_schema = ( - self.pipeline - | 'CreateDeferredSchema' >> beam.Create([output_metadata.schema])) - output_dataset_metadata = output_metadata - output_dataset_metadata._output_record_batches = self._output_record_batches # pylint: disable=protected-access - - # Increment input metrics. - _ = ( - input_values - | 'InstrumentInputBytes[Transform]' >> telemetry.TrackRecordBatchBytes( - beam_common.METRICS_NAMESPACE, 'transform_input_bytes')) - - _ = ( - self.pipeline | 'CreateTransformInputTensorRepresentations' >> - beam.Create([input_tensor_adapter_config.tensor_representations]) - | 'InstrumentTransformInputTensors' >> - telemetry.TrackTensorRepresentations( - telemetry_util.AppendToNamespace(beam_common.METRICS_NAMESPACE, - ['transform_input_tensors']))) - - tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get( - type(self.pipeline.runner)) - output_batches = input_values | 'Transform' >> beam.ParDo( - _RunMetaGraphDoFn( - tf_config, - input_tensor_adapter_config=input_tensor_adapter_config, - use_tf_compat_v1=self._use_tf_compat_v1, - shared_graph_state_handle=shared.Shared(), - passthrough_keys=Context.get_passthrough_keys(), - exclude_outputs=self._exclude_outputs, - ), - saved_model_dir=beam.pvalue.AsSingleton(transform_fn), - ) - # Since we are using a deferred schema, obtain a pcollection containing - # the converter that will be created from it. - converter_pcol = deferred_schema | 'MakeTensorToArrowConverter' >> beam.Map( - impl_helper.make_tensor_to_arrow_converter - ) + _ = ( + self.pipeline + | "CreateTransformInputTensorRepresentations" + >> beam.Create([input_tensor_adapter_config.tensor_representations]) + | "InstrumentTransformInputTensors" + >> telemetry.TrackTensorRepresentations( + telemetry_util.AppendToNamespace( + beam_common.METRICS_NAMESPACE, ["transform_input_tensors"] + ) + ) + ) - # Increment output data metrics. - _ = ( - converter_pcol - | 'MapToTensorRepresentations' - >> beam.Map(lambda converter: converter.tensor_representations()) - | 'InstrumentTransformOutputTensors' - >> telemetry.TrackTensorRepresentations( - telemetry_util.AppendToNamespace( - beam_common.METRICS_NAMESPACE, ['transform_output_tensors'] + tf_config = _DEFAULT_TENSORFLOW_CONFIG_BY_BEAM_RUNNER_TYPE.get( + type(self.pipeline.runner) + ) + output_batches = input_values | "Transform" >> beam.ParDo( + _RunMetaGraphDoFn( + tf_config, + input_tensor_adapter_config=input_tensor_adapter_config, + use_tf_compat_v1=self._use_tf_compat_v1, + shared_graph_state_handle=shared.Shared(), + passthrough_keys=Context.get_passthrough_keys(), + exclude_outputs=self._exclude_outputs, + ), + saved_model_dir=beam.pvalue.AsSingleton(transform_fn), + ) + + # Since we are using a deferred schema, obtain a pcollection containing + # the converter that will be created from it. + converter_pcol = deferred_schema | "MakeTensorToArrowConverter" >> beam.Map( + impl_helper.make_tensor_to_arrow_converter + ) + + # Increment output data metrics. + _ = ( + converter_pcol + | "MapToTensorRepresentations" + >> beam.Map(lambda converter: converter.tensor_representations()) + | "InstrumentTransformOutputTensors" + >> telemetry.TrackTensorRepresentations( + telemetry_util.AppendToNamespace( + beam_common.METRICS_NAMESPACE, ["transform_output_tensors"] + ) ) ) - ) - output_data = output_batches | 'ConvertToRecordBatch' >> beam.FlatMap( - _convert_to_record_batch, - converter=beam.pvalue.AsSingleton(converter_pcol), - passthrough_keys=Context.get_passthrough_keys(), - input_metadata=input_metadata, - # TODO(b/254822532): Consider always doing the validation. - validate_varlen_sparse_values=not self._output_record_batches, - ) + output_data = output_batches | "ConvertToRecordBatch" >> beam.FlatMap( + _convert_to_record_batch, + converter=beam.pvalue.AsSingleton(converter_pcol), + passthrough_keys=Context.get_passthrough_keys(), + input_metadata=input_metadata, + # TODO(b/254822532): Consider always doing the validation. + validate_varlen_sparse_values=not self._output_record_batches, + ) - if not self._output_record_batches: - logging.warning( - 'You are outputting instance dicts from `TransformDataset` which ' - 'will not provide optimal performance. Consider setting ' - '`output_record_batches=True` to upgrade to the TFXIO format (Apache ' - 'Arrow RecordBatch). Encoding functionality in this module works ' - 'with both formats.' - ) - output_data |= 'ConvertAndUnbatchToInstanceDicts' >> beam.FlatMap( - _transformed_batch_to_instance_dicts, - schema=beam.pvalue.AsSingleton(deferred_schema), - ) + if not self._output_record_batches: + logging.warning( + "You are outputting instance dicts from `TransformDataset` which " + "will not provide optimal performance. Consider setting " + "`output_record_batches=True` to upgrade to the TFXIO format (Apache " + "Arrow RecordBatch). Encoding functionality in this module works " + "with both formats." + ) + output_data |= "ConvertAndUnbatchToInstanceDicts" >> beam.FlatMap( + _transformed_batch_to_instance_dicts, + schema=beam.pvalue.AsSingleton(deferred_schema), + ) - _clear_shared_state_after_barrier(self.pipeline, output_data) + _clear_shared_state_after_barrier(self.pipeline, output_data) - return (output_data, output_metadata) + return (output_data, output_metadata) class EncodeTransformedDataset(beam.PTransform): - """Encodes transformed data into serialized tf.Examples. - - Should operate on the output of `TransformDataset`, this can operate on either - record batch or instance dict data. - The expected input is a (transformed_data, transformed_metadata) tuple. - - Example use: - - >>> def preprocessing_fn(inputs): - ... return {'x_scaled': tft.scale_to_z_score(inputs['x'], name='x')} - >>> raw_data = [dict(x=1), dict(x=2), dict(x=3)] - >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> output_path = os.path.join(tempfile.mkdtemp(), 'result') - >>> with beam.Pipeline() as p: - ... with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... data_pcoll = p | beam.Create(raw_data) - ... transformed_dataset, transform_fn = ( - ... (data_pcoll, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - ... _ = ( - ... transformed_dataset - ... | tft_beam.EncodeTransformedDataset() - ... | beam.io.WriteToTFRecord(output_path, shard_name_template='')) - >>> result_feature_spec ={'x_scaled': tf.io.FixedLenFeature([], tf.float32)} - >>> list(tf.data.TFRecordDataset([output_path]) - ... .map(lambda x: tf.io.parse_example(x, result_feature_spec)) - ... .as_numpy_iterator()) - [{'x_scaled': -1.2247448}, {'x_scaled': 0.0}, {'x_scaled': 1.2247448}] - """ - - def _extract_input_pvalues(self, transformed_data_and_metadata): - # This method lets beam know that metadata is not a pvalue. - return transformed_data_and_metadata, [transformed_data_and_metadata[0]] - - def expand(self, transformed_data_and_metadata): - - transformed_data, transformed_metadata = transformed_data_and_metadata - - deferred_schema = ( - transformed_metadata.deferred_metadata - | 'GetDeferredSchema' >> beam.Map(lambda m: m.schema)) - - if transformed_metadata.dataset_metadata._output_record_batches: # pylint: disable=protected-access - transformed_data_coder_pcol = ( - deferred_schema | 'RecordBatchToExamplesEncoder' >> beam.Map( - example_coder.RecordBatchToExamplesEncoder)) - encode_ptransform = 'EncodeRecordBatches' >> beam.FlatMap( - # Dropping passthrough features. - lambda elem, coder: coder.encode(elem[0]), - coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol)) - else: - transformed_data_coder_pcol = ( - deferred_schema - | 'ExampleProtoCoder' >> beam.Map( - example_proto_coder.ExampleProtoCoder)) - encode_ptransform = 'EncodeInstances' >> beam.Map( - lambda data, data_coder: data_coder.encode(data), - data_coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol)) - - return transformed_data | encode_ptransform + """Encodes transformed data into serialized tf.Examples. + + Should operate on the output of `TransformDataset`, this can operate on either + record batch or instance dict data. + The expected input is a (transformed_data, transformed_metadata) tuple. + + Example use: + + >>> def preprocessing_fn(inputs): + ... return {'x_scaled': tft.scale_to_z_score(inputs['x'], name='x')} + >>> raw_data = [dict(x=1), dict(x=2), dict(x=3)] + >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> output_path = os.path.join(tempfile.mkdtemp(), 'result') + >>> with beam.Pipeline() as p: + ... with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... data_pcoll = p | beam.Create(raw_data) + ... transformed_dataset, transform_fn = ( + ... (data_pcoll, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + ... _ = ( + ... transformed_dataset + ... | tft_beam.EncodeTransformedDataset() + ... | beam.io.WriteToTFRecord(output_path, shard_name_template='')) + >>> result_feature_spec ={'x_scaled': tf.io.FixedLenFeature([], tf.float32)} + >>> list(tf.data.TFRecordDataset([output_path]) + ... .map(lambda x: tf.io.parse_example(x, result_feature_spec)) + ... .as_numpy_iterator()) + [{'x_scaled': -1.2247448}, {'x_scaled': 0.0}, {'x_scaled': 1.2247448}] + """ + + def _extract_input_pvalues(self, transformed_data_and_metadata): + # This method lets beam know that metadata is not a pvalue. + return transformed_data_and_metadata, [transformed_data_and_metadata[0]] + + def expand(self, transformed_data_and_metadata): + transformed_data, transformed_metadata = transformed_data_and_metadata + + deferred_schema = ( + transformed_metadata.deferred_metadata + | "GetDeferredSchema" >> beam.Map(lambda m: m.schema) + ) + + if transformed_metadata.dataset_metadata._output_record_batches: # pylint: disable=protected-access + transformed_data_coder_pcol = ( + deferred_schema + | "RecordBatchToExamplesEncoder" + >> beam.Map(example_coder.RecordBatchToExamplesEncoder) + ) + encode_ptransform = "EncodeRecordBatches" >> beam.FlatMap( + # Dropping passthrough features. + lambda elem, coder: coder.encode(elem[0]), + coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol), + ) + else: + transformed_data_coder_pcol = ( + deferred_schema + | "ExampleProtoCoder" >> beam.Map(example_proto_coder.ExampleProtoCoder) + ) + encode_ptransform = "EncodeInstances" >> beam.Map( + lambda data, data_coder: data_coder.encode(data), + data_coder=beam.pvalue.AsSingleton(transformed_data_coder_pcol), + ) + + return transformed_data | encode_ptransform diff --git a/tensorflow_transform/beam/impl_output_record_batches_test.py b/tensorflow_transform/beam/impl_output_record_batches_test.py index a4c01bd..beaeaed 100644 --- a/tensorflow_transform/beam/impl_output_record_batches_test.py +++ b/tensorflow_transform/beam/impl_output_record_batches_test.py @@ -18,186 +18,188 @@ import numpy as np import pyarrow as pa import tensorflow as tf +from tfx_bsl.tfxio import tensor_adapter + from tensorflow_transform import impl_helper -from tensorflow_transform.beam import impl -from tensorflow_transform.beam import impl_test -from tensorflow_transform.beam import tft_unit +from tensorflow_transform.beam import impl, impl_test, tft_unit from tensorflow_transform.tf_metadata import schema_utils -from tfx_bsl.tfxio import tensor_adapter _LARGE_BATCH_SIZE = 1 << 10 class BeamImplOutputRecordBatchesTest(impl_test.BeamImplTest): + def _OutputRecordBatches(self): + return True - def _OutputRecordBatches(self): - return True - - def _MakeTransformOutputAssertFn(self, expected, sort=False): - # Merge expected instance dicts. - merged_expected = collections.defaultdict(list) - for instance_dict in expected: - for key, value in instance_dict.items(): - # Scalars must be wrapped in a list. - if (hasattr(value, '__iter__') and not isinstance(value, - (str, bytes)) or - value is None): - maybe_wrapped_value = value - else: - maybe_wrapped_value = [value] - merged_expected[key].append(maybe_wrapped_value) + def _MakeTransformOutputAssertFn(self, expected, sort=False): + # Merge expected instance dicts. + merged_expected = collections.defaultdict(list) + for instance_dict in expected: + for key, value in instance_dict.items(): + # Scalars must be wrapped in a list. + if ( + hasattr(value, "__iter__") + and not isinstance(value, (str, bytes)) + or value is None + ): + maybe_wrapped_value = value + else: + maybe_wrapped_value = [value] + merged_expected[key].append(maybe_wrapped_value) - def _assert_fn(actual): - # Merge output RecordBatches. - merged_actual = collections.defaultdict(list) - for record_batch, _ in actual: - for key, value in record_batch.to_pydict().items(): - merged_actual[key].extend(value) - if sort: - for value in merged_actual.values(): - value.sort() - for value in merged_expected.values(): - value.sort() - self.assertDictEqual(merged_expected, merged_actual) + def _assert_fn(actual): + # Merge output RecordBatches. + merged_actual = collections.defaultdict(list) + for record_batch, _ in actual: + for key, value in record_batch.to_pydict().items(): + merged_actual[key].extend(value) + if sort: + for value in merged_actual.values(): + value.sort() + for value in merged_expected.values(): + value.sort() + self.assertDictEqual(merged_expected, merged_actual) - return _assert_fn + return _assert_fn - def testConvertToRecordBatchPassthroughData(self): - passthrough_key1 = '__passthrough_with_batch_length__' - passthrough_key2 = '__passthrough_with_one_value__' - passthrough_key3 = '__passthrough_with_one_distinct_value_none__' - passthrough_key4 = '__passthrough_with_one_distinct_value_not_none__' - batch_dict = { - 'a': - np.array([100, 1, 10], np.int64), - passthrough_key1: - pa.array([[1], None, [0]], pa.large_list(pa.int64())), - passthrough_key2: - pa.array([None], pa.large_list(pa.float32())), - passthrough_key3: - pa.array([None, None], pa.large_list(pa.large_binary())), - passthrough_key4: - pa.array([[10], [10]], pa.large_list(pa.int64())) - } - schema = schema_utils.schema_from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.int64)}) - converter = impl_helper.make_tensor_to_arrow_converter(schema) - passthrough_keys = { - passthrough_key1, passthrough_key2, passthrough_key3, passthrough_key4 - } - arrow_schema = pa.schema([ - ('a', pa.large_list(pa.int64())), - (passthrough_key1, batch_dict[passthrough_key1].type), - (passthrough_key2, batch_dict[passthrough_key2].type), - (passthrough_key3, batch_dict[passthrough_key3].type), - (passthrough_key4, batch_dict[passthrough_key4].type) - ]) - # Note that we only need `input_metadata.arrow_schema`. - input_metadata = tensor_adapter.TensorAdapterConfig(arrow_schema, {}) - converted = list(impl._convert_to_record_batch( - batch_dict, converter, passthrough_keys, input_metadata)) - self.assertLen(converted, 1) - record_batch, unary_features = converted[0] - expected_record_batch = { - 'a': [[100], [1], [10]], - passthrough_key1: [[1], None, [0]] - } - self.assertDictEqual(expected_record_batch, record_batch.to_pydict()) - expected_unary_features = { - passthrough_key2: [None], - passthrough_key3: [None], - passthrough_key4: [[10]] - } - unary_features = {k: v.to_pylist() for k, v in unary_features.items()} - self.assertDictEqual(expected_unary_features, unary_features) + def testConvertToRecordBatchPassthroughData(self): + passthrough_key1 = "__passthrough_with_batch_length__" + passthrough_key2 = "__passthrough_with_one_value__" + passthrough_key3 = "__passthrough_with_one_distinct_value_none__" + passthrough_key4 = "__passthrough_with_one_distinct_value_not_none__" + batch_dict = { + "a": np.array([100, 1, 10], np.int64), + passthrough_key1: pa.array([[1], None, [0]], pa.large_list(pa.int64())), + passthrough_key2: pa.array([None], pa.large_list(pa.float32())), + passthrough_key3: pa.array([None, None], pa.large_list(pa.large_binary())), + passthrough_key4: pa.array([[10], [10]], pa.large_list(pa.int64())), + } + schema = schema_utils.schema_from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.int64)} + ) + converter = impl_helper.make_tensor_to_arrow_converter(schema) + passthrough_keys = { + passthrough_key1, + passthrough_key2, + passthrough_key3, + passthrough_key4, + } + arrow_schema = pa.schema( + [ + ("a", pa.large_list(pa.int64())), + (passthrough_key1, batch_dict[passthrough_key1].type), + (passthrough_key2, batch_dict[passthrough_key2].type), + (passthrough_key3, batch_dict[passthrough_key3].type), + (passthrough_key4, batch_dict[passthrough_key4].type), + ] + ) + # Note that we only need `input_metadata.arrow_schema`. + input_metadata = tensor_adapter.TensorAdapterConfig(arrow_schema, {}) + converted = list( + impl._convert_to_record_batch( + batch_dict, converter, passthrough_keys, input_metadata + ) + ) + self.assertLen(converted, 1) + record_batch, unary_features = converted[0] + expected_record_batch = { + "a": [[100], [1], [10]], + passthrough_key1: [[1], None, [0]], + } + self.assertDictEqual(expected_record_batch, record_batch.to_pydict()) + expected_unary_features = { + passthrough_key2: [None], + passthrough_key3: [None], + passthrough_key4: [[10]], + } + unary_features = {k: v.to_pylist() for k, v in unary_features.items()} + self.assertDictEqual(expected_unary_features, unary_features) - # Test pass-through data when input and output batch sizes are different and - # the number of its unique values is >1. - passthrough_key5 = '__passthrough_with_wrong_batch_size__' - passthrough_keys.add(passthrough_key5) - batch_dict[passthrough_key5] = pa.array([[1], [2]], - pa.large_list(pa.int64())) - input_metadata.arrow_schema = input_metadata.arrow_schema.append( - pa.field(passthrough_key5, batch_dict[passthrough_key5].type)) - with self.assertRaisesRegex( - ValueError, - 'Cannot pass-through data when ' - 'input and output batch sizes are different', - ): - _ = list( - impl._convert_to_record_batch( - batch_dict, converter, passthrough_keys, input_metadata - ) - ) + # Test pass-through data when input and output batch sizes are different and + # the number of its unique values is >1. + passthrough_key5 = "__passthrough_with_wrong_batch_size__" + passthrough_keys.add(passthrough_key5) + batch_dict[passthrough_key5] = pa.array([[1], [2]], pa.large_list(pa.int64())) + input_metadata.arrow_schema = input_metadata.arrow_schema.append( + pa.field(passthrough_key5, batch_dict[passthrough_key5].type) + ) + with self.assertRaisesRegex( + ValueError, + "Cannot pass-through data when " + "input and output batch sizes are different", + ): + _ = list( + impl._convert_to_record_batch( + batch_dict, converter, passthrough_keys, input_metadata + ) + ) - @tft_unit.named_parameters( - dict( - testcase_name='NoPassthroughData', - passthrough_data={}, - expected_unary_features={}, - ), - dict( - testcase_name='WithPassthroughData', - passthrough_data={ - '__passthrough_with_batch_length__': pa.array( - [[1]] * _LARGE_BATCH_SIZE, pa.large_list(pa.int64()) - ), - '__passthrough_with_one_value__': pa.array( - [None], pa.large_list(pa.float32()) - ), - }, - expected_unary_features={ - '__passthrough_with_one_value__': pa.array( - [None], pa.large_list(pa.float32()) - ), - }, - ), - ) - def testConvertToLargeRecordBatch( - self, passthrough_data, expected_unary_features - ): - """Tests slicing of large transformed batches during conversion.""" - # Any Beam test pipeline handling elements this large crashes the program - # with OOM (even with 28GB memory available), so we test the conversion - # pretty narrowly. - - # 2^31 elements in total. - num_values = 1 << 21 - batch_dict = { - 'a': np.zeros([_LARGE_BATCH_SIZE, num_values], np.float32), - **passthrough_data, - } - schema = schema_utils.schema_from_feature_spec( - {'a': tf.io.FixedLenFeature([num_values], tf.float32)} - ) - converter = impl_helper.make_tensor_to_arrow_converter(schema) - arrow_schema = pa.schema( - [ - ('a', pa.large_list(pa.float32())), - ] - + [(key, value.type) for key, value in passthrough_data.items()] + @tft_unit.named_parameters( + dict( + testcase_name="NoPassthroughData", + passthrough_data={}, + expected_unary_features={}, + ), + dict( + testcase_name="WithPassthroughData", + passthrough_data={ + "__passthrough_with_batch_length__": pa.array( + [[1]] * _LARGE_BATCH_SIZE, pa.large_list(pa.int64()) + ), + "__passthrough_with_one_value__": pa.array( + [None], pa.large_list(pa.float32()) + ), + }, + expected_unary_features={ + "__passthrough_with_one_value__": pa.array( + [None], pa.large_list(pa.float32()) + ), + }, + ), ) - input_metadata = tensor_adapter.TensorAdapterConfig(arrow_schema, {}) - actual_num_rows = 0 - actual_num_batches = 0 - # Features are either going to be in the `record_batch` or in - # `unary_features`. - record_batch_features = set(batch_dict.keys()) - set( - expected_unary_features.keys() - ) - for record_batch, unary_features in impl._convert_to_record_batch( - batch_dict, converter, set(passthrough_data.keys()), input_metadata - ): - self.assertEqual(set(record_batch.schema.names), record_batch_features) - self.assertEqual(unary_features, expected_unary_features) - self.assertLessEqual( - record_batch.nbytes, impl._MAX_TRANSFORMED_BATCH_BYTES_SIZE - ) - actual_num_rows += record_batch.num_rows - actual_num_batches += 1 - self.assertEqual(actual_num_rows, _LARGE_BATCH_SIZE) - self.assertGreater(actual_num_batches, 1) + def testConvertToLargeRecordBatch(self, passthrough_data, expected_unary_features): + """Tests slicing of large transformed batches during conversion.""" + # Any Beam test pipeline handling elements this large crashes the program + # with OOM (even with 28GB memory available), so we test the conversion + # pretty narrowly. + + # 2^31 elements in total. + num_values = 1 << 21 + batch_dict = { + "a": np.zeros([_LARGE_BATCH_SIZE, num_values], np.float32), + **passthrough_data, + } + schema = schema_utils.schema_from_feature_spec( + {"a": tf.io.FixedLenFeature([num_values], tf.float32)} + ) + converter = impl_helper.make_tensor_to_arrow_converter(schema) + arrow_schema = pa.schema( + [ + ("a", pa.large_list(pa.float32())), + ] + + [(key, value.type) for key, value in passthrough_data.items()] + ) + input_metadata = tensor_adapter.TensorAdapterConfig(arrow_schema, {}) + actual_num_rows = 0 + actual_num_batches = 0 + # Features are either going to be in the `record_batch` or in + # `unary_features`. + record_batch_features = set(batch_dict.keys()) - set( + expected_unary_features.keys() + ) + for record_batch, unary_features in impl._convert_to_record_batch( + batch_dict, converter, set(passthrough_data.keys()), input_metadata + ): + self.assertEqual(set(record_batch.schema.names), record_batch_features) + self.assertEqual(unary_features, expected_unary_features) + self.assertLessEqual( + record_batch.nbytes, impl._MAX_TRANSFORMED_BATCH_BYTES_SIZE + ) + actual_num_rows += record_batch.num_rows + actual_num_batches += 1 + self.assertEqual(actual_num_rows, _LARGE_BATCH_SIZE) + self.assertGreater(actual_num_batches, 1) -if __name__ == '__main__': - tft_unit.main() +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/impl_test.py b/tensorflow_transform/beam/impl_test.py index dde3034..a8c26e9 100644 --- a/tensorflow_transform/beam/impl_test.py +++ b/tensorflow_transform/beam/impl_test.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2017 Google Inc. All Rights Reserved. # @@ -17,4368 +16,4835 @@ import itertools import math import os +import unittest from typing import Optional, Tuple -import apache_beam as beam -from apache_beam.testing import util as beam_test_util -import numpy as np -import pyarrow as pa -import tensorflow as tf -import tensorflow_transform as tft -from tensorflow_transform import analyzers -from tensorflow_transform import common -from tensorflow_transform import pretrained_models -from tensorflow_transform import schema_inference -import tensorflow_transform.beam as tft_beam -from tensorflow_transform.beam.tft_beam_io import transform_fn_io -from tensorflow_transform.beam import tft_unit -from tensorflow_transform.keras_lib import tf_keras -from tfx_bsl.tfxio import tensor_adapter +import apache_beam as beam +import numpy as np +import pyarrow as pa +import tensorflow as tf +from apache_beam.testing import util as beam_test_util +from google.protobuf import text_format +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import tensor_adapter + +import tensorflow_transform as tft +import tensorflow_transform.beam as tft_beam +from tensorflow_transform import analyzers, common, pretrained_models, schema_inference +from tensorflow_transform.beam import tft_unit +from tensorflow_transform.beam.tft_beam_io import transform_fn_io +from tensorflow_transform.keras_lib import tf_keras + +if common.IS_ANNOTATIONS_PB_AVAILABLE: + from tensorflow_transform import ( + annotations_pb2, # pylint: disable=g-import-not-at-top + ) + + +_SCALE_TO_Z_SCORE_TEST_CASES = [ + dict( + testcase_name="int16", + input_data=np.array([[1], [1], [2], [2]], np.int16), + output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float32), + elementwise=False, + ), + dict( + testcase_name="int32", + input_data=np.array([[1], [1], [2], [2]], np.int32), + output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float32), + elementwise=False, + ), + dict( + testcase_name="int64", + input_data=np.array([[1], [1], [2], [2]], np.int64), + output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float32), + elementwise=False, + ), + dict( + testcase_name="float32", + input_data=np.array([[1], [1], [2], [2]], np.float32), + output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float32), + elementwise=False, + ), + dict( + testcase_name="float64", + input_data=np.array([[1], [1], [2], [2]], np.float64), + output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float64), + elementwise=False, + ), + dict( + testcase_name="vector", + input_data=np.array([[1, 2], [3, 4]], np.float32), + output_data=np.array([[-3, -1], [1, 3]] / np.sqrt(5.0), np.float32), + elementwise=False, + ), + dict( + testcase_name="vector_elementwise", + input_data=np.array([[1, 2], [3, 4]], np.float32), + output_data=np.array([[-1.0, -1.0], [1.0, 1.0]], np.float32), + elementwise=True, + ), + dict( + testcase_name="zero_variance", + input_data=np.array([[3], [3], [3], [3]], np.float32), + output_data=np.array([[0], [0], [0], [0]], np.float32), + elementwise=False, + ), + dict( + testcase_name="zero_variance_elementwise", + input_data=np.array([[3, 4], [3, 4]], np.float32), + output_data=np.array([[0, 0], [0, 0]], np.float32), + elementwise=True, + ), +] + +_SCALE_TO_Z_SCORE_NAN_TEST_CASES = [ + dict( + testcase_name="with_nans", + input_data=np.array([[1], [np.nan], [np.nan], [2]], np.float32), + output_data=np.array([[-1.0], [np.nan], [np.nan], [1.0]], np.float32), + elementwise=False, + ), + dict( + testcase_name="with_nans_elementwise", + input_data=np.array([[1, np.nan], [np.nan, 2]], np.float32), + output_data=np.array([[0, np.nan], [np.nan, 0]], np.float32), + elementwise=True, + ), +] + + +def _sigmoid(x): + return 1 / (1 + np.exp(-x)) + + +def sum_output_dtype(input_dtype): + """Returns the output dtype for tft.sum.""" + return input_dtype if input_dtype.is_floating else tf.int64 + + +def _mean_output_dtype(input_dtype): + """Returns the output dtype for tft.mean (and similar functions).""" + return tf.float64 if input_dtype == tf.float64 else tf.float32 + + +class BeamImplTest(tft_unit.TransformTestCase): + def setUp(self): + super().setUp() + tf.compat.v1.logging.info("Starting test case: %s", self._testMethodName) + self._context = tft_beam.Context(use_deep_copy_optimization=True) + self._context.__enter__() + + def tearDown(self): + super().tearDown() + self._context.__exit__() + + def _OutputRecordBatches(self): + return False + + def _SkipIfOutputRecordBatches(self): + if self._OutputRecordBatches(): + raise unittest.SkipTest( + "Test is disabled when TFT outputs `pa.RecordBatch`es to avoid " + "duplicated testing: it does not exercise `TransformDataset` or " + "`AnalyzeAndTransformDataset`." + ) + + # Overrides that automatically pass the proper value for + # `output_record_batches`. + def assertAnalyzeAndTransformResults(self, *args, **kwargs): + kwargs["output_record_batches"] = self._OutputRecordBatches() + return super().assertAnalyzeAndTransformResults(*args, **kwargs) + + def assertAnalyzerOutputs(self, *args, **kwargs): + kwargs["output_record_batches"] = self._OutputRecordBatches() + return super().assertAnalyzerOutputs(*args, **kwargs) + + def _MakeTransformOutputAssertFn(self, expected, sort=False): + def _assert_fn(actual): + if sort: + dict_key_fn = lambda d: sorted(d.items()) + expected_sorted = sorted(expected, key=dict_key_fn) + actual_sorted = sorted(actual, key=dict_key_fn) + self.assertCountEqual(expected_sorted, actual_sorted) + else: + self.assertCountEqual(expected, actual) + + return _assert_fn + + def testApplySavedModelSingleInput(self): + def save_model_with_single_input(instance, export_dir): + builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) + with tf.compat.v1.Graph().as_default() as graph: + with instance.test_session(graph=graph) as sess: + input1 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="myinput1" + ) + initializer = tf.compat.v1.constant_initializer([1, 2, 3]) + with tf.compat.v1.variable_scope( + "Model", reuse=None, initializer=initializer + ): + v1 = tf.compat.v1.get_variable("v1", [3], dtype=tf.int64) + output1 = tf.add(v1, input1, name="myadd1") + inputs = {"single_input": input1} + outputs = {"single_output": output1} + signature_def_map = { + "serving_default": tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( + inputs, outputs + ) + } + sess.run(tf.compat.v1.global_variables_initializer()) + builder.add_meta_graph_and_variables( + sess, + [tf.saved_model.SERVING], + signature_def_map=signature_def_map, + ) + builder.save(False) + + export_dir = os.path.join(self.get_temp_dir(), "saved_model_single") + + def preprocessing_fn(inputs): + x = inputs["x"] + output_col = pretrained_models.apply_saved_model( + export_dir, x, tags=[tf.saved_model.SERVING] + ) + return {"out": output_col} + + save_model_with_single_input(self, export_dir) + input_data = [ + {"x": [1, 2, 3]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([3], tf.int64), + } + ) + # [1, 2, 3] + [1, 2, 3] = [2, 4, 6] + expected_data = [{"out": [2, 4, 6]}] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"out": tf.io.FixedLenFeature([3], tf.int64)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testApplySavedModelWithHashTable(self): + def save_model_with_hash_table(instance, export_dir): + builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) + with tf.compat.v1.Graph().as_default() as graph: + with instance.test_session(graph=graph) as sess: + key = tf.constant("test_key", shape=[1]) + value = tf.constant("test_value", shape=[1]) + table = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(key, value), "__MISSING__" + ) + + input1 = tf.compat.v1.placeholder( + dtype=tf.string, shape=[1], name="myinput" + ) + output1 = tf.reshape(table.lookup(input1), shape=[1]) + inputs = {"input": input1} + outputs = {"output": output1} + + signature_def_map = { + "serving_default": tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( + inputs, outputs + ) + } + + sess.run(tf.compat.v1.tables_initializer()) + builder.add_meta_graph_and_variables( + sess, + [tf.saved_model.SERVING], + signature_def_map=signature_def_map, + ) + builder.save(False) + + export_dir = os.path.join(self.get_temp_dir(), "saved_model_hash_table") + + def preprocessing_fn(inputs): + x = inputs["x"] + output_col = pretrained_models.apply_saved_model( + export_dir, x, tags=[tf.saved_model.SERVING] + ) + return {"out": output_col} + + save_model_with_hash_table(self, export_dir) + input_data = [{"x": ["test_key"]}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([1], tf.string), + } + ) + expected_data = [{"out": b"test_value"}] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"out": tf.io.FixedLenFeature([], tf.string)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testApplySavedModelMultiInputs(self): + def save_model_with_multi_inputs(instance, export_dir): + builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) + with tf.compat.v1.Graph().as_default() as graph: + with instance.test_session(graph=graph) as sess: + input1 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="myinput1" + ) + input2 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="myinput2" + ) + input3 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="myinput3" + ) + initializer = tf.compat.v1.constant_initializer([1, 2, 3]) + with tf.compat.v1.variable_scope( + "Model", reuse=None, initializer=initializer + ): + v1 = tf.compat.v1.get_variable("v1", [3], dtype=tf.int64) + o1 = tf.add(v1, input1, name="myadd1") + o2 = tf.subtract(o1, input2, name="mysubtract1") + output1 = tf.add(o2, input3, name="myadd2") + inputs = {"name1": input1, "name2": input2, "name3": input3} + outputs = {"single_output": output1} + signature_def_map = { + "serving_default": tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( + inputs, outputs + ) + } + sess.run(tf.compat.v1.global_variables_initializer()) + builder.add_meta_graph_and_variables( + sess, + [tf.saved_model.SERVING], + signature_def_map=signature_def_map, + ) + builder.save(False) + + export_dir = os.path.join(self.get_temp_dir(), "saved_model_multi") + + def preprocessing_fn(inputs): + x = inputs["x"] + y = inputs["y"] + z = inputs["z"] + sum_column = pretrained_models.apply_saved_model( + export_dir, + {"name1": x, "name3": z, "name2": y}, + tags=[tf.saved_model.SERVING], + ) + return {"sum": sum_column} + + save_model_with_multi_inputs(self, export_dir) + input_data = [ + {"x": [1, 2, 3], "y": [2, 3, 4], "z": [1, 1, 1]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([3], tf.int64), + "y": tf.io.FixedLenFeature([3], tf.int64), + "z": tf.io.FixedLenFeature([3], tf.int64), + } + ) + # [1, 2, 3] + [1, 2, 3] - [2, 3, 4] + [1, 1, 1] = [1, 2, 3] + expected_data = [{"sum": [1, 2, 3]}] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"sum": tf.io.FixedLenFeature([3], tf.int64)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testApplyFunctionWithCheckpoint(self): + def tensor_fn(input1, input2): + initializer = tf.compat.v1.constant_initializer([1, 2, 3]) + with tf.compat.v1.variable_scope( + "Model", reuse=None, initializer=initializer + ): + v1 = tf.compat.v1.get_variable("v1", [3], dtype=tf.int64) + v2 = tf.compat.v1.get_variable("v2", [3], dtype=tf.int64) + o1 = tf.add(v1, v2, name="add1") + o2 = tf.subtract(o1, input1, name="sub1") + o3 = tf.subtract(o2, input2, name="sub2") + return o3 + + def save_checkpoint(instance, checkpoint_path): + with tf.compat.v1.Graph().as_default() as graph: + with instance.test_session(graph=graph) as sess: + input1 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="myinput1" + ) + input2 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="myinput2" + ) + tensor_fn(input1, input2) + saver = tf.compat.v1.train.Saver() + sess.run(tf.compat.v1.global_variables_initializer()) + saver.save(sess, checkpoint_path) + + checkpoint_path = os.path.join(self.get_temp_dir(), "chk") + + def preprocessing_fn(inputs): + x = inputs["x"] + y = inputs["y"] + out_value = pretrained_models.apply_function_with_checkpoint( + tensor_fn, [x, y], checkpoint_path + ) + return {"out": out_value} + + save_checkpoint(self, checkpoint_path) + input_data = [ + {"x": [2, 2, 2], "y": [-1, -3, 1]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([3], tf.int64), + "y": tf.io.FixedLenFeature([3], tf.int64), + } + ) + # [1, 2, 3] + [1, 2, 3] - [2, 2, 2] - [-1, -3, 1] = [1, 5, 3] + expected_data = [{"out": [1, 5, 3]}] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"out": tf.io.FixedLenFeature([3], tf.int64)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.named_parameters( + dict(testcase_name="NoDeepCopy", with_deep_copy=False), + dict(testcase_name="WithDeepCopy", with_deep_copy=True), + ) + def testMultipleLevelsOfAnalyzers(self, with_deep_copy): + # Test a preprocessing function similar to scale_to_0_1 except that it + # involves multiple interleavings of analyzers and transforms. + def preprocessing_fn(inputs): + scaled_to_0 = inputs["x"] - tft.min(inputs["x"]) + scaled_to_0_1 = scaled_to_0 / tft.max(scaled_to_0) + return {"x_scaled": scaled_to_0_1} + + input_data = [{"x": 4}, {"x": 1}, {"x": 5}, {"x": 2}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.float32)} + ) + expected_data = [ + {"x_scaled": 0.75}, + {"x_scaled": 0.0}, + {"x_scaled": 1.0}, + {"x_scaled": 0.25}, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"x_scaled": tf.io.FixedLenFeature([], tf.float32)} + ) + with tft_beam.Context(use_deep_copy_optimization=with_deep_copy): + # NOTE: In order to correctly test deep_copy here, we can't pass test_data + # to assertAnalyzeAndTransformResults. + # Not passing test_data to assertAnalyzeAndTransformResults means that + # tft.AnalyzeAndTransform is called, exercising the right code path. + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testRawFeedDictInput(self): + # Test the ability to feed raw data into AnalyzeDataset and TransformDataset + # by using subclasses of these transforms which create batches of size 1. + def preprocessing_fn(inputs): + sequence_example = inputs["sequence_example"] + + # Ordinarily this would have shape (batch_size,) since 'sequence_example' + # was defined as a FixedLenFeature with shape (). But since we specified + # desired_batch_size, we can assume that the shape is (1,), and reshape + # to (). + sequence_example = tf.reshape(sequence_example, ()) + + # Parse the sequence example. + feature_spec = { + "x": tf.io.FixedLenSequenceFeature( + shape=[], dtype=tf.string, default_value=None + ) + } + _, sequences = tf.io.parse_single_sequence_example( + sequence_example, sequence_features=feature_spec + ) + + # Create a batch based on the sequence "x". + return {"x": sequences["x"]} + + def text_sequence_example_to_binary(text_proto): + proto = text_format.Merge(text_proto, tf.train.SequenceExample()) + return proto.SerializeToString() + + sequence_examples = [ + """ + feature_lists: { + feature_list: { + key: "x" + value: { + feature: {bytes_list: {value: 'ab'}} + feature: {bytes_list: {value: ''}} + feature: {bytes_list: {value: 'c'}} + feature: {bytes_list: {value: 'd'}} + } + } + } + """, + """ + feature_lists: { + feature_list: { + key: "x" + value: { + feature: {bytes_list: {value: 'ef'}} + feature: {bytes_list: {value: 'g'}} + } + } + } + """, + ] + input_data = [ + {"sequence_example": text_sequence_example_to_binary(sequence_example)} + for sequence_example in sequence_examples + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"sequence_example": tf.io.FixedLenFeature([], tf.string)} + ) + expected_data = [ + {"x": b"ab"}, + {"x": b""}, + {"x": b"c"}, + {"x": b"d"}, + {"x": b"ef"}, + {"x": b"g"}, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.string)} + ) + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + desired_batch_size=1, + ) + + def testTransformWithExcludedOutputs(self): + def preprocessing_fn(inputs): + return { + "x_scaled": tft.scale_to_0_1(inputs["x"]), + "y_scaled": tft.scale_to_0_1(inputs["y"]), + } + + # Run AnalyzeAndTransform on some input data and compare with expected + # output. + input_data = [{"x": 5, "y": 1}, {"x": 1, "y": 2}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + } + ) + with tft_beam.Context(temp_dir=self.get_temp_dir()): + transform_fn = (input_data, input_metadata) | tft_beam.AnalyzeDataset( + preprocessing_fn + ) + + # Take the transform function and use TransformDataset to apply it to + # some eval data, with missing 'y' column. + eval_data = [{"x": 6}] + eval_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.float32)} + ) + transformed_eval_data, transformed_eval_metadata = ( + (eval_data, eval_metadata), + transform_fn, + ) | tft_beam.TransformDataset( + exclude_outputs=["y_scaled"], + output_record_batches=self._OutputRecordBatches(), + ) + + if self._OutputRecordBatches(): + expected_transformed_eval_data = {"x_scaled": [[1.25]]} + self.assertLen(transformed_eval_data, 1) + # Contains RecordBatch and unary pass-through features dict. + self.assertLen(transformed_eval_data[0], 2) + self.assertDictEqual( + transformed_eval_data[0][0].to_pydict(), expected_transformed_eval_data + ) + self.assertDictEqual(transformed_eval_data[0][1], {}) + else: + expected_transformed_eval_data = [{"x_scaled": 1.25}] + self.assertDataCloseOrEqual( + transformed_eval_data, expected_transformed_eval_data + ) + expected_transformed_eval_metadata = tft.DatasetMetadata.from_feature_spec( + {"x_scaled": tf.io.FixedLenFeature([], tf.float32)} + ) + self.assertEqual( + transformed_eval_metadata.dataset_metadata, + expected_transformed_eval_metadata, + ) + + def testMapWithCond(self): + def preprocessing_fn(inputs): + return { + "a": tf.cond( + pred=tf.constant(True), + true_fn=lambda: inputs["a"], + false_fn=lambda: inputs["b"], + ) + } + + input_data = [ + {"a": 4, "b": 3}, + {"a": 1, "b": 2}, + {"a": 5, "b": 6}, + {"a": 2, "b": 3}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature([], tf.float32), + "b": tf.io.FixedLenFeature([], tf.float32), + } + ) + expected_data = [{"a": 4}, {"a": 1}, {"a": 5}, {"a": 2}] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.float32)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testPyFuncs(self): + def my_multiply(x, y): + return x * y + + def my_add(x, y): + return x + y + + def my_list_return(x, y): + return [x, y, 2 * x, 2 * y] + + def preprocessing_fn(inputs): + result = { + "a+b": tft.apply_pyfunc( + my_add, tf.float32, True, "add", inputs["a"], inputs["b"] + ), + "a+c": tft.apply_pyfunc( + my_add, tf.float32, True, "add", inputs["a"], inputs["c"] + ), + "ab": tft.apply_pyfunc( + my_multiply, tf.float32, False, "multiply", inputs["a"], inputs["b"] + ), + "sum_scaled": tft.scale_to_0_1( + tft.apply_pyfunc( + my_add, tf.float32, True, "add", inputs["a"], inputs["c"] + ) + ), + "list": tf.reduce_sum( + tft.apply_pyfunc( + my_list_return, + [tf.float32, tf.float32, tf.float32, tf.float32], + True, + "my_list_return", + inputs["a"], + inputs["b"], + ), + axis=0, + ), + } + for value in result.values(): + value.set_shape( + [ + 1, + ] + ) + return result + + input_data = [ + {"a": 4, "b": 3, "c": 2}, + {"a": 1, "b": 2, "c": 3}, + {"a": 5, "b": 6, "c": 7}, + {"a": 2, "b": 3, "c": 4}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature([], tf.float32), + "b": tf.io.FixedLenFeature([], tf.float32), + "c": tf.io.FixedLenFeature([], tf.float32), + } + ) + expected_data = [ + {"ab": 12, "a+b": 7, "a+c": 6, "list": 21, "sum_scaled": 0.25}, + {"ab": 2, "a+b": 3, "a+c": 4, "list": 9, "sum_scaled": 0}, + {"ab": 30, "a+b": 11, "a+c": 12, "list": 33, "sum_scaled": 1}, + {"ab": 6, "a+b": 5, "a+c": 6, "list": 15, "sum_scaled": 0.25}, + ] + # When calling tf.py_func, the output shape is set to unknown. + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "ab": tf.io.FixedLenFeature([], tf.float32), + "a+b": tf.io.FixedLenFeature([], tf.float32), + "a+c": tf.io.FixedLenFeature([], tf.float32), + "list": tf.io.FixedLenFeature([], tf.float32), + "sum_scaled": tf.io.FixedLenFeature([], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + force_tf_compat_v1=True, + ) + + def testAssertsNoReturnPyFunc(self): + # Asserts that apply_pyfunc raises an exception if the passed function does + # not return anything. + self._SkipIfOutputRecordBatches() + + def bad_func(): + return None + + with self.assertRaises(ValueError): + tft.apply_pyfunc(bad_func, [], False, "bad_func") + + def testWithMoreThanDesiredBatchSize(self): + def preprocessing_fn(inputs): + return { + "ab": tf.multiply(inputs["a"], inputs["b"]), + "i": tft.compute_and_apply_vocabulary(inputs["c"]), + } + + batch_size = 100 + num_instances = batch_size + 1 + # pylint: disable=g-complex-comprehension + input_data = [ + { + "a": 2, + "b": i, + "c": "%.10i" % i, # Front-padded to facilitate lexicographic sorting. + } + for i in range(num_instances) + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature([], tf.float32), + "b": tf.io.FixedLenFeature([], tf.float32), + "c": tf.io.FixedLenFeature([], tf.string), + } + ) + expected_data = [ + { + "ab": 2 * i, + "i": (len(input_data) - 1) - i, # Due to reverse lexicographic sorting. + } + for i in range(len(input_data)) + ] + # pylint: enable=g-complex-comprehension + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "ab": tf.io.FixedLenFeature([], tf.float32), + "i": tf.io.FixedLenFeature([], tf.int64), + }, + { + "i": schema_pb2.IntDomain( + min=-1, max=num_instances - 1, is_categorical=True + ) + }, + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + desired_batch_size=batch_size, + force_tf_compat_v1=True, + ) + + def testWithUnicode(self): + def preprocessing_fn(inputs): + return { + "a b": tf.compat.v1.strings.join( + [inputs["a"], inputs["b"]], separator=" " + ) + } + + input_data = [{"a": "Hello", "b": "world"}, {"a": "Hello", "b": "κόσμε"}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature([], tf.string), + "b": tf.io.FixedLenFeature([], tf.string), + } + ) + expected_data = [{"a b": b"Hello world"}, {"a b": "Hello κόσμε".encode()}] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"a b": tf.io.FixedLenFeature([], tf.string)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testNpArrayInput(self): + def preprocessing_fn(inputs): + return { + "a b": tf.compat.v1.strings.join( + [inputs["a"], inputs["b"]], separator=" " + ) + } + + input_data = [ + { + "a": np.array("Hello", dtype=object), + "b": np.array("world", dtype=object), + }, + { + "a": np.array("Hello", dtype=object), + "b": np.array("κόσμε", dtype=object), + }, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature([], tf.string), + "b": tf.io.FixedLenFeature([], tf.string), + } + ) + expected_data = [ + {"a b": np.array(b"Hello world", dtype=object)}, + {"a b": np.array("Hello κόσμε".encode(), dtype=object)}, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"a b": tf.io.FixedLenFeature([], tf.string)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.parameters((True,), (False,)) + def testScaleUnitInterval(self, elementwise): + def preprocessing_fn(inputs): + outputs = {} + stacked_input = tf.stack([inputs["x"], inputs["y"]], axis=1) + result = tft.scale_to_0_1(stacked_input, elementwise=elementwise) + outputs["x_scaled"], outputs["y_scaled"] = tf.unstack(result, axis=1) + return outputs + + input_data = [ + {"x": 4, "y": 5}, + {"x": 1, "y": 2}, + {"x": 5, "y": 6}, + {"x": 2, "y": 3}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + } + ) + if elementwise: + expected_data = [ + {"x_scaled": 0.75, "y_scaled": 0.75}, + {"x_scaled": 0.0, "y_scaled": 0.0}, + {"x_scaled": 1.0, "y_scaled": 1.0}, + {"x_scaled": 0.25, "y_scaled": 0.25}, + ] + else: + expected_data = [ + {"x_scaled": 0.6, "y_scaled": 0.8}, + {"x_scaled": 0.0, "y_scaled": 0.2}, + {"x_scaled": 0.8, "y_scaled": 1.0}, + {"x_scaled": 0.2, "y_scaled": 0.4}, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([], tf.float32), + "y_scaled": tf.io.FixedLenFeature([], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.parameters((True,), (False,)) + def testScaleUnitIntervalPerKey(self, elementwise): + def preprocessing_fn(inputs): + outputs = {} + stacked_input = tf.stack([inputs["x"], inputs["y"]], axis=1) + result = tft.scale_to_0_1_per_key(stacked_input, inputs["key"], elementwise) + outputs["x_scaled"], outputs["y_scaled"] = tf.unstack(result, axis=1) + return outputs + + input_data = [ + {"x": 4, "y": 5, "key": "a"}, + {"x": 1, "y": 2, "key": "a"}, + {"x": 5, "y": 6, "key": "a"}, + {"x": 2, "y": 3, "key": "a"}, + {"x": 25, "y": -25, "key": "b"}, + {"x": 5, "y": 0, "key": "b"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ) + if elementwise: + expected_data = [ + {"x_scaled": 0.75, "y_scaled": 0.75}, + {"x_scaled": 0.0, "y_scaled": 0.0}, + {"x_scaled": 1.0, "y_scaled": 1.0}, + {"x_scaled": 0.25, "y_scaled": 0.25}, + {"x_scaled": 1.0, "y_scaled": 0.0}, + {"x_scaled": 0.0, "y_scaled": 1.0}, + ] + else: + expected_data = [ + {"x_scaled": 0.6, "y_scaled": 0.8}, + {"x_scaled": 0.0, "y_scaled": 0.2}, + {"x_scaled": 0.8, "y_scaled": 1.0}, + {"x_scaled": 0.2, "y_scaled": 0.4}, + {"x_scaled": 1.0, "y_scaled": 0.0}, + {"x_scaled": 0.6, "y_scaled": 0.5}, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([], tf.float32), + "y_scaled": tf.io.FixedLenFeature([], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.parameters((True,), (False,)) + def testScaleMinMax(self, elementwise): + def preprocessing_fn(inputs): + outputs = {} + stacked_input = tf.stack([inputs["x"], inputs["y"]], axis=1) + result = tft.scale_by_min_max( + stacked_input, output_min=-1, output_max=1, elementwise=elementwise + ) + outputs["x_scaled"], outputs["y_scaled"] = tf.unstack(result, axis=1) + return outputs + + input_data = [ + {"x": 4, "y": 8}, + {"x": 1, "y": 5}, + {"x": 5, "y": 9}, + {"x": 2, "y": 6}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + } + ) + if elementwise: + expected_data = [ + {"x_scaled": 0.5, "y_scaled": 0.5}, + {"x_scaled": -1.0, "y_scaled": -1.0}, + {"x_scaled": 1.0, "y_scaled": 1.0}, + {"x_scaled": -0.5, "y_scaled": -0.5}, + ] + else: + expected_data = [ + {"x_scaled": -0.25, "y_scaled": 0.75}, + {"x_scaled": -1.0, "y_scaled": 0.0}, + {"x_scaled": 0.0, "y_scaled": 1.0}, + {"x_scaled": -0.75, "y_scaled": 0.25}, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([], tf.float32), + "y_scaled": tf.io.FixedLenFeature([], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.named_parameters( + dict( + testcase_name="_empty_filename", + elementwise=False, + key_vocabulary_filename="", + ), + dict( + testcase_name="_nonempty_filename", + elementwise=False, + key_vocabulary_filename="per_key", + ), + dict( + testcase_name="_none_filename", + elementwise=False, + key_vocabulary_filename=None, + ), + dict( + testcase_name="_elementwise_none_filename", + elementwise=True, + key_vocabulary_filename=None, + ), + ) + def testScaleMinMaxPerKey(self, elementwise, key_vocabulary_filename): + def preprocessing_fn(inputs): + outputs = {} + stacked_input = tf.stack([inputs["x"], inputs["y"]], axis=1) + result = tft.scale_by_min_max_per_key( + stacked_input, + inputs["key"], + output_min=-1, + output_max=1, + elementwise=elementwise, + key_vocabulary_filename=key_vocabulary_filename, + ) + outputs["x_scaled"], outputs["y_scaled"] = tf.unstack(result, axis=1) + return outputs + + input_data = [ + {"x": 4, "y": 8, "key": "a"}, + {"x": 1, "y": 5, "key": "a"}, + {"x": 5, "y": 9, "key": "a"}, + {"x": 2, "y": 6, "key": "a"}, + {"x": -2, "y": 0, "key": "b"}, + {"x": 0, "y": 2, "key": "b"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ) + if elementwise: + expected_data = [ + {"x_scaled": 0.5, "y_scaled": 0.5}, + {"x_scaled": -1.0, "y_scaled": -1.0}, + {"x_scaled": 1.0, "y_scaled": 1.0}, + {"x_scaled": -0.5, "y_scaled": -0.5}, + {"x_scaled": -1.0, "y_scaled": -1.0}, + {"x_scaled": 1.0, "y_scaled": 1.0}, + ] + else: + expected_data = [ + {"x_scaled": -0.25, "y_scaled": 0.75}, + {"x_scaled": -1.0, "y_scaled": 0.0}, + {"x_scaled": 0.0, "y_scaled": 1.0}, + {"x_scaled": -0.75, "y_scaled": 0.25}, + {"x_scaled": -1.0, "y_scaled": 0.0}, + {"x_scaled": 0.0, "y_scaled": 1.0}, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([], tf.float32), + "y_scaled": tf.io.FixedLenFeature([], tf.float32), + } + ) + if key_vocabulary_filename: + per_key_vocab_contents = { + key_vocabulary_filename: [(b"a", [-1.0, 9.0]), (b"b", [2.0, 2.0])] + } + else: + per_key_vocab_contents = None + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + expected_vocab_file_contents=per_key_vocab_contents, + ) + + def testScalePerKeySparse(self): + def preprocessing_fn(inputs): + return { + "scaled_by_min_max": tft.scale_by_min_max_per_key( + inputs["x"], inputs["key"], output_min=-1, output_max=1 + ), + "scaled_to_0_1": tft.scale_to_0_1_per_key(inputs["x"], inputs["key"]), + "scaled_to_z_score": tft.scale_to_z_score_per_key( + inputs["x"], inputs["key"] + ), + } + + input_data = [ + {"val": [4, 8], "s": ["a", "a"]}, + {"val": [1, 5], "s": ["a", "a"]}, + {"val": [5, 9], "s": ["a", "a"]}, + {"val": [2, 6], "s": ["a", "a"]}, + {"val": [-2, 0], "s": ["b", "b"]}, + {"val": [0, 2], "s": ["b", "b"]}, + ] + indices = [([x % 2] * 2, [x % 3] * 2) for x in range(len(input_data))] + indices_x = [{"idx_x_0": a, "idx_x_1": b} for a, b in indices] + indices_key = [{"idx_key_0": a, "idx_key_1": b} for a, b in indices] + input_data = [ + {**a, **b, **c} for a, b, c in zip(input_data, indices_x, indices_key) + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.SparseFeature( + ["idx_x_0", "idx_x_1"], "val", tf.float32, (2, 3) + ), + "key": tf.io.SparseFeature( + ["idx_key_0", "idx_key_1"], "s", tf.string, (2, 3) + ), + } + ) + + output_names = ["scaled_by_min_max", "scaled_to_0_1", "scaled_to_z_score"] + expected_indices_prefix = [ + (("$sparse_indices_0", a), ("$sparse_indices_1", b)) for a, b in indices + ] + expected_indices = [] + for idx0, idx1 in expected_indices_prefix: + instance = {} + for n in output_names: + instance.update({n + idx0[0]: idx0[1]}) + instance.update({n + idx1[0]: idx1[1]}) + expected_indices.append(instance) + + expected_data = [ + { + "scaled_by_min_max$sparse_values": [-0.25, 0.75], + "scaled_to_0_1$sparse_values": np.array([3.0 / 8.0, 7.0 / 8]), + "scaled_to_z_score$sparse_values": np.array( + [-1.0 / math.sqrt(6.5), 3.0 / math.sqrt(6.5)] + ), + }, + { + "scaled_by_min_max$sparse_values": [-1.0, 0.0], + "scaled_to_0_1$sparse_values": np.array([0.0, 0.5]), + "scaled_to_z_score$sparse_values": np.array( + [-4.0 / math.sqrt(6.5), 0.0] + ), + }, + { + "scaled_by_min_max$sparse_values": [0.0, 1.0], + "scaled_to_0_1$sparse_values": np.array([0.5, 1.0]), + "scaled_to_z_score$sparse_values": np.array( + [0.0, 4.0 / math.sqrt(6.5)] + ), + }, + { + "scaled_by_min_max$sparse_values": [-0.75, 0.25], + "scaled_to_0_1$sparse_values": np.array([1.0 / 8.0, 5.0 / 8.0]), + "scaled_to_z_score$sparse_values": np.array( + [-3.0 / math.sqrt(6.5), 1.0 / math.sqrt(6.5)] + ), + }, + { + "scaled_by_min_max$sparse_values": np.array([-1.0, 0.0]), + "scaled_to_0_1$sparse_values": np.array([0.0, 0.5]), + "scaled_to_z_score$sparse_values": np.array([-2.0 / math.sqrt(2), 0.0]), + }, + { + "scaled_by_min_max$sparse_values": [0.0, 1.0], + "scaled_to_0_1$sparse_values": np.array([0.5, 1.0]), + "scaled_to_z_score$sparse_values": np.array([0.0, 2.0 / math.sqrt(2)]), + }, + ] + expected_data = [{**a, **b} for a, b in zip(expected_data, expected_indices)] + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + beam_pipeline=beam.Pipeline(), + ) + + @tft_unit.named_parameters( + dict( + testcase_name="sparse_key", + input_data=[ + {"idx": [0, 1], "val": [-4, 4], "key_idx": [0, 1], "key": ["a", "a"]}, + {"idx": [0, 1], "val": [2, 1], "key_idx": [0, 1], "key": ["a", "b"]}, + {"idx": [0, 1], "val": [-1, 4], "key_idx": [0, 1], "key": ["b", "a"]}, + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.SparseFeature( + "idx", "val", tft_unit.canonical_numeric_dtype(tf.float32), 4 + ), + "key": tf.io.SparseFeature("key_idx", "key", tf.string, 4), + } + ), + expected_data=[ + {"x_scaled": [0.0, 1.0, 0, 0]}, + {"x_scaled": [0.75, 1.0, 0, 0]}, + {"x_scaled": [0.0, 1.0, 0, 0]}, + ], + ), + dict( + testcase_name="dense_key", + input_data=[ + {"idx": [0, 1], "val": [-4, 4], "key": "a"}, + {"idx": [0, 1], "val": [2, 1], "key": "a"}, + {"idx": [0, 1], "val": [-1, 4], "key": "b"}, + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.SparseFeature( + "idx", "val", tft_unit.canonical_numeric_dtype(tf.float32), 4 + ), + "key": tf.io.FixedLenFeature([], tf.string), + } + ), + expected_data=[ + {"x_scaled": [0.0, 1.0, 0, 0]}, + {"x_scaled": [0.75, 0.625, 0, 0]}, + {"x_scaled": [0.0, 1.0, 0, 0]}, + ], + ), + ) + def testScaleMinMaxSparsePerKey(self, input_data, input_metadata, expected_data): + def preprocessing_fn(inputs): + x_scaled = tf.sparse.to_dense( + tft.scale_to_0_1_per_key(inputs["x"], inputs["key"]) + ) + x_scaled.set_shape([None, 4]) + return {"x_scaled": x_scaled} + + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"x_scaled": tf.io.FixedLenFeature([4], tf.float32)} + ) + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.named_parameters( + *tft_unit.cross_named_parameters( + [ + dict( + testcase_name="dense_key", + input_data=[ + { + "x_val": [-4, 4], + "x_row_lengths": [0, 2], + "key": "a", + }, + { + "x_val": [0, 1], + "x_row_lengths": [1, 1], + "key": "a", + }, + { + "x_val": [-4, 1, 1], + "x_row_lengths": [3], + "key": "b", + }, + ], + make_key_spec=lambda: tf.io.FixedLenFeature([], tf.string), + expected_data=[ + { + "scaled_by_min_max$ragged_values": [-1.0, 1.0], + "scaled_by_min_max$row_lengths_1": [0, 2], + "scaled_to_0_1$ragged_values": [0.0, 1.0], + "scaled_to_0_1$row_lengths_1": [0, 2], + "scaled_to_z_score$ragged_values": [-1.4852968, 1.310556], + "scaled_to_z_score$row_lengths_1": [0, 2], + }, + { + "scaled_by_min_max$ragged_values": [0.0, 0.25], + "scaled_by_min_max$row_lengths_1": [1, 1], + "scaled_to_0_1$ragged_values": [0.5, 0.625], + "scaled_to_0_1$row_lengths_1": [1, 1], + "scaled_to_z_score$ragged_values": [-0.0873704, 0.26211122], + "scaled_to_z_score$row_lengths_1": [1, 1], + }, + { + "scaled_by_min_max$ragged_values": [-1.0, 1.0, 1.0], + "scaled_by_min_max$row_lengths_1": [3], + "scaled_to_0_1$ragged_values": [0.0, 1.0, 1.0], + "scaled_to_0_1$row_lengths_1": [3], + "scaled_to_z_score$ragged_values": [ + -1.4142135, + 0.7071068, + 0.7071068, + ], + "scaled_to_z_score$row_lengths_1": [3], + }, + ], + ), + dict( + testcase_name="ragged_key", + input_data=[ + { + "x_val": [-4, 4], + "x_row_lengths": [0, 2], + "key_val": ["a", "a"], + "key_row_lengths": [0, 2], + }, + { + "x_val": [0, 1], + "x_row_lengths": [1, 1], + "key_val": ["a", "b"], + "key_row_lengths": [1, 1], + }, + { + "x_val": [-4, 1, 1], + "x_row_lengths": [3], + "key_val": ["b", "a", "b"], + "key_row_lengths": [3], + }, + ], + make_key_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda + tf.string, + value_key="key_val", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "key_row_lengths" + ) # pytype: disable=attribute-error + ], + ), + expected_data=[ + { + "scaled_by_min_max$ragged_values": [-1.0, 1.0], + "scaled_by_min_max$row_lengths_1": [0, 2], + "scaled_to_0_1$ragged_values": [0.0, 1.0], + "scaled_to_0_1$row_lengths_1": [0, 2], + "scaled_to_z_score$ragged_values": [-1.4852968, 1.310556], + "scaled_to_z_score$row_lengths_1": [0, 2], + }, + { + "scaled_by_min_max$ragged_values": [0.0, 1.0], + "scaled_by_min_max$row_lengths_1": [1, 1], + "scaled_to_0_1$ragged_values": [0.5, 1.0], + "scaled_to_0_1$row_lengths_1": [1, 1], + "scaled_to_z_score$ragged_values": [-0.0873704, 0.7071068], + "scaled_to_z_score$row_lengths_1": [1, 1], + }, + { + "scaled_by_min_max$ragged_values": [-1.0, 0.25, 1.0], + "scaled_by_min_max$row_lengths_1": [3], + "scaled_to_0_1$ragged_values": [0.0, 0.625, 1.0], + "scaled_to_0_1$row_lengths_1": [3], + "scaled_to_z_score$ragged_values": [ + -1.4142135, + 0.26211122, + 0.7071068, + ], + "scaled_to_z_score$row_lengths_1": [3], + }, + ], + ), + ], + [ + dict(testcase_name="int16", input_dtype=tf.int16), + dict(testcase_name="int32", input_dtype=tf.int32), + dict(testcase_name="int64", input_dtype=tf.int64), + dict(testcase_name="float32", input_dtype=tf.float32), + dict(testcase_name="float64", input_dtype=tf.float64), + ], + ) + ) + def testScalePerKeyRagged( + self, input_data, make_key_spec, expected_data, input_dtype + ): + make_x_spec = lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda + tft_unit.canonical_numeric_dtype(input_dtype), + value_key="x_val", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "x_row_lengths" + ) # pytype: disable=attribute-error + ], + ) + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tft_unit.make_feature_spec_wrapper(make_x_spec), + "key": tft_unit.make_feature_spec_wrapper(make_key_spec), + } + ) + + def preprocessing_fn(inputs): + scaled_to_z_score = tft.scale_to_z_score_per_key( + tf.cast(inputs["x"], input_dtype), inputs["key"] + ) + self.assertEqual(scaled_to_z_score.dtype, _mean_output_dtype(input_dtype)) + return { + "scaled_by_min_max": tft.scale_by_min_max_per_key( + tf.cast(inputs["x"], input_dtype), + inputs["key"], + output_min=-1, + output_max=1, + ), + "scaled_to_0_1": tft.scale_to_0_1_per_key( + tf.cast(inputs["x"], input_dtype), inputs["key"] + ), + "scaled_to_z_score": tf.cast(scaled_to_z_score, tf.float32), + } + + expected_specs = {} + for output_name in ("scaled_by_min_max", "scaled_to_0_1", "scaled_to_z_score"): + expected_specs[output_name] = tf.io.RaggedFeature( + tf.float32, + value_key=f"{output_name}$ragged_values", + partitions=[ + tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error + f"{output_name}$row_lengths_1" + ) + ], + ) + expected_metadata = tft.DatasetMetadata.from_feature_spec(expected_specs) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testScaleMinMaxConstant(self): + def preprocessing_fn(inputs): + return {"x_scaled": tft.scale_by_min_max(inputs["x"], 0, 10)} + + input_data = [{"x": 4}, {"x": 4}, {"x": 4}, {"x": 4}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.float32)} + ) + expected_data = [ + {"x_scaled": 9.8201379}, + {"x_scaled": 9.8201379}, + {"x_scaled": 9.8201379}, + {"x_scaled": 9.8201379}, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"x_scaled": tf.io.FixedLenFeature([], tf.float32)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testScaleMinMaxConstantElementwise(self): + def preprocessing_fn(inputs): + outputs = {} + stacked_input = tf.stack([inputs["x"], inputs["y"]], axis=1) + result = tft.scale_by_min_max( + stacked_input, output_min=0, output_max=10, elementwise=True + ) + outputs["x_scaled"], outputs["y_scaled"] = tf.unstack(result, axis=1) + return outputs + + input_data = [ + {"x": 4, "y": 1}, + {"x": 4, "y": 1}, + {"x": 4, "y": 2}, + {"x": 4, "y": 2}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + } + ) + expected_data = [ + {"x_scaled": 9.8201379, "y_scaled": 0}, + {"x_scaled": 9.8201379, "y_scaled": 0}, + {"x_scaled": 9.8201379, "y_scaled": 10}, + {"x_scaled": 9.8201379, "y_scaled": 10}, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([], tf.float32), + "y_scaled": tf.io.FixedLenFeature([], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testScaleMinMaxError(self): + def preprocessing_fn(inputs): + return {"x_scaled": tft.scale_by_min_max(inputs["x"], 2, 1)} + + input_data = [{"x": 1}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.float32)} + ) + expected_data = [{"x_scaled": float("nan")}] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"x_scaled": tf.io.FixedLenFeature([], tf.float32)} + ) + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + ValueError, "output_min must be less than output_max" + ): + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testScaleMinMaxWithEmptyInputs(self): + # x is repeated `multiple` times to test elementwise mapping. + multiple = 3 + + def preprocessing_fn(inputs): + return { + "x_scaled": tft.scale_by_min_max(inputs["x"]), + "x_scaled_elementwise": tft.scale_by_min_max( + tf.tile(inputs["x"], [1, multiple]), elementwise=True + ), + } + + input_data = [] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([1], tf.float32)} + ) + test_data = [{"x": [100]}, {"x": [1]}, {"x": [12]}] + expected_data = [ + {"x_scaled": [v], "x_scaled_elementwise": [v] * multiple} + for v in [1.0, 0.7310585, 0.9999938] + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([1], tf.float32), + "x_scaled_elementwise": tf.io.FixedLenFeature([multiple], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + test_data=test_data, + ) + + @tft_unit.named_parameters( + *(_SCALE_TO_Z_SCORE_TEST_CASES + _SCALE_TO_Z_SCORE_NAN_TEST_CASES) + ) + def testScaleToZScore(self, input_data, output_data, elementwise): + def preprocessing_fn(inputs): + x = inputs["x"] + x_cast = tf.cast(x, tf.as_dtype(input_data.dtype)) + x_scaled = tft.scale_to_z_score(x_cast, elementwise=elementwise) + self.assertEqual(x_scaled.dtype, tf.as_dtype(output_data.dtype)) + return {"x_scaled": tf.cast(x_scaled, tf.float32)} + + input_data_dicts = [{"x": x} for x in input_data] + expected_data_dicts = [{"x_scaled": x_scaled} for x_scaled in output_data] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature( + input_data.shape[1:], + tft_unit.canonical_numeric_dtype(tf.as_dtype(input_data.dtype)), + ), + } + ) + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature(output_data.shape[1:], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data_dicts, + input_metadata, + preprocessing_fn, + expected_data_dicts, + expected_metadata, + ) + + @tft_unit.parameters( + *itertools.product( + [ + tf.int16, + tf.int32, + tf.int64, + tf.float32, + tf.float64, + ], + (True, False), + ) + ) + def testScaleToZScoreSparse(self, input_dtype, elementwise): + def preprocessing_fn(inputs): + z_score = tf.sparse.to_dense( + tft.scale_to_z_score( + tf.cast(inputs["x"], input_dtype), elementwise=elementwise + ), + default_value=np.nan, + ) + z_score.set_shape([None, 4]) + self.assertEqual(z_score.dtype, _mean_output_dtype(input_dtype)) + return {"x_scaled": tf.cast(z_score, tf.float32)} + + input_data = [ + {"idx": [0, 1], "val": [-4, 10]}, + {"idx": [0, 1], "val": [2, 4]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.SparseFeature( + "idx", "val", tft_unit.canonical_numeric_dtype(input_dtype), 4 + ) + } + ) + if elementwise: + # Mean(x) = [-1, 7] + # Var(x) = [9, 9] + # StdDev(x) = [3, 3] + expected_data = [ + { + "x_scaled": [ + -1.0, + 1.0, + float("nan"), + float("nan"), + ] # [(-4 +1 ) / 3, (10 -7) / 3] + }, + { + "x_scaled": [ + 1.0, + -1.0, + float("nan"), + float("nan"), + ] # [(2 + 1) / 3, (4 - 7) / 3] + }, + ] + else: + # Mean = 3 + # Var = 25 + # Std Dev = 5 + expected_data = [ + { + "x_scaled": [ + -1.4, + 1.4, + float("nan"), + float("nan"), + ] # [(-4 - 3) / 5, (10 - 3) / 5] + }, + { + "x_scaled": [ + -0.2, + 0.2, + float("nan"), + float("nan"), + ] # [(2 - 3) / 5, (4 - 3) / 5] + }, + ] + if input_dtype.is_floating: + input_data.append({"idx": [0, 1], "val": [np.nan, np.nan]}) + expected_data.append( + {"x_scaled": [float("nan"), float("nan"), float("nan"), float("nan")]} + ) + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"x_scaled": tf.io.FixedLenFeature([4], tf.float32)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.parameters( + (tf.int16,), + (tf.int32,), + (tf.int64,), + (tf.float32,), + (tf.float64,), + ) + def testScaleToZScoreSparsePerDenseKey(self, input_dtype): + # TODO(b/131852830) Add elementwise tests. + def preprocessing_fn(inputs): + def scale_to_z_score_per_key(tensor, key): + z_score = tft.scale_to_z_score_per_key( + tf.cast(tensor, input_dtype), key=key, elementwise=False + ) + self.assertEqual(z_score.dtype, _mean_output_dtype(input_dtype)) + return tf.cast(z_score, tf.float32) + + return { + "x_scaled": scale_to_z_score_per_key(inputs["x"], inputs["key"]), + "y_scaled": scale_to_z_score_per_key(inputs["y"], inputs["key"]), + } + + np_dtype = input_dtype.as_numpy_dtype + input_data = [ + { + "x": np.array([-4, 2], dtype=np_dtype), + "y": np.array([0, 0], dtype=np_dtype), + "key": "a", + }, + { + "x": np.array([10, 4], dtype=np_dtype), + "y": np.array([0, 0], dtype=np_dtype), + "key": "a", + }, + { + "x": np.array([1, -1], dtype=np_dtype), + "y": np.array([0, 0], dtype=np_dtype), + "key": "b", + }, + ] + # Mean(x) = 3, Mean(y) = 0 + # Var(x) = (-7^2 + -1^2 + 7^2 + 1^2) / 4 = 25, Var(y) = 0 + # StdDev(x) = 5, StdDev(y) = 0 + # 'b': + # Mean(x) = 0, Mean(y) = 0 + # Var(x) = 1, Var(y) = 0 + # StdDev(x) = 1, StdDev(y) = 0 + expected_data = [ + { + "x_scaled": [-1.4, -0.2], # [(-4 - 3) / 5, (2 - 3) / 5] + "y_scaled": [0.0, 0.0], + }, + { + "x_scaled": [1.4, 0.2], # [(10 - 3) / 5, (4 - 3) / 5] + "y_scaled": [0.0, 0.0], + }, + { + "x_scaled": [1.0, -1.0], # [(1 - 0) / 1, (-1 - 0) / 1] + "y_scaled": [0.0, 0.0], + }, + ] + + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.VarLenFeature(tft_unit.canonical_numeric_dtype(input_dtype)), + "y": tf.io.VarLenFeature(tft_unit.canonical_numeric_dtype(input_dtype)), + "key": tf.io.FixedLenFeature([], tf.string), + } + ) + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.VarLenFeature(tf.float32), + "y_scaled": tf.io.VarLenFeature(tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.named_parameters( + dict(testcase_name="_empty_filename", key_vocabulary_filename=""), + dict(testcase_name="_nonempty_filename", key_vocabulary_filename="per_key"), + dict(testcase_name="_none_filename", key_vocabulary_filename=None), + ) + def testScaleToZScorePerKey(self, key_vocabulary_filename): + # TODO(b/131852830) Add elementwise tests. + def preprocessing_fn(inputs): + def scale_to_z_score_per_key(tensor, key, var_name=""): + if key_vocabulary_filename is None: + filename = None + else: + filename = key_vocabulary_filename + var_name + z_score = tft.scale_to_z_score_per_key( + tf.cast(tensor, tf.float32), + key=key, + elementwise=False, + key_vocabulary_filename=filename, + ) + self.assertEqual(z_score.dtype, tf.float32) + return z_score + + return { + "x_scaled": scale_to_z_score_per_key(inputs["x"], inputs["key"], "x"), + "y_scaled": scale_to_z_score_per_key(inputs["y"], inputs["key"], "y"), + "s_scaled": scale_to_z_score_per_key(inputs["s"], inputs["key"], "s"), + } + + np_dtype = np.float32 + input_data = [ + { + "x": np.array([-4], dtype=np_dtype), + "y": np.array([0], dtype=np_dtype), + "s": 3, + "key": "a", + }, + { + "x": np.array([10], dtype=np_dtype), + "y": np.array([0], dtype=np_dtype), + "s": -3, + "key": "a", + }, + { + "x": np.array([1], dtype=np_dtype), + "y": np.array([0], dtype=np_dtype), + "s": 3, + "key": "b", + }, + { + "x": np.array([2], dtype=np_dtype), + "y": np.array([0], dtype=np_dtype), + "s": 3, + "key": "a", + }, + { + "x": np.array([4], dtype=np_dtype), + "y": np.array([0], dtype=np_dtype), + "s": -3, + "key": "a", + }, + { + "x": np.array([-1], dtype=np_dtype), + "y": np.array([0], dtype=np_dtype), + "s": -3, + "key": "b", + }, + { + "x": np.array([np.nan], dtype=np_dtype), + "y": np.array([np.nan], dtype=np_dtype), + "s": np.nan, + "key": "b", + }, + ] + # 'a': + # Mean(x) = 3, Mean(y) = 0 + # Var(x) = (-7^2 + -1^2 + 7^2 + 1^2) / 4 = 25, Var(y) = 0 + # StdDev(x) = 5, StdDev(y) = 0 + # 'b': + # Mean(x) = 0, Mean(y) = 0 + # Var(x) = 1, Var(y) = 0 + # StdDev(x) = 1, StdDev(y) = 0 + expected_data = [ + { + "x_scaled": [-1.4], # [(-4 - 3) / 5, (2 - 3) / 5] + "y_scaled": [0.0], + "s_scaled": 1.0, + }, + { + "x_scaled": [1.4], # [(10 - 3) / 5, (4 - 3) / 5] + "y_scaled": [0.0], + "s_scaled": -1.0, + }, + { + "x_scaled": [1.0], # [(1 - 0) / 1, (-1 - 0) / 1] + "y_scaled": [0.0], + "s_scaled": 1.0, + }, + { + "x_scaled": [-0.2], # [(-4 - 3) / 5, (2 - 3) / 5] + "y_scaled": [0.0], + "s_scaled": 1.0, + }, + { + "x_scaled": [0.2], # [(10 - 3) / 5, (4 - 3) / 5] + "y_scaled": [0.0], + "s_scaled": -1.0, + }, + { + "x_scaled": [-1.0], # [(1 - 0) / 1, (-1 - 0) / 1] + "y_scaled": [0.0], + "s_scaled": -1.0, + }, + { + "x_scaled": [np.nan], + "y_scaled": [np.nan], + "s_scaled": np.nan, + }, + ] + + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature( + [1], tft_unit.canonical_numeric_dtype(tf.float32) + ), + "y": tf.io.FixedLenFeature( + [1], tft_unit.canonical_numeric_dtype(tf.float32) + ), + "s": tf.io.FixedLenFeature( + [], tft_unit.canonical_numeric_dtype(tf.float32) + ), + "key": tf.io.FixedLenFeature([], tf.string), + } + ) + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([1], tf.float32), + "y_scaled": tf.io.FixedLenFeature([1], tf.float32), + "s_scaled": tf.io.FixedLenFeature([], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.named_parameters( + dict( + testcase_name="_float", + input_data=[ + { + "x": [-4, 0], + "key": "a", + }, + { + "x": [10, 0], + "key": "a", + }, + { + "x": [2, 0], + "key": "a", + }, + { + "x": [4, 0], + "key": "a", + }, + { + "x": [1, 0], + "key": "b", + }, + { + "x": [-1, 0], + "key": "b", + }, + { + "x": [np.nan, np.nan], + "key": "b", + }, + ], + # Elementwise = True + # Mean [a, b] = [[ 3.0, 0.0], [0.0, 0.0]] + # Variance [a, b] = [[25.0, 0.0], [1.0, 0.0]] + # StdDev [a, b] = [[ 5.0, 0.0], [1.0, 0.0]] + expected_data=[ + { + "x_scaled": [-1.4, 0.0], # [(-4 - 3) / 5, (0 - 0) / 0] + }, + { + "x_scaled": [1.4, 0.0] # [(10 - 3) / 5, (0 - 0) / 0] + }, + { + "x_scaled": [-0.2, 0.0] # [(2 - 3) / 5, (0 - 0) / 0] + }, + { + "x_scaled": [0.2, 0.0], # [(4 - 3) / 5, (0 - 0) / 0] + }, + { + "x_scaled": [1.0, 0.0] # [(1 - 0) / 1, (0 - 0) / 0] + }, + { + "x_scaled": [-1.0, 0.0] # [(-1 - 0) / 1, (0 - 0) / 0] + }, + {"x_scaled": [np.nan, np.nan]}, + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([2], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ), + expected_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([2], tf.float32), + } + ), + ), + dict( + testcase_name="float_3dims", + input_data=[ + { + "x": [[-4, -8], [-12, -16]], + "key": "a", + }, + { + "x": [[10, 20], [30, 40]], + "key": "a", + }, + { + "x": [[2, 4], [6, 8]], + "key": "a", + }, + { + "x": [[4, 8], [12, 16]], + "key": "a", + }, + { + "x": [[1, 2], [3, 4]], + "key": "b", + }, + ], + expected_data=[ + { + "x_scaled": [[-1.4, -1.4], [-1.4, -1.4]], + }, + { + "x_scaled": [[1.4, 1.4], [1.4, 1.4]], + }, + { + "x_scaled": [[-0.2, -0.2], [-0.2, -0.2]], + }, + { + "x_scaled": [[0.2, 0.2], [0.2, 0.2]], + }, + { + "x_scaled": [[0.0, 0.0], [0.0, 0.0]], + }, + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([2, 2], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ), + expected_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([2, 2], tf.float32), + } + ), + ), + ) + def testScaleToZScorePerKeyElementwise( + self, input_data, expected_data, input_metadata, expected_metadata + ): + def preprocessing_fn(inputs): + outputs = {} + outputs["x_scaled"] = tft.scale_to_z_score_per_key( + tf.cast(inputs["x"], tf.float32), + key=inputs["key"], + elementwise=True, + key_vocabulary_filename=None, + ) + self.assertEqual(outputs["x_scaled"].dtype, tf.float32) + return outputs + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + @tft_unit.parameters( + (tf.int16,), + (tf.int32,), + (tf.int64,), + (tf.float32,), + (tf.float64,), + ) + def testScaleToZScoreSparsePerKey(self, input_dtype): + # TODO(b/131852830) Add elementwise tests. + def preprocessing_fn(inputs): + z_score = tf.sparse.to_dense( + tft.scale_to_z_score_per_key( + tf.cast(inputs["x"], input_dtype), inputs["key"], elementwise=False + ), + default_value=np.nan, + ) + z_score.set_shape([None, 4]) + self.assertEqual(z_score.dtype, _mean_output_dtype(input_dtype)) + return {"x_scaled": tf.cast(z_score, tf.float32)} + + input_data = [ + {"idx": [0, 1], "val": [-4, 10], "key_idx": [0, 1], "key": ["a", "a"]}, + {"idx": [0, 1], "val": [2, 1], "key_idx": [0, 1], "key": ["a", "b"]}, + {"idx": [0, 1], "val": [-1, 4], "key_idx": [0, 1], "key": ["b", "a"]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "key": tf.io.SparseFeature("key_idx", "key", tf.string, 4), + "x": tf.io.SparseFeature( + "idx", "val", tft_unit.canonical_numeric_dtype(input_dtype), 4 + ), + } + ) + # 'a': + # Mean = 3 + # Var = 25 + # Std Dev = 5 + # 'b': + # Mean = 0 + # Var = 1 + # Std Dev = 1 + expected_data = [ + { + "x_scaled": [ + -1.4, + 1.4, + float("nan"), + float("nan"), + ] # [(-4 - 3) / 5, (10 - 3) / 5] + }, + { + "x_scaled": [ + -0.2, + 1.0, + float("nan"), + float("nan"), + ] # [(2 - 3) / 5, (1 - 0) / 1] + }, + { + "x_scaled": [ + -1.0, + 0.2, + float("nan"), + float("nan"), + ] # [(-1 - 0) / 1, (4 - 3) / 5] + }, + ] + if input_dtype.is_floating: + input_data.append( + { + "idx": [0, 1], + "val": [np.nan, np.nan], + "key_idx": [0, 1], + "key": ["a", "b"], + } + ) + expected_data.append( + {"x_scaled": [float("nan"), float("nan"), float("nan"), float("nan")]} + ) + expected_metadata = tft.DatasetMetadata.from_feature_spec( + {"x_scaled": tf.io.FixedLenFeature([4], tf.float32)} + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testScaleToZScoreWithEmptyInputs(self): + # x is repeated `multiple` times to test elementwise mapping. + multiple = 3 + + def preprocessing_fn(inputs): + return { + "x_scaled": tft.scale_to_z_score(inputs["x"]), + "x_scaled_elementwise": tft.scale_to_z_score( + tf.tile(inputs["x"], [1, multiple]), elementwise=True + ), + } + + input_data = [] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([1], tf.float32)} + ) + test_data = [{"x": [100]}, {"x": [1]}, {"x": [12]}] + expected_data = [ + {"x_scaled": [v], "x_scaled_elementwise": [v] * multiple} + for v in [100.0, 1.0, 12.0] + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_scaled": tf.io.FixedLenFeature([1], tf.float32), + "x_scaled_elementwise": tf.io.FixedLenFeature([multiple], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + test_data=test_data, + ) + + def testMeanAndVar(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + mean, var = analyzers._mean_and_var(inputs["x"]) + return {"mean": mean, "var": var} + + # NOTE: We force 11 batches: data has 110 elements and we request a batch + # size of 10. + input_data = [{"x": [x if x < 101 else np.nan]} for x in range(1, 111)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([1], tf.float32)} + ) + expected_outputs = {"mean": np.float32(50.5), "var": np.float32(833.25)} + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=10, + ) + + def testMeanAndVarPerKey(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + key_vocab, mean, var = analyzers._mean_and_var_per_key( + inputs["x"], inputs["key"] + ) + return { + "key_vocab": key_vocab, + "mean": mean, + "var": tf.round(100 * var) / 100.0, + } + + # NOTE: We force 12 batches: data has 120 elements and we request a batch + # size of 10. + input_data = [ + {"x": [x], "key": "a" if x < 50 else "b"} for x in range(1, 101) + ] + [{"x": [np.nan], "key": "a"}] * 20 + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([1], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ) + expected_outputs = { + "key_vocab": np.array([b"a", b"b"], object), + "mean": np.array([25, 75], np.float32), + "var": np.array([200, 216.67], np.float32), + } + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=10, + ) + + def testMeanAndVarPerKeyElementwise(self): + def analyzer_fn(inputs): + key_vocab, mean, var = analyzers._mean_and_var_per_key( + inputs["x"], inputs["key"], reduce_instance_dims=False + ) + return { + "key_vocab": key_vocab, + "mean": mean, + "var": tf.round(100 * var) / 100.0, + } + + input_data = input_data = [ + { + "x": [-4, -1], + "key": "a", + }, + { + "x": [10, 0], + "key": "a", + }, + { + "x": [2, 0], + "key": "a", + }, + { + "x": [4, -1], + "key": "a", + }, + { + "x": [10, 0], + "key": "b", + }, + { + "x": [0, 10], + "key": "b", + }, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([2], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ) + expected_outputs = { + "key_vocab": np.array([b"a", b"b"], object), + "mean": np.array([[3.0, -0.5], [5.0, 5.0]], np.float32), + "var": np.array([[25.0, 0.25], [25.0, 25.0]], np.float32), + } + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + @tft_unit.named_parameters( + dict( + testcase_name="_dense_2d", + input_data=[ + {"x": [4, 8], "key": "a"}, + {"x": [1, 5], "key": "a"}, + {"x": [5, 9], "key": "a"}, + {"x": [2, 6], "key": "a"}, + {"x": [-2, 0], "key": "b"}, + {"x": [0, 2], "key": "b"}, + {"x": [2, 4], "key": "b"}, + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([2], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ), + reduce_instance_dims=True, + expected_outputs={ + "key_vocab": np.array([b"a", b"b"], object), + "min_x_value": np.array([1, -2], np.float32), + "max_x_value": np.array([9, 4], np.float32), + }, + ), + dict( + testcase_name="_dense_2d_elementwise", + input_data=[ + {"x": [4, 8], "key": "a"}, + {"x": [1, 5], "key": "a"}, + {"x": [5, 9], "key": "a"}, + {"x": [2, 6], "key": "a"}, + {"x": [-2, 0], "key": "b"}, + {"x": [0, 2], "key": "b"}, + {"x": [2, 4], "key": "b"}, + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([2], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ), + reduce_instance_dims=False, + expected_outputs={ + "key_vocab": np.array([b"a", b"b"], object), + "min_x_value": np.array([[1, 5], [-2, 0]], np.float32), + "max_x_value": np.array([[5, 9], [2, 4]], np.float32), + }, + ), + dict( + testcase_name="_dense_3d", + input_data=[ + {"x": [[1, 5], [1, 1]], "key": "a"}, + {"x": [[5, 1], [5, 5]], "key": "a"}, + {"x": [[2, 2], [2, 5]], "key": "a"}, + {"x": [[3, -3], [3, 3]], "key": "b"}, + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([2, 2], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ), + reduce_instance_dims=True, + expected_outputs={ + "key_vocab": np.array([b"a", b"b"], object), + "min_x_value": np.array([1, -3], np.float32), + "max_x_value": np.array([5, 3], np.float32), + }, + ), + dict( + testcase_name="_dense_3d_elementwise", + input_data=[ + {"x": [[1, 5], [1, 1]], "key": "a"}, + {"x": [[5, 1], [5, 5]], "key": "a"}, + {"x": [[2, 2], [2, 5]], "key": "a"}, + {"x": [[3, -3], [3, 3]], "key": "b"}, + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([2, 2], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ), + reduce_instance_dims=False, + expected_outputs={ + "key_vocab": np.array([b"a", b"b"], object), + "min_x_value": np.array( + [[[1, 1], [1, 1]], [[3, -3], [3, 3]]], np.float32 + ), + "max_x_value": np.array( + [[[5, 5], [5, 5]], [[3, -3], [3, 3]]], np.float32 + ), + }, + ), + ) + def testMinAndMaxPerKey( + self, input_data, input_metadata, reduce_instance_dims, expected_outputs + ): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + key_vocab, min_x_value, max_x_value = analyzers._min_and_max_per_key( + x=inputs["x"], + key=inputs["key"], + reduce_instance_dims=reduce_instance_dims, + ) + return { + "key_vocab": key_vocab, + "min_x_value": min_x_value, + "max_x_value": max_x_value, + } + + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + @tft_unit.parameters((True,), (False,)) + def testPerKeyWithOOVKeys(self, use_vocabulary): + def preprocessing_fn(inputs): + result = {} + result["x_scaled"] = tft.scale_to_0_1_per_key( + inputs["x"], + inputs["key"], + elementwise=False, + key_vocabulary_filename="a" if use_vocabulary else None, + ) + result["x_z_score"] = tft.scale_to_z_score_per_key( + inputs["x"], + inputs["key"], + elementwise=False, + key_vocabulary_filename="b" if use_vocabulary else None, + ) + # TODO(b/179891014): Add key_vocabulary_filename to bucketize_per_key once + # implemented. + result["x_bucketized"] = tft.bucketize_per_key( + inputs["x"], inputs["key"], 3 + ) + return result + + input_data = [ + dict(x=4, key="a"), + dict(x=1, key="a"), + dict(x=5, key="a"), + dict(x=2, key="a"), + dict(x=25, key="b"), + dict(x=5, key="b"), + ] + test_data = input_data + [dict(x=5, key="oov")] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "key": tf.io.FixedLenFeature([], tf.string), + } + ) + + expected_data = [ + { + "x_scaled": 0.75, + "x_z_score": 0.6324555, + "x_bucketized": 2, + }, + { + "x_scaled": 0.0, + "x_z_score": -1.264911, + "x_bucketized": 0, + }, + { + "x_scaled": 1.0, + "x_z_score": 1.264911, + "x_bucketized": 2, + }, + { + "x_scaled": 0.25, + "x_z_score": -0.6324555, + "x_bucketized": 1, + }, + { + "x_scaled": 1.0, + "x_z_score": 1.0, + "x_bucketized": 2, + }, + { + "x_scaled": 0.0, + "x_z_score": -1.0, + "x_bucketized": 1, + }, + { + "x_scaled": _sigmoid(5), + "x_z_score": 5.0, + "x_bucketized": -1, + }, + ] + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + test_data=test_data, + ) + + @tft_unit.named_parameters( + dict( + testcase_name="_string", + input_data=[{"key": "a" if x < 25 else "b"} for x in range(100)], + input_metadata=tft.DatasetMetadata.from_feature_spec( + {"key": tf.io.FixedLenFeature([], tf.string)} + ), + expected_outputs={ + "elements": np.array([b"a", b"b"], object), + "counts": np.array([25, 75], np.int64), + }, + ), + dict( + testcase_name="_int", + input_data=[{"key": 0 if x < 25 else 1} for x in range(100)], + input_metadata=tft.DatasetMetadata.from_feature_spec( + {"key": tf.io.FixedLenFeature([], tf.int64)} + ), + expected_outputs={ + "elements": np.array([0, 1], np.int64), + "counts": np.array([25, 75], np.int64), + }, + ), + dict( + testcase_name="_int_sparse", + input_data=[{"key": [0] if x < 25 else [1]} for x in range(100)], + input_metadata=tft.DatasetMetadata.from_feature_spec( + {"key": tf.io.VarLenFeature(tf.int64)} + ), + expected_outputs={ + "elements": np.array([0, 1], np.int64), + "counts": np.array([25, 75], np.int64), + }, + ), + dict( + testcase_name="_3d_sparse", + input_data=[ + { # pylint: disable=g-complex-comprehension + "key": [0, 1] if x < 25 else [1], + "idx0": [0, 1] if x < 25 else [0], + "idx1": [0, 1] if x < 25 else [0], + } + for x in range(100) + ], + input_metadata=tft.DatasetMetadata.from_feature_spec( + {"key": tf.io.SparseFeature(["idx0", "idx1"], "key", tf.int64, [2, 2])} + ), + expected_outputs={ + "elements": np.array([0, 1], np.int64), + "counts": np.array([25, 100], np.int64), + }, + ), + ) + def testCountPerKey(self, input_data, input_metadata, expected_outputs): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + elements, counts = analyzers.count_per_key(inputs["key"]) + return {"elements": elements, "counts": counts} + + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + @tft_unit.named_parameters( + dict( + testcase_name="_uniform", + input_data=[{"x": [x]} for x in range(10, 100)], + make_feature_spec=lambda: tf.io.FixedLenFeature([1], tf.int64), + boundaries=10 * np.arange(11, dtype=np.float32), + categorical=False, + expected_outputs={ + "hist": 10 * np.array([0] + [1] * 9, np.int64), + "boundaries": 10 * np.arange(11, dtype=np.float32).reshape((1, 11)), + }, + ), + dict( + testcase_name="_categorical_string", + input_data=[{"x": [str(x % 10) + "_"]} for x in range(1, 101)], + make_feature_spec=lambda: tf.io.FixedLenFeature([1], tf.string), + boundaries=None, + categorical=True, + expected_outputs={ + "hist": 10 * np.ones(10, np.int64), + "boundaries": np.asarray( + sorted([tf.compat.as_bytes(str(x % 10) + "_") for x in range(10)]), + dtype=object, + ), + }, + ), + dict( + testcase_name="_categorical_int", + input_data=[{"x": [(x % 10)]} for x in range(1, 101)], + make_feature_spec=lambda: tf.io.FixedLenFeature([1], tf.int64), + boundaries=None, + categorical=True, + expected_outputs={ + "hist": 10 * np.ones(10, np.int64), + "boundaries": np.arange(10), + }, + ), + dict( + testcase_name="_sparse", + input_data=[ + { # pylint: disable=g-complex-comprehension + "val": [(x % 10)], + "idx0": [(x % 2)], + "idx1": [((x + 1) % 2)], + } + for x in range(1, 101) + ], + make_feature_spec=lambda: tf.io.SparseFeature( # pylint: disable=g-long-lambda + ["idx0", "idx1"], "val", tf.int64, [2, 2] + ), + boundaries=None, + categorical=True, + expected_outputs={ + "hist": 10 * np.ones(10, np.int64), + "boundaries": np.arange(10), + }, + ), + dict( + testcase_name="_ragged", + input_data=[ + { # pylint: disable=g-complex-comprehension + "val": [x % 10, 9 - (x % 10)], + "row_lengths": [0, 1, 1], + } + for x in range(1, 101) + ], + make_feature_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda + tf.int64, + value_key="val", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "row_lengths" + ) # pytype: disable=attribute-error + ], + ), + boundaries=None, + categorical=True, + expected_outputs={ + "hist": 20 * np.ones(10, np.int64), + "boundaries": np.arange(10), + }, + ), + ) + def testHistograms( + self, input_data, make_feature_spec, boundaries, categorical, expected_outputs + ): + self._SkipIfOutputRecordBatches() + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tft_unit.make_feature_spec_wrapper(make_feature_spec)} + ) + + def analyzer_fn(inputs): + counts, bucket_boundaries = analyzers.histogram( + inputs["x"], categorical=categorical, boundaries=boundaries + ) + if not categorical: + bucket_boundaries = tf.math.round(bucket_boundaries) + return {"hist": counts, "boundaries": bucket_boundaries} + + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + def testProbCategoricalInt(self): + def preprocessing_fn(inputs): + return { + "probs": tft.estimated_probability_density( + inputs["x"], categorical=True + ) + } + + # NOTE: We force 10 batches: data has 100 elements and we request a batch + # size of 10. + input_data = [{"x": [x % 10]} for x in range(1, 101)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([1], tf.int64)} + ) + expected_outputs = [ + {"probs": np.array(np.ones(1) / 10.0, np.float32)} for _ in range(100) + ] + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_outputs, + desired_batch_size=10, + ) + + def testProbCategorical(self): + def preprocessing_fn(inputs): + return { + "probs": tft.estimated_probability_density( + inputs["x"], categorical=True + ) + } + + # NOTE: We force 10 batches: data has 100 elements and we request a batch + # size of 10. + input_data = [{"x": [str(x % 10) + "_"]} for x in range(1, 101)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([1], tf.string)} + ) + expected_outputs = [ + {"probs": np.array(np.ones(1) / 10.0, np.float32)} for _ in range(100) + ] + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_outputs, + desired_batch_size=10, + ) + + def testProbTenBoundaries(self): + # If we draw uniformly from a range (0, 100], the expected density is 0.01. + def preprocessing_fn(inputs): + return { + "probs": tft.estimated_probability_density( + inputs["x"], boundaries=list(range(0, 101, 10)) + ) + } + + # NOTE: We force 10 batches: data has 100 elements and we request a batch + # size of 10. + input_data = [{"x": [x]} for x in range(100)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([1], tf.int64)} + ) + expected_outputs = [ + {"probs": np.array(np.ones(1) / (100.0), np.float32)} for _ in range(100) + ] + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_outputs, + desired_batch_size=10, + ) + + @tft_unit.named_parameters( + { + "testcase_name": "uniform", + "boundaries": 6, + "input_data": [{"x": [x]} for x in range(100)], + "expected_outputs": [ + {"probs": np.array(np.ones(1) / 99.0, np.float32)} for _ in range(100) + ], + }, + { + "testcase_name": "nonuniform_with_zeros", + "boundaries": 5, + "input_data": [ + {"x": [x]} + for x in list(range(25)) + + (list(range(50, 75)) + list(range(50, 75)) + list(range(75, 100))) + ], + "expected_outputs": [ + { + "probs": np.ones((1), np.float32) + / 99.0 + * (2.0 if 24 < i < 75 else 1.0) + } + for i in range(100) + ], + }, + { + "testcase_name": "empty", + "boundaries": 5, + "input_data": [], + "expected_outputs": [], + }, + ) + def testProbUnknownBoundaries(self, input_data, expected_outputs, boundaries): + # Test 1 has 100 points over a range of 99; test 2 is an uneven distribution + def preprocessing_fn(inputs): + return { + "probs": tft.estimated_probability_density( + inputs["x"], boundaries=boundaries + ) + } + + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([1], tf.int64)} + ) + + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn, expected_outputs + ) + + @tft_unit.named_parameters( + dict( + testcase_name="Int64In", + input_dtype=tf.int64, + output_dtypes={ + "min": tf.int64, + "max": tf.int64, + "sum": tf.int64, + "size": tf.int64, + "mean": tf.float32, + "var": tf.float32, + }, + ), + dict( + testcase_name="Int32In", + input_dtype=tf.int32, + output_dtypes={ + "min": tf.int32, + "max": tf.int32, + "sum": tf.int64, + "size": tf.int64, + "mean": tf.float32, + "var": tf.float32, + }, + ), + dict( + testcase_name="Int16In", + input_dtype=tf.int16, + output_dtypes={ + "min": tf.int16, + "max": tf.int16, + "sum": tf.int64, + "size": tf.int64, + "mean": tf.float32, + "var": tf.float32, + }, + ), + dict( + testcase_name="Float64In", + input_dtype=tf.float64, + output_dtypes={ + "min": tf.float64, + "max": tf.float64, + "sum": tf.float64, + "size": tf.int64, + "mean": tf.float64, + "var": tf.float64, + }, + ), + dict( + testcase_name="Float32In", + input_dtype=tf.float32, + output_dtypes={ + "min": tf.float32, + "max": tf.float32, + "sum": tf.float32, + "size": tf.int64, + "mean": tf.float32, + "var": tf.float32, + }, + ), + dict( + testcase_name="Float16In", + input_dtype=tf.float16, + output_dtypes={ + "min": tf.float16, + "max": tf.float16, + "sum": tf.float32, + "size": tf.int64, + "mean": tf.float16, + "var": tf.float16, + }, + ), + ) + def testNumericAnalyzersWithScalarInputs(self, input_dtype, output_dtypes): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + a = tf.cast(inputs["a"], input_dtype) + + def assert_and_cast_dtype(tensor, out_dtype): + self.assertEqual(tensor.dtype, out_dtype) + return tf.cast(tensor, tft_unit.canonical_numeric_dtype(out_dtype)) + + return { + "min": assert_and_cast_dtype(tft.min(a), output_dtypes["min"]), + "max": assert_and_cast_dtype(tft.max(a), output_dtypes["max"]), + "sum": assert_and_cast_dtype(tft.sum(a), output_dtypes["sum"]), + "size": assert_and_cast_dtype(tft.size(a), output_dtypes["size"]), + "mean": assert_and_cast_dtype(tft.mean(a), output_dtypes["mean"]), + "var": assert_and_cast_dtype(tft.var(a), output_dtypes["var"]), + } + + input_data = [{"a": 4}, {"a": 1}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature( + [], tft_unit.canonical_numeric_dtype(input_dtype) + ) + } + ) + expected_outputs = { + "min": np.array( + 1, tft_unit.canonical_numeric_dtype(output_dtypes["min"]).as_numpy_dtype + ), + "max": np.array( + 4, tft_unit.canonical_numeric_dtype(output_dtypes["max"]).as_numpy_dtype + ), + "sum": np.array( + 5, tft_unit.canonical_numeric_dtype(output_dtypes["sum"]).as_numpy_dtype + ), + "size": np.array( + 2, + tft_unit.canonical_numeric_dtype(output_dtypes["size"]).as_numpy_dtype, + ), + "mean": np.array( + 2.5, + tft_unit.canonical_numeric_dtype(output_dtypes["mean"]).as_numpy_dtype, + ), + "var": np.array( + 2.25, + tft_unit.canonical_numeric_dtype(output_dtypes["var"]).as_numpy_dtype, + ), + } + + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + @tft_unit.named_parameters( + *tft_unit.cross_named_parameters( + [ + dict( + testcase_name="sparse", + input_data=[ + { + "idx0": [0, 1], + "idx1": [0, 1], + "val": [0, 1], + }, + { + "idx0": [1, 2], + "idx1": [1, 3], + "val": [2, 3], + }, + ], + make_feature_spec=lambda dtype: tf.io.SparseFeature( # pylint: disable=g-long-lambda + ["idx0", "idx1"], "val", dtype, (3, 4) + ), + expected_outputs={ + "min": 0.0, + "max": 3.0, + "sum": 6.0, + "size": 4, + "mean": 1.5, + "var": 1.25, + }, + reduce_instance_dims=True, + ), + dict( + testcase_name="sparse_elementwise", + input_data=[ + { + "idx0": [0, 1], + "idx1": [0, 1], + "val": [0, 1], + }, + { + "idx0": [1, 2], + "idx1": [1, 3], + "val": [2, 3], + }, + ], + make_feature_spec=lambda dtype: tf.io.SparseFeature( # pylint: disable=g-long-lambda + ["idx0", "idx1"], "val", dtype, (3, 4) + ), + expected_outputs={ + # We use np.nan in place of missing values here but replace + # them accordingly to the dtype in the test. + "min": [ + [0.0, np.nan, np.nan, np.nan], + [np.nan, 1.0, np.nan, np.nan], + [np.nan, np.nan, np.nan, 3.0], + ], + "max": [ + [0.0, np.nan, np.nan, np.nan], + [np.nan, 2.0, np.nan, np.nan], + [np.nan, np.nan, np.nan, 3.0], + ], + "sum": [ + [0.0, 0.0, 0.0, 0.0], + [0.0, 3.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 3.0], + ], + "size": [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 1]], + "mean": [ + [0.0, np.nan, np.nan, np.nan], + [np.nan, 1.5, np.nan, np.nan], + [np.nan, np.nan, np.nan, 3.0], + ], + "var": [ + [0.0, np.nan, np.nan, np.nan], + [np.nan, 0.25, np.nan, np.nan], + [np.nan, np.nan, np.nan, 0.0], + ], + }, + reduce_instance_dims=False, + ), + dict( + testcase_name="ragged", + input_data=[ + { + "val": [0.0, 2.0, 3.0], + "row_lengths": [0, 3], + }, + { + "val": [3.0, 3.0, 1.0], + "row_lengths": [3], + }, + ], + make_feature_spec=lambda dtype: tf.io.RaggedFeature( # pylint: disable=g-long-lambda + dtype, + value_key="val", + partitions=[tf.io.RaggedFeature.RowLengths("row_lengths")], + ), # pytype: disable=attribute-error + expected_outputs={ + "min": 0.0, + "max": 3.0, + "sum": 12.0, + "size": 6, + "mean": 2.0, + "var": 1.333333, + }, + reduce_instance_dims=True, + ), + ], + [ + dict(testcase_name="int16", input_dtype=tf.int16), + dict(testcase_name="int32", input_dtype=tf.int32), + dict(testcase_name="int64", input_dtype=tf.int64), + dict(testcase_name="float32", input_dtype=tf.float32), + dict(testcase_name="tf.float64", input_dtype=tf.float64), + dict(testcase_name="tf.uint8", input_dtype=tf.uint8), + dict(testcase_name="tf.uint16", input_dtype=tf.uint16), + ], + ) + ) + def testNumericAnalyzersWithCompositeInputs( + self, + input_data, + make_feature_spec, + expected_outputs, + reduce_instance_dims, + input_dtype, + ): + self._SkipIfOutputRecordBatches() + output_dtype = tft_unit.canonical_numeric_dtype(input_dtype) + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tft_unit.make_feature_spec_wrapper(make_feature_spec, output_dtype)} + ) + + def analyzer_fn(inputs): + return { + "min": tft.min(inputs["a"], reduce_instance_dims), + "max": tft.max(inputs["a"], reduce_instance_dims), + "sum": tft.sum(inputs["a"], reduce_instance_dims), + "size": tft.size(inputs["a"], reduce_instance_dims), + "mean": tft.mean(inputs["a"], reduce_instance_dims), + "var": tft.var(inputs["a"], reduce_instance_dims), + } + + input_val_dtype = input_dtype.as_numpy_dtype + # Cast input values to appropriate type. + for instance in input_data: + instance["val"] = np.array(instance["val"], input_val_dtype) + if not reduce_instance_dims: + if input_dtype.is_floating: + missing_value_max = float("nan") + missing_value_min = float("nan") + else: + missing_value_max = np.iinfo(output_dtype.as_numpy_dtype).min + missing_value_min = np.iinfo(output_dtype.as_numpy_dtype).max + # Replace NaNs with proper missing values. + for row in expected_outputs["min"]: + for idx in range(len(row)): + if np.isnan(row[idx]): + row[idx] = missing_value_min + for row in expected_outputs["max"]: + for idx in range(len(row)): + if np.isnan(row[idx]): + row[idx] = missing_value_max + for op in ("min", "max", "sum"): + expected_outputs[op] = np.array( + expected_outputs[op], output_dtype.as_numpy_dtype + ) + expected_outputs["size"] = np.array(expected_outputs["size"], np.int64) + expected_outputs["mean"] = np.array(expected_outputs["mean"], np.float32) + expected_outputs["var"] = np.array(expected_outputs["var"], np.float32) + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + @tft_unit.named_parameters( + dict( + testcase_name="sparse", + input_data=[ + { + "idx0": [0, 1], + "idx1": [0, 1], + "val": np.array([0, 1], dtype=np.int64), + }, + { + "idx0": [1, 2], + "idx1": [1, 3], + "val": np.array([2, 3], dtype=np.int64), + }, + ], + make_feature_spec=lambda: tf.io.SparseFeature( # pylint: disable=g-long-lambda + ["idx0", "idx1"], "val", tf.int64, (3, 4) + ), + elementwise=False, + expected_outputs=[ + { + "scale_to_0_1$sparse_indices_0": np.array([0, 1]), + "scale_to_0_1$sparse_indices_1": np.array([0, 1]), + "scale_to_z_score$sparse_indices_0": np.array([0, 1]), + "scale_to_z_score$sparse_indices_1": np.array([0, 1]), + "scale_by_min_max$sparse_indices_0": np.array([0, 1]), + "scale_by_min_max$sparse_indices_1": np.array([0, 1]), + "scale_to_0_1$sparse_values": np.array( + [0.0, 1.0 / 3.0], dtype=np.float32 + ), + "scale_to_z_score$sparse_values": np.array( + [-1.5 / np.sqrt(1.25), -0.5 / np.sqrt(1.25)], dtype=np.float32 + ), + "scale_by_min_max$sparse_values": np.array( + [0.0, 1.0 / 3.0], dtype=np.float32 + ), + }, + { + "scale_to_0_1$sparse_indices_0": np.array([1, 2]), + "scale_to_0_1$sparse_indices_1": np.array([1, 3]), + "scale_to_z_score$sparse_indices_0": np.array([1, 2]), + "scale_to_z_score$sparse_indices_1": np.array([1, 3]), + "scale_by_min_max$sparse_indices_0": np.array([1, 2]), + "scale_by_min_max$sparse_indices_1": np.array([1, 3]), + "scale_to_0_1$sparse_values": np.array( + [2.0 / 3.0, 1.0], dtype=np.float32 + ), + "scale_to_z_score$sparse_values": np.array( + [0.5 / np.sqrt(1.25), 1.5 / np.sqrt(1.25)], dtype=np.float32 + ), + "scale_by_min_max$sparse_values": np.array( + [2.0 / 3.0, 1.0], dtype=np.float32 + ), + }, + ], + ), + dict( + testcase_name="sparse_elementwise", + input_data=[ + { + "idx0": [0, 1], + "idx1": [0, 1], + "val": np.array([0, 1], dtype=np.int64), + }, + { + "idx0": [1, 2], + "idx1": [1, 3], + "val": np.array([2, 3], dtype=np.int64), + }, + ], + make_feature_spec=lambda: tf.io.SparseFeature( # pylint: disable=g-long-lambda + ["idx0", "idx1"], "val", tf.int64, (3, 4) + ), + elementwise=True, + expected_outputs=[ + { + "scale_to_0_1$sparse_indices_0": np.array([0, 1]), + "scale_to_0_1$sparse_indices_1": np.array([0, 1]), + "scale_to_z_score$sparse_indices_0": np.array([0, 1]), + "scale_to_z_score$sparse_indices_1": np.array([0, 1]), + "scale_by_min_max$sparse_indices_0": np.array([0, 1]), + "scale_by_min_max$sparse_indices_1": np.array([0, 1]), + "scale_to_0_1$sparse_values": np.array( + [0.5, 0.0], dtype=np.float32 + ), + "scale_to_z_score$sparse_values": np.array( + [0, -1], dtype=np.float32 + ), + "scale_by_min_max$sparse_values": np.array( + [0.5, 0.0], dtype=np.float32 + ), + }, + { + "scale_to_0_1$sparse_indices_0": np.array([1, 2]), + "scale_to_0_1$sparse_indices_1": np.array([1, 3]), + "scale_to_z_score$sparse_indices_0": np.array([1, 2]), + "scale_to_z_score$sparse_indices_1": np.array([1, 3]), + "scale_by_min_max$sparse_indices_0": np.array([1, 2]), + "scale_by_min_max$sparse_indices_1": np.array([1, 3]), + "scale_to_0_1$sparse_values": np.array( + [1.0, _sigmoid(3)], dtype=np.float32 + ), + "scale_to_z_score$sparse_values": np.array( + [1, 0], dtype=np.float32 + ), + "scale_by_min_max$sparse_values": np.array( + [1.0, _sigmoid(3)], dtype=np.float32 + ), + }, + ], + ), + dict( + testcase_name="ragged", + input_data=[ + { + "val": [0.0, 2.0, 3.0], + "row_lengths": [0, 3], + }, + { + "val": [3.0, 3.0, 1.0], + "row_lengths": [3], + }, + ], + make_feature_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda + tf.float32, + value_key="val", + partitions=[tf.io.RaggedFeature.RowLengths("row_lengths")], + ), # pytype: disable=attribute-error + elementwise=False, + expected_outputs=[ + { + "scale_by_min_max$ragged_values": [0.0, 0.6666667, 1.0], + "scale_to_z_score$row_lengths_1": [0, 3], + "scale_to_0_1$row_lengths_1": [0, 3], + "scale_to_0_1$ragged_values": [0.0, 0.6666667, 1.0], + "scale_to_z_score$ragged_values": [-1.7320509, 0.0, 0.86602545], + "scale_by_min_max$row_lengths_1": [0, 3], + }, + { + "scale_to_0_1$row_lengths_1": [3], + "scale_by_min_max$row_lengths_1": [3], + "scale_to_z_score$ragged_values": [ + 0.86602545, + 0.86602545, + -0.86602545, + ], + "scale_to_z_score$row_lengths_1": [3], + "scale_to_0_1$ragged_values": [1.0, 1.0, 0.33333334], + "scale_by_min_max$ragged_values": [1.0, 1.0, 0.33333334], + }, + ], + ), + dict( + testcase_name="ragged_uniform", + input_data=[ + { + "val": [0.0, 2.0, 3.0, 11.0, 2.0, 7.0], + }, + { + "val": [3.0, 1.0, 2.0], + }, + ], + make_feature_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda + tf.float32, + value_key="val", + partitions=[ + tf.io.RaggedFeature.UniformRowLength( + 3 + ), # pytype: disable=attribute-error + ], + ), + elementwise=False, + expected_outputs=[ + { + "scale_by_min_max$ragged_values": [ + 0.0, + 0.18181819, + 0.27272728, + 1.0, + 0.18181819, + 0.6363636, + ], + "scale_to_z_score$ragged_values": [ + -1.0645443, + -0.4464218, + -0.13736054, + 2.3351295, + -0.4464218, + 1.0988845, + ], + "scale_to_0_1$ragged_values": [ + 0.0, + 0.18181819, + 0.27272728, + 1.0, + 0.18181819, + 0.6363636, + ], + }, + { + "scale_to_0_1$ragged_values": [0.27272728, 0.09090909, 0.18181819], + "scale_by_min_max$ragged_values": [ + 0.27272728, + 0.09090909, + 0.18181819, + ], + "scale_to_z_score$ragged_values": [ + -0.13736054, + -0.7554831, + -0.4464218, + ], + }, + ], + ), + dict( + testcase_name="2d_ragged_uniform", + input_data=[ + { + "val": [0.0, 2.0, 3.0, 1.0, 2.0, 7.0], + "row_lengths": [0, 2, 0, 1], + }, + { + "val": [3.0, 3.0, 1.0, 2.0], + "row_lengths": [2], + }, + ], + make_feature_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda + tf.float32, + value_key="val", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "row_lengths" + ), # pytype: disable=attribute-error + tf.io.RaggedFeature.UniformRowLength( + 2 + ), # pytype: disable=attribute-error + ], + # Note that row splits are always encoded as int64 since we only + # support this integral type in outputs. We modify the default + # `row_splits_dtype` (tf.int32) here to make sure it still works. + row_splits_dtype=tf.int64, + ), + elementwise=False, + expected_outputs=[ + { + "scale_by_min_max$ragged_values": [ + 0.0, + 0.285714, + 0.428571, + 0.142857, + 0.285714, + 1.0, + ], + "scale_by_min_max$row_lengths_1": [0, 2, 0, 1], + "scale_to_z_score$row_lengths_1": [0, 2, 0, 1], + "scale_to_z_score$ragged_values": [ + -1.3333334, + -0.22222228, + 0.33333328, + -0.77777785, + -0.22222228, + 2.5555556, + ], + "scale_to_0_1$row_lengths_1": [0, 2, 0, 1], + "scale_to_0_1$ragged_values": [ + 0.0, + 0.2857143, + 0.42857143, + 0.14285715, + 0.2857143, + 1.0, + ], + }, + { + "scale_to_0_1$ragged_values": [ + 0.42857143, + 0.42857143, + 0.14285715, + 0.2857143, + ], + "scale_to_0_1$row_lengths_1": [2], + "scale_by_min_max$ragged_values": [ + 0.42857143, + 0.42857143, + 0.14285715, + 0.2857143, + ], + "scale_by_min_max$row_lengths_1": [2], + "scale_to_z_score$ragged_values": [ + 0.33333328, + 0.33333328, + -0.77777785, + -0.22222228, + ], + "scale_to_z_score$row_lengths_1": [2], + }, + ], + ), + ) + def testNumericMappersWithCompositeInputs( + self, input_data, make_feature_spec, elementwise, expected_outputs + ): + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tft_unit.make_feature_spec_wrapper(make_feature_spec)} + ) + + def preprocessing_fn(inputs): + return { + "scale_to_0_1": tft.scale_to_0_1(inputs["a"], elementwise=elementwise), + "scale_to_z_score": tft.scale_to_z_score( + inputs["a"], elementwise=elementwise + ), + "scale_by_min_max": tft.scale_by_min_max( + inputs["a"], elementwise=elementwise + ), + } + + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn, expected_outputs + ) + + def testNumericAnalyzersWithInputsAndAxis(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return { + "min": tft.min(inputs["a"], reduce_instance_dims=False), + "max": tft.max(inputs["a"], reduce_instance_dims=False), + "sum": tft.sum(inputs["a"], reduce_instance_dims=False), + "size": tft.size(inputs["a"], reduce_instance_dims=False), + "mean": tft.mean(inputs["a"], reduce_instance_dims=False), + "var": tft.var(inputs["a"], reduce_instance_dims=False), + } + + input_data = [{"a": [8, 9, 3, 4]}, {"a": [1, 2, 10, 11]}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([4], tf.int64)} + ) + expected_outputs = { + "min": np.array([1, 2, 3, 4], np.int64), + "max": np.array([8, 9, 10, 11], np.int64), + "sum": np.array([9, 11, 13, 15], np.int64), + "size": np.array([2, 2, 2, 2], np.int64), + "mean": np.array([4.5, 5.5, 6.5, 7.5], np.float32), + "var": np.array([12.25, 12.25, 12.25, 12.25], np.float32), + } + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + def testNumericAnalyzersWithNDInputsAndAxis(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return { + "min": tft.min(inputs["a"], reduce_instance_dims=False), + "max": tft.max(inputs["a"], reduce_instance_dims=False), + "sum": tft.sum(inputs["a"], reduce_instance_dims=False), + "size": tft.size(inputs["a"], reduce_instance_dims=False), + "mean": tft.mean(inputs["a"], reduce_instance_dims=False), + "var": tft.var(inputs["a"], reduce_instance_dims=False), + } + + input_data = [{"a": [[8, 9], [3, 4]]}, {"a": [[1, 2], [10, 11]]}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([2, 2], tf.int64)} + ) + expected_outputs = { + "min": np.array([[1, 2], [3, 4]], np.int64), + "max": np.array([[8, 9], [10, 11]], np.int64), + "sum": np.array([[9, 11], [13, 15]], np.int64), + "size": np.array([[2, 2], [2, 2]], np.int64), + "mean": np.array([[4.5, 5.5], [6.5, 7.5]], np.float32), + "var": np.array([[12.25, 12.25], [12.25, 12.25]], np.float32), + } + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + def testNumericAnalyzersWithShape1NDInputsAndAxis(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return { + "min": tft.min(inputs["a"], reduce_instance_dims=False), + "max": tft.max(inputs["a"], reduce_instance_dims=False), + "sum": tft.sum(inputs["a"], reduce_instance_dims=False), + "size": tft.size(inputs["a"], reduce_instance_dims=False), + "mean": tft.mean(inputs["a"], reduce_instance_dims=False), + "var": tft.var(inputs["a"], reduce_instance_dims=False), + } + + input_data = [{"a": [[8, 9]]}, {"a": [[1, 2]]}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([1, 2], tf.int64)} + ) + expected_outputs = { + "min": np.array([[1, 2]], np.int64), + "max": np.array([[8, 9]], np.int64), + "sum": np.array([[9, 11]], np.int64), + "size": np.array([[2, 2]], np.int64), + "mean": np.array([[4.5, 5.5]], np.float32), + "var": np.array([[12.25, 12.25]], np.float32), + } + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + def testNumericAnalyzersWithNDInputs(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return { + "min": tft.min(inputs["a"]), + "max": tft.max(inputs["a"]), + "sum": tft.sum(inputs["a"]), + "size": tft.size(inputs["a"]), + "mean": tft.mean(inputs["a"]), + "var": tft.var(inputs["a"]), + } + + input_data = [{"a": [[4, 5], [6, 7]]}, {"a": [[1, 2], [3, 4]]}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([2, 2], tf.int64)} + ) + expected_outputs = { + "min": np.array(1, np.int64), + "max": np.array(7, np.int64), + "sum": np.array(32, np.int64), + "size": np.array(8, np.int64), + "mean": np.array(4.0, np.float32), + "var": np.array(3.5, np.float32), + } + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + @tft_unit.named_parameters( + *tft_unit.cross_named_parameters( + [ + dict(testcase_name="int64", input_dtype=tf.int64), + dict(testcase_name="float32", input_dtype=tf.float32), + ], + [ + dict(testcase_name="scalar", input_shape=[]), + dict(testcase_name="ND", input_shape=[2, 3]), + ], + [ + dict(testcase_name="elementwise", reduce_instance_dims=False), + dict(testcase_name="not_elementwise", reduce_instance_dims=True), + ], + ) + ) + def testNumericAnalyzersWithEmptyInputs( + self, input_dtype, input_shape, reduce_instance_dims + ): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return { + "min": tft.min(inputs["a"], reduce_instance_dims=reduce_instance_dims), + "max": tft.max(inputs["a"], reduce_instance_dims=reduce_instance_dims), + "sum": tft.sum(inputs["a"], reduce_instance_dims=reduce_instance_dims), + "size": tft.size( + inputs["a"], reduce_instance_dims=reduce_instance_dims + ), + "mean": tft.mean( + inputs["a"], reduce_instance_dims=reduce_instance_dims + ), + "var": tft.var(inputs["a"], reduce_instance_dims=reduce_instance_dims), + } + + input_data = [] + canonical_dtype = tft_unit.canonical_numeric_dtype(input_dtype) + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature(input_shape, canonical_dtype)} + ) + input_val_dtype = input_dtype.as_numpy_dtype + output_shape = [] if reduce_instance_dims else input_shape + output_dtype = canonical_dtype.as_numpy_dtype + default_min = np.inf if input_dtype.is_floating else canonical_dtype.max + default_max = -np.inf if input_dtype.is_floating else canonical_dtype.min + expected_outputs = { + "min": np.full(output_shape, default_min, output_dtype), + "max": np.full(output_shape, default_max, output_dtype), + "sum": np.full(output_shape, 0, output_dtype), + "size": np.full(output_shape, 0, np.int64), + "mean": np.full(output_shape, 0, np.float32), + "var": np.full(output_shape, 0, np.float32), + } + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + test_data=[ + {"a": np.zeros(input_shape, input_val_dtype)}, + {"a": np.ones(input_shape, input_val_dtype)}, + ], + ) + + def testNumericMeanWithSparseTensorReduceFalseOverflow(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return {"mean": tft.mean(tf.cast(inputs["sparse"], tf.int32), False)} + + input_data = [ + {"idx": [0, 1], "val": [1, 1]}, + {"idx": [1, 3], "val": [2147483647, 3]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"sparse": tf.io.SparseFeature("idx", "val", tf.int64, 4)} + ) + expected_outputs = { + "mean": np.array([1.0, 1073741824.0, float("nan"), 3.0], np.float32) + } + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + def testStringToTFIDF(self): + def preprocessing_fn(inputs): + inputs_as_ints = tft.compute_and_apply_vocabulary( + tf.compat.v1.strings.split(inputs["a"]) + ) + out_index, out_values = tft.tfidf( + inputs_as_ints, + tft.get_num_buckets_for_transformed_feature(inputs_as_ints), + ) + return { + "tf_idf": out_values, + "index": out_index, + } + + input_data = [ + {"a": "hello hello world"}, + {"a": "hello goodbye hello world"}, + {"a": "I like pie pie pie"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) + + # IDFs + # hello = 1 + log(4/3) = 1.28768 + # world = 1 + log(4/3) + # goodbye = 1 + log(4/2) = 1.69314 + # I = 1 + log(4/2) + # like = 1 + log(4/2) + # pie = 1 + log(4/2) + log_4_over_2_plus_1 = 1.69314718056 + log_4_over_3_plus_1 = 1.28768207245 + expected_transformed_data = [ + { + "tf_idf": [ + (2 / 3) * log_4_over_3_plus_1, + (1 / 3) * log_4_over_3_plus_1, + ], + "index": [0, 2], + }, + { + "tf_idf": [ + (2 / 4) * log_4_over_3_plus_1, + (1 / 4) * log_4_over_3_plus_1, + (1 / 4) * log_4_over_2_plus_1, + ], + "index": [0, 2, 4], + }, + { + "tf_idf": [ + (3 / 5) * log_4_over_2_plus_1, + (1 / 5) * log_4_over_2_plus_1, + (1 / 5) * log_4_over_2_plus_1, + ], + "index": [1, 3, 5], + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "tf_idf": tf.io.VarLenFeature(tf.float32), + "index": tf.io.VarLenFeature(tf.int64), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_transformed_data, + expected_metadata, + ) + + def testTFIDFNoData(self): + def preprocessing_fn(inputs): + inputs_as_ints = tft.compute_and_apply_vocabulary( + tf.compat.v1.strings.split(inputs["a"]) + ) + out_index, out_values = tft.tfidf(inputs_as_ints, 6) + return {"tf_idf": out_values, "index": out_index} + + input_data = [{"a": ""}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) + expected_transformed_data = [{"tf_idf": [], "index": []}] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "tf_idf": tf.io.VarLenFeature(tf.float32), + "index": tf.io.VarLenFeature(tf.int64), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_transformed_data, + expected_metadata, + ) + + def testStringToTFIDFEmptyDoc(self): + def preprocessing_fn(inputs): + inputs_as_ints = tft.compute_and_apply_vocabulary( + tf.compat.v1.strings.split(inputs["a"]) + ) + out_index, out_values = tft.tfidf( + inputs_as_ints, + tft.get_num_buckets_for_transformed_feature(inputs_as_ints), + ) + return {"tf_idf": out_values, "index": out_index} + + input_data = [ + {"a": "hello hello world"}, + {"a": ""}, + {"a": "hello goodbye hello world"}, + {"a": "I like pie pie pie"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) + + log_5_over_2_plus_1 = 1.91629073187 + log_5_over_3_plus_1 = 1.51082562376 + expected_transformed_data = [ + { + "tf_idf": [ + (2 / 3) * log_5_over_3_plus_1, + (1 / 3) * log_5_over_3_plus_1, + ], + "index": [0, 2], + }, + {"tf_idf": [], "index": []}, + { + "tf_idf": [ + (2 / 4) * log_5_over_3_plus_1, + (1 / 4) * log_5_over_3_plus_1, + (1 / 4) * log_5_over_2_plus_1, + ], + "index": [0, 2, 4], + }, + { + "tf_idf": [ + (3 / 5) * log_5_over_2_plus_1, + (1 / 5) * log_5_over_2_plus_1, + (1 / 5) * log_5_over_2_plus_1, + ], + "index": [1, 3, 5], + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "tf_idf": tf.io.VarLenFeature(tf.float32), + "index": tf.io.VarLenFeature(tf.int64), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_transformed_data, + expected_metadata, + ) + + def testIntToTFIDF(self): + def preprocessing_fn(inputs): + out_index, out_values = tft.tfidf(inputs["a"], 13) + return {"tf_idf": out_values, "index": out_index} + + input_data = [ + {"a": [2, 2, 0]}, + {"a": [2, 6, 2, 0]}, + {"a": [8, 10, 12, 12, 12]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.VarLenFeature(tf.int64)} + ) + log_4_over_2_plus_1 = 1.69314718056 + log_4_over_3_plus_1 = 1.28768207245 + expected_data = [ + { + "tf_idf": [ + (1 / 3) * log_4_over_3_plus_1, + (2 / 3) * log_4_over_3_plus_1, + ], + "index": [0, 2], + }, + { + "tf_idf": [ + (1 / 4) * log_4_over_3_plus_1, + (2 / 4) * log_4_over_3_plus_1, + (1 / 4) * log_4_over_2_plus_1, + ], + "index": [0, 2, 6], + }, + { + "tf_idf": [ + (1 / 5) * log_4_over_2_plus_1, + (1 / 5) * log_4_over_2_plus_1, + (3 / 5) * log_4_over_2_plus_1, + ], + "index": [8, 10, 12], + }, + ] + expected_schema = tft.DatasetMetadata.from_feature_spec( + { + "tf_idf": tf.io.VarLenFeature(tf.float32), + "index": tf.io.VarLenFeature(tf.int64), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn, expected_data, expected_schema + ) + + def testIntToTFIDFWithoutSmoothing(self): + def preprocessing_fn(inputs): + out_index, out_values = tft.tfidf(inputs["a"], 13, smooth=False) + return {"tf_idf": out_values, "index": out_index} + + input_data = [ + {"a": [2, 2, 0]}, + {"a": [2, 6, 2, 0]}, + {"a": [8, 10, 12, 12, 12]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.VarLenFeature(tf.int64)} + ) + log_3_over_2_plus_1 = 1.4054651081 + log_3_plus_1 = 2.0986122886 + expected_data = [ + { + "tf_idf": [ + (1 / 3) * log_3_over_2_plus_1, + (2 / 3) * log_3_over_2_plus_1, + ], + "index": [0, 2], + }, + { + "tf_idf": [ + (1 / 4) * log_3_over_2_plus_1, + (2 / 4) * log_3_over_2_plus_1, + (1 / 4) * log_3_plus_1, + ], + "index": [0, 2, 6], + }, + { + "tf_idf": [ + (1 / 5) * log_3_plus_1, + (1 / 5) * log_3_plus_1, + (3 / 5) * log_3_plus_1, + ], + "index": [8, 10, 12], + }, + ] + expected_schema = tft.DatasetMetadata.from_feature_spec( + { + "tf_idf": tf.io.VarLenFeature(tf.float32), + "index": tf.io.VarLenFeature(tf.int64), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn, expected_data, expected_schema + ) + + def testTFIDFWithOOV(self): + def preprocessing_fn(inputs): + inputs_as_ints = tft.compute_and_apply_vocabulary( + tf.compat.v1.strings.split(inputs["a"]), top_k=3 + ) + out_index, out_values = tft.tfidf( + inputs_as_ints, + tft.get_num_buckets_for_transformed_feature(inputs_as_ints) + 1, + ) + return {"tf_idf": out_values, "index": out_index} + + input_data = [ + {"a": "hello hello world"}, + {"a": "hello goodbye hello world"}, + {"a": "I like pie pie pie"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) + + # IDFs + # hello = 1 + log(3/3) = 1 + # pie = 1+ log(3/2) = 1.4054651081 + # world = 1 + log(3/3) = 1 + # OOV - goodbye, I, like = 1 + log(3/3) = 1 + log_4_over_2_plus_1 = 1.69314718056 + log_4_over_3_plus_1 = 1.28768207245 + expected_transformed_data = [ + { + "tf_idf": [ + (2 / 3) * log_4_over_3_plus_1, + (1 / 3) * log_4_over_3_plus_1, + ], + "index": [0, 2], + }, + { + "tf_idf": [ + (2 / 4) * log_4_over_3_plus_1, + (1 / 4) * log_4_over_3_plus_1, + (1 / 4) * log_4_over_3_plus_1, + ], + "index": [0, 2, 3], + }, + { + "tf_idf": [ + (3 / 5) * log_4_over_2_plus_1, + (2 / 5) * log_4_over_3_plus_1, + ], + "index": [1, 3], + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "tf_idf": tf.io.VarLenFeature(tf.float32), + "index": tf.io.VarLenFeature(tf.int64), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_transformed_data, + expected_metadata, + ) + + def testTFIDFWithNegatives(self): + def preprocessing_fn(inputs): + out_index, out_values = tft.tfidf(inputs["a"], 14) + return {"tf_idf": out_values, "index": out_index} + + input_data = [ + {"a": [2, 2, -4]}, + {"a": [2, 6, 2, -1]}, + {"a": [8, 10, 12, 12, 12]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.VarLenFeature(tf.int64)} + ) + + log_4_over_2_plus_1 = 1.69314718056 + log_4_over_3_plus_1 = 1.28768207245 + # NOTE: -4 mod 14 = 10 + expected_transformed_data = [ + { + "tf_idf": [ + (2 / 3) * log_4_over_3_plus_1, + (1 / 3) * log_4_over_3_plus_1, + ], + "index": [2, 10], + }, + { + "tf_idf": [ + (2 / 4) * log_4_over_3_plus_1, + (1 / 4) * log_4_over_2_plus_1, + (1 / 4) * log_4_over_2_plus_1, + ], + "index": [2, 6, 13], + }, + { + "tf_idf": [ + (1 / 5) * log_4_over_2_plus_1, + (1 / 5) * log_4_over_3_plus_1, + (3 / 5) * log_4_over_2_plus_1, + ], + "index": [8, 10, 12], + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "tf_idf": tf.io.VarLenFeature(tf.float32), + "index": tf.io.VarLenFeature(tf.int64), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_transformed_data, + expected_metadata, + ) + + def _get_dfidf_experimental_preprocessing_fn( + self, + is_str_input: bool = False, + smooth: bool = True, + add_baseline: bool = True, + vocab_size: Optional[int] = None, + top_k: Optional[int] = None, + ): + """Returns proper preprocessing fn for df/idf under tft.experimental.""" + + def preprocessing_fn(inputs): + if is_str_input: + inputs_as_ints = tft.compute_and_apply_vocabulary( + tf.compat.v1.strings.split(inputs["a"]), top_k=top_k + ) + else: + inputs_as_ints = inputs["a"] + + if vocab_size is None: + computed_vocab_size = tft.get_num_buckets_for_transformed_feature( + inputs_as_ints + ) + else: + computed_vocab_size = vocab_size + + out_df_counts = tft.experimental.document_frequency( + inputs_as_ints, computed_vocab_size + ) + out_idf_weights = tft.experimental.idf( + inputs_as_ints, + computed_vocab_size, + smooth=smooth, + add_baseline=add_baseline, + ) + return {"df": out_df_counts, "idf": out_idf_weights} + + return preprocessing_fn + + @tft_unit.named_parameters( + dict(testcase_name="StrInputSmoothBasaeline", smooth=True, add_baseline=True), + dict( + testcase_name="StrInputSmoothWOBasaeline", smooth=True, add_baseline=False + ), + dict( + testcase_name="StrInputNonSmoothBasaeline", smooth=False, add_baseline=True + ), + dict( + testcase_name="StrInputNonSmoothWOBasaeline", + smooth=False, + add_baseline=False, + ), + ) + def testStringToDFIDFExperimental(self, smooth, add_baseline): + input_data = [ + {"a": "hello hello world pie"}, + {"a": "hello goodbye world pie"}, + {"a": "I like pie pie"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) -from google.protobuf import text_format -import unittest -from tensorflow_metadata.proto.v0 import schema_pb2 + # corpus_size = 3 + # DF smooth base IDF non-smooth base IDF with baseline + # hello 2 log(4/3) log(3/2) * + 1 + # world 2 log(4/3) log(3/2) * + 1 + # goodbye 1 log(4/2) log3 * + 1 + # I 1 log(4/2) log3 * + 1 + # like 1 log(4/2) log3 * + 1 + # pie 3 log(4/4) = 0 log(3/3)=0 * + 1 + log_4_over_2 = 0.69314718056 + log_4_over_3 = 0.28768207245 + log_3_over_2 = 0.4054651081 + log_3 = 1.09861228867 + + if smooth: + base_idf1, base_idf2 = log_4_over_3, log_4_over_2 + else: + base_idf1, base_idf2 = log_3_over_2, log_3 + + baseline = 1.0 if add_baseline else 0.0 + + expected_transformed_data = [ + { + "df": [2, 2, 2, 3], + "idf": [ + baseline + base_idf1, + baseline + base_idf1, + baseline + base_idf1, + baseline, + ], + }, + { + "df": [2, 1, 2, 3], + "idf": [ + baseline + base_idf1, + baseline + base_idf2, + baseline + base_idf1, + baseline, + ], + }, + { + "df": [1, 1, 3, 3], + "idf": [baseline + base_idf2, baseline + base_idf2, baseline, baseline], + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "df": tf.io.VarLenFeature(tf.int64), + "idf": tf.io.VarLenFeature(tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + self._get_dfidf_experimental_preprocessing_fn( + is_str_input=True, smooth=smooth, add_baseline=add_baseline + ), + expected_transformed_data, + expected_metadata, + ) -if common.IS_ANNOTATIONS_PB_AVAILABLE: - from tensorflow_transform import annotations_pb2 # pylint: disable=g-import-not-at-top + def testDFIDFExperimentalNoData(self): + input_data = [{"a": ""}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) + # Input data is completely empty so need to specify vocab_size explicitly + preprocessing_fn = self._get_dfidf_experimental_preprocessing_fn( + is_str_input=True, vocab_size=6 + ) + expected_transformed_data = [{"df": [], "idf": []}] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "df": tf.io.VarLenFeature(tf.int64), + "idf": tf.io.VarLenFeature(tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_transformed_data, + expected_metadata, + ) -_SCALE_TO_Z_SCORE_TEST_CASES = [ - dict(testcase_name='int16', - input_data=np.array([[1], [1], [2], [2]], np.int16), - output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float32), - elementwise=False), - dict(testcase_name='int32', - input_data=np.array([[1], [1], [2], [2]], np.int32), - output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float32), - elementwise=False), - dict(testcase_name='int64', - input_data=np.array([[1], [1], [2], [2]], np.int64), - output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float32), - elementwise=False), - dict(testcase_name='float32', - input_data=np.array([[1], [1], [2], [2]], np.float32), - output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float32), - elementwise=False), - dict(testcase_name='float64', - input_data=np.array([[1], [1], [2], [2]], np.float64), - output_data=np.array([[-1.0], [-1.0], [1.0], [1.0]], np.float64), - elementwise=False), - dict(testcase_name='vector', - input_data=np.array([[1, 2], [3, 4]], np.float32), - output_data=np.array([[-3, -1], [1, 3]] / np.sqrt(5.0), np.float32), - elementwise=False), - dict(testcase_name='vector_elementwise', - input_data=np.array([[1, 2], [3, 4]], np.float32), - output_data=np.array([[-1.0, -1.0], [1.0, 1.0]], np.float32), - elementwise=True), - dict(testcase_name='zero_variance', - input_data=np.array([[3], [3], [3], [3]], np.float32), - output_data=np.array([[0], [0], [0], [0]], np.float32), - elementwise=False), - dict(testcase_name='zero_variance_elementwise', - input_data=np.array([[3, 4], [3, 4]], np.float32), - output_data=np.array([[0, 0], [0, 0]], np.float32), - elementwise=True), -] + def testStringToDFIDFExperimentalEmptyDoc(self): + input_data = [ + {"a": "hello hello world"}, + {"a": ""}, + {"a": "hello goodbye world"}, + {"a": "I like pie"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) -_SCALE_TO_Z_SCORE_NAN_TEST_CASES = [ - dict( - testcase_name='with_nans', - input_data=np.array([[1], [np.nan], [np.nan], [2]], np.float32), - output_data=np.array([[-1.0], [np.nan], [np.nan], [1.0]], np.float32), - elementwise=False), - dict( - testcase_name='with_nans_elementwise', - input_data=np.array([[1, np.nan], [np.nan, 2]], np.float32), - output_data=np.array([[0, np.nan], [np.nan, 0]], np.float32), - elementwise=True), -] + log_5_over_2_plus_1 = 1.91629073187 + log_5_over_3_plus_1 = 1.51082562376 + expected_transformed_data = [ + { + "df": [2, 2, 2], + "idf": [log_5_over_3_plus_1, log_5_over_3_plus_1, log_5_over_3_plus_1], + }, + {"df": [], "idf": []}, + { + "df": [2, 1, 2], + "idf": [log_5_over_3_plus_1, log_5_over_2_plus_1, log_5_over_3_plus_1], + }, + { + "df": [1, 1, 1], + "idf": [log_5_over_2_plus_1, log_5_over_2_plus_1, log_5_over_2_plus_1], + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "df": tf.io.VarLenFeature(tf.int64), + "idf": tf.io.VarLenFeature(tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + self._get_dfidf_experimental_preprocessing_fn(is_str_input=True), + expected_transformed_data, + expected_metadata, + ) + def testDFIDFExperimentalWithOOV(self): + input_data = [ + {"a": "hello world hi"}, + {"a": "hello goodbye world"}, + {"a": "I like pie pie"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) -def _sigmoid(x): - return 1 / (1 + np.exp(-x)) + preprocessing_fn_w_oov = self._get_dfidf_experimental_preprocessing_fn( + is_str_input=True, vocab_size=4, top_k=3 + ) + # smoothed base IDFs + # hello = log(4/3) + # pie = log(4/2) + # world = log(4/3) + # OOV - hi, goodbye, I, like = log(4/4) = 0 (OOV in all 3 out of 3 docs) + log_4_over_2_plus_1 = 1.69314718056 + log_4_over_3_plus_1 = 1.28768207245 + expected_transformed_data = [ + {"df": [2, 2, 3], "idf": [log_4_over_3_plus_1, log_4_over_3_plus_1, 1.0]}, + {"df": [2, 3, 2], "idf": [log_4_over_3_plus_1, 1.0, log_4_over_3_plus_1]}, + { + "df": [3, 3, 1, 1], + "idf": [1.0, 1.0, log_4_over_2_plus_1, log_4_over_2_plus_1], + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "df": tf.io.VarLenFeature(tf.int64), + "idf": tf.io.VarLenFeature(tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn_w_oov, + expected_transformed_data, + expected_metadata, + ) + @tft_unit.named_parameters( + dict(testcase_name="IntInputSmoothBasaeline", smooth=True, add_baseline=True), + dict( + testcase_name="IntInputSmoothWOBasaeline", smooth=True, add_baseline=False + ), + dict( + testcase_name="IntInputNoneSmoothBasaeline", smooth=False, add_baseline=True + ), + dict( + testcase_name="IntInputNoneSmoothWOBasaeline", + smooth=False, + add_baseline=False, + ), + ) + def testIntToDFIDFExpeirmental(self, smooth, add_baseline): + input_data = [ + {"a": [2, 2, 0]}, + {"a": [2, 6, 2, 0]}, + {"a": [8, 10, 12, 12, 12]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.VarLenFeature(tf.int64)} + ) + log_4_over_2 = 0.69314718056 + log_4_over_3 = 0.28768207245 + log_3 = 1.09861228867 + log_3_over_2 = 0.4054651081 -def sum_output_dtype(input_dtype): - """Returns the output dtype for tft.sum.""" - return input_dtype if input_dtype.is_floating else tf.int64 + if smooth: + idf1, idf2 = log_4_over_2, log_4_over_3 + else: + idf1, idf2 = log_3, log_3_over_2 + + if add_baseline: + idf1 += 1.0 + idf2 += 1.0 + + expected_data = [ + { + "df": [2, 2, 2], + "idf": [idf2, idf2, idf2], + }, + { + "df": [2, 1, 2, 2], + "idf": [idf2, idf1, idf2, idf2], + }, + {"df": [1, 1, 1, 1, 1], "idf": [idf1, idf1, idf1, idf1, idf1]}, + ] + expected_schema = tft.DatasetMetadata.from_feature_spec( + { + "df": tf.io.VarLenFeature(tf.int64), + "idf": tf.io.VarLenFeature(tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + self._get_dfidf_experimental_preprocessing_fn( + vocab_size=13, smooth=smooth, add_baseline=add_baseline + ), + expected_data, + expected_schema, + ) + def testDFIDFExperimentalWithNegatives(self): + input_data = [ + {"a": [2, 2, -4]}, + {"a": [2, 6, 2, -1]}, + {"a": [8, 10, 12, 12, 12]}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.VarLenFeature(tf.int64)} + ) -def _mean_output_dtype(input_dtype): - """Returns the output dtype for tft.mean (and similar functions).""" - return tf.float64 if input_dtype == tf.float64 else tf.float32 + log_4_over_2_plus_1 = 1.69314718056 + log_4_over_3_plus_1 = 1.28768207245 + # NOTE: -4 mod 14 = 10, -1 mod 14 = 13 + expected_transformed_data = [ + { + "df": [2, 2, 2], + "idf": [log_4_over_3_plus_1, log_4_over_3_plus_1, log_4_over_3_plus_1], + }, + { + "df": [2, 1, 2, 1], + "idf": [ + log_4_over_3_plus_1, + log_4_over_2_plus_1, + log_4_over_3_plus_1, + log_4_over_2_plus_1, + ], + }, + { + "df": [1, 2, 1, 1, 1], + "idf": [ + log_4_over_2_plus_1, + log_4_over_3_plus_1, + log_4_over_2_plus_1, + log_4_over_2_plus_1, + log_4_over_2_plus_1, + ], + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "df": tf.io.VarLenFeature(tf.int64), + "idf": tf.io.VarLenFeature(tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + self._get_dfidf_experimental_preprocessing_fn(vocab_size=14), + expected_transformed_data, + expected_metadata, + ) + def testCovarianceTwoDimensions(self): + self._SkipIfOutputRecordBatches() -class BeamImplTest(tft_unit.TransformTestCase): + def analyzer_fn(inputs): + return {"y": tft.covariance(inputs["x"], dtype=tf.float32)} - def setUp(self): - super().setUp() - tf.compat.v1.logging.info('Starting test case: %s', self._testMethodName) - self._context = tft_beam.Context(use_deep_copy_optimization=True) - self._context.__enter__() - - def tearDown(self): - super().tearDown() - self._context.__exit__() - - def _OutputRecordBatches(self): - return False - - def _SkipIfOutputRecordBatches(self): - if self._OutputRecordBatches(): - raise unittest.SkipTest( - 'Test is disabled when TFT outputs `pa.RecordBatch`es to avoid ' - 'duplicated testing: it does not exercise `TransformDataset` or ' - '`AnalyzeAndTransformDataset`.') - - # Overrides that automatically pass the proper value for - # `output_record_batches`. - def assertAnalyzeAndTransformResults(self, *args, **kwargs): - kwargs['output_record_batches'] = self._OutputRecordBatches() - return super().assertAnalyzeAndTransformResults(*args, **kwargs) - - def assertAnalyzerOutputs(self, *args, **kwargs): - kwargs['output_record_batches'] = self._OutputRecordBatches() - return super().assertAnalyzerOutputs(*args, **kwargs) - - def _MakeTransformOutputAssertFn(self, expected, sort=False): - - def _assert_fn(actual): - if sort: - dict_key_fn = lambda d: sorted(d.items()) - expected_sorted = sorted(expected, key=dict_key_fn) - actual_sorted = sorted(actual, key=dict_key_fn) - self.assertCountEqual(expected_sorted, actual_sorted) - else: - self.assertCountEqual(expected, actual) - - return _assert_fn - - def testApplySavedModelSingleInput(self): - def save_model_with_single_input(instance, export_dir): - builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) - with tf.compat.v1.Graph().as_default() as graph: - with instance.test_session(graph=graph) as sess: - input1 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='myinput1') - initializer = tf.compat.v1.constant_initializer([1, 2, 3]) - with tf.compat.v1.variable_scope( - 'Model', reuse=None, initializer=initializer): - v1 = tf.compat.v1.get_variable('v1', [3], dtype=tf.int64) - output1 = tf.add(v1, input1, name='myadd1') - inputs = {'single_input': input1} - outputs = {'single_output': output1} - signature_def_map = { - 'serving_default': - tf.compat.v1.saved_model.signature_def_utils - .predict_signature_def(inputs, outputs) - } - sess.run(tf.compat.v1.global_variables_initializer()) - builder.add_meta_graph_and_variables( - sess, [tf.saved_model.SERVING], - signature_def_map=signature_def_map) - builder.save(False) - - export_dir = os.path.join(self.get_temp_dir(), 'saved_model_single') - - def preprocessing_fn(inputs): - x = inputs['x'] - output_col = pretrained_models.apply_saved_model( - export_dir, x, tags=[tf.saved_model.SERVING]) - return {'out': output_col} - - save_model_with_single_input(self, export_dir) - input_data = [ - {'x': [1, 2, 3]}, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([3], tf.int64), - }) - # [1, 2, 3] + [1, 2, 3] = [2, 4, 6] - expected_data = [ - {'out': [2, 4, 6]} - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'out': tf.io.FixedLenFeature([3], tf.int64)}) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata) - - def testApplySavedModelWithHashTable(self): - def save_model_with_hash_table(instance, export_dir): - builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) - with tf.compat.v1.Graph().as_default() as graph: - with instance.test_session(graph=graph) as sess: - key = tf.constant('test_key', shape=[1]) - value = tf.constant('test_value', shape=[1]) - table = tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer(key, value), '__MISSING__') - - input1 = tf.compat.v1.placeholder( - dtype=tf.string, shape=[1], name='myinput') - output1 = tf.reshape(table.lookup(input1), shape=[1]) - inputs = {'input': input1} - outputs = {'output': output1} - - signature_def_map = { - 'serving_default': - tf.compat.v1.saved_model.signature_def_utils - .predict_signature_def(inputs, outputs) - } + input_data = [{"x": x} for x in [[0, 0], [4, 0], [2, -2], [2, 2]]] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([2], tf.float32)} + ) + expected_outputs = {"y": np.array([[2, 0], [0, 2]], np.float32)} + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) - sess.run(tf.compat.v1.tables_initializer()) - builder.add_meta_graph_and_variables( - sess, [tf.saved_model.SERVING], - signature_def_map=signature_def_map) - builder.save(False) - - export_dir = os.path.join(self.get_temp_dir(), 'saved_model_hash_table') - - def preprocessing_fn(inputs): - x = inputs['x'] - output_col = pretrained_models.apply_saved_model( - export_dir, x, tags=[tf.saved_model.SERVING]) - return {'out': output_col} - - save_model_with_hash_table(self, export_dir) - input_data = [ - {'x': ['test_key']} - ] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([1], tf.string), - }) - expected_data = [ - {'out': b'test_value'} - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'out': tf.io.FixedLenFeature([], tf.string)}) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata) - - def testApplySavedModelMultiInputs(self): - - def save_model_with_multi_inputs(instance, export_dir): - builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) - with tf.compat.v1.Graph().as_default() as graph: - with instance.test_session(graph=graph) as sess: - input1 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='myinput1') - input2 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='myinput2') - input3 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='myinput3') - initializer = tf.compat.v1.constant_initializer([1, 2, 3]) - with tf.compat.v1.variable_scope( - 'Model', reuse=None, initializer=initializer): - v1 = tf.compat.v1.get_variable('v1', [3], dtype=tf.int64) - o1 = tf.add(v1, input1, name='myadd1') - o2 = tf.subtract(o1, input2, name='mysubtract1') - output1 = tf.add(o2, input3, name='myadd2') - inputs = {'name1': input1, 'name2': input2, - 'name3': input3} - outputs = {'single_output': output1} - signature_def_map = { - 'serving_default': - tf.compat.v1.saved_model.signature_def_utils - .predict_signature_def(inputs, outputs) - } - sess.run(tf.compat.v1.global_variables_initializer()) - builder.add_meta_graph_and_variables( - sess, [tf.saved_model.SERVING], - signature_def_map=signature_def_map) - builder.save(False) - - export_dir = os.path.join(self.get_temp_dir(), 'saved_model_multi') - - def preprocessing_fn(inputs): - x = inputs['x'] - y = inputs['y'] - z = inputs['z'] - sum_column = pretrained_models.apply_saved_model( - export_dir, { - 'name1': x, - 'name3': z, - 'name2': y - }, - tags=[tf.saved_model.SERVING]) - return {'sum': sum_column} - - save_model_with_multi_inputs(self, export_dir) - input_data = [ - {'x': [1, 2, 3], 'y': [2, 3, 4], 'z': [1, 1, 1]}, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([3], tf.int64), - 'y': tf.io.FixedLenFeature([3], tf.int64), - 'z': tf.io.FixedLenFeature([3], tf.int64), - }) - # [1, 2, 3] + [1, 2, 3] - [2, 3, 4] + [1, 1, 1] = [1, 2, 3] - expected_data = [ - {'sum': [1, 2, 3]} - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'sum': tf.io.FixedLenFeature([3], tf.int64)}) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata) - - def testApplyFunctionWithCheckpoint(self): - - def tensor_fn(input1, input2): - initializer = tf.compat.v1.constant_initializer([1, 2, 3]) - with tf.compat.v1.variable_scope( - 'Model', reuse=None, initializer=initializer): - v1 = tf.compat.v1.get_variable('v1', [3], dtype=tf.int64) - v2 = tf.compat.v1.get_variable('v2', [3], dtype=tf.int64) - o1 = tf.add(v1, v2, name='add1') - o2 = tf.subtract(o1, input1, name='sub1') - o3 = tf.subtract(o2, input2, name='sub2') - return o3 - - def save_checkpoint(instance, checkpoint_path): - with tf.compat.v1.Graph().as_default() as graph: - with instance.test_session(graph=graph) as sess: - input1 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='myinput1') - input2 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='myinput2') - tensor_fn(input1, input2) - saver = tf.compat.v1.train.Saver() - sess.run(tf.compat.v1.global_variables_initializer()) - saver.save(sess, checkpoint_path) - - checkpoint_path = os.path.join(self.get_temp_dir(), 'chk') - - def preprocessing_fn(inputs): - x = inputs['x'] - y = inputs['y'] - out_value = pretrained_models.apply_function_with_checkpoint( - tensor_fn, [x, y], checkpoint_path) - return {'out': out_value} - - save_checkpoint(self, checkpoint_path) - input_data = [ - {'x': [2, 2, 2], 'y': [-1, -3, 1]}, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([3], tf.int64), - 'y': tf.io.FixedLenFeature([3], tf.int64), - }) - # [1, 2, 3] + [1, 2, 3] - [2, 2, 2] - [-1, -3, 1] = [1, 5, 3] - expected_data = [ - {'out': [1, 5, 3]} - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'out': tf.io.FixedLenFeature([3], tf.int64)}) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.named_parameters( - dict(testcase_name='NoDeepCopy', with_deep_copy=False), - dict(testcase_name='WithDeepCopy', with_deep_copy=True), - ) - def testMultipleLevelsOfAnalyzers(self, with_deep_copy): - # Test a preprocessing function similar to scale_to_0_1 except that it - # involves multiple interleavings of analyzers and transforms. - def preprocessing_fn(inputs): - scaled_to_0 = inputs['x'] - tft.min(inputs['x']) - scaled_to_0_1 = scaled_to_0 / tft.max(scaled_to_0) - return {'x_scaled': scaled_to_0_1} - - input_data = [{'x': 4}, {'x': 1}, {'x': 5}, {'x': 2}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.float32)}) - expected_data = [ - {'x_scaled': 0.75}, - {'x_scaled': 0.0}, - {'x_scaled': 1.0}, - {'x_scaled': 0.25} - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'x_scaled': tf.io.FixedLenFeature([], tf.float32)}) - with tft_beam.Context(use_deep_copy_optimization=with_deep_copy): - # NOTE: In order to correctly test deep_copy here, we can't pass test_data - # to assertAnalyzeAndTransformResults. - # Not passing test_data to assertAnalyzeAndTransformResults means that - # tft.AnalyzeAndTransform is called, exercising the right code path. - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata) - - def testRawFeedDictInput(self): - # Test the ability to feed raw data into AnalyzeDataset and TransformDataset - # by using subclasses of these transforms which create batches of size 1. - def preprocessing_fn(inputs): - sequence_example = inputs['sequence_example'] - - # Ordinarily this would have shape (batch_size,) since 'sequence_example' - # was defined as a FixedLenFeature with shape (). But since we specified - # desired_batch_size, we can assume that the shape is (1,), and reshape - # to (). - sequence_example = tf.reshape(sequence_example, ()) - - # Parse the sequence example. - feature_spec = { - 'x': - tf.io.FixedLenSequenceFeature( - shape=[], dtype=tf.string, default_value=None) - } - _, sequences = tf.io.parse_single_sequence_example( - sequence_example, sequence_features=feature_spec) - - # Create a batch based on the sequence "x". - return {'x': sequences['x']} - - def text_sequence_example_to_binary(text_proto): - proto = text_format.Merge(text_proto, tf.train.SequenceExample()) - return proto.SerializeToString() - - sequence_examples = [ - """ - feature_lists: { - feature_list: { - key: "x" - value: { - feature: {bytes_list: {value: 'ab'}} - feature: {bytes_list: {value: ''}} - feature: {bytes_list: {value: 'c'}} - feature: {bytes_list: {value: 'd'}} + def testCovarianceOneDimension(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return {"y": tft.covariance(inputs["x"], dtype=tf.float32)} + + input_data = [{"x": x} for x in [[0], [2], [4], [6]]] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([1], tf.float32)} + ) + expected_outputs = {"y": np.array([[5]], np.float32)} + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + def testCovarianceOneDimensionWithEmptyInputs(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return {"y": tft.covariance(inputs["x"], dtype=tf.float32)} + + input_data = [] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([1], tf.float32)} + ) + test_data = [{"x": [1]}, {"x": [2]}] + expected_outputs = {"y": np.array([[0]], dtype=np.float32)} + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + test_data=test_data, + ) + + def testPCAThreeToTwoDimensions(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return {"y": tft.pca(inputs["x"], 2, dtype=tf.float32)} + + input_data = [{"x": x} for x in [[0, 0, 1], [4, 0, 1], [2, -1, 1], [2, 1, 1]]] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([3], tf.float32)} + ) + expected_outputs = {"y": np.array([[1, 0], [0, 1], [0, 0]], np.float32)} + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + def testPCAThreeToTwoDimensionsWithEmptyInputs(self): + self._SkipIfOutputRecordBatches() + + def analyzer_fn(inputs): + return {"y": tft.pca(inputs["x"], 2, dtype=tf.float32)} + + input_data = [] + test_data = [{"x": x} for x in [[0, 0, 1], [4, 0, 1], [2, -1, 1], [2, 1, 1]]] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([3], tf.float32)} + ) + expected_outputs = {"y": np.array([[1, 0], [0, 1], [0, 0]], np.float32)} + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + test_data=test_data, + ) + + class _SumCombiner(tft_beam.experimental.PTransformAnalyzer): + def __init__(self): + super().__init__() + self.base_temp_dir_in_expand = None + + def _extract_outputs(self, sums): + return [ + beam.pvalue.TaggedOutput("0", sums[0]), + beam.pvalue.TaggedOutput("1", sums[1]), + ] + + def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]): + self.base_temp_dir_in_expand = self.base_temp_dir + return ( + pcoll + | beam.FlatMap(lambda batches: list(zip(*batches))) + | beam.CombineGlobally(lambda values: np.sum(list(values), axis=0)) + | beam.FlatMap(self._extract_outputs).with_outputs("0", "1") + ) + + def testPTransformAnalyzer(self): + self._SkipIfOutputRecordBatches() + + sum_combiner = self._SumCombiner() + + def analyzer_fn(inputs): + outputs = tft.experimental.ptransform_analyzer( + [inputs["x"], inputs["y"]], sum_combiner, [tf.int64, tf.int64], [[], []] + ) + return {"x_sum": outputs[0], "y_sum": outputs[1]} + + input_data = [{"x": 1, "y": i} for i in range(100)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.int64), + "y": tf.io.FixedLenFeature([], tf.int64), } - } + ) + expected_outputs = { + "x_sum": np.array(100, np.int64), + "y_sum": np.array(4950, np.int64), } - """, - """ - feature_lists: { - feature_list: { - key: "x" - value: { - feature: {bytes_list: {value: 'ef'}} - feature: {bytes_list: {value: 'g'}} + self.assertIsNone(sum_combiner.base_temp_dir_in_expand) + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + self.assertIsNotNone(sum_combiner.base_temp_dir_in_expand) + self.assertStartsWith(sum_combiner.base_temp_dir_in_expand, self.get_temp_dir()) + + @tft_unit.named_parameters( + dict(testcase_name="ArrayOutput", output_fn=lambda x: np.array(x, np.int64)), + dict(testcase_name="ListOutput", output_fn=list), + ) + def testPTransformAnalyzerMultiDimOutput(self, output_fn): + self._SkipIfOutputRecordBatches() + + class _SimpleSumCombiner(tft_beam.experimental.PTransformAnalyzer): + def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]): + return ( + pcoll + | beam.FlatMap(lambda batches: list(zip(*batches))) + | beam.CombineGlobally(lambda values: np.sum(list(values), axis=0)) + | beam.combiners.ToList() + | beam.Map(output_fn) + ) + + sum_combiner = _SimpleSumCombiner() + + def analyzer_fn(inputs): + (outputs,) = tft.experimental.ptransform_analyzer( + [inputs["x"], inputs["y"]], sum_combiner, [tf.int64], [[1, 2]] + ) + return {"x_y_sums": outputs} + + input_data = [{"x": 1, "y": i} for i in range(100)] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.int64), + "y": tf.io.FixedLenFeature([], tf.int64), } - } - } - """ - ] - input_data = [ - {'sequence_example': text_sequence_example_to_binary(sequence_example)} - for sequence_example in sequence_examples] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'sequence_example': tf.io.FixedLenFeature([], tf.string)}) - expected_data = [ - {'x': b'ab'}, - {'x': b''}, - {'x': b'c'}, - {'x': b'd'}, - {'x': b'ef'}, - {'x': b'g'} - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.string)}) - - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata, desired_batch_size=1) - - def testTransformWithExcludedOutputs(self): - def preprocessing_fn(inputs): - return { - 'x_scaled': tft.scale_to_0_1(inputs['x']), - 'y_scaled': tft.scale_to_0_1(inputs['y']) - } - - # Run AnalyzeAndTransform on some input data and compare with expected - # output. - input_data = [{'x': 5, 'y': 1}, {'x': 1, 'y': 2}] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32) - }) - with tft_beam.Context(temp_dir=self.get_temp_dir()): - transform_fn = ((input_data, input_metadata) - | tft_beam.AnalyzeDataset(preprocessing_fn)) - - # Take the transform function and use TransformDataset to apply it to - # some eval data, with missing 'y' column. - eval_data = [{'x': 6}] - eval_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.float32)}) - transformed_eval_data, transformed_eval_metadata = ( - ((eval_data, eval_metadata), transform_fn) - | tft_beam.TransformDataset( - exclude_outputs=['y_scaled'], - output_record_batches=self._OutputRecordBatches())) - - if self._OutputRecordBatches(): - expected_transformed_eval_data = {'x_scaled': [[1.25]]} - self.assertLen(transformed_eval_data, 1) - # Contains RecordBatch and unary pass-through features dict. - self.assertLen(transformed_eval_data[0], 2) - self.assertDictEqual(transformed_eval_data[0][0].to_pydict(), - expected_transformed_eval_data) - self.assertDictEqual(transformed_eval_data[0][1], {}) - else: - expected_transformed_eval_data = [{'x_scaled': 1.25}] - self.assertDataCloseOrEqual(transformed_eval_data, - expected_transformed_eval_data) - expected_transformed_eval_metadata = tft.DatasetMetadata.from_feature_spec( - {'x_scaled': tf.io.FixedLenFeature([], tf.float32)}) - self.assertEqual(transformed_eval_metadata.dataset_metadata, - expected_transformed_eval_metadata) - - def testMapWithCond(self): - def preprocessing_fn(inputs): - return { - 'a': - tf.cond( - pred=tf.constant(True), - true_fn=lambda: inputs['a'], - false_fn=lambda: inputs['b']) - } - - input_data = [ - {'a': 4, 'b': 3}, - {'a': 1, 'b': 2}, - {'a': 5, 'b': 6}, - {'a': 2, 'b': 3} - ] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': tf.io.FixedLenFeature([], tf.float32), - 'b': tf.io.FixedLenFeature([], tf.float32) - }) - expected_data = [ - {'a': 4}, - {'a': 1}, - {'a': 5}, - {'a': 2} - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.float32)}) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata) - - def testPyFuncs(self): - def my_multiply(x, y): - return x*y - - def my_add(x, y): - return x+y - - def my_list_return(x, y): - return [x, y, 2 * x, 2 * y] - - def preprocessing_fn(inputs): - result = { - 'a+b': - tft.apply_pyfunc(my_add, tf.float32, True, 'add', inputs['a'], - inputs['b']), - 'a+c': - tft.apply_pyfunc(my_add, tf.float32, True, 'add', inputs['a'], - inputs['c']), - 'ab': - tft.apply_pyfunc(my_multiply, tf.float32, False, 'multiply', - inputs['a'], inputs['b']), - 'sum_scaled': - tft.scale_to_0_1( - tft.apply_pyfunc(my_add, tf.float32, True, 'add', inputs['a'], - inputs['c'])), - 'list': - tf.reduce_sum( - tft.apply_pyfunc( - my_list_return, - [tf.float32, tf.float32, tf.float32, tf.float32], True, - 'my_list_return', inputs['a'], inputs['b']), - axis=0), - } - for value in result.values(): - value.set_shape([1,]) - return result - - input_data = [ - {'a': 4, 'b': 3, 'c': 2}, - {'a': 1, 'b': 2, 'c': 3}, - {'a': 5, 'b': 6, 'c': 7}, - {'a': 2, 'b': 3, 'c': 4} - ] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': tf.io.FixedLenFeature([], tf.float32), - 'b': tf.io.FixedLenFeature([], tf.float32), - 'c': tf.io.FixedLenFeature([], tf.float32) - }) - expected_data = [ - {'ab': 12, 'a+b': 7, 'a+c': 6, 'list': 21, 'sum_scaled': 0.25}, - {'ab': 2, 'a+b': 3, 'a+c': 4, 'list': 9, 'sum_scaled': 0}, - {'ab': 30, 'a+b': 11, 'a+c': 12, 'list': 33, 'sum_scaled': 1}, - {'ab': 6, 'a+b': 5, 'a+c': 6, 'list': 15, 'sum_scaled': 0.25} - ] - # When calling tf.py_func, the output shape is set to unknown. - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'ab': tf.io.FixedLenFeature([], tf.float32), - 'a+b': tf.io.FixedLenFeature([], tf.float32), - 'a+c': tf.io.FixedLenFeature([], tf.float32), - 'list': tf.io.FixedLenFeature([], tf.float32), - 'sum_scaled': tf.io.FixedLenFeature([], tf.float32) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata, force_tf_compat_v1=True) - - def testAssertsNoReturnPyFunc(self): - # Asserts that apply_pyfunc raises an exception if the passed function does - # not return anything. - self._SkipIfOutputRecordBatches() - - def bad_func(): - return None - - with self.assertRaises(ValueError): - tft.apply_pyfunc(bad_func, [], False, 'bad_func') - - def testWithMoreThanDesiredBatchSize(self): - def preprocessing_fn(inputs): - return { - 'ab': tf.multiply(inputs['a'], inputs['b']), - 'i': tft.compute_and_apply_vocabulary(inputs['c']) - } - - batch_size = 100 - num_instances = batch_size + 1 - # pylint: disable=g-complex-comprehension - input_data = [{ - 'a': 2, - 'b': i, - 'c': '%.10i' % i, # Front-padded to facilitate lexicographic sorting. - } for i in range(num_instances)] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': tf.io.FixedLenFeature([], tf.float32), - 'b': tf.io.FixedLenFeature([], tf.float32), - 'c': tf.io.FixedLenFeature([], tf.string) - }) - expected_data = [{ - 'ab': 2*i, - 'i': (len(input_data) - 1) - i, # Due to reverse lexicographic sorting. - } for i in range(len(input_data))] - # pylint: enable=g-complex-comprehension - expected_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'ab': tf.io.FixedLenFeature([], tf.float32), - 'i': tf.io.FixedLenFeature([], tf.int64), - }, { - 'i': - schema_pb2.IntDomain( - min=-1, max=num_instances - 1, is_categorical=True) - }) - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - desired_batch_size=batch_size, - force_tf_compat_v1=True) - - def testWithUnicode(self): - def preprocessing_fn(inputs): - return {'a b': tf.compat.v1.strings.join( - [inputs['a'], inputs['b']], separator=' ')} - - input_data = [{'a': 'Hello', 'b': 'world'}, {'a': 'Hello', 'b': u'κόσμε'}] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': tf.io.FixedLenFeature([], tf.string), - 'b': tf.io.FixedLenFeature([], tf.string), - }) - expected_data = [ - {'a b': b'Hello world'}, - {'a b': u'Hello κόσμε'.encode('utf-8')} - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'a b': tf.io.FixedLenFeature([], tf.string)}) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_metadata) - - def testNpArrayInput(self): - - def preprocessing_fn(inputs): - return {'a b': tf.compat.v1.strings.join( - [inputs['a'], inputs['b']], separator=' ')} - - input_data = [{ - 'a': np.array('Hello', dtype=object), - 'b': np.array('world', dtype=object) - }, { - 'a': np.array('Hello', dtype=object), - 'b': np.array(u'κόσμε', dtype=object) - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': tf.io.FixedLenFeature([], tf.string), - 'b': tf.io.FixedLenFeature([], tf.string), - }) - expected_data = [{ - 'a b': np.array(b'Hello world', dtype=object) - }, { - 'a b': np.array(u'Hello κόσμε'.encode('utf-8'), dtype=object) - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'a b': tf.io.FixedLenFeature([], tf.string)}) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.parameters((True,), (False,)) - def testScaleUnitInterval(self, elementwise): - - def preprocessing_fn(inputs): - outputs = {} - stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1) - result = tft.scale_to_0_1(stacked_input, elementwise=elementwise) - outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1) - return outputs - - input_data = [{ - 'x': 4, - 'y': 5 - }, { - 'x': 1, - 'y': 2 - }, { - 'x': 5, - 'y': 6 - }, { - 'x': 2, - 'y': 3 - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32) - }) - if elementwise: - expected_data = [{ - 'x_scaled': 0.75, - 'y_scaled': 0.75 - }, { - 'x_scaled': 0.0, - 'y_scaled': 0.0 - }, { - 'x_scaled': 1.0, - 'y_scaled': 1.0 - }, { - 'x_scaled': 0.25, - 'y_scaled': 0.25 - }] - else: - expected_data = [{ - 'x_scaled': 0.6, - 'y_scaled': 0.8 - }, { - 'x_scaled': 0.0, - 'y_scaled': 0.2 - }, { - 'x_scaled': 0.8, - 'y_scaled': 1.0 - }, { - 'x_scaled': 0.2, - 'y_scaled': 0.4 - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([], tf.float32), - 'y_scaled': tf.io.FixedLenFeature([], tf.float32) - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.parameters((True,), (False,)) - def testScaleUnitIntervalPerKey(self, elementwise): - - def preprocessing_fn(inputs): - outputs = {} - stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1) - result = tft.scale_to_0_1_per_key( - stacked_input, inputs['key'], elementwise) - outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1) - return outputs - - input_data = [{ - 'x': 4, - 'y': 5, - 'key': 'a' - }, { - 'x': 1, - 'y': 2, - 'key': 'a' - }, { - 'x': 5, - 'y': 6, - 'key': 'a' - }, { - 'x': 2, - 'y': 3, - 'key': 'a' - }, { - 'x': 25, - 'y': -25, - 'key': 'b' - }, { - 'x': 5, - 'y': 0, - 'key': 'b' - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string) - }) - if elementwise: - expected_data = [{ - 'x_scaled': 0.75, - 'y_scaled': 0.75 - }, { - 'x_scaled': 0.0, - 'y_scaled': 0.0 - }, { - 'x_scaled': 1.0, - 'y_scaled': 1.0 - }, { - 'x_scaled': 0.25, - 'y_scaled': 0.25 - }, { - 'x_scaled': 1.0, - 'y_scaled': 0.0 - }, { - 'x_scaled': 0.0, - 'y_scaled': 1.0 - }] - else: - expected_data = [{ - 'x_scaled': 0.6, - 'y_scaled': 0.8 - }, { - 'x_scaled': 0.0, - 'y_scaled': 0.2 - }, { - 'x_scaled': 0.8, - 'y_scaled': 1.0 - }, { - 'x_scaled': 0.2, - 'y_scaled': 0.4 - }, { - 'x_scaled': 1.0, - 'y_scaled': 0.0 - }, { - 'x_scaled': 0.6, - 'y_scaled': 0.5 - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([], tf.float32), - 'y_scaled': tf.io.FixedLenFeature([], tf.float32) - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.parameters((True,), (False,)) - def testScaleMinMax(self, elementwise): - def preprocessing_fn(inputs): - outputs = {} - stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1) - result = tft.scale_by_min_max( - stacked_input, output_min=-1, output_max=1, elementwise=elementwise) - outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1) - return outputs - - input_data = [{ - 'x': 4, - 'y': 8 - }, { - 'x': 1, - 'y': 5 - }, { - 'x': 5, - 'y': 9 - }, { - 'x': 2, - 'y': 6 - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32) - }) - if elementwise: - expected_data = [{ - 'x_scaled': 0.5, - 'y_scaled': 0.5 - }, { - 'x_scaled': -1.0, - 'y_scaled': -1.0 - }, { - 'x_scaled': 1.0, - 'y_scaled': 1.0 - }, { - 'x_scaled': -0.5, - 'y_scaled': -0.5 - }] - else: - expected_data = [{ - 'x_scaled': -0.25, - 'y_scaled': 0.75 - }, { - 'x_scaled': -1.0, - 'y_scaled': 0.0 - }, { - 'x_scaled': 0.0, - 'y_scaled': 1.0 - }, { - 'x_scaled': -0.75, - 'y_scaled': 0.25 - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([], tf.float32), - 'y_scaled': tf.io.FixedLenFeature([], tf.float32) - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.named_parameters( - dict( - testcase_name='_empty_filename', - elementwise=False, - key_vocabulary_filename=''), - dict( - testcase_name='_nonempty_filename', - elementwise=False, - key_vocabulary_filename='per_key'), - dict( - testcase_name='_none_filename', - elementwise=False, - key_vocabulary_filename=None), - dict( - testcase_name='_elementwise_none_filename', - elementwise=True, - key_vocabulary_filename=None)) - def testScaleMinMaxPerKey(self, elementwise, key_vocabulary_filename): - - def preprocessing_fn(inputs): - outputs = {} - stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1) - result = tft.scale_by_min_max_per_key( - stacked_input, - inputs['key'], - output_min=-1, - output_max=1, - elementwise=elementwise, - key_vocabulary_filename=key_vocabulary_filename) - outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1) - return outputs - - input_data = [{ - 'x': 4, - 'y': 8, - 'key': 'a' - }, { - 'x': 1, - 'y': 5, - 'key': 'a' - }, { - 'x': 5, - 'y': 9, - 'key': 'a' - }, { - 'x': 2, - 'y': 6, - 'key': 'a' - }, { - 'x': -2, - 'y': 0, - 'key': 'b' - }, { - 'x': 0, - 'y': 2, - 'key': 'b' - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string) - }) - if elementwise: - expected_data = [{ - 'x_scaled': 0.5, - 'y_scaled': 0.5 - }, { - 'x_scaled': -1.0, - 'y_scaled': -1.0 - }, { - 'x_scaled': 1.0, - 'y_scaled': 1.0 - }, { - 'x_scaled': -0.5, - 'y_scaled': -0.5 - }, { - 'x_scaled': -1.0, - 'y_scaled': -1.0 - }, { - 'x_scaled': 1.0, - 'y_scaled': 1.0 - }] - else: - expected_data = [{ - 'x_scaled': -0.25, - 'y_scaled': 0.75 - }, { - 'x_scaled': -1.0, - 'y_scaled': 0.0 - }, { - 'x_scaled': 0.0, - 'y_scaled': 1.0 - }, { - 'x_scaled': -0.75, - 'y_scaled': 0.25 - }, { - 'x_scaled': -1.0, - 'y_scaled': 0.0 - }, { - 'x_scaled': 0.0, - 'y_scaled': 1.0 - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([], tf.float32), - 'y_scaled': tf.io.FixedLenFeature([], tf.float32) - }) - if key_vocabulary_filename: - per_key_vocab_contents = { - key_vocabulary_filename: [(b'a', [-1.0, 9.0]), (b'b', [2.0, 2.0])] - } - else: - per_key_vocab_contents = None - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - expected_vocab_file_contents=per_key_vocab_contents) - - def testScalePerKeySparse(self): - def preprocessing_fn(inputs): - return { - 'scaled_by_min_max': - tft.scale_by_min_max_per_key( - inputs['x'], inputs['key'], output_min=-1, output_max=1), - 'scaled_to_0_1': - tft.scale_to_0_1_per_key(inputs['x'], inputs['key']), - 'scaled_to_z_score': - tft.scale_to_z_score_per_key(inputs['x'], inputs['key']), - } - - input_data = [{ - 'val': [4, 8], - 's': ['a', 'a'] - }, { - 'val': [1, 5], - 's': ['a', 'a'] - }, { - 'val': [5, 9], - 's': ['a', 'a'] - }, { - 'val': [2, 6], - 's': ['a', 'a'] - }, { - 'val': [-2, 0], - 's': ['b', 'b'] - }, { - 'val': [0, 2], - 's': ['b', 'b'] - }] - indices = [([x % 2] * 2, [x % 3] * 2) for x in range(len(input_data))] - indices_x = [{'idx_x_0': a, 'idx_x_1': b} for a, b in indices] - indices_key = [{'idx_key_0': a, 'idx_key_1': b} for a, b in indices] - input_data = [{**a, **b, **c} - for a, b, c in zip(input_data, indices_x, indices_key)] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.SparseFeature(['idx_x_0', 'idx_x_1'], 'val', tf.float32, - (2, 3)), - 'key': - tf.io.SparseFeature(['idx_key_0', 'idx_key_1'], 's', tf.string, - (2, 3)) - }) - - output_names = ['scaled_by_min_max', 'scaled_to_0_1', 'scaled_to_z_score'] - expected_indices_prefix = [ - (('$sparse_indices_0', a), ('$sparse_indices_1', b)) for a, b in indices - ] - expected_indices = [] - for idx0, idx1 in expected_indices_prefix: - instance = {} - for n in output_names: - instance.update({n + idx0[0]: idx0[1]}) - instance.update({n + idx1[0]: idx1[1]}) - expected_indices.append(instance) - - expected_data = [{ - 'scaled_by_min_max$sparse_values': [-0.25, 0.75], - 'scaled_to_0_1$sparse_values': - np.array([3. / 8., 7. / 8]), - 'scaled_to_z_score$sparse_values': - np.array([-1. / math.sqrt(6.5), 3. / math.sqrt(6.5)]) - }, { - 'scaled_by_min_max$sparse_values': [-1.0, 0.0], - 'scaled_to_0_1$sparse_values': np.array([0., 0.5]), - 'scaled_to_z_score$sparse_values': np.array([-4. / math.sqrt(6.5), 0.]), - }, { - 'scaled_by_min_max$sparse_values': [0.0, 1.0], - 'scaled_to_0_1$sparse_values': np.array([0.5, 1.]), - 'scaled_to_z_score$sparse_values': np.array([0., 4. / math.sqrt(6.5)]), - }, { - 'scaled_by_min_max$sparse_values': [-0.75, 0.25], - 'scaled_to_0_1$sparse_values': - np.array([1. / 8., 5. / 8.]), - 'scaled_to_z_score$sparse_values': - np.array([-3. / math.sqrt(6.5), 1. / math.sqrt(6.5)]), - }, { - 'scaled_by_min_max$sparse_values': np.array([-1., 0.]), - 'scaled_to_0_1$sparse_values': np.array([0., 0.5]), - 'scaled_to_z_score$sparse_values': np.array([-2. / math.sqrt(2), 0.]), - }, { - 'scaled_by_min_max$sparse_values': [0.0, 1.0], - 'scaled_to_0_1$sparse_values': np.array([0.5, 1.]), - 'scaled_to_z_score$sparse_values': np.array([0., 2. / math.sqrt(2)]), - }] - expected_data = [{**a, **b} - for a, b in zip(expected_data, expected_indices)] - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - beam_pipeline=beam.Pipeline()) - - @tft_unit.named_parameters( - dict( - testcase_name='sparse_key', - input_data=[{ - 'idx': [0, 1], - 'val': [-4, 4], - 'key_idx': [0, 1], - 'key': ['a', 'a'] - }, { - 'idx': [0, 1], - 'val': [2, 1], - 'key_idx': [0, 1], - 'key': ['a', 'b'] - }, { - 'idx': [0, 1], - 'val': [-1, 4], - 'key_idx': [0, 1], - 'key': ['b', 'a'] - }], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.SparseFeature( - 'idx', 'val', - tft_unit.canonical_numeric_dtype(tf.float32), 4), - 'key': - tf.io.SparseFeature('key_idx', 'key', tf.string, 4) - }), - expected_data=[{ - 'x_scaled': [0., 1., 0, 0] - }, { - 'x_scaled': [.75, 1., 0, 0] - }, { - 'x_scaled': [0., 1., 0, 0] - }]), - dict( - testcase_name='dense_key', - input_data=[{ - 'idx': [0, 1], - 'val': [-4, 4], - 'key': 'a' - }, { - 'idx': [0, 1], - 'val': [2, 1], - 'key': 'a' - }, { - 'idx': [0, 1], - 'val': [-1, 4], - 'key': 'b' - }], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.SparseFeature( - 'idx', 'val', - tft_unit.canonical_numeric_dtype(tf.float32), 4), - 'key': - tf.io.FixedLenFeature([], tf.string) - }), - expected_data=[{ - 'x_scaled': [0., 1., 0, 0] - }, { - 'x_scaled': [.75, .625, 0, 0] - }, { - 'x_scaled': [0., 1., 0, 0] - }]), - ) - def testScaleMinMaxSparsePerKey( - self, input_data, input_metadata, expected_data): - def preprocessing_fn(inputs): - x_scaled = tf.sparse.to_dense( - tft.scale_to_0_1_per_key(inputs['x'], inputs['key'])) - x_scaled.set_shape([None, 4]) - return {'x_scaled': x_scaled} - - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'x_scaled': tf.io.FixedLenFeature([4], tf.float32)}) - - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.named_parameters(*tft_unit.cross_named_parameters( - [ - dict( - testcase_name='dense_key', - input_data=[{ - 'x_val': [-4, 4], - 'x_row_lengths': [0, 2], - 'key': 'a', - }, { - 'x_val': [0, 1], - 'x_row_lengths': [1, 1], - 'key': 'a', - }, { - 'x_val': [-4, 1, 1], - 'x_row_lengths': [3], - 'key': 'b', - }], - make_key_spec=lambda: tf.io.FixedLenFeature([], tf.string), - expected_data=[{ - 'scaled_by_min_max$ragged_values': [-1., 1.], - 'scaled_by_min_max$row_lengths_1': [0, 2], - 'scaled_to_0_1$ragged_values': [0., 1.], - 'scaled_to_0_1$row_lengths_1': [0, 2], - 'scaled_to_z_score$ragged_values': [-1.4852968, 1.310556], - 'scaled_to_z_score$row_lengths_1': [0, 2], - }, { - 'scaled_by_min_max$ragged_values': [0., 0.25], - 'scaled_by_min_max$row_lengths_1': [1, 1], - 'scaled_to_0_1$ragged_values': [0.5, 0.625], - 'scaled_to_0_1$row_lengths_1': [1, 1], - 'scaled_to_z_score$ragged_values': [-0.0873704, 0.26211122], - 'scaled_to_z_score$row_lengths_1': [1, 1], - }, { - 'scaled_by_min_max$ragged_values': [-1., 1., 1.], - 'scaled_by_min_max$row_lengths_1': [3], - 'scaled_to_0_1$ragged_values': [0., 1., 1.], - 'scaled_to_0_1$row_lengths_1': [3], - 'scaled_to_z_score$ragged_values': - [-1.4142135, 0.7071068, 0.7071068], - 'scaled_to_z_score$row_lengths_1': [3] - }], - ), - dict( - testcase_name='ragged_key', - input_data=[{ - 'x_val': [-4, 4], - 'x_row_lengths': [0, 2], - 'key_val': ['a', 'a'], - 'key_row_lengths': [0, 2], - }, { - 'x_val': [0, 1], - 'x_row_lengths': [1, 1], - 'key_val': ['a', 'b'], - 'key_row_lengths': [1, 1], - }, { - 'x_val': [-4, 1, 1], - 'x_row_lengths': [3], - 'key_val': ['b', 'a', 'b'], - 'key_row_lengths': [3], - }], - make_key_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda - tf.string, - value_key='key_val', - partitions=[ - tf.io.RaggedFeature.RowLengths('key_row_lengths') # pytype: disable=attribute-error - ]), - expected_data=[{ - 'scaled_by_min_max$ragged_values': [-1., 1.], - 'scaled_by_min_max$row_lengths_1': [0, 2], - 'scaled_to_0_1$ragged_values': [0., 1.], - 'scaled_to_0_1$row_lengths_1': [0, 2], - 'scaled_to_z_score$ragged_values': [-1.4852968, 1.310556], - 'scaled_to_z_score$row_lengths_1': [0, 2], - }, { - 'scaled_by_min_max$ragged_values': [0., 1.], - 'scaled_by_min_max$row_lengths_1': [1, 1], - 'scaled_to_0_1$ragged_values': [0.5, 1.], - 'scaled_to_0_1$row_lengths_1': [1, 1], - 'scaled_to_z_score$ragged_values': [-0.0873704, 0.7071068], - 'scaled_to_z_score$row_lengths_1': [1, 1], - }, { - 'scaled_by_min_max$ragged_values': [-1., 0.25, 1.], - 'scaled_by_min_max$row_lengths_1': [3], - 'scaled_to_0_1$ragged_values': [0., 0.625, 1.], - 'scaled_to_0_1$row_lengths_1': [3], - 'scaled_to_z_score$ragged_values': - [-1.4142135, 0.26211122, 0.7071068], - 'scaled_to_z_score$row_lengths_1': [3] - }]), - ], - [ - dict(testcase_name='int16', input_dtype=tf.int16), - dict(testcase_name='int32', input_dtype=tf.int32), - dict(testcase_name='int64', input_dtype=tf.int64), - dict(testcase_name='float32', input_dtype=tf.float32), - dict(testcase_name='float64', input_dtype=tf.float64), - ])) - def testScalePerKeyRagged(self, input_data, make_key_spec, expected_data, - input_dtype): - make_x_spec = lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda - tft_unit.canonical_numeric_dtype(input_dtype), - value_key='x_val', - partitions=[ - tf.io.RaggedFeature.RowLengths('x_row_lengths') # pytype: disable=attribute-error - ]) - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tft_unit.make_feature_spec_wrapper(make_x_spec), - 'key': tft_unit.make_feature_spec_wrapper(make_key_spec) - }) - - def preprocessing_fn(inputs): - scaled_to_z_score = tft.scale_to_z_score_per_key( - tf.cast(inputs['x'], input_dtype), inputs['key']) - self.assertEqual(scaled_to_z_score.dtype, _mean_output_dtype(input_dtype)) - return { - 'scaled_by_min_max': - tft.scale_by_min_max_per_key( - tf.cast(inputs['x'], input_dtype), - inputs['key'], - output_min=-1, - output_max=1), - 'scaled_to_0_1': - tft.scale_to_0_1_per_key( - tf.cast(inputs['x'], input_dtype), inputs['key']), - 'scaled_to_z_score': - tf.cast(scaled_to_z_score, tf.float32), - } - - expected_specs = {} - for output_name in ('scaled_by_min_max', 'scaled_to_0_1', - 'scaled_to_z_score'): - expected_specs[output_name] = tf.io.RaggedFeature( - tf.float32, - value_key='{}$ragged_values'.format(output_name), - partitions=[ - tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error - '{}$row_lengths_1'.format(output_name)) - ]) - expected_metadata = tft.DatasetMetadata.from_feature_spec(expected_specs) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - def testScaleMinMaxConstant(self): - - def preprocessing_fn(inputs): - return {'x_scaled': tft.scale_by_min_max(inputs['x'], 0, 10)} - - input_data = [{'x': 4}, {'x': 4}, {'x': 4}, {'x': 4}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.float32)}) - expected_data = [{ - 'x_scaled': 9.8201379 - }, { - 'x_scaled': 9.8201379 - }, { - 'x_scaled': 9.8201379 - }, { - 'x_scaled': 9.8201379 - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'x_scaled': tf.io.FixedLenFeature([], tf.float32)}) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - def testScaleMinMaxConstantElementwise(self): - - def preprocessing_fn(inputs): - outputs = {} - stacked_input = tf.stack([inputs['x'], inputs['y']], axis=1) - result = tft.scale_by_min_max( - stacked_input, output_min=0, output_max=10, elementwise=True) - outputs['x_scaled'], outputs['y_scaled'] = tf.unstack(result, axis=1) - return outputs - - input_data = [{ - 'x': 4, - 'y': 1 - }, { - 'x': 4, - 'y': 1 - }, { - 'x': 4, - 'y': 2 - }, { - 'x': 4, - 'y': 2 - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32) - }) - expected_data = [{ - 'x_scaled': 9.8201379, - 'y_scaled': 0 - }, { - 'x_scaled': 9.8201379, - 'y_scaled': 0 - }, { - 'x_scaled': 9.8201379, - 'y_scaled': 10 - }, { - 'x_scaled': 9.8201379, - 'y_scaled': 10 - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([], tf.float32), - 'y_scaled': tf.io.FixedLenFeature([], tf.float32) - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - def testScaleMinMaxError(self): - - def preprocessing_fn(inputs): - return {'x_scaled': tft.scale_by_min_max(inputs['x'], 2, 1)} - - input_data = [{'x': 1}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.float32)}) - expected_data = [{'x_scaled': float('nan')}] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'x_scaled': tf.io.FixedLenFeature([], tf.float32)}) - with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises - ValueError, 'output_min must be less than output_max' - ): - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - def testScaleMinMaxWithEmptyInputs(self): - # x is repeated `multiple` times to test elementwise mapping. - multiple = 3 - - def preprocessing_fn(inputs): - return { - 'x_scaled': - tft.scale_by_min_max(inputs['x']), - 'x_scaled_elementwise': - tft.scale_by_min_max( - tf.tile(inputs['x'], [1, multiple]), elementwise=True) - } - - input_data = [] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([1], tf.float32)}) - test_data = [{'x': [100]}, {'x': [1]}, {'x': [12]}] - expected_data = [{'x_scaled': [v], 'x_scaled_elementwise': [v] * multiple} - for v in [1., 0.7310585, 0.9999938]] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([1], tf.float32), - 'x_scaled_elementwise': tf.io.FixedLenFeature([multiple], tf.float32) - }) - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - test_data=test_data) - - @tft_unit.named_parameters(*(_SCALE_TO_Z_SCORE_TEST_CASES + - _SCALE_TO_Z_SCORE_NAN_TEST_CASES)) - def testScaleToZScore(self, input_data, output_data, elementwise): - - def preprocessing_fn(inputs): - x = inputs['x'] - x_cast = tf.cast(x, tf.as_dtype(input_data.dtype)) - x_scaled = tft.scale_to_z_score(x_cast, elementwise=elementwise) - self.assertEqual(x_scaled.dtype, tf.as_dtype(output_data.dtype)) - return {'x_scaled': tf.cast(x_scaled, tf.float32)} - - input_data_dicts = [{'x': x} for x in input_data] - expected_data_dicts = [{'x_scaled': x_scaled} for x_scaled in output_data] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.FixedLenFeature( - input_data.shape[1:], - tft_unit.canonical_numeric_dtype(tf.as_dtype( - input_data.dtype))), - }) - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature(output_data.shape[1:], tf.float32), - }) - self.assertAnalyzeAndTransformResults( - input_data_dicts, input_metadata, - preprocessing_fn, expected_data_dicts, expected_metadata) - - @tft_unit.parameters(*itertools.product([ - tf.int16, - tf.int32, - tf.int64, - tf.float32, - tf.float64, - ], (True, False))) - def testScaleToZScoreSparse(self, input_dtype, elementwise): - def preprocessing_fn(inputs): - z_score = tf.sparse.to_dense( - tft.scale_to_z_score( - tf.cast(inputs['x'], input_dtype), elementwise=elementwise), - default_value=np.nan) - z_score.set_shape([None, 4]) - self.assertEqual(z_score.dtype, _mean_output_dtype(input_dtype)) - return { - 'x_scaled': tf.cast(z_score, tf.float32) - } - - input_data = [ - {'idx': [0, 1], 'val': [-4, 10]}, - {'idx': [0, 1], 'val': [2, 4]}, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.SparseFeature('idx', 'val', - tft_unit.canonical_numeric_dtype(input_dtype), - 4) - }) - if elementwise: - # Mean(x) = [-1, 7] - # Var(x) = [9, 9] - # StdDev(x) = [3, 3] - expected_data = [ - { - 'x_scaled': [-1., 1., - float('nan'), - float('nan')] # [(-4 +1 ) / 3, (10 -7) / 3] - }, - { - 'x_scaled': [1., -1., - float('nan'), - float('nan')] # [(2 + 1) / 3, (4 - 7) / 3] - } - ] - else: - # Mean = 3 - # Var = 25 - # Std Dev = 5 - expected_data = [ - { - 'x_scaled': [-1.4, 1.4, float('nan'), - float('nan')] # [(-4 - 3) / 5, (10 - 3) / 5] - }, - { - 'x_scaled': [-.2, .2, float('nan'), - float('nan')] # [(2 - 3) / 5, (4 - 3) / 5] - } - ] - if input_dtype.is_floating: - input_data.append({'idx': [0, 1], 'val': [np.nan, np.nan]}) - expected_data.append({ - 'x_scaled': [float('nan'), - float('nan'), - float('nan'), - float('nan')] - }) - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'x_scaled': tf.io.FixedLenFeature([4], tf.float32)}) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.parameters( - (tf.int16,), - (tf.int32,), - (tf.int64,), - (tf.float32,), - (tf.float64,), - ) - def testScaleToZScoreSparsePerDenseKey(self, input_dtype): - # TODO(b/131852830) Add elementwise tests. - def preprocessing_fn(inputs): - - def scale_to_z_score_per_key(tensor, key): - z_score = tft.scale_to_z_score_per_key( - tf.cast(tensor, input_dtype), key=key, elementwise=False) - self.assertEqual(z_score.dtype, _mean_output_dtype(input_dtype)) - return tf.cast(z_score, tf.float32) - - return { - 'x_scaled': scale_to_z_score_per_key(inputs['x'], inputs['key']), - 'y_scaled': scale_to_z_score_per_key(inputs['y'], inputs['key']), - } - np_dtype = input_dtype.as_numpy_dtype - input_data = [{ - 'x': np.array([-4, 2], dtype=np_dtype), - 'y': np.array([0, 0], dtype=np_dtype), - 'key': 'a', - }, { - 'x': np.array([10, 4], dtype=np_dtype), - 'y': np.array([0, 0], dtype=np_dtype), - 'key': 'a', - }, { - 'x': np.array([1, -1], dtype=np_dtype), - 'y': np.array([0, 0], dtype=np_dtype), - 'key': 'b', - }] - # Mean(x) = 3, Mean(y) = 0 - # Var(x) = (-7^2 + -1^2 + 7^2 + 1^2) / 4 = 25, Var(y) = 0 - # StdDev(x) = 5, StdDev(y) = 0 - # 'b': - # Mean(x) = 0, Mean(y) = 0 - # Var(x) = 1, Var(y) = 0 - # StdDev(x) = 1, StdDev(y) = 0 - expected_data = [ - { - 'x_scaled': [-1.4, -.2], # [(-4 - 3) / 5, (2 - 3) / 5] - 'y_scaled': [0., 0.], - }, - { - 'x_scaled': [1.4, .2], # [(10 - 3) / 5, (4 - 3) / 5] - 'y_scaled': [0., 0.], - }, - { - 'x_scaled': [1., -1.], # [(1 - 0) / 1, (-1 - 0) / 1] - 'y_scaled': [0., 0.], - } - ] - - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.VarLenFeature(tft_unit.canonical_numeric_dtype(input_dtype)), - 'y': tf.io.VarLenFeature(tft_unit.canonical_numeric_dtype(input_dtype)), - 'key': tf.io.FixedLenFeature([], tf.string), - }) - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.VarLenFeature(tf.float32), - 'y_scaled': tf.io.VarLenFeature(tf.float32), - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.named_parameters( - dict(testcase_name='_empty_filename', - key_vocabulary_filename=''), - dict(testcase_name='_nonempty_filename', - key_vocabulary_filename='per_key'), - dict(testcase_name='_none_filename', - key_vocabulary_filename=None) - ) - def testScaleToZScorePerKey(self, key_vocabulary_filename): - # TODO(b/131852830) Add elementwise tests. - def preprocessing_fn(inputs): - - def scale_to_z_score_per_key(tensor, key, var_name=''): - if key_vocabulary_filename is None: - filename = None - else: - filename = key_vocabulary_filename + var_name - z_score = tft.scale_to_z_score_per_key( - tf.cast(tensor, tf.float32), key=key, elementwise=False, - key_vocabulary_filename=filename) - self.assertEqual(z_score.dtype, tf.float32) - return z_score - - return { - 'x_scaled': scale_to_z_score_per_key(inputs['x'], inputs['key'], 'x'), - 'y_scaled': scale_to_z_score_per_key(inputs['y'], inputs['key'], 'y'), - 's_scaled': scale_to_z_score_per_key(inputs['s'], inputs['key'], 's'), - } - - np_dtype = np.float32 - input_data = [ - { - 'x': np.array([-4], dtype=np_dtype), - 'y': np.array([0], dtype=np_dtype), - 's': 3, - 'key': 'a', - }, - { - 'x': np.array([10], dtype=np_dtype), - 'y': np.array([0], dtype=np_dtype), - 's': -3, - 'key': 'a', - }, - { - 'x': np.array([1], dtype=np_dtype), - 'y': np.array([0], dtype=np_dtype), - 's': 3, - 'key': 'b', - }, - { - 'x': np.array([2], dtype=np_dtype), - 'y': np.array([0], dtype=np_dtype), - 's': 3, - 'key': 'a', - }, - { - 'x': np.array([4], dtype=np_dtype), - 'y': np.array([0], dtype=np_dtype), - 's': -3, - 'key': 'a', - }, - { - 'x': np.array([-1], dtype=np_dtype), - 'y': np.array([0], dtype=np_dtype), - 's': -3, - 'key': 'b', - }, - { - 'x': np.array([np.nan], dtype=np_dtype), - 'y': np.array([np.nan], dtype=np_dtype), - 's': np.nan, - 'key': 'b', - }, - ] - # 'a': - # Mean(x) = 3, Mean(y) = 0 - # Var(x) = (-7^2 + -1^2 + 7^2 + 1^2) / 4 = 25, Var(y) = 0 - # StdDev(x) = 5, StdDev(y) = 0 - # 'b': - # Mean(x) = 0, Mean(y) = 0 - # Var(x) = 1, Var(y) = 0 - # StdDev(x) = 1, StdDev(y) = 0 - expected_data = [ - { - 'x_scaled': [-1.4], # [(-4 - 3) / 5, (2 - 3) / 5] - 'y_scaled': [0.], - 's_scaled': 1., - }, - { - 'x_scaled': [1.4], # [(10 - 3) / 5, (4 - 3) / 5] - 'y_scaled': [0.], - 's_scaled': -1., - }, - { - 'x_scaled': [1.], # [(1 - 0) / 1, (-1 - 0) / 1] - 'y_scaled': [0.], - 's_scaled': 1., - }, - { - 'x_scaled': [-.2], # [(-4 - 3) / 5, (2 - 3) / 5] - 'y_scaled': [0.], - 's_scaled': 1., - }, - { - 'x_scaled': [.2], # [(10 - 3) / 5, (4 - 3) / 5] - 'y_scaled': [0.], - 's_scaled': -1., - }, - { - 'x_scaled': [-1.], # [(1 - 0) / 1, (-1 - 0) / 1] - 'y_scaled': [0.], - 's_scaled': -1., - }, - { - 'x_scaled': [np.nan], - 'y_scaled': [np.nan], - 's_scaled': np.nan, - }, - ] - - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.FixedLenFeature([1], - tft_unit.canonical_numeric_dtype(tf.float32)), - 'y': - tf.io.FixedLenFeature([1], - tft_unit.canonical_numeric_dtype(tf.float32)), - 's': - tf.io.FixedLenFeature([], - tft_unit.canonical_numeric_dtype(tf.float32)), - 'key': - tf.io.FixedLenFeature([], tf.string), - }) - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([1], tf.float32), - 'y_scaled': tf.io.FixedLenFeature([1], tf.float32), - 's_scaled': tf.io.FixedLenFeature([], tf.float32), - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.named_parameters( - dict( - testcase_name='_float', - input_data=[ - { - 'x': [-4, 0], - 'key': 'a', - }, - { - 'x': [10, 0], - 'key': 'a', - }, - { - 'x': [2, 0], - 'key': 'a', - }, - { - 'x': [4, 0], - 'key': 'a', - }, - { - 'x': [1, 0], - 'key': 'b', - }, - { - 'x': [-1, 0], - 'key': 'b', - }, - { - 'x': [np.nan, np.nan], - 'key': 'b', - }, - ], - # Elementwise = True - # Mean [a, b] = [[ 3.0, 0.0], [0.0, 0.0]] - # Variance [a, b] = [[25.0, 0.0], [1.0, 0.0]] - # StdDev [a, b] = [[ 5.0, 0.0], [1.0, 0.0]] - expected_data=[ - { - 'x_scaled': [-1.4, 0.0], # [(-4 - 3) / 5, (0 - 0) / 0] - }, - { - 'x_scaled': [1.4, 0.0] # [(10 - 3) / 5, (0 - 0) / 0] - }, - { - 'x_scaled': [-0.2, 0.0] # [(2 - 3) / 5, (0 - 0) / 0] - }, - { - 'x_scaled': [0.2, 0.0], # [(4 - 3) / 5, (0 - 0) / 0] - }, - { - 'x_scaled': [1.0, 0.0] # [(1 - 0) / 1, (0 - 0) / 0] - }, - { - 'x_scaled': [-1.0, 0.0] # [(-1 - 0) / 1, (0 - 0) / 0] - }, - { - 'x_scaled': [np.nan, np.nan] - }, - ], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([2], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string), - }), - expected_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([2], tf.float32), - })), - dict( - testcase_name='float_3dims', - input_data=[ - { - 'x': [[-4, -8], [-12, -16]], - 'key': 'a', - }, - { - 'x': [[10, 20], [30, 40]], - 'key': 'a', - }, - { - 'x': [[2, 4], [6, 8]], - 'key': 'a', - }, - { - 'x': [[4, 8], [12, 16]], - 'key': 'a', - }, - { - 'x': [[1, 2], [3, 4]], - 'key': 'b', - }, - ], - expected_data=[ - { - 'x_scaled': [[-1.4, -1.4], [-1.4, -1.4]], - }, - { - 'x_scaled': [[1.4, 1.4], [1.4, 1.4]], - }, - { - 'x_scaled': [[-0.2, -0.2], [-0.2, -0.2]], - }, - { - 'x_scaled': [[0.2, 0.2], [0.2, 0.2]], - }, - { - 'x_scaled': [[0.0, 0.0], [0.0, 0.0]], - }, - ], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([2, 2], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string), - }), - expected_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([2, 2], tf.float32), - })), - ) - def testScaleToZScorePerKeyElementwise(self, input_data, expected_data, - input_metadata, expected_metadata): - - def preprocessing_fn(inputs): - outputs = {} - outputs['x_scaled'] = tft.scale_to_z_score_per_key( - tf.cast(inputs['x'], tf.float32), - key=inputs['key'], - elementwise=True, - key_vocabulary_filename=None) - self.assertEqual(outputs['x_scaled'].dtype, tf.float32) - return outputs - - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - @tft_unit.parameters( - (tf.int16,), - (tf.int32,), - (tf.int64,), - (tf.float32,), - (tf.float64,), - ) - def testScaleToZScoreSparsePerKey(self, input_dtype): - # TODO(b/131852830) Add elementwise tests. - def preprocessing_fn(inputs): - z_score = tf.sparse.to_dense( - tft.scale_to_z_score_per_key( - tf.cast(inputs['x'], input_dtype), - inputs['key'], - elementwise=False), - default_value=np.nan) - z_score.set_shape([None, 4]) - self.assertEqual(z_score.dtype, _mean_output_dtype(input_dtype)) - return { - 'x_scaled': tf.cast(z_score, tf.float32) - } - - input_data = [ - {'idx': [0, 1], 'val': [-4, 10], 'key_idx': [0, 1], 'key': ['a', 'a']}, - {'idx': [0, 1], 'val': [2, 1], 'key_idx': [0, 1], 'key': ['a', 'b']}, - {'idx': [0, 1], 'val': [-1, 4], 'key_idx': [0, 1], 'key': ['b', 'a']}, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'key': - tf.io.SparseFeature('key_idx', 'key', tf.string, 4), - 'x': - tf.io.SparseFeature('idx', 'val', - tft_unit.canonical_numeric_dtype(input_dtype), - 4) - }) - # 'a': - # Mean = 3 - # Var = 25 - # Std Dev = 5 - # 'b': - # Mean = 0 - # Var = 1 - # Std Dev = 1 - expected_data = [ - { - 'x_scaled': [-1.4, 1.4, float('nan'), - float('nan')] # [(-4 - 3) / 5, (10 - 3) / 5] - }, - { - 'x_scaled': [-.2, 1., float('nan'), - float('nan')] # [(2 - 3) / 5, (1 - 0) / 1] - }, - { - 'x_scaled': [-1., .2, - float('nan'), - float('nan')] # [(-1 - 0) / 1, (4 - 3) / 5] + ) + expected_outputs = { + "x_y_sums": np.array([[100, 4950]], np.int64), } - ] - if input_dtype.is_floating: - input_data.append({ - 'idx': [0, 1], - 'val': [np.nan, np.nan], - 'key_idx': [0, 1], - 'key': ['a', 'b'] - }) - expected_data.append({ - 'x_scaled': [float('nan'), - float('nan'), - float('nan'), - float('nan')] - }) - expected_metadata = tft.DatasetMetadata.from_feature_spec( - {'x_scaled': tf.io.FixedLenFeature([4], tf.float32)}) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - def testScaleToZScoreWithEmptyInputs(self): - # x is repeated `multiple` times to test elementwise mapping. - multiple = 3 - - def preprocessing_fn(inputs): - return { - 'x_scaled': - tft.scale_to_z_score(inputs['x']), - 'x_scaled_elementwise': - tft.scale_to_z_score( - tf.tile(inputs['x'], [1, multiple]), elementwise=True) - } - - input_data = [] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([1], tf.float32)}) - test_data = [{'x': [100]}, {'x': [1]}, {'x': [12]}] - expected_data = [{'x_scaled': [v], 'x_scaled_elementwise': [v] * multiple} - for v in [100., 1., 12.]] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_scaled': tf.io.FixedLenFeature([1], tf.float32), - 'x_scaled_elementwise': tf.io.FixedLenFeature([multiple], tf.float32) - }) - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - test_data=test_data) - - def testMeanAndVar(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - mean, var = analyzers._mean_and_var(inputs['x']) - return { - 'mean': mean, - 'var': var - } - - # NOTE: We force 11 batches: data has 110 elements and we request a batch - # size of 10. - input_data = [{'x': [x if x < 101 else np.nan]} for x in range(1, 111)] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([1], tf.float32)}) - expected_outputs = { - 'mean': np.float32(50.5), - 'var': np.float32(833.25) - } - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=10) - - def testMeanAndVarPerKey(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - key_vocab, mean, var = analyzers._mean_and_var_per_key( - inputs['x'], inputs['key']) - return { - 'key_vocab': key_vocab, - 'mean': mean, - 'var': tf.round(100 * var) / 100.0 - } - - # NOTE: We force 12 batches: data has 120 elements and we request a batch - # size of 10. - input_data = [{'x': [x], 'key': 'a' if x < 50 else 'b'} - for x in range(1, 101)] + [{'x': [np.nan], 'key': 'a'}] * 20 - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([1], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string) - }) - expected_outputs = { - 'key_vocab': np.array([b'a', b'b'], object), - 'mean': np.array([25, 75], np.float32), - 'var': np.array([200, 216.67], np.float32) - } - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=10) - - def testMeanAndVarPerKeyElementwise(self): - - def analyzer_fn(inputs): - key_vocab, mean, var = analyzers._mean_and_var_per_key( - inputs['x'], inputs['key'], reduce_instance_dims=False) - return { - 'key_vocab': key_vocab, - 'mean': mean, - 'var': tf.round(100 * var) / 100.0 - } - - input_data = input_data = [{ - 'x': [-4, -1], - 'key': 'a', - }, { - 'x': [10, 0], - 'key': 'a', - }, { - 'x': [2, 0], - 'key': 'a', - }, { - 'x': [4, -1], - 'key': 'a', - }, { - 'x': [10, 0], - 'key': 'b', - }, { - 'x': [0, 10], - 'key': 'b', - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([2], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string) - }) - expected_outputs = { - 'key_vocab': np.array([b'a', b'b'], object), - 'mean': np.array([[3.0, -0.5], [5.0, 5.0]], np.float32), - 'var': np.array([[25.0, 0.25], [25.0, 25.0]], np.float32) - } - self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn, - expected_outputs) - - @tft_unit.named_parameters( - dict( - testcase_name='_dense_2d', - input_data=[{ - 'x': [4, 8], - 'key': 'a' - }, { - 'x': [1, 5], - 'key': 'a' - }, { - 'x': [5, 9], - 'key': 'a' - }, { - 'x': [2, 6], - 'key': 'a' - }, { - 'x': [-2, 0], - 'key': 'b' - }, { - 'x': [0, 2], - 'key': 'b' - }, { - 'x': [2, 4], - 'key': 'b' - }], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([2], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string), - }), - reduce_instance_dims=True, - expected_outputs={ - 'key_vocab': np.array([b'a', b'b'], object), - 'min_x_value': np.array([1, -2], np.float32), - 'max_x_value': np.array([9, 4], np.float32), - }), - dict( - testcase_name='_dense_2d_elementwise', - input_data=[{ - 'x': [4, 8], - 'key': 'a' - }, { - 'x': [1, 5], - 'key': 'a' - }, { - 'x': [5, 9], - 'key': 'a' - }, { - 'x': [2, 6], - 'key': 'a' - }, { - 'x': [-2, 0], - 'key': 'b' - }, { - 'x': [0, 2], - 'key': 'b' - }, { - 'x': [2, 4], - 'key': 'b' - }], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([2], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string), - }), - reduce_instance_dims=False, - expected_outputs={ - 'key_vocab': np.array([b'a', b'b'], object), - 'min_x_value': np.array([[1, 5], [-2, 0]], np.float32), - 'max_x_value': np.array([[5, 9], [2, 4]], np.float32), - }), - dict( - testcase_name='_dense_3d', - input_data=[ - { - 'x': [[1, 5], [1, 1]], - 'key': 'a' - }, - { - 'x': [[5, 1], [5, 5]], - 'key': 'a' - }, - { - 'x': [[2, 2], [2, 5]], - 'key': 'a' - }, - { - 'x': [[3, -3], [3, 3]], - 'key': 'b' - }, - ], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([2, 2], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string), - }), - reduce_instance_dims=True, - expected_outputs={ - 'key_vocab': np.array([b'a', b'b'], object), - 'min_x_value': np.array([1, -3], np.float32), - 'max_x_value': np.array([5, 3], np.float32), - }), - dict( - testcase_name='_dense_3d_elementwise', - input_data=[ - { - 'x': [[1, 5], [1, 1]], - 'key': 'a' - }, - { - 'x': [[5, 1], [5, 5]], - 'key': 'a' - }, - { - 'x': [[2, 2], [2, 5]], - 'key': 'a' - }, - { - 'x': [[3, -3], [3, 3]], - 'key': 'b' - }, - ], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([2, 2], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string), - }), - reduce_instance_dims=False, - expected_outputs={ - 'key_vocab': - np.array([b'a', b'b'], object), - 'min_x_value': - np.array([[[1, 1], [1, 1]], [[3, -3], [3, 3]]], np.float32), - 'max_x_value': - np.array([[[5, 5], [5, 5]], [[3, -3], [3, 3]]], np.float32), - }), - ) - def testMinAndMaxPerKey(self, input_data, input_metadata, - reduce_instance_dims, expected_outputs): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - key_vocab, min_x_value, max_x_value = analyzers._min_and_max_per_key( - x=inputs['x'], - key=inputs['key'], - reduce_instance_dims=reduce_instance_dims) - return { - 'key_vocab': key_vocab, - 'min_x_value': min_x_value, - 'max_x_value': max_x_value, - } - - self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn, - expected_outputs) - - @tft_unit.parameters((True,), (False,)) - def testPerKeyWithOOVKeys(self, use_vocabulary): - def preprocessing_fn(inputs): - result = {} - result['x_scaled'] = tft.scale_to_0_1_per_key( - inputs['x'], - inputs['key'], - elementwise=False, - key_vocabulary_filename='a' if use_vocabulary else None) - result['x_z_score'] = tft.scale_to_z_score_per_key( - inputs['x'], - inputs['key'], - elementwise=False, - key_vocabulary_filename='b' if use_vocabulary else None) - # TODO(b/179891014): Add key_vocabulary_filename to bucketize_per_key once - # implemented. - result['x_bucketized'] = tft.bucketize_per_key(inputs['x'], inputs['key'], - 3) - return result - - input_data = [ - dict(x=4, key='a'), - dict(x=1, key='a'), - dict(x=5, key='a'), - dict(x=2, key='a'), - dict(x=25, key='b'), - dict(x=5, key='b') - ] - test_data = input_data + [dict(x=5, key='oov')] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'key': tf.io.FixedLenFeature([], tf.string) - }) - - expected_data = [{ - 'x_scaled': 0.75, - 'x_z_score': 0.6324555, - 'x_bucketized': 2, - }, { - 'x_scaled': 0.0, - 'x_z_score': -1.264911, - 'x_bucketized': 0, - }, { - 'x_scaled': 1.0, - 'x_z_score': 1.264911, - 'x_bucketized': 2, - }, { - 'x_scaled': 0.25, - 'x_z_score': -0.6324555, - 'x_bucketized': 1, - }, { - 'x_scaled': 1.0, - 'x_z_score': 1.0, - 'x_bucketized': 2, - }, { - 'x_scaled': 0.0, - 'x_z_score': -1.0, - 'x_bucketized': 1, - }, { - 'x_scaled': _sigmoid(5), - 'x_z_score': 5.0, - 'x_bucketized': -1, - }] - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - test_data=test_data) - - @tft_unit.named_parameters( - dict( - testcase_name='_string', - input_data=[{ - 'key': 'a' if x < 25 else 'b' - } for x in range(100)], - input_metadata=tft.DatasetMetadata.from_feature_spec( - {'key': tf.io.FixedLenFeature([], tf.string)}), - expected_outputs={ - 'elements': np.array([b'a', b'b'], object), - 'counts': np.array([25, 75], np.int64) - }), - dict( - testcase_name='_int', - input_data=[{ - 'key': 0 if x < 25 else 1 - } for x in range(100)], - input_metadata=tft.DatasetMetadata.from_feature_spec( - {'key': tf.io.FixedLenFeature([], tf.int64)}), - expected_outputs={ - 'elements': np.array([0, 1], np.int64), - 'counts': np.array([25, 75], np.int64) - }), - dict( - testcase_name='_int_sparse', - input_data=[{ - 'key': [0] if x < 25 else [1] - } for x in range(100)], - input_metadata=tft.DatasetMetadata.from_feature_spec( - {'key': tf.io.VarLenFeature(tf.int64)}), - expected_outputs={ - 'elements': np.array([0, 1], np.int64), - 'counts': np.array([25, 75], np.int64) - }), - dict( - testcase_name='_3d_sparse', - input_data=[ - { # pylint: disable=g-complex-comprehension - 'key': [0, 1] if x < 25 else [1], - 'idx0': [0, 1] if x < 25 else [0], - 'idx1': [0, 1] if x < 25 else [0] - } for x in range(100) - ], - input_metadata=tft.DatasetMetadata.from_feature_spec({ - 'key': - tf.io.SparseFeature(['idx0', 'idx1'], 'key', tf.int64, [2, 2]) - }), - expected_outputs={ - 'elements': np.array([0, 1], np.int64), - 'counts': np.array([25, 100], np.int64) - }, - ), - ) - def testCountPerKey(self, input_data, input_metadata, expected_outputs): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - elements, counts = analyzers.count_per_key(inputs['key']) - return { - 'elements': elements, - 'counts': counts - } - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs) - - @tft_unit.named_parameters( - dict( - testcase_name='_uniform', - input_data=[{ - 'x': [x] - } for x in range(10, 100)], - make_feature_spec=lambda: tf.io.FixedLenFeature([1], tf.int64), - boundaries=10 * np.arange(11, dtype=np.float32), - categorical=False, - expected_outputs={ - 'hist': - 10 * np.array([0] + [1] * 9, np.int64), - 'boundaries': - 10 * np.arange(11, dtype=np.float32).reshape((1, 11)) - }), - dict( - testcase_name='_categorical_string', - input_data=[{ - 'x': [str(x % 10) + '_'] - } for x in range(1, 101)], - make_feature_spec=lambda: tf.io.FixedLenFeature([1], tf.string), - boundaries=None, - categorical=True, - expected_outputs={ - 'hist': - 10 * np.ones(10, np.int64), - 'boundaries': - np.asarray( - sorted([ - tf.compat.as_bytes(str(x % 10) + '_') - for x in range(10) - ]), - dtype=object) - }, - ), - dict( - testcase_name='_categorical_int', - input_data=[{ - 'x': [(x % 10)] - } for x in range(1, 101)], - make_feature_spec=lambda: tf.io.FixedLenFeature([1], tf.int64), - boundaries=None, - categorical=True, - expected_outputs={ - 'hist': 10 * np.ones(10, np.int64), - 'boundaries': np.arange(10) - }), - dict( - testcase_name='_sparse', - input_data=[{ # pylint: disable=g-complex-comprehension - 'val': [(x % 10)], - 'idx0': [(x % 2)], - 'idx1': [((x + 1) % 2)] - } for x in range(1, 101)], - make_feature_spec=lambda: tf.io.SparseFeature( # pylint: disable=g-long-lambda - ['idx0', 'idx1'], 'val', tf.int64, [2, 2]), - boundaries=None, - categorical=True, - expected_outputs={ - 'hist': 10 * np.ones(10, np.int64), - 'boundaries': np.arange(10) - }), - dict( - testcase_name='_ragged', - input_data=[{ # pylint: disable=g-complex-comprehension - 'val': [x % 10, 9 - (x % 10)], - 'row_lengths': [0, 1, 1], - } for x in range(1, 101)], - make_feature_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda - tf.int64, - value_key='val', - partitions=[ - tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error - ]) - , - boundaries=None, - categorical=True, - expected_outputs={ - 'hist': 20 * np.ones(10, np.int64), - 'boundaries': np.arange(10) - }), - ) - def testHistograms(self, input_data, make_feature_spec, boundaries, - categorical, expected_outputs): - self._SkipIfOutputRecordBatches() - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tft_unit.make_feature_spec_wrapper(make_feature_spec)}) - - def analyzer_fn(inputs): - counts, bucket_boundaries = analyzers.histogram( - inputs['x'], categorical=categorical, boundaries=boundaries) - if not categorical: - bucket_boundaries = tf.math.round(bucket_boundaries) - return {'hist': counts, 'boundaries': bucket_boundaries} - - self.assertAnalyzerOutputs(input_data, - input_metadata, - analyzer_fn, - expected_outputs) - - def testProbCategoricalInt(self): - def preprocessing_fn(inputs): - return {'probs': tft.estimated_probability_density(inputs['x'], - categorical=True)} - - # NOTE: We force 10 batches: data has 100 elements and we request a batch - # size of 10. - input_data = [{'x': [x % 10]} for x in range(1, 101)] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([1], tf.int64)}) - expected_outputs = [{ - 'probs': np.array(np.ones(1) / 10.0, np.float32) - } for _ in range(100)] - self.assertAnalyzeAndTransformResults(input_data, - input_metadata, - preprocessing_fn, - expected_outputs, - desired_batch_size=10) - - def testProbCategorical(self): - def preprocessing_fn(inputs): - return {'probs': tft.estimated_probability_density(inputs['x'], - categorical=True)} - - # NOTE: We force 10 batches: data has 100 elements and we request a batch - # size of 10. - input_data = [{'x': [str(x % 10) + '_']} for x in range(1, 101)] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([1], tf.string)}) - expected_outputs = [{ - 'probs': np.array(np.ones(1) / 10.0, np.float32) - } for _ in range(100)] - self.assertAnalyzeAndTransformResults(input_data, - input_metadata, - preprocessing_fn, - expected_outputs, - desired_batch_size=10) - - def testProbTenBoundaries(self): - # If we draw uniformly from a range (0, 100], the expected density is 0.01. - def preprocessing_fn(inputs): - return {'probs': tft.estimated_probability_density( - inputs['x'], boundaries=list(range(0, 101, 10)))} - - # NOTE: We force 10 batches: data has 100 elements and we request a batch - # size of 10. - input_data = [{'x': [x]} for x in range(100)] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([1], tf.int64)}) - expected_outputs = [{ - 'probs': np.array(np.ones(1) / (100.0), np.float32) - } for _ in range(100)] - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_outputs, - desired_batch_size=10) - - @tft_unit.named_parameters( - {'testcase_name': 'uniform', - 'boundaries': 6, - 'input_data': [{'x': [x]} for x in range(100)], - 'expected_outputs': [{'probs': np.array(np.ones((1)) / 99.0, np.float32) - } for _ in range(100)] - }, - {'testcase_name': 'nonuniform_with_zeros', - 'boundaries': 5, - 'input_data': [{'x': [x]} for x in list(range(25)) + ( - list(range(50, 75)) + list(range(50, 75)) + list(range(75, 100)))], - 'expected_outputs': [{'probs': np.ones((1), np.float32) / 99.0 * ( - 2.0 if 24 < i < 75 else 1.0)} for i in range(100)] - }, - {'testcase_name': 'empty', - 'boundaries': 5, - 'input_data': [], - 'expected_outputs': [] - }, - ) - def testProbUnknownBoundaries( - self, input_data, expected_outputs, boundaries): - # Test 1 has 100 points over a range of 99; test 2 is an uneven distribution - def preprocessing_fn(inputs): - return {'probs': tft.estimated_probability_density(inputs['x'], - boundaries=boundaries)} - - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([1], tf.int64)}) - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_outputs) - - @tft_unit.named_parameters( - dict( - testcase_name='Int64In', - input_dtype=tf.int64, - output_dtypes={ - 'min': tf.int64, - 'max': tf.int64, - 'sum': tf.int64, - 'size': tf.int64, - 'mean': tf.float32, - 'var': tf.float32 - }), - dict( - testcase_name='Int32In', - input_dtype=tf.int32, - output_dtypes={ - 'min': tf.int32, - 'max': tf.int32, - 'sum': tf.int64, - 'size': tf.int64, - 'mean': tf.float32, - 'var': tf.float32 - }), - dict( - testcase_name='Int16In', - input_dtype=tf.int16, - output_dtypes={ - 'min': tf.int16, - 'max': tf.int16, - 'sum': tf.int64, - 'size': tf.int64, - 'mean': tf.float32, - 'var': tf.float32 - }), - dict( - testcase_name='Float64In', - input_dtype=tf.float64, - output_dtypes={ - 'min': tf.float64, - 'max': tf.float64, - 'sum': tf.float64, - 'size': tf.int64, - 'mean': tf.float64, - 'var': tf.float64 - }), - dict( - testcase_name='Float32In', - input_dtype=tf.float32, - output_dtypes={ - 'min': tf.float32, - 'max': tf.float32, - 'sum': tf.float32, - 'size': tf.int64, - 'mean': tf.float32, - 'var': tf.float32 - }), - dict( - testcase_name='Float16In', - input_dtype=tf.float16, - output_dtypes={ - 'min': tf.float16, - 'max': tf.float16, - 'sum': tf.float32, - 'size': tf.int64, - 'mean': tf.float16, - 'var': tf.float16 - }) - ) - def testNumericAnalyzersWithScalarInputs(self, input_dtype, output_dtypes): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - a = tf.cast(inputs['a'], input_dtype) - - def assert_and_cast_dtype(tensor, out_dtype): - self.assertEqual(tensor.dtype, out_dtype) - return tf.cast(tensor, tft_unit.canonical_numeric_dtype(out_dtype)) - - return { - 'min': assert_and_cast_dtype(tft.min(a), - output_dtypes['min']), - 'max': assert_and_cast_dtype(tft.max(a), - output_dtypes['max']), - 'sum': assert_and_cast_dtype(tft.sum(a), - output_dtypes['sum']), - 'size': assert_and_cast_dtype(tft.size(a), - output_dtypes['size']), - 'mean': assert_and_cast_dtype(tft.mean(a), - output_dtypes['mean']), - 'var': assert_and_cast_dtype(tft.var(a), - output_dtypes['var']), - } - - input_data = [{'a': 4}, {'a': 1}] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': - tf.io.FixedLenFeature([], - tft_unit.canonical_numeric_dtype(input_dtype)) - }) - expected_outputs = { - 'min': - np.array( - 1, - tft_unit.canonical_numeric_dtype( - output_dtypes['min']).as_numpy_dtype), - 'max': - np.array( - 4, - tft_unit.canonical_numeric_dtype( - output_dtypes['max']).as_numpy_dtype), - 'sum': - np.array( - 5, - tft_unit.canonical_numeric_dtype( - output_dtypes['sum']).as_numpy_dtype), - 'size': - np.array( - 2, - tft_unit.canonical_numeric_dtype( - output_dtypes['size']).as_numpy_dtype), - 'mean': - np.array( - 2.5, - tft_unit.canonical_numeric_dtype( - output_dtypes['mean']).as_numpy_dtype), - 'var': - np.array( - 2.25, - tft_unit.canonical_numeric_dtype( - output_dtypes['var']).as_numpy_dtype), - } - - self.assertAnalyzerOutputs( - input_data, input_metadata, analyzer_fn, expected_outputs) - - @tft_unit.named_parameters(*tft_unit.cross_named_parameters( - [ - dict( - testcase_name='sparse', - input_data=[ - { - 'idx0': [0, 1], - 'idx1': [0, 1], - 'val': [0, 1], - }, - { - 'idx0': [1, 2], - 'idx1': [1, 3], - 'val': [2, 3], - }, - ], - make_feature_spec=lambda dtype: tf.io.SparseFeature( # pylint: disable=g-long-lambda - ['idx0', 'idx1'], 'val', dtype, (3, 4)), - expected_outputs={ - 'min': 0., - 'max': 3., - 'sum': 6., - 'size': 4, - 'mean': 1.5, - 'var': 1.25, - }, - reduce_instance_dims=True, - ), - dict( - testcase_name='sparse_elementwise', - input_data=[ - { - 'idx0': [0, 1], - 'idx1': [0, 1], - 'val': [0, 1], - }, - { - 'idx0': [1, 2], - 'idx1': [1, 3], - 'val': [2, 3], - }, - ], - make_feature_spec=lambda dtype: tf.io.SparseFeature( # pylint: disable=g-long-lambda - ['idx0', 'idx1'], 'val', dtype, (3, 4)), - expected_outputs={ - # We use np.nan in place of missing values here but replace - # them accordingly to the dtype in the test. - 'min': [[0., np.nan, np.nan, np.nan], - [np.nan, 1., np.nan, np.nan], - [np.nan, np.nan, np.nan, 3.]], - 'max': [[0., np.nan, np.nan, np.nan], - [np.nan, 2., np.nan, np.nan], - [np.nan, np.nan, np.nan, 3.]], - 'sum': [[0., 0., 0., 0.], [0., 3., 0., 0.], [0., 0., 0., 3.]], - 'size': [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 0, 1]], - 'mean': [[0., np.nan, np.nan, np.nan], - [np.nan, 1.5, np.nan, np.nan], - [np.nan, np.nan, np.nan, 3.]], - 'var': [[0., np.nan, np.nan, np.nan], - [np.nan, 0.25, np.nan, np.nan], - [np.nan, np.nan, np.nan, 0.]], - }, - reduce_instance_dims=False, - ), - dict( - testcase_name='ragged', - input_data=[ - { - 'val': [0., 2., 3.], - 'row_lengths': [0, 3], - }, - { - 'val': [3., 3., 1.], - 'row_lengths': [3], - }, - ], - make_feature_spec=lambda dtype: tf.io.RaggedFeature( # pylint: disable=g-long-lambda - dtype, - value_key='val', - partitions=[tf.io.RaggedFeature.RowLengths('row_lengths')]), # pytype: disable=attribute-error - expected_outputs={ - 'min': 0., - 'max': 3., - 'sum': 12., - 'size': 6, - 'mean': 2., - 'var': 1.333333, - }, - reduce_instance_dims=True, - ) - ], - [ - dict(testcase_name='int16', input_dtype=tf.int16), - dict(testcase_name='int32', input_dtype=tf.int32), - dict(testcase_name='int64', input_dtype=tf.int64), - dict(testcase_name='float32', input_dtype=tf.float32), - dict(testcase_name='tf.float64', input_dtype=tf.float64), - dict(testcase_name='tf.uint8', input_dtype=tf.uint8), - dict(testcase_name='tf.uint16', input_dtype=tf.uint16), - ])) - def testNumericAnalyzersWithCompositeInputs(self, input_data, - make_feature_spec, - expected_outputs, - reduce_instance_dims, - input_dtype): - self._SkipIfOutputRecordBatches() - output_dtype = tft_unit.canonical_numeric_dtype(input_dtype) - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': tft_unit.make_feature_spec_wrapper(make_feature_spec, output_dtype) - }) - - def analyzer_fn(inputs): - return { - 'min': tft.min(inputs['a'], reduce_instance_dims), - 'max': tft.max(inputs['a'], reduce_instance_dims), - 'sum': tft.sum(inputs['a'], reduce_instance_dims), - 'size': tft.size(inputs['a'], reduce_instance_dims), - 'mean': tft.mean(inputs['a'], reduce_instance_dims), - 'var': tft.var(inputs['a'], reduce_instance_dims), - } - - input_val_dtype = input_dtype.as_numpy_dtype - # Cast input values to appropriate type. - for instance in input_data: - instance['val'] = np.array(instance['val'], input_val_dtype) - if not reduce_instance_dims: - if input_dtype.is_floating: - missing_value_max = float('nan') - missing_value_min = float('nan') - else: - missing_value_max = np.iinfo(output_dtype.as_numpy_dtype).min - missing_value_min = np.iinfo(output_dtype.as_numpy_dtype).max - # Replace NaNs with proper missing values. - for row in expected_outputs['min']: - for idx in range(len(row)): - if np.isnan(row[idx]): - row[idx] = missing_value_min - for row in expected_outputs['max']: - for idx in range(len(row)): - if np.isnan(row[idx]): - row[idx] = missing_value_max - for op in ('min', 'max', 'sum'): - expected_outputs[op] = np.array(expected_outputs[op], - output_dtype.as_numpy_dtype) - expected_outputs['size'] = np.array(expected_outputs['size'], np.int64) - expected_outputs['mean'] = np.array(expected_outputs['mean'], np.float32) - expected_outputs['var'] = np.array(expected_outputs['var'], np.float32) - self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn, - expected_outputs) - - @tft_unit.named_parameters( - dict( - testcase_name='sparse', - input_data=[ - { - 'idx0': [0, 1], - 'idx1': [0, 1], - 'val': np.array([0, 1], dtype=np.int64) - }, - { - 'idx0': [1, 2], - 'idx1': [1, 3], - 'val': np.array([2, 3], dtype=np.int64) - }, - ], - make_feature_spec=lambda: tf.io.SparseFeature( # pylint: disable=g-long-lambda - ['idx0', 'idx1'], 'val', tf.int64, (3, 4)), - elementwise=False, - expected_outputs=[{ - 'scale_to_0_1$sparse_indices_0': - np.array([0, 1]), - 'scale_to_0_1$sparse_indices_1': - np.array([0, 1]), - 'scale_to_z_score$sparse_indices_0': - np.array([0, 1]), - 'scale_to_z_score$sparse_indices_1': - np.array([0, 1]), - 'scale_by_min_max$sparse_indices_0': - np.array([0, 1]), - 'scale_by_min_max$sparse_indices_1': - np.array([0, 1]), - 'scale_to_0_1$sparse_values': - np.array([0., 1. / 3.], dtype=np.float32), - 'scale_to_z_score$sparse_values': - np.array([-1.5 / np.sqrt(1.25), -0.5 / np.sqrt(1.25)], - dtype=np.float32), - 'scale_by_min_max$sparse_values': - np.array([0., 1. / 3.], dtype=np.float32), - }, { - 'scale_to_0_1$sparse_indices_0': - np.array([1, 2]), - 'scale_to_0_1$sparse_indices_1': - np.array([1, 3]), - 'scale_to_z_score$sparse_indices_0': - np.array([1, 2]), - 'scale_to_z_score$sparse_indices_1': - np.array([1, 3]), - 'scale_by_min_max$sparse_indices_0': - np.array([1, 2]), - 'scale_by_min_max$sparse_indices_1': - np.array([1, 3]), - 'scale_to_0_1$sparse_values': - np.array([2. / 3., 1.], dtype=np.float32), - 'scale_to_z_score$sparse_values': - np.array([.5 / np.sqrt(1.25), 1.5 / np.sqrt(1.25)], - dtype=np.float32), - 'scale_by_min_max$sparse_values': - np.array([2. / 3., 1.], dtype=np.float32) - }]), - dict( - testcase_name='sparse_elementwise', - input_data=[ - { - 'idx0': [0, 1], - 'idx1': [0, 1], - 'val': np.array([0, 1], dtype=np.int64) - }, - { - 'idx0': [1, 2], - 'idx1': [1, 3], - 'val': np.array([2, 3], dtype=np.int64) - }, - ], - make_feature_spec=lambda: tf.io.SparseFeature( # pylint: disable=g-long-lambda - ['idx0', 'idx1'], 'val', tf.int64, (3, 4)), - elementwise=True, - expected_outputs=[{ - 'scale_to_0_1$sparse_indices_0': - np.array([0, 1]), - 'scale_to_0_1$sparse_indices_1': - np.array([0, 1]), - 'scale_to_z_score$sparse_indices_0': - np.array([0, 1]), - 'scale_to_z_score$sparse_indices_1': - np.array([0, 1]), - 'scale_by_min_max$sparse_indices_0': - np.array([0, 1]), - 'scale_by_min_max$sparse_indices_1': - np.array([0, 1]), - 'scale_to_0_1$sparse_values': - np.array([0.5, 0.], dtype=np.float32), - 'scale_to_z_score$sparse_values': - np.array([0, -1], dtype=np.float32), - 'scale_by_min_max$sparse_values': - np.array([0.5, 0.], dtype=np.float32), - }, { - 'scale_to_0_1$sparse_indices_0': - np.array([1, 2]), - 'scale_to_0_1$sparse_indices_1': - np.array([1, 3]), - 'scale_to_z_score$sparse_indices_0': - np.array([1, 2]), - 'scale_to_z_score$sparse_indices_1': - np.array([1, 3]), - 'scale_by_min_max$sparse_indices_0': - np.array([1, 2]), - 'scale_by_min_max$sparse_indices_1': - np.array([1, 3]), - 'scale_to_0_1$sparse_values': - np.array([1., _sigmoid(3)], dtype=np.float32), - 'scale_to_z_score$sparse_values': - np.array([1, 0], dtype=np.float32), - 'scale_by_min_max$sparse_values': - np.array([1., _sigmoid(3)], dtype=np.float32), - }]), - dict( - testcase_name='ragged', - input_data=[ - { - 'val': [0., 2., 3.], - 'row_lengths': [0, 3], - }, - { - 'val': [3., 3., 1.], - 'row_lengths': [3], - }, - ], - make_feature_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda - tf.float32, - value_key='val', - partitions=[tf.io.RaggedFeature.RowLengths('row_lengths')]), # pytype: disable=attribute-error - elementwise=False, - expected_outputs=[{ - 'scale_by_min_max$ragged_values': [0., 0.6666667, 1.], - 'scale_to_z_score$row_lengths_1': [0, 3], - 'scale_to_0_1$row_lengths_1': [0, 3], - 'scale_to_0_1$ragged_values': [0., 0.6666667, 1.], - 'scale_to_z_score$ragged_values': [-1.7320509, 0., 0.86602545], - 'scale_by_min_max$row_lengths_1': [0, 3], - }, { - 'scale_to_0_1$row_lengths_1': [3], - 'scale_by_min_max$row_lengths_1': [3], - 'scale_to_z_score$ragged_values': [ - 0.86602545, 0.86602545, -0.86602545 - ], - 'scale_to_z_score$row_lengths_1': [3], - 'scale_to_0_1$ragged_values': [1., 1., 0.33333334], - 'scale_by_min_max$ragged_values': [1., 1., 0.33333334], - }], - ), - dict( - testcase_name='ragged_uniform', - input_data=[ - { - 'val': [0., 2., 3., 11., 2., 7.], - }, - { - 'val': [3., 1., 2.], - }, - ], - make_feature_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda - tf.float32, - value_key='val', - partitions=[ - tf.io.RaggedFeature.UniformRowLength(3), # pytype: disable=attribute-error - ]), - elementwise=False, - expected_outputs=[{ - 'scale_by_min_max$ragged_values': [ - 0., 0.18181819, 0.27272728, 1., 0.18181819, 0.6363636 - ], - 'scale_to_z_score$ragged_values': [ - -1.0645443, -0.4464218, -0.13736054, 2.3351295, -0.4464218, - 1.0988845 - ], - 'scale_to_0_1$ragged_values': [ - 0., 0.18181819, 0.27272728, 1., 0.18181819, 0.6363636 - ], - }, { - 'scale_to_0_1$ragged_values': [ - 0.27272728, 0.09090909, 0.18181819 - ], - 'scale_by_min_max$ragged_values': [ - 0.27272728, 0.09090909, 0.18181819 - ], - 'scale_to_z_score$ragged_values': [ - -0.13736054, -0.7554831, -0.4464218 - ], - }], - ), - dict( - testcase_name='2d_ragged_uniform', - input_data=[ - { - 'val': [0., 2., 3., 1., 2., 7.], - 'row_lengths': [0, 2, 0, 1], - }, - { - 'val': [3., 3., 1., 2.], - 'row_lengths': [2], - }, - ], - make_feature_spec=lambda: tf.io.RaggedFeature( # pylint: disable=g-long-lambda - tf.float32, - value_key='val', - partitions=[ - tf.io.RaggedFeature.RowLengths('row_lengths'), # pytype: disable=attribute-error - tf.io.RaggedFeature.UniformRowLength(2), # pytype: disable=attribute-error - ], - # Note that row splits are always encoded as int64 since we only - # support this integral type in outputs. We modify the default - # `row_splits_dtype` (tf.int32) here to make sure it still works. - row_splits_dtype=tf.int64), - elementwise=False, - expected_outputs=[{ - 'scale_by_min_max$ragged_values': [ - 0., 0.285714, 0.428571, 0.142857, 0.285714, 1. - ], - 'scale_by_min_max$row_lengths_1': [0, 2, 0, 1], - 'scale_to_z_score$row_lengths_1': [0, 2, 0, 1], - 'scale_to_z_score$ragged_values': [ - -1.3333334, -0.22222228, 0.33333328, -0.77777785, -0.22222228, - 2.5555556 - ], - 'scale_to_0_1$row_lengths_1': [0, 2, 0, 1], - 'scale_to_0_1$ragged_values': [ - 0., 0.2857143, 0.42857143, 0.14285715, 0.2857143, 1. - ], - }, { - 'scale_to_0_1$ragged_values': [ - 0.42857143, 0.42857143, 0.14285715, 0.2857143 - ], - 'scale_to_0_1$row_lengths_1': [2], - 'scale_by_min_max$ragged_values': [ - 0.42857143, 0.42857143, 0.14285715, 0.2857143 - ], - 'scale_by_min_max$row_lengths_1': [2], - 'scale_to_z_score$ragged_values': [ - 0.33333328, 0.33333328, -0.77777785, -0.22222228 - ], - 'scale_to_z_score$row_lengths_1': [2], - }], - ), - ) - def testNumericMappersWithCompositeInputs(self, input_data, make_feature_spec, - elementwise, expected_outputs): - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tft_unit.make_feature_spec_wrapper(make_feature_spec)}) - - def preprocessing_fn(inputs): - return { - 'scale_to_0_1': - tft.scale_to_0_1(inputs['a'], elementwise=elementwise), - 'scale_to_z_score': - tft.scale_to_z_score(inputs['a'], elementwise=elementwise), - 'scale_by_min_max': - tft.scale_by_min_max(inputs['a'], elementwise=elementwise), - } - - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_outputs) - - def testNumericAnalyzersWithInputsAndAxis(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return { - 'min': tft.min(inputs['a'], reduce_instance_dims=False), - 'max': tft.max(inputs['a'], reduce_instance_dims=False), - 'sum': tft.sum(inputs['a'], reduce_instance_dims=False), - 'size': tft.size(inputs['a'], reduce_instance_dims=False), - 'mean': tft.mean(inputs['a'], reduce_instance_dims=False), - 'var': tft.var(inputs['a'], reduce_instance_dims=False), - } - - input_data = [ - {'a': [8, 9, 3, 4]}, - {'a': [1, 2, 10, 11]} - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([4], tf.int64)}) - expected_outputs = { - 'min': np.array([1, 2, 3, 4], np.int64), - 'max': np.array([8, 9, 10, 11], np.int64), - 'sum': np.array([9, 11, 13, 15], np.int64), - 'size': np.array([2, 2, 2, 2], np.int64), - 'mean': np.array([4.5, 5.5, 6.5, 7.5], np.float32), - 'var': np.array([12.25, 12.25, 12.25, 12.25], np.float32), - } - self.assertAnalyzerOutputs( - input_data, input_metadata, analyzer_fn, expected_outputs) - - def testNumericAnalyzersWithNDInputsAndAxis(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return { - 'min': tft.min(inputs['a'], reduce_instance_dims=False), - 'max': tft.max(inputs['a'], reduce_instance_dims=False), - 'sum': tft.sum(inputs['a'], reduce_instance_dims=False), - 'size': tft.size(inputs['a'], reduce_instance_dims=False), - 'mean': tft.mean(inputs['a'], reduce_instance_dims=False), - 'var': tft.var(inputs['a'], reduce_instance_dims=False), - } - - input_data = [ - {'a': [[8, 9], [3, 4]]}, - {'a': [[1, 2], [10, 11]]}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([2, 2], tf.int64)}) - expected_outputs = { - 'min': np.array([[1, 2], [3, 4]], np.int64), - 'max': np.array([[8, 9], [10, 11]], np.int64), - 'sum': np.array([[9, 11], [13, 15]], np.int64), - 'size': np.array([[2, 2], [2, 2]], np.int64), - 'mean': np.array([[4.5, 5.5], [6.5, 7.5]], np.float32), - 'var': np.array([[12.25, 12.25], [12.25, 12.25]], np.float32), - } - self.assertAnalyzerOutputs( - input_data, input_metadata, analyzer_fn, expected_outputs) - - def testNumericAnalyzersWithShape1NDInputsAndAxis(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return { - 'min': tft.min(inputs['a'], reduce_instance_dims=False), - 'max': tft.max(inputs['a'], reduce_instance_dims=False), - 'sum': tft.sum(inputs['a'], reduce_instance_dims=False), - 'size': tft.size(inputs['a'], reduce_instance_dims=False), - 'mean': tft.mean(inputs['a'], reduce_instance_dims=False), - 'var': tft.var(inputs['a'], reduce_instance_dims=False), - } - - input_data = [{'a': [[8, 9]]}, {'a': [[1, 2]]}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([1, 2], tf.int64)}) - expected_outputs = { - 'min': np.array([[1, 2]], np.int64), - 'max': np.array([[8, 9]], np.int64), - 'sum': np.array([[9, 11]], np.int64), - 'size': np.array([[2, 2]], np.int64), - 'mean': np.array([[4.5, 5.5]], np.float32), - 'var': np.array([[12.25, 12.25]], np.float32), - } - self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn, - expected_outputs) - - def testNumericAnalyzersWithNDInputs(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return { - 'min': tft.min(inputs['a']), - 'max': tft.max(inputs['a']), - 'sum': tft.sum(inputs['a']), - 'size': tft.size(inputs['a']), - 'mean': tft.mean(inputs['a']), - 'var': tft.var(inputs['a']), - } - - input_data = [ - {'a': [[4, 5], [6, 7]]}, - {'a': [[1, 2], [3, 4]]} - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([2, 2], tf.int64)}) - expected_outputs = { - 'min': np.array(1, np.int64), - 'max': np.array(7, np.int64), - 'sum': np.array(32, np.int64), - 'size': np.array(8, np.int64), - 'mean': np.array(4.0, np.float32), - 'var': np.array(3.5, np.float32), - } - self.assertAnalyzerOutputs( - input_data, input_metadata, analyzer_fn, expected_outputs) - - @tft_unit.named_parameters(*tft_unit.cross_named_parameters( - [ - dict(testcase_name='int64', input_dtype=tf.int64), - dict(testcase_name='float32', input_dtype=tf.float32) - ], - [ - dict(testcase_name='scalar', input_shape=[]), - dict(testcase_name='ND', input_shape=[2, 3]) - ], - [ - dict(testcase_name='elementwise', reduce_instance_dims=False), - dict(testcase_name='not_elementwise', reduce_instance_dims=True) - ])) - def testNumericAnalyzersWithEmptyInputs(self, input_dtype, input_shape, - reduce_instance_dims): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return { - 'min': - tft.min(inputs['a'], reduce_instance_dims=reduce_instance_dims), - 'max': - tft.max(inputs['a'], reduce_instance_dims=reduce_instance_dims), - 'sum': - tft.sum(inputs['a'], reduce_instance_dims=reduce_instance_dims), - 'size': - tft.size(inputs['a'], reduce_instance_dims=reduce_instance_dims), - 'mean': - tft.mean(inputs['a'], reduce_instance_dims=reduce_instance_dims), - 'var': - tft.var(inputs['a'], reduce_instance_dims=reduce_instance_dims), - } - - input_data = [] - canonical_dtype = tft_unit.canonical_numeric_dtype(input_dtype) - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature(input_shape, canonical_dtype)}) - input_val_dtype = input_dtype.as_numpy_dtype - output_shape = [] if reduce_instance_dims else input_shape - output_dtype = canonical_dtype.as_numpy_dtype - default_min = np.inf if input_dtype.is_floating else canonical_dtype.max - default_max = -np.inf if input_dtype.is_floating else canonical_dtype.min - expected_outputs = { - 'min': np.full(output_shape, default_min, output_dtype), - 'max': np.full(output_shape, default_max, output_dtype), - 'sum': np.full(output_shape, 0, output_dtype), - 'size': np.full(output_shape, 0, np.int64), - 'mean': np.full(output_shape, 0, np.float32), - 'var': np.full(output_shape, 0, np.float32), - } - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - test_data=[{ - 'a': np.zeros(input_shape, input_val_dtype) - }, { - 'a': np.ones(input_shape, input_val_dtype) - }]) - - def testNumericMeanWithSparseTensorReduceFalseOverflow(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return {'mean': tft.mean(tf.cast(inputs['sparse'], tf.int32), False)} - - input_data = [ - {'idx': [0, 1], 'val': [1, 1]}, - {'idx': [1, 3], 'val': [2147483647, 3]}, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'sparse': tf.io.SparseFeature('idx', 'val', tf.int64, 4)}) - expected_outputs = { - 'mean': np.array([1., 1073741824., float('nan'), 3.], np.float32) - } - self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn, - expected_outputs) - - def testStringToTFIDF(self): - def preprocessing_fn(inputs): - inputs_as_ints = tft.compute_and_apply_vocabulary( - tf.compat.v1.strings.split(inputs['a'])) - out_index, out_values = tft.tfidf( - inputs_as_ints, - tft.get_num_buckets_for_transformed_feature(inputs_as_ints)) - return { - 'tf_idf': out_values, - 'index': out_index, - } - input_data = [{'a': 'hello hello world'}, - {'a': 'hello goodbye hello world'}, - {'a': 'I like pie pie pie'}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - - # IDFs - # hello = 1 + log(4/3) = 1.28768 - # world = 1 + log(4/3) - # goodbye = 1 + log(4/2) = 1.69314 - # I = 1 + log(4/2) - # like = 1 + log(4/2) - # pie = 1 + log(4/2) - log_4_over_2_plus_1 = 1.69314718056 - log_4_over_3_plus_1 = 1.28768207245 - expected_transformed_data = [{ - 'tf_idf': [(2 / 3) * log_4_over_3_plus_1, - (1 / 3) * log_4_over_3_plus_1], - 'index': [0, 2] - }, { - 'tf_idf': [(2 / 4) * log_4_over_3_plus_1, (1 / 4) * log_4_over_3_plus_1, - (1 / 4) * log_4_over_2_plus_1], - 'index': [0, 2, 4] - }, { - 'tf_idf': [(3 / 5) * log_4_over_2_plus_1, (1 / 5) * log_4_over_2_plus_1, - (1 / 5) * log_4_over_2_plus_1], - 'index': [1, 3, 5] - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'tf_idf': tf.io.VarLenFeature(tf.float32), - 'index': tf.io.VarLenFeature(tf.int64) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, - expected_transformed_data, expected_metadata) - - def testTFIDFNoData(self): - def preprocessing_fn(inputs): - inputs_as_ints = tft.compute_and_apply_vocabulary( - tf.compat.v1.strings.split(inputs['a'])) - out_index, out_values = tft.tfidf(inputs_as_ints, 6) - return { - 'tf_idf': out_values, - 'index': out_index - } - input_data = [{'a': ''}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - expected_transformed_data = [{'tf_idf': [], 'index': []}] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'tf_idf': tf.io.VarLenFeature(tf.float32), - 'index': tf.io.VarLenFeature(tf.int64) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_transformed_data, - expected_metadata) - - def testStringToTFIDFEmptyDoc(self): - def preprocessing_fn(inputs): - inputs_as_ints = tft.compute_and_apply_vocabulary( - tf.compat.v1.strings.split(inputs['a'])) - out_index, out_values = tft.tfidf( - inputs_as_ints, - tft.get_num_buckets_for_transformed_feature(inputs_as_ints)) - return { - 'tf_idf': out_values, - 'index': out_index - } - input_data = [{'a': 'hello hello world'}, - {'a': ''}, - {'a': 'hello goodbye hello world'}, - {'a': 'I like pie pie pie'}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - - log_5_over_2_plus_1 = 1.91629073187 - log_5_over_3_plus_1 = 1.51082562376 - expected_transformed_data = [{ - 'tf_idf': [(2 / 3) * log_5_over_3_plus_1, - (1 / 3) * log_5_over_3_plus_1], - 'index': [0, 2] - }, { - 'tf_idf': [], - 'index': [] - }, { - 'tf_idf': [(2 / 4) * log_5_over_3_plus_1, (1 / 4) * log_5_over_3_plus_1, - (1 / 4) * log_5_over_2_plus_1], - 'index': [0, 2, 4] - }, { - 'tf_idf': [(3 / 5) * log_5_over_2_plus_1, (1 / 5) * log_5_over_2_plus_1, - (1 / 5) * log_5_over_2_plus_1], - 'index': [1, 3, 5] - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'tf_idf': tf.io.VarLenFeature(tf.float32), - 'index': tf.io.VarLenFeature(tf.int64) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, - expected_transformed_data, expected_metadata) - - def testIntToTFIDF(self): - def preprocessing_fn(inputs): - out_index, out_values = tft.tfidf(inputs['a'], 13) - return {'tf_idf': out_values, 'index': out_index} - input_data = [{'a': [2, 2, 0]}, - {'a': [2, 6, 2, 0]}, - {'a': [8, 10, 12, 12, 12]}, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.VarLenFeature(tf.int64)}) - log_4_over_2_plus_1 = 1.69314718056 - log_4_over_3_plus_1 = 1.28768207245 - expected_data = [{ - 'tf_idf': [(1 / 3) * log_4_over_3_plus_1, - (2 / 3) * log_4_over_3_plus_1], - 'index': [0, 2] - }, { - 'tf_idf': [(1 / 4) * log_4_over_3_plus_1, (2 / 4) * log_4_over_3_plus_1, - (1 / 4) * log_4_over_2_plus_1], - 'index': [0, 2, 6] - }, { - 'tf_idf': [(1 / 5) * log_4_over_2_plus_1, (1 / 5) * log_4_over_2_plus_1, - (3 / 5) * log_4_over_2_plus_1], - 'index': [8, 10, 12] - }] - expected_schema = tft.DatasetMetadata.from_feature_spec({ - 'tf_idf': tf.io.VarLenFeature(tf.float32), - 'index': tf.io.VarLenFeature(tf.int64) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_schema) - - def testIntToTFIDFWithoutSmoothing(self): - def preprocessing_fn(inputs): - out_index, out_values = tft.tfidf(inputs['a'], 13, smooth=False) - return {'tf_idf': out_values, 'index': out_index} - input_data = [{'a': [2, 2, 0]}, - {'a': [2, 6, 2, 0]}, - {'a': [8, 10, 12, 12, 12]}, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.VarLenFeature(tf.int64)}) - log_3_over_2_plus_1 = 1.4054651081 - log_3_plus_1 = 2.0986122886 - expected_data = [{ - 'tf_idf': [(1 / 3) * log_3_over_2_plus_1, - (2 / 3) * log_3_over_2_plus_1], - 'index': [0, 2] - }, { - 'tf_idf': [(1 / 4) * log_3_over_2_plus_1, (2 / 4) * log_3_over_2_plus_1, - (1 / 4) * log_3_plus_1], - 'index': [0, 2, 6] - }, { - 'tf_idf': [(1 / 5) * log_3_plus_1, (1 / 5) * log_3_plus_1, - (3 / 5) * log_3_plus_1], - 'index': [8, 10, 12] - }] - expected_schema = tft.DatasetMetadata.from_feature_spec({ - 'tf_idf': tf.io.VarLenFeature(tf.float32), - 'index': tf.io.VarLenFeature(tf.int64) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_data, - expected_schema) - - def testTFIDFWithOOV(self): - def preprocessing_fn(inputs): - inputs_as_ints = tft.compute_and_apply_vocabulary( - tf.compat.v1.strings.split(inputs['a']), top_k=3) - out_index, out_values = tft.tfidf( - inputs_as_ints, - tft.get_num_buckets_for_transformed_feature(inputs_as_ints) + 1) - return { - 'tf_idf': out_values, - 'index': out_index - } - input_data = [{'a': 'hello hello world'}, - {'a': 'hello goodbye hello world'}, - {'a': 'I like pie pie pie'}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - - # IDFs - # hello = 1 + log(3/3) = 1 - # pie = 1+ log(3/2) = 1.4054651081 - # world = 1 + log(3/3) = 1 - # OOV - goodbye, I, like = 1 + log(3/3) = 1 - log_4_over_2_plus_1 = 1.69314718056 - log_4_over_3_plus_1 = 1.28768207245 - expected_transformed_data = [{ - 'tf_idf': [(2 / 3) * log_4_over_3_plus_1, - (1 / 3) * log_4_over_3_plus_1], - 'index': [0, 2] - }, { - 'tf_idf': [(2 / 4) * log_4_over_3_plus_1, (1 / 4) * log_4_over_3_plus_1, - (1 / 4) * log_4_over_3_plus_1], - 'index': [0, 2, 3] - }, { - 'tf_idf': [(3 / 5) * log_4_over_2_plus_1, - (2 / 5) * log_4_over_3_plus_1], - 'index': [1, 3] - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'tf_idf': tf.io.VarLenFeature(tf.float32), - 'index': tf.io.VarLenFeature(tf.int64) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, preprocessing_fn, expected_transformed_data, - expected_metadata) - - def testTFIDFWithNegatives(self): - def preprocessing_fn(inputs): - out_index, out_values = tft.tfidf(inputs['a'], 14) - return { - 'tf_idf': out_values, - 'index': out_index - } - input_data = [{'a': [2, 2, -4]}, - {'a': [2, 6, 2, -1]}, - {'a': [8, 10, 12, 12, 12]}, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.VarLenFeature(tf.int64)}) - - log_4_over_2_plus_1 = 1.69314718056 - log_4_over_3_plus_1 = 1.28768207245 - # NOTE: -4 mod 14 = 10 - expected_transformed_data = [{ - 'tf_idf': [(2 / 3) * log_4_over_3_plus_1, - (1 / 3) * log_4_over_3_plus_1], - 'index': [2, 10] - }, { - 'tf_idf': [(2 / 4) * log_4_over_3_plus_1, (1 / 4) * log_4_over_2_plus_1, - (1 / 4) * log_4_over_2_plus_1], - 'index': [2, 6, 13] - }, { - 'tf_idf': [(1 / 5) * log_4_over_2_plus_1, (1 / 5) * log_4_over_3_plus_1, - (3 / 5) * log_4_over_2_plus_1], - 'index': [8, 10, 12] - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'tf_idf': tf.io.VarLenFeature(tf.float32), - 'index': tf.io.VarLenFeature(tf.int64) - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, - expected_transformed_data, - expected_metadata) - - def _get_dfidf_experimental_preprocessing_fn(self, - is_str_input: bool = False, - smooth: bool = True, - add_baseline: bool = True, - vocab_size: Optional[int] = None, - top_k: Optional[int] = None): - """Returns proper preprocessing fn for df/idf under tft.experimental.""" - - def preprocessing_fn(inputs): - if is_str_input: - inputs_as_ints = tft.compute_and_apply_vocabulary( - tf.compat.v1.strings.split(inputs['a']), top_k=top_k) - else: - inputs_as_ints = inputs['a'] - - if vocab_size is None: - computed_vocab_size = tft.get_num_buckets_for_transformed_feature( - inputs_as_ints) - else: - computed_vocab_size = vocab_size - - out_df_counts = tft.experimental.document_frequency( - inputs_as_ints, computed_vocab_size) - out_idf_weights = tft.experimental.idf( - inputs_as_ints, - computed_vocab_size, - smooth=smooth, - add_baseline=add_baseline) - return {'df': out_df_counts, 'idf': out_idf_weights} - - return preprocessing_fn - - @tft_unit.named_parameters( - dict( - testcase_name='StrInputSmoothBasaeline', - smooth=True, - add_baseline=True), - dict( - testcase_name='StrInputSmoothWOBasaeline', - smooth=True, - add_baseline=False), - dict( - testcase_name='StrInputNonSmoothBasaeline', - smooth=False, - add_baseline=True), - dict( - testcase_name='StrInputNonSmoothWOBasaeline', - smooth=False, - add_baseline=False), - ) - def testStringToDFIDFExperimental(self, smooth, add_baseline): - input_data = [{ - 'a': 'hello hello world pie' - }, { - 'a': 'hello goodbye world pie' - }, { - 'a': 'I like pie pie' - }] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - - # corpus_size = 3 - # DF smooth base IDF non-smooth base IDF with baseline - # hello 2 log(4/3) log(3/2) * + 1 - # world 2 log(4/3) log(3/2) * + 1 - # goodbye 1 log(4/2) log3 * + 1 - # I 1 log(4/2) log3 * + 1 - # like 1 log(4/2) log3 * + 1 - # pie 3 log(4/4) = 0 log(3/3)=0 * + 1 - log_4_over_2 = 0.69314718056 - log_4_over_3 = 0.28768207245 - log_3_over_2 = 0.4054651081 - log_3 = 1.09861228867 - - if smooth: - base_idf1, base_idf2 = log_4_over_3, log_4_over_2 - else: - base_idf1, base_idf2 = log_3_over_2, log_3 - - baseline = 1.0 if add_baseline else 0.0 - - expected_transformed_data = [{ - 'df': [2, 2, 2, 3], - 'idf': [ - baseline + base_idf1, baseline + base_idf1, baseline + base_idf1, - baseline - ] - }, { - 'df': [2, 1, 2, 3], - 'idf': [ - baseline + base_idf1, baseline + base_idf2, baseline + base_idf1, - baseline - ] - }, { - 'df': [1, 1, 3, 3], - 'idf': [baseline + base_idf2, baseline + base_idf2, baseline, baseline] - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'df': tf.io.VarLenFeature(tf.int64), - 'idf': tf.io.VarLenFeature(tf.float32) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, - self._get_dfidf_experimental_preprocessing_fn( - is_str_input=True, smooth=smooth, add_baseline=add_baseline), - expected_transformed_data, expected_metadata) - - def testDFIDFExperimentalNoData(self): - - input_data = [{'a': ''}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - # Input data is completely empty so need to specify vocab_size explicitly - preprocessing_fn = self._get_dfidf_experimental_preprocessing_fn( - is_str_input=True, vocab_size=6) - - expected_transformed_data = [{'df': [], 'idf': []}] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'df': tf.io.VarLenFeature(tf.int64), - 'idf': tf.io.VarLenFeature(tf.float32) - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, - expected_transformed_data, - expected_metadata) - - def testStringToDFIDFExperimentalEmptyDoc(self): - input_data = [{ - 'a': 'hello hello world' - }, { - 'a': '' - }, { - 'a': 'hello goodbye world' - }, { - 'a': 'I like pie' - }] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - - log_5_over_2_plus_1 = 1.91629073187 - log_5_over_3_plus_1 = 1.51082562376 - expected_transformed_data = [{ - 'df': [2, 2, 2], - 'idf': [log_5_over_3_plus_1, log_5_over_3_plus_1, log_5_over_3_plus_1] - }, { - 'df': [], - 'idf': [] - }, { - 'df': [2, 1, 2], - 'idf': [log_5_over_3_plus_1, log_5_over_2_plus_1, log_5_over_3_plus_1] - }, { - 'df': [1, 1, 1], - 'idf': [log_5_over_2_plus_1, log_5_over_2_plus_1, log_5_over_2_plus_1] - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'df': tf.io.VarLenFeature(tf.int64), - 'idf': tf.io.VarLenFeature(tf.float32) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, - self._get_dfidf_experimental_preprocessing_fn(is_str_input=True), - expected_transformed_data, expected_metadata) - - def testDFIDFExperimentalWithOOV(self): - - input_data = [{ - 'a': 'hello world hi' - }, { - 'a': 'hello goodbye world' - }, { - 'a': 'I like pie pie' - }] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - - preprocessing_fn_w_oov = self._get_dfidf_experimental_preprocessing_fn( - is_str_input=True, vocab_size=4, top_k=3) - # smoothed base IDFs - # hello = log(4/3) - # pie = log(4/2) - # world = log(4/3) - # OOV - hi, goodbye, I, like = log(4/4) = 0 (OOV in all 3 out of 3 docs) - log_4_over_2_plus_1 = 1.69314718056 - log_4_over_3_plus_1 = 1.28768207245 - expected_transformed_data = [{ - 'df': [2, 2, 3], - 'idf': [log_4_over_3_plus_1, log_4_over_3_plus_1, 1.0] - }, { - 'df': [2, 3, 2], - 'idf': [log_4_over_3_plus_1, 1.0, log_4_over_3_plus_1] - }, { - 'df': [3, 3, 1, 1], - 'idf': [1.0, 1.0, log_4_over_2_plus_1, log_4_over_2_plus_1] - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'df': tf.io.VarLenFeature(tf.int64), - 'idf': tf.io.VarLenFeature(tf.float32) - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn_w_oov, - expected_transformed_data, - expected_metadata) - - @tft_unit.named_parameters( - dict( - testcase_name='IntInputSmoothBasaeline', - smooth=True, - add_baseline=True), - dict( - testcase_name='IntInputSmoothWOBasaeline', - smooth=True, - add_baseline=False), - dict( - testcase_name='IntInputNoneSmoothBasaeline', - smooth=False, - add_baseline=True), - dict( - testcase_name='IntInputNoneSmoothWOBasaeline', - smooth=False, - add_baseline=False)) - def testIntToDFIDFExpeirmental(self, smooth, add_baseline): - - input_data = [ - { - 'a': [2, 2, 0] - }, - { - 'a': [2, 6, 2, 0] - }, - { - 'a': [8, 10, 12, 12, 12] - }, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.VarLenFeature(tf.int64)}) - log_4_over_2 = 0.69314718056 - log_4_over_3 = 0.28768207245 - log_3 = 1.09861228867 - log_3_over_2 = 0.4054651081 - - if smooth: - idf1, idf2 = log_4_over_2, log_4_over_3 - else: - idf1, idf2 = log_3, log_3_over_2 - - if add_baseline: - idf1 += 1.0 - idf2 += 1.0 - - expected_data = [{ - 'df': [2, 2, 2], - 'idf': [idf2, idf2, idf2], - }, { - 'df': [2, 1, 2, 2], - 'idf': [idf2, idf1, idf2, idf2], - }, { - 'df': [1, 1, 1, 1, 1], - 'idf': [idf1, idf1, idf1, idf1, idf1] - }] - expected_schema = tft.DatasetMetadata.from_feature_spec({ - 'df': tf.io.VarLenFeature(tf.int64), - 'idf': tf.io.VarLenFeature(tf.float32) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, - self._get_dfidf_experimental_preprocessing_fn( - vocab_size=13, smooth=smooth, add_baseline=add_baseline), - expected_data, expected_schema) - - def testDFIDFExperimentalWithNegatives(self): - input_data = [ - { - 'a': [2, 2, -4] - }, - { - 'a': [2, 6, 2, -1] - }, - { - 'a': [8, 10, 12, 12, 12] - }, - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.VarLenFeature(tf.int64)}) - - log_4_over_2_plus_1 = 1.69314718056 - log_4_over_3_plus_1 = 1.28768207245 - # NOTE: -4 mod 14 = 10, -1 mod 14 = 13 - expected_transformed_data = [{ - 'df': [2, 2, 2], - 'idf': [log_4_over_3_plus_1, log_4_over_3_plus_1, log_4_over_3_plus_1] - }, { - 'df': [2, 1, 2, 1], - 'idf': [ - log_4_over_3_plus_1, log_4_over_2_plus_1, log_4_over_3_plus_1, - log_4_over_2_plus_1 + self.assertAnalyzerOutputs( + input_data, input_metadata, analyzer_fn, expected_outputs + ) + + @unittest.skipIf( + not common.IS_ANNOTATIONS_PB_AVAILABLE, "Schema annotations are not available" + ) + def testSavedModelWithAnnotations(self): + """Test serialization/deserialization as a saved model with annotations.""" + self._SkipIfOutputRecordBatches() + + def preprocessing_fn(inputs): + # Bucketization applies annotations to the output schema + return { + "x_bucketized": tft.bucketize(inputs["x"], num_buckets=4), + "y_vocab": tft.compute_and_apply_vocabulary(inputs["y"]), + } + + input_data = [ + { + "x": 1, + "y": "foo", + }, + { + "x": 2, + "y": "bar", + }, + { + "x": 3, + "y": "foo", + }, + { + "x": 4, + "y": "foo", + }, ] - }, { - 'df': [1, 2, 1, 1, 1], - 'idf': [ - log_4_over_2_plus_1, log_4_over_3_plus_1, log_4_over_2_plus_1, - log_4_over_2_plus_1, log_4_over_2_plus_1 + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.string), + } + ) + temp_dir = self.get_temp_dir() + # Force a batch size of 1 to ensure that occurences are correctly aggregated + # across batches when computing the total vocabulary size. + with tft_beam.Context(temp_dir=temp_dir, desired_batch_size=1): + transform_fn = (input_data, input_metadata) | tft_beam.AnalyzeDataset( + preprocessing_fn + ) + # Write transform_fn to serialize annotation collections to SavedModel + _ = transform_fn | transform_fn_io.WriteTransformFn(temp_dir) + + # Ensure that the annotations survive the round trip to SavedModel. + tf_transform_output = tft.TFTransformOutput(temp_dir) + schema = tf_transform_output.transformed_metadata._schema + self.assertLen(schema.feature, 2) + for feature in schema.feature: + if feature.name == "x_bucketized": + self.assertLen(feature.annotation.extra_metadata, 1) + for annotation in feature.annotation.extra_metadata: + message = annotations_pb2.BucketBoundaries() + annotation.Unpack(message) + self.assertAllClose(list(message.boundaries), [2, 3, 4]) + elif feature.name == "y_vocab": + self.assertLen(feature.annotation.extra_metadata, 0) + else: + raise ValueError(f"Unexpected feature with metadata: {feature.name}") + # Vocabularies create a top-level schema annotation for each vocab file. + self.assertLen(schema.annotation.extra_metadata, 1) + message = annotations_pb2.VocabularyMetadata() + annotation = schema.annotation.extra_metadata[0] + annotation.Unpack(message) + self.assertEqual(message.unfiltered_vocabulary_size, 2) + + @unittest.skipIf( + not common.IS_ANNOTATIONS_PB_AVAILABLE, "Schema annotations are not available" + ) + def testSavedModelWithGlobalAnnotations(self): + self._SkipIfOutputRecordBatches() + + def preprocessing_fn(inputs): + # Add some arbitrary annotation data at the global schema level. + boundaries = tf.constant([[1.0]]) + message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name + sizes = tf.expand_dims([tf.size(boundaries)], axis=0) + message_proto = tf.raw_ops.EncodeProto( + sizes=sizes, + values=[tf.cast(boundaries, tf.float32)], + field_names=["boundaries"], + message_type=message_type, + )[0] + type_url = os.path.join("type.googleapis.com", message_type) + schema_inference.annotate(type_url, message_proto) + return { + "x_scaled": tft.scale_by_min_max(inputs["x"]), + } + + input_data = [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + } + ) + temp_dir = self.get_temp_dir() + with tft_beam.Context(temp_dir=temp_dir): + transform_fn = (input_data, input_metadata) | tft_beam.AnalyzeDataset( + preprocessing_fn + ) + # Write transform_fn to serialize annotation collections to SavedModel + _ = transform_fn | transform_fn_io.WriteTransformFn(temp_dir) + + # Ensure that global annotations survive the round trip to SavedModel. + tf_transform_output = tft.TFTransformOutput(temp_dir) + schema = tf_transform_output.transformed_metadata._schema + self.assertLen(schema.annotation.extra_metadata, 1) + for annotation in schema.annotation.extra_metadata: + message = annotations_pb2.BucketBoundaries() + annotation.Unpack(message) + self.assertAllClose(list(message.boundaries), [1]) + + def testPipelineAPICounters(self): + self._SkipIfOutputRecordBatches() + + def preprocessing_fn(inputs): + _ = tft.vocabulary(inputs["a"]) + return { + "a_int": tft.compute_and_apply_vocabulary(inputs["a"]), + "x_scaled": tft.scale_to_0_1(inputs["x"]), + "y_scaled": tft.scale_to_0_1(inputs["y"]), + } + + with self._makeTestPipeline() as pipeline: + input_data = pipeline | "CreateTrainingData" >> beam.Create( + [{"x": 4, "y": 5, "a": "hello"}, {"x": 1, "y": 3, "a": "world"}] + ) + metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.FixedLenFeature([], tf.float32), + "a": tf.io.FixedLenFeature([], tf.string), + } + ) + with tft_beam.Context(temp_dir=self.get_temp_dir()): + _ = ( + input_data, + metadata, + ) | "AnalyzeAndTransformDataset" >> tft_beam.AnalyzeAndTransformDataset( + preprocessing_fn + ) + + metrics = pipeline.metrics + self.assertMetricsCounterEqual(metrics, "tft_analyzer_vocabulary", 1) + self.assertMetricsCounterEqual(metrics, "tft_mapper_scale_to_0_1", 2) + self.assertMetricsCounterEqual( + metrics, "tft_mapper_compute_and_apply_vocabulary", 1 + ) + # compute_and_apply_vocabulary implicitly calls apply_vocabulary. + # We check that that call is not logged. + self.assertMetricsCounterEqual(metrics, "tft_mapper_apply_vocabulary", 0) + + for namespace in ( + "tfx.Transform.analyze_input_tensors", + "tfx.Transform.transform_input_tensors", + "tfx.Transform.transform_output_tensors", + ): + self.assertMetricsCounterEqual( + metrics, "dense_tensor", 3, namespaces_list=[namespace] + ) + + def testNumBytesCounter(self): + self._SkipIfOutputRecordBatches() + + test_data = [ + pa.RecordBatch.from_arrays( + [ + pa.array([[4]], type=pa.large_list(pa.float32())), + pa.array([[5]], type=pa.large_list(pa.float32())), + pa.array([["hello"]], type=pa.large_list(pa.large_binary())), + ], + ["x", "y", "a"], + ), + pa.RecordBatch.from_arrays( + [ + pa.array([[1]], type=pa.large_list(pa.float32())), + pa.array([[3]], type=pa.large_list(pa.float32())), + pa.array([["world"]], type=pa.large_list(pa.large_binary())), + ], + ["x", "y", "a"], + ), ] - }] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'df': tf.io.VarLenFeature(tf.int64), - 'idf': tf.io.VarLenFeature(tf.float32) - }) - self.assertAnalyzeAndTransformResults( - input_data, input_metadata, - self._get_dfidf_experimental_preprocessing_fn(vocab_size=14), - expected_transformed_data, expected_metadata) - - def testCovarianceTwoDimensions(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return {'y': tft.covariance(inputs['x'], dtype=tf.float32)} - - input_data = [{'x': x} for x in [[0, 0], [4, 0], [2, -2], [2, 2]]] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([2], tf.float32)}) - expected_outputs = {'y': np.array([[2, 0], [0, 2]], np.float32)} - self.assertAnalyzerOutputs( - input_data, input_metadata, analyzer_fn, expected_outputs) - - def testCovarianceOneDimension(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return {'y': tft.covariance(inputs['x'], dtype=tf.float32)} - - input_data = [{'x': x} for x in [[0], [2], [4], [6]]] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([1], tf.float32)}) - expected_outputs = {'y': np.array([[5]], np.float32)} - self.assertAnalyzerOutputs( - input_data, input_metadata, analyzer_fn, expected_outputs) - - def testCovarianceOneDimensionWithEmptyInputs(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return {'y': tft.covariance(inputs['x'], dtype=tf.float32)} - - input_data = [] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([1], tf.float32)}) - test_data = [{'x': [1]}, {'x': [2]}] - expected_outputs = {'y': np.array([[0]], dtype=np.float32)} - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - test_data=test_data) - - def testPCAThreeToTwoDimensions(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return {'y': tft.pca(inputs['x'], 2, dtype=tf.float32)} - - input_data = [{'x': x} - for x in [[0, 0, 1], [4, 0, 1], [2, -1, 1], [2, 1, 1]]] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([3], tf.float32)}) - expected_outputs = {'y': np.array([[1, 0], [0, 1], [0, 0]], np.float32)} - self.assertAnalyzerOutputs( - input_data, input_metadata, analyzer_fn, expected_outputs) - - def testPCAThreeToTwoDimensionsWithEmptyInputs(self): - self._SkipIfOutputRecordBatches() - - def analyzer_fn(inputs): - return {'y': tft.pca(inputs['x'], 2, dtype=tf.float32)} - - input_data = [] - test_data = [{'x': x} for x in - [[0, 0, 1], [4, 0, 1], [2, -1, 1], [2, 1, 1]]] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([3], tf.float32)}) - expected_outputs = {'y': np.array([[1, 0], [0, 1], [0, 0]], np.float32)} - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - test_data=test_data) - - class _SumCombiner(tft_beam.experimental.PTransformAnalyzer): - - def __init__(self): - super().__init__() - self.base_temp_dir_in_expand = None - - def _extract_outputs(self, sums): - return [beam.pvalue.TaggedOutput('0', sums[0]), - beam.pvalue.TaggedOutput('1', sums[1])] - - def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]): - self.base_temp_dir_in_expand = self.base_temp_dir - return (pcoll - | beam.FlatMap(lambda batches: list(zip(*batches))) - | - beam.CombineGlobally(lambda values: np.sum(list(values), axis=0)) - | beam.FlatMap(self._extract_outputs).with_outputs('0', '1')) - - def testPTransformAnalyzer(self): - self._SkipIfOutputRecordBatches() - - sum_combiner = self._SumCombiner() - - def analyzer_fn(inputs): - outputs = tft.experimental.ptransform_analyzer([inputs['x'], inputs['y']], - sum_combiner, - [tf.int64, tf.int64], - [[], []]) - return {'x_sum': outputs[0], 'y_sum': outputs[1]} - - input_data = [{'x': 1, 'y': i} for i in range(100)] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.int64), - 'y': tf.io.FixedLenFeature([], tf.int64) - }) - expected_outputs = { - 'x_sum': np.array(100, np.int64), - 'y_sum': np.array(4950, np.int64) - } - self.assertIsNone(sum_combiner.base_temp_dir_in_expand) - self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn, - expected_outputs) - self.assertIsNotNone(sum_combiner.base_temp_dir_in_expand) - self.assertStartsWith(sum_combiner.base_temp_dir_in_expand, - self.get_temp_dir()) - - @tft_unit.named_parameters( - dict( - testcase_name='ArrayOutput', - output_fn=lambda x: np.array(x, np.int64)), - dict(testcase_name='ListOutput', output_fn=list), - ) - def testPTransformAnalyzerMultiDimOutput(self, output_fn): - self._SkipIfOutputRecordBatches() - - class _SimpleSumCombiner(tft_beam.experimental.PTransformAnalyzer): - - def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]): - return ( - pcoll - | beam.FlatMap(lambda batches: list(zip(*batches))) - | beam.CombineGlobally(lambda values: np.sum(list(values), axis=0)) - | beam.combiners.ToList() - | beam.Map(output_fn)) - - sum_combiner = _SimpleSumCombiner() - - def analyzer_fn(inputs): - outputs, = tft.experimental.ptransform_analyzer( - [inputs['x'], inputs['y']], sum_combiner, [tf.int64], [[1, 2]]) - return {'x_y_sums': outputs} - - input_data = [{'x': 1, 'y': i} for i in range(100)] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.int64), - 'y': tf.io.FixedLenFeature([], tf.int64) - }) - expected_outputs = { - 'x_y_sums': np.array([[100, 4950]], np.int64), - } - self.assertAnalyzerOutputs(input_data, input_metadata, analyzer_fn, - expected_outputs) - - @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE, - 'Schema annotations are not available') - def testSavedModelWithAnnotations(self): - """Test serialization/deserialization as a saved model with annotations.""" - self._SkipIfOutputRecordBatches() - - def preprocessing_fn(inputs): - # Bucketization applies annotations to the output schema - return { - 'x_bucketized': tft.bucketize(inputs['x'], num_buckets=4), - 'y_vocab': tft.compute_and_apply_vocabulary(inputs['y']), - } - - input_data = [{ - 'x': 1, - 'y': 'foo', - }, { - 'x': 2, - 'y': 'bar', - }, { - 'x': 3, - 'y': 'foo', - }, { - 'x': 4, - 'y': 'foo', - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.string), - }) - temp_dir = self.get_temp_dir() - # Force a batch size of 1 to ensure that occurences are correctly aggregated - # across batches when computing the total vocabulary size. - with tft_beam.Context(temp_dir=temp_dir, desired_batch_size=1): - transform_fn = ((input_data, input_metadata) - | tft_beam.AnalyzeDataset(preprocessing_fn)) - # Write transform_fn to serialize annotation collections to SavedModel - _ = transform_fn | transform_fn_io.WriteTransformFn(temp_dir) - - # Ensure that the annotations survive the round trip to SavedModel. - tf_transform_output = tft.TFTransformOutput(temp_dir) - schema = tf_transform_output.transformed_metadata._schema - self.assertLen(schema.feature, 2) - for feature in schema.feature: - if feature.name == 'x_bucketized': - self.assertLen(feature.annotation.extra_metadata, 1) - for annotation in feature.annotation.extra_metadata: - message = annotations_pb2.BucketBoundaries() - annotation.Unpack(message) - self.assertAllClose(list(message.boundaries), [2, 3, 4]) - elif feature.name == 'y_vocab': - self.assertLen(feature.annotation.extra_metadata, 0) - else: - raise ValueError('Unexpected feature with metadata: {}'.format( - feature.name)) - # Vocabularies create a top-level schema annotation for each vocab file. - self.assertLen(schema.annotation.extra_metadata, 1) - message = annotations_pb2.VocabularyMetadata() - annotation = schema.annotation.extra_metadata[0] - annotation.Unpack(message) - self.assertEqual(message.unfiltered_vocabulary_size, 2) - - @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE, - 'Schema annotations are not available') - def testSavedModelWithGlobalAnnotations(self): - self._SkipIfOutputRecordBatches() - - def preprocessing_fn(inputs): - # Add some arbitrary annotation data at the global schema level. - boundaries = tf.constant([[1.0]]) - message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name - sizes = tf.expand_dims([tf.size(boundaries)], axis=0) - message_proto = tf.raw_ops.EncodeProto( - sizes=sizes, values=[tf.cast(boundaries, tf.float32)], - field_names=['boundaries'], message_type=message_type)[0] - type_url = os.path.join('type.googleapis.com', message_type) - schema_inference.annotate(type_url, message_proto) - return { - 'x_scaled': tft.scale_by_min_max(inputs['x']), - } - - input_data = [{'x': 1}, {'x': 2}, {'x': 3}, {'x': 4}] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - }) - temp_dir = self.get_temp_dir() - with tft_beam.Context(temp_dir=temp_dir): - transform_fn = ((input_data, input_metadata) - | tft_beam.AnalyzeDataset(preprocessing_fn)) - # Write transform_fn to serialize annotation collections to SavedModel - _ = transform_fn | transform_fn_io.WriteTransformFn(temp_dir) - - # Ensure that global annotations survive the round trip to SavedModel. - tf_transform_output = tft.TFTransformOutput(temp_dir) - schema = tf_transform_output.transformed_metadata._schema - self.assertLen(schema.annotation.extra_metadata, 1) - for annotation in schema.annotation.extra_metadata: - message = annotations_pb2.BucketBoundaries() - annotation.Unpack(message) - self.assertAllClose(list(message.boundaries), [1]) - - def testPipelineAPICounters(self): - self._SkipIfOutputRecordBatches() - - def preprocessing_fn(inputs): - _ = tft.vocabulary(inputs['a']) - return { - 'a_int': tft.compute_and_apply_vocabulary(inputs['a']), - 'x_scaled': tft.scale_to_0_1(inputs['x']), - 'y_scaled': tft.scale_to_0_1(inputs['y']) - } - - with self._makeTestPipeline() as pipeline: - input_data = pipeline | 'CreateTrainingData' >> beam.Create([{ - 'x': 4, - 'y': 5, - 'a': 'hello' - }, { - 'x': 1, - 'y': 3, - 'a': 'world' - }]) - metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.FixedLenFeature([], tf.float32), - 'a': tf.io.FixedLenFeature([], tf.string) - }) - with tft_beam.Context(temp_dir=self.get_temp_dir()): - _ = ((input_data, metadata) - | 'AnalyzeAndTransformDataset' >> - tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - - metrics = pipeline.metrics - self.assertMetricsCounterEqual(metrics, 'tft_analyzer_vocabulary', 1) - self.assertMetricsCounterEqual(metrics, 'tft_mapper_scale_to_0_1', 2) - self.assertMetricsCounterEqual(metrics, - 'tft_mapper_compute_and_apply_vocabulary', 1) - # compute_and_apply_vocabulary implicitly calls apply_vocabulary. - # We check that that call is not logged. - self.assertMetricsCounterEqual(metrics, 'tft_mapper_apply_vocabulary', 0) - - for namespace in ('tfx.Transform.analyze_input_tensors', - 'tfx.Transform.transform_input_tensors', - 'tfx.Transform.transform_output_tensors'): - self.assertMetricsCounterEqual( - metrics, 'dense_tensor', 3, namespaces_list=[namespace]) - - def testNumBytesCounter(self): - self._SkipIfOutputRecordBatches() - - test_data = [ - pa.RecordBatch.from_arrays([ - pa.array([[4]], type=pa.large_list(pa.float32())), - pa.array([[5]], type=pa.large_list(pa.float32())), - pa.array([['hello']], type=pa.large_list(pa.large_binary())) - ], ['x', 'y', 'a']), - pa.RecordBatch.from_arrays([ - pa.array([[1]], type=pa.large_list(pa.float32())), - pa.array([[3]], type=pa.large_list(pa.float32())), - pa.array([['world']], type=pa.large_list(pa.large_binary())) - ], ['x', 'y', 'a']) - ] - tensor_representations = { - name: text_format.Parse( - f'dense_tensor {{ column_name: \"{name}\" shape {{}} }}', - schema_pb2.TensorRepresentation()) for name in ('x', 'y', 'a') - } - expected_input_size = sum(rb.nbytes for rb in test_data) - - def preprocessing_fn(inputs): - _ = tft.vocabulary(inputs['a']) - return { - 'a_int': tft.compute_and_apply_vocabulary(inputs['a']), - 'x_scaled': tft.scale_to_0_1(inputs['x']), - 'y_scaled': tft.scale_to_0_1(inputs['y']) - } - - with self._makeTestPipeline() as pipeline: - input_data = pipeline | 'CreateTrainingData' >> beam.Create(test_data) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - test_data[0].schema, tensor_representations) - - with tft_beam.Context(temp_dir=self.get_temp_dir()): - _ = ((input_data, tensor_adapter_config) - | 'AnalyzeDataset' >> tft_beam.AnalyzeDataset(preprocessing_fn)) - - metrics = pipeline.metrics - self.assertMetricsCounterEqual(metrics, 'analysis_input_bytes', - expected_input_size) - - def testHandleBatchError(self): - self._SkipIfOutputRecordBatches() - - def preprocessing_fn(inputs): - return {'x_scaled': tft.scale_to_0_1(inputs['x'])} - - # Exception type depends on the runner being used. - with self.assertRaisesRegex( - (RuntimeError, ValueError, TypeError), '.*list'): - # TODO(b/149997088): Remove this explicit use of DirectRunner. - with beam.Pipeline() as pipeline: - metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.float32), - }) - - input_data = pipeline | 'CreateTrainingData' >> beam.Create([{ - 'x': 1 - }, { - 'x': [4, 1] - }]) - with tft_beam.Context(temp_dir=self.get_temp_dir()): - _ = ((input_data, metadata) - | 'AnalyzeDataset' >> tft_beam.AnalyzeDataset(preprocessing_fn)) - - def testPassthroughKeys(self): - passthrough_key1 = '__passthrough__' - passthrough_key2 = '__passthrough_not_in_input_record_batch__' - - def preprocessing_fn(inputs): - self.assertNotIn(passthrough_key1, inputs) - self.assertNotIn(passthrough_key2, inputs) - return {'x_scaled': tft.scale_to_0_1(inputs['x'])} - - x_data = [0., 1., 2.] - passthrough_data = [1, None, 3] - input_record_batch = pa.RecordBatch.from_arrays([ - pa.array([[x] for x in x_data], type=pa.list_(pa.float32())), - pa.array([None if p is None else [p] for p in passthrough_data], - type=pa.list_(pa.int64())), - ], ['x', passthrough_key1]) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - input_record_batch.schema, - {'x': text_format.Parse( - 'dense_tensor { column_name: "x" shape {} }', - schema_pb2.TensorRepresentation())}) - - with self._makeTestPipeline() as pipeline: - input_data = ( - pipeline | beam.Create([input_record_batch])) - with tft_beam.Context( - temp_dir=self.get_temp_dir(), - passthrough_keys=set([passthrough_key1, passthrough_key2])): - (transformed_data, - _), _ = ((input_data, tensor_adapter_config) - | tft_beam.AnalyzeAndTransformDataset( - preprocessing_fn, - output_record_batches=self._OutputRecordBatches())) - expected_data = [{'x_scaled': x / 2.0, passthrough_key1: p} - for x, p in zip(x_data, passthrough_data)] - beam_test_util.assert_that( - transformed_data, self._MakeTransformOutputAssertFn(expected_data)) - - def test3dSparseWithTFXIO(self): - x_data = [0., 1., 2.] - x_idx0 = [0, 0, 1] - x_idx1 = [0, 0, 1] - input_record_batch = pa.RecordBatch.from_arrays([ - pa.array([[x] for x in x_idx0], type=pa.list_(pa.int64())), - pa.array([[x] for x in x_idx1], type=pa.list_(pa.int64())), - pa.array([[x] for x in x_data], type=pa.list_(pa.float32())), - ], ['x_idx0', 'x_idx1', 'x_val']) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - input_record_batch.schema, { - 'x': - text_format.Parse( + tensor_representations = { + name: text_format.Parse( + f'dense_tensor {{ column_name: "{name}" shape {{}} }}', + schema_pb2.TensorRepresentation(), + ) + for name in ("x", "y", "a") + } + expected_input_size = sum(rb.nbytes for rb in test_data) + + def preprocessing_fn(inputs): + _ = tft.vocabulary(inputs["a"]) + return { + "a_int": tft.compute_and_apply_vocabulary(inputs["a"]), + "x_scaled": tft.scale_to_0_1(inputs["x"]), + "y_scaled": tft.scale_to_0_1(inputs["y"]), + } + + with self._makeTestPipeline() as pipeline: + input_data = pipeline | "CreateTrainingData" >> beam.Create(test_data) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + test_data[0].schema, tensor_representations + ) + + with tft_beam.Context(temp_dir=self.get_temp_dir()): + _ = ( + input_data, + tensor_adapter_config, + ) | "AnalyzeDataset" >> tft_beam.AnalyzeDataset(preprocessing_fn) + + metrics = pipeline.metrics + self.assertMetricsCounterEqual( + metrics, "analysis_input_bytes", expected_input_size + ) + + def testHandleBatchError(self): + self._SkipIfOutputRecordBatches() + + def preprocessing_fn(inputs): + return {"x_scaled": tft.scale_to_0_1(inputs["x"])} + + # Exception type depends on the runner being used. + with self.assertRaisesRegex((RuntimeError, ValueError, TypeError), ".*list"): + # TODO(b/149997088): Remove this explicit use of DirectRunner. + with beam.Pipeline() as pipeline: + metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.float32), + } + ) + + input_data = pipeline | "CreateTrainingData" >> beam.Create( + [{"x": 1}, {"x": [4, 1]}] + ) + with tft_beam.Context(temp_dir=self.get_temp_dir()): + _ = ( + input_data, + metadata, + ) | "AnalyzeDataset" >> tft_beam.AnalyzeDataset(preprocessing_fn) + + def testPassthroughKeys(self): + passthrough_key1 = "__passthrough__" + passthrough_key2 = "__passthrough_not_in_input_record_batch__" + + def preprocessing_fn(inputs): + self.assertNotIn(passthrough_key1, inputs) + self.assertNotIn(passthrough_key2, inputs) + return {"x_scaled": tft.scale_to_0_1(inputs["x"])} + + x_data = [0.0, 1.0, 2.0] + passthrough_data = [1, None, 3] + input_record_batch = pa.RecordBatch.from_arrays( + [ + pa.array([[x] for x in x_data], type=pa.list_(pa.float32())), + pa.array( + [None if p is None else [p] for p in passthrough_data], + type=pa.list_(pa.int64()), + ), + ], + ["x", passthrough_key1], + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + input_record_batch.schema, + { + "x": text_format.Parse( + 'dense_tensor { column_name: "x" shape {} }', + schema_pb2.TensorRepresentation(), + ) + }, + ) + + with self._makeTestPipeline() as pipeline: + input_data = pipeline | beam.Create([input_record_batch]) + with tft_beam.Context( + temp_dir=self.get_temp_dir(), + passthrough_keys=set([passthrough_key1, passthrough_key2]), + ): + (transformed_data, _), _ = ( + input_data, + tensor_adapter_config, + ) | tft_beam.AnalyzeAndTransformDataset( + preprocessing_fn, output_record_batches=self._OutputRecordBatches() + ) + expected_data = [ + {"x_scaled": x / 2.0, passthrough_key1: p} + for x, p in zip(x_data, passthrough_data) + ] + beam_test_util.assert_that( + transformed_data, self._MakeTransformOutputAssertFn(expected_data) + ) + + def test3dSparseWithTFXIO(self): + x_data = [0.0, 1.0, 2.0] + x_idx0 = [0, 0, 1] + x_idx1 = [0, 0, 1] + input_record_batch = pa.RecordBatch.from_arrays( + [ + pa.array([[x] for x in x_idx0], type=pa.list_(pa.int64())), + pa.array([[x] for x in x_idx1], type=pa.list_(pa.int64())), + pa.array([[x] for x in x_data], type=pa.list_(pa.float32())), + ], + ["x_idx0", "x_idx1", "x_val"], + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + input_record_batch.schema, + { + "x": text_format.Parse( """ sparse_tensor { index_column_names: ["x_idx0", "x_idx1"] @@ -4391,39 +4857,45 @@ def test3dSparseWithTFXIO(self): size: 5 } } - }""", schema_pb2.TensorRepresentation()) - }) - expected_data = [ - { # pylint: disable=g-complex-comprehension - 'x$sparse_values': x, - 'x$sparse_indices_0': idx0, - 'x$sparse_indices_1': idx1 - } for idx0, idx1, x in zip(x_idx0, x_idx1, x_data) - ] - - materialize_path = os.path.join(self.get_temp_dir(), 'transformed_data') - transform_output_path = os.path.join(self.get_temp_dir(), - 'transform_output') - with self._makeTestPipeline() as pipeline: - input_data = (pipeline | beam.Create([input_record_batch])) - with tft_beam.Context(temp_dir=self.get_temp_dir()): - (transformed_data, transformed_metadata), transform_fn = ( - (input_data, tensor_adapter_config) - | tft_beam.AnalyzeAndTransformDataset( - lambda inputs: inputs, - output_record_batches=self._OutputRecordBatches())) - - _ = ((transformed_data, transformed_metadata) - | 'Encode' >> tft_beam.EncodeTransformedDataset() - | 'Write' >> beam.io.WriteToTFRecord( - materialize_path, shard_name_template='')) - _ = ( - transform_fn - | 'WriteTransformFn' >> - tft.beam.WriteTransformFn(transform_output_path)) - - expected_metadata = text_format.Parse( - """ + }""", + schema_pb2.TensorRepresentation(), + ) + }, + ) + expected_data = [ + { # pylint: disable=g-complex-comprehension + "x$sparse_values": x, + "x$sparse_indices_0": idx0, + "x$sparse_indices_1": idx1, + } + for idx0, idx1, x in zip(x_idx0, x_idx1, x_data) + ] + + materialize_path = os.path.join(self.get_temp_dir(), "transformed_data") + transform_output_path = os.path.join(self.get_temp_dir(), "transform_output") + with self._makeTestPipeline() as pipeline: + input_data = pipeline | beam.Create([input_record_batch]) + with tft_beam.Context(temp_dir=self.get_temp_dir()): + (transformed_data, transformed_metadata), transform_fn = ( + input_data, + tensor_adapter_config, + ) | tft_beam.AnalyzeAndTransformDataset( + lambda inputs: inputs, + output_record_batches=self._OutputRecordBatches(), + ) + + _ = ( + (transformed_data, transformed_metadata) + | "Encode" >> tft_beam.EncodeTransformedDataset() + | "Write" + >> beam.io.WriteToTFRecord(materialize_path, shard_name_template="") + ) + _ = transform_fn | "WriteTransformFn" >> tft.beam.WriteTransformFn( + transform_output_path + ) + + expected_metadata = text_format.Parse( + """ feature { name: "x$sparse_indices_0" type: INT @@ -4457,90 +4929,104 @@ def test3dSparseWithTFXIO(self): name: "x$sparse_values" } }""", - schema_pb2.Schema(), - ) - if not tft_unit.is_external_environment(): - expected_metadata.generate_legacy_feature_spec = False - - self.assertProtoEquals(transformed_metadata.schema, expected_metadata) - - beam_test_util.assert_that( - transformed_data, self._MakeTransformOutputAssertFn(expected_data)) - - def _assert_schemas_equal_fn(schema_dict_list): - self.assertEqual(1, len(schema_dict_list)) - self.assertProtoEquals(schema_dict_list[0].schema, expected_metadata) - - beam_test_util.assert_that( - transformed_metadata.deferred_metadata, - _assert_schemas_equal_fn, - label='assert_deferred_metadata') - - with tf.Graph().as_default(): - dataset = tf.data.TFRecordDataset(materialize_path) - tft_out = tft.TFTransformOutput(transform_output_path) - transformed_feature_spec = tft_out.transformed_feature_spec() - self.assertLen(transformed_feature_spec, 1) - self.assertIn('x', transformed_feature_spec) - self.assertEqual( - transformed_feature_spec['x'], - tf.io.SparseFeature( - ['x$sparse_indices_0', 'x$sparse_indices_1'], - 'x$sparse_values', - tf.float32, - [5, 5], - already_sorted=True, - ), - ) - - transformed_feature_spec['x'] = tf.io.SparseFeature( - ['x$sparse_indices_0', 'x$sparse_indices_1'], - 'x$sparse_values', tf.float32, [5, 5], already_sorted=True) - - def parse_fn(serialized_input): - result = tf.io.parse_single_example(serialized_input, - transformed_feature_spec)['x'] - return result.indices, result.values, result.dense_shape - - dataset = dataset.map(parse_fn).batch(len(x_data)) - transformed_sparse_components = tf.data.experimental.get_single_element( - dataset) - with tf.compat.v1.Session(): - transformed_sparse_components = [ - t.eval() for t in transformed_sparse_components + schema_pb2.Schema(), + ) + if not tft_unit.is_external_environment(): + expected_metadata.generate_legacy_feature_spec = False + + self.assertProtoEquals(transformed_metadata.schema, expected_metadata) + + beam_test_util.assert_that( + transformed_data, self._MakeTransformOutputAssertFn(expected_data) + ) + + def _assert_schemas_equal_fn(schema_dict_list): + self.assertEqual(1, len(schema_dict_list)) + self.assertProtoEquals( + schema_dict_list[0].schema, expected_metadata + ) + + beam_test_util.assert_that( + transformed_metadata.deferred_metadata, + _assert_schemas_equal_fn, + label="assert_deferred_metadata", + ) + + with tf.Graph().as_default(): + dataset = tf.data.TFRecordDataset(materialize_path) + tft_out = tft.TFTransformOutput(transform_output_path) + transformed_feature_spec = tft_out.transformed_feature_spec() + self.assertLen(transformed_feature_spec, 1) + self.assertIn("x", transformed_feature_spec) + self.assertEqual( + transformed_feature_spec["x"], + tf.io.SparseFeature( + ["x$sparse_indices_0", "x$sparse_indices_1"], + "x$sparse_values", + tf.float32, + [5, 5], + already_sorted=True, + ), + ) + + transformed_feature_spec["x"] = tf.io.SparseFeature( + ["x$sparse_indices_0", "x$sparse_indices_1"], + "x$sparse_values", + tf.float32, + [5, 5], + already_sorted=True, + ) + + def parse_fn(serialized_input): + result = tf.io.parse_single_example( + serialized_input, transformed_feature_spec + )["x"] + return result.indices, result.values, result.dense_shape + + dataset = dataset.map(parse_fn).batch(len(x_data)) + transformed_sparse_components = tf.data.experimental.get_single_element( + dataset + ) + with tf.compat.v1.Session(): + transformed_sparse_components = [ + t.eval() for t in transformed_sparse_components + ] + expected_sparse_components = [ + np.array([[arr] for arr in zip(x_idx0, x_idx1)]), + np.array([[x] for x in x_data]), + np.array([[5, 5]] * len(x_data)), ] - expected_sparse_components = [ - np.array([[arr] for arr in zip(x_idx0, x_idx1)]), - np.array([[x] for x in x_data]), - np.array([[5, 5]] * len(x_data)) - ] - self.assertLen(transformed_sparse_components, - len(expected_sparse_components)) - for transformed, expected in zip(transformed_sparse_components, - expected_sparse_components): - self.assertAllEqual(expected[0], transformed[0]) - self.assertAllEqual(expected[1], transformed[1]) - self.assertAllEqual(expected[2], transformed[2]) - - def testRaggedWithTFXIO(self): - x_data = [[[1], [], [2, 3]], [[]]] - y_data = [[[1, 2]], [[3, 4], [], [5, 6]]] - input_record_batch = pa.RecordBatch.from_arrays([ - pa.array(x_data, type=pa.large_list(pa.large_list(pa.int64()))), - pa.array(y_data, type=pa.large_list(pa.large_list(pa.float32()))) - ], ['x', 'y']) - tensor_adapter_config = tensor_adapter.TensorAdapterConfig( - input_record_batch.schema, { - 'x': - text_format.Parse( + self.assertLen(transformed_sparse_components, len(expected_sparse_components)) + for transformed, expected in zip( + transformed_sparse_components, expected_sparse_components + ): + self.assertAllEqual(expected[0], transformed[0]) + self.assertAllEqual(expected[1], transformed[1]) + self.assertAllEqual(expected[2], transformed[2]) + + def testRaggedWithTFXIO(self): + x_data = [[[1], [], [2, 3]], [[]]] + y_data = [[[1, 2]], [[3, 4], [], [5, 6]]] + input_record_batch = pa.RecordBatch.from_arrays( + [ + pa.array(x_data, type=pa.large_list(pa.large_list(pa.int64()))), + pa.array(y_data, type=pa.large_list(pa.large_list(pa.float32()))), + ], + ["x", "y"], + ) + tensor_adapter_config = tensor_adapter.TensorAdapterConfig( + input_record_batch.schema, + { + "x": text_format.Parse( """ragged_tensor { feature_path { step: "x" } row_partition_dtype: INT64 - }""", schema_pb2.TensorRepresentation()), - 'y': - text_format.Parse( + }""", + schema_pb2.TensorRepresentation(), + ), + "y": text_format.Parse( """ragged_tensor { feature_path { step: "y" @@ -4549,257 +5035,266 @@ def testRaggedWithTFXIO(self): partition { uniform_row_length: 2 } - }""", schema_pb2.TensorRepresentation()) - }) + }""", + schema_pb2.TensorRepresentation(), + ), + }, + ) - def preprocessing_fn(inputs): - return { - 'x_ones': tf.ones_like(inputs['x']), - 'y_ones': tf.ones_like(inputs['y']) - } + def preprocessing_fn(inputs): + return { + "x_ones": tf.ones_like(inputs["x"]), + "y_ones": tf.ones_like(inputs["y"]), + } - expected_data = [ - { - 'x_ones$ragged_values': [1, 1, 1], - 'x_ones$row_lengths_1': [1, 0, 2], - 'y_ones$ragged_values': [1, 1], - 'y_ones$row_lengths_1': [1], - }, - { - 'x_ones$ragged_values': [], - 'x_ones$row_lengths_1': [0], - 'y_ones$ragged_values': [1, 1, 1, 1], - 'y_ones$row_lengths_1': [1, 0, 1], - }, - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_ones': - tf.io.RaggedFeature( - tf.int64, - value_key='x_ones$ragged_values', - partitions=[ - tf.io.RaggedFeature.RowLengths('x_ones$row_lengths_1') # pytype: disable=attribute-error - ]), - 'y_ones': - tf.io.RaggedFeature( - tf.float32, - value_key='y_ones$ragged_values', - partitions=[ - tf.io.RaggedFeature.RowLengths('y_ones$row_lengths_1'), # pytype: disable=attribute-error - tf.io.RaggedFeature.UniformRowLength(2), # pytype: disable=attribute-error - ]), - }) - self.assertAnalyzeAndTransformResults([input_record_batch], - tensor_adapter_config, - preprocessing_fn, - expected_data=expected_data, - expected_metadata=expected_metadata) - - def testPipelineWithoutAutomaterialization(self): - # Other tests pass lists instead of PCollections and thus invoke - # automaterialization where each call to a beam PTransform will implicitly - # run its own pipeline. - # - # In order to test the case where PCollections are not materialized in - # between calls to the tf.Transform PTransforms, we include a test that is - # not based on automaterialization. - def preprocessing_fn(inputs): - return {'x_scaled': tft.scale_to_0_1(inputs['x'])} - - with self._makeTestPipeline() as pipeline: - input_data = pipeline | 'CreateTrainingData' >> beam.Create( - [{'x': 4}, {'x': 1}, {'x': 5}, {'x': 2}]) - metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.float32)}) - with tft_beam.Context(temp_dir=self.get_temp_dir()): - transform_fn = ( - (input_data, metadata) - | 'AnalyzeDataset' >> tft_beam.AnalyzeDataset(preprocessing_fn)) - - # Run transform_columns on some eval dataset. - eval_data = pipeline | 'CreateEvalData' >> beam.Create( - [{'x': 6}, {'x': 3}]) - transformed_eval_data, _ = ( - ((eval_data, metadata), transform_fn) - | 'TransformDataset' >> tft_beam.TransformDataset( - output_record_batches=self._OutputRecordBatches())) - expected_data = [{'x_scaled': 1.25}, {'x_scaled': 0.5}] - beam_test_util.assert_that( - transformed_eval_data, - self._MakeTransformOutputAssertFn(expected_data, sort=True)) - - def testModifyInputs(self): - - def preprocessing_fn(inputs): - inputs['x_center'] = inputs['x'] - tft.mean(inputs['x']) - return inputs - - input_data = [{'x': 1}, {'x': 2}, {'x': 3}, {'x': 4}, {'x': 5}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.float32)}) - expected_outputs = [{ - 'x': 1, - 'x_center': -2 - }, { - 'x': 2, - 'x_center': -1 - }, { - 'x': 3, - 'x_center': 0 - }, { - 'x': 4, - 'x_center': 1 - }, { - 'x': 5, - 'x_center': 2 - }] - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_outputs) - - def testEmptySchema(self): - with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises - ValueError, 'The input metadata is empty.' - ): - self.assertAnalyzeAndTransformResults( - input_data=[{'x': x} for x in range(5)], - input_metadata=tft.DatasetMetadata.from_feature_spec({}), - preprocessing_fn=lambda inputs: inputs) # pyformat: disable - - def testLoadKerasModelInPreprocessingFn(self): - def _create_model(features, target): - inputs = [ - tf_keras.Input(shape=(1,), name=f, dtype=tf.float32) for f in features - ] - x = tf_keras.layers.Concatenate()(inputs) - x = tf_keras.layers.Dense(64, activation='relu')(x) - outputs = tf_keras.layers.Dense(1, activation='sigmoid', name=target)(x) - model = tf_keras.Model(inputs=inputs, outputs=outputs) - model.compile( - loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) - - n = 50 - model.fit( - { - f: tf.constant([np.random.uniform() for _ in range(n) - ]) for f in features - }, - {target: tf.constant([np.random.randint(2) for _ in range(n)])}, - ) - return model - - test_base_dir = os.path.join(self.get_temp_dir(), self._testMethodName) - # Create and save a test Keras model - features = ['f1', 'f2'] - target = 't' - keras_model = _create_model(features, target) - keras_model_dir = os.path.join(test_base_dir, 'keras_model') - keras_model.save(keras_model_dir) - - def preprocessing_fn(inputs): - model = tft.make_and_track_object( - lambda: tf_keras.models.load_model(keras_model_dir), name='keras') - return {'prediction': model(inputs)} - - input_data = [{'f1': 1.0, 'f2': 0.0}, {'f1': 2.0, 'f2': 3.0}] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'f1': tf.io.FixedLenFeature([], tf.float32), - 'f2': tf.io.FixedLenFeature([], tf.float32) - }) - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn) - - def test_non_deterministic_preprocessing_fn_without_name(self): - - idx = 0 - - def get_features(): - nonlocal idx - features = ['f1', 'f2', 'f3'] - result = features[idx:] + features[:idx] - idx = 0 if idx == 2 else idx + 1 - return result - - def preprocessing_fn(inputs): - features = get_features() - - outputs = {} - for f in features: - outputs[f] = inputs[f] - tft.mean(inputs[f]) - return outputs - - input_data = [{'f1': 0, 'f2': 10, 'f3': 20}, {'f1': 2, 'f2': 12, 'f3': 22}] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'f1': tf.io.FixedLenFeature([], tf.float32), - 'f2': tf.io.FixedLenFeature([], tf.float32), - 'f3': tf.io.FixedLenFeature([], tf.float32) - }) - expected_outputs = [{ - 'f1': -1, - 'f2': -1, - 'f3': -1 - }, { - 'f1': 1, - 'f2': 1, - 'f3': 1 - }] - - with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises - RuntimeError, 'analyzers.*appears to be non-deterministic'): - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_outputs) - - def test_non_deterministic_preprocessing_fn_with_name(self): - - idx = 0 - - def get_features(): - nonlocal idx - features = ['f1', 'f2', 'f3'] - result = features[idx:] + features[:idx] - idx = 0 if idx == 2 else idx + 1 - return result - - def preprocessing_fn(inputs): - features = get_features() - - outputs = {} - for f in features: - outputs[f] = inputs[f] - tft.mean(inputs[f], name=f) - return outputs - - input_data = [{'f1': 0, 'f2': 10, 'f3': 20}, {'f1': 2, 'f2': 12, 'f3': 22}] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'f1': tf.io.FixedLenFeature([], tf.float32), - 'f2': tf.io.FixedLenFeature([], tf.float32), - 'f3': tf.io.FixedLenFeature([], tf.float32) - }) - expected_outputs = [{ - 'f1': -1, - 'f2': -1, - 'f3': -1 - }, { - 'f1': 1, - 'f2': 1, - 'f3': 1 - }] - - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_outputs) - - def test_preprocessing_fn_returns_wrong_type(self): - with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises - ValueError, - r'A `preprocessing_fn` must return a ' - r'Dict\[str, Union\[tf.Tensor, tf.SparseTensor, tf.RaggedTensor\]\]. ' - 'Got: Tensor.*', - ): - self.assertAnalyzeAndTransformResults( - input_data=[{'f1': 0}], - input_metadata=tft.DatasetMetadata.from_feature_spec( - {'f1': tf.io.FixedLenFeature([], tf.float32)}), - preprocessing_fn=lambda inputs: inputs['f1'], - expected_data=None) + expected_data = [ + { + "x_ones$ragged_values": [1, 1, 1], + "x_ones$row_lengths_1": [1, 0, 2], + "y_ones$ragged_values": [1, 1], + "y_ones$row_lengths_1": [1], + }, + { + "x_ones$ragged_values": [], + "x_ones$row_lengths_1": [0], + "y_ones$ragged_values": [1, 1, 1, 1], + "y_ones$row_lengths_1": [1, 0, 1], + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_ones": tf.io.RaggedFeature( + tf.int64, + value_key="x_ones$ragged_values", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "x_ones$row_lengths_1" + ) # pytype: disable=attribute-error + ], + ), + "y_ones": tf.io.RaggedFeature( + tf.float32, + value_key="y_ones$ragged_values", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "y_ones$row_lengths_1" + ), # pytype: disable=attribute-error + tf.io.RaggedFeature.UniformRowLength( + 2 + ), # pytype: disable=attribute-error + ], + ), + } + ) + self.assertAnalyzeAndTransformResults( + [input_record_batch], + tensor_adapter_config, + preprocessing_fn, + expected_data=expected_data, + expected_metadata=expected_metadata, + ) + + def testPipelineWithoutAutomaterialization(self): + # Other tests pass lists instead of PCollections and thus invoke + # automaterialization where each call to a beam PTransform will implicitly + # run its own pipeline. + # + # In order to test the case where PCollections are not materialized in + # between calls to the tf.Transform PTransforms, we include a test that is + # not based on automaterialization. + def preprocessing_fn(inputs): + return {"x_scaled": tft.scale_to_0_1(inputs["x"])} + + with self._makeTestPipeline() as pipeline: + input_data = pipeline | "CreateTrainingData" >> beam.Create( + [{"x": 4}, {"x": 1}, {"x": 5}, {"x": 2}] + ) + metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.float32)} + ) + with tft_beam.Context(temp_dir=self.get_temp_dir()): + transform_fn = ( + input_data, + metadata, + ) | "AnalyzeDataset" >> tft_beam.AnalyzeDataset(preprocessing_fn) + + # Run transform_columns on some eval dataset. + eval_data = pipeline | "CreateEvalData" >> beam.Create( + [{"x": 6}, {"x": 3}] + ) + transformed_eval_data, _ = ( + (eval_data, metadata), + transform_fn, + ) | "TransformDataset" >> tft_beam.TransformDataset( + output_record_batches=self._OutputRecordBatches() + ) + expected_data = [{"x_scaled": 1.25}, {"x_scaled": 0.5}] + beam_test_util.assert_that( + transformed_eval_data, + self._MakeTransformOutputAssertFn(expected_data, sort=True), + ) + + def testModifyInputs(self): + def preprocessing_fn(inputs): + inputs["x_center"] = inputs["x"] - tft.mean(inputs["x"]) + return inputs + + input_data = [{"x": 1}, {"x": 2}, {"x": 3}, {"x": 4}, {"x": 5}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.float32)} + ) + expected_outputs = [ + {"x": 1, "x_center": -2}, + {"x": 2, "x_center": -1}, + {"x": 3, "x_center": 0}, + {"x": 4, "x_center": 1}, + {"x": 5, "x_center": 2}, + ] + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn, expected_outputs + ) + + def testEmptySchema(self): + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + ValueError, "The input metadata is empty." + ): + self.assertAnalyzeAndTransformResults( + input_data=[{"x": x} for x in range(5)], + input_metadata=tft.DatasetMetadata.from_feature_spec({}), + preprocessing_fn=lambda inputs: inputs, + ) # pyformat: disable + + def testLoadKerasModelInPreprocessingFn(self): + def _create_model(features, target): + inputs = [ + tf_keras.Input(shape=(1,), name=f, dtype=tf.float32) for f in features + ] + x = tf_keras.layers.Concatenate()(inputs) + x = tf_keras.layers.Dense(64, activation="relu")(x) + outputs = tf_keras.layers.Dense(1, activation="sigmoid", name=target)(x) + model = tf_keras.Model(inputs=inputs, outputs=outputs) + model.compile( + loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"] + ) + + n = 50 + model.fit( + { + f: tf.constant([np.random.uniform() for _ in range(n)]) + for f in features + }, + {target: tf.constant([np.random.randint(2) for _ in range(n)])}, + ) + return model + + test_base_dir = os.path.join(self.get_temp_dir(), self._testMethodName) + # Create and save a test Keras model + features = ["f1", "f2"] + target = "t" + keras_model = _create_model(features, target) + keras_model_dir = os.path.join(test_base_dir, "keras_model") + keras_model.save(keras_model_dir) + + def preprocessing_fn(inputs): + model = tft.make_and_track_object( + lambda: tf_keras.models.load_model(keras_model_dir), name="keras" + ) + return {"prediction": model(inputs)} + + input_data = [{"f1": 1.0, "f2": 0.0}, {"f1": 2.0, "f2": 3.0}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "f1": tf.io.FixedLenFeature([], tf.float32), + "f2": tf.io.FixedLenFeature([], tf.float32), + } + ) + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn + ) + + def test_non_deterministic_preprocessing_fn_without_name(self): + idx = 0 + + def get_features(): + nonlocal idx + features = ["f1", "f2", "f3"] + result = features[idx:] + features[:idx] + idx = 0 if idx == 2 else idx + 1 + return result + + def preprocessing_fn(inputs): + features = get_features() + + outputs = {} + for f in features: + outputs[f] = inputs[f] - tft.mean(inputs[f]) + return outputs + + input_data = [{"f1": 0, "f2": 10, "f3": 20}, {"f1": 2, "f2": 12, "f3": 22}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "f1": tf.io.FixedLenFeature([], tf.float32), + "f2": tf.io.FixedLenFeature([], tf.float32), + "f3": tf.io.FixedLenFeature([], tf.float32), + } + ) + expected_outputs = [{"f1": -1, "f2": -1, "f3": -1}, {"f1": 1, "f2": 1, "f3": 1}] + + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + RuntimeError, "analyzers.*appears to be non-deterministic" + ): + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn, expected_outputs + ) + + def test_non_deterministic_preprocessing_fn_with_name(self): + idx = 0 + + def get_features(): + nonlocal idx + features = ["f1", "f2", "f3"] + result = features[idx:] + features[:idx] + idx = 0 if idx == 2 else idx + 1 + return result + + def preprocessing_fn(inputs): + features = get_features() + + outputs = {} + for f in features: + outputs[f] = inputs[f] - tft.mean(inputs[f], name=f) + return outputs + + input_data = [{"f1": 0, "f2": 10, "f3": 20}, {"f1": 2, "f2": 12, "f3": 22}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "f1": tf.io.FixedLenFeature([], tf.float32), + "f2": tf.io.FixedLenFeature([], tf.float32), + "f3": tf.io.FixedLenFeature([], tf.float32), + } + ) + expected_outputs = [{"f1": -1, "f2": -1, "f3": -1}, {"f1": 1, "f2": 1, "f3": 1}] + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn, expected_outputs + ) -if __name__ == '__main__': - tft_unit.main() + def test_preprocessing_fn_returns_wrong_type(self): + with self.assertRaisesRegex( # pylint: disable=g-error-prone-assert-raises + ValueError, + r"A `preprocessing_fn` must return a " + r"Dict\[str, Union\[tf.Tensor, tf.SparseTensor, tf.RaggedTensor\]\]. " + "Got: Tensor.*", + ): + self.assertAnalyzeAndTransformResults( + input_data=[{"f1": 0}], + input_metadata=tft.DatasetMetadata.from_feature_spec( + {"f1": tf.io.FixedLenFeature([], tf.float32)} + ), + preprocessing_fn=lambda inputs: inputs["f1"], + expected_data=None, + ) + + +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/test_helpers.py b/tensorflow_transform/beam/test_helpers.py index 08b031c..736280d 100644 --- a/tensorflow_transform/beam/test_helpers.py +++ b/tensorflow_transform/beam/test_helpers.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2017 Google Inc. All Rights Reserved. # @@ -17,6 +16,6 @@ def make_test_beam_pipeline_kwargs(): - # This is kwargs for apache_beam.Pipeline's __init__, using the default runner - # here. - return {} + # This is kwargs for apache_beam.Pipeline's __init__, using the default runner + # here. + return {} diff --git a/tensorflow_transform/beam/tft_beam_io/__init__.py b/tensorflow_transform/beam/tft_beam_io/__init__.py index 1e74690..fd9cdcf 100644 --- a/tensorflow_transform/beam/tft_beam_io/__init__.py +++ b/tensorflow_transform/beam/tft_beam_io/__init__.py @@ -14,5 +14,7 @@ """Module level imports for tensorflow_transform.beam.tft_beam_io.""" from tensorflow_transform.beam.tft_beam_io.beam_metadata_io import WriteMetadata -from tensorflow_transform.beam.tft_beam_io.transform_fn_io import ReadTransformFn -from tensorflow_transform.beam.tft_beam_io.transform_fn_io import WriteTransformFn +from tensorflow_transform.beam.tft_beam_io.transform_fn_io import ( + ReadTransformFn, + WriteTransformFn, +) diff --git a/tensorflow_transform/beam/tft_beam_io/beam_metadata_io.py b/tensorflow_transform/beam/tft_beam_io/beam_metadata_io.py index 3fba073..8182f97 100644 --- a/tensorflow_transform/beam/tft_beam_io/beam_metadata_io.py +++ b/tensorflow_transform/beam/tft_beam_io/beam_metadata_io.py @@ -22,80 +22,87 @@ import apache_beam as beam import tensorflow as tf -from tensorflow_transform import output_wrapper -from tensorflow_transform.beam import common -from tensorflow_transform.tf_metadata import metadata_io + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple +from tensorflow_transform import output_wrapper +from tensorflow_transform.beam import common +from tensorflow_transform.tf_metadata import metadata_io + class BeamDatasetMetadata( tfx_namedtuple.namedtuple( - 'BeamDatasetMetadata', - ['dataset_metadata', 'deferred_metadata', 'asset_map'])): - """A class like DatasetMetadata also holding `PCollection`s and an asset_map. + "BeamDatasetMetadata", ["dataset_metadata", "deferred_metadata", "asset_map"] + ) +): + """A class like DatasetMetadata also holding `PCollection`s and an asset_map. - `deferred_metadata` is a PCollection containing a single DatasetMetadata. - `asset_map` is a Dictionary mapping asset keys to filenames. - """ + `deferred_metadata` is a PCollection containing a single DatasetMetadata. + `asset_map` is a Dictionary mapping asset keys to filenames. + """ - @property - def schema(self): - return self.dataset_metadata.schema + @property + def schema(self): + return self.dataset_metadata.schema class WriteMetadata(beam.PTransform): - """A PTransform to write Metadata to disk. + """A PTransform to write Metadata to disk. - Input can either be a DatasetMetadata or a tuple of properties. + Input can either be a DatasetMetadata or a tuple of properties. - Depending on the optional `write_to_unique_subdirectory`, writes the given - metadata to either `path` or a new unique subdirectory under `path`. + Depending on the optional `write_to_unique_subdirectory`, writes the given + metadata to either `path` or a new unique subdirectory under `path`. - Returns a singleton with the path to which the metadata was written. - """ - - # NOTE: The pipeline metadata is required by PTransform given that all the - # inputs may be non-deferred. - def __init__(self, path, pipeline, write_to_unique_subdirectory=False): - """Init method. - - Args: - path: A str, the default path that the metadata should be written to. - pipeline: A beam Pipeline. - write_to_unique_subdirectory: (Optional) A bool indicating whether to - write the metadata out to `path` or a unique subdirectory under `path`. + Returns a singleton with the path to which the metadata was written. """ - super().__init__() - self._path = path - self._write_to_unique_subdirectory = write_to_unique_subdirectory - self.pipeline = pipeline - - def _extract_input_pvalues(self, metadata): - pvalues = [] - if isinstance(metadata, BeamDatasetMetadata): - pvalues.append(metadata.deferred_metadata) - return metadata, pvalues - - def expand(self, metadata): - if hasattr(metadata, 'deferred_metadata'): - metadata_pcoll = metadata.deferred_metadata - else: - metadata_pcoll = self.pipeline | beam.Create([metadata]) - - asset_map = getattr(metadata, 'asset_map', {}) - - def write_metadata_output(metadata): - output_path = self._path - if self._write_to_unique_subdirectory: - output_path = common.get_unique_temp_path(self._path) - metadata_io.write_metadata(metadata, output_path) - if asset_map: - with tf.io.gfile.GFile( - os.path.join(output_path, - output_wrapper.TFTransformOutput.ASSET_MAP), 'w') as f: - f.write(json.dumps(asset_map)) - return output_path - - return metadata_pcoll | 'WriteMetadata' >> beam.Map(write_metadata_output) + + # NOTE: The pipeline metadata is required by PTransform given that all the + # inputs may be non-deferred. + def __init__(self, path, pipeline, write_to_unique_subdirectory=False): + """Init method. + + Args: + ---- + path: A str, the default path that the metadata should be written to. + pipeline: A beam Pipeline. + write_to_unique_subdirectory: (Optional) A bool indicating whether to + write the metadata out to `path` or a unique subdirectory under `path`. + """ + super().__init__() + self._path = path + self._write_to_unique_subdirectory = write_to_unique_subdirectory + self.pipeline = pipeline + + def _extract_input_pvalues(self, metadata): + pvalues = [] + if isinstance(metadata, BeamDatasetMetadata): + pvalues.append(metadata.deferred_metadata) + return metadata, pvalues + + def expand(self, metadata): + if hasattr(metadata, "deferred_metadata"): + metadata_pcoll = metadata.deferred_metadata + else: + metadata_pcoll = self.pipeline | beam.Create([metadata]) + + asset_map = getattr(metadata, "asset_map", {}) + + def write_metadata_output(metadata): + output_path = self._path + if self._write_to_unique_subdirectory: + output_path = common.get_unique_temp_path(self._path) + metadata_io.write_metadata(metadata, output_path) + if asset_map: + with tf.io.gfile.GFile( + os.path.join( + output_path, output_wrapper.TFTransformOutput.ASSET_MAP + ), + "w", + ) as f: + f.write(json.dumps(asset_map)) + return output_path + + return metadata_pcoll | "WriteMetadata" >> beam.Map(write_metadata_output) diff --git a/tensorflow_transform/beam/tft_beam_io/beam_metadata_io_test.py b/tensorflow_transform/beam/tft_beam_io/beam_metadata_io_test.py index 0dad77e..f584e51 100644 --- a/tensorflow_transform/beam/tft_beam_io/beam_metadata_io_test.py +++ b/tensorflow_transform/beam/tft_beam_io/beam_metadata_io_test.py @@ -18,79 +18,83 @@ import apache_beam as beam import tensorflow as tf + +import tensorflow_transform.test_case as tft_test_case from tensorflow_transform import output_wrapper -from tensorflow_transform.beam.tft_beam_io import beam_metadata_io from tensorflow_transform.beam import tft_unit -from tensorflow_transform.beam.tft_beam_io import test_metadata -import tensorflow_transform.test_case as tft_test_case +from tensorflow_transform.beam.tft_beam_io import beam_metadata_io, test_metadata from tensorflow_transform.tf_metadata import metadata_io mock = tf.compat.v1.test.mock class BeamMetadataIoTest(tft_unit.TransformTestCase): - - def testWriteMetadataNonDeferred(self): - # Write metadata to disk using WriteMetadata PTransform. - with beam.Pipeline() as pipeline: - path = self.get_temp_dir() - _ = (test_metadata.COMPLETE_METADATA - | beam_metadata_io.WriteMetadata(path, pipeline)) - - # Load from disk and check that it is as expected. - metadata = metadata_io.read_metadata(path) - self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) - - def testWriteMetadataDeferred(self): - # Write metadata to disk using WriteMetadata PTransform, combining - # incomplete metadata with (deferred) complete metadata. - expected_asset_map = {'key': 'value'} - with beam.Pipeline() as pipeline: - path = self.get_temp_dir() - deferred_metadata = pipeline | 'CreateDeferredMetadata' >> beam.Create( - [test_metadata.COMPLETE_METADATA]) - metadata = beam_metadata_io.BeamDatasetMetadata( - test_metadata.INCOMPLETE_METADATA, deferred_metadata, - expected_asset_map) - _ = metadata | beam_metadata_io.WriteMetadata(path, pipeline) - - # Load from disk and check that it is as expected. - metadata = metadata_io.read_metadata(path) - self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) - - with tf.io.gfile.GFile( - os.path.join(path, output_wrapper.TFTransformOutput.ASSET_MAP)) as f: - asset_map = json.loads(f.read()) - self.assertDictEqual(asset_map, expected_asset_map) - - def testWriteMetadataIsRetryable(self): - tft_test_case.skip_if_external_environment( - 'Retries are currently not available on this environment.') - original_write_metadata = beam_metadata_io.metadata_io.write_metadata - write_metadata_called_list = [] - - def mock_write_metadata(metadata, path): - """Mocks metadata_io.write_metadata to fail the first time it is called by this test, thus forcing a retry which should succeed.""" - if not write_metadata_called_list: - write_metadata_called_list.append(True) - original_write_metadata(metadata, path) - raise ArithmeticError('Some error') - return original_write_metadata(metadata, path) - - # Write metadata to disk using WriteMetadata PTransform. - with mock.patch( - 'tensorflow_transform.tf_metadata.metadata_io.write_metadata', - mock_write_metadata): - with self._makeTestPipeline() as pipeline: - path = self.get_temp_dir() - _ = ( - test_metadata.COMPLETE_METADATA - | beam_metadata_io.WriteMetadata(path, pipeline)) - - # Load from disk and check that it is as expected. - metadata = metadata_io.read_metadata(path) - self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) - - -if __name__ == '__main__': - tf.test.main() + def testWriteMetadataNonDeferred(self): + # Write metadata to disk using WriteMetadata PTransform. + with beam.Pipeline() as pipeline: + path = self.get_temp_dir() + _ = test_metadata.COMPLETE_METADATA | beam_metadata_io.WriteMetadata( + path, pipeline + ) + + # Load from disk and check that it is as expected. + metadata = metadata_io.read_metadata(path) + self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) + + def testWriteMetadataDeferred(self): + # Write metadata to disk using WriteMetadata PTransform, combining + # incomplete metadata with (deferred) complete metadata. + expected_asset_map = {"key": "value"} + with beam.Pipeline() as pipeline: + path = self.get_temp_dir() + deferred_metadata = pipeline | "CreateDeferredMetadata" >> beam.Create( + [test_metadata.COMPLETE_METADATA] + ) + metadata = beam_metadata_io.BeamDatasetMetadata( + test_metadata.INCOMPLETE_METADATA, deferred_metadata, expected_asset_map + ) + _ = metadata | beam_metadata_io.WriteMetadata(path, pipeline) + + # Load from disk and check that it is as expected. + metadata = metadata_io.read_metadata(path) + self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) + + with tf.io.gfile.GFile( + os.path.join(path, output_wrapper.TFTransformOutput.ASSET_MAP) + ) as f: + asset_map = json.loads(f.read()) + self.assertDictEqual(asset_map, expected_asset_map) + + def testWriteMetadataIsRetryable(self): + tft_test_case.skip_if_external_environment( + "Retries are currently not available on this environment." + ) + original_write_metadata = beam_metadata_io.metadata_io.write_metadata + write_metadata_called_list = [] + + def mock_write_metadata(metadata, path): + """Mocks metadata_io.write_metadata to fail the first time it is called by this test, thus forcing a retry which should succeed.""" + if not write_metadata_called_list: + write_metadata_called_list.append(True) + original_write_metadata(metadata, path) + raise ArithmeticError("Some error") + return original_write_metadata(metadata, path) + + # Write metadata to disk using WriteMetadata PTransform. + with mock.patch( + "tensorflow_transform.tf_metadata.metadata_io.write_metadata", + mock_write_metadata, + ): + with self._makeTestPipeline() as pipeline: + path = self.get_temp_dir() + _ = test_metadata.COMPLETE_METADATA | beam_metadata_io.WriteMetadata( + path, pipeline + ) + + # Load from disk and check that it is as expected. + metadata = metadata_io.read_metadata(path) + self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_transform/beam/tft_beam_io/test_metadata.py b/tensorflow_transform/beam/tft_beam_io/test_metadata.py index a185bd5..b6f7881 100644 --- a/tensorflow_transform/beam/tft_beam_io/test_metadata.py +++ b/tensorflow_transform/beam/tft_beam_io/test_metadata.py @@ -14,19 +14,21 @@ """Test metadata for tft_beam_io tests.""" import tensorflow as tf -from tensorflow_transform.tf_metadata import dataset_metadata - from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_transform.tf_metadata import dataset_metadata + _FEATURE_SPEC = { - 'fixed_column': tf.io.FixedLenFeature([3], tf.string), - 'list_columm': tf.io.VarLenFeature(tf.int64), + "fixed_column": tf.io.FixedLenFeature([3], tf.string), + "list_columm": tf.io.VarLenFeature(tf.int64), } COMPLETE_METADATA = dataset_metadata.DatasetMetadata.from_feature_spec( - _FEATURE_SPEC, domains={'list_columm': schema_pb2.IntDomain(min=-1, max=5)}) + _FEATURE_SPEC, domains={"list_columm": schema_pb2.IntDomain(min=-1, max=5)} +) INCOMPLETE_METADATA = dataset_metadata.DatasetMetadata.from_feature_spec( _FEATURE_SPEC, # Values will be overridden by those in COMPLETE_METADATA - domains={'list_columm': schema_pb2.IntDomain(min=0, max=0)}) + domains={"list_columm": schema_pb2.IntDomain(min=0, max=0)}, +) diff --git a/tensorflow_transform/beam/tft_beam_io/transform_fn_io.py b/tensorflow_transform/beam/tft_beam_io/transform_fn_io.py index 8809336..a3a9f92 100644 --- a/tensorflow_transform/beam/tft_beam_io/transform_fn_io.py +++ b/tensorflow_transform/beam/tft_beam_io/transform_fn_io.py @@ -16,6 +16,7 @@ import os import apache_beam as beam + import tensorflow_transform as tft from tensorflow_transform import impl_helper from tensorflow_transform.beam import common @@ -29,121 +30,135 @@ def _copy_tree_to_unique_temp_dir(source, base_temp_dir_path): - """Copies from source to a unique sub directory under base_temp_dir_path.""" - destination = common.get_unique_temp_path(base_temp_dir_path) - _copy_tree(source, destination) - return destination + """Copies from source to a unique sub directory under base_temp_dir_path.""" + destination = common.get_unique_temp_path(base_temp_dir_path) + _copy_tree(source, destination) + return destination def _copy_tree(source, destination): - """Recursively copies source to destination.""" - # TODO(b/35363519): Perhaps use Beam IO eventually (which also already - # supports recursive copy)? - import tensorflow as tf # pylint: disable=g-import-not-at-top + """Recursively copies source to destination.""" + # TODO(b/35363519): Perhaps use Beam IO eventually (which also already + # supports recursive copy)? + import tensorflow as tf # pylint: disable=g-import-not-at-top - if tf.io.gfile.isdir(source): - source_dir_name = os.path.basename(os.path.normpath(source)) - if source_dir_name == impl_helper.METADATA_DIR_NAME: - return + if tf.io.gfile.isdir(source): + source_dir_name = os.path.basename(os.path.normpath(source)) + if source_dir_name == impl_helper.METADATA_DIR_NAME: + return - tf.io.gfile.makedirs(destination) - for filename in tf.io.gfile.listdir(source): - _copy_tree( - os.path.join(source, filename), os.path.join(destination, filename)) - else: - tf.io.gfile.copy(source, destination) + tf.io.gfile.makedirs(destination) + for filename in tf.io.gfile.listdir(source): + _copy_tree( + os.path.join(source, filename), os.path.join(destination, filename) + ) + else: + tf.io.gfile.copy(source, destination) class WriteTransformFn(beam.PTransform): - """Writes a TransformFn to disk. - - The internal structure is a directory containing two subdirectories. The - first is 'transformed_metadata' and contains metadata of the transformed data. - The second is 'transform_fn' and contains a SavedModel representing the - transformed data. - """ - - def __init__(self, path): - super().__init__() - self._path = path - - def _extract_input_pvalues(self, transform_fn): - saved_model_dir, metadata = transform_fn - pvalues = [saved_model_dir] - if isinstance(metadata, beam_metadata_io.BeamDatasetMetadata): - pvalues.append(metadata.deferred_metadata) - return transform_fn, pvalues - - def expand(self, transform_fn): - saved_model_dir, metadata = transform_fn - pipeline = saved_model_dir.pipeline - - # Using a temp dir within `path` ensures that the source and dstination - # paths for the rename below are in the same file system. - base_temp_dir = os.path.join(self._path, 'transform_tmp') - temp_metadata_path = ( - metadata - | 'WriteMetadataToTemp' >> beam_metadata_io.WriteMetadata( - base_temp_dir, pipeline, write_to_unique_subdirectory=True)) - - temp_transform_fn_path = ( - saved_model_dir - | 'WriteTransformFnToTemp' >> beam.Map(_copy_tree_to_unique_temp_dir, - base_temp_dir)) - - metadata_path = os.path.join(self._path, - tft.TFTransformOutput.TRANSFORMED_METADATA_DIR) - transform_fn_path = os.path.join(self._path, - tft.TFTransformOutput.TRANSFORM_FN_DIR) - - def publish_outputs(unused_element, metadata_source_path, - transform_fn_source_path): - import tensorflow as tf # pylint: disable=g-import-not-at-top - if not tf.io.gfile.exists(self._path): - tf.io.gfile.makedirs(self._path) - - if tf.io.gfile.exists(metadata_path): - tf.io.gfile.rmtree(metadata_path) - tf.io.gfile.rename(metadata_source_path, metadata_path, overwrite=True) - - if tf.io.gfile.exists(transform_fn_path): - tf.io.gfile.rmtree(transform_fn_path) - tf.io.gfile.rename( - transform_fn_source_path, transform_fn_path, overwrite=True) - - # TODO(b/211615643): Remove the exists check once importing TFIO in S3 - # addresses NotFoundError. - if tf.io.gfile.exists(base_temp_dir): - tf.io.gfile.rmtree(base_temp_dir) - - # TODO(KesterTong): Move this "must follows" logic into a tfx_bsl helper - # function or into Beam. - return ( - pipeline - | 'CreateSole' >> beam.Create([None]) - | 'PublishMetadataAndTransformFn' >> beam.Map( - publish_outputs, - metadata_source_path=beam.pvalue.AsSingleton(temp_metadata_path), - transform_fn_source_path=beam.pvalue.AsSingleton( - temp_transform_fn_path))) + """Writes a TransformFn to disk. + + The internal structure is a directory containing two subdirectories. The + first is 'transformed_metadata' and contains metadata of the transformed data. + The second is 'transform_fn' and contains a SavedModel representing the + transformed data. + """ + + def __init__(self, path): + super().__init__() + self._path = path + + def _extract_input_pvalues(self, transform_fn): + saved_model_dir, metadata = transform_fn + pvalues = [saved_model_dir] + if isinstance(metadata, beam_metadata_io.BeamDatasetMetadata): + pvalues.append(metadata.deferred_metadata) + return transform_fn, pvalues + + def expand(self, transform_fn): + saved_model_dir, metadata = transform_fn + pipeline = saved_model_dir.pipeline + + # Using a temp dir within `path` ensures that the source and dstination + # paths for the rename below are in the same file system. + base_temp_dir = os.path.join(self._path, "transform_tmp") + temp_metadata_path = ( + metadata + | "WriteMetadataToTemp" + >> beam_metadata_io.WriteMetadata( + base_temp_dir, pipeline, write_to_unique_subdirectory=True + ) + ) + + temp_transform_fn_path = saved_model_dir | "WriteTransformFnToTemp" >> beam.Map( + _copy_tree_to_unique_temp_dir, base_temp_dir + ) + + metadata_path = os.path.join( + self._path, tft.TFTransformOutput.TRANSFORMED_METADATA_DIR + ) + transform_fn_path = os.path.join( + self._path, tft.TFTransformOutput.TRANSFORM_FN_DIR + ) + + def publish_outputs( + unused_element, metadata_source_path, transform_fn_source_path + ): + import tensorflow as tf # pylint: disable=g-import-not-at-top + + if not tf.io.gfile.exists(self._path): + tf.io.gfile.makedirs(self._path) + + if tf.io.gfile.exists(metadata_path): + tf.io.gfile.rmtree(metadata_path) + tf.io.gfile.rename(metadata_source_path, metadata_path, overwrite=True) + + if tf.io.gfile.exists(transform_fn_path): + tf.io.gfile.rmtree(transform_fn_path) + tf.io.gfile.rename( + transform_fn_source_path, transform_fn_path, overwrite=True + ) + + # TODO(b/211615643): Remove the exists check once importing TFIO in S3 + # addresses NotFoundError. + if tf.io.gfile.exists(base_temp_dir): + tf.io.gfile.rmtree(base_temp_dir) + + # TODO(KesterTong): Move this "must follows" logic into a tfx_bsl helper + # function or into Beam. + return ( + pipeline + | "CreateSole" >> beam.Create([None]) + | "PublishMetadataAndTransformFn" + >> beam.Map( + publish_outputs, + metadata_source_path=beam.pvalue.AsSingleton(temp_metadata_path), + transform_fn_source_path=beam.pvalue.AsSingleton( + temp_transform_fn_path + ), + ) + ) class ReadTransformFn(beam.PTransform): - """Reads a TransformFn written by WriteTransformFn.""" - - def __init__(self, path): - super().__init__() - self._path = path - - def expand(self, pvalue): - transform_fn_path = os.path.join(self._path, - tft.TFTransformOutput.TRANSFORM_FN_DIR) - saved_model_dir_pcoll = ( - pvalue.pipeline - | 'CreateTransformFnPath' >> beam.Create([transform_fn_path])) - - metadata = metadata_io.read_metadata( - os.path.join(self._path, - tft.TFTransformOutput.TRANSFORMED_METADATA_DIR)) - - return saved_model_dir_pcoll, metadata + """Reads a TransformFn written by WriteTransformFn.""" + + def __init__(self, path): + super().__init__() + self._path = path + + def expand(self, pvalue): + transform_fn_path = os.path.join( + self._path, tft.TFTransformOutput.TRANSFORM_FN_DIR + ) + saved_model_dir_pcoll = ( + pvalue.pipeline + | "CreateTransformFnPath" >> beam.Create([transform_fn_path]) + ) + + metadata = metadata_io.read_metadata( + os.path.join(self._path, tft.TFTransformOutput.TRANSFORMED_METADATA_DIR) + ) + + return saved_model_dir_pcoll, metadata diff --git a/tensorflow_transform/beam/tft_beam_io/transform_fn_io_test.py b/tensorflow_transform/beam/tft_beam_io/transform_fn_io_test.py index 722ca6b..75897f5 100644 --- a/tensorflow_transform/beam/tft_beam_io/transform_fn_io_test.py +++ b/tensorflow_transform/beam/tft_beam_io/transform_fn_io_test.py @@ -16,138 +16,164 @@ import os import apache_beam as beam -from apache_beam.testing import util as beam_test_util import tensorflow as tf +from apache_beam.testing import util as beam_test_util +from tensorflow.python.lib.io import ( + file_io, # pylint: disable=g-direct-tensorflow-import +) + import tensorflow_transform as tft -from tensorflow_transform.beam.tft_beam_io import beam_metadata_io -from tensorflow_transform.beam.tft_beam_io import transform_fn_io from tensorflow_transform.beam import tft_unit -from tensorflow_transform.beam.tft_beam_io import test_metadata +from tensorflow_transform.beam.tft_beam_io import ( + beam_metadata_io, + test_metadata, + transform_fn_io, +) from tensorflow_transform.tf_metadata import metadata_io -from tensorflow.python.lib.io import file_io # pylint: disable=g-direct-tensorflow-import - mock = tf.compat.v1.test.mock # TODO(varshaan): Remove global variable and use a class attribute. _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED = False class TransformFnIoTest(tft_unit.TransformTestCase): - - def testReadTransformFn(self): - path = self.get_temp_dir() - # NOTE: we don't need to create or write to the transform_fn directory since - # ReadTransformFn never inspects this directory. - transform_fn_dir = os.path.join( - path, tft.TFTransformOutput.TRANSFORM_FN_DIR) - transformed_metadata_dir = os.path.join( - path, tft.TFTransformOutput.TRANSFORMED_METADATA_DIR) - metadata_io.write_metadata(test_metadata.COMPLETE_METADATA, - transformed_metadata_dir) - - with beam.Pipeline() as pipeline: - saved_model_dir_pcoll, metadata = ( - pipeline | transform_fn_io.ReadTransformFn(path)) - beam_test_util.assert_that( - saved_model_dir_pcoll, - beam_test_util.equal_to([transform_fn_dir]), - label='AssertSavedModelDir') - # NOTE: metadata is currently read in a non-deferred manner. - self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) - - def testWriteTransformFn(self): - transform_output_dir = os.path.join(self.get_temp_dir(), 'output') - - with beam.Pipeline() as pipeline: - # Create an empty directory for the source saved model dir. - saved_model_dir = os.path.join(self.get_temp_dir(), 'source') - file_io.recursive_create_dir(saved_model_dir) - saved_model_dir_pcoll = ( - pipeline | 'CreateSavedModelDir' >> beam.Create([saved_model_dir])) - # Combine test metadata with a dict of PCollections resolving futures. - deferred_metadata = pipeline | 'CreateDeferredMetadata' >> beam.Create( - [test_metadata.COMPLETE_METADATA]) - metadata = beam_metadata_io.BeamDatasetMetadata( - test_metadata.INCOMPLETE_METADATA, deferred_metadata, {}) - - _ = ((saved_model_dir_pcoll, metadata) - | transform_fn_io.WriteTransformFn(transform_output_dir)) - - # Test reading with TFTransformOutput - tf_transform_output = tft.TFTransformOutput(transform_output_dir) - metadata = tf_transform_output.transformed_metadata - self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) - - transform_fn_dir = tf_transform_output.transform_savedmodel_dir - self.assertTrue(file_io.file_exists(transform_fn_dir)) - self.assertTrue(file_io.is_directory(transform_fn_dir)) - - def testWriteTransformFnIsIdempotent(self): - transform_output_dir = os.path.join(self.get_temp_dir(), 'output') - - def mock_write_metadata_expand(unused_self, unused_metadata): - raise ArithmeticError('Some error') - - with beam.Pipeline() as pipeline: - # Create an empty directory for the source saved model dir. - saved_model_dir = os.path.join(self.get_temp_dir(), 'source') - saved_model_dir_pcoll = ( - pipeline | 'CreateSavedModelDir' >> beam.Create([saved_model_dir])) - - with mock.patch.object(transform_fn_io.beam_metadata_io.WriteMetadata, - 'expand', mock_write_metadata_expand): - with self.assertRaisesRegex(ArithmeticError, 'Some error'): - _ = ((saved_model_dir_pcoll, object()) - | transform_fn_io.WriteTransformFn(transform_output_dir)) - - self.assertFalse(file_io.file_exists(transform_output_dir)) - - def testWriteTransformFnIsRetryable(self): - tft.test_case.skip_if_external_environment( - 'Retries are currently not available on this environment.') - original_copy_tree_to_unique_temp_dir = ( - transform_fn_io._copy_tree_to_unique_temp_dir) - - def mock_copy_tree_to_unique_temp_dir(source, base_temp_dir_path): - """Mocks transform_fn_io._copy_tree to fail the first time it is called by this test, thus forcing a retry which should succeed.""" - global _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED - if not _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED: - _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED = True - original_copy_tree_to_unique_temp_dir(source, base_temp_dir_path) - raise ArithmeticError('Some error') - return original_copy_tree_to_unique_temp_dir(source, base_temp_dir_path) - - with self._makeTestPipeline() as pipeline: - transform_output_dir = os.path.join(self.get_temp_dir(), 'output') - # Create an empty directory for the source saved model dir. - saved_model_dir = os.path.join(self.get_temp_dir(), 'source') - file_io.recursive_create_dir(saved_model_dir) - saved_model_path = os.path.join(saved_model_dir, 'saved_model') - with file_io.FileIO(saved_model_path, mode='w') as f: - f.write('some content') - saved_model_dir_pcoll = ( - pipeline | 'CreateSavedModelDir' >> beam.Create([saved_model_dir])) - # Combine test metadata with a dict of PCollections resolving futures. - deferred_metadata = pipeline | 'CreateDeferredMetadata' >> beam.Create( - [test_metadata.COMPLETE_METADATA]) - metadata = beam_metadata_io.BeamDatasetMetadata( - test_metadata.INCOMPLETE_METADATA, deferred_metadata, {}) - with mock.patch.object(transform_fn_io, '_copy_tree_to_unique_temp_dir', - mock_copy_tree_to_unique_temp_dir): - _ = ((saved_model_dir_pcoll, metadata) - | transform_fn_io.WriteTransformFn(transform_output_dir)) - - # Test reading with TFTransformOutput - tf_transform_output = tft.TFTransformOutput(transform_output_dir) - metadata = tf_transform_output.transformed_metadata - self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) - - transform_fn_dir = tf_transform_output.transform_savedmodel_dir - self.assertTrue(file_io.file_exists(transform_fn_dir)) - self.assertTrue(file_io.is_directory(transform_fn_dir)) - # Check temp directory created by failed run was cleaned up. - self.assertEqual(2, len(file_io.list_directory(transform_output_dir))) - - -if __name__ == '__main__': - tf.test.main() + def testReadTransformFn(self): + path = self.get_temp_dir() + # NOTE: we don't need to create or write to the transform_fn directory since + # ReadTransformFn never inspects this directory. + transform_fn_dir = os.path.join(path, tft.TFTransformOutput.TRANSFORM_FN_DIR) + transformed_metadata_dir = os.path.join( + path, tft.TFTransformOutput.TRANSFORMED_METADATA_DIR + ) + metadata_io.write_metadata( + test_metadata.COMPLETE_METADATA, transformed_metadata_dir + ) + + with beam.Pipeline() as pipeline: + saved_model_dir_pcoll, metadata = ( + pipeline | transform_fn_io.ReadTransformFn(path) + ) + beam_test_util.assert_that( + saved_model_dir_pcoll, + beam_test_util.equal_to([transform_fn_dir]), + label="AssertSavedModelDir", + ) + # NOTE: metadata is currently read in a non-deferred manner. + self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) + + def testWriteTransformFn(self): + transform_output_dir = os.path.join(self.get_temp_dir(), "output") + + with beam.Pipeline() as pipeline: + # Create an empty directory for the source saved model dir. + saved_model_dir = os.path.join(self.get_temp_dir(), "source") + file_io.recursive_create_dir(saved_model_dir) + saved_model_dir_pcoll = pipeline | "CreateSavedModelDir" >> beam.Create( + [saved_model_dir] + ) + # Combine test metadata with a dict of PCollections resolving futures. + deferred_metadata = pipeline | "CreateDeferredMetadata" >> beam.Create( + [test_metadata.COMPLETE_METADATA] + ) + metadata = beam_metadata_io.BeamDatasetMetadata( + test_metadata.INCOMPLETE_METADATA, deferred_metadata, {} + ) + + _ = (saved_model_dir_pcoll, metadata) | transform_fn_io.WriteTransformFn( + transform_output_dir + ) + + # Test reading with TFTransformOutput + tf_transform_output = tft.TFTransformOutput(transform_output_dir) + metadata = tf_transform_output.transformed_metadata + self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) + + transform_fn_dir = tf_transform_output.transform_savedmodel_dir + self.assertTrue(file_io.file_exists(transform_fn_dir)) + self.assertTrue(file_io.is_directory(transform_fn_dir)) + + def testWriteTransformFnIsIdempotent(self): + transform_output_dir = os.path.join(self.get_temp_dir(), "output") + + def mock_write_metadata_expand(unused_self, unused_metadata): + raise ArithmeticError("Some error") + + with beam.Pipeline() as pipeline: + # Create an empty directory for the source saved model dir. + saved_model_dir = os.path.join(self.get_temp_dir(), "source") + saved_model_dir_pcoll = pipeline | "CreateSavedModelDir" >> beam.Create( + [saved_model_dir] + ) + + with mock.patch.object( + transform_fn_io.beam_metadata_io.WriteMetadata, + "expand", + mock_write_metadata_expand, + ): + with self.assertRaisesRegex(ArithmeticError, "Some error"): + _ = ( + saved_model_dir_pcoll, + object(), + ) | transform_fn_io.WriteTransformFn(transform_output_dir) + + self.assertFalse(file_io.file_exists(transform_output_dir)) + + def testWriteTransformFnIsRetryable(self): + tft.test_case.skip_if_external_environment( + "Retries are currently not available on this environment." + ) + original_copy_tree_to_unique_temp_dir = ( + transform_fn_io._copy_tree_to_unique_temp_dir + ) + + def mock_copy_tree_to_unique_temp_dir(source, base_temp_dir_path): + """Mocks transform_fn_io._copy_tree to fail the first time it is called by this test, thus forcing a retry which should succeed.""" + global _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED + if not _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED: + _COPY_TREE_TO_UNIQUE_TEMP_DIR_CALLED = True + original_copy_tree_to_unique_temp_dir(source, base_temp_dir_path) + raise ArithmeticError("Some error") + return original_copy_tree_to_unique_temp_dir(source, base_temp_dir_path) + + with self._makeTestPipeline() as pipeline: + transform_output_dir = os.path.join(self.get_temp_dir(), "output") + # Create an empty directory for the source saved model dir. + saved_model_dir = os.path.join(self.get_temp_dir(), "source") + file_io.recursive_create_dir(saved_model_dir) + saved_model_path = os.path.join(saved_model_dir, "saved_model") + with file_io.FileIO(saved_model_path, mode="w") as f: + f.write("some content") + saved_model_dir_pcoll = pipeline | "CreateSavedModelDir" >> beam.Create( + [saved_model_dir] + ) + # Combine test metadata with a dict of PCollections resolving futures. + deferred_metadata = pipeline | "CreateDeferredMetadata" >> beam.Create( + [test_metadata.COMPLETE_METADATA] + ) + metadata = beam_metadata_io.BeamDatasetMetadata( + test_metadata.INCOMPLETE_METADATA, deferred_metadata, {} + ) + with mock.patch.object( + transform_fn_io, + "_copy_tree_to_unique_temp_dir", + mock_copy_tree_to_unique_temp_dir, + ): + _ = ( + saved_model_dir_pcoll, + metadata, + ) | transform_fn_io.WriteTransformFn(transform_output_dir) + + # Test reading with TFTransformOutput + tf_transform_output = tft.TFTransformOutput(transform_output_dir) + metadata = tf_transform_output.transformed_metadata + self.assertEqual(metadata, test_metadata.COMPLETE_METADATA) + + transform_fn_dir = tf_transform_output.transform_savedmodel_dir + self.assertTrue(file_io.file_exists(transform_fn_dir)) + self.assertTrue(file_io.is_directory(transform_fn_dir)) + # Check temp directory created by failed run was cleaned up. + self.assertEqual(2, len(file_io.list_directory(transform_output_dir))) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_transform/beam/tft_unit.py b/tensorflow_transform/beam/tft_unit.py index 898a28c..5197a02 100644 --- a/tensorflow_transform/beam/tft_unit.py +++ b/tensorflow_transform/beam/tft_unit.py @@ -15,24 +15,25 @@ import os import tempfile +import unittest from typing import Dict, Iterable, List, Optional, Tuple -from absl import logging import apache_beam as beam import pyarrow as pa import tensorflow as tf +from absl import logging +from tensorflow.python.util.protobuf import ( + compare, # pylint: disable=g-direct-tensorflow-import +) +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.coders import example_coder + import tensorflow_transform as tft import tensorflow_transform.beam as tft_beam -from tensorflow_transform.beam.tft_beam_io import transform_fn_io from tensorflow_transform import test_case from tensorflow_transform.beam import test_helpers +from tensorflow_transform.beam.tft_beam_io import transform_fn_io from tensorflow_transform.tf_metadata import dataset_metadata -from tfx_bsl.coders import example_coder - -import unittest -from tensorflow.python.util.protobuf import compare # pylint: disable=g-direct-tensorflow-import -from tensorflow_metadata.proto.v0 import schema_pb2 - parameters = test_case.parameters cross_parameters = test_case.cross_parameters @@ -49,385 +50,404 @@ def canonical_numeric_dtype(dtype): - """Returns int64 for int dtypes and float32 for float dtypes.""" - if dtype.is_floating: - return tf.float32 - elif dtype.is_integer: - return tf.int64 - else: - raise ValueError('Bad dtype {}'.format(dtype)) + """Returns int64 for int dtypes and float32 for float dtypes.""" + if dtype.is_floating: + return tf.float32 + elif dtype.is_integer: + return tf.int64 + else: + raise ValueError(f"Bad dtype {dtype}") def make_feature_spec_wrapper(make_feature_spec, *args): - """Skips test cases with RaggedFeature in TF 1.x.""" - try: - return make_feature_spec(*args) - except AttributeError as e: - if 'no attribute \'RaggedFeature\'' in repr(e): - raise unittest.SkipTest('RaggedFeature is not available in TF 1.x.') - else: - raise e + """Skips test cases with RaggedFeature in TF 1.x.""" + try: + return make_feature_spec(*args) + except AttributeError as e: + if "no attribute 'RaggedFeature'" in repr(e): + raise unittest.SkipTest("RaggedFeature is not available in TF 1.x.") + else: + raise e def _format_example_as_numpy_dict(example, feature_shape_dict): - result = example_coder.ExampleToNumpyDict(example) - for key, value in result.items(): - shape = feature_shape_dict[key] - value = value.reshape(shape) - if not shape: - value = value.squeeze(0) - result[key] = value - return result + result = example_coder.ExampleToNumpyDict(example) + for key, value in result.items(): + shape = feature_shape_dict[key] + value = value.reshape(shape) + if not shape: + value = value.squeeze(0) + result[key] = value + return result def _encode_transformed_data_batch( data: Tuple[pa.RecordBatch, Dict[str, pa.Array]], - coder: example_coder.RecordBatchToExamplesEncoder) -> List[bytes]: - """Produces a list of serialized tf.Examples from transformed data.""" - # Drop unary pass-through features that are not relevant for this testing - # framework. - record_batch, _ = data - return coder.encode(record_batch) + coder: example_coder.RecordBatchToExamplesEncoder, +) -> List[bytes]: + """Produces a list of serialized tf.Examples from transformed data.""" + # Drop unary pass-through features that are not relevant for this testing + # framework. + record_batch, _ = data + return coder.encode(record_batch) class TransformTestCase(test_case.TransformTestCase): - """Base test class for testing tf-transform preprocessing functions.""" - - class _TestPipeline(beam.Pipeline): - """Test pipeline class that retains pipeline metrics.""" - - @property - def has_ran(self): - return hasattr(self, '_run_result') - - @property - def metrics(self): - if not self.has_ran: - raise RuntimeError('Pipeline has to run before accessing its metrics') - return self._run_result.metrics() - - def __exit__(self, exc_type, exc_val, exc_tb): - if not exc_type: - assert not self.has_ran - self._run_result = self.run() - self._run_result.wait_until_finish() - - def _makeTestPipeline(self): - return self._TestPipeline(**test_helpers.make_test_beam_pipeline_kwargs()) - - def _getMetricsCounter(self, metrics: beam.metrics.Metrics, name: str, - namespaces_list: Iterable[str]) -> int: - metrics_filter = beam.metrics.MetricsFilter().with_name(name) - if namespaces_list: - metrics_filter = metrics_filter.with_namespaces(namespaces_list) - metric = metrics.query( - metrics_filter)['counters'] - committed = sum([r.committed for r in metric]) - attempted = sum([r.attempted for r in metric]) - self.assertEqual( - committed, - attempted, - msg=f'Attempted counter {name} from namespace {namespaces_list}') - return committed - - def assertMetricsCounterEqual( - self, - metrics: beam.metrics.Metrics, - name: str, - expected_count: int, - namespaces_list: Optional[Iterable[str]] = None): - counter_value = self._getMetricsCounter(metrics, name, namespaces_list) - self.assertEqual( - counter_value, - expected_count, - msg=f'Expected counter {name} from namespace {namespaces_list}') - - def assertMetricsCounterGreater( - self, - metrics: beam.metrics.Metrics, - name: str, - than: int, - namespaces_list: Optional[Iterable[str]] = None): - counter_value = self._getMetricsCounter(metrics, name, namespaces_list) - self.assertGreater( - counter_value, - than, - msg=f'Expected counter {name} from namespace {namespaces_list}') - - def assertAnalyzerOutputs(self, - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - test_data=None, - desired_batch_size=None, - beam_pipeline=None, - force_tf_compat_v1=False, - output_record_batches=False): - """Assert that input data and metadata is transformed as expected. - - This methods asserts transformed data and transformed metadata match - with expected_data and expected_metadata. - - Args: - input_data: Input data formatted in one of two ways: - * A sequence of dicts whose values are one of: - strings, lists of strings, numeric types or a pair of those. - Must have at least one key so that we can infer the batch size, or - * A sequence of pa.RecordBatch. - input_metadata: One of - - * DatasetMetadata describing input_data if `input_data` are dicts. - * TensorAdapterConfig otherwise. - analyzer_fn: A function taking a dict of tensors and returning a dict of - tensors. Unlike a preprocessing_fn, this should emit the results of a - call to an analyzer, while a preprocessing_fn must typically add a batch - dimension and broadcast across this batch dimension. - expected_outputs: A dict whose keys are the same as those of the output of - `analyzer_fn` and whose values are convertible to an ndarrays. - test_data: (optional) If this is provided then instead of calling - AnalyzeAndTransformDataset with input_data, this function will call - AnalyzeDataset with input_data and TransformDataset with test_data. - Must be provided if the input_data is empty. test_data should also - conform to input_metadata. - desired_batch_size: (Optional) A batch size to batch elements by. If not - provided, a batch size will be computed automatically. - beam_pipeline: (optional) A Beam Pipeline to use in this test. - force_tf_compat_v1: A bool. If `True`, TFT's public APIs use - Tensorflow in compat.v1 mode. - output_record_batches: (optional) A bool. If `True`, `TransformDataset` - and `AnalyzeAndTransformDataset` output `pyarrow.RecordBatch`es; - otherwise, they output instance dicts. - - Raises: - AssertionError: If the expected output does not match the results of - the analyzer_fn. - """ - - def preprocessing_fn(inputs): - """A helper function for validating analyzer outputs.""" - # Get tensors representing the outputs of the analyzers - analyzer_outputs = analyzer_fn(inputs) - - # Check that keys of analyzer_outputs match expected_output. - self.assertCountEqual(analyzer_outputs.keys(), expected_outputs.keys()) - - # Get batch size from any input tensor. - an_input = next(iter(inputs.values())) - if isinstance(an_input, tf.RaggedTensor): - batch_size = an_input.bounding_shape(axis=0) - else: - batch_size = tf.shape(input=an_input)[0] - - # Add a batch dimension and broadcast the analyzer outputs. - result = {} - for key, output_tensor in analyzer_outputs.items(): - # Get the expected shape, and set it. - expected_output_shape = list(expected_outputs[key].shape) - try: - output_tensor.set_shape(expected_output_shape) - except ValueError as e: - raise ValueError( - f'Error for key {key}, shapes are incompatible. Got ' - f'{output_tensor.shape}, expected {expected_output_shape}.' - ) from e - # Add a batch dimension - output_tensor = tf.expand_dims(output_tensor, 0) - # Broadcast along the batch dimension - result[key] = tf.tile( - output_tensor, - multiples=[batch_size] + [1] * len(expected_output_shape)) - - return result - - if input_data and not test_data: - # Create test dataset by repeating the first instance a number of times. - num_test_instances = 3 - test_data = [input_data[0]] * num_test_instances - expected_data = [expected_outputs] * num_test_instances - else: - # Ensure that the test dataset is specified and is not empty. - assert test_data - expected_data = [expected_outputs] * len(test_data) - expected_metadata = dataset_metadata.DatasetMetadata.from_feature_spec({ - key: tf.io.FixedLenFeature(value.shape, tf.as_dtype(value.dtype)) - for key, value in expected_outputs.items() - }) - - self.assertAnalyzeAndTransformResults( + """Base test class for testing tf-transform preprocessing functions.""" + + class _TestPipeline(beam.Pipeline): + """Test pipeline class that retains pipeline metrics.""" + + @property + def has_ran(self): + return hasattr(self, "_run_result") + + @property + def metrics(self): + if not self.has_ran: + raise RuntimeError("Pipeline has to run before accessing its metrics") + return self._run_result.metrics() + + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_type: + assert not self.has_ran + self._run_result = self.run() + self._run_result.wait_until_finish() + + def _makeTestPipeline(self): + return self._TestPipeline(**test_helpers.make_test_beam_pipeline_kwargs()) + + def _getMetricsCounter( + self, metrics: beam.metrics.Metrics, name: str, namespaces_list: Iterable[str] + ) -> int: + metrics_filter = beam.metrics.MetricsFilter().with_name(name) + if namespaces_list: + metrics_filter = metrics_filter.with_namespaces(namespaces_list) + metric = metrics.query(metrics_filter)["counters"] + committed = sum([r.committed for r in metric]) + attempted = sum([r.attempted for r in metric]) + self.assertEqual( + committed, + attempted, + msg=f"Attempted counter {name} from namespace {namespaces_list}", + ) + return committed + + def assertMetricsCounterEqual( + self, + metrics: beam.metrics.Metrics, + name: str, + expected_count: int, + namespaces_list: Optional[Iterable[str]] = None, + ): + counter_value = self._getMetricsCounter(metrics, name, namespaces_list) + self.assertEqual( + counter_value, + expected_count, + msg=f"Expected counter {name} from namespace {namespaces_list}", + ) + + def assertMetricsCounterGreater( + self, + metrics: beam.metrics.Metrics, + name: str, + than: int, + namespaces_list: Optional[Iterable[str]] = None, + ): + counter_value = self._getMetricsCounter(metrics, name, namespaces_list) + self.assertGreater( + counter_value, + than, + msg=f"Expected counter {name} from namespace {namespaces_list}", + ) + + def assertAnalyzerOutputs( + self, input_data, input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - test_data=test_data, - desired_batch_size=desired_batch_size, - beam_pipeline=beam_pipeline, - force_tf_compat_v1=force_tf_compat_v1, - output_record_batches=output_record_batches) - - def assertAnalyzeAndTransformResults(self, - input_data, - input_metadata, - preprocessing_fn, - expected_data=None, - expected_metadata=None, - expected_vocab_file_contents=None, - test_data=None, - desired_batch_size=None, - beam_pipeline=None, - temp_dir=None, - force_tf_compat_v1=False, - output_record_batches=False): - """Assert that input data and metadata is transformed as expected. - - This methods asserts transformed data and transformed metadata match - with expected_data and expected_metadata. - - Args: - input_data: Input data formatted in one of three ways: - * A sequence of dicts whose values are one of: - strings, lists of strings, numeric types or a pair of those. - Must have at least one key so that we can infer the batch size, or - * A sequence of pa.RecordBatch. - * A Beam source PTransform that produces either of the above. - input_metadata: One of - - * DatasetMetadata describing input_data if `input_data` are dicts. - * TensorAdapterConfig otherwise. - preprocessing_fn: A function taking a dict of tensors and returning - a dict of tensors. - expected_data: (optional) A dataset with the same type constraints as - input_data, but representing the output after transformation. - If supplied, transformed data is asserted to be equal. - expected_metadata: (optional) DatasetMetadata describing the transformed - data. If supplied, transformed metadata is asserted to be equal. - expected_vocab_file_contents: (optional) A dictionary from vocab filenames - to their expected content as a list of text lines or a list of tuples - of frequency and text. Values should be the expected result of calling - f.readlines() on the given asset files. - test_data: (optional) If this is provided then instead of calling - AnalyzeAndTransformDataset with input_data, this function will call - AnalyzeDataset with input_data and TransformDataset with test_data. - Note that this is the case even if input_data and test_data are equal. - test_data should also conform to input_metadata. - desired_batch_size: (optional) A batch size to batch elements by. If not - provided, a batch size will be computed automatically. - beam_pipeline: (optional) A Beam Pipeline to use in this test. - temp_dir: If set, it is used as output directory, else a new unique - directory is created. - force_tf_compat_v1: A bool. If `True`, TFT's public APIs use Tensorflow - in compat.v1 mode. - output_record_batches: (optional) A bool. If `True`, `TransformDataset` - and `AnalyzeAndTransformDataset` output `pyarrow.RecordBatch`es; - otherwise, they output instance dicts. - Raises: - AssertionError: if the expected data does not match the results of - transforming input_data according to preprocessing_fn, or - (if provided) if the expected metadata does not match. - """ - - expected_vocab_file_contents = expected_vocab_file_contents or {} - - # Note: we don't separately test AnalyzeDataset and TransformDataset as - # AnalyzeAndTransformDataset currently simply composes these two - # transforms. If in future versions of the code, the implementation - # differs, we should also run AnalyzeDataset and TransformDataset composed. - temp_dir = temp_dir or tempfile.mkdtemp( - prefix=self._testMethodName, dir=self.get_temp_dir()) - with beam_pipeline or self._makeTestPipeline() as pipeline: - with tft_beam.Context( - temp_dir=temp_dir, - desired_batch_size=desired_batch_size, - force_tf_compat_v1=force_tf_compat_v1, - ): - source_ptransform = ( - input_data if isinstance(input_data, beam.PTransform) else - beam.Create(input_data, reshuffle=False)) - input_data = pipeline | 'CreateInput' >> source_ptransform - if test_data is None: - (transformed_data, transformed_metadata), transform_fn = ( - input_data, - input_metadata, - ) | tft_beam.AnalyzeAndTransformDataset( - preprocessing_fn, output_record_batches=output_record_batches - ) + analyzer_fn, + expected_outputs, + test_data=None, + desired_batch_size=None, + beam_pipeline=None, + force_tf_compat_v1=False, + output_record_batches=False, + ): + """Assert that input data and metadata is transformed as expected. + + This methods asserts transformed data and transformed metadata match + with expected_data and expected_metadata. + + Args: + ---- + input_data: Input data formatted in one of two ways: + * A sequence of dicts whose values are one of: + strings, lists of strings, numeric types or a pair of those. + Must have at least one key so that we can infer the batch size, or + * A sequence of pa.RecordBatch. + input_metadata: One of - + * DatasetMetadata describing input_data if `input_data` are dicts. + * TensorAdapterConfig otherwise. + analyzer_fn: A function taking a dict of tensors and returning a dict of + tensors. Unlike a preprocessing_fn, this should emit the results of a + call to an analyzer, while a preprocessing_fn must typically add a batch + dimension and broadcast across this batch dimension. + expected_outputs: A dict whose keys are the same as those of the output of + `analyzer_fn` and whose values are convertible to an ndarrays. + test_data: (optional) If this is provided then instead of calling + AnalyzeAndTransformDataset with input_data, this function will call + AnalyzeDataset with input_data and TransformDataset with test_data. + Must be provided if the input_data is empty. test_data should also + conform to input_metadata. + desired_batch_size: (Optional) A batch size to batch elements by. If not + provided, a batch size will be computed automatically. + beam_pipeline: (optional) A Beam Pipeline to use in this test. + force_tf_compat_v1: A bool. If `True`, TFT's public APIs use + Tensorflow in compat.v1 mode. + output_record_batches: (optional) A bool. If `True`, `TransformDataset` + and `AnalyzeAndTransformDataset` output `pyarrow.RecordBatch`es; + otherwise, they output instance dicts. + + Raises: + ------ + AssertionError: If the expected output does not match the results of + the analyzer_fn. + """ + + def preprocessing_fn(inputs): + """A helper function for validating analyzer outputs.""" + # Get tensors representing the outputs of the analyzers + analyzer_outputs = analyzer_fn(inputs) + + # Check that keys of analyzer_outputs match expected_output. + self.assertCountEqual(analyzer_outputs.keys(), expected_outputs.keys()) + + # Get batch size from any input tensor. + an_input = next(iter(inputs.values())) + if isinstance(an_input, tf.RaggedTensor): + batch_size = an_input.bounding_shape(axis=0) + else: + batch_size = tf.shape(input=an_input)[0] + + # Add a batch dimension and broadcast the analyzer outputs. + result = {} + for key, output_tensor in analyzer_outputs.items(): + # Get the expected shape, and set it. + expected_output_shape = list(expected_outputs[key].shape) + try: + output_tensor.set_shape(expected_output_shape) + except ValueError as e: + raise ValueError( + f"Error for key {key}, shapes are incompatible. Got " + f"{output_tensor.shape}, expected {expected_output_shape}." + ) from e + # Add a batch dimension + output_tensor = tf.expand_dims(output_tensor, 0) + # Broadcast along the batch dimension + result[key] = tf.tile( + output_tensor, + multiples=[batch_size] + [1] * len(expected_output_shape), + ) + + return result + + if input_data and not test_data: + # Create test dataset by repeating the first instance a number of times. + num_test_instances = 3 + test_data = [input_data[0]] * num_test_instances + expected_data = [expected_outputs] * num_test_instances else: - transform_fn = (input_data, input_metadata) | tft_beam.AnalyzeDataset( - preprocessing_fn - ) - test_data = pipeline | 'CreateTest' >> beam.Create(test_data) - transformed_data, transformed_metadata = ( - (test_data, input_metadata), - transform_fn, - ) | tft_beam.TransformDataset( - output_record_batches=output_record_batches - ) - - # Write transform_fn so we can test its assets - _ = transform_fn | transform_fn_io.WriteTransformFn(temp_dir) - - transformed_data_path = os.path.join(temp_dir, 'transformed_data') + # Ensure that the test dataset is specified and is not empty. + assert test_data + expected_data = [expected_outputs] * len(test_data) + expected_metadata = dataset_metadata.DatasetMetadata.from_feature_spec( + { + key: tf.io.FixedLenFeature(value.shape, tf.as_dtype(value.dtype)) + for key, value in expected_outputs.items() + } + ) + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + test_data=test_data, + desired_batch_size=desired_batch_size, + beam_pipeline=beam_pipeline, + force_tf_compat_v1=force_tf_compat_v1, + output_record_batches=output_record_batches, + ) + + def assertAnalyzeAndTransformResults( + self, + input_data, + input_metadata, + preprocessing_fn, + expected_data=None, + expected_metadata=None, + expected_vocab_file_contents=None, + test_data=None, + desired_batch_size=None, + beam_pipeline=None, + temp_dir=None, + force_tf_compat_v1=False, + output_record_batches=False, + ): + """Assert that input data and metadata is transformed as expected. + + This methods asserts transformed data and transformed metadata match + with expected_data and expected_metadata. + + Args: + ---- + input_data: Input data formatted in one of three ways: + * A sequence of dicts whose values are one of: + strings, lists of strings, numeric types or a pair of those. + Must have at least one key so that we can infer the batch size, or + * A sequence of pa.RecordBatch. + * A Beam source PTransform that produces either of the above. + input_metadata: One of - + * DatasetMetadata describing input_data if `input_data` are dicts. + * TensorAdapterConfig otherwise. + preprocessing_fn: A function taking a dict of tensors and returning + a dict of tensors. + expected_data: (optional) A dataset with the same type constraints as + input_data, but representing the output after transformation. + If supplied, transformed data is asserted to be equal. + expected_metadata: (optional) DatasetMetadata describing the transformed + data. If supplied, transformed metadata is asserted to be equal. + expected_vocab_file_contents: (optional) A dictionary from vocab filenames + to their expected content as a list of text lines or a list of tuples + of frequency and text. Values should be the expected result of calling + f.readlines() on the given asset files. + test_data: (optional) If this is provided then instead of calling + AnalyzeAndTransformDataset with input_data, this function will call + AnalyzeDataset with input_data and TransformDataset with test_data. + Note that this is the case even if input_data and test_data are equal. + test_data should also conform to input_metadata. + desired_batch_size: (optional) A batch size to batch elements by. If not + provided, a batch size will be computed automatically. + beam_pipeline: (optional) A Beam Pipeline to use in this test. + temp_dir: If set, it is used as output directory, else a new unique + directory is created. + force_tf_compat_v1: A bool. If `True`, TFT's public APIs use Tensorflow + in compat.v1 mode. + output_record_batches: (optional) A bool. If `True`, `TransformDataset` + and `AnalyzeAndTransformDataset` output `pyarrow.RecordBatch`es; + otherwise, they output instance dicts. + + Raises: + ------ + AssertionError: if the expected data does not match the results of + transforming input_data according to preprocessing_fn, or + (if provided) if the expected metadata does not match. + """ + expected_vocab_file_contents = expected_vocab_file_contents or {} + + # Note: we don't separately test AnalyzeDataset and TransformDataset as + # AnalyzeAndTransformDataset currently simply composes these two + # transforms. If in future versions of the code, the implementation + # differs, we should also run AnalyzeDataset and TransformDataset composed. + temp_dir = temp_dir or tempfile.mkdtemp( + prefix=self._testMethodName, dir=self.get_temp_dir() + ) + with beam_pipeline or self._makeTestPipeline() as pipeline: + with tft_beam.Context( + temp_dir=temp_dir, + desired_batch_size=desired_batch_size, + force_tf_compat_v1=force_tf_compat_v1, + ): + source_ptransform = ( + input_data + if isinstance(input_data, beam.PTransform) + else beam.Create(input_data, reshuffle=False) + ) + input_data = pipeline | "CreateInput" >> source_ptransform + if test_data is None: + (transformed_data, transformed_metadata), transform_fn = ( + input_data, + input_metadata, + ) | tft_beam.AnalyzeAndTransformDataset( + preprocessing_fn, output_record_batches=output_record_batches + ) + else: + transform_fn = ( + input_data, + input_metadata, + ) | tft_beam.AnalyzeDataset(preprocessing_fn) + test_data = pipeline | "CreateTest" >> beam.Create(test_data) + transformed_data, transformed_metadata = ( + (test_data, input_metadata), + transform_fn, + ) | tft_beam.TransformDataset( + output_record_batches=output_record_batches + ) + + # Write transform_fn so we can test its assets + _ = transform_fn | transform_fn_io.WriteTransformFn(temp_dir) + + transformed_data_path = os.path.join(temp_dir, "transformed_data") + if expected_data is not None: + _ = ( + (transformed_data, transformed_metadata) + | "Encode" >> tft_beam.EncodeTransformedDataset() + | "Write" + >> beam.io.tfrecordio.WriteToTFRecord( + transformed_data_path, shard_name_template="" + ) + ) + + # TODO(ebreck) Log transformed_data somewhere. + tf_transform_output = tft.TFTransformOutput(temp_dir) if expected_data is not None: - _ = ( - (transformed_data, transformed_metadata) - | 'Encode' >> tft_beam.EncodeTransformedDataset() - | 'Write' - >> beam.io.tfrecordio.WriteToTFRecord( - transformed_data_path, shard_name_template='' - ) - ) - - # TODO(ebreck) Log transformed_data somewhere. - tf_transform_output = tft.TFTransformOutput(temp_dir) - if expected_data is not None: - examples = tf.compat.v1.python_io.tf_record_iterator( - path=transformed_data_path) - shapes = { - f.name: - [s.size for s in f.shape.dim] if f.HasField('shape') else [-1] - for f in tf_transform_output.transformed_metadata.schema.feature - } - transformed_data = [ - _format_example_as_numpy_dict(e, shapes) for e in examples - ] - self.assertDataCloseOrEqual(expected_data, transformed_data) - - if expected_metadata: - # Make a copy with no annotations. - transformed_schema = schema_pb2.Schema() - transformed_schema.CopyFrom( - tf_transform_output.transformed_metadata.schema) - transformed_schema.ClearField('annotation') - for feature in transformed_schema.feature: - feature.ClearField('annotation') - - # assertProtoEqual has a size limit on the length of the - # serialized as text strings. Therefore, we first try to use - # assertProtoEqual, if that fails we try to use assertEqual, if that fails - # as well then we raise the exception from assertProtoEqual. - try: - compare.assertProtoEqual(self, expected_metadata.schema, - transformed_schema) - except AssertionError as compare_exception: - try: - self.assertEqual(expected_metadata.schema, transformed_schema) - except AssertionError: - raise compare_exception - - for filename, file_contents in expected_vocab_file_contents.items(): - full_filename = tf_transform_output.vocabulary_file_by_name(filename) - self.AssertVocabularyContents(full_filename, file_contents) - - def DebugPublishLatestsRenderedTFTGraph( - self, output_file: Optional[str] = None - ): - """Outputs a rendered graph which may be used for debugging. - - Requires adding the binary resource to the test target: - data = ["//third_party/graphviz:dot_binary"] - - Args: - output_file: Path to output the rendered graph file. - """ - logging.info( - 'DebugPublishLatestsRenderedTFTGraph is not currently supported.' - ) + examples = tf.compat.v1.python_io.tf_record_iterator( + path=transformed_data_path + ) + shapes = { + f.name: [s.size for s in f.shape.dim] if f.HasField("shape") else [-1] + for f in tf_transform_output.transformed_metadata.schema.feature + } + transformed_data = [ + _format_example_as_numpy_dict(e, shapes) for e in examples + ] + self.assertDataCloseOrEqual(expected_data, transformed_data) + + if expected_metadata: + # Make a copy with no annotations. + transformed_schema = schema_pb2.Schema() + transformed_schema.CopyFrom(tf_transform_output.transformed_metadata.schema) + transformed_schema.ClearField("annotation") + for feature in transformed_schema.feature: + feature.ClearField("annotation") + + # assertProtoEqual has a size limit on the length of the + # serialized as text strings. Therefore, we first try to use + # assertProtoEqual, if that fails we try to use assertEqual, if that fails + # as well then we raise the exception from assertProtoEqual. + try: + compare.assertProtoEqual( + self, expected_metadata.schema, transformed_schema + ) + except AssertionError as compare_exception: + try: + self.assertEqual(expected_metadata.schema, transformed_schema) + except AssertionError: + raise compare_exception + + for filename, file_contents in expected_vocab_file_contents.items(): + full_filename = tf_transform_output.vocabulary_file_by_name(filename) + self.AssertVocabularyContents(full_filename, file_contents) + + def DebugPublishLatestsRenderedTFTGraph(self, output_file: Optional[str] = None): + """Outputs a rendered graph which may be used for debugging. + + Requires adding the binary resource to the test target: + data = ["//third_party/graphviz:dot_binary"] + + Args: + ---- + output_file: Path to output the rendered graph file. + """ + logging.info("DebugPublishLatestsRenderedTFTGraph is not currently supported.") diff --git a/tensorflow_transform/beam/tukey_hh_params_integration_test.py b/tensorflow_transform/beam/tukey_hh_params_integration_test.py index 3f4dee0..bf73dba 100644 --- a/tensorflow_transform/beam/tukey_hh_params_integration_test.py +++ b/tensorflow_transform/beam/tukey_hh_params_integration_test.py @@ -17,615 +17,1003 @@ import apache_beam as beam import numpy as np - import tensorflow as tf + import tensorflow_transform as tft from tensorflow_transform.beam import impl as beam_impl -from tensorflow_transform.beam import impl_test # Use attributes, but no tests. -from tensorflow_transform.beam import tft_unit - +from tensorflow_transform.beam import ( + impl_test, # Use attributes, but no tests. + tft_unit, +) # The input_data in _SCALE_TO_Z_SCORE_TEST_CASES (this is defined in impl_tests # to test tft.scale_to_z_score) do not have long tails; # therefore, gaussianization produces the same result of z_score. _SCALE_TO_GAUSSIAN_TEST_CASES = impl_test._SCALE_TO_Z_SCORE_TEST_CASES + [ - dict(testcase_name='gaussianization_int32', - input_data=np.array( - [516, -871, 737, 415, 584, 583, 152, 479, 576, 409, - 591, 844, -16, 508, 669, 617, 502, 532, 517, 479], - dtype=np.int32), - output_data=np.array( - [-0.09304726, -2.24682532, 1.56900163, -0.78244931, 0.48285998, - 0.47461339, -1.50929952, -0.39008015, 0.41659823, -0.81174337, - 0.54027596, 2.11624695, -1.72816411, -0.16046759, 1.13320023, - 0.74814557, -0.21014091, 0.04373742, -0.08454805, -0.39008015], - dtype=np.float32), - elementwise=False), - dict(testcase_name='gaussianization_float32', - input_data=np.array( - [516., -871., 737., 415., 584., 583., 152., 479., 576., 409., - 591., 844., -16., 508., 669., 617., 502., 532., 517., 479.], - dtype=np.float32), - output_data=np.array( - [-0.09304726, -2.24682532, 1.56900163, -0.78244931, 0.48285998, - 0.47461339, -1.50929952, -0.39008015, 0.41659823, -0.81174337, - 0.54027596, 2.11624695, -1.72816411, -0.16046759, 1.13320023, - 0.74814557, -0.21014091, 0.04373742, -0.08454805, -0.39008015], - dtype=np.float32), - elementwise=False), - dict(testcase_name='gaussianization_vector', - input_data=np.array( - [[516., -871.], [737., 415.], [584., 583.], [152., 479.], - [576., 409.], [591., 844.], [-16., 508.], [669., 617.], - [502., 532.], [517., 479.]], - dtype=np.float32), - output_data=np.array( - [[-0.09304726, -2.24682532], [1.56900163, -0.78244931], - [0.48285998, 0.47461339], [-1.50929952, -0.39008015], - [0.41659823, -0.81174337], [0.54027596, 2.11624695], - [-1.72816411, -0.16046759], [1.13320023, 0.74814557], - [-0.21014091, 0.04373742], [-0.08454805, -0.39008015]], - dtype=np.float32), - elementwise=False), - dict(testcase_name='gaussianization_vector_elementwise', - input_data=np.array( - [[516., -479.], [-871., -517.], [737., -532.], [415., -502.], - [584., -617.], [583., -669.], [152., -508.], [479., 16.], - [576., -844.], [409., -591.], [591., -409.], [844., -576.], - [-16., -479.], [508., -152.], [669., -583.], [617., -584.], - [502., -415.], [532., -737.], [517., 871.], [479., -516.]], - dtype=np.float32), - output_data=np.array( - [[-0.09304726, 0.39008015], [-2.24682532, 0.08454805], - [1.56900163, -0.04373742], [-0.78244931, 0.21014091], - [0.48285998, -0.74814557], [0.47461339, -1.13320023], - [-1.50929952, 0.16046759], [-0.39008015, 1.72816411], - [0.41659823, -2.11624695], [-0.81174337, -0.54027596], - [0.54027596, 0.81174337], [2.11624695, -0.41659823], - [-1.72816411, 0.39008015], [-0.16046759, 1.50929952], - [1.13320023, -0.47461339], [0.74814557, -0.48285998], - [-0.21014091, 0.78244931], [0.04373742, -1.56900163], - [-0.08454805, 2.24682532], [-0.39008015, 0.09304726]], - dtype=np.float32), - elementwise=True), + dict( + testcase_name="gaussianization_int32", + input_data=np.array( + [ + 516, + -871, + 737, + 415, + 584, + 583, + 152, + 479, + 576, + 409, + 591, + 844, + -16, + 508, + 669, + 617, + 502, + 532, + 517, + 479, + ], + dtype=np.int32, + ), + output_data=np.array( + [ + -0.09304726, + -2.24682532, + 1.56900163, + -0.78244931, + 0.48285998, + 0.47461339, + -1.50929952, + -0.39008015, + 0.41659823, + -0.81174337, + 0.54027596, + 2.11624695, + -1.72816411, + -0.16046759, + 1.13320023, + 0.74814557, + -0.21014091, + 0.04373742, + -0.08454805, + -0.39008015, + ], + dtype=np.float32, + ), + elementwise=False, + ), + dict( + testcase_name="gaussianization_float32", + input_data=np.array( + [ + 516.0, + -871.0, + 737.0, + 415.0, + 584.0, + 583.0, + 152.0, + 479.0, + 576.0, + 409.0, + 591.0, + 844.0, + -16.0, + 508.0, + 669.0, + 617.0, + 502.0, + 532.0, + 517.0, + 479.0, + ], + dtype=np.float32, + ), + output_data=np.array( + [ + -0.09304726, + -2.24682532, + 1.56900163, + -0.78244931, + 0.48285998, + 0.47461339, + -1.50929952, + -0.39008015, + 0.41659823, + -0.81174337, + 0.54027596, + 2.11624695, + -1.72816411, + -0.16046759, + 1.13320023, + 0.74814557, + -0.21014091, + 0.04373742, + -0.08454805, + -0.39008015, + ], + dtype=np.float32, + ), + elementwise=False, + ), + dict( + testcase_name="gaussianization_vector", + input_data=np.array( + [ + [516.0, -871.0], + [737.0, 415.0], + [584.0, 583.0], + [152.0, 479.0], + [576.0, 409.0], + [591.0, 844.0], + [-16.0, 508.0], + [669.0, 617.0], + [502.0, 532.0], + [517.0, 479.0], + ], + dtype=np.float32, + ), + output_data=np.array( + [ + [-0.09304726, -2.24682532], + [1.56900163, -0.78244931], + [0.48285998, 0.47461339], + [-1.50929952, -0.39008015], + [0.41659823, -0.81174337], + [0.54027596, 2.11624695], + [-1.72816411, -0.16046759], + [1.13320023, 0.74814557], + [-0.21014091, 0.04373742], + [-0.08454805, -0.39008015], + ], + dtype=np.float32, + ), + elementwise=False, + ), + dict( + testcase_name="gaussianization_vector_elementwise", + input_data=np.array( + [ + [516.0, -479.0], + [-871.0, -517.0], + [737.0, -532.0], + [415.0, -502.0], + [584.0, -617.0], + [583.0, -669.0], + [152.0, -508.0], + [479.0, 16.0], + [576.0, -844.0], + [409.0, -591.0], + [591.0, -409.0], + [844.0, -576.0], + [-16.0, -479.0], + [508.0, -152.0], + [669.0, -583.0], + [617.0, -584.0], + [502.0, -415.0], + [532.0, -737.0], + [517.0, 871.0], + [479.0, -516.0], + ], + dtype=np.float32, + ), + output_data=np.array( + [ + [-0.09304726, 0.39008015], + [-2.24682532, 0.08454805], + [1.56900163, -0.04373742], + [-0.78244931, 0.21014091], + [0.48285998, -0.74814557], + [0.47461339, -1.13320023], + [-1.50929952, 0.16046759], + [-0.39008015, 1.72816411], + [0.41659823, -2.11624695], + [-0.81174337, -0.54027596], + [0.54027596, 0.81174337], + [2.11624695, -0.41659823], + [-1.72816411, 0.39008015], + [-0.16046759, 1.50929952], + [1.13320023, -0.47461339], + [0.74814557, -0.48285998], + [-0.21014091, 0.78244931], + [0.04373742, -1.56900163], + [-0.08454805, 2.24682532], + [-0.39008015, 0.09304726], + ], + dtype=np.float32, + ), + elementwise=True, + ), ] class TukeyHHParamsIntegrationTest(tft_unit.TransformTestCase): - - def setUp(self): - self._context = beam_impl.Context(use_deep_copy_optimization=True) - self._context.__enter__() - super().setUp() - - def tearDown(self): - self._context.__exit__() - super().tearDown() - - @tft_unit.named_parameters(*_SCALE_TO_GAUSSIAN_TEST_CASES) - def testGaussianize(self, input_data, output_data, elementwise): - - def preprocessing_fn(inputs): - x = inputs['x'] - x_cast = tf.cast(x, tf.as_dtype(input_data.dtype)) - x_gaussianized = tft.scale_to_gaussian(x_cast, elementwise=elementwise) - self.assertEqual(x_gaussianized.dtype, tf.as_dtype(output_data.dtype)) - return {'x_gaussianized': tf.cast(x_gaussianized, tf.float32)} - - input_data_dicts = [{'x': x} for x in input_data] - expected_data_dicts = [ - {'x_gaussianized': x_gaussianized} for x_gaussianized in output_data] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.FixedLenFeature( - input_data.shape[1:], - tft_unit.canonical_numeric_dtype(tf.as_dtype( - input_data.dtype))), - }) - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x_gaussianized': - tf.io.FixedLenFeature(output_data.shape[1:], tf.float32), - }) - self.assertAnalyzeAndTransformResults( - input_data_dicts, input_metadata, preprocessing_fn, expected_data_dicts, - expected_metadata, desired_batch_size=20, beam_pipeline=beam.Pipeline()) - - @tft_unit.parameters(*itertools.product([ - tf.int16, - tf.int32, - tf.int64, - tf.float32, - tf.float64, - ], (True, False))) - def testGaussianizeSparse(self, input_dtype, elementwise): - - def preprocessing_fn(inputs): - x_gaussianized = tft.scale_to_gaussian( - tf.cast(inputs['x'], input_dtype), elementwise=elementwise) - self.assertEqual(x_gaussianized.dtype, - impl_test._mean_output_dtype(input_dtype)) - return { - 'x_gaussianized': tf.cast(x_gaussianized, tf.float32) - } - - input_data_values = [516, -871, 737, 415, 584, 583, 152, 479, 576, 409, 591, - 844, -16, 508, 669, 617, 502, 532, 517, 479] - input_data = [] - for idx, v in enumerate(input_data_values): - input_data.append({ - 'idx0': [1, 1], - 'idx1': [0, 1], - 'val': [v, -input_data_values[-1 - idx]] - }) - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.SparseFeature(['idx0', 'idx1'], 'val', - tft_unit.canonical_numeric_dtype(input_dtype), - (4, 5)) - }) - if elementwise: - expected_data_values = [ - -0.09304726, -2.24682532, 1.56900163, -0.78244931, 0.48285998, - 0.47461339, -1.50929952, -0.39008015, 0.41659823, -0.81174337, - 0.54027596, 2.11624695, -1.72816411, -0.16046759, 1.13320023, - 0.74814557, -0.21014091, 0.04373742, -0.08454805, -0.39008015] - else: - expected_data_values = [ - 0.91555131, -1.54543642, 1.30767697, 0.73634456, 1.03620536, - 1.03443104, 0.26969729, 0.84990131, 1.02201077, 0.72569862, - 1.04862563, 1.49752966, -0.02838919, 0.90135672, 1.18702292, - 1.09475806, 0.89071077, 0.9439405, 0.91732564, 0.84990131] - expected_data = [] - for idx, v in enumerate(expected_data_values): - expected_data.append({ - 'x_gaussianized$sparse_values': ([v, - -expected_data_values[-1 - idx]]), - 'x_gaussianized$sparse_indices_0': [1, 1], - 'x_gaussianized$sparse_indices_1': [0, 1], - }) - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - desired_batch_size=20, - beam_pipeline=beam.Pipeline()) - - @tft_unit.parameters( - (tf.int16,), - (tf.int32,), - (tf.int64,), - (tf.float32,), - (tf.float64,), - ) - def testGaussianizeRagged(self, input_dtype): - tft_unit.skip_if_not_tf2('RaggedFeature is not available in TF 1.x.') - - def preprocessing_fn(inputs): - x_gaussianized = tft.scale_to_gaussian(tf.cast(inputs['x'], input_dtype)) - self.assertEqual(x_gaussianized.dtype, - impl_test._mean_output_dtype(input_dtype)) - return {'x_gaussianized': tf.cast(x_gaussianized, tf.float32)} - - input_data_values = [ - 516, -871, 737, 415, 584, 583, 152, 479, 576, 409, 591, 844, -16, 508, - 669, 617, 502, 532, 517, 479 - ] - input_data = [] - for idx, v in enumerate(input_data_values): - input_data.append({ - 'val': [v, -input_data_values[-1 - idx]], - 'row_lengths_1': [2, 1, 0], - 'row_lengths_2': [1, 0, 1], - }) - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': - tf.io.RaggedFeature( - tft_unit.canonical_numeric_dtype(input_dtype), - value_key='val', - partitions=[ - tf.io.RaggedFeature.RowLengths('row_lengths_1'), # pytype: disable=attribute-error - tf.io.RaggedFeature.RowLengths('row_lengths_2') # pytype: disable=attribute-error - ]), - }) - expected_data_values = [ - 0.91555131, -1.54543642, 1.30767697, 0.73634456, 1.03620536, 1.03443104, - 0.26969729, 0.84990131, 1.02201077, 0.72569862, 1.04862563, 1.49752966, - -0.02838919, 0.90135672, 1.18702292, 1.09475806, 0.89071077, 0.9439405, - 0.91732564, 0.84990131 - ] - expected_data = [] - for idx, v in enumerate(expected_data_values): - expected_data.append({ - 'x_gaussianized$ragged_values': ([v, - -expected_data_values[-1 - idx]]), - 'x_gaussianized$row_lengths_1': [2, 1, 0], - 'x_gaussianized$row_lengths_2': [1, 0, 1] - }) - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - desired_batch_size=20, - # Runs the test deterministically on the whole batch. - beam_pipeline=beam.Pipeline()) - - @tft_unit.named_parameters( - dict( - testcase_name='tukey_int64in', - input_dtype=tf.int64, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }), - dict( - testcase_name='tukey_int32in', - input_dtype=tf.int32, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }), - dict( - testcase_name='tukey_int16in', - input_dtype=tf.int16, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }), - dict( - testcase_name='tukey_float64in', - input_dtype=tf.float64, - output_dtypes={ - 'tukey_location': tf.float64, - 'tukey_scale': tf.float64, - 'tukey_hl': tf.float64, - 'tukey_hr': tf.float64 - }), - dict( - testcase_name='tukey_float32in', - input_dtype=tf.float32, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }, - elementwise=True), - dict( - testcase_name='tukey_float32in_reduce', - input_dtype=tf.float32, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }, - elementwise=False), - ) - def testTukeyHHAnalyzersWithDenseInputs( - self, input_dtype, output_dtypes, elementwise=True): - - def analyzer_fn(inputs): - a = tf.cast(inputs['a'], input_dtype) - - def assert_and_cast_dtype(tensor, out_dtype): - self.assertEqual(tensor.dtype, out_dtype) - return tf.cast(tensor, tft_unit.canonical_numeric_dtype(out_dtype)) - - return { - 'tukey_location': assert_and_cast_dtype( - tft.tukey_location(a, reduce_instance_dims=not elementwise), - output_dtypes['tukey_location']), - 'tukey_scale': assert_and_cast_dtype( - tft.tukey_scale(a, reduce_instance_dims=not elementwise), - output_dtypes['tukey_scale']), - 'tukey_hl': assert_and_cast_dtype( - tft.tukey_h_params(a, reduce_instance_dims=not elementwise)[0], - output_dtypes['tukey_hl']), - 'tukey_hr': assert_and_cast_dtype( - tft.tukey_h_params(a, reduce_instance_dims=not elementwise)[1], - output_dtypes['tukey_hr']), - } - - input_data_values = [516, -871, 737, 415, 584, 583, 152, 479, 576, 409, 591, - 844, -16, 508, 669, 617, 502, 532, 517, 479] - input_data = [] - for idx, v in enumerate(input_data_values): - input_data.append({'a': [v, -input_data_values[-1 - idx]]}) - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': - tf.io.FixedLenFeature([2], - tft_unit.canonical_numeric_dtype(input_dtype)) - }) - expected_outputs = { - 'tukey_location': - np.array( + def setUp(self): + self._context = beam_impl.Context(use_deep_copy_optimization=True) + self._context.__enter__() + super().setUp() + + def tearDown(self): + self._context.__exit__() + super().tearDown() + + @tft_unit.named_parameters(*_SCALE_TO_GAUSSIAN_TEST_CASES) + def testGaussianize(self, input_data, output_data, elementwise): + def preprocessing_fn(inputs): + x = inputs["x"] + x_cast = tf.cast(x, tf.as_dtype(input_data.dtype)) + x_gaussianized = tft.scale_to_gaussian(x_cast, elementwise=elementwise) + self.assertEqual(x_gaussianized.dtype, tf.as_dtype(output_data.dtype)) + return {"x_gaussianized": tf.cast(x_gaussianized, tf.float32)} + + input_data_dicts = [{"x": x} for x in input_data] + expected_data_dicts = [ + {"x_gaussianized": x_gaussianized} for x_gaussianized in output_data + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature( + input_data.shape[1:], + tft_unit.canonical_numeric_dtype(tf.as_dtype(input_data.dtype)), + ), + } + ) + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x_gaussianized": tf.io.FixedLenFeature( + output_data.shape[1:], tf.float32 + ), + } + ) + self.assertAnalyzeAndTransformResults( + input_data_dicts, + input_metadata, + preprocessing_fn, + expected_data_dicts, + expected_metadata, + desired_batch_size=20, + beam_pipeline=beam.Pipeline(), + ) + + @tft_unit.parameters( + *itertools.product( + [ + tf.int16, + tf.int32, + tf.int64, + tf.float32, + tf.float64, + ], + (True, False), + ) + ) + def testGaussianizeSparse(self, input_dtype, elementwise): + def preprocessing_fn(inputs): + x_gaussianized = tft.scale_to_gaussian( + tf.cast(inputs["x"], input_dtype), elementwise=elementwise + ) + self.assertEqual( + x_gaussianized.dtype, impl_test._mean_output_dtype(input_dtype) + ) + return {"x_gaussianized": tf.cast(x_gaussianized, tf.float32)} + + input_data_values = [ + 516, + -871, + 737, + 415, + 584, + 583, + 152, + 479, + 576, + 409, + 591, + 844, + -16, + 508, + 669, + 617, + 502, + 532, + 517, + 479, + ] + input_data = [] + for idx, v in enumerate(input_data_values): + input_data.append( + { + "idx0": [1, 1], + "idx1": [0, 1], + "val": [v, -input_data_values[-1 - idx]], + } + ) + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.SparseFeature( + ["idx0", "idx1"], + "val", + tft_unit.canonical_numeric_dtype(input_dtype), + (4, 5), + ) + } + ) + if elementwise: + expected_data_values = [ + -0.09304726, + -2.24682532, + 1.56900163, + -0.78244931, + 0.48285998, + 0.47461339, + -1.50929952, + -0.39008015, + 0.41659823, + -0.81174337, + 0.54027596, + 2.11624695, + -1.72816411, + -0.16046759, + 1.13320023, + 0.74814557, + -0.21014091, + 0.04373742, + -0.08454805, + -0.39008015, + ] + else: + expected_data_values = [ + 0.91555131, + -1.54543642, + 1.30767697, + 0.73634456, + 1.03620536, + 1.03443104, + 0.26969729, + 0.84990131, + 1.02201077, + 0.72569862, + 1.04862563, + 1.49752966, + -0.02838919, + 0.90135672, + 1.18702292, + 1.09475806, + 0.89071077, + 0.9439405, + 0.91732564, + 0.84990131, + ] + expected_data = [] + for idx, v in enumerate(expected_data_values): + expected_data.append( + { + "x_gaussianized$sparse_values": ( + [v, -expected_data_values[-1 - idx]] + ), + "x_gaussianized$sparse_indices_0": [1, 1], + "x_gaussianized$sparse_indices_1": [0, 1], + } + ) + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + desired_batch_size=20, + beam_pipeline=beam.Pipeline(), + ) + + @tft_unit.parameters( + (tf.int16,), + (tf.int32,), + (tf.int64,), + (tf.float32,), + (tf.float64,), + ) + def testGaussianizeRagged(self, input_dtype): + tft_unit.skip_if_not_tf2("RaggedFeature is not available in TF 1.x.") + + def preprocessing_fn(inputs): + x_gaussianized = tft.scale_to_gaussian(tf.cast(inputs["x"], input_dtype)) + self.assertEqual( + x_gaussianized.dtype, impl_test._mean_output_dtype(input_dtype) + ) + return {"x_gaussianized": tf.cast(x_gaussianized, tf.float32)} + + input_data_values = [ + 516, + -871, + 737, + 415, + 584, + 583, + 152, + 479, + 576, + 409, + 591, + 844, + -16, + 508, + 669, + 617, + 502, + 532, + 517, + 479, + ] + input_data = [] + for idx, v in enumerate(input_data_values): + input_data.append( + { + "val": [v, -input_data_values[-1 - idx]], + "row_lengths_1": [2, 1, 0], + "row_lengths_2": [1, 0, 1], + } + ) + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.RaggedFeature( + tft_unit.canonical_numeric_dtype(input_dtype), + value_key="val", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "row_lengths_1" + ), # pytype: disable=attribute-error + tf.io.RaggedFeature.RowLengths( + "row_lengths_2" + ), # pytype: disable=attribute-error + ], + ), + } + ) + expected_data_values = [ + 0.91555131, + -1.54543642, + 1.30767697, + 0.73634456, + 1.03620536, + 1.03443104, + 0.26969729, + 0.84990131, + 1.02201077, + 0.72569862, + 1.04862563, + 1.49752966, + -0.02838919, + 0.90135672, + 1.18702292, + 1.09475806, + 0.89071077, + 0.9439405, + 0.91732564, + 0.84990131, + ] + expected_data = [] + for idx, v in enumerate(expected_data_values): + expected_data.append( + { + "x_gaussianized$ragged_values": ( + [v, -expected_data_values[-1 - idx]] + ), + "x_gaussianized$row_lengths_1": [2, 1, 0], + "x_gaussianized$row_lengths_2": [1, 0, 1], + } + ) + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + desired_batch_size=20, + # Runs the test deterministically on the whole batch. + beam_pipeline=beam.Pipeline(), + ) + + @tft_unit.named_parameters( + dict( + testcase_name="tukey_int64in", + input_dtype=tf.int64, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + ), + dict( + testcase_name="tukey_int32in", + input_dtype=tf.int32, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + ), + dict( + testcase_name="tukey_int16in", + input_dtype=tf.int16, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + ), + dict( + testcase_name="tukey_float64in", + input_dtype=tf.float64, + output_dtypes={ + "tukey_location": tf.float64, + "tukey_scale": tf.float64, + "tukey_hl": tf.float64, + "tukey_hr": tf.float64, + }, + ), + dict( + testcase_name="tukey_float32in", + input_dtype=tf.float32, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + elementwise=True, + ), + dict( + testcase_name="tukey_float32in_reduce", + input_dtype=tf.float32, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + elementwise=False, + ), + ) + def testTukeyHHAnalyzersWithDenseInputs( + self, input_dtype, output_dtypes, elementwise=True + ): + def analyzer_fn(inputs): + a = tf.cast(inputs["a"], input_dtype) + + def assert_and_cast_dtype(tensor, out_dtype): + self.assertEqual(tensor.dtype, out_dtype) + return tf.cast(tensor, tft_unit.canonical_numeric_dtype(out_dtype)) + + return { + "tukey_location": assert_and_cast_dtype( + tft.tukey_location(a, reduce_instance_dims=not elementwise), + output_dtypes["tukey_location"], + ), + "tukey_scale": assert_and_cast_dtype( + tft.tukey_scale(a, reduce_instance_dims=not elementwise), + output_dtypes["tukey_scale"], + ), + "tukey_hl": assert_and_cast_dtype( + tft.tukey_h_params(a, reduce_instance_dims=not elementwise)[0], + output_dtypes["tukey_hl"], + ), + "tukey_hr": assert_and_cast_dtype( + tft.tukey_h_params(a, reduce_instance_dims=not elementwise)[1], + output_dtypes["tukey_hr"], + ), + } + + input_data_values = [ + 516, + -871, + 737, + 415, + 584, + 583, + 152, + 479, + 576, + 409, + 591, + 844, + -16, + 508, + 669, + 617, + 502, + 532, + 517, + 479, + ] + input_data = [] + for idx, v in enumerate(input_data_values): + input_data.append({"a": [v, -input_data_values[-1 - idx]]}) + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature( + [2], tft_unit.canonical_numeric_dtype(input_dtype) + ) + } + ) + expected_outputs = { + "tukey_location": np.array( [526.89355, -526.89355] if elementwise else 0.0, tft_unit.canonical_numeric_dtype( - output_dtypes['tukey_location']).as_numpy_dtype), - 'tukey_scale': - np.array( + output_dtypes["tukey_location"] + ).as_numpy_dtype, + ), + "tukey_scale": np.array( [116.73997, 116.73997] if elementwise else 572.277649, tft_unit.canonical_numeric_dtype( - output_dtypes['tukey_scale']).as_numpy_dtype), - 'tukey_hl': - np.array( + output_dtypes["tukey_scale"] + ).as_numpy_dtype, + ), + "tukey_hl": np.array( [0.6629082, 0.11148566] if elementwise else 0.0, tft_unit.canonical_numeric_dtype( - output_dtypes['tukey_hl']).as_numpy_dtype), - 'tukey_hr': - np.array( + output_dtypes["tukey_hl"] + ).as_numpy_dtype, + ), + "tukey_hr": np.array( [0.11148566, 0.6629082] if elementwise else 0.0, tft_unit.canonical_numeric_dtype( - output_dtypes['tukey_hr']).as_numpy_dtype), - } - - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=20, - # Runs the test deterministically on the whole batch. - beam_pipeline=beam.Pipeline()) - - def testTukeyHHAnalyzersWithNDDenseInputs(self): - - def analyzer_fn(inputs): - a = inputs['a'] - - return { - 'tukey_location': tft.tukey_location(a, reduce_instance_dims=False), - 'tukey_scale': tft.tukey_scale(a, reduce_instance_dims=False), - 'tukey_hl': tft.tukey_h_params(a, reduce_instance_dims=False)[0], - 'tukey_hr': tft.tukey_h_params(a, reduce_instance_dims=False)[1], - } - - input_data_values = [516, -871, 737, 415, 584, 583, 152, 479, 576, 409, 591, - 844, -16, 508, 669, 617, 502, 532, 517, 479] - input_data = [] - for idx, v in enumerate(input_data_values): - input_data.append({'a': [ - [v, -input_data_values[-1 - idx]], - [2 * v, -2 * input_data_values[-1 - idx]]]}) - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([2, 2], tf.float32)}) - expected_outputs = { - 'tukey_location': - np.array( - [[526.89355, -526.89355], [2. * 526.89355, -2. * 526.89355]], - np.float32), - 'tukey_scale': - np.array([[116.73997, 116.73997], [2. * 116.73997, 2. * 116.73997]], - np.float32), - 'tukey_hl': - np.array( - [[0.6629082, 0.11148566], [0.6629082, 0.11148566]], np.float32), - 'tukey_hr': - np.array( - [[0.11148566, 0.6629082], [0.11148566, 0.6629082]], np.float32) - } - - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=20, - # Runs the test deterministically on the whole batch. - beam_pipeline=beam.Pipeline()) - - @tft_unit.named_parameters( - dict( - testcase_name='_int64in', - input_dtype=tf.int64, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }), - dict( - testcase_name='_int32in', - input_dtype=tf.int32, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }), - dict( - testcase_name='_int16in', - input_dtype=tf.int16, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }), - dict( - testcase_name='_float64in', - input_dtype=tf.float64, - output_dtypes={ - 'tukey_location': tf.float64, - 'tukey_scale': tf.float64, - 'tukey_hl': tf.float64, - 'tukey_hr': tf.float64 - }), - dict( - testcase_name='_float32in', - input_dtype=tf.float32, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }, - elementwise=True - ), - dict( - testcase_name='_float32in_reduce', - input_dtype=tf.float32, - output_dtypes={ - 'tukey_location': tf.float32, - 'tukey_scale': tf.float32, - 'tukey_hl': tf.float32, - 'tukey_hr': tf.float32 - }, - elementwise=False - ), - ) - def testTukeyHHAnalyzersWithSparseInputs( - self, input_dtype, output_dtypes, elementwise=True): - - def analyzer_fn(inputs): - a = tf.cast(inputs['a'], input_dtype) - - def assert_and_cast_dtype(tensor, out_dtype): - self.assertEqual(tensor.dtype, out_dtype) - return tf.cast(tensor, tft_unit.canonical_numeric_dtype(out_dtype)) - - return { - 'tukey_location': assert_and_cast_dtype( - tft.tukey_location(a, reduce_instance_dims=not elementwise), - output_dtypes['tukey_location']), - 'tukey_scale': assert_and_cast_dtype( - tft.tukey_scale(a, reduce_instance_dims=not elementwise), - output_dtypes['tukey_scale']), - 'tukey_hl': assert_and_cast_dtype( - tft.tukey_h_params(a, reduce_instance_dims=not elementwise)[0], - output_dtypes['tukey_hl']), - 'tukey_hr': assert_and_cast_dtype( - tft.tukey_h_params(a, reduce_instance_dims=not elementwise)[1], - output_dtypes['tukey_hr']), - } - - input_data_values = [516, -871, 737, 415, 584, 583, 152, 479, 576, 409, 591, - 844, -16, 508, 669, 617, 502, 532, 517, 479] - input_data = [] - for idx, v in enumerate(input_data_values): - input_data.append({ - 'idx0': [0, 0], - 'idx1': [0, 1], - 'val': [v, -input_data_values[-1 - idx]] - }) - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': - tf.io.SparseFeature(['idx0', 'idx1'], 'val', - tft_unit.canonical_numeric_dtype(input_dtype), - (2, 2)) - }) - - expected_outputs = { - 'tukey_location': - np.array( - [[526.89355, -526.89355], [0., 0.]] if elementwise else 0.0, + output_dtypes["tukey_hr"] + ).as_numpy_dtype, + ), + } + + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=20, + # Runs the test deterministically on the whole batch. + beam_pipeline=beam.Pipeline(), + ) + + def testTukeyHHAnalyzersWithNDDenseInputs(self): + def analyzer_fn(inputs): + a = inputs["a"] + + return { + "tukey_location": tft.tukey_location(a, reduce_instance_dims=False), + "tukey_scale": tft.tukey_scale(a, reduce_instance_dims=False), + "tukey_hl": tft.tukey_h_params(a, reduce_instance_dims=False)[0], + "tukey_hr": tft.tukey_h_params(a, reduce_instance_dims=False)[1], + } + + input_data_values = [ + 516, + -871, + 737, + 415, + 584, + 583, + 152, + 479, + 576, + 409, + 591, + 844, + -16, + 508, + 669, + 617, + 502, + 532, + 517, + 479, + ] + input_data = [] + for idx, v in enumerate(input_data_values): + input_data.append( + { + "a": [ + [v, -input_data_values[-1 - idx]], + [2 * v, -2 * input_data_values[-1 - idx]], + ] + } + ) + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([2, 2], tf.float32)} + ) + expected_outputs = { + "tukey_location": np.array( + [[526.89355, -526.89355], [2.0 * 526.89355, -2.0 * 526.89355]], + np.float32, + ), + "tukey_scale": np.array( + [[116.73997, 116.73997], [2.0 * 116.73997, 2.0 * 116.73997]], np.float32 + ), + "tukey_hl": np.array( + [[0.6629082, 0.11148566], [0.6629082, 0.11148566]], np.float32 + ), + "tukey_hr": np.array( + [[0.11148566, 0.6629082], [0.11148566, 0.6629082]], np.float32 + ), + } + + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=20, + # Runs the test deterministically on the whole batch. + beam_pipeline=beam.Pipeline(), + ) + + @tft_unit.named_parameters( + dict( + testcase_name="_int64in", + input_dtype=tf.int64, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + ), + dict( + testcase_name="_int32in", + input_dtype=tf.int32, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + ), + dict( + testcase_name="_int16in", + input_dtype=tf.int16, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + ), + dict( + testcase_name="_float64in", + input_dtype=tf.float64, + output_dtypes={ + "tukey_location": tf.float64, + "tukey_scale": tf.float64, + "tukey_hl": tf.float64, + "tukey_hr": tf.float64, + }, + ), + dict( + testcase_name="_float32in", + input_dtype=tf.float32, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + elementwise=True, + ), + dict( + testcase_name="_float32in_reduce", + input_dtype=tf.float32, + output_dtypes={ + "tukey_location": tf.float32, + "tukey_scale": tf.float32, + "tukey_hl": tf.float32, + "tukey_hr": tf.float32, + }, + elementwise=False, + ), + ) + def testTukeyHHAnalyzersWithSparseInputs( + self, input_dtype, output_dtypes, elementwise=True + ): + def analyzer_fn(inputs): + a = tf.cast(inputs["a"], input_dtype) + + def assert_and_cast_dtype(tensor, out_dtype): + self.assertEqual(tensor.dtype, out_dtype) + return tf.cast(tensor, tft_unit.canonical_numeric_dtype(out_dtype)) + + return { + "tukey_location": assert_and_cast_dtype( + tft.tukey_location(a, reduce_instance_dims=not elementwise), + output_dtypes["tukey_location"], + ), + "tukey_scale": assert_and_cast_dtype( + tft.tukey_scale(a, reduce_instance_dims=not elementwise), + output_dtypes["tukey_scale"], + ), + "tukey_hl": assert_and_cast_dtype( + tft.tukey_h_params(a, reduce_instance_dims=not elementwise)[0], + output_dtypes["tukey_hl"], + ), + "tukey_hr": assert_and_cast_dtype( + tft.tukey_h_params(a, reduce_instance_dims=not elementwise)[1], + output_dtypes["tukey_hr"], + ), + } + + input_data_values = [ + 516, + -871, + 737, + 415, + 584, + 583, + 152, + 479, + 576, + 409, + 591, + 844, + -16, + 508, + 669, + 617, + 502, + 532, + 517, + 479, + ] + input_data = [] + for idx, v in enumerate(input_data_values): + input_data.append( + { + "idx0": [0, 0], + "idx1": [0, 1], + "val": [v, -input_data_values[-1 - idx]], + } + ) + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.SparseFeature( + ["idx0", "idx1"], + "val", + tft_unit.canonical_numeric_dtype(input_dtype), + (2, 2), + ) + } + ) + + expected_outputs = { + "tukey_location": np.array( + [[526.89355, -526.89355], [0.0, 0.0]] if elementwise else 0.0, tft_unit.canonical_numeric_dtype( - output_dtypes['tukey_location']).as_numpy_dtype), - 'tukey_scale': - np.array( - [[116.73997, 116.73997], [1., 1.]] if elementwise else 572.2776, + output_dtypes["tukey_location"] + ).as_numpy_dtype, + ), + "tukey_scale": np.array( + [[116.73997, 116.73997], [1.0, 1.0]] if elementwise else 572.2776, tft_unit.canonical_numeric_dtype( - output_dtypes['tukey_scale']).as_numpy_dtype), - 'tukey_hl': - np.array( - [[0.6629082, 0.11148566], [0., 0.]] if elementwise else 0.0, + output_dtypes["tukey_scale"] + ).as_numpy_dtype, + ), + "tukey_hl": np.array( + [[0.6629082, 0.11148566], [0.0, 0.0]] if elementwise else 0.0, tft_unit.canonical_numeric_dtype( - output_dtypes['tukey_hl']).as_numpy_dtype), - 'tukey_hr': - np.array( - [[0.11148566, 0.6629082], [0., 0.]] if elementwise else 0.0, + output_dtypes["tukey_hl"] + ).as_numpy_dtype, + ), + "tukey_hr": np.array( + [[0.11148566, 0.6629082], [0.0, 0.0]] if elementwise else 0.0, tft_unit.canonical_numeric_dtype( - output_dtypes['tukey_hr']).as_numpy_dtype), - } - - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=20, - # Runs the test deterministically on the whole batch. - beam_pipeline=beam.Pipeline()) - - @tft_unit.parameters( - (tf.int16,), - (tf.int32,), - (tf.int64,), - (tf.float32,), - (tf.float64,), - ) - def testTukeyHHAnalyzersWithRaggedInputs(self, input_dtype): - tft_unit.skip_if_not_tf2('RaggedFeature is not available in TF 1.x.') - - output_dtype = impl_test._mean_output_dtype(input_dtype) - canonical_output_dtype = tft_unit.canonical_numeric_dtype(output_dtype) - - def analyzer_fn(inputs): - a = tf.cast(inputs['a'], input_dtype) - - def assert_and_cast_dtype(tensor): - self.assertEqual(tensor.dtype, output_dtype) - return tf.cast(tensor, canonical_output_dtype) - - return { - 'tukey_location': assert_and_cast_dtype(tft.tukey_location(a)), - 'tukey_scale': assert_and_cast_dtype(tft.tukey_scale(a)), - 'tukey_hl': assert_and_cast_dtype(tft.tukey_h_params(a)[0]), - 'tukey_hr': assert_and_cast_dtype(tft.tukey_h_params(a)[1]), - } - - input_data_values = [ - 516, -871, 737, 415, 584, 583, 152, 479, 576, 409, 591, 844, -16, 508, - 669, 617, 502, 532, 517, 479 - ] - input_data = [] - for idx, v in enumerate(input_data_values): - input_data.append({ - 'val': [v, -input_data_values[-1 - idx]], - 'row_lengths_1': [2, 0, 1], - 'row_lengths_2': [0, 1, 1] - }) - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': - tf.io.RaggedFeature( - tft_unit.canonical_numeric_dtype(input_dtype), - value_key='val', - partitions=[ - tf.io.RaggedFeature.RowLengths('row_lengths_1'), # pytype: disable=attribute-error - tf.io.RaggedFeature.RowLengths('row_lengths_2') # pytype: disable=attribute-error - ]), - }) - - expected_outputs = { - 'tukey_location': - np.array(0.0, canonical_output_dtype.as_numpy_dtype), - 'tukey_scale': - np.array(572.2776, canonical_output_dtype.as_numpy_dtype), - 'tukey_hl': - np.array(0.0, canonical_output_dtype.as_numpy_dtype), - 'tukey_hr': - np.array(0.0, canonical_output_dtype.as_numpy_dtype), - } - - self.assertAnalyzerOutputs( - input_data, - input_metadata, - analyzer_fn, - expected_outputs, - desired_batch_size=20, - # Runs the test deterministically on the whole batch. - beam_pipeline=beam.Pipeline()) - -if __name__ == '__main__': - tft_unit.main() + output_dtypes["tukey_hr"] + ).as_numpy_dtype, + ), + } + + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=20, + # Runs the test deterministically on the whole batch. + beam_pipeline=beam.Pipeline(), + ) + + @tft_unit.parameters( + (tf.int16,), + (tf.int32,), + (tf.int64,), + (tf.float32,), + (tf.float64,), + ) + def testTukeyHHAnalyzersWithRaggedInputs(self, input_dtype): + tft_unit.skip_if_not_tf2("RaggedFeature is not available in TF 1.x.") + + output_dtype = impl_test._mean_output_dtype(input_dtype) + canonical_output_dtype = tft_unit.canonical_numeric_dtype(output_dtype) + + def analyzer_fn(inputs): + a = tf.cast(inputs["a"], input_dtype) + + def assert_and_cast_dtype(tensor): + self.assertEqual(tensor.dtype, output_dtype) + return tf.cast(tensor, canonical_output_dtype) + + return { + "tukey_location": assert_and_cast_dtype(tft.tukey_location(a)), + "tukey_scale": assert_and_cast_dtype(tft.tukey_scale(a)), + "tukey_hl": assert_and_cast_dtype(tft.tukey_h_params(a)[0]), + "tukey_hr": assert_and_cast_dtype(tft.tukey_h_params(a)[1]), + } + + input_data_values = [ + 516, + -871, + 737, + 415, + 584, + 583, + 152, + 479, + 576, + 409, + 591, + 844, + -16, + 508, + 669, + 617, + 502, + 532, + 517, + 479, + ] + input_data = [] + for idx, v in enumerate(input_data_values): + input_data.append( + { + "val": [v, -input_data_values[-1 - idx]], + "row_lengths_1": [2, 0, 1], + "row_lengths_2": [0, 1, 1], + } + ) + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.RaggedFeature( + tft_unit.canonical_numeric_dtype(input_dtype), + value_key="val", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "row_lengths_1" + ), # pytype: disable=attribute-error + tf.io.RaggedFeature.RowLengths( + "row_lengths_2" + ), # pytype: disable=attribute-error + ], + ), + } + ) + + expected_outputs = { + "tukey_location": np.array(0.0, canonical_output_dtype.as_numpy_dtype), + "tukey_scale": np.array(572.2776, canonical_output_dtype.as_numpy_dtype), + "tukey_hl": np.array(0.0, canonical_output_dtype.as_numpy_dtype), + "tukey_hr": np.array(0.0, canonical_output_dtype.as_numpy_dtype), + } + + self.assertAnalyzerOutputs( + input_data, + input_metadata, + analyzer_fn, + expected_outputs, + desired_batch_size=20, + # Runs the test deterministically on the whole batch. + beam_pipeline=beam.Pipeline(), + ) + + +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/vocabulary_integration_test.py b/tensorflow_transform/beam/vocabulary_integration_test.py index 437d794..2030525 100644 --- a/tensorflow_transform/beam/vocabulary_integration_test.py +++ b/tensorflow_transform/beam/vocabulary_integration_test.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2017 Google Inc. All Rights Reserved. # @@ -19,2074 +18,2544 @@ import apache_beam as beam import tensorflow as tf +from tensorflow_metadata.proto.v0 import schema_pb2 + import tensorflow_transform as tft -from tensorflow_transform.beam import analyzer_impls +from tensorflow_transform.beam import analyzer_impls, tft_unit from tensorflow_transform.beam import impl as beam_impl from tensorflow_transform.beam.tft_beam_io import transform_fn_io -from tensorflow_transform.beam import tft_unit - -from tensorflow_metadata.proto.v0 import schema_pb2 _COMPOSITE_COMPUTE_AND_APPLY_VOCABULARY_TEST_CASES = [ dict( - testcase_name='sparse', + testcase_name="sparse", input_data=[ - {'val': ['hello'], 'idx0': [0], 'idx1': [0]}, - {'val': ['world'], 'idx0': [1], 'idx1': [1]}, - {'val': ['hello', 'goodbye'], 'idx0': [0, 1], 'idx1': [1, 2]}, + {"val": ["hello"], "idx0": [0], "idx1": [0]}, + {"val": ["world"], "idx0": [1], "idx1": [1]}, + {"val": ["hello", "goodbye"], "idx0": [0, 1], "idx1": [1, 2]}, { - 'val': ['hello', 'goodbye', ' '], - 'idx0': [0, 1, 1], - 'idx1': [0, 1, 2], + "val": ["hello", "goodbye", " "], + "idx0": [0, 1, 1], + "idx1": [0, 1, 2], }, ], input_metadata=tft.DatasetMetadata.from_feature_spec( - { - 'x': tf.io.SparseFeature( - ['idx0', 'idx1'], 'val', tf.string, [2, 3] - ) - } + {"x": tf.io.SparseFeature(["idx0", "idx1"], "val", tf.string, [2, 3])} ), expected_data=[ { - 'index$sparse_indices_0': [0], - 'index$sparse_indices_1': [0], - 'index$sparse_values': [0], + "index$sparse_indices_0": [0], + "index$sparse_indices_1": [0], + "index$sparse_values": [0], }, { - 'index$sparse_indices_0': [1], - 'index$sparse_indices_1': [1], - 'index$sparse_values': [2], + "index$sparse_indices_0": [1], + "index$sparse_indices_1": [1], + "index$sparse_values": [2], }, { - 'index$sparse_indices_0': [0, 1], - 'index$sparse_indices_1': [1, 2], - 'index$sparse_values': [0, 1], + "index$sparse_indices_0": [0, 1], + "index$sparse_indices_1": [1, 2], + "index$sparse_values": [0, 1], }, { - 'index$sparse_indices_0': [0, 1, 1], - 'index$sparse_indices_1': [0, 1, 2], - 'index$sparse_values': [0, 1, 3], + "index$sparse_indices_0": [0, 1, 1], + "index$sparse_indices_1": [0, 1, 2], + "index$sparse_values": [0, 1, 3], }, ], expected_vocab_contents={ - b'hello': 3, - b'goodbye': 2, - b'world': 1, - b' ': 1, + b"hello": 3, + b"goodbye": 2, + b"world": 1, + b" ": 1, }, ), dict( - testcase_name='ragged', + testcase_name="ragged", input_data=[ - {'val': ['hello', ' '], 'row_lengths': [1, 0, 1]}, - {'val': ['world'], 'row_lengths': [0, 1]}, - {'val': ['hello', 'goodbye'], 'row_lengths': [2, 0, 0]}, - {'val': ['hello', 'goodbye', ' '], 'row_lengths': [0, 2, 1]}, + {"val": ["hello", " "], "row_lengths": [1, 0, 1]}, + {"val": ["world"], "row_lengths": [0, 1]}, + {"val": ["hello", "goodbye"], "row_lengths": [2, 0, 0]}, + {"val": ["hello", "goodbye", " "], "row_lengths": [0, 2, 1]}, ], input_metadata=tft.DatasetMetadata.from_feature_spec( { - 'x': tf.io.RaggedFeature( + "x": tf.io.RaggedFeature( tf.string, - value_key='val', + value_key="val", partitions=[ - tf.io.RaggedFeature.RowLengths('row_lengths') # pytype: disable=attribute-error + tf.io.RaggedFeature.RowLengths( + "row_lengths" + ) # pytype: disable=attribute-error ], ) } ), expected_data=[ - {'index$ragged_values': [0, 2], 'index$row_lengths_1': [1, 0, 1]}, - {'index$ragged_values': [3], 'index$row_lengths_1': [0, 1]}, - {'index$ragged_values': [0, 1], 'index$row_lengths_1': [2, 0, 0]}, + {"index$ragged_values": [0, 2], "index$row_lengths_1": [1, 0, 1]}, + {"index$ragged_values": [3], "index$row_lengths_1": [0, 1]}, + {"index$ragged_values": [0, 1], "index$row_lengths_1": [2, 0, 0]}, { - 'index$ragged_values': [0, 1, 2], - 'index$row_lengths_1': [0, 2, 1], + "index$ragged_values": [0, 1, 2], + "index$row_lengths_1": [0, 2, 1], }, ], expected_vocab_contents={ - b'hello': 3, - b'goodbye': 2, - b' ': 2, - b'world': 1, + b"hello": 3, + b"goodbye": 2, + b" ": 2, + b"world": 1, }, ), ] class VocabularyIntegrationTest(tft_unit.TransformTestCase): + def setUp(self): + tf.compat.v1.logging.info("Starting test case: %s", self._testMethodName) + super().setUp() + + def _VocabFormat(self): + return "text" + + _WITH_LABEL_PARAMS = tft_unit.cross_named_parameters( + [ + dict( + testcase_name="_string", + x_data=[ + b"hello", + b"hello", + b"hello", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + expected_vocab_file_contents=[ + (b"goodbye", 1.9753224), + (b"aaaaa", 1.6600707), + (b"hello", 1.2450531), + ], + ), + dict( + testcase_name="_int64", + x_data=[3, 3, 3, 1, 2, 2, 1, 1, 2, 2, 1, 1], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[ + (b"1", 1.9753224), + (b"2", 1.6600707), + (b"3", 1.2450531), + ], + ), + ], + [ + dict( + testcase_name="with_label", + label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + min_diff_from_avg=0.0, + store_frequency=True, + ), + ], + ) - def setUp(self): - tf.compat.v1.logging.info('Starting test case: %s', self._testMethodName) - super().setUp() - - def _VocabFormat(self): - return 'text' - - _WITH_LABEL_PARAMS = tft_unit.cross_named_parameters([ - dict( - testcase_name='_string', - x_data=[ - b'hello', b'hello', b'hello', b'goodbye', b'aaaaa', b'aaaaa', - b'goodbye', b'goodbye', b'aaaaa', b'aaaaa', b'goodbye', b'goodbye' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - expected_vocab_file_contents=[(b'goodbye', 1.9753224), - (b'aaaaa', 1.6600707), - (b'hello', 1.2450531)]), - dict( - testcase_name='_int64', - x_data=[3, 3, 3, 1, 2, 2, 1, 1, 2, 2, 1, 1], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[(b'1', 1.9753224), (b'2', 1.6600707), - (b'3', 1.2450531)]), - ], [ - dict( - testcase_name='with_label', - label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - min_diff_from_avg=0.0, - store_frequency=True), - ]) - - @tft_unit.named_parameters(*([ - dict( - testcase_name='_unadjusted_mi_binary_label', - x_data=[ - b'informative', b'informative', b'informative', b'uninformative', - b'uninformative', b'uninformative', b'uninformative', - b'uninformative_rare', b'uninformative_rare' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 1, 0, 1, 1, 0, 0, 1], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[ - (b'informative', 1.7548264), - (b'uninformative', 0.33985), - (b'uninformative_rare', 0.169925), - ], - min_diff_from_avg=0.0, - use_adjusted_mutual_info=False, - store_frequency=True), - dict( - testcase_name='_unadjusted_mi_multi_class_label', - x_data=[ - b'good_predictor_of_0', b'good_predictor_of_0', - b'good_predictor_of_0', b'good_predictor_of_1', - b'good_predictor_of_2', b'good_predictor_of_2', - b'good_predictor_of_2', b'good_predictor_of_1', - b'good_predictor_of_1', b'weak_predictor_of_1', - b'good_predictor_of_0', b'good_predictor_of_1', - b'good_predictor_of_1', b'good_predictor_of_1', - b'weak_predictor_of_1' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[0, 0, 0, 1, 2, 2, 2, 1, 1, 1, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[ - (b'good_predictor_of_2', 6.9656613), - (b'good_predictor_of_1', 6.5969828), - (b'good_predictor_of_0', 6.339692), - (b'weak_predictor_of_1', 0.684463), - ], - min_diff_from_avg=0.0, - use_adjusted_mutual_info=False, - store_frequency=True), - dict( - testcase_name='_unadjusted_mi_binary_label_with_weights', - x_data=[ - b'informative_1', b'informative_1', b'informative_0', - b'informative_0', b'uninformative', b'uninformative', - b'informative_by_weight', b'informative_by_weight' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 0, 0, 0, 1, 0, 1], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - # uninformative and informative_by_weight have the same co-occurrence - # relationship with the label but will have different importance - # values due to the weighting. - expected_vocab_file_contents=[ - (b'informative_0', 3.1698803), - (b'informative_1', 1.1698843), - (b'informative_by_weight', 0.6096405), - (b'uninformative', 0.169925), - ], - weight_data=[1, 1, 1, 1, 1, 1, 1, 5], - weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), - min_diff_from_avg=0.0, - use_adjusted_mutual_info=False, - store_frequency=True), - dict( - testcase_name='_unadjusted_mi_binary_label_min_diff_from_avg', - x_data=[ - b'hello', b'hello', b'hello', b'goodbye', b'aaaaa', b'aaaaa', - b'goodbye', b'goodbye', b'aaaaa', b'aaaaa', b'goodbye', b'goodbye' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - # All features are weak predictors, so all are adjusted to zero. - expected_vocab_file_contents=[ - (b'hello', 0.0), - (b'goodbye', 0.0), - (b'aaaaa', 0.0), - ], - use_adjusted_mutual_info=False, - min_diff_from_avg=2.0, - store_frequency=True), - dict( - testcase_name='_adjusted_mi_binary_label', - x_data=[ - b'hello', b'hello', b'hello', b'goodbye', b'aaaaa', b'aaaaa', - b'goodbye', b'goodbye', b'aaaaa', b'aaaaa', b'goodbye', b'goodbye' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[ - (b'goodbye', 1.4070794), - (b'aaaaa', 0.9987448), - (b'hello', 0.5017178), - ], - min_diff_from_avg=0.0, - use_adjusted_mutual_info=True, - store_frequency=True), - dict( - testcase_name='_adjusted_mi_binary_label_int64_feature', - x_data=[3, 3, 3, 1, 2, 2, 1, 1, 2, 2, 1, 1], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[ - (b'1', 1.4070794), - (b'2', 0.9987448), - (b'3', 0.5017178), - ], - min_diff_from_avg=0.0, - use_adjusted_mutual_info=True, - store_frequency=True), - dict( - testcase_name='_adjusted_mi_multi_class_label', - x_data=[ - b'good_predictor_of_0', b'good_predictor_of_0', - b'good_predictor_of_0', b'good_predictor_of_1', - b'good_predictor_of_2', b'good_predictor_of_2', - b'good_predictor_of_2', b'good_predictor_of_1', - b'good_predictor_of_1', b'weak_predictor_of_1', - b'good_predictor_of_0', b'good_predictor_of_1', - b'good_predictor_of_1', b'good_predictor_of_1', - b'weak_predictor_of_1' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[0, 0, 0, 1, 2, 2, 2, 1, 1, 1, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[ - (b'good_predictor_of_1', 5.4800901), - (b'good_predictor_of_2', 5.3861019), - (b'good_predictor_of_0', 4.9054722), - (b'weak_predictor_of_1', -0.9748023), - ], - min_diff_from_avg=0.0, - use_adjusted_mutual_info=True, - store_frequency=True), - # TODO(b/128831096): Determine correct interaction between AMI and weights - dict( - testcase_name='_adjusted_mi_binary_label_with_weights', - x_data=[ - b'informative_1', b'informative_1', b'informative_0', - b'informative_0', b'uninformative', b'uninformative', - b'informative_by_weight', b'informative_by_weight' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 0, 0, 0, 1, 0, 1], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - weight_data=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], - weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), - # uninformative and informative_by_weight have the same co-occurrence - # relationship with the label but will have different importance - # values due to the weighting. - expected_vocab_file_contents=[ - (b'informative_0', 2.3029856), - (b'informative_1', 0.3029896), - (b'informative_by_weight', 0.1713041), - (b'uninformative', -0.6969697), - ], - min_diff_from_avg=0.0, - use_adjusted_mutual_info=True, - store_frequency=True), - dict( - testcase_name='_adjusted_mi_min_diff_from_avg', - x_data=[ - b'good_predictor_of_0', b'good_predictor_of_0', - b'good_predictor_of_0', b'good_predictor_of_1', - b'good_predictor_of_0', b'good_predictor_of_1', - b'good_predictor_of_1', b'good_predictor_of_1', - b'good_predictor_of_1', b'good_predictor_of_0', - b'good_predictor_of_1', b'good_predictor_of_1', - b'good_predictor_of_1', b'weak_predictor_of_1', - b'weak_predictor_of_1' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - # With min_diff_from_avg, the small AMI value is regularized to 0 - expected_vocab_file_contents=[ - (b'good_predictor_of_0', 1.8322128), - (b'good_predictor_of_1', 1.7554416), - (b'weak_predictor_of_1', 0), - ], - use_adjusted_mutual_info=True, - min_diff_from_avg=1.0, - store_frequency=True), - dict( - testcase_name='_labels_weight_and_frequency', - x_data=[ - b'hello', b'hello', b'hello', b'goodbye', b'aaaaa', b'aaaaa', - b'goodbye', b'goodbye', b'aaaaa', b'aaaaa', b'goodbye', b'goodbye' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - weight_data=[ - 0.3, 0.4, 0.3, 1.2, 0.6, 0.7, 1.0, 1.0, 0.6, 0.7, 1.0, 1.0 - ], - weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), - expected_vocab_file_contents=[ - (b'aaaaa', 1.5637185), - (b'goodbye', 0.8699492), - (b'hello', 0.6014302), - ], - min_diff_from_avg=0.0, - store_frequency=True), - # fingerprints by which each of the tokens will be sorted if fingerprint - # shuffling is used. - # 'ho ho': '1b3dd735ddff70d90f3b7ba5ebf65df521d6ca4d' - # 'world': '7c211433f02071597741e6ff5a8ea34789abbf43' - # 'hello': 'aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d' - # 'hi': 'c22b5f9178342609428d6f51b2c5af4c0bde6a42' - # '1': '356a192b7913b04c54574d18c28d46e6395428ab' - # '2': 'da4b9237bacccdf19c0760cab7aec4a8359010b0' - # '3': '77de68daecd823babbb58edb1c8e14d7106e83bb' - dict( - testcase_name='_string_feature_with_frequency_and_shuffle', - x_data=[b'world', b'hello', b'hello'], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - expected_vocab_file_contents=[(b'world', 1), (b'hello', 2)], - fingerprint_shuffle=True, - store_frequency=True), - dict( - testcase_name='_string_feature_with_frequency_and_no_shuffle', - x_data=[b'hi', b'ho ho', b'ho ho'], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - expected_vocab_file_contents=[(b'ho ho', 2), (b'hi', 1)], - store_frequency=True), - dict( - testcase_name='_string_feature_with_no_frequency_and_shuffle', - x_data=[b'world', b'hello', b'hello'], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - expected_vocab_file_contents=[b'world', b'hello'], - fingerprint_shuffle=True), - dict( - testcase_name='_string_feature_with_no_frequency_and_no_shuffle', - x_data=[b'world', b'hello', b'hello'], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - expected_vocab_file_contents=[b'hello', b'world']), - dict( - testcase_name='_int_feature_with_frequency_and_shuffle', - x_data=[1, 2, 2, 3], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[(b'1', 1), (b'3', 1), (b'2', 2)], - fingerprint_shuffle=True, - store_frequency=True), - dict( - testcase_name='_int_feature_with_frequency_and_no_shuffle', - x_data=[2, 1, 1, 1], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[(b'1', 3), (b'2', 1)], - store_frequency=True), - dict( - testcase_name='_int_feature_with_no_frequency_and_shuffle', - x_data=[1, 2, 2, 3], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[b'1', b'3', b'2'], - fingerprint_shuffle=True), - dict( - testcase_name='_int_feature_with_no_frequency_and_no_shuffle', - x_data=[1, 2, 2, 3], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[b'2', b'3', b'1']), - dict( - testcase_name='_int_feature_with_top_k', - x_data=[111, 2, 2, 3], - top_k=2, - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - expected_vocab_file_contents=[b'2', b'3']), - ] + _WITH_LABEL_PARAMS)) - def testVocabulary(self, - x_data, - x_feature_spec, - label_data=None, - label_feature_spec=None, - weight_data=None, - weight_feature_spec=None, - expected_vocab_file_contents=None, - **kwargs): - """Test tft.Vocabulary with various inputs.""" - - input_data = [{'x': x} for x in x_data] - input_feature_spec = {'x': x_feature_spec} - - if label_data is not None: - for idx, label in enumerate(label_data): - input_data[idx]['label'] = label - input_feature_spec['label'] = label_feature_spec - - if weight_data is not None: - for idx, weight in enumerate(weight_data): - input_data[idx]['weights'] = weight - input_feature_spec['weights'] = weight_feature_spec - - input_metadata = tft.DatasetMetadata.from_feature_spec(input_feature_spec) - - def preprocessing_fn(inputs): - x = inputs['x'] - labels = inputs.get('label') - weights = inputs.get('weights') - # Note even though the return value is not used, calling tft.vocabulary - # will generate the vocabulary as a side effect, and since we have named - # this vocabulary it can be looked up using public APIs. - tft.vocabulary( - x, - labels=labels, - weights=weights, - vocab_filename='my_vocab', - file_format=self._VocabFormat(), - **kwargs) - return inputs - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - input_data, # expected output data is same as input data - input_metadata, # expected output metadata is same as input metadata - expected_vocab_file_contents={'my_vocab': expected_vocab_file_contents}) - - @tft_unit.named_parameters( - *tft_unit.cross_named_parameters( - [ - dict( - testcase_name='_string', - input_data=[ - {'x': b'hello'}, - {'x': b'hello'}, - {'x': b'hello'}, - {'x': b'goodbye'}, - {'x': b'aaaaa'}, - {'x': b'aaaaa'}, - {'x': b'goodbye'}, - {'x': b'goodbye'}, - {'x': b'aaaaa'}, - {'x': b'aaaaa'}, - {'x': b'goodbye'}, - {'x': b'goodbye'}, - ], - make_feature_spec=lambda: { # pylint: disable=g-long-lambda - 'x': tf.io.FixedLenFeature([], tf.string) - }, - top_k=2, - make_expected_vocab_fn=( - lambda _: [(b'goodbye', 5), (b'aaaaa', 4)] - ), - ), - dict( - testcase_name='_int', - input_data=[{'x': 1}, {'x': 2}, {'x': 2}, {'x': 3}, {'x': 1}], - make_feature_spec=lambda: { # pylint: disable=g-long-lambda - 'x': tf.io.FixedLenFeature([], tf.int64) - }, - top_k=2, - make_expected_vocab_fn=lambda _: [(b'2', 2), (b'1', 2)], - ), - dict( - testcase_name='_weights', - input_data=[ - {'x': b'hello', 'weights': 1.4}, - {'x': b'hello', 'weights': 0.5}, - {'x': b'hello', 'weights': 1.12}, - {'x': b'goodbye', 'weights': 0.123}, - {'x': b'aaaaa', 'weights': 0.3}, - {'x': b'aaaaa', 'weights': 1.123}, - {'x': b'goodbye', 'weights': 0.1}, - {'x': b'goodbye', 'weights': 0.00001}, - ], - make_feature_spec=lambda: { # pylint: disable=g-long-lambda - 'x': tf.io.FixedLenFeature([], tf.string), - 'weights': tf.io.FixedLenFeature([], tf.float32), - }, - top_k=2, - make_expected_vocab_fn=( - lambda _: [(b'hello', 3.02), (b'aaaaa', 1.423)] - ), - ), - dict( - testcase_name='_large_top_k', - input_data=[ - {'x': b'hello'}, - {'x': b'hello'}, - {'x': b'hello'}, - {'x': b' '}, - {'x': b'aaaaa'}, - {'x': b'aaaaa'}, - {'x': b'goodbye'}, - {'x': b'goodbye'}, - {'x': b' '}, - {'x': b''}, - {'x': b'goodbye'}, - {'x': b'goodbye'}, - ], - make_feature_spec=lambda: { # pylint: disable=g-long-lambda - 'x': tf.io.FixedLenFeature([], tf.string) - }, - top_k=100, - make_expected_vocab_fn=lambda file_format: ( # pylint: disable=g-long-lambda - [ # pylint: disable=g-long-ternary - (b'goodbye', 4), - (b'hello', 3), - (b'aaaaa', 2), - (b' ', 2), - ] - if file_format == 'text' - else [ - (b'goodbye', 4), - (b'hello', 3), - (b'aaaaa', 2), - (b' ', 2), - (b'', 1), - ] - ), - ), - dict( - testcase_name='_ragged', - input_data=[ - { - 'x$ragged_values': ['hello', ' '], - 'x$row_lengths_1': [1, 0, 1], - }, - {'x$ragged_values': ['hello'], 'x$row_lengths_1': [0, 1]}, - { - 'x$ragged_values': ['hello', 'goodbye'], - 'x$row_lengths_1': [2, 0, 0], - }, - { - 'x$ragged_values': ['hello', 'hello', ' ', ' '], - 'x$row_lengths_1': [0, 2, 2], - }, - ], - make_feature_spec=lambda: { # pylint: disable=g-long-lambda - 'x': tf.io.RaggedFeature( - tf.string, - value_key='x$ragged_values', - partitions=[ - tf.io.RaggedFeature.RowLengths('x$row_lengths_1') # pytype: disable=attribute-error - ], - ) - }, - top_k=2, - make_expected_vocab_fn=lambda _: [(b'hello', 5), (b' ', 3)], - ), - dict( - testcase_name='_sparse', - input_data=[ - { - 'x$sparse_indices_0': [0, 1], - 'x$sparse_indices_1': [2, 3], - 'x$sparse_values': [-4, 4], - }, - { - 'x$sparse_indices_0': [0, 1], - 'x$sparse_indices_1': [4, 1], - 'x$sparse_values': [2, 2], - }, - { - 'x$sparse_indices_0': [0, 1], - 'x$sparse_indices_1': [0, 3], - 'x$sparse_values': [2, 4], - }, - ], - make_feature_spec=lambda: { # pylint: disable=g-long-lambda - 'x': tf.io.SparseFeature( - ['x$sparse_indices_0', 'x$sparse_indices_1'], - 'x$sparse_values', - tf.int64, - [5, 5], - ) - }, - top_k=2, - make_expected_vocab_fn=lambda _: [(b'2', 3), (b'4', 2)], - ), - dict( - testcase_name='_newline_chars', - input_data=[ - {'x': b'aaaaa\n'}, - {'x': b'\n\n'}, - {'x': b''}, - {'x': b' '}, - {'x': b' '}, - {'x': b'aaaaa\n'}, - {'x': b'aaaaa\n'}, - {'x': b'aaaaa'}, - {'x': b'goo\rdbye'}, - {'x': b' '}, - {'x': b' '}, - {'x': b'aaaaa\n'}, - ], - make_feature_spec=( - lambda: {'x': tf.io.FixedLenFeature([], tf.string)} - ), - top_k=6, - make_expected_vocab_fn=( - lambda file_format: [(b' ', 4), (b'aaaaa', 1)] # pylint: disable=g-long-lambda,g-long-ternary - if file_format == 'text' - else [ - (b'aaaaa\n', 4), - (b' ', 4), - (b'goo\rdbye', 1), - (b'aaaaa', 1), - (b'\n\n', 1), - (b'', 1), - ] - ), - ), - ], - [ - dict(testcase_name='no_frequency', store_frequency=False), - dict(testcase_name='with_frequency', store_frequency=True), - ], - [ - dict(testcase_name='no_reserved_tokens', reserved_tokens=None), - dict(testcase_name='with_reserved_tokens', reserved_tokens=['A']), - ], - ) - ) - def testApproximateVocabulary( - self, - input_data, - make_feature_spec, - top_k, - make_expected_vocab_fn, - store_frequency, - reserved_tokens, - ): - input_metadata = tft.DatasetMetadata.from_feature_spec( - tft_unit.make_feature_spec_wrapper(make_feature_spec)) - - def preprocessing_fn(inputs): - x = inputs['x'] - weights = inputs.get('weights') - # Note even though the return value is not used, calling - # tft.experimental.approximate_vocabulary will generate the vocabulary as - # a side effect, and since we have named this vocabulary it can be looked - # up using public APIs. - tft.experimental.approximate_vocabulary( - x, - top_k, - store_frequency=store_frequency, - weights=weights, - vocab_filename='my_approximate_vocab', - reserved_tokens=reserved_tokens, - file_format=self._VocabFormat(), - ) - return inputs - - expected_vocab_file_contents = make_expected_vocab_fn(self._VocabFormat()) - if reserved_tokens is not None: - expected_vocab_file_contents = [ - (t, -1) for t in reserved_tokens - ] + expected_vocab_file_contents - if not store_frequency: - expected_vocab_file_contents = [ - token for token, _ in expected_vocab_file_contents - ] - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_vocab_file_contents={ - 'my_approximate_vocab': expected_vocab_file_contents - }) - - @tft_unit.named_parameters( - *tft_unit.cross_named_parameters( - [ - dict(testcase_name='no_frequency', store_frequency=False), - dict(testcase_name='with_frequency', store_frequency=True), - ], - [ - dict(testcase_name='no_reserved_tokens', reserved_tokens=None), - dict(testcase_name='with_reserved_tokens', reserved_tokens=['A']), - ], - ) - ) - def testComputeAndApplyApproximateVocabulary( - self, store_frequency, reserved_tokens - ): - input_data = [{'x': 'a'}] * 2 + [{'x': 'b'}] * 3 - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.string)}) - - def preprocessing_fn(inputs): - index = tft.experimental.compute_and_apply_approximate_vocabulary( - inputs['x'], - top_k=2, - file_format=self._VocabFormat(), - store_frequency=store_frequency, - reserved_tokens=reserved_tokens, - num_oov_buckets=1, - ) - return {'index': index} - - offset = len(reserved_tokens) if reserved_tokens else 0 - expected_data = [{'index': offset + idx} for idx in [1] * 2 + [0] * 3 + [2]] - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - test_data=input_data + [{'x': 'c'}]) # pyformat: disable - - def testEmptyComputeAndApplyApproximateVocabulary(self): - input_data = [{'x': ''}] * 3 - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.string)}) - - def preprocessing_fn(inputs): - index = tft.experimental.compute_and_apply_approximate_vocabulary( - inputs['x'], - top_k=2, - file_format=self._VocabFormat(), - num_oov_buckets=1) - return {'index': index} - - # We only filter empty tokens for `text` format. - expected_data = [{'index': 1 if self._VocabFormat() == 'text' else 0}] * 3 - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data) - - def testJointVocabularyForMultipleFeatures(self): - input_data = [{ - 'a': 'hello', - 'b': 'world', - 'c': 'aaaaa' - }, { - 'a': 'good', - 'b': '', - 'c': 'hello' - }, { - 'a': 'goodbye', - 'b': 'hello', - 'c': '\n' - }, { - 'a': ' ', - 'b': 'aaaaa', - 'c': 'bbbbb' - }] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': tf.io.FixedLenFeature([], tf.string), - 'b': tf.io.FixedLenFeature([], tf.string), - 'c': tf.io.FixedLenFeature([], tf.string) - }) - vocab_filename = 'test_compute_and_apply_vocabulary' - - def preprocessing_fn(inputs): - deferred_vocab_and_filename = tft.vocabulary( - tf.concat([inputs['a'], inputs['b'], inputs['c']], 0), - vocab_filename=vocab_filename, - file_format=self._VocabFormat()) - return { - 'index_a': - tft.apply_vocabulary( - inputs['a'], - deferred_vocab_and_filename, - file_format=self._VocabFormat()), - 'index_b': - tft.apply_vocabulary( - inputs['b'], - deferred_vocab_and_filename, - file_format=self._VocabFormat()) - } - - expected_vocab = [ - b'hello', b'aaaaa', b'world', b'goodbye', b'good', b'bbbbb', b' ', - b'\n', b'' - ] - empty_index = len(expected_vocab) - 1 - if self._VocabFormat() == 'text': - expected_vocab = expected_vocab[:-2] - empty_index = -1 - max_index = len(expected_vocab) - 1 - expected_data = [ - # For tied frequencies, larger (lexicographic) items come first. - { - 'index_a': 0, # hello - 'index_b': 2 # world - }, - { - 'index_a': 4, # good - 'index_b': empty_index # '' - }, - { - 'index_a': 3, # goodbye - 'index_b': 0 # hello - }, - { - 'index_a': 6, # ' ' - 'index_b': 1 # aaaaa - }, - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'index_a': tf.io.FixedLenFeature([], tf.int64), - 'index_b': tf.io.FixedLenFeature([], tf.int64), - }, { - 'index_a': - schema_pb2.IntDomain( - min=-1, max=max_index, is_categorical=True), - 'index_b': - schema_pb2.IntDomain( - min=-1, max=max_index, is_categorical=True), - }) - self.assertAnalyzeAndTransformResults( + @tft_unit.named_parameters( + *( + [ + dict( + testcase_name="_unadjusted_mi_binary_label", + x_data=[ + b"informative", + b"informative", + b"informative", + b"uninformative", + b"uninformative", + b"uninformative", + b"uninformative", + b"uninformative_rare", + b"uninformative_rare", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 1, 0, 1, 1, 0, 0, 1], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[ + (b"informative", 1.7548264), + (b"uninformative", 0.33985), + (b"uninformative_rare", 0.169925), + ], + min_diff_from_avg=0.0, + use_adjusted_mutual_info=False, + store_frequency=True, + ), + dict( + testcase_name="_unadjusted_mi_multi_class_label", + x_data=[ + b"good_predictor_of_0", + b"good_predictor_of_0", + b"good_predictor_of_0", + b"good_predictor_of_1", + b"good_predictor_of_2", + b"good_predictor_of_2", + b"good_predictor_of_2", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"weak_predictor_of_1", + b"good_predictor_of_0", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"weak_predictor_of_1", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[0, 0, 0, 1, 2, 2, 2, 1, 1, 1, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[ + (b"good_predictor_of_2", 6.9656613), + (b"good_predictor_of_1", 6.5969828), + (b"good_predictor_of_0", 6.339692), + (b"weak_predictor_of_1", 0.684463), + ], + min_diff_from_avg=0.0, + use_adjusted_mutual_info=False, + store_frequency=True, + ), + dict( + testcase_name="_unadjusted_mi_binary_label_with_weights", + x_data=[ + b"informative_1", + b"informative_1", + b"informative_0", + b"informative_0", + b"uninformative", + b"uninformative", + b"informative_by_weight", + b"informative_by_weight", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 0, 0, 0, 1, 0, 1], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + # uninformative and informative_by_weight have the same co-occurrence + # relationship with the label but will have different importance + # values due to the weighting. + expected_vocab_file_contents=[ + (b"informative_0", 3.1698803), + (b"informative_1", 1.1698843), + (b"informative_by_weight", 0.6096405), + (b"uninformative", 0.169925), + ], + weight_data=[1, 1, 1, 1, 1, 1, 1, 5], + weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), + min_diff_from_avg=0.0, + use_adjusted_mutual_info=False, + store_frequency=True, + ), + dict( + testcase_name="_unadjusted_mi_binary_label_min_diff_from_avg", + x_data=[ + b"hello", + b"hello", + b"hello", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + # All features are weak predictors, so all are adjusted to zero. + expected_vocab_file_contents=[ + (b"hello", 0.0), + (b"goodbye", 0.0), + (b"aaaaa", 0.0), + ], + use_adjusted_mutual_info=False, + min_diff_from_avg=2.0, + store_frequency=True, + ), + dict( + testcase_name="_adjusted_mi_binary_label", + x_data=[ + b"hello", + b"hello", + b"hello", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[ + (b"goodbye", 1.4070794), + (b"aaaaa", 0.9987448), + (b"hello", 0.5017178), + ], + min_diff_from_avg=0.0, + use_adjusted_mutual_info=True, + store_frequency=True, + ), + dict( + testcase_name="_adjusted_mi_binary_label_int64_feature", + x_data=[3, 3, 3, 1, 2, 2, 1, 1, 2, 2, 1, 1], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[ + (b"1", 1.4070794), + (b"2", 0.9987448), + (b"3", 0.5017178), + ], + min_diff_from_avg=0.0, + use_adjusted_mutual_info=True, + store_frequency=True, + ), + dict( + testcase_name="_adjusted_mi_multi_class_label", + x_data=[ + b"good_predictor_of_0", + b"good_predictor_of_0", + b"good_predictor_of_0", + b"good_predictor_of_1", + b"good_predictor_of_2", + b"good_predictor_of_2", + b"good_predictor_of_2", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"weak_predictor_of_1", + b"good_predictor_of_0", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"weak_predictor_of_1", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[0, 0, 0, 1, 2, 2, 2, 1, 1, 1, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[ + (b"good_predictor_of_1", 5.4800901), + (b"good_predictor_of_2", 5.3861019), + (b"good_predictor_of_0", 4.9054722), + (b"weak_predictor_of_1", -0.9748023), + ], + min_diff_from_avg=0.0, + use_adjusted_mutual_info=True, + store_frequency=True, + ), + # TODO(b/128831096): Determine correct interaction between AMI and weights + dict( + testcase_name="_adjusted_mi_binary_label_with_weights", + x_data=[ + b"informative_1", + b"informative_1", + b"informative_0", + b"informative_0", + b"uninformative", + b"uninformative", + b"informative_by_weight", + b"informative_by_weight", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 0, 0, 0, 1, 0, 1], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + weight_data=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0], + weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), + # uninformative and informative_by_weight have the same co-occurrence + # relationship with the label but will have different importance + # values due to the weighting. + expected_vocab_file_contents=[ + (b"informative_0", 2.3029856), + (b"informative_1", 0.3029896), + (b"informative_by_weight", 0.1713041), + (b"uninformative", -0.6969697), + ], + min_diff_from_avg=0.0, + use_adjusted_mutual_info=True, + store_frequency=True, + ), + dict( + testcase_name="_adjusted_mi_min_diff_from_avg", + x_data=[ + b"good_predictor_of_0", + b"good_predictor_of_0", + b"good_predictor_of_0", + b"good_predictor_of_1", + b"good_predictor_of_0", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"good_predictor_of_0", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"good_predictor_of_1", + b"weak_predictor_of_1", + b"weak_predictor_of_1", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + # With min_diff_from_avg, the small AMI value is regularized to 0 + expected_vocab_file_contents=[ + (b"good_predictor_of_0", 1.8322128), + (b"good_predictor_of_1", 1.7554416), + (b"weak_predictor_of_1", 0), + ], + use_adjusted_mutual_info=True, + min_diff_from_avg=1.0, + store_frequency=True, + ), + dict( + testcase_name="_labels_weight_and_frequency", + x_data=[ + b"hello", + b"hello", + b"hello", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + weight_data=[ + 0.3, + 0.4, + 0.3, + 1.2, + 0.6, + 0.7, + 1.0, + 1.0, + 0.6, + 0.7, + 1.0, + 1.0, + ], + weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), + expected_vocab_file_contents=[ + (b"aaaaa", 1.5637185), + (b"goodbye", 0.8699492), + (b"hello", 0.6014302), + ], + min_diff_from_avg=0.0, + store_frequency=True, + ), + # fingerprints by which each of the tokens will be sorted if fingerprint + # shuffling is used. + # 'ho ho': '1b3dd735ddff70d90f3b7ba5ebf65df521d6ca4d' + # 'world': '7c211433f02071597741e6ff5a8ea34789abbf43' + # 'hello': 'aaf4c61ddcc5e8a2dabede0f3b482cd9aea9434d' + # 'hi': 'c22b5f9178342609428d6f51b2c5af4c0bde6a42' + # '1': '356a192b7913b04c54574d18c28d46e6395428ab' + # '2': 'da4b9237bacccdf19c0760cab7aec4a8359010b0' + # '3': '77de68daecd823babbb58edb1c8e14d7106e83bb' + dict( + testcase_name="_string_feature_with_frequency_and_shuffle", + x_data=[b"world", b"hello", b"hello"], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + expected_vocab_file_contents=[(b"world", 1), (b"hello", 2)], + fingerprint_shuffle=True, + store_frequency=True, + ), + dict( + testcase_name="_string_feature_with_frequency_and_no_shuffle", + x_data=[b"hi", b"ho ho", b"ho ho"], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + expected_vocab_file_contents=[(b"ho ho", 2), (b"hi", 1)], + store_frequency=True, + ), + dict( + testcase_name="_string_feature_with_no_frequency_and_shuffle", + x_data=[b"world", b"hello", b"hello"], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + expected_vocab_file_contents=[b"world", b"hello"], + fingerprint_shuffle=True, + ), + dict( + testcase_name="_string_feature_with_no_frequency_and_no_shuffle", + x_data=[b"world", b"hello", b"hello"], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + expected_vocab_file_contents=[b"hello", b"world"], + ), + dict( + testcase_name="_int_feature_with_frequency_and_shuffle", + x_data=[1, 2, 2, 3], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[(b"1", 1), (b"3", 1), (b"2", 2)], + fingerprint_shuffle=True, + store_frequency=True, + ), + dict( + testcase_name="_int_feature_with_frequency_and_no_shuffle", + x_data=[2, 1, 1, 1], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[(b"1", 3), (b"2", 1)], + store_frequency=True, + ), + dict( + testcase_name="_int_feature_with_no_frequency_and_shuffle", + x_data=[1, 2, 2, 3], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[b"1", b"3", b"2"], + fingerprint_shuffle=True, + ), + dict( + testcase_name="_int_feature_with_no_frequency_and_no_shuffle", + x_data=[1, 2, 2, 3], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[b"2", b"3", b"1"], + ), + dict( + testcase_name="_int_feature_with_top_k", + x_data=[111, 2, 2, 3], + top_k=2, + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + expected_vocab_file_contents=[b"2", b"3"], + ), + ] + + _WITH_LABEL_PARAMS + ) + ) + def testVocabulary( + self, + x_data, + x_feature_spec, + label_data=None, + label_feature_spec=None, + weight_data=None, + weight_feature_spec=None, + expected_vocab_file_contents=None, + **kwargs, + ): + """Test tft.Vocabulary with various inputs.""" + input_data = [{"x": x} for x in x_data] + input_feature_spec = {"x": x_feature_spec} + + if label_data is not None: + for idx, label in enumerate(label_data): + input_data[idx]["label"] = label + input_feature_spec["label"] = label_feature_spec + + if weight_data is not None: + for idx, weight in enumerate(weight_data): + input_data[idx]["weights"] = weight + input_feature_spec["weights"] = weight_feature_spec + + input_metadata = tft.DatasetMetadata.from_feature_spec(input_feature_spec) + + def preprocessing_fn(inputs): + x = inputs["x"] + labels = inputs.get("label") + weights = inputs.get("weights") + # Note even though the return value is not used, calling tft.vocabulary + # will generate the vocabulary as a side effect, and since we have named + # this vocabulary it can be looked up using public APIs. + tft.vocabulary( + x, + labels=labels, + weights=weights, + vocab_filename="my_vocab", + file_format=self._VocabFormat(), + **kwargs, + ) + return inputs + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + input_data, # expected output data is same as input data + input_metadata, # expected output metadata is same as input metadata + expected_vocab_file_contents={"my_vocab": expected_vocab_file_contents}, + ) + + @tft_unit.named_parameters( + *tft_unit.cross_named_parameters( + [ + dict( + testcase_name="_string", + input_data=[ + {"x": b"hello"}, + {"x": b"hello"}, + {"x": b"hello"}, + {"x": b"goodbye"}, + {"x": b"aaaaa"}, + {"x": b"aaaaa"}, + {"x": b"goodbye"}, + {"x": b"goodbye"}, + {"x": b"aaaaa"}, + {"x": b"aaaaa"}, + {"x": b"goodbye"}, + {"x": b"goodbye"}, + ], + make_feature_spec=lambda: { # pylint: disable=g-long-lambda + "x": tf.io.FixedLenFeature([], tf.string) + }, + top_k=2, + make_expected_vocab_fn=(lambda _: [(b"goodbye", 5), (b"aaaaa", 4)]), + ), + dict( + testcase_name="_int", + input_data=[{"x": 1}, {"x": 2}, {"x": 2}, {"x": 3}, {"x": 1}], + make_feature_spec=lambda: { # pylint: disable=g-long-lambda + "x": tf.io.FixedLenFeature([], tf.int64) + }, + top_k=2, + make_expected_vocab_fn=lambda _: [(b"2", 2), (b"1", 2)], + ), + dict( + testcase_name="_weights", + input_data=[ + {"x": b"hello", "weights": 1.4}, + {"x": b"hello", "weights": 0.5}, + {"x": b"hello", "weights": 1.12}, + {"x": b"goodbye", "weights": 0.123}, + {"x": b"aaaaa", "weights": 0.3}, + {"x": b"aaaaa", "weights": 1.123}, + {"x": b"goodbye", "weights": 0.1}, + {"x": b"goodbye", "weights": 0.00001}, + ], + make_feature_spec=lambda: { # pylint: disable=g-long-lambda + "x": tf.io.FixedLenFeature([], tf.string), + "weights": tf.io.FixedLenFeature([], tf.float32), + }, + top_k=2, + make_expected_vocab_fn=( + lambda _: [(b"hello", 3.02), (b"aaaaa", 1.423)] + ), + ), + dict( + testcase_name="_large_top_k", + input_data=[ + {"x": b"hello"}, + {"x": b"hello"}, + {"x": b"hello"}, + {"x": b" "}, + {"x": b"aaaaa"}, + {"x": b"aaaaa"}, + {"x": b"goodbye"}, + {"x": b"goodbye"}, + {"x": b" "}, + {"x": b""}, + {"x": b"goodbye"}, + {"x": b"goodbye"}, + ], + make_feature_spec=lambda: { # pylint: disable=g-long-lambda + "x": tf.io.FixedLenFeature([], tf.string) + }, + top_k=100, + make_expected_vocab_fn=lambda file_format: ( # pylint: disable=g-long-lambda + [ # pylint: disable=g-long-ternary + (b"goodbye", 4), + (b"hello", 3), + (b"aaaaa", 2), + (b" ", 2), + ] + if file_format == "text" + else [ + (b"goodbye", 4), + (b"hello", 3), + (b"aaaaa", 2), + (b" ", 2), + (b"", 1), + ] + ), + ), + dict( + testcase_name="_ragged", + input_data=[ + { + "x$ragged_values": ["hello", " "], + "x$row_lengths_1": [1, 0, 1], + }, + {"x$ragged_values": ["hello"], "x$row_lengths_1": [0, 1]}, + { + "x$ragged_values": ["hello", "goodbye"], + "x$row_lengths_1": [2, 0, 0], + }, + { + "x$ragged_values": ["hello", "hello", " ", " "], + "x$row_lengths_1": [0, 2, 2], + }, + ], + make_feature_spec=lambda: { # pylint: disable=g-long-lambda + "x": tf.io.RaggedFeature( + tf.string, + value_key="x$ragged_values", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "x$row_lengths_1" + ) # pytype: disable=attribute-error + ], + ) + }, + top_k=2, + make_expected_vocab_fn=lambda _: [(b"hello", 5), (b" ", 3)], + ), + dict( + testcase_name="_sparse", + input_data=[ + { + "x$sparse_indices_0": [0, 1], + "x$sparse_indices_1": [2, 3], + "x$sparse_values": [-4, 4], + }, + { + "x$sparse_indices_0": [0, 1], + "x$sparse_indices_1": [4, 1], + "x$sparse_values": [2, 2], + }, + { + "x$sparse_indices_0": [0, 1], + "x$sparse_indices_1": [0, 3], + "x$sparse_values": [2, 4], + }, + ], + make_feature_spec=lambda: { # pylint: disable=g-long-lambda + "x": tf.io.SparseFeature( + ["x$sparse_indices_0", "x$sparse_indices_1"], + "x$sparse_values", + tf.int64, + [5, 5], + ) + }, + top_k=2, + make_expected_vocab_fn=lambda _: [(b"2", 3), (b"4", 2)], + ), + dict( + testcase_name="_newline_chars", + input_data=[ + {"x": b"aaaaa\n"}, + {"x": b"\n\n"}, + {"x": b""}, + {"x": b" "}, + {"x": b" "}, + {"x": b"aaaaa\n"}, + {"x": b"aaaaa\n"}, + {"x": b"aaaaa"}, + {"x": b"goo\rdbye"}, + {"x": b" "}, + {"x": b" "}, + {"x": b"aaaaa\n"}, + ], + make_feature_spec=( + lambda: {"x": tf.io.FixedLenFeature([], tf.string)} + ), + top_k=6, + make_expected_vocab_fn=( + lambda file_format: [(b" ", 4), (b"aaaaa", 1)] # pylint: disable=g-long-lambda,g-long-ternary + if file_format == "text" + else [ + (b"aaaaa\n", 4), + (b" ", 4), + (b"goo\rdbye", 1), + (b"aaaaa", 1), + (b"\n\n", 1), + (b"", 1), + ] + ), + ), + ], + [ + dict(testcase_name="no_frequency", store_frequency=False), + dict(testcase_name="with_frequency", store_frequency=True), + ], + [ + dict(testcase_name="no_reserved_tokens", reserved_tokens=None), + dict(testcase_name="with_reserved_tokens", reserved_tokens=["A"]), + ], + ) + ) + def testApproximateVocabulary( + self, input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - expected_vocab_file_contents={vocab_filename: expected_vocab}) - - _EMPTY_VOCABULARY_PARAMS = tft_unit.cross_named_parameters([ - dict( - testcase_name='_string', - x_data=['a', 'b'], - x_feature_spec=tf.io.FixedLenFeature([], tf.string)), - dict( - testcase_name='_int64', - x_data=[1, 2], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64)), - ], [ - dict( - testcase_name='empty_vocabulary', - index_data=[-1, -1], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=0, is_categorical=True), - frequency_threshold=5), - ]) - - @tft_unit.named_parameters(*([ - dict( - testcase_name='_string_feature_with_label_top_2', - x_data=[ - b'hello', b'hello', b'hello', b'goodbye', b'aaaaa', b'aaaaa', - b'goodbye', b'goodbye', b'aaaaa', b'aaaaa', b'goodbye', b'goodbye' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[-1, -1, -1, 0, 1, 1, 0, 0, 0, 1, 1, 0], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=1, is_categorical=True), - top_k=2), - dict( - testcase_name='_string_feature_with_label_top_1', - x_data=[ - b'hello', b'hello', b'hello', b'goodbye', b'aaaaa', b'aaaaa', - b'goodbye', b'goodbye', b'aaaaa', b'aaaaa', b'goodbye', b'goodbye' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[-1, -1, -1, 0, -1, -1, 0, 0, 0, -1, -1, 0], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=0, is_categorical=True), - top_k=1), - dict( - testcase_name='_int_feature_with_label_top_2', - x_data=[3, 3, 3, 1, 2, 2, 1, 1, 2, 2, 1, 1], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[-1, -1, -1, 0, 1, 1, 0, 0, 0, 1, 1, 0], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=1, is_categorical=True), - top_k=2), - dict( - testcase_name='_varlen_feature', - x_data=[[b'world', b'hello', b'hello'], [b'hello', b'world', b'foo'], - [], [b'hello']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[1, 0, 0], [0, 1, -99], [], [0]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=1, is_categorical=True), - default_value=-99, - top_k=2), - dict( - testcase_name='_vector_feature', - x_data=[[b'world', b'hello', b'hello'], [b'hello', b'world', b'moo'], - [b'hello', b'hello', b'foo'], [b'world', b'foo', b'moo']], - x_feature_spec=tf.io.FixedLenFeature([3], tf.string), - index_data=[[1, 0, 0], [0, 1, -99], [0, 0, -99], [1, -99, -99]], - index_feature_spec=tf.io.FixedLenFeature([3], tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=1, is_categorical=True), - default_value=-99, - top_k=2), - dict( - testcase_name='_varlen_feature_with_labels', - x_data=[[b'hello', b'world', b'bye', b'moo'], - [b'world', b'moo', b'foo'], [b'hello', b'foo', b'moo'], - [b'moo']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - label_data=[1, 0, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[[0, -99, 1, -99], [-99, -99, -99], [0, -99, -99], [-99]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=1, is_categorical=True), - default_value=-99, - top_k=2), - dict( - testcase_name='_vector_feature_with_labels', - x_data=[[b'world', b'hello', b'hi'], [b'hello', b'world', b'moo'], - [b'hello', b'bye', b'foo'], [b'world', b'foo', b'moo']], - x_feature_spec=tf.io.FixedLenFeature([3], tf.string), - label_data=[1, 0, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[[-99, -99, 1], [-99, -99, 0], [-99, -99, -99], - [-99, -99, 0]], - index_feature_spec=tf.io.FixedLenFeature([3], tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=1, is_categorical=True), - default_value=-99, - top_k=2), - dict( - testcase_name='_varlen_integer_feature_with_labels', - x_data=[[0, 1, 3, 2], [1, 2, 4], [0, 4, 2], [2]], - x_feature_spec=tf.io.VarLenFeature(tf.int64), - label_data=[1, 0, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[[0, -99, 1, -99], [-99, -99, -99], [0, -99, -99], [-99]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=1, is_categorical=True), - default_value=-99, - top_k=2), - dict( - testcase_name='_varlen_feature_with_some_empty_feature_values', - x_data=[[b'world', b'hello', b'hi', b'moo'], [], - [b'world', b'hello', b'foo'], []], - x_feature_spec=tf.io.VarLenFeature(tf.string), - label_data=[1, 0, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[[0, 1, -99, -99], [], [0, 1, -99], []], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=1, is_categorical=True), - default_value=-99, - top_k=2), - dict( - testcase_name='_varlen_with_multiclass_labels', - x_data=[[1, 2, 3, 5], [1, 4, 5], [1, 2], [1, 2], [1, 3, 5], [1, 4, 3], - [1, 3]], - x_feature_spec=tf.io.VarLenFeature(tf.int64), - label_data=[1, 0, 1, 1, 4, 5, 4], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[[-1, 0, 2, 3], [-1, 1, 3], [-1, 0], [-1, 0], [-1, 2, 3], - [-1, 1, 2], [-1, 2]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=3, is_categorical=True), - top_k=4), - dict( - testcase_name='_labels_and_weights', - x_data=[ - b'hello', b'hello', b'hello', b'goodbye', b'aaaaa', b'aaaaa', - b'goodbye', b'goodbye', b'aaaaa', b'aaaaa', b'goodbye', b'goodbye' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - weight_data=[ - 0.3, 0.4, 0.3, 1.2, 0.6, 0.7, 1.0, 1.0, 0.6, 0.7, 1.0, 1.0 - ], - weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), - index_data=[2, 2, 2, 1, 0, 0, 1, 1, 0, 0, 1, 1], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=2, - is_categorical=True)), - dict( - testcase_name='_string_feature_with_weights', - x_data=[ - b'hello', b'world', b'goodbye', b'aaaaa', b'aaaaa', b'goodbye' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - weight_data=[1.0, .5, 1.0, .26, .25, 1.5], - weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), - index_data=[1, 3, 0, 2, 2, 0], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=3, - is_categorical=True)), - dict( - testcase_name='_int64_feature_with_weights', - x_data=[2, 1, 3, 4, 4, 3], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - weight_data=[1.0, .5, 1.0, .26, .25, 1.5], - weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), - index_data=[1, 3, 0, 2, 2, 0], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=3, - is_categorical=True)), - dict( - testcase_name='_whitespace_newlines_and_empty_strings_text', - x_data=[ - b'hello', b'world', b'hello', b'hello', b'goodbye', b'world', - b'aaaaa', b' ', b'', b'\n', b'hi \n ho \n', '\r' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - # The empty string and strings containing newlines map to default - # value because the vocab cannot contain them. - index_data=[0, 1, 0, 0, 2, 1, 3, 4, -1, -1, -1, -1], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=4, is_categorical=True), - vocab_filename='my_vocab', - expected_vocab_file_contents={ - 'my_vocab': [b'hello', b'world', b'goodbye', b'aaaaa', b' '] - }, - required_format='text'), - dict( - testcase_name='_whitespace_newlines_and_empty_strings_tfrecord', - x_data=[ - b'hello', b'world', b'hello', b'hello', b'goodbye', b'world', - b'aaaaa', b' ', b'', b'\n', b'hi \n ho \n', b'\r' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - index_data=[0, 0, 0, 1, 1, 8, 3, 2, 4, 5, 6, 7], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=8, is_categorical=True), - vocab_filename='my_vocab', - expected_vocab_file_contents={ - 'my_vocab': [ - b'hello', b'world', b'hi \n ho \n', b'goodbye', b'aaaaa', - b' ', b'\r', b'\n', b'' - ] - }, - required_format='tfrecord_gzip'), - dict( - testcase_name='_whitespace_newlines_empty_oov_buckets_text', - x_data=[ - b'hello', b'world', b'hello', b'hello', b'goodbye', b'world', - b'aaaaa', b' ', b'', b'\n', b'hi \n ho \n', '\r' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - index_data=[0, 1, 0, 0, 2, 1, 3, 4, 5, 5, 5, 5], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=0, max=5, is_categorical=True), - num_oov_buckets=1, - vocab_filename='my_vocab', - expected_vocab_file_contents={ - 'my_vocab': [b'hello', b'world', b'goodbye', b'aaaaa', b' '] - }, - required_format='text'), - dict( - testcase_name='_whitespace_newlines_empty_oov_buckets_tfrecord', - x_data=[ - b'hello', b'world', b'hello', b'hello', b'goodbye', b'world', - b'aaaaa', b' ', b'', b'\n', b'hi \n ho \n', '\r' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - index_data=[0, 0, 1, 0, 1, 8, 3, 2, 4, 5, 6, 7], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=0, max=9, is_categorical=True), - num_oov_buckets=1, - vocab_filename='my_vocab', - expected_vocab_file_contents={ - 'my_vocab': [ - b'hello', b'world', b'hi \n ho \n', b'goodbye', b'aaaaa', - b' ', b'\r', b'\n', b'' - ] - }, - required_format='tfrecord_gzip'), - dict( - testcase_name='_positive_and_negative_integers', - x_data=[13, 14, 13, 13, 12, 14, 11, 10, 10, -10, -10, -20], - x_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[0, 1, 0, 0, 4, 1, 5, 2, 2, 3, 3, 6], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=6, is_categorical=True), - vocab_filename='my_vocab', - expected_vocab_file_contents={ - 'my_vocab': [b'13', b'14', b'10', b'-10', b'12', b'11', b'-20'] - }), - dict( - testcase_name='_rank_2', - x_data=[[[b'some', b'say'], [b'the', b'world']], - [[b'will', b'end'], [b'in', b'fire']], - [[b'some', b'say'], [b'in', b'ice']]], - x_feature_spec=tf.io.FixedLenFeature([2, 2], tf.string), - index_data=[[[0, 1], [5, 3]], [[4, 8], [2, 7]], [[0, 1], [2, 6]]], - index_feature_spec=tf.io.FixedLenFeature([2, 2], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=8, - is_categorical=True)), - dict( - testcase_name='_top_k', - x_data=[[b'hello', b'hello', b'world'], - [b'hello', b'goodbye', b'world'], - [b'hello', b'goodbye', b'foo']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[0, 0, 1], [0, -99, 1], [0, -99, -99]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=1, is_categorical=True), - default_value=-99, - top_k=2), - dict( - testcase_name='_top_k_specified_as_str', - x_data=[[b'hello', b'hello', b'world'], - [b'hello', b'goodbye', b'world'], - [b'hello', b'goodbye', b'foo']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[0, 0, 1], [0, -9, 1], [0, -9, -9]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain(min=-9, max=1, is_categorical=True), - default_value=-9, - top_k='2'), - dict( - testcase_name='_frequency_threshold', - x_data=[[b'hello', b'hello', b'world'], - [b'hello', b'goodbye', b'world'], - [b'hello', b'goodbye', b'foo']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[0, 0, 1], [0, 2, 1], [0, 2, -99]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=2, is_categorical=True), - default_value=-99, - frequency_threshold=2), - dict( - testcase_name='_frequency_threshold_specified_with_str', - x_data=[[b'hello', b'hello', b'world'], - [b'hello', b'goodbye', b'world'], - [b'hello', b'goodbye', b'foo']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[0, 0, 1], [0, 2, 1], [0, 2, -9]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain(min=-9, max=2, is_categorical=True), - default_value=-9, - frequency_threshold='2'), - dict( - testcase_name='_empty_vocabulary_from_high_frequency_threshold', - x_data=[[b'hello', b'hello', b'world'], - [b'hello', b'goodbye', b'world'], - [b'hello', b'goodbye', b'foo']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[-99, -99, -99], [-99, -99, -99], [-99, -99, -99]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=0, is_categorical=True), - default_value=-99, - frequency_threshold=77), - dict( - testcase_name='_top_k_and_oov', - x_data=[[b'hello', b'hello', b'world', b'world'], - [b'hello', b'tarkus', b'toccata'], - [b'hello', b'goodbye', b'foo']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - # Generated vocab (ordered by frequency, then value) should be: - # ["hello", "world", "goodbye", "foo", "tarkus", "toccata"]. After - # applying top_k =1 this becomes ["hello"] plus three OOV buckets. - # The specific output values here depend on the hash of the words, - # and the test will break if the hash changes. - index_data=[[0, 0, 2, 2], [0, 3, 1], [0, 2, 1]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain(min=0, max=3, is_categorical=True), - default_value=-99, - top_k=1, - num_oov_buckets=3), - dict( - testcase_name='_key_fn', - x_data=[['a_X_1', 'a_X_1', 'a_X_2', 'b_X_1', 'b_X_2'], - ['a_X_1', 'a_X_1', 'a_X_2', 'a_X_2'], ['b_X_2']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[0, 0, 1, -99, 2], [0, 0, 1, 1], [2]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=2, is_categorical=True), - coverage_top_k=1, - default_value=-99, - key_fn=lambda s: s.split(b'_X_')[0], - frequency_threshold=3), - dict( - testcase_name='_key_fn_and_multi_coverage_top_k', - x_data=[['a_X_1', 'a_X_1', 'a_X_2', 'b_X_1', 'b_X_2'], - ['a_X_1', 'a_X_1', 'a_X_2', 'a_X_2', 'a_X_3'], ['b_X_2']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[0, 0, 1, 3, 2], [0, 0, 1, 1, -99], [2]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=3, is_categorical=True), - coverage_top_k=2, - default_value=-99, - key_fn=lambda s: s.split(b'_X_')[0], - frequency_threshold=300), - dict( - testcase_name='_key_fn_and_top_k', - x_data=[['a_X_1', 'a_X_1', 'a_X_2', 'b_X_1', 'b_X_2'], - ['a_X_1', 'a_X_1', 'a_X_2', 'a_X_2'], - ['b_X_2', 'b_X_2', 'b_X_2', 'b_X_2', 'c_X_1']], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[1, 1, -99, -99, 0], [1, 1, -99, -99], [0, 0, 0, 0, 2]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=2, is_categorical=True), - coverage_top_k=1, - default_value=-99, - key_fn=lambda s: s.split(b'_X_')[0], - top_k=2), - dict( - testcase_name='_key_fn_multi_coverage_top_k', - x_data=[ - ['0_X_a', '0_X_a', '5_X_a', '6_X_a', '6_X_a', '0_X_a'], - ['0_X_a', '2_X_a', '2_X_a', '2_X_a', '0_X_a', '5_X_a'], - ['1_X_b', '1_X_b', '3_X_b', '3_X_b', '0_X_b', '1_X_b', '1_X_b'] - ], - x_feature_spec=tf.io.VarLenFeature(tf.string), - index_data=[[0, 0, -99, -99, -99, 0], [0, 2, 2, 2, 0, -99], - [1, 1, 3, 3, -99, 1, 1]], - index_feature_spec=tf.io.VarLenFeature(tf.int64), - index_domain=schema_pb2.IntDomain( - min=-99, max=3, is_categorical=True), - coverage_top_k=2, - default_value=-99, - key_fn=lambda s: s.split(b'_X_')[1], - frequency_threshold=4), - dict( - testcase_name='_key_fn_and_labels', - x_data=[ - 'aaa', 'aaa', 'aaa', 'aab', 'aba', 'aba', 'aab', 'aab', 'aba', - 'abc', 'abc', 'aab' - ], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - label_data=[1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0], - label_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_data=[0, 0, 0, -1, -1, -1, -1, -1, -1, 1, 1, -1], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=1, is_categorical=True), - coverage_top_k=1, - key_fn=lambda s: s[:2], - frequency_threshold=3), - dict( - testcase_name='_key_fn_and_weights', - x_data=['xa', 'xa', 'xb', 'ya', 'yb', 'yc'], - x_feature_spec=tf.io.FixedLenFeature([], tf.string), - weight_data=[1.0, 0.5, 3.0, 0.6, 0.25, 0.5], - weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), - index_data=[1, 1, 0, -1, -1, -1], - index_feature_spec=tf.io.FixedLenFeature([], tf.int64), - index_domain=schema_pb2.IntDomain(min=-1, max=1, is_categorical=True), - coverage_top_k=1, - key_fn=lambda s: s[0], - frequency_threshold=1.5, - coverage_frequency_threshold=1), - ] + _EMPTY_VOCABULARY_PARAMS)) - def testComputeAndApplyVocabulary(self, - x_data, - x_feature_spec, - index_data, - index_feature_spec, - index_domain, - label_data=None, - label_feature_spec=None, - weight_data=None, - weight_feature_spec=None, - expected_vocab_file_contents=None, - required_format=None, - **kwargs): - """Test tft.compute_and_apply_vocabulary with various inputs.""" - if required_format is not None and required_format != self._VocabFormat(): - raise tft_unit.SkipTest('Test only applicable to format: {}.'.format( - self._VocabFormat())) - - input_data = [{'x': x} for x in x_data] - input_feature_spec = {'x': x_feature_spec} - expected_data = [{'index': index} for index in index_data] - expected_feature_spec = {'index': index_feature_spec} - expected_domains = {'index': index_domain} - - if label_data is not None: - for idx, label in enumerate(label_data): - input_data[idx]['label'] = label - input_feature_spec['label'] = label_feature_spec - - if weight_data is not None: - for idx, weight in enumerate(weight_data): - input_data[idx]['weights'] = weight - input_feature_spec['weights'] = weight_feature_spec - - input_metadata = tft.DatasetMetadata.from_feature_spec(input_feature_spec) - expected_metadata = tft.DatasetMetadata.from_feature_spec( - expected_feature_spec, expected_domains) - - def preprocessing_fn(inputs): - x = inputs['x'] - labels = inputs.get('label') - weights = inputs.get('weights') - index = tft.compute_and_apply_vocabulary( - x, - labels=labels, - weights=weights, - file_format=self._VocabFormat(), - **kwargs) - return {'index': index} - - self.assertAnalyzeAndTransformResults( + make_feature_spec, + top_k, + make_expected_vocab_fn, + store_frequency, + reserved_tokens, + ): + input_metadata = tft.DatasetMetadata.from_feature_spec( + tft_unit.make_feature_spec_wrapper(make_feature_spec) + ) + + def preprocessing_fn(inputs): + x = inputs["x"] + weights = inputs.get("weights") + # Note even though the return value is not used, calling + # tft.experimental.approximate_vocabulary will generate the vocabulary as + # a side effect, and since we have named this vocabulary it can be looked + # up using public APIs. + tft.experimental.approximate_vocabulary( + x, + top_k, + store_frequency=store_frequency, + weights=weights, + vocab_filename="my_approximate_vocab", + reserved_tokens=reserved_tokens, + file_format=self._VocabFormat(), + ) + return inputs + + expected_vocab_file_contents = make_expected_vocab_fn(self._VocabFormat()) + if reserved_tokens is not None: + expected_vocab_file_contents = [ + (t, -1) for t in reserved_tokens + ] + expected_vocab_file_contents + if not store_frequency: + expected_vocab_file_contents = [ + token for token, _ in expected_vocab_file_contents + ] + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_vocab_file_contents={ + "my_approximate_vocab": expected_vocab_file_contents + }, + ) + + @tft_unit.named_parameters( + *tft_unit.cross_named_parameters( + [ + dict(testcase_name="no_frequency", store_frequency=False), + dict(testcase_name="with_frequency", store_frequency=True), + ], + [ + dict(testcase_name="no_reserved_tokens", reserved_tokens=None), + dict(testcase_name="with_reserved_tokens", reserved_tokens=["A"]), + ], + ) + ) + def testComputeAndApplyApproximateVocabulary( + self, store_frequency, reserved_tokens + ): + input_data = [{"x": "a"}] * 2 + [{"x": "b"}] * 3 + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.string)} + ) + + def preprocessing_fn(inputs): + index = tft.experimental.compute_and_apply_approximate_vocabulary( + inputs["x"], + top_k=2, + file_format=self._VocabFormat(), + store_frequency=store_frequency, + reserved_tokens=reserved_tokens, + num_oov_buckets=1, + ) + return {"index": index} + + offset = len(reserved_tokens) if reserved_tokens else 0 + expected_data = [{"index": offset + idx} for idx in [1] * 2 + [0] * 3 + [2]] + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + test_data=input_data + [{"x": "c"}], + ) # pyformat: disable + + def testEmptyComputeAndApplyApproximateVocabulary(self): + input_data = [{"x": ""}] * 3 + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.string)} + ) + + def preprocessing_fn(inputs): + index = tft.experimental.compute_and_apply_approximate_vocabulary( + inputs["x"], top_k=2, file_format=self._VocabFormat(), num_oov_buckets=1 + ) + return {"index": index} + + # We only filter empty tokens for `text` format. + expected_data = [{"index": 1 if self._VocabFormat() == "text" else 0}] * 3 + self.assertAnalyzeAndTransformResults( + input_data, input_metadata, preprocessing_fn, expected_data + ) + + def testJointVocabularyForMultipleFeatures(self): + input_data = [ + {"a": "hello", "b": "world", "c": "aaaaa"}, + {"a": "good", "b": "", "c": "hello"}, + {"a": "goodbye", "b": "hello", "c": "\n"}, + {"a": " ", "b": "aaaaa", "c": "bbbbb"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature([], tf.string), + "b": tf.io.FixedLenFeature([], tf.string), + "c": tf.io.FixedLenFeature([], tf.string), + } + ) + vocab_filename = "test_compute_and_apply_vocabulary" + + def preprocessing_fn(inputs): + deferred_vocab_and_filename = tft.vocabulary( + tf.concat([inputs["a"], inputs["b"], inputs["c"]], 0), + vocab_filename=vocab_filename, + file_format=self._VocabFormat(), + ) + return { + "index_a": tft.apply_vocabulary( + inputs["a"], + deferred_vocab_and_filename, + file_format=self._VocabFormat(), + ), + "index_b": tft.apply_vocabulary( + inputs["b"], + deferred_vocab_and_filename, + file_format=self._VocabFormat(), + ), + } + + expected_vocab = [ + b"hello", + b"aaaaa", + b"world", + b"goodbye", + b"good", + b"bbbbb", + b" ", + b"\n", + b"", + ] + empty_index = len(expected_vocab) - 1 + if self._VocabFormat() == "text": + expected_vocab = expected_vocab[:-2] + empty_index = -1 + max_index = len(expected_vocab) - 1 + expected_data = [ + # For tied frequencies, larger (lexicographic) items come first. + { + "index_a": 0, # hello + "index_b": 2, # world + }, + { + "index_a": 4, # good + "index_b": empty_index, # '' + }, + { + "index_a": 3, # goodbye + "index_b": 0, # hello + }, + { + "index_a": 6, # ' ' + "index_b": 1, # aaaaa + }, + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "index_a": tf.io.FixedLenFeature([], tf.int64), + "index_b": tf.io.FixedLenFeature([], tf.int64), + }, + { + "index_a": schema_pb2.IntDomain( + min=-1, max=max_index, is_categorical=True + ), + "index_b": schema_pb2.IntDomain( + min=-1, max=max_index, is_categorical=True + ), + }, + ) + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + expected_vocab_file_contents={vocab_filename: expected_vocab}, + ) + + _EMPTY_VOCABULARY_PARAMS = tft_unit.cross_named_parameters( + [ + dict( + testcase_name="_string", + x_data=["a", "b"], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + ), + dict( + testcase_name="_int64", + x_data=[1, 2], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + ), + ], + [ + dict( + testcase_name="empty_vocabulary", + index_data=[-1, -1], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain(min=-1, max=0, is_categorical=True), + frequency_threshold=5, + ), + ], + ) + + @tft_unit.named_parameters( + *( + [ + dict( + testcase_name="_string_feature_with_label_top_2", + x_data=[ + b"hello", + b"hello", + b"hello", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[-1, -1, -1, 0, 1, 1, 0, 0, 0, 1, 1, 0], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=1, is_categorical=True + ), + top_k=2, + ), + dict( + testcase_name="_string_feature_with_label_top_1", + x_data=[ + b"hello", + b"hello", + b"hello", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[-1, -1, -1, 0, -1, -1, 0, 0, 0, -1, -1, 0], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=0, is_categorical=True + ), + top_k=1, + ), + dict( + testcase_name="_int_feature_with_label_top_2", + x_data=[3, 3, 3, 1, 2, 2, 1, 1, 2, 2, 1, 1], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[-1, -1, -1, 0, 1, 1, 0, 0, 0, 1, 1, 0], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=1, is_categorical=True + ), + top_k=2, + ), + dict( + testcase_name="_varlen_feature", + x_data=[ + [b"world", b"hello", b"hello"], + [b"hello", b"world", b"foo"], + [], + [b"hello"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[[1, 0, 0], [0, 1, -99], [], [0]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=1, is_categorical=True + ), + default_value=-99, + top_k=2, + ), + dict( + testcase_name="_vector_feature", + x_data=[ + [b"world", b"hello", b"hello"], + [b"hello", b"world", b"moo"], + [b"hello", b"hello", b"foo"], + [b"world", b"foo", b"moo"], + ], + x_feature_spec=tf.io.FixedLenFeature([3], tf.string), + index_data=[[1, 0, 0], [0, 1, -99], [0, 0, -99], [1, -99, -99]], + index_feature_spec=tf.io.FixedLenFeature([3], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=1, is_categorical=True + ), + default_value=-99, + top_k=2, + ), + dict( + testcase_name="_varlen_feature_with_labels", + x_data=[ + [b"hello", b"world", b"bye", b"moo"], + [b"world", b"moo", b"foo"], + [b"hello", b"foo", b"moo"], + [b"moo"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + label_data=[1, 0, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[ + [0, -99, 1, -99], + [-99, -99, -99], + [0, -99, -99], + [-99], + ], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=1, is_categorical=True + ), + default_value=-99, + top_k=2, + ), + dict( + testcase_name="_vector_feature_with_labels", + x_data=[ + [b"world", b"hello", b"hi"], + [b"hello", b"world", b"moo"], + [b"hello", b"bye", b"foo"], + [b"world", b"foo", b"moo"], + ], + x_feature_spec=tf.io.FixedLenFeature([3], tf.string), + label_data=[1, 0, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[ + [-99, -99, 1], + [-99, -99, 0], + [-99, -99, -99], + [-99, -99, 0], + ], + index_feature_spec=tf.io.FixedLenFeature([3], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=1, is_categorical=True + ), + default_value=-99, + top_k=2, + ), + dict( + testcase_name="_varlen_integer_feature_with_labels", + x_data=[[0, 1, 3, 2], [1, 2, 4], [0, 4, 2], [2]], + x_feature_spec=tf.io.VarLenFeature(tf.int64), + label_data=[1, 0, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[ + [0, -99, 1, -99], + [-99, -99, -99], + [0, -99, -99], + [-99], + ], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=1, is_categorical=True + ), + default_value=-99, + top_k=2, + ), + dict( + testcase_name="_varlen_feature_with_some_empty_feature_values", + x_data=[ + [b"world", b"hello", b"hi", b"moo"], + [], + [b"world", b"hello", b"foo"], + [], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + label_data=[1, 0, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[[0, 1, -99, -99], [], [0, 1, -99], []], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=1, is_categorical=True + ), + default_value=-99, + top_k=2, + ), + dict( + testcase_name="_varlen_with_multiclass_labels", + x_data=[ + [1, 2, 3, 5], + [1, 4, 5], + [1, 2], + [1, 2], + [1, 3, 5], + [1, 4, 3], + [1, 3], + ], + x_feature_spec=tf.io.VarLenFeature(tf.int64), + label_data=[1, 0, 1, 1, 4, 5, 4], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[ + [-1, 0, 2, 3], + [-1, 1, 3], + [-1, 0], + [-1, 0], + [-1, 2, 3], + [-1, 1, 2], + [-1, 2], + ], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=3, is_categorical=True + ), + top_k=4, + ), + dict( + testcase_name="_labels_and_weights", + x_data=[ + b"hello", + b"hello", + b"hello", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + b"goodbye", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + weight_data=[ + 0.3, + 0.4, + 0.3, + 1.2, + 0.6, + 0.7, + 1.0, + 1.0, + 0.6, + 0.7, + 1.0, + 1.0, + ], + weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), + index_data=[2, 2, 2, 1, 0, 0, 1, 1, 0, 0, 1, 1], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=2, is_categorical=True + ), + ), + dict( + testcase_name="_string_feature_with_weights", + x_data=[ + b"hello", + b"world", + b"goodbye", + b"aaaaa", + b"aaaaa", + b"goodbye", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + weight_data=[1.0, 0.5, 1.0, 0.26, 0.25, 1.5], + weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), + index_data=[1, 3, 0, 2, 2, 0], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=3, is_categorical=True + ), + ), + dict( + testcase_name="_int64_feature_with_weights", + x_data=[2, 1, 3, 4, 4, 3], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + weight_data=[1.0, 0.5, 1.0, 0.26, 0.25, 1.5], + weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), + index_data=[1, 3, 0, 2, 2, 0], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=3, is_categorical=True + ), + ), + dict( + testcase_name="_whitespace_newlines_and_empty_strings_text", + x_data=[ + b"hello", + b"world", + b"hello", + b"hello", + b"goodbye", + b"world", + b"aaaaa", + b" ", + b"", + b"\n", + b"hi \n ho \n", + "\r", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + # The empty string and strings containing newlines map to default + # value because the vocab cannot contain them. + index_data=[0, 1, 0, 0, 2, 1, 3, 4, -1, -1, -1, -1], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=4, is_categorical=True + ), + vocab_filename="my_vocab", + expected_vocab_file_contents={ + "my_vocab": [b"hello", b"world", b"goodbye", b"aaaaa", b" "] + }, + required_format="text", + ), + dict( + testcase_name="_whitespace_newlines_and_empty_strings_tfrecord", + x_data=[ + b"hello", + b"world", + b"hello", + b"hello", + b"goodbye", + b"world", + b"aaaaa", + b" ", + b"", + b"\n", + b"hi \n ho \n", + b"\r", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + index_data=[0, 0, 0, 1, 1, 8, 3, 2, 4, 5, 6, 7], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=8, is_categorical=True + ), + vocab_filename="my_vocab", + expected_vocab_file_contents={ + "my_vocab": [ + b"hello", + b"world", + b"hi \n ho \n", + b"goodbye", + b"aaaaa", + b" ", + b"\r", + b"\n", + b"", + ] + }, + required_format="tfrecord_gzip", + ), + dict( + testcase_name="_whitespace_newlines_empty_oov_buckets_text", + x_data=[ + b"hello", + b"world", + b"hello", + b"hello", + b"goodbye", + b"world", + b"aaaaa", + b" ", + b"", + b"\n", + b"hi \n ho \n", + "\r", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + index_data=[0, 1, 0, 0, 2, 1, 3, 4, 5, 5, 5, 5], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=0, max=5, is_categorical=True + ), + num_oov_buckets=1, + vocab_filename="my_vocab", + expected_vocab_file_contents={ + "my_vocab": [b"hello", b"world", b"goodbye", b"aaaaa", b" "] + }, + required_format="text", + ), + dict( + testcase_name="_whitespace_newlines_empty_oov_buckets_tfrecord", + x_data=[ + b"hello", + b"world", + b"hello", + b"hello", + b"goodbye", + b"world", + b"aaaaa", + b" ", + b"", + b"\n", + b"hi \n ho \n", + "\r", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + index_data=[0, 0, 1, 0, 1, 8, 3, 2, 4, 5, 6, 7], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=0, max=9, is_categorical=True + ), + num_oov_buckets=1, + vocab_filename="my_vocab", + expected_vocab_file_contents={ + "my_vocab": [ + b"hello", + b"world", + b"hi \n ho \n", + b"goodbye", + b"aaaaa", + b" ", + b"\r", + b"\n", + b"", + ] + }, + required_format="tfrecord_gzip", + ), + dict( + testcase_name="_positive_and_negative_integers", + x_data=[13, 14, 13, 13, 12, 14, 11, 10, 10, -10, -10, -20], + x_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[0, 1, 0, 0, 4, 1, 5, 2, 2, 3, 3, 6], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=6, is_categorical=True + ), + vocab_filename="my_vocab", + expected_vocab_file_contents={ + "my_vocab": [b"13", b"14", b"10", b"-10", b"12", b"11", b"-20"] + }, + ), + dict( + testcase_name="_rank_2", + x_data=[ + [[b"some", b"say"], [b"the", b"world"]], + [[b"will", b"end"], [b"in", b"fire"]], + [[b"some", b"say"], [b"in", b"ice"]], + ], + x_feature_spec=tf.io.FixedLenFeature([2, 2], tf.string), + index_data=[[[0, 1], [5, 3]], [[4, 8], [2, 7]], [[0, 1], [2, 6]]], + index_feature_spec=tf.io.FixedLenFeature([2, 2], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=8, is_categorical=True + ), + ), + dict( + testcase_name="_top_k", + x_data=[ + [b"hello", b"hello", b"world"], + [b"hello", b"goodbye", b"world"], + [b"hello", b"goodbye", b"foo"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[[0, 0, 1], [0, -99, 1], [0, -99, -99]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=1, is_categorical=True + ), + default_value=-99, + top_k=2, + ), + dict( + testcase_name="_top_k_specified_as_str", + x_data=[ + [b"hello", b"hello", b"world"], + [b"hello", b"goodbye", b"world"], + [b"hello", b"goodbye", b"foo"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[[0, 0, 1], [0, -9, 1], [0, -9, -9]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-9, max=1, is_categorical=True + ), + default_value=-9, + top_k="2", + ), + dict( + testcase_name="_frequency_threshold", + x_data=[ + [b"hello", b"hello", b"world"], + [b"hello", b"goodbye", b"world"], + [b"hello", b"goodbye", b"foo"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[[0, 0, 1], [0, 2, 1], [0, 2, -99]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=2, is_categorical=True + ), + default_value=-99, + frequency_threshold=2, + ), + dict( + testcase_name="_frequency_threshold_specified_with_str", + x_data=[ + [b"hello", b"hello", b"world"], + [b"hello", b"goodbye", b"world"], + [b"hello", b"goodbye", b"foo"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[[0, 0, 1], [0, 2, 1], [0, 2, -9]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-9, max=2, is_categorical=True + ), + default_value=-9, + frequency_threshold="2", + ), + dict( + testcase_name="_empty_vocabulary_from_high_frequency_threshold", + x_data=[ + [b"hello", b"hello", b"world"], + [b"hello", b"goodbye", b"world"], + [b"hello", b"goodbye", b"foo"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[[-99, -99, -99], [-99, -99, -99], [-99, -99, -99]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=0, is_categorical=True + ), + default_value=-99, + frequency_threshold=77, + ), + dict( + testcase_name="_top_k_and_oov", + x_data=[ + [b"hello", b"hello", b"world", b"world"], + [b"hello", b"tarkus", b"toccata"], + [b"hello", b"goodbye", b"foo"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + # Generated vocab (ordered by frequency, then value) should be: + # ["hello", "world", "goodbye", "foo", "tarkus", "toccata"]. After + # applying top_k =1 this becomes ["hello"] plus three OOV buckets. + # The specific output values here depend on the hash of the words, + # and the test will break if the hash changes. + index_data=[[0, 0, 2, 2], [0, 3, 1], [0, 2, 1]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=0, max=3, is_categorical=True + ), + default_value=-99, + top_k=1, + num_oov_buckets=3, + ), + dict( + testcase_name="_key_fn", + x_data=[ + ["a_X_1", "a_X_1", "a_X_2", "b_X_1", "b_X_2"], + ["a_X_1", "a_X_1", "a_X_2", "a_X_2"], + ["b_X_2"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[[0, 0, 1, -99, 2], [0, 0, 1, 1], [2]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=2, is_categorical=True + ), + coverage_top_k=1, + default_value=-99, + key_fn=lambda s: s.split(b"_X_")[0], + frequency_threshold=3, + ), + dict( + testcase_name="_key_fn_and_multi_coverage_top_k", + x_data=[ + ["a_X_1", "a_X_1", "a_X_2", "b_X_1", "b_X_2"], + ["a_X_1", "a_X_1", "a_X_2", "a_X_2", "a_X_3"], + ["b_X_2"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[[0, 0, 1, 3, 2], [0, 0, 1, 1, -99], [2]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=3, is_categorical=True + ), + coverage_top_k=2, + default_value=-99, + key_fn=lambda s: s.split(b"_X_")[0], + frequency_threshold=300, + ), + dict( + testcase_name="_key_fn_and_top_k", + x_data=[ + ["a_X_1", "a_X_1", "a_X_2", "b_X_1", "b_X_2"], + ["a_X_1", "a_X_1", "a_X_2", "a_X_2"], + ["b_X_2", "b_X_2", "b_X_2", "b_X_2", "c_X_1"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[[1, 1, -99, -99, 0], [1, 1, -99, -99], [0, 0, 0, 0, 2]], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=2, is_categorical=True + ), + coverage_top_k=1, + default_value=-99, + key_fn=lambda s: s.split(b"_X_")[0], + top_k=2, + ), + dict( + testcase_name="_key_fn_multi_coverage_top_k", + x_data=[ + ["0_X_a", "0_X_a", "5_X_a", "6_X_a", "6_X_a", "0_X_a"], + ["0_X_a", "2_X_a", "2_X_a", "2_X_a", "0_X_a", "5_X_a"], + ["1_X_b", "1_X_b", "3_X_b", "3_X_b", "0_X_b", "1_X_b", "1_X_b"], + ], + x_feature_spec=tf.io.VarLenFeature(tf.string), + index_data=[ + [0, 0, -99, -99, -99, 0], + [0, 2, 2, 2, 0, -99], + [1, 1, 3, 3, -99, 1, 1], + ], + index_feature_spec=tf.io.VarLenFeature(tf.int64), + index_domain=schema_pb2.IntDomain( + min=-99, max=3, is_categorical=True + ), + coverage_top_k=2, + default_value=-99, + key_fn=lambda s: s.split(b"_X_")[1], + frequency_threshold=4, + ), + dict( + testcase_name="_key_fn_and_labels", + x_data=[ + "aaa", + "aaa", + "aaa", + "aab", + "aba", + "aba", + "aab", + "aab", + "aba", + "abc", + "abc", + "aab", + ], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + label_data=[1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0], + label_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_data=[0, 0, 0, -1, -1, -1, -1, -1, -1, 1, 1, -1], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=1, is_categorical=True + ), + coverage_top_k=1, + key_fn=lambda s: s[:2], + frequency_threshold=3, + ), + dict( + testcase_name="_key_fn_and_weights", + x_data=["xa", "xa", "xb", "ya", "yb", "yc"], + x_feature_spec=tf.io.FixedLenFeature([], tf.string), + weight_data=[1.0, 0.5, 3.0, 0.6, 0.25, 0.5], + weight_feature_spec=tf.io.FixedLenFeature([], tf.float32), + index_data=[1, 1, 0, -1, -1, -1], + index_feature_spec=tf.io.FixedLenFeature([], tf.int64), + index_domain=schema_pb2.IntDomain( + min=-1, max=1, is_categorical=True + ), + coverage_top_k=1, + key_fn=lambda s: s[0], + frequency_threshold=1.5, + coverage_frequency_threshold=1, + ), + ] + + _EMPTY_VOCABULARY_PARAMS + ) + ) + def testComputeAndApplyVocabulary( + self, + x_data, + x_feature_spec, + index_data, + index_feature_spec, + index_domain, + label_data=None, + label_feature_spec=None, + weight_data=None, + weight_feature_spec=None, + expected_vocab_file_contents=None, + required_format=None, + **kwargs, + ): + """Test tft.compute_and_apply_vocabulary with various inputs.""" + if required_format is not None and required_format != self._VocabFormat(): + raise tft_unit.SkipTest( + f"Test only applicable to format: {self._VocabFormat()}." + ) + + input_data = [{"x": x} for x in x_data] + input_feature_spec = {"x": x_feature_spec} + expected_data = [{"index": index} for index in index_data] + expected_feature_spec = {"index": index_feature_spec} + expected_domains = {"index": index_domain} + + if label_data is not None: + for idx, label in enumerate(label_data): + input_data[idx]["label"] = label + input_feature_spec["label"] = label_feature_spec + + if weight_data is not None: + for idx, weight in enumerate(weight_data): + input_data[idx]["weights"] = weight + input_feature_spec["weights"] = weight_feature_spec + + input_metadata = tft.DatasetMetadata.from_feature_spec(input_feature_spec) + expected_metadata = tft.DatasetMetadata.from_feature_spec( + expected_feature_spec, expected_domains + ) + + def preprocessing_fn(inputs): + x = inputs["x"] + labels = inputs.get("label") + weights = inputs.get("weights") + index = tft.compute_and_apply_vocabulary( + x, + labels=labels, + weights=weights, + file_format=self._VocabFormat(), + **kwargs, + ) + return {"index": index} + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + expected_vocab_file_contents=expected_vocab_file_contents, + ) + + @tft_unit.named_parameters( + *tft_unit.cross_named_parameters( + _COMPOSITE_COMPUTE_AND_APPLY_VOCABULARY_TEST_CASES, + [ + dict(testcase_name="no_frequency", store_frequency=False), + dict(testcase_name="with_frequency", store_frequency=True), + ], + ) + ) + def testCompositeComputeAndApplyVocabulary( + self, input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata, - expected_vocab_file_contents=expected_vocab_file_contents) - - @tft_unit.named_parameters( - *tft_unit.cross_named_parameters( - _COMPOSITE_COMPUTE_AND_APPLY_VOCABULARY_TEST_CASES, - [ - dict(testcase_name='no_frequency', store_frequency=False), - dict(testcase_name='with_frequency', store_frequency=True), - ], - ) - ) - def testCompositeComputeAndApplyVocabulary( - self, - input_data, - input_metadata, - expected_data, - expected_vocab_contents, - store_frequency, - ): - def preprocessing_fn(inputs): - index = tft.compute_and_apply_vocabulary( - inputs['x'], - file_format=self._VocabFormat(), - store_frequency=store_frequency, - vocab_filename='my_vocab', - ) - return {'index': index} - - if store_frequency: - def format_pair(t: bytes, c: int) -> str: - t = t.decode('utf-8') - if t != ' ' or self._VocabFormat() != 'text': - suffix = ' ' + t + expected_vocab_contents, + store_frequency, + ): + def preprocessing_fn(inputs): + index = tft.compute_and_apply_vocabulary( + inputs["x"], + file_format=self._VocabFormat(), + store_frequency=store_frequency, + vocab_filename="my_vocab", + ) + return {"index": index} + + if store_frequency: + + def format_pair(t: bytes, c: int) -> str: + t = t.decode("utf-8") + if t != " " or self._VocabFormat() != "text": + suffix = " " + t + else: + suffix = " __SPACE__" + return f"{c}{suffix}" + + contents = [ + format_pair(t, c).encode("utf-8") + for t, c in expected_vocab_contents.items() + ] else: - suffix = ' __SPACE__' - return f'{c}{suffix}' - contents = [ - format_pair(t, c).encode('utf-8') - for t, c in expected_vocab_contents.items() - ] - else: - contents = [t for t in expected_vocab_contents] - - expected_vocab_file_contents = {'my_vocab': contents} - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_vocab_file_contents=expected_vocab_file_contents) - - # Example on how to use the vocab frequency as part of the transform - # function. - def testCreateVocabWithFrequency(self): - input_data = [ - {'a': 'hello', 'b': 'world', 'c': 'aaaaa'}, - {'a': 'good', 'b': '', 'c': 'hello'}, - {'a': 'goodbye', 'b': 'hello', 'c': '\n'}, - {'a': '_', 'b': 'aaaaa', 'c': 'bbbbb'} - ] - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': tf.io.FixedLenFeature([], tf.string), - 'b': tf.io.FixedLenFeature([], tf.string), - 'c': tf.io.FixedLenFeature([], tf.string) - }) - vocab_filename = 'test_vocab_with_frequency' - - def preprocessing_fn(inputs): - deferred_vocab_and_filename = tft.vocabulary( - tf.concat([inputs['a'], inputs['b'], inputs['c']], 0), - vocab_filename=vocab_filename, - store_frequency=True, - file_format=self._VocabFormat()) - - def _make_table_initializer(filename_tensor, is_frequency_value): - if self._VocabFormat() == 'text': - return tf.lookup.TextFileInitializer( - filename=filename_tensor, - key_dtype=tf.string, - key_index=1, - value_dtype=tf.int64, - value_index=(0 if is_frequency_value else - tf.lookup.TextFileIndex.LINE_NUMBER), - delimiter=' ') - elif self._VocabFormat() == 'tfrecord_gzip': - return tft.tf_utils.make_tfrecord_vocabulary_lookup_initializer( - filename_tensor, - return_indicator_as_value=is_frequency_value, - has_indicator=True) - - def _apply_vocab(y, deferred_vocab_filename_tensor): - initializer = _make_table_initializer(deferred_vocab_filename_tensor, - False) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - table_size = table.size() - return table.lookup(y), table_size - - def _apply_frequency(y, deferred_vocab_filename_tensor): - initializer = _make_table_initializer(deferred_vocab_filename_tensor, - True) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - return table.lookup(y), table.size() - - return { - 'index_a': - tft.apply_vocabulary( - inputs['a'], - deferred_vocab_and_filename, - lookup_fn=_apply_vocab, - file_format=self._VocabFormat()), - 'frequency_a': - tft.apply_vocabulary( - inputs['a'], - deferred_vocab_and_filename, - lookup_fn=_apply_frequency, - file_format=self._VocabFormat()), - 'index_b': - tft.apply_vocabulary( - inputs['b'], - deferred_vocab_and_filename, - lookup_fn=_apply_vocab, - file_format=self._VocabFormat()), - 'frequency_b': - tft.apply_vocabulary( - inputs['b'], - deferred_vocab_and_filename, - lookup_fn=_apply_frequency, - file_format=self._VocabFormat()), - } - - expected_vocab = [(b'hello', 3), (b'aaaaa', 2), (b'world', 1), - (b'goodbye', 1), (b'good', 1), (b'bbbbb', 1), (b'_', 1), - (b'\n', 1), (b'', 1)] - if self._VocabFormat() == 'text': - expected_vocab = expected_vocab[:-2] - empty_index = -1 - empty_frequency = -1 - else: - empty_index = 8 - empty_frequency = 1 - expected_data = [ - # For tied frequencies, larger (lexicographic) items come first. - { - 'index_a': 0, - 'frequency_a': 3, - 'index_b': 2, - 'frequency_b': 1 - }, - { - 'index_a': 4, - 'frequency_a': 1, - 'index_b': empty_index, - 'frequency_b': empty_frequency - }, - { - 'index_a': 3, - 'frequency_a': 1, - 'index_b': 0, - 'frequency_b': 3 - }, - { - 'index_a': 6, - 'frequency_a': 1, - 'index_b': 1, - 'frequency_b': 2 - } - ] - size = len(expected_vocab) - 1 - expected_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'index_a': tf.io.FixedLenFeature([], tf.int64), - 'index_b': tf.io.FixedLenFeature([], tf.int64), - 'frequency_a': tf.io.FixedLenFeature([], tf.int64), - 'frequency_b': tf.io.FixedLenFeature([], tf.int64), - }, { - 'index_a': - schema_pb2.IntDomain(min=-1, max=size, is_categorical=True), - 'index_b': - schema_pb2.IntDomain(min=-1, max=size, is_categorical=True), - 'frequency_a': - schema_pb2.IntDomain(min=-1, max=size, is_categorical=True), - 'frequency_b': - schema_pb2.IntDomain(min=-1, max=size, is_categorical=True), - }) - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - expected_vocab_file_contents={vocab_filename: expected_vocab}) - - def testVocabularyAnalyzerWithTokenization(self): - def preprocessing_fn(inputs): - return { - 'index': - tft.compute_and_apply_vocabulary( - tf.compat.v1.strings.split(inputs['a']), - file_format=self._VocabFormat(), - vocab_filename='my_vocab') - } - - input_data = [{'a': 'hello hello world'}, {'a': 'hello goodbye world'}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - expected_data = [{'index': [0, 0, 1]}, {'index': [0, 2, 1]}] - - expected_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'index': tf.io.VarLenFeature(tf.int64), - }, { - 'index': schema_pb2.IntDomain(min=-1, max=2, is_categorical=True), - }) - expected_vocabulary = {'my_vocab': [b'hello', b'world', b'goodbye']} - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata=expected_metadata, - expected_vocab_file_contents=expected_vocabulary) - - def testVocabularyWithFrequency(self): - outfile = 'vocabulary_with_frequency' - def preprocessing_fn(inputs): - - # Force the analyzer to be executed, and store the frequency file as a - # side-effect. - _ = tft.vocabulary( - inputs['a'], - vocab_filename=outfile, - store_frequency=True, - file_format=self._VocabFormat()) - _ = tft.vocabulary( - inputs['a'], store_frequency=True, file_format=self._VocabFormat()) - _ = tft.vocabulary( - inputs['b'], store_frequency=True, file_format=self._VocabFormat()) - - # The following must not produce frequency output, just the vocab words. - _ = tft.vocabulary(inputs['b'], file_format=self._VocabFormat()) - a_int = tft.compute_and_apply_vocabulary( - inputs['a'], file_format=self._VocabFormat()) - - # Return input unchanged, this preprocessing_fn is a no-op except for - # computing uniques. - return {'a_int': a_int} - - input_metadata = tft.DatasetMetadata.from_feature_spec({ - 'a': tf.io.FixedLenFeature([], tf.string), - 'b': tf.io.FixedLenFeature([], tf.string) - }) - - tft_tmp_dir = os.path.join(self.get_temp_dir(), 'temp_dir') - transform_fn_dir = os.path.join(self.get_temp_dir(), 'export_transform_fn') - - with beam_impl.Context(temp_dir=tft_tmp_dir): - with self._makeTestPipeline() as pipeline: - input_data = pipeline | beam.Create([ - {'a': 'hello', 'b': 'hi'}, - {'a': 'world', 'b': 'ho ho'}, - {'a': 'hello', 'b': 'ho ho'}, - ]) - transform_fn = ( - (input_data, input_metadata) - | beam_impl.AnalyzeDataset(preprocessing_fn)) - _ = transform_fn | transform_fn_io.WriteTransformFn(transform_fn_dir) - - self.assertTrue(os.path.isdir(tft_tmp_dir)) - - tft_output = tft.TFTransformOutput(transform_fn_dir) - assets_path = os.path.join(tft_output.transform_savedmodel_dir, - tf.saved_model.ASSETS_DIRECTORY) - self.assertTrue(os.path.isdir(assets_path)) - - self.assertEqual([b'2 hello', b'1 world'], - tft_output.vocabulary_by_name(outfile)) - - self.assertEqual( - [b'2 hello', b'1 world'], - tft_output.vocabulary_by_name('vocab_frequency_vocabulary_1')) - - self.assertEqual( - [b'2 ho ho', b'1 hi'], - tft_output.vocabulary_by_name('vocab_frequency_vocabulary_2')) - - self.assertEqual([b'ho ho', b'hi'], - tft_output.vocabulary_by_name('vocab_vocabulary_3')) - - self.assertEqual([b'hello', b'world'], - tft_output.vocabulary_by_name( - 'vocab_compute_and_apply_vocabulary_vocabulary')) - - def testVocabularyWithKeyFnAndFrequency(self): - def key_fn(string): - return string.split(b'_X_')[1] - - outfile = 'vocabulary_with_frequency' - - def preprocessing_fn(inputs): - - # Force the analyzer to be executed, and store the frequency file as a - # side-effect. - - _ = tft.vocabulary( - tf.compat.v1.strings.split(inputs['a']), - coverage_top_k=1, - key_fn=key_fn, - frequency_threshold=4, - vocab_filename=outfile, - store_frequency=True, - file_format=self._VocabFormat()) - - _ = tft.vocabulary( - tf.compat.v1.strings.split(inputs['a']), - coverage_top_k=1, - key_fn=key_fn, - frequency_threshold=4, - store_frequency=True, - file_format=self._VocabFormat()) - - a_int = tft.compute_and_apply_vocabulary( - tf.compat.v1.strings.split(inputs['a']), - coverage_top_k=1, - key_fn=key_fn, - frequency_threshold=4, - store_frequency=True, - file_format=self._VocabFormat(), - ) - - # Return input unchanged, this preprocessing_fn is a no-op except for - # computing uniques. - return {'a_int': a_int} - - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - - tft_tmp_dir = os.path.join(self.get_temp_dir(), 'temp_dir') - transform_fn_dir = os.path.join(self.get_temp_dir(), 'export_transform_fn') - - with beam_impl.Context(temp_dir=tft_tmp_dir): - with self._makeTestPipeline() as pipeline: - input_data = pipeline | beam.Create([ - {'a': '1_X_a 1_X_a 2_X_a 1_X_b 2_X_b'}, - {'a': '1_X_a 1_X_a 2_X_a 2_X_a'}, - {'a': '2_X_b 3_X_c 4_X_c'} - ]) - transform_fn = ( - (input_data, input_metadata) - | beam_impl.AnalyzeDataset(preprocessing_fn)) - _ = transform_fn | transform_fn_io.WriteTransformFn(transform_fn_dir) - - self.assertTrue(os.path.isdir(tft_tmp_dir)) - - tft_output = tft.TFTransformOutput(transform_fn_dir) - assets_path = os.path.join(tft_output.transform_savedmodel_dir, - tf.saved_model.ASSETS_DIRECTORY) - self.assertTrue(os.path.isdir(assets_path)) - - self.assertEqual([b'4 1_X_a', b'2 2_X_b', b'1 4_X_c'], - tft_output.vocabulary_by_name(outfile)) - - def testVocabularyAnnotations(self): - outfile = 'vocab.file' - # Sanitization of vocabulary file names replaces '.' with '_'. - annotation_file = 'vocab_file' - if self._VocabFormat() == 'tfrecord_gzip': - annotation_file = '{}.tfrecord.gz'.format(annotation_file) - - def preprocessing_fn(inputs): - _ = tft.vocabulary( - inputs['a'], vocab_filename=outfile, file_format=self._VocabFormat()) - tft.annotate_asset('key_1', annotation_file) - return inputs - - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'a': tf.io.FixedLenFeature([], tf.string)}) - - tft_tmp_dir = os.path.join(self.get_temp_dir(), 'temp_dir') - transform_fn_dir = os.path.join(self.get_temp_dir(), 'export_transform_fn') - - with beam_impl.Context(temp_dir=tft_tmp_dir): - with self._makeTestPipeline() as pipeline: - input_data = pipeline | beam.Create([ + contents = [t for t in expected_vocab_contents] + + expected_vocab_file_contents = {"my_vocab": contents} + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_vocab_file_contents=expected_vocab_file_contents, + ) + + # Example on how to use the vocab frequency as part of the transform + # function. + def testCreateVocabWithFrequency(self): + input_data = [ + {"a": "hello", "b": "world", "c": "aaaaa"}, + {"a": "good", "b": "", "c": "hello"}, + {"a": "goodbye", "b": "hello", "c": "\n"}, + {"a": "_", "b": "aaaaa", "c": "bbbbb"}, + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature([], tf.string), + "b": tf.io.FixedLenFeature([], tf.string), + "c": tf.io.FixedLenFeature([], tf.string), + } + ) + vocab_filename = "test_vocab_with_frequency" + + def preprocessing_fn(inputs): + deferred_vocab_and_filename = tft.vocabulary( + tf.concat([inputs["a"], inputs["b"], inputs["c"]], 0), + vocab_filename=vocab_filename, + store_frequency=True, + file_format=self._VocabFormat(), + ) + + def _make_table_initializer(filename_tensor, is_frequency_value): + if self._VocabFormat() == "text": + return tf.lookup.TextFileInitializer( + filename=filename_tensor, + key_dtype=tf.string, + key_index=1, + value_dtype=tf.int64, + value_index=( + 0 + if is_frequency_value + else tf.lookup.TextFileIndex.LINE_NUMBER + ), + delimiter=" ", + ) + elif self._VocabFormat() == "tfrecord_gzip": + return tft.tf_utils.make_tfrecord_vocabulary_lookup_initializer( + filename_tensor, + return_indicator_as_value=is_frequency_value, + has_indicator=True, + ) + + def _apply_vocab(y, deferred_vocab_filename_tensor): + initializer = _make_table_initializer( + deferred_vocab_filename_tensor, False + ) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + table_size = table.size() + return table.lookup(y), table_size + + def _apply_frequency(y, deferred_vocab_filename_tensor): + initializer = _make_table_initializer( + deferred_vocab_filename_tensor, True + ) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + return table.lookup(y), table.size() + + return { + "index_a": tft.apply_vocabulary( + inputs["a"], + deferred_vocab_and_filename, + lookup_fn=_apply_vocab, + file_format=self._VocabFormat(), + ), + "frequency_a": tft.apply_vocabulary( + inputs["a"], + deferred_vocab_and_filename, + lookup_fn=_apply_frequency, + file_format=self._VocabFormat(), + ), + "index_b": tft.apply_vocabulary( + inputs["b"], + deferred_vocab_and_filename, + lookup_fn=_apply_vocab, + file_format=self._VocabFormat(), + ), + "frequency_b": tft.apply_vocabulary( + inputs["b"], + deferred_vocab_and_filename, + lookup_fn=_apply_frequency, + file_format=self._VocabFormat(), + ), + } + + expected_vocab = [ + (b"hello", 3), + (b"aaaaa", 2), + (b"world", 1), + (b"goodbye", 1), + (b"good", 1), + (b"bbbbb", 1), + (b"_", 1), + (b"\n", 1), + (b"", 1), + ] + if self._VocabFormat() == "text": + expected_vocab = expected_vocab[:-2] + empty_index = -1 + empty_frequency = -1 + else: + empty_index = 8 + empty_frequency = 1 + expected_data = [ + # For tied frequencies, larger (lexicographic) items come first. + {"index_a": 0, "frequency_a": 3, "index_b": 2, "frequency_b": 1}, { - 'a': 'hello', + "index_a": 4, + "frequency_a": 1, + "index_b": empty_index, + "frequency_b": empty_frequency, }, + {"index_a": 3, "frequency_a": 1, "index_b": 0, "frequency_b": 3}, + {"index_a": 6, "frequency_a": 1, "index_b": 1, "frequency_b": 2}, + ] + size = len(expected_vocab) - 1 + expected_metadata = tft.DatasetMetadata.from_feature_spec( { - 'a': 'world', + "index_a": tf.io.FixedLenFeature([], tf.int64), + "index_b": tf.io.FixedLenFeature([], tf.int64), + "frequency_a": tf.io.FixedLenFeature([], tf.int64), + "frequency_b": tf.io.FixedLenFeature([], tf.int64), }, { - 'a': 'hello', + "index_a": schema_pb2.IntDomain(min=-1, max=size, is_categorical=True), + "index_b": schema_pb2.IntDomain(min=-1, max=size, is_categorical=True), + "frequency_a": schema_pb2.IntDomain( + min=-1, max=size, is_categorical=True + ), + "frequency_b": schema_pb2.IntDomain( + min=-1, max=size, is_categorical=True + ), }, - ]) - transform_fn = ((input_data, input_metadata) - | beam_impl.AnalyzeDataset(preprocessing_fn)) - _, metadata = transform_fn - self.assertDictEqual(metadata.asset_map, { - 'key_1': annotation_file, - outfile: annotation_file - }) - - _ = transform_fn | transform_fn_io.WriteTransformFn(transform_fn_dir) - - self.assertTrue(os.path.isdir(tft_tmp_dir)) - - tft_output = tft.TFTransformOutput(transform_fn_dir) - assets_path = os.path.join(tft_output.transform_savedmodel_dir, - tf.saved_model.ASSETS_DIRECTORY) - self.assertTrue(os.path.isdir(assets_path)) - - self.assertEqual([b'hello', b'world'], - tft_output.vocabulary_by_name('key_1')) - - def testVocabularyPreSort(self): - input_data = [ - dict(x=b'foo'), - dict(x=b'hello'), - dict(x=b'aaaaa'), - dict(x=b'goodbye'), - dict(x=b'bar'), - dict(x=b'hello'), - dict(x=b'goodbye'), - dict(x=b'hello'), - dict(x=b'hello'), - dict(x=b'goodbye'), - dict(x=b'aaaaa'), - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.string)}) - expected_vocab_file_contents = [(b'hello', 4), (b'goodbye', 3), - (b'aaaaa', 2), (b'foo', 1), (b'bar', 1)] - - def preprocessing_fn(inputs): - tft.vocabulary( - inputs['x'], - vocab_filename='my_vocab', - file_format=self._VocabFormat(), - store_frequency=True) - return inputs - - with tf.compat.v1.test.mock.patch.object(analyzer_impls, - '_PRESORT_BATCH_SIZE', 2): - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - input_data, - input_metadata, - expected_vocab_file_contents={ - 'my_vocab': expected_vocab_file_contents - }) - - def testVocabularyWithUserDefinedLookupFnFeedsSecondAnalyzer(self): - input_data = [ - dict(x=b'bar'), - dict(x=b'foo'), - dict(x=b'bar'), - dict(x=b'bar'), - dict(x=b'foo'), - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.string)}) - expected_data = [ - dict(x=b'bar', x_int=0, x_int_mean=0.4), - dict(x=b'bar', x_int=0, x_int_mean=0.4), - dict(x=b'bar', x_int=0, x_int_mean=0.4), - dict(x=b'foo', x_int=1, x_int_mean=0.4), - dict(x=b'foo', x_int=1, x_int_mean=0.4), - ] - expected_vocab_file_contents = [(b'bar'), (b'foo')] - size = len(expected_vocab_file_contents) - 1 - expected_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'x': tf.io.FixedLenFeature([], tf.string), - 'x_int': tf.io.FixedLenFeature([], tf.int64), - 'x_int_mean': tf.io.FixedLenFeature([], tf.float32) - }, - domains={ - 'x_int': schema_pb2.IntDomain( - min=-1, max=size, is_categorical=True) - }) - - def preprocessing_fn(inputs): - - def _make_table_initializer(filename_tensor): - if self._VocabFormat() == 'text': - return tf.lookup.TextFileInitializer( - filename=filename_tensor, - key_dtype=tf.string, - key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - value_dtype=tf.int64, - value_index=tf.lookup.TextFileIndex.LINE_NUMBER) - elif self._VocabFormat() == 'tfrecord_gzip': - return tft.tf_utils.make_tfrecord_vocabulary_lookup_initializer( - filename_tensor, return_indicator_as_value=False) - - def _apply_vocab(y, deferred_vocab_filename_tensor): - initializer = _make_table_initializer(deferred_vocab_filename_tensor) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - table_size = table.size() - return table.lookup(y), table_size - - deferred_vocab_and_filename = tft.vocabulary( - inputs['x'], - vocab_filename='my_vocab', - file_format=self._VocabFormat()) - x_int = tft.apply_vocabulary( - inputs['x'], - deferred_vocab_and_filename, - lookup_fn=_apply_vocab, - file_format=self._VocabFormat()) - - x_int_mean = tf.zeros_like(x_int, dtype=tf.float32) + tft.mean(x_int) - return {'x': inputs['x'], 'x_int': x_int, 'x_int_mean': x_int_mean} - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - expected_vocab_file_contents={'my_vocab': expected_vocab_file_contents}) - - def testVocabularyWithTableDefinedInPreprocessingFnFeedsSecondAnalyzer(self): - if self._VocabFormat() != 'text': - raise tft_unit.SkipTest('Test only applicable to text format.') - - input_data = [ - dict(x=b'bar'), - dict(x=b'foo'), - dict(x=b'bar'), - dict(x=b'bar'), - dict(x=b'foo'), - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.string)}) - expected_data = [ - dict(x=b'bar', x_int=0, x_int_mean=0.4), - dict(x=b'bar', x_int=0, x_int_mean=0.4), - dict(x=b'bar', x_int=0, x_int_mean=0.4), - dict(x=b'foo', x_int=1, x_int_mean=0.4), - dict(x=b'foo', x_int=1, x_int_mean=0.4), - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.string), - 'x_int': tf.io.FixedLenFeature([], tf.int64), - 'x_int_mean': tf.io.FixedLenFeature([], tf.float32) - }) - expected_vocab_file_contents = [(b'bar'), (b'foo')] - - def preprocessing_fn(inputs): - vocab_path = tft.vocabulary( - inputs['x'], - vocab_filename='my_vocab', - file_format=self._VocabFormat()) - initializer = tf.lookup.TextFileInitializer( - vocab_path, - key_dtype=tf.string, - key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - value_dtype=tf.int64, - value_index=tf.lookup.TextFileIndex.LINE_NUMBER) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - x_int = table.lookup(inputs['x']) - x_int_mean = tf.zeros_like(x_int, dtype=tf.float32) + tft.mean(x_int) - return {'x': inputs['x'], 'x_int': x_int, 'x_int_mean': x_int_mean} - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - expected_vocab_file_contents={'my_vocab': expected_vocab_file_contents}) - - def testStringOpsWithAutomaticControlDependencies(self): - - def preprocessing_fn(inputs): - month_str = tf.strings.substr( - inputs['date'], pos=5, len=3, unit='UTF8_CHAR') - - # The table created here will add an automatic control dependency. - month_int = tft.compute_and_apply_vocabulary(month_str) - return {'month_int': month_int} - - input_data = [{'date': '2021-May-31'}, {'date': '2021-Jun-01'}] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'date': tf.io.FixedLenFeature([], tf.string)}) - expected_data = [{'month_int': 0}, {'month_int': 1}] - max_index = len(expected_data) - 1 - expected_metadata = tft.DatasetMetadata.from_feature_spec( - { - 'month_int': tf.io.FixedLenFeature([], tf.int64), - }, { - 'month_int': - schema_pb2.IntDomain( - min=-1, max=max_index, is_categorical=True), - }) - - self.assertAnalyzeAndTransformResults(input_data, input_metadata, - preprocessing_fn, expected_data, - expected_metadata) - - def testVocabularyOneHotEncoding(self): - - input_data = [ - dict(x=b'bar'), - dict(x=b'foo'), - dict(x=b'bar'), - dict(x=b'bar'), - dict(x=b'foo'), - ] - input_metadata = tft.DatasetMetadata.from_feature_spec( - {'x': tf.io.FixedLenFeature([], tf.string)}) - expected_data = [ - dict(x=b'bar', x_encoded=[1], x_encoded_centered=[0.4]), - dict(x=b'bar', x_encoded=[1], x_encoded_centered=[0.4]), - dict(x=b'bar', x_encoded=[1], x_encoded_centered=[0.4]), - dict(x=b'foo', x_encoded=[0], x_encoded_centered=[-0.6]), - dict(x=b'foo', x_encoded=[0], x_encoded_centered=[-0.6]), - ] - expected_metadata = tft.DatasetMetadata.from_feature_spec({ - 'x': tf.io.FixedLenFeature([], tf.string), - 'x_encoded': tf.io.FixedLenFeature([1], tf.int64), - 'x_encoded_centered': tf.io.FixedLenFeature([1], tf.float32), - }) - expected_vocab_file_contents = [(b'bar')] - - def preprocessing_fn(inputs): - x_int = tft.compute_and_apply_vocabulary( - inputs['x'], - vocab_filename='my_vocab', - file_format=self._VocabFormat(), - frequency_threshold=3) - - depth = tft.experimental.get_vocabulary_size_by_name('my_vocab') - x_encoded = tf.one_hot( - x_int, depth=tf.cast(depth, tf.int32), dtype=tf.int64) - # Add a second phase that depends on vocabulary size. - x_encoded_centered = ( - tf.cast(x_encoded, dtype=tf.float32) - tft.mean(x_encoded)) - return { - 'x': inputs['x'], - 'x_encoded': x_encoded, - 'x_encoded_centered': x_encoded_centered - } - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_metadata, - expected_vocab_file_contents={'my_vocab': expected_vocab_file_contents}) - - def testVocabularyReservedTokens(self): - """Test vocabulary with reserved tokens.""" - x_data = ['hello', 'world', 'world', '42', '42', '42', 'a'] - input_data = [{'x': x} for x in x_data] - input_feature_spec = {'x': tf.io.FixedLenFeature([], tf.string)} - input_metadata = tft.DatasetMetadata.from_feature_spec(input_feature_spec) - reserved_tokens = ['a', 'b', 'c'] - - def preprocessing_fn(inputs): - tft.vocabulary( - inputs['x'], - vocab_filename='my_vocab', - file_format=self._VocabFormat(), - store_frequency=True, - reserved_tokens=reserved_tokens, - ) - tft.compute_and_apply_vocabulary( - inputs['x'], - vocab_filename='reserved_tokens_tensor', - file_format=self._VocabFormat(), - store_frequency=True, - top_k=20, - reserved_tokens=tf.constant(reserved_tokens), - ) - tft.vocabulary( - inputs['x'], - vocab_filename='sanity', - file_format=self._VocabFormat(), - ) - outputs = inputs.copy() - shape = tf.shape(inputs['x']) - outputs['my_vocab_size'] = tf.broadcast_to( - tft.experimental.get_vocabulary_size_by_name('my_vocab'), shape - ) - outputs['reserved_tokens_tensor_size'] = tf.broadcast_to( - tft.experimental.get_vocabulary_size_by_name( - 'reserved_tokens_tensor' - ), - shape, - ) - outputs['sanity_size'] = tf.broadcast_to( - tft.experimental.get_vocabulary_size_by_name('sanity'), shape - ) - return outputs - - expected_vocab_content_from_data = [ - ('42', 3), - ('world', 2), - ('hello', 1), - ] - - expected_vocab_file_contents = [ - (t, -1) for t in reserved_tokens - ] + expected_vocab_content_from_data - - expected_data = input_data.copy() - for instance in expected_data: - instance.update({ - # Vocabulary with reserved tokens sizes are off by one due to a - # duplicate token. - 'my_vocab_size': ( - len(expected_vocab_content_from_data) + len(reserved_tokens) + 1 - ), - 'reserved_tokens_tensor_size': ( - len(expected_vocab_content_from_data) + len(reserved_tokens) + 1 - ), - 'sanity_size': 4, - }) - - self.assertAnalyzeAndTransformResults( - input_data, - input_metadata, - preprocessing_fn, - expected_data, - expected_vocab_file_contents={'my_vocab': expected_vocab_file_contents}, - ) + ) + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + expected_vocab_file_contents={vocab_filename: expected_vocab}, + ) + + def testVocabularyAnalyzerWithTokenization(self): + def preprocessing_fn(inputs): + return { + "index": tft.compute_and_apply_vocabulary( + tf.compat.v1.strings.split(inputs["a"]), + file_format=self._VocabFormat(), + vocab_filename="my_vocab", + ) + } + + input_data = [{"a": "hello hello world"}, {"a": "hello goodbye world"}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) + expected_data = [{"index": [0, 0, 1]}, {"index": [0, 2, 1]}] + + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "index": tf.io.VarLenFeature(tf.int64), + }, + { + "index": schema_pb2.IntDomain(min=-1, max=2, is_categorical=True), + }, + ) + expected_vocabulary = {"my_vocab": [b"hello", b"world", b"goodbye"]} + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata=expected_metadata, + expected_vocab_file_contents=expected_vocabulary, + ) + + def testVocabularyWithFrequency(self): + outfile = "vocabulary_with_frequency" + + def preprocessing_fn(inputs): + # Force the analyzer to be executed, and store the frequency file as a + # side-effect. + _ = tft.vocabulary( + inputs["a"], + vocab_filename=outfile, + store_frequency=True, + file_format=self._VocabFormat(), + ) + _ = tft.vocabulary( + inputs["a"], store_frequency=True, file_format=self._VocabFormat() + ) + _ = tft.vocabulary( + inputs["b"], store_frequency=True, file_format=self._VocabFormat() + ) + + # The following must not produce frequency output, just the vocab words. + _ = tft.vocabulary(inputs["b"], file_format=self._VocabFormat()) + a_int = tft.compute_and_apply_vocabulary( + inputs["a"], file_format=self._VocabFormat() + ) + + # Return input unchanged, this preprocessing_fn is a no-op except for + # computing uniques. + return {"a_int": a_int} + + input_metadata = tft.DatasetMetadata.from_feature_spec( + { + "a": tf.io.FixedLenFeature([], tf.string), + "b": tf.io.FixedLenFeature([], tf.string), + } + ) + + tft_tmp_dir = os.path.join(self.get_temp_dir(), "temp_dir") + transform_fn_dir = os.path.join(self.get_temp_dir(), "export_transform_fn") + + with beam_impl.Context(temp_dir=tft_tmp_dir): + with self._makeTestPipeline() as pipeline: + input_data = pipeline | beam.Create( + [ + {"a": "hello", "b": "hi"}, + {"a": "world", "b": "ho ho"}, + {"a": "hello", "b": "ho ho"}, + ] + ) + transform_fn = (input_data, input_metadata) | beam_impl.AnalyzeDataset( + preprocessing_fn + ) + _ = transform_fn | transform_fn_io.WriteTransformFn(transform_fn_dir) + + self.assertTrue(os.path.isdir(tft_tmp_dir)) + + tft_output = tft.TFTransformOutput(transform_fn_dir) + assets_path = os.path.join( + tft_output.transform_savedmodel_dir, tf.saved_model.ASSETS_DIRECTORY + ) + self.assertTrue(os.path.isdir(assets_path)) + + self.assertEqual( + [b"2 hello", b"1 world"], tft_output.vocabulary_by_name(outfile) + ) + + self.assertEqual( + [b"2 hello", b"1 world"], + tft_output.vocabulary_by_name("vocab_frequency_vocabulary_1"), + ) + + self.assertEqual( + [b"2 ho ho", b"1 hi"], + tft_output.vocabulary_by_name("vocab_frequency_vocabulary_2"), + ) + + self.assertEqual( + [b"ho ho", b"hi"], tft_output.vocabulary_by_name("vocab_vocabulary_3") + ) + + self.assertEqual( + [b"hello", b"world"], + tft_output.vocabulary_by_name( + "vocab_compute_and_apply_vocabulary_vocabulary" + ), + ) + + def testVocabularyWithKeyFnAndFrequency(self): + def key_fn(string): + return string.split(b"_X_")[1] + + outfile = "vocabulary_with_frequency" + + def preprocessing_fn(inputs): + # Force the analyzer to be executed, and store the frequency file as a + # side-effect. + + _ = tft.vocabulary( + tf.compat.v1.strings.split(inputs["a"]), + coverage_top_k=1, + key_fn=key_fn, + frequency_threshold=4, + vocab_filename=outfile, + store_frequency=True, + file_format=self._VocabFormat(), + ) + + _ = tft.vocabulary( + tf.compat.v1.strings.split(inputs["a"]), + coverage_top_k=1, + key_fn=key_fn, + frequency_threshold=4, + store_frequency=True, + file_format=self._VocabFormat(), + ) + + a_int = tft.compute_and_apply_vocabulary( + tf.compat.v1.strings.split(inputs["a"]), + coverage_top_k=1, + key_fn=key_fn, + frequency_threshold=4, + store_frequency=True, + file_format=self._VocabFormat(), + ) + + # Return input unchanged, this preprocessing_fn is a no-op except for + # computing uniques. + return {"a_int": a_int} + + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) + + tft_tmp_dir = os.path.join(self.get_temp_dir(), "temp_dir") + transform_fn_dir = os.path.join(self.get_temp_dir(), "export_transform_fn") + + with beam_impl.Context(temp_dir=tft_tmp_dir): + with self._makeTestPipeline() as pipeline: + input_data = pipeline | beam.Create( + [ + {"a": "1_X_a 1_X_a 2_X_a 1_X_b 2_X_b"}, + {"a": "1_X_a 1_X_a 2_X_a 2_X_a"}, + {"a": "2_X_b 3_X_c 4_X_c"}, + ] + ) + transform_fn = (input_data, input_metadata) | beam_impl.AnalyzeDataset( + preprocessing_fn + ) + _ = transform_fn | transform_fn_io.WriteTransformFn(transform_fn_dir) + + self.assertTrue(os.path.isdir(tft_tmp_dir)) + + tft_output = tft.TFTransformOutput(transform_fn_dir) + assets_path = os.path.join( + tft_output.transform_savedmodel_dir, tf.saved_model.ASSETS_DIRECTORY + ) + self.assertTrue(os.path.isdir(assets_path)) + + self.assertEqual( + [b"4 1_X_a", b"2 2_X_b", b"1 4_X_c"], tft_output.vocabulary_by_name(outfile) + ) + + def testVocabularyAnnotations(self): + outfile = "vocab.file" + # Sanitization of vocabulary file names replaces '.' with '_'. + annotation_file = "vocab_file" + if self._VocabFormat() == "tfrecord_gzip": + annotation_file = f"{annotation_file}.tfrecord.gz" + + def preprocessing_fn(inputs): + _ = tft.vocabulary( + inputs["a"], vocab_filename=outfile, file_format=self._VocabFormat() + ) + tft.annotate_asset("key_1", annotation_file) + return inputs + + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"a": tf.io.FixedLenFeature([], tf.string)} + ) + + tft_tmp_dir = os.path.join(self.get_temp_dir(), "temp_dir") + transform_fn_dir = os.path.join(self.get_temp_dir(), "export_transform_fn") + + with beam_impl.Context(temp_dir=tft_tmp_dir): + with self._makeTestPipeline() as pipeline: + input_data = pipeline | beam.Create( + [ + { + "a": "hello", + }, + { + "a": "world", + }, + { + "a": "hello", + }, + ] + ) + transform_fn = (input_data, input_metadata) | beam_impl.AnalyzeDataset( + preprocessing_fn + ) + _, metadata = transform_fn + self.assertDictEqual( + metadata.asset_map, + {"key_1": annotation_file, outfile: annotation_file}, + ) + _ = transform_fn | transform_fn_io.WriteTransformFn(transform_fn_dir) + + self.assertTrue(os.path.isdir(tft_tmp_dir)) + + tft_output = tft.TFTransformOutput(transform_fn_dir) + assets_path = os.path.join( + tft_output.transform_savedmodel_dir, tf.saved_model.ASSETS_DIRECTORY + ) + self.assertTrue(os.path.isdir(assets_path)) + + self.assertEqual([b"hello", b"world"], tft_output.vocabulary_by_name("key_1")) + + def testVocabularyPreSort(self): + input_data = [ + dict(x=b"foo"), + dict(x=b"hello"), + dict(x=b"aaaaa"), + dict(x=b"goodbye"), + dict(x=b"bar"), + dict(x=b"hello"), + dict(x=b"goodbye"), + dict(x=b"hello"), + dict(x=b"hello"), + dict(x=b"goodbye"), + dict(x=b"aaaaa"), + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.string)} + ) + expected_vocab_file_contents = [ + (b"hello", 4), + (b"goodbye", 3), + (b"aaaaa", 2), + (b"foo", 1), + (b"bar", 1), + ] + + def preprocessing_fn(inputs): + tft.vocabulary( + inputs["x"], + vocab_filename="my_vocab", + file_format=self._VocabFormat(), + store_frequency=True, + ) + return inputs + + with tf.compat.v1.test.mock.patch.object( + analyzer_impls, "_PRESORT_BATCH_SIZE", 2 + ): + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + input_data, + input_metadata, + expected_vocab_file_contents={"my_vocab": expected_vocab_file_contents}, + ) + + def testVocabularyWithUserDefinedLookupFnFeedsSecondAnalyzer(self): + input_data = [ + dict(x=b"bar"), + dict(x=b"foo"), + dict(x=b"bar"), + dict(x=b"bar"), + dict(x=b"foo"), + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.string)} + ) + expected_data = [ + dict(x=b"bar", x_int=0, x_int_mean=0.4), + dict(x=b"bar", x_int=0, x_int_mean=0.4), + dict(x=b"bar", x_int=0, x_int_mean=0.4), + dict(x=b"foo", x_int=1, x_int_mean=0.4), + dict(x=b"foo", x_int=1, x_int_mean=0.4), + ] + expected_vocab_file_contents = [(b"bar"), (b"foo")] + size = len(expected_vocab_file_contents) - 1 + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.string), + "x_int": tf.io.FixedLenFeature([], tf.int64), + "x_int_mean": tf.io.FixedLenFeature([], tf.float32), + }, + domains={ + "x_int": schema_pb2.IntDomain(min=-1, max=size, is_categorical=True) + }, + ) + + def preprocessing_fn(inputs): + def _make_table_initializer(filename_tensor): + if self._VocabFormat() == "text": + return tf.lookup.TextFileInitializer( + filename=filename_tensor, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + ) + elif self._VocabFormat() == "tfrecord_gzip": + return tft.tf_utils.make_tfrecord_vocabulary_lookup_initializer( + filename_tensor, return_indicator_as_value=False + ) + + def _apply_vocab(y, deferred_vocab_filename_tensor): + initializer = _make_table_initializer(deferred_vocab_filename_tensor) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + table_size = table.size() + return table.lookup(y), table_size + + deferred_vocab_and_filename = tft.vocabulary( + inputs["x"], vocab_filename="my_vocab", file_format=self._VocabFormat() + ) + x_int = tft.apply_vocabulary( + inputs["x"], + deferred_vocab_and_filename, + lookup_fn=_apply_vocab, + file_format=self._VocabFormat(), + ) + + x_int_mean = tf.zeros_like(x_int, dtype=tf.float32) + tft.mean(x_int) + return {"x": inputs["x"], "x_int": x_int, "x_int_mean": x_int_mean} + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + expected_vocab_file_contents={"my_vocab": expected_vocab_file_contents}, + ) + + def testVocabularyWithTableDefinedInPreprocessingFnFeedsSecondAnalyzer(self): + if self._VocabFormat() != "text": + raise tft_unit.SkipTest("Test only applicable to text format.") + + input_data = [ + dict(x=b"bar"), + dict(x=b"foo"), + dict(x=b"bar"), + dict(x=b"bar"), + dict(x=b"foo"), + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.string)} + ) + expected_data = [ + dict(x=b"bar", x_int=0, x_int_mean=0.4), + dict(x=b"bar", x_int=0, x_int_mean=0.4), + dict(x=b"bar", x_int=0, x_int_mean=0.4), + dict(x=b"foo", x_int=1, x_int_mean=0.4), + dict(x=b"foo", x_int=1, x_int_mean=0.4), + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.string), + "x_int": tf.io.FixedLenFeature([], tf.int64), + "x_int_mean": tf.io.FixedLenFeature([], tf.float32), + } + ) + expected_vocab_file_contents = [(b"bar"), (b"foo")] + + def preprocessing_fn(inputs): + vocab_path = tft.vocabulary( + inputs["x"], vocab_filename="my_vocab", file_format=self._VocabFormat() + ) + initializer = tf.lookup.TextFileInitializer( + vocab_path, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + ) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + x_int = table.lookup(inputs["x"]) + x_int_mean = tf.zeros_like(x_int, dtype=tf.float32) + tft.mean(x_int) + return {"x": inputs["x"], "x_int": x_int, "x_int_mean": x_int_mean} + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + expected_vocab_file_contents={"my_vocab": expected_vocab_file_contents}, + ) + + def testStringOpsWithAutomaticControlDependencies(self): + def preprocessing_fn(inputs): + month_str = tf.strings.substr( + inputs["date"], pos=5, len=3, unit="UTF8_CHAR" + ) + + # The table created here will add an automatic control dependency. + month_int = tft.compute_and_apply_vocabulary(month_str) + return {"month_int": month_int} + + input_data = [{"date": "2021-May-31"}, {"date": "2021-Jun-01"}] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"date": tf.io.FixedLenFeature([], tf.string)} + ) + expected_data = [{"month_int": 0}, {"month_int": 1}] + max_index = len(expected_data) - 1 + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "month_int": tf.io.FixedLenFeature([], tf.int64), + }, + { + "month_int": schema_pb2.IntDomain( + min=-1, max=max_index, is_categorical=True + ), + }, + ) + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + ) + + def testVocabularyOneHotEncoding(self): + input_data = [ + dict(x=b"bar"), + dict(x=b"foo"), + dict(x=b"bar"), + dict(x=b"bar"), + dict(x=b"foo"), + ] + input_metadata = tft.DatasetMetadata.from_feature_spec( + {"x": tf.io.FixedLenFeature([], tf.string)} + ) + expected_data = [ + dict(x=b"bar", x_encoded=[1], x_encoded_centered=[0.4]), + dict(x=b"bar", x_encoded=[1], x_encoded_centered=[0.4]), + dict(x=b"bar", x_encoded=[1], x_encoded_centered=[0.4]), + dict(x=b"foo", x_encoded=[0], x_encoded_centered=[-0.6]), + dict(x=b"foo", x_encoded=[0], x_encoded_centered=[-0.6]), + ] + expected_metadata = tft.DatasetMetadata.from_feature_spec( + { + "x": tf.io.FixedLenFeature([], tf.string), + "x_encoded": tf.io.FixedLenFeature([1], tf.int64), + "x_encoded_centered": tf.io.FixedLenFeature([1], tf.float32), + } + ) + expected_vocab_file_contents = [(b"bar")] + + def preprocessing_fn(inputs): + x_int = tft.compute_and_apply_vocabulary( + inputs["x"], + vocab_filename="my_vocab", + file_format=self._VocabFormat(), + frequency_threshold=3, + ) + + depth = tft.experimental.get_vocabulary_size_by_name("my_vocab") + x_encoded = tf.one_hot( + x_int, depth=tf.cast(depth, tf.int32), dtype=tf.int64 + ) + # Add a second phase that depends on vocabulary size. + x_encoded_centered = tf.cast(x_encoded, dtype=tf.float32) - tft.mean( + x_encoded + ) + return { + "x": inputs["x"], + "x_encoded": x_encoded, + "x_encoded_centered": x_encoded_centered, + } -if __name__ == '__main__': - tft_unit.main() + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_metadata, + expected_vocab_file_contents={"my_vocab": expected_vocab_file_contents}, + ) + + def testVocabularyReservedTokens(self): + """Test vocabulary with reserved tokens.""" + x_data = ["hello", "world", "world", "42", "42", "42", "a"] + input_data = [{"x": x} for x in x_data] + input_feature_spec = {"x": tf.io.FixedLenFeature([], tf.string)} + input_metadata = tft.DatasetMetadata.from_feature_spec(input_feature_spec) + reserved_tokens = ["a", "b", "c"] + + def preprocessing_fn(inputs): + tft.vocabulary( + inputs["x"], + vocab_filename="my_vocab", + file_format=self._VocabFormat(), + store_frequency=True, + reserved_tokens=reserved_tokens, + ) + tft.compute_and_apply_vocabulary( + inputs["x"], + vocab_filename="reserved_tokens_tensor", + file_format=self._VocabFormat(), + store_frequency=True, + top_k=20, + reserved_tokens=tf.constant(reserved_tokens), + ) + tft.vocabulary( + inputs["x"], + vocab_filename="sanity", + file_format=self._VocabFormat(), + ) + outputs = inputs.copy() + shape = tf.shape(inputs["x"]) + outputs["my_vocab_size"] = tf.broadcast_to( + tft.experimental.get_vocabulary_size_by_name("my_vocab"), shape + ) + outputs["reserved_tokens_tensor_size"] = tf.broadcast_to( + tft.experimental.get_vocabulary_size_by_name("reserved_tokens_tensor"), + shape, + ) + outputs["sanity_size"] = tf.broadcast_to( + tft.experimental.get_vocabulary_size_by_name("sanity"), shape + ) + return outputs + + expected_vocab_content_from_data = [ + ("42", 3), + ("world", 2), + ("hello", 1), + ] + + expected_vocab_file_contents = [ + (t, -1) for t in reserved_tokens + ] + expected_vocab_content_from_data + + expected_data = input_data.copy() + for instance in expected_data: + instance.update( + { + # Vocabulary with reserved tokens sizes are off by one due to a + # duplicate token. + "my_vocab_size": ( + len(expected_vocab_content_from_data) + len(reserved_tokens) + 1 + ), + "reserved_tokens_tensor_size": ( + len(expected_vocab_content_from_data) + len(reserved_tokens) + 1 + ), + "sanity_size": 4, + } + ) + + self.assertAnalyzeAndTransformResults( + input_data, + input_metadata, + preprocessing_fn, + expected_data, + expected_vocab_file_contents={"my_vocab": expected_vocab_file_contents}, + ) + + +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py b/tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py index 4ac011c..1eb5214 100644 --- a/tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py +++ b/tensorflow_transform/beam/vocabulary_tfrecord_gzip_integration_test.py @@ -14,16 +14,15 @@ # limitations under the License. """Tests for tfrecord_gzip tft.vocabulary and tft.compute_and_apply_vocabulary.""" -from tensorflow_transform.beam import vocabulary_integration_test -from tensorflow_transform.beam import tft_unit +from tensorflow_transform.beam import tft_unit, vocabulary_integration_test class TFRecordVocabularyIntegrationTest( - vocabulary_integration_test.VocabularyIntegrationTest): + vocabulary_integration_test.VocabularyIntegrationTest +): + def _VocabFormat(self): + return "tfrecord_gzip" - def _VocabFormat(self): - return 'tfrecord_gzip' - -if __name__ == '__main__': - tft_unit.main() +if __name__ == "__main__": + tft_unit.main() diff --git a/tensorflow_transform/coders/csv_coder.py b/tensorflow_transform/coders/csv_coder.py index e683f2e..9322812 100644 --- a/tensorflow_transform/coders/csv_coder.py +++ b/tensorflow_transform/coders/csv_coder.py @@ -12,248 +12,286 @@ # See the License for the specific language governing permissions and # limitations under the License. """Coder classes for encoding CSV into tf.Transform datasets.""" + import csv import io import numpy as np import tensorflow as tf + from tensorflow_transform.tf_metadata import schema_utils def _to_string(x): - """Converts x to string. + """Converts x to string. - This will return Unicode for Py3. This is needed as a pre-processing step - before calling csv reader/writer since it only supports Unicode for Py3. + This will return Unicode for Py3. This is needed as a pre-processing step + before calling csv reader/writer since it only supports Unicode for Py3. - Args: - x: The data to be converted. + Args: + ---- + x: The data to be converted. - Returns: - Unicode representation for Py3. + Returns: + ------- + Unicode representation for Py3. - """ - return tf.compat.as_str_any(x) + """ + return tf.compat.as_str_any(x) class _FixedLenFeatureHandler: - """Handler for `FixedLenFeature` values. - - `FixedLenFeature` values will be parsed as a scalar or an array of the - corresponding dtype. In case the value is missing the default_value will - be returned. If the default value is not present a ValueError will be raised. - """ - - def __init__(self, name, feature_spec, index, encoder=None): - self._name = name - self._default_value = feature_spec.default_value - self._index = index - self._encoder = encoder - self._np_dtype = feature_spec.dtype.as_numpy_dtype - self._shape = feature_spec.shape - self._rank = len(feature_spec.shape) - self._size = 1 - for dim in feature_spec.shape: - self._size *= dim - - @property - def name(self): - return self._name - - def encode_value(self, string_list, values): - """Encode the value of this feature into the CSV line.""" - - if self._rank == 0: - flattened_values = [values] - elif self._rank == 1: - # Short-circuit the reshaping logic needed for rank > 1. - flattened_values = values - else: - flattened_values = np.asarray(values, dtype=self._np_dtype).reshape(-1) - - if len(flattened_values) != self._size: - raise ValueError( - 'FixedLenFeature "{}" got wrong number of values. Expected {} but ' - 'got {}'.format(self._name, self._size, len(flattened_values))) - - if self._encoder: - string_list[self._index] = self._encoder.encode_record(flattened_values) - else: - string_list[self._index] = _to_string(flattened_values[0]) + """Handler for `FixedLenFeature` values. + `FixedLenFeature` values will be parsed as a scalar or an array of the + corresponding dtype. In case the value is missing the default_value will + be returned. If the default value is not present a ValueError will be raised. + """ -class _VarLenFeatureHandler: - """Handler for `VarLenFeature` values. - - `VarLenFeature` values will be parsed as an array of values of the - corresponding dtype. In case the value is missing an empty array - will be returned. - """ + def __init__(self, name, feature_spec, index, encoder=None): + self._name = name + self._default_value = feature_spec.default_value + self._index = index + self._encoder = encoder + self._np_dtype = feature_spec.dtype.as_numpy_dtype + self._shape = feature_spec.shape + self._rank = len(feature_spec.shape) + self._size = 1 + for dim in feature_spec.shape: + self._size *= dim + + @property + def name(self): + return self._name + + def encode_value(self, string_list, values): + """Encode the value of this feature into the CSV line.""" + if self._rank == 0: + flattened_values = [values] + elif self._rank == 1: + # Short-circuit the reshaping logic needed for rank > 1. + flattened_values = values + else: + flattened_values = np.asarray(values, dtype=self._np_dtype).reshape(-1) + + if len(flattened_values) != self._size: + raise ValueError( + f'FixedLenFeature "{self._name}" got wrong number of values. Expected {self._size} but ' + f"got {len(flattened_values)}" + ) + + if self._encoder: + string_list[self._index] = self._encoder.encode_record(flattened_values) + else: + string_list[self._index] = _to_string(flattened_values[0]) - def __init__(self, name, dtype, index, encoder=None): - self._name = name - self._np_dtype = dtype.as_numpy_dtype - self._index = index - self._encoder = encoder - @property - def name(self): - return self._name +class _VarLenFeatureHandler: + """Handler for `VarLenFeature` values. - def encode_value(self, string_list, values): - """Encode the value of this feature into the CSV line.""" - if self._encoder: - string_list[self._index] = self._encoder.encode_record(values) - else: - string_list[self._index] = _to_string(values[0]) if values else '' + `VarLenFeature` values will be parsed as an array of values of the + corresponding dtype. In case the value is missing an empty array + will be returned. + """ + def __init__(self, name, dtype, index, encoder=None): + self._name = name + self._np_dtype = dtype.as_numpy_dtype + self._index = index + self._encoder = encoder -class CsvCoder: - """A coder to encode CSV formatted data.""" - - class _WriterWrapper: - """A wrapper for csv.writer to make it picklable.""" - - def __init__(self, delimiter): - """Initializes the writer wrapper. - - Args: - delimiter: A one-character string used to separate fields. - """ - self._state = (delimiter) - self._buffer = io.StringIO() - - # Since we use self._writer to encode individual rows, we set - # lineterminator='' so that self._writer doesn't add a newline. - self._writer = csv.writer( - self._buffer, lineterminator='', delimiter=delimiter) - - def encode_record(self, record): - """Converts the record to bytes. - - Since csv writer only supports Unicode for PY3, we need to convert them - conditionally before calling csv writer. We always return result in bytes - format to be consistent with current behavior. - - Args: - record: The data to be converted. - - Returns: - Bytes representation input. - """ - self._writer.writerow([_to_string(x) for x in record]) - result = tf.compat.as_bytes(self._buffer.getvalue()) - # Reset the buffer. - self._buffer.seek(0) - self._buffer.truncate(0) - return result - - def __getstate__(self): - return self._state - - def __setstate__(self, state): - self.__init__(*state) - - def __init__(self, - column_names, - schema, - delimiter=',', - secondary_delimiter=None, - multivalent_columns=None): - """Initializes CsvCoder. + @property + def name(self): + return self._name - Args: - column_names: Tuple of strings. Order must match the order in the file. - schema: A `Schema` proto. - delimiter: A one-character string used to separate fields. - secondary_delimiter: A one-character string used to separate values within - the same field. - multivalent_columns: A list of names for multivalent columns that need to - be split based on secondary delimiter. - - Raises: - ValueError: If `schema` is invalid. - """ - self._column_names = column_names - self._schema = schema - self._delimiter = delimiter - self._secondary_delimiter = secondary_delimiter - self._encoder = self._WriterWrapper(delimiter) - - if multivalent_columns is None: - multivalent_columns = [] - self._multivalent_columns = multivalent_columns - - if secondary_delimiter: - secondary_encoder = self._WriterWrapper(secondary_delimiter) - elif multivalent_columns: - raise ValueError( - 'secondary_delimiter unspecified for multivalent columns "{}"'.format( - multivalent_columns)) - secondary_encoder_by_name = { - name: secondary_encoder for name in multivalent_columns - } - indices_by_name = { - name: index for index, name in enumerate(self._column_names) - } - - def index(name): - index = indices_by_name.get(name) - if index is None: - raise ValueError('Column not found: "{}"'.format(name)) - else: - return index - - self._feature_handlers = [] - for name, feature_spec in schema_utils.schema_as_feature_spec( - schema).feature_spec.items(): - if isinstance(feature_spec, tf.io.FixedLenFeature): - self._feature_handlers.append( - _FixedLenFeatureHandler(name, feature_spec, index(name), - secondary_encoder_by_name.get(name))) - elif isinstance(feature_spec, tf.io.VarLenFeature): - self._feature_handlers.append( - _VarLenFeatureHandler(name, feature_spec.dtype, index(name), - secondary_encoder_by_name.get(name))) - elif isinstance(feature_spec, tf.io.SparseFeature): - index_keys = ( - feature_spec.index_key if isinstance(feature_spec.index_key, list) - else [feature_spec.index_key]) - for key in index_keys: - self._feature_handlers.append( - _VarLenFeatureHandler(key, tf.int64, index(key), - secondary_encoder_by_name.get(name))) - self._feature_handlers.append( - _VarLenFeatureHandler(feature_spec.value_key, feature_spec.dtype, - index(feature_spec.value_key), - secondary_encoder_by_name.get(name))) - else: - raise ValueError( - 'feature_spec should be one of tf.FixedLenFeature, ' - 'tf.VarLenFeature or tf.SparseFeature: {!r} was {!r}'.format( - name, type(feature_spec))) - - def __reduce__(self): - return self.__class__, (self._column_names, self._schema, self._delimiter, - self._secondary_delimiter, - self._multivalent_columns) - - def encode(self, instance): - """Encode a tf.transform encoded dict to a csv-formatted string. + def encode_value(self, string_list, values): + """Encode the value of this feature into the CSV line.""" + if self._encoder: + string_list[self._index] = self._encoder.encode_record(values) + else: + string_list[self._index] = _to_string(values[0]) if values else "" - Args: - instance: A python dictionary where the keys are the column names and the - values are fixed len or var len encoded features. - Returns: - A csv-formatted string. The order of the columns is given by column_names. - """ - string_list = [None] * len(self._column_names) - for feature_handler in self._feature_handlers: - try: - feature_handler.encode_value(string_list, - instance[feature_handler.name]) - except TypeError as e: - raise TypeError('{} while encoding feature "{}"'.format( - e, feature_handler.name)) - return self._encoder.encode_record(string_list) +class CsvCoder: + """A coder to encode CSV formatted data.""" + + class _WriterWrapper: + """A wrapper for csv.writer to make it picklable.""" + + def __init__(self, delimiter): + """Initializes the writer wrapper. + + Args: + ---- + delimiter: A one-character string used to separate fields. + """ + self._state = delimiter + self._buffer = io.StringIO() + + # Since we use self._writer to encode individual rows, we set + # lineterminator='' so that self._writer doesn't add a newline. + self._writer = csv.writer( + self._buffer, lineterminator="", delimiter=delimiter + ) + + def encode_record(self, record): + """Converts the record to bytes. + + Since csv writer only supports Unicode for PY3, we need to convert them + conditionally before calling csv writer. We always return result in bytes + format to be consistent with current behavior. + + Args: + ---- + record: The data to be converted. + + Returns: + ------- + Bytes representation input. + """ + self._writer.writerow([_to_string(x) for x in record]) + result = tf.compat.as_bytes(self._buffer.getvalue()) + # Reset the buffer. + self._buffer.seek(0) + self._buffer.truncate(0) + return result + + def __getstate__(self): + return self._state + + def __setstate__(self, state): + self.__init__(*state) + + def __init__( + self, + column_names, + schema, + delimiter=",", + secondary_delimiter=None, + multivalent_columns=None, + ): + """Initializes CsvCoder. + + Args: + ---- + column_names: Tuple of strings. Order must match the order in the file. + schema: A `Schema` proto. + delimiter: A one-character string used to separate fields. + secondary_delimiter: A one-character string used to separate values within + the same field. + multivalent_columns: A list of names for multivalent columns that need to + be split based on secondary delimiter. + + Raises: + ------ + ValueError: If `schema` is invalid. + """ + self._column_names = column_names + self._schema = schema + self._delimiter = delimiter + self._secondary_delimiter = secondary_delimiter + self._encoder = self._WriterWrapper(delimiter) + + if multivalent_columns is None: + multivalent_columns = [] + self._multivalent_columns = multivalent_columns + + if secondary_delimiter: + secondary_encoder = self._WriterWrapper(secondary_delimiter) + elif multivalent_columns: + raise ValueError( + f'secondary_delimiter unspecified for multivalent columns "{multivalent_columns}"' + ) + secondary_encoder_by_name = { + name: secondary_encoder for name in multivalent_columns + } + indices_by_name = {name: index for index, name in enumerate(self._column_names)} + + def index(name): + index = indices_by_name.get(name) + if index is None: + raise ValueError(f'Column not found: "{name}"') + else: + return index + + self._feature_handlers = [] + for name, feature_spec in schema_utils.schema_as_feature_spec( + schema + ).feature_spec.items(): + if isinstance(feature_spec, tf.io.FixedLenFeature): + self._feature_handlers.append( + _FixedLenFeatureHandler( + name, + feature_spec, + index(name), + secondary_encoder_by_name.get(name), + ) + ) + elif isinstance(feature_spec, tf.io.VarLenFeature): + self._feature_handlers.append( + _VarLenFeatureHandler( + name, + feature_spec.dtype, + index(name), + secondary_encoder_by_name.get(name), + ) + ) + elif isinstance(feature_spec, tf.io.SparseFeature): + index_keys = ( + feature_spec.index_key + if isinstance(feature_spec.index_key, list) + else [feature_spec.index_key] + ) + for key in index_keys: + self._feature_handlers.append( + _VarLenFeatureHandler( + key, + tf.int64, + index(key), + secondary_encoder_by_name.get(name), + ) + ) + self._feature_handlers.append( + _VarLenFeatureHandler( + feature_spec.value_key, + feature_spec.dtype, + index(feature_spec.value_key), + secondary_encoder_by_name.get(name), + ) + ) + else: + raise ValueError( + "feature_spec should be one of tf.FixedLenFeature, " + f"tf.VarLenFeature or tf.SparseFeature: {name!r} was {type(feature_spec)!r}" + ) + + def __reduce__(self): + return self.__class__, ( + self._column_names, + self._schema, + self._delimiter, + self._secondary_delimiter, + self._multivalent_columns, + ) + + def encode(self, instance): + """Encode a tf.transform encoded dict to a csv-formatted string. + + Args: + ---- + instance: A python dictionary where the keys are the column names and the + values are fixed len or var len encoded features. + + Returns: + ------- + A csv-formatted string. The order of the columns is given by column_names. + """ + string_list = [None] * len(self._column_names) + for feature_handler in self._feature_handlers: + try: + feature_handler.encode_value( + string_list, instance[feature_handler.name] + ) + except TypeError as e: + raise TypeError(f'{e} while encoding feature "{feature_handler.name}"') + return self._encoder.encode_record(string_list) diff --git a/tensorflow_transform/coders/csv_coder_test.py b/tensorflow_transform/coders/csv_coder_test.py index 077ed8d..b3f7a37 100644 --- a/tensorflow_transform/coders/csv_coder_test.py +++ b/tensorflow_transform/coders/csv_coder_test.py @@ -1,4 +1,3 @@ -# coding=utf-8 # # Copyright 2017 Google Inc. All Rights Reserved. # @@ -14,283 +13,286 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tensorflow-transform CsvCoder tests.""" + import pickle import tensorflow as tf -from tensorflow_transform.coders import csv_coder + from tensorflow_transform import test_case +from tensorflow_transform.coders import csv_coder from tensorflow_transform.tf_metadata import schema_utils _COLUMNS = [ - 'numeric1', - 'text1', - 'category1', - 'idx', - 'numeric2', - 'value', - 'numeric3', - '2d_idx0', - '2d_idx1', - '2d_val', + "numeric1", + "text1", + "category1", + "idx", + "numeric2", + "value", + "numeric3", + "2d_idx0", + "2d_idx1", + "2d_val", ] _FEATURE_SPEC = { - 'numeric1': - tf.io.FixedLenFeature([], tf.int64), - 'numeric2': - tf.io.VarLenFeature(tf.float32), - 'numeric3': - tf.io.FixedLenFeature([1], tf.int64), - 'text1': - tf.io.FixedLenFeature([], tf.string), - 'category1': - tf.io.VarLenFeature(tf.string), - 'y': - tf.io.SparseFeature('idx', 'value', tf.float32, 10), - '2dsparse': - tf.io.SparseFeature(['2d_idx0', '2d_idx1'], '2d_val', tf.float32, - [2, 10]), + "numeric1": tf.io.FixedLenFeature([], tf.int64), + "numeric2": tf.io.VarLenFeature(tf.float32), + "numeric3": tf.io.FixedLenFeature([1], tf.int64), + "text1": tf.io.FixedLenFeature([], tf.string), + "category1": tf.io.VarLenFeature(tf.string), + "y": tf.io.SparseFeature("idx", "value", tf.float32, 10), + "2dsparse": tf.io.SparseFeature( + ["2d_idx0", "2d_idx1"], "2d_val", tf.float32, [2, 10] + ), } _ENCODE_CASES = [ dict( - testcase_name='multiple_columns', + testcase_name="multiple_columns", columns=_COLUMNS, feature_spec=_FEATURE_SPEC, csv_line='12,"this is a ,text",categorical_value,1,89.0,12.0,20,1,7,17.0', instance={ - 'category1': [b'categorical_value'], - 'numeric1': 12, - 'numeric2': [89.0], - 'numeric3': [20], - 'text1': b'this is a ,text', - 'idx': [1], - 'value': [12.0], - '2d_idx0': [1], - '2d_idx1': [7], - '2d_val': [17.0], - }), + "category1": [b"categorical_value"], + "numeric1": 12, + "numeric2": [89.0], + "numeric3": [20], + "text1": b"this is a ,text", + "idx": [1], + "value": [12.0], + "2d_idx0": [1], + "2d_idx1": [7], + "2d_val": [17.0], + }, + ), dict( - testcase_name='multiple_columns_unicode', + testcase_name="multiple_columns_unicode", columns=_COLUMNS, feature_spec=_FEATURE_SPEC, - csv_line=u'12,"this is a ,text",Hello κόσμε,1,89.0,12.0,20,1,7,17.0', + csv_line='12,"this is a ,text",Hello κόσμε,1,89.0,12.0,20,1,7,17.0', instance={ - 'category1': [u'Hello κόσμε'.encode('utf-8')], - 'numeric1': 12, - 'numeric2': [89.0], - 'numeric3': [20], - 'text1': b'this is a ,text', - 'idx': [1], - 'value': [12.0], - '2d_idx0': [1], - '2d_idx1': [7], - '2d_val': [17.0], - }), + "category1": ["Hello κόσμε".encode()], + "numeric1": 12, + "numeric2": [89.0], + "numeric3": [20], + "text1": b"this is a ,text", + "idx": [1], + "value": [12.0], + "2d_idx0": [1], + "2d_idx1": [7], + "2d_val": [17.0], + }, + ), dict( - testcase_name='multiple_columns_tab_separated', + testcase_name="multiple_columns_tab_separated", columns=_COLUMNS, feature_spec=_FEATURE_SPEC, csv_line=( '12\t"this is a \ttext"\tcategorical_value\t1\t89.0\t12.0\t20\t1\t7\t17.0' ), instance={ - 'category1': [b'categorical_value'], - 'numeric1': 12, - 'numeric2': [89.0], - 'numeric3': [20], - 'text1': b'this is a \ttext', - 'idx': [1], - 'value': [12.0], - '2d_idx0': [1], - '2d_idx1': [7], - '2d_val': [17.0], + "category1": [b"categorical_value"], + "numeric1": 12, + "numeric2": [89.0], + "numeric3": [20], + "text1": b"this is a \ttext", + "idx": [1], + "value": [12.0], + "2d_idx0": [1], + "2d_idx1": [7], + "2d_val": [17.0], }, - delimiter='\t'), + delimiter="\t", + ), dict( - testcase_name='multiple_columns_multivalent', - columns=[ - 'numeric1', 'category1', 'idx', 'numeric2', 'value', 'numeric3' - ], + testcase_name="multiple_columns_multivalent", + columns=["numeric1", "category1", "idx", "numeric2", "value", "numeric3"], feature_spec={ - 'numeric1': tf.io.FixedLenFeature([2], tf.int64), - 'numeric2': tf.io.VarLenFeature(tf.float32), - 'numeric3': tf.io.FixedLenFeature([1], tf.int64), - 'category1': tf.io.VarLenFeature(tf.string), - 'y': tf.io.SparseFeature('idx', 'value', tf.float32, 10), + "numeric1": tf.io.FixedLenFeature([2], tf.int64), + "numeric2": tf.io.VarLenFeature(tf.float32), + "numeric3": tf.io.FixedLenFeature([1], tf.int64), + "category1": tf.io.VarLenFeature(tf.string), + "y": tf.io.SparseFeature("idx", "value", tf.float32, 10), }, - csv_line=('11|12,categorical_value|other_value,1|3,89.0|91.0,' - '12.0|15.0,20'), + csv_line=("11|12,categorical_value|other_value,1|3,89.0|91.0," "12.0|15.0,20"), instance={ - 'category1': [b'categorical_value|other_value'], - 'numeric1': [11, 12], - 'numeric2': [89.0, 91.0], - 'numeric3': [20], - 'idx': [1, 3], - 'value': [12.0, 15.0], + "category1": [b"categorical_value|other_value"], + "numeric1": [11, 12], + "numeric2": [89.0, 91.0], + "numeric3": [20], + "idx": [1, 3], + "value": [12.0, 15.0], }, - secondary_delimiter='|', - multivalent_columns=['numeric1', 'numeric2', 'y']), + secondary_delimiter="|", + multivalent_columns=["numeric1", "numeric2", "y"], + ), dict( - testcase_name='scalar_int', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}, - csv_line='12', - instance={'x': 12}), + testcase_name="scalar_int", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([], tf.int64)}, + csv_line="12", + instance={"x": 12}, + ), dict( - testcase_name='scalar_float', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}, - csv_line='12', - instance={'x': 12}), + testcase_name="scalar_float", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([], tf.float32)}, + csv_line="12", + instance={"x": 12}, + ), dict( - testcase_name='size_1_vector_int', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([1], tf.int64)}, - csv_line='12', - instance={'x': [12]}), + testcase_name="size_1_vector_int", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([1], tf.int64)}, + csv_line="12", + instance={"x": [12]}, + ), dict( - testcase_name='1x1_matrix_int', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([1, 1], tf.int64)}, - csv_line='12', - instance={'x': [[12]]}), + testcase_name="1x1_matrix_int", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([1, 1], tf.int64)}, + csv_line="12", + instance={"x": [[12]]}, + ), dict( - testcase_name='unquoted_text', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([], tf.string)}, - csv_line='this is unquoted text', - instance={'x': b'this is unquoted text'}), + testcase_name="unquoted_text", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([], tf.string)}, + csv_line="this is unquoted text", + instance={"x": b"this is unquoted text"}, + ), dict( - testcase_name='quoted_text', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([], tf.string)}, + testcase_name="quoted_text", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([], tf.string)}, csv_line='"this is a ,text"', - instance={'x': b'this is a ,text'}), + instance={"x": b"this is a ,text"}, + ), dict( - testcase_name='var_len_text', - columns=['x'], - feature_spec={'x': tf.io.VarLenFeature(tf.string)}, - csv_line='a test', - instance={'x': [b'a test']}), + testcase_name="var_len_text", + columns=["x"], + feature_spec={"x": tf.io.VarLenFeature(tf.string)}, + csv_line="a test", + instance={"x": [b"a test"]}, + ), dict( - testcase_name='sparse_float_one_value', - columns=['idx', 'value'], - feature_spec={'x': tf.io.SparseFeature('idx', 'value', tf.float32, 10)}, - csv_line='5,2.0', - instance={ - 'idx': [5], - 'value': [2.0] - }), + testcase_name="sparse_float_one_value", + columns=["idx", "value"], + feature_spec={"x": tf.io.SparseFeature("idx", "value", tf.float32, 10)}, + csv_line="5,2.0", + instance={"idx": [5], "value": [2.0]}, + ), dict( - testcase_name='sparse_float_no_values', - columns=['idx', 'value'], - feature_spec={'x': tf.io.SparseFeature('idx', 'value', tf.float32, 10)}, - csv_line=',', - instance={ - 'idx': [], - 'value': [] - }), + testcase_name="sparse_float_no_values", + columns=["idx", "value"], + feature_spec={"x": tf.io.SparseFeature("idx", "value", tf.float32, 10)}, + csv_line=",", + instance={"idx": [], "value": []}, + ), dict( - testcase_name='size_2_vector_int_multivalent', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([2], tf.int64)}, - csv_line='12|14', - instance={'x': [12, 14]}, - secondary_delimiter='|', - multivalent_columns=['x']), + testcase_name="size_2_vector_int_multivalent", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([2], tf.int64)}, + csv_line="12|14", + instance={"x": [12, 14]}, + secondary_delimiter="|", + multivalent_columns=["x"], + ), dict( - testcase_name='2x2_matrix_int_multivalent', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([2, 2], tf.int64)}, - csv_line='12|13|14|15', - instance={'x': [[12, 13], [14, 15]]}, - secondary_delimiter='|', - multivalent_columns=['x']), + testcase_name="2x2_matrix_int_multivalent", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([2, 2], tf.int64)}, + csv_line="12|13|14|15", + instance={"x": [[12, 13], [14, 15]]}, + secondary_delimiter="|", + multivalent_columns=["x"], + ), ] _CONSTRUCTOR_ERROR_CASES = [ dict( - testcase_name='missing_column', + testcase_name="missing_column", columns=[], - feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}, - error_msg='Column not found: '), + feature_spec={"x": tf.io.FixedLenFeature([], tf.int64)}, + error_msg="Column not found: ", + ), ] _ENCODE_ERROR_CASES = [ dict( - testcase_name='multivalent_size_2_vector_3_values', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([2], tf.string)}, - instance={'x': [1, 2, 3]}, - error_msg=r'FixedLenFeature \"x\" got wrong number of values', - secondary_delimiter='|', - multivalent_columns=['x']), + testcase_name="multivalent_size_2_vector_3_values", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([2], tf.string)}, + instance={"x": [1, 2, 3]}, + error_msg=r"FixedLenFeature \"x\" got wrong number of values", + secondary_delimiter="|", + multivalent_columns=["x"], + ), dict( - testcase_name='multivalent_size_2_vector_1_value', - columns=['x'], - feature_spec={'x': tf.io.FixedLenFeature([2], tf.string)}, - instance={'x': [1]}, - error_msg=r'FixedLenFeature \"x\" got wrong number of values', - secondary_delimiter='|', - multivalent_columns=['x']), + testcase_name="multivalent_size_2_vector_1_value", + columns=["x"], + feature_spec={"x": tf.io.FixedLenFeature([2], tf.string)}, + instance={"x": [1]}, + error_msg=r"FixedLenFeature \"x\" got wrong number of values", + secondary_delimiter="|", + multivalent_columns=["x"], + ), ] class TestCSVCoder(test_case.TransformTestCase): + @test_case.named_parameters(*_ENCODE_CASES) + def test_encode(self, columns, feature_spec, csv_line, instance, **kwargs): + schema = schema_utils.schema_from_feature_spec(feature_spec) + coder = csv_coder.CsvCoder(columns, schema, **kwargs) + self.assertEqual(coder.encode(instance), csv_line.encode("utf-8")) - @test_case.named_parameters(*_ENCODE_CASES) - def test_encode(self, columns, feature_spec, csv_line, instance, **kwargs): - schema = schema_utils.schema_from_feature_spec(feature_spec) - coder = csv_coder.CsvCoder(columns, schema, **kwargs) - self.assertEqual(coder.encode(instance), csv_line.encode('utf-8')) - - @test_case.named_parameters(*_CONSTRUCTOR_ERROR_CASES) - def test_constructor_error(self, - columns, - feature_spec, - error_msg, - error_type=ValueError, - **kwargs): - schema = schema_utils.schema_from_feature_spec(feature_spec) - with self.assertRaisesRegex(error_type, error_msg): - csv_coder.CsvCoder(columns, schema, **kwargs) + @test_case.named_parameters(*_CONSTRUCTOR_ERROR_CASES) + def test_constructor_error( + self, columns, feature_spec, error_msg, error_type=ValueError, **kwargs + ): + schema = schema_utils.schema_from_feature_spec(feature_spec) + with self.assertRaisesRegex(error_type, error_msg): + csv_coder.CsvCoder(columns, schema, **kwargs) - @test_case.named_parameters(*_ENCODE_ERROR_CASES) - def test_encode_error(self, - columns, - feature_spec, - instance, - error_msg, - error_type=ValueError, - **kwargs): - schema = schema_utils.schema_from_feature_spec(feature_spec) - coder = csv_coder.CsvCoder(columns, schema, **kwargs) - with self.assertRaisesRegex(error_type, error_msg): - coder.encode(instance) + @test_case.named_parameters(*_ENCODE_ERROR_CASES) + def test_encode_error( + self, + columns, + feature_spec, + instance, + error_msg, + error_type=ValueError, + **kwargs, + ): + schema = schema_utils.schema_from_feature_spec(feature_spec) + coder = csv_coder.CsvCoder(columns, schema, **kwargs) + with self.assertRaisesRegex(error_type, error_msg): + coder.encode(instance) - def test_picklable(self): - csv_line = '12,"this is a ,text",categorical_value,1,89.0,12.0,20,1,7,17.0' - instance = { - 'category1': [b'categorical_value'], - 'numeric1': 12, - 'numeric2': [89.0], - 'numeric3': [20], - 'text1': b'this is a ,text', - 'idx': [1], - 'value': [12.0], - '2d_idx0': [1], - '2d_idx1': [7], - '2d_val': [17.0], - } - schema = schema_utils.schema_from_feature_spec(_FEATURE_SPEC) - coder = csv_coder.CsvCoder(_COLUMNS, schema) - # Repeat twice to ensure the act of encoding/decoding doesn't break - # pickling. - for _ in range(2): - coder = pickle.loads(pickle.dumps(coder)) - self.assertEqual(coder.encode(instance), csv_line.encode('utf-8')) + def test_picklable(self): + csv_line = '12,"this is a ,text",categorical_value,1,89.0,12.0,20,1,7,17.0' + instance = { + "category1": [b"categorical_value"], + "numeric1": 12, + "numeric2": [89.0], + "numeric3": [20], + "text1": b"this is a ,text", + "idx": [1], + "value": [12.0], + "2d_idx0": [1], + "2d_idx1": [7], + "2d_val": [17.0], + } + schema = schema_utils.schema_from_feature_spec(_FEATURE_SPEC) + coder = csv_coder.CsvCoder(_COLUMNS, schema) + # Repeat twice to ensure the act of encoding/decoding doesn't break + # pickling. + for _ in range(2): + coder = pickle.loads(pickle.dumps(coder)) + self.assertEqual(coder.encode(instance), csv_line.encode("utf-8")) -if __name__ == '__main__': - test_case.main() +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/coders/example_proto_coder.py b/tensorflow_transform/coders/example_proto_coder.py index 1ba7e50..b6575c2 100644 --- a/tensorflow_transform/coders/example_proto_coder.py +++ b/tensorflow_transform/coders/example_proto_coder.py @@ -18,6 +18,7 @@ import numpy as np import tensorflow as tf + from tensorflow_transform.tf_metadata import schema_utils @@ -25,235 +26,259 @@ # the protocol buffer library installed in the workers (which might be different # from the one installed in the pipeline constructor). def _make_cast_fn(np_dtype): - """Return a function to extract the typed value from the feature. + """Return a function to extract the typed value from the feature. - For performance reasons it is preferred to have the cast fn - constructed once (for each handler). + For performance reasons it is preferred to have the cast fn + constructed once (for each handler). - Args: - np_dtype: The numpy type of the Tensorflow feature. + Args: + ---- + np_dtype: The numpy type of the Tensorflow feature. - Returns: - A function to extract the value field from a string depending on dtype. - """ + Returns: + ------- + A function to extract the value field from a string depending on dtype. + """ - def identity(x): - return x + def identity(x): + return x - # This is in agreement with Tensorflow conversions for Unicode values for both - # Python 2 and 3 (and also works for non-Unicode objects). It is also in - # agreement with the testWithUnicode of the Beam impl. - def utf8(s): - return s if isinstance(s, bytes) else s.encode('utf-8') + # This is in agreement with Tensorflow conversions for Unicode values for both + # Python 2 and 3 (and also works for non-Unicode objects). It is also in + # agreement with the testWithUnicode of the Beam impl. + def utf8(s): + return s if isinstance(s, bytes) else s.encode("utf-8") - vectorize = np.vectorize(utf8) + vectorize = np.vectorize(utf8) - def string_cast(x): - if isinstance(x, list) or isinstance(x, np.ndarray) and x.ndim > 0: - return map(utf8, x) - elif isinstance(x, np.ndarray): - return vectorize(x).tolist() - return utf8(x) + def string_cast(x): + if isinstance(x, list) or isinstance(x, np.ndarray) and x.ndim > 0: + return map(utf8, x) + elif isinstance(x, np.ndarray): + return vectorize(x).tolist() + return utf8(x) - if issubclass(np_dtype, np.floating) or issubclass(np_dtype, np.integer): - return identity + if issubclass(np_dtype, np.floating) or issubclass(np_dtype, np.integer): + return identity - return string_cast + return string_cast def _make_feature_value_fn(dtype): - """Return a function to extract the typed value from the feature. + """Return a function to extract the typed value from the feature. - For performance reasons it is preferred to have the feature value fn - constructed once (for each handler). + For performance reasons it is preferred to have the feature value fn + constructed once (for each handler). - Args: - dtype: The type of the Tensorflow feature. + Args: + ---- + dtype: The type of the Tensorflow feature. - Returns: - A function to extract the value field from the feature depending on dtype. - """ - if dtype.is_integer: - return lambda feature: feature.int64_list.value + Returns: + ------- + A function to extract the value field from the feature depending on dtype. + """ + if dtype.is_integer: + return lambda feature: feature.int64_list.value - if dtype.is_floating: - return lambda feature: feature.float_list.value + if dtype.is_floating: + return lambda feature: feature.float_list.value - return lambda feature: feature.bytes_list.value + return lambda feature: feature.bytes_list.value class _FixedLenFeatureHandler: - """Handler for `FixedLenFeature` values. - - `FixedLenFeature` values will be parsed to a list of the corresponding - dtype. - """ - - def __init__(self, name, feature_spec): - self._name = name - self._np_dtype = feature_spec.dtype.as_numpy_dtype - self._value_fn = _make_feature_value_fn(feature_spec.dtype) - self._rank = len(feature_spec.shape) - self._size = 1 - for dim in feature_spec.shape: - self._size *= dim - - @property - def name(self): - """The name of the feature.""" - return self._name - - def initialize_encode_cache(self, example): - """Initialize fields (performance caches) that point to example's state.""" - self._cast_fn = _make_cast_fn(self._np_dtype) - self._value = self._value_fn(example.features.feature[self._name]) - - def encode_value(self, values): - """Encodes a feature into its Example proto representation.""" - del self._value[:] - if self._rank == 0: - scalar_value = values if not isinstance(values, - np.ndarray) else values.item() - self._value.append(self._cast_fn(scalar_value)) - else: - flattened_values = ( - values if self._rank == 1 else np.asarray( - values, dtype=self._np_dtype).reshape(-1)) - if len(flattened_values) != self._size: - raise ValueError('FixedLenFeature %r got wrong number of values. ' - 'Expected %d but got %d' % - (self._name, self._size, len(flattened_values))) - self._value.extend(self._cast_fn(flattened_values)) + """Handler for `FixedLenFeature` values. + `FixedLenFeature` values will be parsed to a list of the corresponding + dtype. + """ -class _VarLenFeatureHandler: - """Handler for `VarLenFeature` values. - - `VarLenFeature` values will be parsed as an array of the corresponding dtype. - """ - - def __init__(self, name, dtype): - self._name = name - self._np_dtype = dtype.as_numpy_dtype - self._value_fn = _make_feature_value_fn(dtype) - - @property - def name(self): - """The name of the feature.""" - return self._name - - def initialize_encode_cache(self, example): - """Initialize fields (performance caches) that point to example's state.""" - self._cast_fn = _make_cast_fn(self._np_dtype) - self._feature = example.features.feature[self._name] - self._value = self._value_fn(self._feature) - - def encode_value(self, values): - """Encode values as tf.train.Feature.""" - if values is None: - self._feature.Clear() - # Note after Clear(), self._value no longer points to a submessage of - # self._feature so we need to reset it. - self._value = self._value_fn(self._feature) - else: - del self._value[:] - - # Scalar must be length 1 array. - values = values if isinstance(values, (list, np.ndarray)) else [values] - casted = self._cast_fn(values) - self._value.extend(casted) + def __init__(self, name, feature_spec): + self._name = name + self._np_dtype = feature_spec.dtype.as_numpy_dtype + self._value_fn = _make_feature_value_fn(feature_spec.dtype) + self._rank = len(feature_spec.shape) + self._size = 1 + for dim in feature_spec.shape: + self._size *= dim + + @property + def name(self): + """The name of the feature.""" + return self._name + + def initialize_encode_cache(self, example): + """Initialize fields (performance caches) that point to example's state.""" + self._cast_fn = _make_cast_fn(self._np_dtype) + self._value = self._value_fn(example.features.feature[self._name]) + + def encode_value(self, values): + """Encodes a feature into its Example proto representation.""" + del self._value[:] + if self._rank == 0: + scalar_value = ( + values if not isinstance(values, np.ndarray) else values.item() + ) + self._value.append(self._cast_fn(scalar_value)) + else: + flattened_values = ( + values + if self._rank == 1 + else np.asarray(values, dtype=self._np_dtype).reshape(-1) + ) + if len(flattened_values) != self._size: + raise ValueError( + "FixedLenFeature %r got wrong number of values. " + "Expected %d but got %d" + % (self._name, self._size, len(flattened_values)) + ) + self._value.extend(self._cast_fn(flattened_values)) -class ExampleProtoCoder: - """A coder between maybe-serialized TF Examples and tf.Transform datasets.""" +class _VarLenFeatureHandler: + """Handler for `VarLenFeature` values. - def __init__(self, schema, serialized=True): - """Build an ExampleProtoCoder. + `VarLenFeature` values will be parsed as an array of the corresponding dtype. + """ - Args: - schema: A `Schema` proto. - serialized: Whether to encode serialized Example protos (as opposed to - in-memory Example protos). + def __init__(self, name, dtype): + self._name = name + self._np_dtype = dtype.as_numpy_dtype + self._value_fn = _make_feature_value_fn(dtype) + + @property + def name(self): + """The name of the feature.""" + return self._name + + def initialize_encode_cache(self, example): + """Initialize fields (performance caches) that point to example's state.""" + self._cast_fn = _make_cast_fn(self._np_dtype) + self._feature = example.features.feature[self._name] + self._value = self._value_fn(self._feature) + + def encode_value(self, values): + """Encode values as tf.train.Feature.""" + if values is None: + self._feature.Clear() + # Note after Clear(), self._value no longer points to a submessage of + # self._feature so we need to reset it. + self._value = self._value_fn(self._feature) + else: + del self._value[:] + + # Scalar must be length 1 array. + values = values if isinstance(values, (list, np.ndarray)) else [values] + casted = self._cast_fn(values) + self._value.extend(casted) - Raises: - ValueError: If `schema` is invalid. - """ - self._schema = schema - self._serialized = serialized - - # Using pre-allocated tf.train.Example and FeatureHandler objects for - # performance reasons. - # - # Since the output of "encode" is deep as opposed to shallow - # transformations, and since the schema always fully defines the Example's - # FeatureMap (ie all fields are always cleared/assigned or copied), the - # optimization and implementation are correct and thread-compatible. - self._encode_example_cache = tf.train.Example() - self._feature_handlers = [] - for name, feature_spec in schema_utils.schema_as_feature_spec( - schema).feature_spec.items(): - if isinstance(feature_spec, tf.io.FixedLenFeature): - self._feature_handlers.append( - _FixedLenFeatureHandler(name, feature_spec)) - elif isinstance(feature_spec, tf.io.VarLenFeature): - self._feature_handlers.append( - _VarLenFeatureHandler(name, feature_spec.dtype)) - elif isinstance(feature_spec, tf.io.SparseFeature): - index_keys = ( - feature_spec.index_key if isinstance(feature_spec.index_key, list) - else [feature_spec.index_key]) - for index_key in index_keys: - self._feature_handlers.append( - _VarLenFeatureHandler(index_key, tf.int64)) - self._feature_handlers.append( - _VarLenFeatureHandler(feature_spec.value_key, feature_spec.dtype)) - elif isinstance(feature_spec, tf.io.RaggedFeature): - uniform_partition = False - for partition in feature_spec.partitions: - if isinstance(partition, tf.io.RaggedFeature.RowLengths): - if uniform_partition: - raise ValueError( - 'Encountered ragged dimension after uniform for feature ' - '"{}": only inner dimensions can be uniform. Feature spec ' - 'is {}'.format(name, feature_spec)) - self._feature_handlers.append( - _VarLenFeatureHandler(partition.key, tf.int64)) - elif isinstance(partition, tf.io.RaggedFeature.UniformRowLength): - # We don't encode uniform partitions since they can be recovered - # from the shape information. - uniform_partition = True - else: - raise ValueError( - 'Only `RowLengths` and `UniformRowLength` partitions of ragged ' - 'features are supported, got {}'.format(type(partition))) - self._feature_handlers.append( - _VarLenFeatureHandler(feature_spec.value_key, feature_spec.dtype)) - else: - raise ValueError('feature_spec should be one of tf.io.FixedLenFeature, ' - 'tf.io.VarLenFeature, tf.io.SparseFeature or ' - 'tf.io.RaggedFeature: "{}" was {}'.format( - name, type(feature_spec))) - - for feature_handler in self._feature_handlers: - feature_handler.initialize_encode_cache(self._encode_example_cache) - - def __reduce__(self): - return self.__class__, (self._schema, self._serialized) - - def encode(self, instance): - """Encode a tf.transform encoded dict as tf.Example.""" - # The feature handles encode using the self._encode_example_cache. - for feature_handler in self._feature_handlers: - value = instance[feature_handler.name] - try: - feature_handler.encode_value(value) - except TypeError as e: - raise TypeError('%s while encoding feature "%s"' % - (e, feature_handler.name)) - - if self._serialized: - return self._encode_example_cache.SerializeToString() - - result = tf.train.Example() - result.CopyFrom(self._encode_example_cache) - return result + +class ExampleProtoCoder: + """A coder between maybe-serialized TF Examples and tf.Transform datasets.""" + + def __init__(self, schema, serialized=True): + """Build an ExampleProtoCoder. + + Args: + ---- + schema: A `Schema` proto. + serialized: Whether to encode serialized Example protos (as opposed to + in-memory Example protos). + + Raises: + ------ + ValueError: If `schema` is invalid. + """ + self._schema = schema + self._serialized = serialized + + # Using pre-allocated tf.train.Example and FeatureHandler objects for + # performance reasons. + # + # Since the output of "encode" is deep as opposed to shallow + # transformations, and since the schema always fully defines the Example's + # FeatureMap (ie all fields are always cleared/assigned or copied), the + # optimization and implementation are correct and thread-compatible. + self._encode_example_cache = tf.train.Example() + self._feature_handlers = [] + for name, feature_spec in schema_utils.schema_as_feature_spec( + schema + ).feature_spec.items(): + if isinstance(feature_spec, tf.io.FixedLenFeature): + self._feature_handlers.append( + _FixedLenFeatureHandler(name, feature_spec) + ) + elif isinstance(feature_spec, tf.io.VarLenFeature): + self._feature_handlers.append( + _VarLenFeatureHandler(name, feature_spec.dtype) + ) + elif isinstance(feature_spec, tf.io.SparseFeature): + index_keys = ( + feature_spec.index_key + if isinstance(feature_spec.index_key, list) + else [feature_spec.index_key] + ) + for index_key in index_keys: + self._feature_handlers.append( + _VarLenFeatureHandler(index_key, tf.int64) + ) + self._feature_handlers.append( + _VarLenFeatureHandler(feature_spec.value_key, feature_spec.dtype) + ) + elif isinstance(feature_spec, tf.io.RaggedFeature): + uniform_partition = False + for partition in feature_spec.partitions: + if isinstance(partition, tf.io.RaggedFeature.RowLengths): + if uniform_partition: + raise ValueError( + "Encountered ragged dimension after uniform for feature " + f'"{name}": only inner dimensions can be uniform. Feature spec ' + f"is {feature_spec}" + ) + self._feature_handlers.append( + _VarLenFeatureHandler(partition.key, tf.int64) + ) + elif isinstance(partition, tf.io.RaggedFeature.UniformRowLength): + # We don't encode uniform partitions since they can be recovered + # from the shape information. + uniform_partition = True + else: + raise ValueError( + "Only `RowLengths` and `UniformRowLength` partitions of ragged " + f"features are supported, got {type(partition)}" + ) + self._feature_handlers.append( + _VarLenFeatureHandler(feature_spec.value_key, feature_spec.dtype) + ) + else: + raise ValueError( + "feature_spec should be one of tf.io.FixedLenFeature, " + "tf.io.VarLenFeature, tf.io.SparseFeature or " + f'tf.io.RaggedFeature: "{name}" was {type(feature_spec)}' + ) + + for feature_handler in self._feature_handlers: + feature_handler.initialize_encode_cache(self._encode_example_cache) + + def __reduce__(self): + return self.__class__, (self._schema, self._serialized) + + def encode(self, instance): + """Encode a tf.transform encoded dict as tf.Example.""" + # The feature handles encode using the self._encode_example_cache. + for feature_handler in self._feature_handlers: + value = instance[feature_handler.name] + try: + feature_handler.encode_value(value) + except TypeError as e: + raise TypeError( + '%s while encoding feature "%s"' % (e, feature_handler.name) + ) + + if self._serialized: + return self._encode_example_cache.SerializeToString() + + result = tf.train.Example() + result.CopyFrom(self._encode_example_cache) + return result diff --git a/tensorflow_transform/coders/example_proto_coder_test.py b/tensorflow_transform/coders/example_proto_coder_test.py index e22765f..512d2fd 100644 --- a/tensorflow_transform/coders/example_proto_coder_test.py +++ b/tensorflow_transform/coders/example_proto_coder_test.py @@ -22,114 +22,105 @@ # Note that this needs to happen before any non-python imports, so we do it # pretty early on. -if any(arg == '--proto_implementation_type=python' for arg in sys.argv): - os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' -elif any(arg == '--proto_implementation_type=cpp' for arg in sys.argv): - os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' - os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' -elif any(arg.startswith('--proto_implementation_type') for arg in sys.argv): - raise ValueError('Unexpected value for --proto_implementation_type') +if any(arg == "--proto_implementation_type=python" for arg in sys.argv): + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" +elif any(arg == "--proto_implementation_type=cpp" for arg in sys.argv): + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp" + os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2" +elif any(arg.startswith("--proto_implementation_type") for arg in sys.argv): + raise ValueError("Unexpected value for --proto_implementation_type") # pylint: disable=g-import-not-at-top import numpy as np import tensorflow as tf -from tensorflow_transform.coders import example_proto_coder +from google.protobuf import text_format +from google.protobuf.internal import api_implementation + from tensorflow_transform import test_case +from tensorflow_transform.coders import example_proto_coder from tensorflow_transform.tf_metadata import schema_utils -from google.protobuf.internal import api_implementation -from google.protobuf import text_format # pylint: enable=g-import-not-at-top flags.DEFINE_string( - 'proto_implementation_type', 'cpp', - 'The implementation type of python proto to use when exercising this test') + "proto_implementation_type", + "cpp", + "The implementation type of python proto to use when exercising this test", +) _FEATURE_SPEC = { - 'scalar_feature_1': - tf.io.FixedLenFeature([], tf.int64), - 'scalar_feature_2': - tf.io.FixedLenFeature([], tf.int64), - 'scalar_feature_3': - tf.io.FixedLenFeature([], tf.float32), - 'varlen_feature_1': - tf.io.VarLenFeature(tf.float32), - 'varlen_feature_2': - tf.io.VarLenFeature(tf.string), - '1d_vector_feature': - tf.io.FixedLenFeature([1], tf.string), - '2d_vector_feature': - tf.io.FixedLenFeature([2, 2], tf.float32), - 'sparse_feature': - tf.io.SparseFeature('sparse_idx', 'sparse_val', tf.float32, 10), - '2d_sparse_feature': - tf.io.SparseFeature(['2d_sparse_idx0', '2d_sparse_idx1'], - '2d_sparse_val', tf.float32, [2, 10]), - 'ragged_feature': - tf.io.RaggedFeature( - tf.float32, - value_key='ragged_val', - partitions=[tf.io.RaggedFeature.RowLengths('ragged_row_lengths1')]), - '2d_ragged_feature': - tf.io.RaggedFeature( - tf.string, - value_key='2d_ragged_val', - partitions=[ - tf.io.RaggedFeature.RowLengths('2d_ragged_row_lengths1'), - tf.io.RaggedFeature.RowLengths('2d_ragged_row_lengths2') - ]), - 'ragged_uniform_feature': - tf.io.RaggedFeature( - tf.int64, - value_key='ragged_uniform_val', - partitions=[tf.io.RaggedFeature.UniformRowLength(2)]), - '2d_ragged_uniform_feature': - tf.io.RaggedFeature( - tf.int64, - value_key='2d_ragged_uniform_val', - partitions=[ - tf.io.RaggedFeature.RowLengths( - '2d_ragged_uniform_row_lengths1'), - tf.io.RaggedFeature.UniformRowLength(2) - ]), + "scalar_feature_1": tf.io.FixedLenFeature([], tf.int64), + "scalar_feature_2": tf.io.FixedLenFeature([], tf.int64), + "scalar_feature_3": tf.io.FixedLenFeature([], tf.float32), + "varlen_feature_1": tf.io.VarLenFeature(tf.float32), + "varlen_feature_2": tf.io.VarLenFeature(tf.string), + "1d_vector_feature": tf.io.FixedLenFeature([1], tf.string), + "2d_vector_feature": tf.io.FixedLenFeature([2, 2], tf.float32), + "sparse_feature": tf.io.SparseFeature("sparse_idx", "sparse_val", tf.float32, 10), + "2d_sparse_feature": tf.io.SparseFeature( + ["2d_sparse_idx0", "2d_sparse_idx1"], "2d_sparse_val", tf.float32, [2, 10] + ), + "ragged_feature": tf.io.RaggedFeature( + tf.float32, + value_key="ragged_val", + partitions=[tf.io.RaggedFeature.RowLengths("ragged_row_lengths1")], + ), + "2d_ragged_feature": tf.io.RaggedFeature( + tf.string, + value_key="2d_ragged_val", + partitions=[ + tf.io.RaggedFeature.RowLengths("2d_ragged_row_lengths1"), + tf.io.RaggedFeature.RowLengths("2d_ragged_row_lengths2"), + ], + ), + "ragged_uniform_feature": tf.io.RaggedFeature( + tf.int64, + value_key="ragged_uniform_val", + partitions=[tf.io.RaggedFeature.UniformRowLength(2)], + ), + "2d_ragged_uniform_feature": tf.io.RaggedFeature( + tf.int64, + value_key="2d_ragged_uniform_val", + partitions=[ + tf.io.RaggedFeature.RowLengths("2d_ragged_uniform_row_lengths1"), + tf.io.RaggedFeature.UniformRowLength(2), + ], + ), } _ENCODE_CASES = { - 'unicode': - dict( - testcase_name='unicode', - feature_spec={ - 'unicode_feature': tf.io.FixedLenFeature([], tf.string) - }, - ascii_proto="""\ + "unicode": dict( + testcase_name="unicode", + feature_spec={"unicode_feature": tf.io.FixedLenFeature([], tf.string)}, + ascii_proto="""\ features { feature { key: "unicode_feature" value { bytes_list { value: [ "Hello κόσμε" ] } } } }""", - instance={'unicode_feature': u'Hello κόσμε'}), - 'scalar_string_to_varlen': - dict( - testcase_name='scalar_string_to_varlen', - feature_spec={'varlen_string': tf.io.VarLenFeature(tf.string)}, - ascii_proto="""\ + instance={"unicode_feature": "Hello κόσμε"}, + ), + "scalar_string_to_varlen": dict( + testcase_name="scalar_string_to_varlen", + feature_spec={"varlen_string": tf.io.VarLenFeature(tf.string)}, + ascii_proto="""\ features { feature { key: "varlen_string" value { bytes_list { value: [ "foo" ] } } } }""", - instance={'varlen_string': 'foo'}), - 'scalar_int_to_varlen': - dict( - testcase_name='scalar_int_to_varlen', - feature_spec={'varlen_int': tf.io.VarLenFeature(tf.int64)}, - ascii_proto="""\ + instance={"varlen_string": "foo"}, + ), + "scalar_int_to_varlen": dict( + testcase_name="scalar_int_to_varlen", + feature_spec={"varlen_int": tf.io.VarLenFeature(tf.int64)}, + ascii_proto="""\ features { feature { key: "varlen_int" value { int64_list { value: [ 123 ] } } } }""", - instance={'varlen_int': 123}), - 'multiple_columns': - dict( - testcase_name='multiple_columns', - feature_spec=_FEATURE_SPEC, - ascii_proto="""\ + instance={"varlen_int": 123}, + ), + "multiple_columns": dict( + testcase_name="multiple_columns", + feature_spec=_FEATURE_SPEC, + ascii_proto="""\ features { feature { key: "scalar_feature_1" value { int64_list { value: [ 12 ] } } } feature { key: "varlen_feature_1" @@ -150,7 +141,7 @@ feature { key: "2d_sparse_val" value { float_list { value: [ 13.0, 23.0 ] } } } }""", - ragged_ascii_proto=""" + ragged_ascii_proto=""" feature { key: "ragged_val" value { float_list { value: [ 7.0, 13.0, 21.0 ] } } } feature { key: "ragged_row_lengths1" @@ -169,35 +160,35 @@ value { int64_list { value: [ 1, 0, 2 ] } } } } """, - instance={ - 'scalar_feature_1': 12, - 'scalar_feature_2': 12, - 'scalar_feature_3': 1.0, - 'varlen_feature_1': [89.0], - '1d_vector_feature': [b'this is a ,text'], - '2d_vector_feature': [[1.0, 2.0], [3.0, 4.0]], - 'varlen_feature_2': [b'female'], - 'sparse_idx': [1, 4], - 'sparse_val': [12.0, 20.0], - '2d_sparse_idx0': [1, 1], - '2d_sparse_idx1': [3, 7], - '2d_sparse_val': [13.0, 23.0], - }, - ragged_instance={ - 'ragged_val': [7.0, 13.0, 21.0], - 'ragged_row_lengths1': [1, 2], - '2d_ragged_val': [b'aa a', b'abc', b'hi'], - '2d_ragged_row_lengths1': [0, 3], - '2d_ragged_row_lengths2': [1, 0, 2], - 'ragged_uniform_val': [1, -1, 2, 1, -1, 2], - '2d_ragged_uniform_val': [1, -1, 2, 1, -1, 2], - '2d_ragged_uniform_row_lengths1': [1, 0, 2], - }), - 'multiple_columns_ndarray': - dict( - testcase_name='multiple_columns_ndarray', - feature_spec=_FEATURE_SPEC, - ascii_proto="""\ + instance={ + "scalar_feature_1": 12, + "scalar_feature_2": 12, + "scalar_feature_3": 1.0, + "varlen_feature_1": [89.0], + "1d_vector_feature": [b"this is a ,text"], + "2d_vector_feature": [[1.0, 2.0], [3.0, 4.0]], + "varlen_feature_2": [b"female"], + "sparse_idx": [1, 4], + "sparse_val": [12.0, 20.0], + "2d_sparse_idx0": [1, 1], + "2d_sparse_idx1": [3, 7], + "2d_sparse_val": [13.0, 23.0], + }, + ragged_instance={ + "ragged_val": [7.0, 13.0, 21.0], + "ragged_row_lengths1": [1, 2], + "2d_ragged_val": [b"aa a", b"abc", b"hi"], + "2d_ragged_row_lengths1": [0, 3], + "2d_ragged_row_lengths2": [1, 0, 2], + "ragged_uniform_val": [1, -1, 2, 1, -1, 2], + "2d_ragged_uniform_val": [1, -1, 2, 1, -1, 2], + "2d_ragged_uniform_row_lengths1": [1, 0, 2], + }, + ), + "multiple_columns_ndarray": dict( + testcase_name="multiple_columns_ndarray", + feature_spec=_FEATURE_SPEC, + ascii_proto="""\ features { feature { key: "scalar_feature_1" value { int64_list { value: [ 13 ] } } } feature { key: "varlen_feature_1" value { float_list { } } } @@ -218,7 +209,7 @@ feature { key: "2d_sparse_val" value { float_list { value: [ 13.0, 23.0 ] } } } }""", - ragged_ascii_proto=""" + ragged_ascii_proto=""" feature { key: "ragged_val" value { float_list { value: [ 22.0, 22.0, 21.0 ] } } } feature { key: "ragged_row_lengths1" @@ -237,174 +228,172 @@ value { int64_list { value: [ 1, 0, 2 ] } } } } """, - instance={ - 'scalar_feature_1': np.array(13), - 'scalar_feature_2': np.int32(214), - 'scalar_feature_3': np.array(2.0), - 'varlen_feature_1': np.array([]), - '1d_vector_feature': np.array([b'this is another ,text']), - '2d_vector_feature': np.array([[9.0, 8.0], [7.0, 6.0]]), - 'varlen_feature_2': np.array([b'male']), - 'sparse_idx': np.array([2, 5]), - 'sparse_val': np.array([13.0, 21.0]), - '2d_sparse_idx0': np.array([1, 1]), - '2d_sparse_idx1': np.array([3, 7]), - '2d_sparse_val': np.array([13.0, 23.0]), - }, - ragged_instance={ - 'ragged_val': np.array([22.0, 22.0, 21.0]), - 'ragged_row_lengths1': np.array([0, 2, 1]), - '2d_ragged_val': np.array([b'oh', b'hello ', b'']), - '2d_ragged_row_lengths1': np.array([1, 2]), - '2d_ragged_row_lengths2': np.array([0, 0, 3]), - 'ragged_uniform_val': np.array([12, -11, 2, 1, -1, 12]), - '2d_ragged_uniform_val': np.array([1, -1, 23, 1, -1, 32]), - '2d_ragged_uniform_row_lengths1': np.array([1, 0, 2]), - }), - 'multiple_columns_with_missing': - dict( - testcase_name='multiple_columns_with_missing', - feature_spec={'varlen_feature': tf.io.VarLenFeature(tf.string)}, - ascii_proto="""\ + instance={ + "scalar_feature_1": np.array(13), + "scalar_feature_2": np.int32(214), + "scalar_feature_3": np.array(2.0), + "varlen_feature_1": np.array([]), + "1d_vector_feature": np.array([b"this is another ,text"]), + "2d_vector_feature": np.array([[9.0, 8.0], [7.0, 6.0]]), + "varlen_feature_2": np.array([b"male"]), + "sparse_idx": np.array([2, 5]), + "sparse_val": np.array([13.0, 21.0]), + "2d_sparse_idx0": np.array([1, 1]), + "2d_sparse_idx1": np.array([3, 7]), + "2d_sparse_val": np.array([13.0, 23.0]), + }, + ragged_instance={ + "ragged_val": np.array([22.0, 22.0, 21.0]), + "ragged_row_lengths1": np.array([0, 2, 1]), + "2d_ragged_val": np.array([b"oh", b"hello ", b""]), + "2d_ragged_row_lengths1": np.array([1, 2]), + "2d_ragged_row_lengths2": np.array([0, 0, 3]), + "ragged_uniform_val": np.array([12, -11, 2, 1, -1, 12]), + "2d_ragged_uniform_val": np.array([1, -1, 23, 1, -1, 32]), + "2d_ragged_uniform_row_lengths1": np.array([1, 0, 2]), + }, + ), + "multiple_columns_with_missing": dict( + testcase_name="multiple_columns_with_missing", + feature_spec={"varlen_feature": tf.io.VarLenFeature(tf.string)}, + ascii_proto="""\ features { feature { key: "varlen_feature" value {} } }""", - instance={'varlen_feature': None}), - 'multivariate_string_to_varlen': - dict( - testcase_name='multivariate_string_to_varlen', - feature_spec={'varlen_string': tf.io.VarLenFeature(tf.string)}, - ascii_proto="""\ + instance={"varlen_feature": None}, + ), + "multivariate_string_to_varlen": dict( + testcase_name="multivariate_string_to_varlen", + feature_spec={"varlen_string": tf.io.VarLenFeature(tf.string)}, + ascii_proto="""\ features { feature { key: "varlen_string" value { bytes_list { value: [ "foo", "bar" ] } } } }""", - instance={'varlen_string': [b'foo', b'bar']}), + instance={"varlen_string": [b"foo", b"bar"]}, + ), } _ENCODE_ERROR_CASES = [ dict( - testcase_name='to_few_values', + testcase_name="to_few_values", feature_spec={ - '2d_vector_feature': tf.io.FixedLenFeature([2, 2], tf.int64), + "2d_vector_feature": tf.io.FixedLenFeature([2, 2], tf.int64), }, - instance={'2d_vector_feature': [1, 2, 3]}, - error_msg='got wrong number of values'), + instance={"2d_vector_feature": [1, 2, 3]}, + error_msg="got wrong number of values", + ), dict( - testcase_name='unsupported_ragged_partition_sequence', + testcase_name="unsupported_ragged_partition_sequence", feature_spec={ - '2d_ragged_feature': - tf.io.RaggedFeature( - tf.string, - value_key='2d_ragged_val', - partitions=[ - tf.io.RaggedFeature.UniformRowLength(4), - tf.io.RaggedFeature.RowLengths('2d_ragged_row_lengths1') - ]), + "2d_ragged_feature": tf.io.RaggedFeature( + tf.string, + value_key="2d_ragged_val", + partitions=[ + tf.io.RaggedFeature.UniformRowLength(4), + tf.io.RaggedFeature.RowLengths("2d_ragged_row_lengths1"), + ], + ), }, - instance={'2d_ragged_val': [b'not', b'necessary']}, - error_msg='Encountered ragged dimension after uniform'), + instance={"2d_ragged_val": [b"not", b"necessary"]}, + error_msg="Encountered ragged dimension after uniform", + ), ] def _maybe_extend_encode_case_with_ragged(encode_case): - result = copy.deepcopy(encode_case) - ragged_ascii_proto = result.pop('ragged_ascii_proto', '}') - ragged_instance = result.pop('ragged_instance', {}) - result['ascii_proto'] = (encode_case['ascii_proto'][:-1] + ragged_ascii_proto) - result['instance'].update(ragged_instance) - return result + result = copy.deepcopy(encode_case) + ragged_ascii_proto = result.pop("ragged_ascii_proto", "}") + ragged_instance = result.pop("ragged_instance", {}) + result["ascii_proto"] = encode_case["ascii_proto"][:-1] + ragged_ascii_proto + result["instance"].update(ragged_instance) + return result def _maybe_extend_encode_cases_with_ragged(encode_cases): - for case in encode_cases.values(): - yield _maybe_extend_encode_case_with_ragged(case) + for case in encode_cases.values(): + yield _maybe_extend_encode_case_with_ragged(case) def _ascii_to_example(ascii_proto): - return text_format.Merge(ascii_proto, tf.train.Example()) + return text_format.Merge(ascii_proto, tf.train.Example()) def _ascii_to_binary(ascii_proto): - return _ascii_to_example(ascii_proto).SerializeToString() + return _ascii_to_example(ascii_proto).SerializeToString() def _binary_to_example(serialized_proto): - return tf.train.Example.FromString(serialized_proto) + return tf.train.Example.FromString(serialized_proto) class ExampleProtoCoderTest(test_case.TransformTestCase): + def setUp(self): + super().setUp() + # Verify that the implementation we requested via the Flag is honored. + if any(arg.startswith("--proto_implementation_type") for arg in sys.argv): + assert api_implementation.Type() == flags.FLAGS.proto_implementation_type, ( + "Expected proto implementation type " + f'"{flags.FLAGS.proto_implementation_type}", got: ' + f'"{api_implementation.Type()}"' + ) - def setUp(self): - super().setUp() - # Verify that the implementation we requested via the Flag is honored. - if any(arg.startswith('--proto_implementation_type') for arg in sys.argv): - assert ( - api_implementation.Type() == flags.FLAGS.proto_implementation_type - ), ( - 'Expected proto implementation type ' - f'"{flags.FLAGS.proto_implementation_type}", got: ' - f'"{api_implementation.Type()}"' - ) - - def assertSerializedProtosEqual(self, a, b): - np.testing.assert_equal(_binary_to_example(a), _binary_to_example(b)) + def assertSerializedProtosEqual(self, a, b): + np.testing.assert_equal(_binary_to_example(a), _binary_to_example(b)) - @test_case.named_parameters( - *_maybe_extend_encode_cases_with_ragged(_ENCODE_CASES)) - def test_encode(self, feature_spec, ascii_proto, instance, **kwargs): - schema = schema_utils.schema_from_feature_spec(feature_spec) - coder = example_proto_coder.ExampleProtoCoder(schema, **kwargs) - serialized_proto = _ascii_to_binary(ascii_proto) - self.assertSerializedProtosEqual(coder.encode(instance), serialized_proto) + @test_case.named_parameters(*_maybe_extend_encode_cases_with_ragged(_ENCODE_CASES)) + def test_encode(self, feature_spec, ascii_proto, instance, **kwargs): + schema = schema_utils.schema_from_feature_spec(feature_spec) + coder = example_proto_coder.ExampleProtoCoder(schema, **kwargs) + serialized_proto = _ascii_to_binary(ascii_proto) + self.assertSerializedProtosEqual(coder.encode(instance), serialized_proto) - @test_case.named_parameters( - *_maybe_extend_encode_cases_with_ragged(_ENCODE_CASES)) - def test_encode_non_serialized(self, feature_spec, ascii_proto, instance, - **kwargs): - schema = schema_utils.schema_from_feature_spec(feature_spec) - coder = example_proto_coder.ExampleProtoCoder( - schema, serialized=False, **kwargs) - proto = _ascii_to_example(ascii_proto) - self.assertProtoEquals(coder.encode(instance), proto) + @test_case.named_parameters(*_maybe_extend_encode_cases_with_ragged(_ENCODE_CASES)) + def test_encode_non_serialized(self, feature_spec, ascii_proto, instance, **kwargs): + schema = schema_utils.schema_from_feature_spec(feature_spec) + coder = example_proto_coder.ExampleProtoCoder( + schema, serialized=False, **kwargs + ) + proto = _ascii_to_example(ascii_proto) + self.assertProtoEquals(coder.encode(instance), proto) - @test_case.named_parameters(*_ENCODE_ERROR_CASES) - def test_encode_error(self, - feature_spec, - instance, - error_msg, - error_type=ValueError, - **kwargs): - schema = schema_utils.schema_from_feature_spec(feature_spec) - with self.assertRaisesRegex(error_type, error_msg): - coder = example_proto_coder.ExampleProtoCoder(schema, **kwargs) - coder.encode(instance) + @test_case.named_parameters(*_ENCODE_ERROR_CASES) + def test_encode_error( + self, feature_spec, instance, error_msg, error_type=ValueError, **kwargs + ): + schema = schema_utils.schema_from_feature_spec(feature_spec) + with self.assertRaisesRegex(error_type, error_msg): + coder = example_proto_coder.ExampleProtoCoder(schema, **kwargs) + coder.encode(instance) - def test_example_proto_coder_picklable(self): - encode_case = _maybe_extend_encode_case_with_ragged( - _ENCODE_CASES['multiple_columns']) - schema = schema_utils.schema_from_feature_spec(encode_case['feature_spec']) - coder = example_proto_coder.ExampleProtoCoder(schema) - ascii_proto = encode_case['ascii_proto'] - instance = encode_case['instance'] - serialized_proto = _ascii_to_binary(ascii_proto) - for _ in range(2): - coder = pickle.loads(pickle.dumps(coder)) - self.assertSerializedProtosEqual(coder.encode(instance), serialized_proto) + def test_example_proto_coder_picklable(self): + encode_case = _maybe_extend_encode_case_with_ragged( + _ENCODE_CASES["multiple_columns"] + ) + schema = schema_utils.schema_from_feature_spec(encode_case["feature_spec"]) + coder = example_proto_coder.ExampleProtoCoder(schema) + ascii_proto = encode_case["ascii_proto"] + instance = encode_case["instance"] + serialized_proto = _ascii_to_binary(ascii_proto) + for _ in range(2): + coder = pickle.loads(pickle.dumps(coder)) + self.assertSerializedProtosEqual(coder.encode(instance), serialized_proto) - def test_example_proto_coder_cache(self): - """Test that the cache remains valid after reading/writing None.""" - schema = schema_utils.schema_from_feature_spec({ - 'varlen': tf.io.VarLenFeature(tf.int64), - }) - coder = example_proto_coder.ExampleProtoCoder(schema) - ascii_protos = [ - 'features {feature {key: "varlen" value {int64_list {value: [5] }}}}', - 'features {feature {key: "varlen" value {}}}', - 'features {feature {key: "varlen" value {int64_list {value: [6] }}}}', - ] - instances = [{'varlen': [5]}, {'varlen': None}, {'varlen': [6]}] - serialized_protos = map(_ascii_to_binary, ascii_protos) - for instance, serialized_proto in zip(instances, serialized_protos): - self.assertSerializedProtosEqual(coder.encode(instance), serialized_proto) + def test_example_proto_coder_cache(self): + """Test that the cache remains valid after reading/writing None.""" + schema = schema_utils.schema_from_feature_spec( + { + "varlen": tf.io.VarLenFeature(tf.int64), + } + ) + coder = example_proto_coder.ExampleProtoCoder(schema) + ascii_protos = [ + 'features {feature {key: "varlen" value {int64_list {value: [5] }}}}', + 'features {feature {key: "varlen" value {}}}', + 'features {feature {key: "varlen" value {int64_list {value: [6] }}}}', + ] + instances = [{"varlen": [5]}, {"varlen": None}, {"varlen": [6]}] + serialized_protos = map(_ascii_to_binary, ascii_protos) + for instance, serialized_proto in zip(instances, serialized_protos): + self.assertSerializedProtosEqual(coder.encode(instance), serialized_proto) -if __name__ == '__main__': - test_case.main() +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/common.py b/tensorflow_transform/common.py index 44a1899..81d70c3 100644 --- a/tensorflow_transform/common.py +++ b/tensorflow_transform/common.py @@ -19,63 +19,68 @@ from typing import Any, Callable, Generator import tensorflow as tf +from tensorflow.python.util import ( + tf_decorator, # pylint: disable=g-direct-tensorflow-import +) -from tensorflow.python.util import tf_decorator # pylint: disable=g-direct-tensorflow-import +ANALYZER_COLLECTION = "tft_analyzer_use" +MAPPER_COLLECTION = "tft_mapper_use" -ANALYZER_COLLECTION = 'tft_analyzer_use' -MAPPER_COLLECTION = 'tft_mapper_use' - -ANNOTATION_PREFIX_URL = 'type.googleapis.com' +ANNOTATION_PREFIX_URL = "type.googleapis.com" # TODO(b/132098015): Schema annotations aren't yet supported in OSS builds. try: - from tensorflow_transform import annotations_pb2 # pylint: disable=g-import-not-at-top, unused-import - IS_ANNOTATIONS_PB_AVAILABLE = True + from tensorflow_transform import ( + annotations_pb2, # pylint: disable=g-import-not-at-top, unused-import + ) + + IS_ANNOTATIONS_PB_AVAILABLE = True except ImportError: - IS_ANNOTATIONS_PB_AVAILABLE = False + IS_ANNOTATIONS_PB_AVAILABLE = False _in_logging_context = False @contextlib.contextmanager def logging_context() -> Generator[None, None, None]: - global _in_logging_context - _in_logging_context = True - try: - yield - finally: - _in_logging_context = False + global _in_logging_context + _in_logging_context = True + try: + yield + finally: + _in_logging_context = False def log_api_use( - collection_name: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]: - """Creates a decorator that logs function calls in the tensorflow graph.""" + collection_name: str, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Creates a decorator that logs function calls in the tensorflow graph.""" - def decorator(fn): - """Logs function calls in a tensorflow graph collection.""" + def decorator(fn): + """Logs function calls in a tensorflow graph collection.""" - @functools.wraps(fn) - def wrapped_fn(*args, **kwargs): - if not _in_logging_context: - with logging_context(): - graph = tf.compat.v1.get_default_graph() - # Collection is a list that contains a single Counter of {name: count} - # Note: We aggregate counts of function calls instead having one - # collection item per call, since TFT users can use an arbitrarily - # large number of analyzers and mappers and we don't want the graph - # to get too big. - # TODO(rachelim): Make this collection serializable so it can be added - # to the SavedModel. - collection = graph.get_collection_ref(collection_name) - if not collection: - collection.append(collections.Counter()) - collection[0][fn.__name__] += 1 - return fn(*args, **kwargs) - else: - return fn(*args, **kwargs) + @functools.wraps(fn) + def wrapped_fn(*args, **kwargs): + if not _in_logging_context: + with logging_context(): + graph = tf.compat.v1.get_default_graph() + # Collection is a list that contains a single Counter of {name: count} + # Note: We aggregate counts of function calls instead having one + # collection item per call, since TFT users can use an arbitrarily + # large number of analyzers and mappers and we don't want the graph + # to get too big. + # TODO(rachelim): Make this collection serializable so it can be added + # to the SavedModel. + collection = graph.get_collection_ref(collection_name) + if not collection: + collection.append(collections.Counter()) + collection[0][fn.__name__] += 1 + return fn(*args, **kwargs) + else: + return fn(*args, **kwargs) - # We use tf_decorator here so that TF can correctly introspect into - # functions for docstring generation. - return tf_decorator.make_decorator(fn, wrapped_fn) + # We use tf_decorator here so that TF can correctly introspect into + # functions for docstring generation. + return tf_decorator.make_decorator(fn, wrapped_fn) - return decorator + return decorator diff --git a/tensorflow_transform/common_test.py b/tensorflow_transform/common_test.py index df9171a..91d94e1 100644 --- a/tensorflow_transform/common_test.py +++ b/tensorflow_transform/common_test.py @@ -14,60 +14,58 @@ """Tests for tensorflow_transform.common.""" import tensorflow as tf -from tensorflow_transform import common -from tensorflow_transform import test_case +from tensorflow_transform import common, test_case -class CommonTest(test_case.TransformTestCase): - - def testLogAPIUse(self): - @common.log_api_use("test_collection") - def fn0(): - return None +class CommonTest(test_case.TransformTestCase): + def testLogAPIUse(self): + @common.log_api_use("test_collection") + def fn0(): + return None - @common.log_api_use("test_collection") - def fn1(): - return None + @common.log_api_use("test_collection") + def fn1(): + return None - @common.log_api_use("another_collection") - def fn2(): - return None + @common.log_api_use("another_collection") + def fn2(): + return None - with tf.compat.v1.Graph().as_default() as graph: - fn0() - fn1() - fn2() - fn0() - fn0() + with tf.compat.v1.Graph().as_default() as graph: + fn0() + fn1() + fn2() + fn0() + fn0() - self.assertAllEqual([{"fn0": 3, "fn1": 1}], - graph.get_collection("test_collection")) - self.assertAllEqual([{"fn2": 1}], - graph.get_collection("another_collection")) + self.assertAllEqual( + [{"fn0": 3, "fn1": 1}], graph.get_collection("test_collection") + ) + self.assertAllEqual([{"fn2": 1}], graph.get_collection("another_collection")) - def testLogAPIUseWithNestedFunction(self): - """Tests that API call is not logged when called from another logged API.""" + def testLogAPIUseWithNestedFunction(self): + """Tests that API call is not logged when called from another logged API.""" - @common.log_api_use("test_collection") - def fn0(): - fn1() - return fn2() + @common.log_api_use("test_collection") + def fn0(): + fn1() + return fn2() - @common.log_api_use("test_collection") - def fn1(): - return None + @common.log_api_use("test_collection") + def fn1(): + return None - @common.log_api_use("another_collection") - def fn2(): - return None + @common.log_api_use("another_collection") + def fn2(): + return None - with tf.compat.v1.Graph().as_default() as graph: - fn0() + with tf.compat.v1.Graph().as_default() as graph: + fn0() - self.assertEqual([{"fn0": 1}], graph.get_collection("test_collection")) - self.assertAllEqual([], graph.get_collection("another_collection")) + self.assertEqual([{"fn0": 1}], graph.get_collection("test_collection")) + self.assertAllEqual([], graph.get_collection("another_collection")) if __name__ == "__main__": - test_case.main() + test_case.main() diff --git a/tensorflow_transform/common_types.py b/tensorflow_transform/common_types.py index f6a4da2..890cc35 100644 --- a/tensorflow_transform/common_types.py +++ b/tensorflow_transform/common_types.py @@ -13,36 +13,36 @@ # limitations under the License. """Common types in tf.transform.""" -from typing import Any, Dict, Iterable, List, TypeVar, Union, Optional +from typing import Any, Dict, Iterable, List, Optional, TypeVar, Union import numpy as np import tensorflow as tf -from typing_extensions import Literal - from tensorflow_metadata.proto.v0 import schema_pb2 +from typing_extensions import Literal # Demonstrational per-row data formats. PrimitiveType = Union[str, bytes, float, int] -InstanceValueType = Optional[ - Union[np.ndarray, np.generic, PrimitiveType, List[Any]] -] +InstanceValueType = Optional[Union[np.ndarray, np.generic, PrimitiveType, List[Any]]] InstanceDictType = Dict[str, InstanceValueType] # TODO(b/185719271): Define BucketBoundariesType at module level of mappers.py. BucketBoundariesType = Union[tf.Tensor, Iterable[Union[int, float]]] -FeatureSpecType = Union[tf.io.FixedLenFeature, tf.io.VarLenFeature, - tf.io.SparseFeature, tf.io.RaggedFeature] +FeatureSpecType = Union[ + tf.io.FixedLenFeature, tf.io.VarLenFeature, tf.io.SparseFeature, tf.io.RaggedFeature +] -DomainType = Union[schema_pb2.IntDomain, schema_pb2.FloatDomain, - schema_pb2.StringDomain] +DomainType = Union[ + schema_pb2.IntDomain, schema_pb2.FloatDomain, schema_pb2.StringDomain +] TensorType = Union[tf.Tensor, tf.SparseTensor, tf.RaggedTensor] ConsistentTensorType = TypeVar( # pylint: disable=invalid-name - 'ConsistentTensorType', tf.Tensor, tf.SparseTensor, tf.RaggedTensor) + "ConsistentTensorType", tf.Tensor, tf.SparseTensor, tf.RaggedTensor +) SparseTensorValueType = Union[tf.SparseTensor, tf.compat.v1.SparseTensorValue] -RaggedTensorValueType = Union[tf.RaggedTensor, - tf.compat.v1.ragged.RaggedTensorValue] -TensorValueType = Union[tf.Tensor, np.ndarray, SparseTensorValueType, - RaggedTensorValueType] +RaggedTensorValueType = Union[tf.RaggedTensor, tf.compat.v1.ragged.RaggedTensorValue] +TensorValueType = Union[ + tf.Tensor, np.ndarray, SparseTensorValueType, RaggedTensorValueType +] TemporaryAnalyzerOutputType = Union[tf.Tensor, tf.saved_model.Asset] -VocabularyFileFormatType = Literal['text', 'tfrecord_gzip'] +VocabularyFileFormatType = Literal["text", "tfrecord_gzip"] diff --git a/tensorflow_transform/experimental/analyzers.py b/tensorflow_transform/experimental/analyzers.py index 4a434e5..582634b 100644 --- a/tensorflow_transform/experimental/analyzers.py +++ b/tensorflow_transform/experimental/analyzers.py @@ -24,17 +24,11 @@ the computation that takes place outside of TensorFlow. """ -from typing import Any, Collection, List, Optional, Tuple, Union, Iterable, Sequence +from typing import Any, Collection, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import pyarrow as pa import tensorflow as tf -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import analyzers -from tensorflow_transform import common -from tensorflow_transform import common_types -from tensorflow_transform import nodes -from tensorflow_transform import tf_utils from tfx_bsl import sketches # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` @@ -42,126 +36,144 @@ from tfx_bsl.types import tfx_namedtuple from typing_extensions import Protocol +from tensorflow_transform import ( + analyzer_nodes, + analyzers, + common, + common_types, + nodes, + tf_utils, +) + __all__ = [ - 'PTransformAnalyzerCacheCoder', - 'SimpleJsonPTransformAnalyzerCacheCoder', - 'CacheablePTransformAnalyzer', - 'ptransform_analyzer', - 'approximate_vocabulary', + "PTransformAnalyzerCacheCoder", + "SimpleJsonPTransformAnalyzerCacheCoder", + "CacheablePTransformAnalyzer", + "ptransform_analyzer", + "approximate_vocabulary", ] PTransformAnalyzerCacheCoder = analyzer_nodes.CacheCoder SimpleJsonPTransformAnalyzerCacheCoder = analyzer_nodes.JsonNumpyCacheCoder -_APPROXIMATE_VOCAB_FILENAME_PREFIX = 'approx_vocab_' -_APPROXIMATE_VOCAB_FREQUENCY_FILENAME_PREFIX = 'approx_vocab_frequency_' +_APPROXIMATE_VOCAB_FILENAME_PREFIX = "approx_vocab_" +_APPROXIMATE_VOCAB_FREQUENCY_FILENAME_PREFIX = "approx_vocab_frequency_" class _BeamPTransform(Protocol): - """Pytype for `beam.PTransform` without depending on beam in this module. - """ + """Pytype for `beam.PTransform` without depending on beam in this module.""" - def expand(self, pcol: Any) -> Any: - ... + def expand(self, pcol: Any) -> Any: ... - def default_label(self) -> str: - ... + def default_label(self) -> str: ... # TODO(zoyahav): Add an example for using this API. class CacheablePTransformAnalyzer( tfx_namedtuple.TypedNamedTuple( - 'PTransformCachedAnalyzer', - [('make_accumulators_ptransform', _BeamPTransform), - ('merge_accumulators_ptransform', _BeamPTransform), - ('extract_output_ptransform', _BeamPTransform), - ('cache_coder', PTransformAnalyzerCacheCoder)])): - """A PTransformAnalyzer which enables analyzer cache. - - WARNING: This should only be used if the analyzer can correctly be separated - into make_accumulators, merge_accumulators and extract_output stages. - 1. make_accumulators_ptransform: this is a `beam.PTransform` which maps data - to a more compact mergeable representation (accumulator). Mergeable here - means that it is possible to combine multiple representations produced from - a partition of the dataset into a representation of the entire dataset. - 1. merge_accumulators_ptransform: this is a `beam.PTransform` which operates - on a collection of accumulators, i.e. the results of both the - make_accumulators_ptransform and merge_accumulators_ptransform stages, - and produces a single reduced accumulator. This operation must be - associative and commutative in order to have reliably reproducible results. - 1. extract_output: this is a `beam.PTransform` which operates on the result of - the merge_accumulators_ptransform stage, and produces the outputs of the - analyzer. These outputs must be consistent with the `output_dtypes` and - `output_shapes` provided to `ptransform_analyzer`. - - This container also holds a `cache_coder` (`PTransformAnalyzerCacheCoder`) - which can encode outputs and decode the inputs of the - `merge_accumulators_ptransform` stage. - In many cases, `SimpleJsonPTransformAnalyzerCacheCoder` would be sufficient. - - To ensure the correctness of this analyzer, the following must hold: - merge(make({D1, ..., Dn})) == merge({make(D1), ..., make(Dn)}) - """ - __slots__ = () - - -def _apply_analyzer(ptransform: Union[_BeamPTransform, - CacheablePTransformAnalyzer], - *tensor_inputs: common_types.TensorType, - **analyzer_def_kwargs: Any) -> Tuple[tf.Tensor, ...]: - """Applies the analyzer over the whole dataset. - - Args: - ptransform: A class inheriting from analyzer_nodes.AnalyzerDef or - CacheablePTransformAnalyzer that should be applied. - *tensor_inputs: A list of input `Tensor`s, `SparseTensor`s, or - `RaggedTensor`s. - **analyzer_def_kwargs: KW arguments to use when constructing - analyzer_def_cls. - - Returns: - A list of `Tensor`s representing the values of the analysis result. - """ - input_values_node = analyzer_nodes.get_input_tensors_value_nodes( - tensor_inputs) - if isinstance(ptransform, CacheablePTransformAnalyzer): - with tf.compat.v1.name_scope('make_accumulators'): - make_accumulators_value_node = nodes.apply_multi_output_operation( - analyzer_nodes.PTransform, - input_values_node, - ptransform=ptransform.make_accumulators_ptransform, - is_partitionable=True, - **analyzer_def_kwargs) - with tf.compat.v1.name_scope('local_merge_accumulators'): - cached_value_nodes = nodes.apply_multi_output_operation( - analyzer_nodes.PTransform, - *make_accumulators_value_node, - ptransform=ptransform.merge_accumulators_ptransform, - is_partitionable=True, - cache_coder=ptransform.cache_coder, - **analyzer_def_kwargs) - with tf.compat.v1.name_scope('global_merge_accumulators'): - merge_output_value_nodes = nodes.apply_multi_output_operation( - analyzer_nodes.PTransform, - *cached_value_nodes, - ptransform=ptransform.merge_accumulators_ptransform, - is_partitionable=False, - **analyzer_def_kwargs) - with tf.compat.v1.name_scope('extract_output'): - output_value_nodes = nodes.apply_multi_output_operation( - analyzer_nodes.PTransform, - *merge_output_value_nodes, - ptransform=ptransform.extract_output_ptransform, - is_partitionable=False, - **analyzer_def_kwargs) - else: - output_value_nodes = nodes.apply_multi_output_operation( - analyzer_nodes.PTransform, - input_values_node, - ptransform=ptransform, - is_partitionable=False, - **analyzer_def_kwargs) - return tuple(map(analyzer_nodes.wrap_as_tensor, output_value_nodes)) + "PTransformCachedAnalyzer", + [ + ("make_accumulators_ptransform", _BeamPTransform), + ("merge_accumulators_ptransform", _BeamPTransform), + ("extract_output_ptransform", _BeamPTransform), + ("cache_coder", PTransformAnalyzerCacheCoder), + ], + ) +): + """A PTransformAnalyzer which enables analyzer cache. + + WARNING: This should only be used if the analyzer can correctly be separated + into make_accumulators, merge_accumulators and extract_output stages. + 1. make_accumulators_ptransform: this is a `beam.PTransform` which maps data + to a more compact mergeable representation (accumulator). Mergeable here + means that it is possible to combine multiple representations produced from + a partition of the dataset into a representation of the entire dataset. + 1. merge_accumulators_ptransform: this is a `beam.PTransform` which operates + on a collection of accumulators, i.e. the results of both the + make_accumulators_ptransform and merge_accumulators_ptransform stages, + and produces a single reduced accumulator. This operation must be + associative and commutative in order to have reliably reproducible results. + 1. extract_output: this is a `beam.PTransform` which operates on the result of + the merge_accumulators_ptransform stage, and produces the outputs of the + analyzer. These outputs must be consistent with the `output_dtypes` and + `output_shapes` provided to `ptransform_analyzer`. + + This container also holds a `cache_coder` (`PTransformAnalyzerCacheCoder`) + which can encode outputs and decode the inputs of the + `merge_accumulators_ptransform` stage. + In many cases, `SimpleJsonPTransformAnalyzerCacheCoder` would be sufficient. + + To ensure the correctness of this analyzer, the following must hold: + merge(make({D1, ..., Dn})) == merge({make(D1), ..., make(Dn)}) + """ + + __slots__ = () + + +def _apply_analyzer( + ptransform: Union[_BeamPTransform, CacheablePTransformAnalyzer], + *tensor_inputs: common_types.TensorType, + **analyzer_def_kwargs: Any, +) -> Tuple[tf.Tensor, ...]: + """Applies the analyzer over the whole dataset. + + Args: + ---- + ptransform: A class inheriting from analyzer_nodes.AnalyzerDef or + CacheablePTransformAnalyzer that should be applied. + *tensor_inputs: A list of input `Tensor`s, `SparseTensor`s, or + `RaggedTensor`s. + **analyzer_def_kwargs: KW arguments to use when constructing + analyzer_def_cls. + + Returns: + ------- + A list of `Tensor`s representing the values of the analysis result. + """ + input_values_node = analyzer_nodes.get_input_tensors_value_nodes(tensor_inputs) + if isinstance(ptransform, CacheablePTransformAnalyzer): + with tf.compat.v1.name_scope("make_accumulators"): + make_accumulators_value_node = nodes.apply_multi_output_operation( + analyzer_nodes.PTransform, + input_values_node, + ptransform=ptransform.make_accumulators_ptransform, + is_partitionable=True, + **analyzer_def_kwargs, + ) + with tf.compat.v1.name_scope("local_merge_accumulators"): + cached_value_nodes = nodes.apply_multi_output_operation( + analyzer_nodes.PTransform, + *make_accumulators_value_node, + ptransform=ptransform.merge_accumulators_ptransform, + is_partitionable=True, + cache_coder=ptransform.cache_coder, + **analyzer_def_kwargs, + ) + with tf.compat.v1.name_scope("global_merge_accumulators"): + merge_output_value_nodes = nodes.apply_multi_output_operation( + analyzer_nodes.PTransform, + *cached_value_nodes, + ptransform=ptransform.merge_accumulators_ptransform, + is_partitionable=False, + **analyzer_def_kwargs, + ) + with tf.compat.v1.name_scope("extract_output"): + output_value_nodes = nodes.apply_multi_output_operation( + analyzer_nodes.PTransform, + *merge_output_value_nodes, + ptransform=ptransform.extract_output_ptransform, + is_partitionable=False, + **analyzer_def_kwargs, + ) + else: + output_value_nodes = nodes.apply_multi_output_operation( + analyzer_nodes.PTransform, + input_values_node, + ptransform=ptransform, + is_partitionable=False, + **analyzer_def_kwargs, + ) + return tuple(map(analyzer_nodes.wrap_as_tensor, output_value_nodes)) # TODO(b/164921571): Support output assets in tfrecord format. @@ -172,136 +184,147 @@ def ptransform_analyzer( output_dtypes: Collection[tf.dtypes.DType], output_shapes: Collection[List[int]], output_asset_default_values: Optional[Collection[Optional[bytes]]] = None, - name: Optional[str] = None): - # pylint: disable=line-too-long - """Applies a user-provided PTransform over the whole dataset. - - WARNING: This is experimental. - - Note that in order to have asset files copied correctly, any outputs that - represent asset filenames must be added to the `tf.GraphKeys.ASSET_FILEPATHS` - collection by the caller if using Transform's APIs in compat v1 mode. - - Example: - - >>> class MeanPerKey(beam.PTransform): - ... def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]) -> Tuple[beam.PCollection[np.ndarray], beam.PCollection[np.ndarray]]: - ... def extract_output(key_value_pairs): - ... keys, values = zip(*key_value_pairs) - ... return [beam.TaggedOutput('keys', keys), - ... beam.TaggedOutput('values', values)] - ... return tuple( - ... pcoll - ... | 'ZipAndFlatten' >> beam.FlatMap(lambda batches: list(zip(*batches))) - ... | 'MeanPerKey' >> beam.CombinePerKey(beam.combiners.MeanCombineFn()) - ... | 'ToList' >> beam.combiners.ToList() - ... | 'Extract' >> beam.FlatMap(extract_output).with_outputs( - ... 'keys', 'values')) - >>> def preprocessing_fn(inputs): - ... outputs = tft.experimental.ptransform_analyzer( - ... inputs=[inputs['s'], inputs['x']], - ... ptransform=MeanPerKey(), - ... output_dtypes=[tf.string, tf.float32], - ... output_shapes=[[2], [2]]) - ... (keys, means) = outputs - ... mean_a = tf.reshape(tf.gather(means, tf.where(keys == 'a')), []) - ... return { 'x/mean_a': inputs['x'] / mean_a } - >>> raw_data = [dict(x=1, s='a'), dict(x=8, s='b'), dict(x=3, s='a')] - >>> feature_spec = dict( - ... x=tf.io.FixedLenFeature([], tf.float32), - ... s=tf.io.FixedLenFeature([], tf.string)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... transformed_dataset, transform_fn = ( - ... (raw_data, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - >>> transformed_data, transformed_metadata = transformed_dataset - >>> transformed_data - [{'x/mean_a': 0.5}, {'x/mean_a': 4.0}, {'x/mean_a': 1.5}] - - Args: - inputs: An ordered collection of input `Tensor`s. - ptransform: A Beam PTransform that accepts a Beam PCollection where each - element is a tuple of `ndarray`s. Each element in the tuple contains a - batch of values for the corresponding input tensor of the analyzer and - maintain their shapes and dtypes. - It returns a `PCollection`, or a tuple of `PCollections`, each containing - a single element which is an `ndarray` or a list of primitive types. The - contents of these output `PCollection`s must be consistent with the given - values of `output_dtypes` and `output_shapes`. - It may inherit from `tft_beam.experimental.PTransformAnalyzer` if access - to a temp base directory is needed. - Alternatively, it could be an instance of - `tft.experimental.CacheablePTransformAnalyzer` in order to enable cache - for this analyzer, when analyzer cache is enabled for this pipeline. - output_dtypes: An ordered collection of TensorFlow dtypes of the output of - the analyzer. - output_shapes: An ordered collection of shapes of the output of the - analyzer. Must have the same length as output_dtypes. - output_asset_default_values: (Optional) An ordered collection of optional - `bytes` aligned with output_dtypes/output_shapes. Every item in this - collection which is not `None` indicates that the output is a TF asset - path, and its value would be used as the default value of this asset file - prior to analysis. - name: (Optional) Similar to a TF op name. Used to define a unique scope for - this analyzer, which can be used for debugging info. - - Returns: - A list of output `Tensor`s. These will have `dtype` and `shape` as - specified by `output_dtypes` and `output_shapes`. - - Raises: - ValueError: If output_dtypes and output_shapes have different lengths. - """ - # pylint: enable=line-too-long - if len(output_dtypes) != len(output_shapes): - raise ValueError('output_dtypes ({}) and output_shapes ({}) had different' - ' lengths'.format(output_dtypes, output_shapes)) - if output_asset_default_values is not None: - if len(output_asset_default_values) != len(output_dtypes): - raise ValueError( - 'output_dtypes ({}) and output_asset_default_values ({}) had ' - 'different lengths'.format(output_dtypes, - output_asset_default_values)) - output_asset_default_values = [ - analyzer_nodes.TemporaryAssetInfo(value, 'text') - for value in output_asset_default_values - ] - else: - output_asset_default_values = [None] * len(output_dtypes) - with tf.compat.v1.name_scope(name, 'ptransform'): - output_tensor_infos = [ - analyzer_nodes.TensorInfo(dtype, shape, default_asset_content) - for dtype, shape, default_asset_content in zip( - output_dtypes, output_shapes, output_asset_default_values) - ] - return _apply_analyzer( - ptransform, *inputs, output_tensor_info_list=output_tensor_infos) - - -def _get_approx_vocab_filename(vocab_filename: Optional[str], - store_frequency: bool) -> str: - """Returns a sanitized vocabulary filename with appropriate prefix applied. - - Args: - vocab_filename: The file name for the approximate vocabulary file. If None, - the "approximate_vocabulary" scope name in the context of this graph will - be used as the file name. - store_frequency: A bool that is true when the vocabulary for which this - generates a filename stores term frequency. False otherwise. - - Returns: - A valid filename. - """ - if vocab_filename is not None: - prefix = None - elif store_frequency: - prefix = _APPROXIMATE_VOCAB_FILENAME_PREFIX - else: - prefix = _APPROXIMATE_VOCAB_FREQUENCY_FILENAME_PREFIX - - # Make the file name path safe. - return analyzers.sanitized_vocab_filename(vocab_filename, prefix=prefix) + name: Optional[str] = None, +): + # pylint: disable=line-too-long + """Applies a user-provided PTransform over the whole dataset. + + WARNING: This is experimental. + + Note that in order to have asset files copied correctly, any outputs that + represent asset filenames must be added to the `tf.GraphKeys.ASSET_FILEPATHS` + collection by the caller if using Transform's APIs in compat v1 mode. + + Example: + ------- + >>> class MeanPerKey(beam.PTransform): + ... def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]) -> Tuple[beam.PCollection[np.ndarray], beam.PCollection[np.ndarray]]: + ... def extract_output(key_value_pairs): + ... keys, values = zip(*key_value_pairs) + ... return [beam.TaggedOutput('keys', keys), + ... beam.TaggedOutput('values', values)] + ... return tuple( + ... pcoll + ... | 'ZipAndFlatten' >> beam.FlatMap(lambda batches: list(zip(*batches))) + ... | 'MeanPerKey' >> beam.CombinePerKey(beam.combiners.MeanCombineFn()) + ... | 'ToList' >> beam.combiners.ToList() + ... | 'Extract' >> beam.FlatMap(extract_output).with_outputs( + ... 'keys', 'values')) + >>> def preprocessing_fn(inputs): + ... outputs = tft.experimental.ptransform_analyzer( + ... inputs=[inputs['s'], inputs['x']], + ... ptransform=MeanPerKey(), + ... output_dtypes=[tf.string, tf.float32], + ... output_shapes=[[2], [2]]) + ... (keys, means) = outputs + ... mean_a = tf.reshape(tf.gather(means, tf.where(keys == 'a')), []) + ... return { 'x/mean_a': inputs['x'] / mean_a } + >>> raw_data = [dict(x=1, s='a'), dict(x=8, s='b'), dict(x=3, s='a')] + >>> feature_spec = dict( + ... x=tf.io.FixedLenFeature([], tf.float32), + ... s=tf.io.FixedLenFeature([], tf.string)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... transformed_dataset, transform_fn = ( + ... (raw_data, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + >>> transformed_data, transformed_metadata = transformed_dataset + >>> transformed_data + [{'x/mean_a': 0.5}, {'x/mean_a': 4.0}, {'x/mean_a': 1.5}] + + Args: + ---- + inputs: An ordered collection of input `Tensor`s. + ptransform: A Beam PTransform that accepts a Beam PCollection where each + element is a tuple of `ndarray`s. Each element in the tuple contains a + batch of values for the corresponding input tensor of the analyzer and + maintain their shapes and dtypes. + It returns a `PCollection`, or a tuple of `PCollections`, each containing + a single element which is an `ndarray` or a list of primitive types. The + contents of these output `PCollection`s must be consistent with the given + values of `output_dtypes` and `output_shapes`. + It may inherit from `tft_beam.experimental.PTransformAnalyzer` if access + to a temp base directory is needed. + Alternatively, it could be an instance of + `tft.experimental.CacheablePTransformAnalyzer` in order to enable cache + for this analyzer, when analyzer cache is enabled for this pipeline. + output_dtypes: An ordered collection of TensorFlow dtypes of the output of + the analyzer. + output_shapes: An ordered collection of shapes of the output of the + analyzer. Must have the same length as output_dtypes. + output_asset_default_values: (Optional) An ordered collection of optional + `bytes` aligned with output_dtypes/output_shapes. Every item in this + collection which is not `None` indicates that the output is a TF asset + path, and its value would be used as the default value of this asset file + prior to analysis. + name: (Optional) Similar to a TF op name. Used to define a unique scope for + this analyzer, which can be used for debugging info. + + Returns: + ------- + A list of output `Tensor`s. These will have `dtype` and `shape` as + specified by `output_dtypes` and `output_shapes`. + + Raises: + ------ + ValueError: If output_dtypes and output_shapes have different lengths. + """ + # pylint: enable=line-too-long + if len(output_dtypes) != len(output_shapes): + raise ValueError( + f"output_dtypes ({output_dtypes}) and output_shapes ({output_shapes}) had different" + " lengths" + ) + if output_asset_default_values is not None: + if len(output_asset_default_values) != len(output_dtypes): + raise ValueError( + f"output_dtypes ({output_dtypes}) and output_asset_default_values ({output_asset_default_values}) had " + "different lengths" + ) + output_asset_default_values = [ + analyzer_nodes.TemporaryAssetInfo(value, "text") + for value in output_asset_default_values + ] + else: + output_asset_default_values = [None] * len(output_dtypes) + with tf.compat.v1.name_scope(name, "ptransform"): + output_tensor_infos = [ + analyzer_nodes.TensorInfo(dtype, shape, default_asset_content) + for dtype, shape, default_asset_content in zip( + output_dtypes, output_shapes, output_asset_default_values + ) + ] + return _apply_analyzer( + ptransform, *inputs, output_tensor_info_list=output_tensor_infos + ) + + +def _get_approx_vocab_filename( + vocab_filename: Optional[str], store_frequency: bool +) -> str: + """Returns a sanitized vocabulary filename with appropriate prefix applied. + + Args: + ---- + vocab_filename: The file name for the approximate vocabulary file. If None, + the "approximate_vocabulary" scope name in the context of this graph will + be used as the file name. + store_frequency: A bool that is true when the vocabulary for which this + generates a filename stores term frequency. False otherwise. + + Returns: + ------- + A valid filename. + """ + if vocab_filename is not None: + prefix = None + elif store_frequency: + prefix = _APPROXIMATE_VOCAB_FILENAME_PREFIX + else: + prefix = _APPROXIMATE_VOCAB_FREQUENCY_FILENAME_PREFIX + + # Make the file name path safe. + return analyzers.sanitized_vocab_filename(vocab_filename, prefix=prefix) @common.log_api_use(common.ANALYZER_COLLECTION) @@ -314,122 +337,126 @@ def approximate_vocabulary( reserved_tokens: Optional[Union[Sequence[str], tf.Tensor]] = None, weights: Optional[tf.Tensor] = None, file_format: common_types.VocabularyFileFormatType = analyzers.DEFAULT_VOCABULARY_FILE_FORMAT, - name: Optional[str] = None + name: Optional[str] = None, ) -> common_types.TemporaryAnalyzerOutputType: - r"""Computes the unique values of a `Tensor` over the whole dataset. - - Approximately computes the unique values taken by `x`, which can be a - `Tensor`, `SparseTensor`, or `RaggedTensor` of any size. The unique values - will be aggregated over all dimensions of `x` and all instances. - - This analyzer provides an approximate alternative to `tft.vocabulary` that can - be more efficient with smaller `top_k` and/or smaller number of unique - elements in `x`. As a rule of thumb, `approximate_vocabulary` becomes more - efficient than `tft.vocabulary` if `top_k` or the number of unique elements in - `x` is smaller than 2*10^5. Moreover, this analyzer is subject to combiner - packing optimization that does not apply to `tft.vocabulary`. Caching is also - more efficient with the approximate implementation since the filtration - happens before writing out cache. Output artifact of `approximate_vocabulary` - is consistent with `tft.vocabulary` and can be used in `tft.apply_vocabulary` - mapper. - - Implementation of this analyzer is based on the Misra-Gries algorithm [1]. It - stores at most `top_k` elements with lower bound frequency estimates at a - time. The algorithm keeps track of the approximation error `delta` such that - for any item x with true frequency X: - - frequency[x] <= X <= frequency[x] + delta, - delta <= (m - m') / (top_k + 1), - - where m is the total frequency of the items in the dataset and m' is the sum - of the lower bound estimates in `frequency` [2]. For datasets that are Zipfian - distributed with parameter `a`, the algorithm provides an expected value of - delta = m / (top_k ^ a) [3]. - - [1] - https://www.cs.utexas.edu/users/misra/scannedPdf.dir/FindRepeatedElements.pdf - [2] http://www.cohenwang.com/edith/bigdataclass2013/lectures/lecture1.pdf - [3] http://dimacs.rutgers.edu/~graham/pubs/papers/countersj.pdf - - In case `file_format` is 'text' and one of the tokens contains the '\n' or - '\r' characters or is empty it will be discarded. - - If an integer `Tensor` is provided, its semantic type should be categorical - not a continuous/numeric, since computing a vocabulary over a continuous - feature is not appropriate. - - The unique values are sorted by decreasing frequency and then reverse - lexicographical order (e.g. [('a', 5), ('c', 3), ('b', 3)]). This is true even - if `x` is numerical dtype (e.g. [('3', 5), ('2', 3), ('111', 3)]). - - Args: - x: A categorical/discrete input `Tensor`, `SparseTensor`, or `RaggedTensor` - with dtype tf.string or tf.int[8|16|32|64]. - top_k: Limit the generated vocabulary to the first `top_k` elements. Note - that if `top_k` is larger than the number of unique elements in `x`, then - the result will be exact. - vocab_filename: The file name for the vocabulary file. If None, a file name - will be chosen based on the current scope. If not None, should be unique - within a given preprocessing function. NOTE: To make your pipelines - resilient to implementation details please set `vocab_filename` when you - are using the vocab_filename on a downstream component. - store_frequency: If True, frequency of the words is stored in the vocabulary - file. Each line in the file will be of the form 'frequency word'. NOTE: if - this is True then the computed vocabulary cannot be used with - `tft.apply_vocabulary` directly, since frequencies are added to the - beginning of each row of the vocabulary, which the mapper will not ignore. - reserved_tokens: (Optional) A list of tokens that should appear in the - vocabulary regardless of their appearance in the input. These tokens would - maintain their order, and have a reserved spot at the beginning of the - vocabulary. Note: this field has no affect on cache. - weights: (Optional) Weights `Tensor` for the vocabulary. It must have the - same shape as x. - file_format: (Optional) A str. The format of the resulting vocabulary file. - Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires - tensorflow>=2.4. The default value is 'text'. - name: (Optional) A name for this operation. - - Returns: - The path name for the vocabulary file containing the unique values of `x`. - - Raises: - ValueError: If `top_k` is negative. - If `file_format` is not in the list of allowed formats. - If x.dtype is not string or integral. - """ - - if top_k <= 0: - raise ValueError('top_k must be positive, but got: %r' % top_k) - elif top_k > analyzers.LARGE_VOCAB_TOP_K: - raise ValueError('Provided top_k threshold is too large for the ' - 'approximate calculation: if the expected number of ' - 'unique elements is larger than top_k, tft.vocabulary may ' - 'be more efficient. Maximum allowed top_k is {}'.format( - analyzers.LARGE_VOCAB_TOP_K)) - - if file_format not in analyzers.ALLOWED_VOCABULARY_FILE_FORMATS: - raise ValueError( - '"{}" is not an accepted file_format. It should be one of: {}'.format( - file_format, analyzers.ALLOWED_VOCABULARY_FILE_FORMATS)) - - if x.dtype != tf.string and not x.dtype.is_integer: - raise ValueError('expected tf.string or integer but got %r' % x.dtype) - - with tf.compat.v1.name_scope(name, 'approximate_vocabulary'): - vocabulary_key = vocab_filename - vocab_filename = _get_approx_vocab_filename(vocab_filename, store_frequency) - analyzer_inputs = _get_approximate_vocabulary_analyzer_inputs( - x=x, file_format=file_format, weights=weights) - return _approximate_vocabulary_analyzer_nodes( - analyzer_inputs=analyzer_inputs, - input_dtype=x.dtype.name, - vocab_filename=vocab_filename, - top_k=top_k, - store_frequency=store_frequency, - reserved_tokens=reserved_tokens, - file_format=file_format, - vocabulary_key=vocabulary_key, - ) + r"""Computes the unique values of a `Tensor` over the whole dataset. + + Approximately computes the unique values taken by `x`, which can be a + `Tensor`, `SparseTensor`, or `RaggedTensor` of any size. The unique values + will be aggregated over all dimensions of `x` and all instances. + + This analyzer provides an approximate alternative to `tft.vocabulary` that can + be more efficient with smaller `top_k` and/or smaller number of unique + elements in `x`. As a rule of thumb, `approximate_vocabulary` becomes more + efficient than `tft.vocabulary` if `top_k` or the number of unique elements in + `x` is smaller than 2*10^5. Moreover, this analyzer is subject to combiner + packing optimization that does not apply to `tft.vocabulary`. Caching is also + more efficient with the approximate implementation since the filtration + happens before writing out cache. Output artifact of `approximate_vocabulary` + is consistent with `tft.vocabulary` and can be used in `tft.apply_vocabulary` + mapper. + + Implementation of this analyzer is based on the Misra-Gries algorithm [1]. It + stores at most `top_k` elements with lower bound frequency estimates at a + time. The algorithm keeps track of the approximation error `delta` such that + for any item x with true frequency X: + + frequency[x] <= X <= frequency[x] + delta, + delta <= (m - m') / (top_k + 1), + + where m is the total frequency of the items in the dataset and m' is the sum + of the lower bound estimates in `frequency` [2]. For datasets that are Zipfian + distributed with parameter `a`, the algorithm provides an expected value of + delta = m / (top_k ^ a) [3]. + + [1] + https://www.cs.utexas.edu/users/misra/scannedPdf.dir/FindRepeatedElements.pdf + [2] http://www.cohenwang.com/edith/bigdataclass2013/lectures/lecture1.pdf + [3] http://dimacs.rutgers.edu/~graham/pubs/papers/countersj.pdf + + In case `file_format` is 'text' and one of the tokens contains the '\n' or + '\r' characters or is empty it will be discarded. + + If an integer `Tensor` is provided, its semantic type should be categorical + not a continuous/numeric, since computing a vocabulary over a continuous + feature is not appropriate. + + The unique values are sorted by decreasing frequency and then reverse + lexicographical order (e.g. [('a', 5), ('c', 3), ('b', 3)]). This is true even + if `x` is numerical dtype (e.g. [('3', 5), ('2', 3), ('111', 3)]). + + Args: + ---- + x: A categorical/discrete input `Tensor`, `SparseTensor`, or `RaggedTensor` + with dtype tf.string or tf.int[8|16|32|64]. + top_k: Limit the generated vocabulary to the first `top_k` elements. Note + that if `top_k` is larger than the number of unique elements in `x`, then + the result will be exact. + vocab_filename: The file name for the vocabulary file. If None, a file name + will be chosen based on the current scope. If not None, should be unique + within a given preprocessing function. NOTE: To make your pipelines + resilient to implementation details please set `vocab_filename` when you + are using the vocab_filename on a downstream component. + store_frequency: If True, frequency of the words is stored in the vocabulary + file. Each line in the file will be of the form 'frequency word'. NOTE: if + this is True then the computed vocabulary cannot be used with + `tft.apply_vocabulary` directly, since frequencies are added to the + beginning of each row of the vocabulary, which the mapper will not ignore. + reserved_tokens: (Optional) A list of tokens that should appear in the + vocabulary regardless of their appearance in the input. These tokens would + maintain their order, and have a reserved spot at the beginning of the + vocabulary. Note: this field has no affect on cache. + weights: (Optional) Weights `Tensor` for the vocabulary. It must have the + same shape as x. + file_format: (Optional) A str. The format of the resulting vocabulary file. + Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires + tensorflow>=2.4. The default value is 'text'. + name: (Optional) A name for this operation. + + Returns: + ------- + The path name for the vocabulary file containing the unique values of `x`. + + Raises: + ------ + ValueError: If `top_k` is negative. + If `file_format` is not in the list of allowed formats. + If x.dtype is not string or integral. + """ + if top_k <= 0: + raise ValueError("top_k must be positive, but got: %r" % top_k) + elif top_k > analyzers.LARGE_VOCAB_TOP_K: + raise ValueError( + "Provided top_k threshold is too large for the " + "approximate calculation: if the expected number of " + "unique elements is larger than top_k, tft.vocabulary may " + f"be more efficient. Maximum allowed top_k is {analyzers.LARGE_VOCAB_TOP_K}" + ) + + if file_format not in analyzers.ALLOWED_VOCABULARY_FILE_FORMATS: + raise ValueError( + f'"{file_format}" is not an accepted file_format. It should be one of: {analyzers.ALLOWED_VOCABULARY_FILE_FORMATS}' + ) + + if x.dtype != tf.string and not x.dtype.is_integer: + raise ValueError("expected tf.string or integer but got %r" % x.dtype) + + with tf.compat.v1.name_scope(name, "approximate_vocabulary"): + vocabulary_key = vocab_filename + vocab_filename = _get_approx_vocab_filename(vocab_filename, store_frequency) + analyzer_inputs = _get_approximate_vocabulary_analyzer_inputs( + x=x, file_format=file_format, weights=weights + ) + return _approximate_vocabulary_analyzer_nodes( + analyzer_inputs=analyzer_inputs, + input_dtype=x.dtype.name, + vocab_filename=vocab_filename, + top_k=top_k, + store_frequency=store_frequency, + reserved_tokens=reserved_tokens, + file_format=file_format, + vocabulary_key=vocabulary_key, + ) def _approximate_vocabulary_analyzer_nodes( @@ -442,118 +469,118 @@ def _approximate_vocabulary_analyzer_nodes( vocabulary_key: str, reserved_tokens: Optional[Union[Sequence[str], tf.Tensor]] = None, ) -> common_types.TemporaryAnalyzerOutputType: - """Internal helper for analyzing vocab. See `vocabulary` doc string.""" - # TODO(b/208879020): Add vocabulary size annotation for this analyzer. - analyzers.register_vocab( - vocab_filename, vocabulary_key=vocabulary_key, file_format=file_format) - - outputs_value_nodes = analyzers.apply_cacheable_combine_operation( - _VocabularyCombiner( - top_k, input_dtype, output_pylist=reserved_tokens is not None - ), - *analyzer_inputs - ) - - flattened_outputs_value_node = nodes.apply_operation( - analyzer_nodes.FlattenLists, *outputs_value_nodes) - - extra_apply_order_and_write_op_args = [] - if reserved_tokens is not None: - tf_utils.register_vocabulary_reserved_tokens( - vocab_filename, reserved_tokens + """Internal helper for analyzing vocab. See `vocabulary` doc string.""" + # TODO(b/208879020): Add vocabulary size annotation for this analyzer. + analyzers.register_vocab( + vocab_filename, vocabulary_key=vocabulary_key, file_format=file_format ) - extra_apply_order_and_write_op_args.append( - nodes.apply_operation( - analyzer_nodes.ExtractVocabularyReservedTokens, name=vocab_filename - ) + + outputs_value_nodes = analyzers.apply_cacheable_combine_operation( + _VocabularyCombiner( + top_k, input_dtype, output_pylist=reserved_tokens is not None + ), + *analyzer_inputs, ) - vocab_filename_node = nodes.apply_operation( - analyzer_nodes.VocabularyOrderAndWrite, - flattened_outputs_value_node, - *extra_apply_order_and_write_op_args, - vocab_filename=vocab_filename, - store_frequency=store_frequency, - input_dtype=input_dtype, - file_format=file_format, - fingerprint_shuffle=False, - input_is_sorted=True - ) + flattened_outputs_value_node = nodes.apply_operation( + analyzer_nodes.FlattenLists, *outputs_value_nodes + ) - return analyzer_nodes.wrap_as_tensor(vocab_filename_node) + extra_apply_order_and_write_op_args = [] + if reserved_tokens is not None: + tf_utils.register_vocabulary_reserved_tokens(vocab_filename, reserved_tokens) + extra_apply_order_and_write_op_args.append( + nodes.apply_operation( + analyzer_nodes.ExtractVocabularyReservedTokens, name=vocab_filename + ) + ) + + vocab_filename_node = nodes.apply_operation( + analyzer_nodes.VocabularyOrderAndWrite, + flattened_outputs_value_node, + *extra_apply_order_and_write_op_args, + vocab_filename=vocab_filename, + store_frequency=store_frequency, + input_dtype=input_dtype, + file_format=file_format, + fingerprint_shuffle=False, + input_is_sorted=True, + ) + + return analyzer_nodes.wrap_as_tensor(vocab_filename_node) class _MisraGriesSketchCoder(analyzer_nodes.CacheCoder): - """Cache coder for the approximate vocabulary accumulator.""" + """Cache coder for the approximate vocabulary accumulator.""" - def encode_cache(self, accumulator: sketches.MisraGriesSketch) -> bytes: - return accumulator.Serialize() + def encode_cache(self, accumulator: sketches.MisraGriesSketch) -> bytes: + return accumulator.Serialize() - def decode_cache(self, - encoded_accumulator: bytes) -> sketches.MisraGriesSketch: - return sketches.MisraGriesSketch.Deserialize(encoded_accumulator) + def decode_cache(self, encoded_accumulator: bytes) -> sketches.MisraGriesSketch: + return sketches.MisraGriesSketch.Deserialize(encoded_accumulator) class _VocabularyCombiner(analyzer_nodes.Combiner): - """Approximately computes unique values on the PCollection.""" - - def __init__( - self, - top_k: int, - input_dtype: tf.dtypes.DType, - output_pylist: bool = False, - ): - self._top_k = top_k - self._input_dtype = input_dtype - self._output_pylist = output_pylist - - def create_accumulator(self) -> sketches.MisraGriesSketch: - return sketches.MisraGriesSketch( - self._top_k, - order_on_tie=sketches.MisraGriesSketch.OrderOnTie.ReverseLexicographical - ) + """Approximately computes unique values on the PCollection.""" + + def __init__( + self, + top_k: int, + input_dtype: tf.dtypes.DType, + output_pylist: bool = False, + ): + self._top_k = top_k + self._input_dtype = input_dtype + self._output_pylist = output_pylist + + def create_accumulator(self) -> sketches.MisraGriesSketch: + return sketches.MisraGriesSketch( + self._top_k, + order_on_tie=sketches.MisraGriesSketch.OrderOnTie.ReverseLexicographical, + ) - def add_input( - self, accumulator: sketches.MisraGriesSketch, - next_input: Tuple[np.ndarray, np.ndarray]) -> sketches.MisraGriesSketch: - items, weights = next_input - if items.size: - accumulator.AddValues(pa.array(items), pa.array(weights, pa.float32())) - return accumulator - - def merge_accumulators( - self, accumulators: Iterable[sketches.MisraGriesSketch] - ) -> sketches.MisraGriesSketch: - # Make sure that `accumulators` is an iterator (so that the position is - # remembered). - accumulators = iter(accumulators) - result = next(accumulators) - for accumulator in accumulators: - result.Merge(accumulator) - return result - - def extract_output(self, - accumulator: sketches.MisraGriesSketch) -> np.ndarray: - estimate = accumulator.Estimate() - estimate.validate() - result = np.dstack(list(reversed(estimate.flatten()))) - if not result.size: - result = np.array( - [[analyzers.get_empy_vocabulary_dummy_value(self._input_dtype)]], - dtype=object, - ) - # TODO(b/282952880): Avoid converting to pylist when we can always rely on - # top_k sorted inputs. - if self._output_pylist: - return result.tolist() - return result - - def output_tensor_infos(self) -> List[analyzer_nodes.TensorInfo]: - return [analyzer_nodes.TensorInfo(tf.string, [None, 2], None)] - - @property - def accumulator_coder(self) -> _MisraGriesSketchCoder: - return _MisraGriesSketchCoder() + def add_input( + self, + accumulator: sketches.MisraGriesSketch, + next_input: Tuple[np.ndarray, np.ndarray], + ) -> sketches.MisraGriesSketch: + items, weights = next_input + if items.size: + accumulator.AddValues(pa.array(items), pa.array(weights, pa.float32())) + return accumulator + + def merge_accumulators( + self, accumulators: Iterable[sketches.MisraGriesSketch] + ) -> sketches.MisraGriesSketch: + # Make sure that `accumulators` is an iterator (so that the position is + # remembered). + accumulators = iter(accumulators) + result = next(accumulators) + for accumulator in accumulators: + result.Merge(accumulator) + return result + + def extract_output(self, accumulator: sketches.MisraGriesSketch) -> np.ndarray: + estimate = accumulator.Estimate() + estimate.validate() + result = np.dstack(list(reversed(estimate.flatten()))) + if not result.size: + result = np.array( + [[analyzers.get_empy_vocabulary_dummy_value(self._input_dtype)]], + dtype=object, + ) + # TODO(b/282952880): Avoid converting to pylist when we can always rely on + # top_k sorted inputs. + if self._output_pylist: + return result.tolist() + return result + + def output_tensor_infos(self) -> List[analyzer_nodes.TensorInfo]: + return [analyzer_nodes.TensorInfo(tf.string, [None, 2], None)] + + @property + def accumulator_coder(self) -> _MisraGriesSketchCoder: + return _MisraGriesSketchCoder() def _get_approximate_vocabulary_analyzer_inputs( @@ -561,24 +588,26 @@ def _get_approximate_vocabulary_analyzer_inputs( file_format: common_types.VocabularyFileFormatType, weights: Optional[common_types.TensorType], ) -> Tuple[common_types.TensorType, common_types.TensorType]: - """Helper for constructing approximate vocabulary inputs from tensors. - - Args: - x: `Tensor`, `SparseTensor`, or `RaggedTensor` to compute vocabulary over. - file_format: The format of the resulting vocabulary file. - 'tfrecord_gzip' requires tensorflow>=2.4. - weights: Optional `Tensor` of weights. - - Returns: - A list of batch-reduced `Tensor`s to feed to vocabulary analysis. - """ - filter_regex = analyzers.get_vocab_newline_characters_regex( - x.dtype, file_format) - reduced_batch = tf_utils.reduce_batch_weighted_counts( - x, weights=weights, force=True, filter_regex=filter_regex) - assert reduced_batch.summed_positive_per_x_and_y is None - if weights is None: - assert reduced_batch.summed_weights_per_x is None - return (reduced_batch.unique_x, reduced_batch.counts_per_x) - else: - return (reduced_batch.unique_x, reduced_batch.summed_weights_per_x) + """Helper for constructing approximate vocabulary inputs from tensors. + + Args: + ---- + x: `Tensor`, `SparseTensor`, or `RaggedTensor` to compute vocabulary over. + file_format: The format of the resulting vocabulary file. + 'tfrecord_gzip' requires tensorflow>=2.4. + weights: Optional `Tensor` of weights. + + Returns: + ------- + A list of batch-reduced `Tensor`s to feed to vocabulary analysis. + """ + filter_regex = analyzers.get_vocab_newline_characters_regex(x.dtype, file_format) + reduced_batch = tf_utils.reduce_batch_weighted_counts( + x, weights=weights, force=True, filter_regex=filter_regex + ) + assert reduced_batch.summed_positive_per_x_and_y is None + if weights is None: + assert reduced_batch.summed_weights_per_x is None + return (reduced_batch.unique_x, reduced_batch.counts_per_x) + else: + return (reduced_batch.unique_x, reduced_batch.summed_weights_per_x) diff --git a/tensorflow_transform/experimental/annotators.py b/tensorflow_transform/experimental/annotators.py index 1164fbf..5bf98e3 100644 --- a/tensorflow_transform/experimental/annotators.py +++ b/tensorflow_transform/experimental/annotators.py @@ -16,111 +16,121 @@ from typing import Sequence, Union import tensorflow as tf -from tensorflow_transform import annotators -from tensorflow_transform import schema_inference - -from tensorflow.python.framework import ops # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.framework import ( + ops, # pylint: disable=g-direct-tensorflow-import +) +from tensorflow_transform import annotators, schema_inference __all__ = [ - 'get_vocabulary_size_by_name', - 'annotate_sparse_output_shape', - 'annotate_true_sparse_output', + "get_vocabulary_size_by_name", + "annotate_sparse_output_shape", + "annotate_true_sparse_output", ] def get_vocabulary_size_by_name(vocab_filename: str) -> tf.Tensor: - # pyformat: disable - """Gets the size of a vocabulary created using `tft.vocabulary`. - - This is the number of keys in the output `vocab_filename` and does not include - number of OOV buckets. - - Args: - vocab_filename: The name of the vocabulary file whose size is to be - retrieved. - - Example: - - >>> def preprocessing_fn(inputs): - ... num_oov_buckets = 1 - ... x_int = tft.compute_and_apply_vocabulary( - ... inputs['x'], vocab_filename='my_vocab', - ... num_oov_buckets=num_oov_buckets) - ... depth = ( - ... tft.experimental.get_vocabulary_size_by_name('my_vocab') + - ... num_oov_buckets) - ... x_encoded = tf.one_hot( - ... x_int, depth=tf.cast(depth, tf.int32), dtype=tf.int64) - ... return {'x_encoded': x_encoded} - >>> raw_data = [dict(x='foo'), dict(x='foo'), dict(x='bar')] - >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.string)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... transformed_dataset, transform_fn = ( - ... (raw_data, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - >>> transformed_data, transformed_metadata = transformed_dataset - >>> transformed_data - [{'x_encoded': array([1, 0, 0])}, {'x_encoded': array([1, 0, 0])}, - {'x_encoded': array([0, 1, 0])}] - - Returns: - An integer tensor containing the size of the requested vocabulary. - - Raises: - ValueError: if no vocabulary size found for the given `vocab_filename`. - - """ - # pyformat: enable - vocabulary_sizes_coll = ops.get_default_graph().get_collection( - annotators.VOCABULARY_SIZE_BY_NAME_COLLECTION) - - result = dict(vocabulary_sizes_coll).get(vocab_filename, None) - - if result is None: - raise ValueError( - f'Vocabulary size not found for {vocab_filename}. If this vocabulary ' - 'was created using `tft.vocabulary`, this should be the same as the ' - '`vocab_filename` argument passed to it.') - - return result + # pyformat: disable + """Gets the size of a vocabulary created using `tft.vocabulary`. + + This is the number of keys in the output `vocab_filename` and does not include + number of OOV buckets. + + Args: + ---- + vocab_filename: The name of the vocabulary file whose size is to be + retrieved. + + Example: + ------- + >>> def preprocessing_fn(inputs): + ... num_oov_buckets = 1 + ... x_int = tft.compute_and_apply_vocabulary( + ... inputs['x'], vocab_filename='my_vocab', + ... num_oov_buckets=num_oov_buckets) + ... depth = ( + ... tft.experimental.get_vocabulary_size_by_name('my_vocab') + + ... num_oov_buckets) + ... x_encoded = tf.one_hot( + ... x_int, depth=tf.cast(depth, tf.int32), dtype=tf.int64) + ... return {'x_encoded': x_encoded} + >>> raw_data = [dict(x='foo'), dict(x='foo'), dict(x='bar')] + >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.string)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... transformed_dataset, transform_fn = ( + ... (raw_data, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + >>> transformed_data, transformed_metadata = transformed_dataset + >>> transformed_data + [{'x_encoded': array([1, 0, 0])}, {'x_encoded': array([1, 0, 0])}, + {'x_encoded': array([0, 1, 0])}] + + Returns: + ------- + An integer tensor containing the size of the requested vocabulary. + + Raises: + ------ + ValueError: if no vocabulary size found for the given `vocab_filename`. + + """ + # pyformat: enable + vocabulary_sizes_coll = ops.get_default_graph().get_collection( + annotators.VOCABULARY_SIZE_BY_NAME_COLLECTION + ) + + result = dict(vocabulary_sizes_coll).get(vocab_filename, None) + + if result is None: + raise ValueError( + f"Vocabulary size not found for {vocab_filename}. If this vocabulary " + "was created using `tft.vocabulary`, this should be the same as the " + "`vocab_filename` argument passed to it." + ) + + return result def annotate_sparse_output_shape( - tensor: tf.SparseTensor, shape: Union[Sequence[int], tf.Tensor]): - """Annotates a sparse output to have a given dense_shape. - - Args: - tensor: An `SparseTensor` to be annotated. - shape: A dense_shape to annotate `tensor` with. Note that this shape does - not include batch_size. - """ - if not isinstance(shape, tf.Tensor): - if (tensor.shape.rank > 1 and tensor.shape.rank - 1 != len(shape)) or ( - tensor.shape.rank == 1 and len(shape) != 1): - raise ValueError( - f'Annotated shape {shape} was expected to have rank' - f' {tensor.shape.rank - 1}') - if not all(a is None or a <= b for a, b in zip(tensor.shape[1:], shape)): - raise ValueError( - f'Shape {shape} cannot contain annotated tensor {tensor}') - shape = tf.convert_to_tensor(shape, dtype=tf.int64) - elif shape.shape.rank > 1 or ( - shape.shape.rank == 1 and shape.shape[0] != tensor.shape.rank - 1): - raise ValueError( - f'Annotation shape has rank {shape.shape.rank} but expected to have' - f' rank {tensor.shape.rank - 1}') - if shape.shape.rank < 1: - shape = tf.expand_dims(shape, -1) - # There's currently no way to override SparseTensor.dense_shape directly, - # unless composing and returning a new SparseTensor. - tensor._dense_shape = tf.concat( # pylint: disable=protected-access - [tf.expand_dims(tensor.dense_shape[0], -1), tf.cast(shape, tf.int64)], - axis=0) - schema_inference.annotate_sparse_output_shape(tensor, shape) + tensor: tf.SparseTensor, shape: Union[Sequence[int], tf.Tensor] +): + """Annotates a sparse output to have a given dense_shape. + + Args: + ---- + tensor: An `SparseTensor` to be annotated. + shape: A dense_shape to annotate `tensor` with. Note that this shape does + not include batch_size. + """ + if not isinstance(shape, tf.Tensor): + if (tensor.shape.rank > 1 and tensor.shape.rank - 1 != len(shape)) or ( + tensor.shape.rank == 1 and len(shape) != 1 + ): + raise ValueError( + f"Annotated shape {shape} was expected to have rank" + f" {tensor.shape.rank - 1}" + ) + if not all(a is None or a <= b for a, b in zip(tensor.shape[1:], shape)): + raise ValueError(f"Shape {shape} cannot contain annotated tensor {tensor}") + shape = tf.convert_to_tensor(shape, dtype=tf.int64) + elif shape.shape.rank > 1 or ( + shape.shape.rank == 1 and shape.shape[0] != tensor.shape.rank - 1 + ): + raise ValueError( + f"Annotation shape has rank {shape.shape.rank} but expected to have" + f" rank {tensor.shape.rank - 1}" + ) + if shape.shape.rank < 1: + shape = tf.expand_dims(shape, -1) + # There's currently no way to override SparseTensor.dense_shape directly, + # unless composing and returning a new SparseTensor. + tensor._dense_shape = tf.concat( # pylint: disable=protected-access + [tf.expand_dims(tensor.dense_shape[0], -1), tf.cast(shape, tf.int64)], axis=0 + ) + schema_inference.annotate_sparse_output_shape(tensor, shape) def annotate_true_sparse_output(tensor: tf.SparseTensor): - """Annotates a sparse output to be truely sparse and not varlen.""" - schema_inference.annotate_true_sparse_output(tensor) + """Annotates a sparse output to be truely sparse and not varlen.""" + schema_inference.annotate_true_sparse_output(tensor) diff --git a/tensorflow_transform/experimental/mappers.py b/tensorflow_transform/experimental/mappers.py index 322eec5..9bea710 100644 --- a/tensorflow_transform/experimental/mappers.py +++ b/tensorflow_transform/experimental/mappers.py @@ -26,14 +26,11 @@ are batches from the dataset, whose batch size may vary. """ -from typing import Any, Optional, Union, Sequence +from typing import Any, Optional, Sequence, Union import tensorflow as tf -from tensorflow_transform import analyzers -from tensorflow_transform import common -from tensorflow_transform import common_types -from tensorflow_transform import mappers -from tensorflow_transform import tf_utils + +from tensorflow_transform import analyzers, common, common_types, mappers, tf_utils from tensorflow_transform.experimental import analyzers as experimental_analyzers @@ -51,349 +48,370 @@ def compute_and_apply_approximate_vocabulary( reserved_tokens: Optional[Union[Sequence[str], tf.Tensor]] = None, name: Optional[str] = None, ) -> common_types.ConsistentTensorType: - """Generates an approximate vocabulary for `x` and maps it to an integer. - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor` of type tf.string or - tf.int[8|16|32|64]. - default_value: The value to use for out-of-vocabulary values, unless - 'num_oov_buckets' is greater than zero. - top_k: Limit the generated vocabulary to the first `top_k` elements. If set - to None, the full vocabulary is generated. - num_oov_buckets: Any lookup of an out-of-vocabulary token will return a - bucket ID based on its hash if `num_oov_buckets` is greater than zero. - Otherwise it is assigned the `default_value`. - vocab_filename: The file name for the vocabulary file. If None, a name based - on the scope name in the context of this graph will be used as the file - name. If not None, should be unique within a given preprocessing function. - NOTE in order to make your pipelines resilient to implementation details - please set `vocab_filename` when you are using the vocab_filename on a - downstream component. - weights: (Optional) Weights `Tensor` for the vocabulary. It must have the - same shape as x. - file_format: (Optional) A str. The format of the resulting vocabulary file. - Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires - tensorflow>=2.4. The default value is 'text'. - store_frequency: If True, frequency of the words is stored in the vocabulary - file. In the case labels are provided, the mutual information is stored in - the file instead. Each line in the file will be of the form 'frequency - word'. NOTE: if True and text_format is 'text' then spaces will be - replaced to avoid information loss. - reserved_tokens: (Optional) A list of tokens that should appear in the - vocabulary regardless of their appearance in the input. These tokens would - maintain their order, and have a reserved spot at the beginning of the - vocabulary. Note: this field has no affect on cache. - name: (Optional) A name for this operation. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` where each string value is - mapped to an integer. Each unique string value that appears in the - vocabulary is mapped to a different integer and integers are consecutive - starting from zero. String value not in the vocabulary is assigned - `default_value`. Alternatively, if `num_oov_buckets` is specified, out of - vocabulary strings are hashed to values in - [vocab_size, vocab_size + num_oov_buckets) for an overall range of - [0, vocab_size + num_oov_buckets). - - Raises: - ValueError: If `top_k` is negative. - If `file_format` is not in the list of allowed formats. - If x.dtype is not string or integral. - """ - with tf.compat.v1.name_scope(name, - 'compute_and_apply_approximate_vocabulary'): - if store_frequency and file_format == 'text': - x = tf_utils.maybe_format_vocabulary_input(x) - deferred_vocab_and_filename = experimental_analyzers.approximate_vocabulary( - x=x, - top_k=top_k, - vocab_filename=vocab_filename, - weights=weights, - file_format=file_format, - store_frequency=store_frequency, - reserved_tokens=reserved_tokens, - name=name, - ) - return mappers._apply_vocabulary_internal( # pylint: disable=protected-access - x, - deferred_vocab_and_filename, - default_value, - num_oov_buckets, - lookup_fn=None, - file_format=file_format, - store_frequency=store_frequency, - name=None, - ) + """Generates an approximate vocabulary for `x` and maps it to an integer. + + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor` of type tf.string or + tf.int[8|16|32|64]. + default_value: The value to use for out-of-vocabulary values, unless + 'num_oov_buckets' is greater than zero. + top_k: Limit the generated vocabulary to the first `top_k` elements. If set + to None, the full vocabulary is generated. + num_oov_buckets: Any lookup of an out-of-vocabulary token will return a + bucket ID based on its hash if `num_oov_buckets` is greater than zero. + Otherwise it is assigned the `default_value`. + vocab_filename: The file name for the vocabulary file. If None, a name based + on the scope name in the context of this graph will be used as the file + name. If not None, should be unique within a given preprocessing function. + NOTE in order to make your pipelines resilient to implementation details + please set `vocab_filename` when you are using the vocab_filename on a + downstream component. + weights: (Optional) Weights `Tensor` for the vocabulary. It must have the + same shape as x. + file_format: (Optional) A str. The format of the resulting vocabulary file. + Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires + tensorflow>=2.4. The default value is 'text'. + store_frequency: If True, frequency of the words is stored in the vocabulary + file. In the case labels are provided, the mutual information is stored in + the file instead. Each line in the file will be of the form 'frequency + word'. NOTE: if True and text_format is 'text' then spaces will be + replaced to avoid information loss. + reserved_tokens: (Optional) A list of tokens that should appear in the + vocabulary regardless of their appearance in the input. These tokens would + maintain their order, and have a reserved spot at the beginning of the + vocabulary. Note: this field has no affect on cache. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` where each string value is + mapped to an integer. Each unique string value that appears in the + vocabulary is mapped to a different integer and integers are consecutive + starting from zero. String value not in the vocabulary is assigned + `default_value`. Alternatively, if `num_oov_buckets` is specified, out of + vocabulary strings are hashed to values in + [vocab_size, vocab_size + num_oov_buckets) for an overall range of + [0, vocab_size + num_oov_buckets). + + Raises: + ------ + ValueError: If `top_k` is negative. + If `file_format` is not in the list of allowed formats. + If x.dtype is not string or integral. + """ + with tf.compat.v1.name_scope(name, "compute_and_apply_approximate_vocabulary"): + if store_frequency and file_format == "text": + x = tf_utils.maybe_format_vocabulary_input(x) + deferred_vocab_and_filename = experimental_analyzers.approximate_vocabulary( + x=x, + top_k=top_k, + vocab_filename=vocab_filename, + weights=weights, + file_format=file_format, + store_frequency=store_frequency, + reserved_tokens=reserved_tokens, + name=name, + ) + return mappers._apply_vocabulary_internal( # pylint: disable=protected-access + x, + deferred_vocab_and_filename, + default_value, + num_oov_buckets, + lookup_fn=None, + file_format=file_format, + store_frequency=store_frequency, + name=None, + ) @common.log_api_use(common.MAPPER_COLLECTION) -def document_frequency(x: tf.SparseTensor, - vocab_size: int, - name: Optional[str] = None) -> tf.SparseTensor: - """Maps the terms in x to their document frequency in the same order. - - The document frequency of a term is the number of documents that contain the - term in the entire dataset. Each unique vocab term has a unique document - frequency. - - Example usage: - - >>> def preprocessing_fn(inputs): - ... integerized = tft.compute_and_apply_vocabulary(inputs['x']) - ... vocab_size = tft.get_num_buckets_for_transformed_feature(integerized) - ... return { - ... 'df': tft.experimental.document_frequency(integerized, vocab_size), - ... 'integerized': integerized, - ... } - >>> raw_data = [dict(x=["I", "like", "pie", "pie", "pie"]), - ... dict(x=["yum", "yum", "pie"])] - >>> feature_spec = dict(x=tf.io.VarLenFeature(tf.string)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... transformed_dataset, transform_fn = ( - ... (raw_data, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - >>> transformed_data, transformed_metadata = transformed_dataset - >>> transformed_data - [{'df': array([1, 1, 2, 2, 2]), 'integerized': array([3, 2, 0, 0, 0])}, - {'df': array([1, 1, 2]), 'integerized': array([1, 1, 0])}] - - ``` - example strings: [["I", "like", "pie", "pie", "pie"], ["yum", "yum", "pie]] - in: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], - [1, 0], [1, 1], [1, 2]], - values=[1, 2, 0, 0, 0, 3, 3, 0]) - out: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], - [1, 0], [1, 1], [1, 2]], - values=[1, 1, 2, 2, 2, 1, 1, 2]) - ``` - - Args: - x: A 2D `SparseTensor` representing int64 values (most likely that are the - result of calling `compute_and_apply_vocabulary` on a tokenized string). - vocab_size: An int - the count of vocab used to turn the string into int64s - including any OOV buckets. - name: (Optional) A name for this operation. - - Returns: - `SparseTensor`s with indices [index_in_batch, index_in_local_sequence] and - values document_frequency. Same shape as the input `x`. - - Raises: - ValueError if `x` does not have 2 dimensions. - """ - if x.get_shape().ndims != 2: - raise ValueError('tft.tfidf requires a 2D SparseTensor input. ' - 'Input had {} dimensions.'.format(x.get_shape().ndims)) - - with tf.compat.v1.name_scope(name, 'df'): - cleaned_input = tf_utils.to_vocab_range(x, vocab_size) - - # all_df is a (1, vocab_size)-shaped sparse tensor storing number of docs - # containing each term in the entire dataset. - all_df = _to_global_document_frequency(cleaned_input, vocab_size) - - # df_values is a batch_size * sequence_size sparse tensor storing the - # document frequency of each term, following the same order as the terms - # within each document. - df_values = tf.gather(tf.squeeze(all_df), cleaned_input.values) - - return tf.SparseTensor( - indices=cleaned_input.indices, - values=df_values, - dense_shape=cleaned_input.dense_shape) +def document_frequency( + x: tf.SparseTensor, vocab_size: int, name: Optional[str] = None +) -> tf.SparseTensor: + """Maps the terms in x to their document frequency in the same order. + + The document frequency of a term is the number of documents that contain the + term in the entire dataset. Each unique vocab term has a unique document + frequency. + + Example usage: + + >>> def preprocessing_fn(inputs): + ... integerized = tft.compute_and_apply_vocabulary(inputs['x']) + ... vocab_size = tft.get_num_buckets_for_transformed_feature(integerized) + ... return { + ... 'df': tft.experimental.document_frequency(integerized, vocab_size), + ... 'integerized': integerized, + ... } + >>> raw_data = [dict(x=["I", "like", "pie", "pie", "pie"]), + ... dict(x=["yum", "yum", "pie"])] + >>> feature_spec = dict(x=tf.io.VarLenFeature(tf.string)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... transformed_dataset, transform_fn = ( + ... (raw_data, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + >>> transformed_data, transformed_metadata = transformed_dataset + >>> transformed_data + [{'df': array([1, 1, 2, 2, 2]), 'integerized': array([3, 2, 0, 0, 0])}, + {'df': array([1, 1, 2]), 'integerized': array([1, 1, 0])}] + + ``` + example strings: [["I", "like", "pie", "pie", "pie"], ["yum", "yum", "pie]] + in: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], + [1, 0], [1, 1], [1, 2]], + values=[1, 2, 0, 0, 0, 3, 3, 0]) + out: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], + [1, 0], [1, 1], [1, 2]], + values=[1, 1, 2, 2, 2, 1, 1, 2]) + ``` + + Args: + ---- + x: A 2D `SparseTensor` representing int64 values (most likely that are the + result of calling `compute_and_apply_vocabulary` on a tokenized string). + vocab_size: An int - the count of vocab used to turn the string into int64s + including any OOV buckets. + name: (Optional) A name for this operation. + + Returns: + ------- + `SparseTensor`s with indices [index_in_batch, index_in_local_sequence] and + values document_frequency. Same shape as the input `x`. + + Raises: + ------ + ValueError if `x` does not have 2 dimensions. + """ + if x.get_shape().ndims != 2: + raise ValueError( + "tft.tfidf requires a 2D SparseTensor input. " + f"Input had {x.get_shape().ndims} dimensions." + ) + + with tf.compat.v1.name_scope(name, "df"): + cleaned_input = tf_utils.to_vocab_range(x, vocab_size) + + # all_df is a (1, vocab_size)-shaped sparse tensor storing number of docs + # containing each term in the entire dataset. + all_df = _to_global_document_frequency(cleaned_input, vocab_size) + + # df_values is a batch_size * sequence_size sparse tensor storing the + # document frequency of each term, following the same order as the terms + # within each document. + df_values = tf.gather(tf.squeeze(all_df), cleaned_input.values) + + return tf.SparseTensor( + indices=cleaned_input.indices, + values=df_values, + dense_shape=cleaned_input.dense_shape, + ) @common.log_api_use(common.MAPPER_COLLECTION) -def idf(x: tf.SparseTensor, - vocab_size: int, - smooth: bool = True, - add_baseline: bool = True, - name: Optional[str] = None) -> tf.SparseTensor: - """Maps the terms in x to their inverse document frequency in the same order. - - The inverse document frequency of a term, by default, is calculated as - 1 + log ((corpus size + 1) / (count of documents containing term + 1)). - - Example usage: - - >>> def preprocessing_fn(inputs): - ... integerized = tft.compute_and_apply_vocabulary(inputs['x']) - ... vocab_size = tft.get_num_buckets_for_transformed_feature(integerized) - ... idf_weights = tft.experimental.idf(integerized, vocab_size) - ... return { - ... 'idf': idf_weights, - ... 'integerized': integerized, - ... } - >>> raw_data = [dict(x=["I", "like", "pie", "pie", "pie"]), - ... dict(x=["yum", "yum", "pie"])] - >>> feature_spec = dict(x=tf.io.VarLenFeature(tf.string)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... transformed_dataset, transform_fn = ( - ... (raw_data, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - >>> transformed_data, transformed_metadata = transformed_dataset - >>> # 1 + log(3/2) = 1.4054651 - >>> transformed_data - [{'idf': array([1.4054651, 1.4054651, 1., 1., 1.], dtype=float32), - 'integerized': array([3, 2, 0, 0, 0])}, - {'idf': array([1.4054651, 1.4054651, 1.], dtype=float32), - 'integerized': array([1, 1, 0])}] - - ``` - example strings: [["I", "like", "pie", "pie", "pie"], ["yum", "yum", "pie]] - in: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], - [1, 0], [1, 1], [1, 2]], - values=[1, 2, 0, 0, 0, 3, 3, 0]) - out: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], - [1, 0], [1, 1], [1, 2]], - values=[1 + log(3/2), 1 + log(3/2), 1, 1, 1, - 1 + log(3/2), 1 + log(3/2), 1]) - ``` - - Args: - x: A 2D `SparseTensor` representing int64 values (most likely that are the - result of calling `compute_and_apply_vocabulary` on a tokenized string). - vocab_size: An int - the count of vocab used to turn the string into int64s - including any OOV buckets. - smooth: A bool indicating if the inverse document frequency should be - smoothed. If True, which is the default, then the idf is calculated as 1 + - log((corpus size + 1) / (document frequency of term + 1)). Otherwise, the - idf is 1 + log((corpus size) / (document frequency of term)), which could - result in a division by zero error. - add_baseline: A bool indicating if the inverse document frequency should be - added with a constant baseline 1.0. If True, which is the default, then - the idf is calculated as 1 + log(*). Otherwise, the idf is log(*) without - the constant 1 baseline. Keeping the baseline reduces the discrepancy in - idf between commonly seen terms and rare terms. - name: (Optional) A name for this operation. - - Returns: - `SparseTensor`s with indices [index_in_batch, index_in_local_sequence] and - values inverse document frequency. Same shape as the input `x`. - - Raises: - ValueError if `x` does not have 2 dimensions. - """ - if x.get_shape().ndims != 2: - raise ValueError('tft.tfidf requires a 2D SparseTensor input. ' - 'Input had {} dimensions.'.format(x.get_shape().ndims)) - - with tf.compat.v1.name_scope(name, 'idf'): - cleaned_input = tf_utils.to_vocab_range(x, vocab_size) - - batch_sizes = tf.expand_dims(tf.shape(input=cleaned_input)[0], 0) - - # all_df is a (1, vocab_size)-shaped tensor storing number of documents - # containing each term in the entire dataset. - all_df = _to_global_document_frequency(cleaned_input, vocab_size) - - # all_idf is a (1, vocab_size)-shaped tensor storing the inverse document - # frequency of each term in the entire dataset. - all_idf = tf_utils.document_frequency_to_idf( - all_df, - analyzers.sum(batch_sizes), - smooth=smooth, - add_baseline=add_baseline) - - # idf_values is a batch_size * sequence_size sparse tensor storing the - # inverse document frequency of each term, following the same order as the - # terms within each document. - idf_values = tf.gather( - tf.reshape(all_idf, [-1]), tf.cast(cleaned_input.values, dtype=tf.int64) +def idf( + x: tf.SparseTensor, + vocab_size: int, + smooth: bool = True, + add_baseline: bool = True, + name: Optional[str] = None, +) -> tf.SparseTensor: + """Maps the terms in x to their inverse document frequency in the same order. + + The inverse document frequency of a term, by default, is calculated as + 1 + log ((corpus size + 1) / (count of documents containing term + 1)). + + Example usage: + + >>> def preprocessing_fn(inputs): + ... integerized = tft.compute_and_apply_vocabulary(inputs['x']) + ... vocab_size = tft.get_num_buckets_for_transformed_feature(integerized) + ... idf_weights = tft.experimental.idf(integerized, vocab_size) + ... return { + ... 'idf': idf_weights, + ... 'integerized': integerized, + ... } + >>> raw_data = [dict(x=["I", "like", "pie", "pie", "pie"]), + ... dict(x=["yum", "yum", "pie"])] + >>> feature_spec = dict(x=tf.io.VarLenFeature(tf.string)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... transformed_dataset, transform_fn = ( + ... (raw_data, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + >>> transformed_data, transformed_metadata = transformed_dataset + >>> # 1 + log(3/2) = 1.4054651 + >>> transformed_data + [{'idf': array([1.4054651, 1.4054651, 1., 1., 1.], dtype=float32), + 'integerized': array([3, 2, 0, 0, 0])}, + {'idf': array([1.4054651, 1.4054651, 1.], dtype=float32), + 'integerized': array([1, 1, 0])}] + + ``` + example strings: [["I", "like", "pie", "pie", "pie"], ["yum", "yum", "pie]] + in: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], + [1, 0], [1, 1], [1, 2]], + values=[1, 2, 0, 0, 0, 3, 3, 0]) + out: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], + [1, 0], [1, 1], [1, 2]], + values=[1 + log(3/2), 1 + log(3/2), 1, 1, 1, + 1 + log(3/2), 1 + log(3/2), 1]) + ``` + + Args: + ---- + x: A 2D `SparseTensor` representing int64 values (most likely that are the + result of calling `compute_and_apply_vocabulary` on a tokenized string). + vocab_size: An int - the count of vocab used to turn the string into int64s + including any OOV buckets. + smooth: A bool indicating if the inverse document frequency should be + smoothed. If True, which is the default, then the idf is calculated as 1 + + log((corpus size + 1) / (document frequency of term + 1)). Otherwise, the + idf is 1 + log((corpus size) / (document frequency of term)), which could + result in a division by zero error. + add_baseline: A bool indicating if the inverse document frequency should be + added with a constant baseline 1.0. If True, which is the default, then + the idf is calculated as 1 + log(*). Otherwise, the idf is log(*) without + the constant 1 baseline. Keeping the baseline reduces the discrepancy in + idf between commonly seen terms and rare terms. + name: (Optional) A name for this operation. + + Returns: + ------- + `SparseTensor`s with indices [index_in_batch, index_in_local_sequence] and + values inverse document frequency. Same shape as the input `x`. + + Raises: + ------ + ValueError if `x` does not have 2 dimensions. + """ + if x.get_shape().ndims != 2: + raise ValueError( + "tft.tfidf requires a 2D SparseTensor input. " + f"Input had {x.get_shape().ndims} dimensions." + ) + + with tf.compat.v1.name_scope(name, "idf"): + cleaned_input = tf_utils.to_vocab_range(x, vocab_size) + + batch_sizes = tf.expand_dims(tf.shape(input=cleaned_input)[0], 0) + + # all_df is a (1, vocab_size)-shaped tensor storing number of documents + # containing each term in the entire dataset. + all_df = _to_global_document_frequency(cleaned_input, vocab_size) + + # all_idf is a (1, vocab_size)-shaped tensor storing the inverse document + # frequency of each term in the entire dataset. + all_idf = tf_utils.document_frequency_to_idf( + all_df, analyzers.sum(batch_sizes), smooth=smooth, add_baseline=add_baseline + ) + + # idf_values is a batch_size * sequence_size sparse tensor storing the + # inverse document frequency of each term, following the same order as the + # terms within each document. + idf_values = tf.gather( + tf.reshape(all_idf, [-1]), tf.cast(cleaned_input.values, dtype=tf.int64) + ) + + return tf.SparseTensor( + indices=cleaned_input.indices, + values=idf_values, + dense_shape=cleaned_input.dense_shape, + ) + + +def _to_term_document_one_hot( + x: tf.SparseTensor, vocab_size: Union[int, tf.Tensor] +) -> tf.SparseTensor: + """Creates a one-hot SparseTensor of term existence for every doc/term pair. + + Converts a -indexed, -valued + sparse tensor to one-hot tensor to represent the + existence of each vocab term in each document of a batch. For example, when x + has the dense form: + [[3, 2, 3], # first example of the batch has vocab term 2 and 3 + [1, 1]], # second example of the batch has vocab term 1 + with vocab_size=4, the dense form of the out one-hot tensor is + [[0, 0, 1, 1], + [0, 1, 0, 0]] + + Args: + ---- + x: a SparseTensor of int64 representing string indices in vocab. The indices + are and the values are . + Typically, x is the output of tft.compute_and_apply_vocabulary. + vocab_size: A scalar int64 Tensor - the count of vocab used to turn the + string into int64s including any OOV buckets. + + Returns: + ------- + a SparseTensor with size (batch_size, vocab_size), indices being + , and int32 values being 1 for + all mentioned terms or 0 if not shown in each document. + """ + vocab_size = tf.convert_to_tensor(value=vocab_size, dtype=tf.int64) + + # Combine batch indices (first column of x's indices) and vocab indices ( + # x's values) as new indices (, ). + batch_indices = x.indices[:, 0] # sparse tensor indices are int64 + vocab_indices = tf.cast(x.values, dtype=tf.int64) + + # Dedup (, ) pairs. This is because document + # frequency only cares the existence of a term in a document, not the + # occurrence frequency within that document. + # Hashing (, ) pairs for dedup. + multiplier = vocab_size + 1 + unique_flatten_indices, _ = tf.raw_ops.UniqueV2( + x=batch_indices * multiplier + vocab_indices, axis=[0] + ) + unique_batch_indices = tf.cast( + tf.math.divide(unique_flatten_indices, multiplier), dtype=tf.int64 + ) + unique_vocab_indices = tf.math.mod(unique_flatten_indices, multiplier) + unique_batch_vocab_indices = tf.transpose( + tf.stack([unique_batch_indices, unique_vocab_indices]) ) - return tf.SparseTensor( - indices=cleaned_input.indices, - values=idf_values, - dense_shape=cleaned_input.dense_shape) + # If term i shows at least once in document j, then doc_freq = 1 + one_hot_values = tf.ones_like(unique_flatten_indices, dtype=tf.int32) + # New shape of the one hot tensor is batch_size * vocab_size + new_shape = tf.stack([x.dense_shape[0], vocab_size]) + new_shape.set_shape([2]) -def _to_term_document_one_hot( - x: tf.SparseTensor, vocab_size: Union[int, tf.Tensor]) -> tf.SparseTensor: - """Creates a one-hot SparseTensor of term existence for every doc/term pair. - - Converts a -indexed, -valued - sparse tensor to one-hot tensor to represent the - existence of each vocab term in each document of a batch. For example, when x - has the dense form: - [[3, 2, 3], # first example of the batch has vocab term 2 and 3 - [1, 1]], # second example of the batch has vocab term 1 - with vocab_size=4, the dense form of the out one-hot tensor is - [[0, 0, 1, 1], - [0, 1, 0, 0]] - - Args: - x: a SparseTensor of int64 representing string indices in vocab. The indices - are and the values are . - Typically, x is the output of tft.compute_and_apply_vocabulary. - vocab_size: A scalar int64 Tensor - the count of vocab used to turn the - string into int64s including any OOV buckets. - - Returns: - a SparseTensor with size (batch_size, vocab_size), indices being - , and int32 values being 1 for - all mentioned terms or 0 if not shown in each document. - """ - vocab_size = tf.convert_to_tensor(value=vocab_size, dtype=tf.int64) - - # Combine batch indices (first column of x's indices) and vocab indices ( - # x's values) as new indices (, ). - batch_indices = x.indices[:, 0] # sparse tensor indices are int64 - vocab_indices = tf.cast(x.values, dtype=tf.int64) - - # Dedup (, ) pairs. This is because document - # frequency only cares the existence of a term in a document, not the - # occurrence frequency within that document. - # Hashing (, ) pairs for dedup. - multiplier = vocab_size + 1 - unique_flatten_indices, _ = tf.raw_ops.UniqueV2( - x=batch_indices * multiplier + vocab_indices, axis=[0]) - unique_batch_indices = tf.cast( - tf.math.divide(unique_flatten_indices, multiplier), dtype=tf.int64) - unique_vocab_indices = tf.math.mod(unique_flatten_indices, multiplier) - unique_batch_vocab_indices = tf.transpose( - tf.stack([unique_batch_indices, unique_vocab_indices])) - - # If term i shows at least once in document j, then doc_freq = 1 - one_hot_values = tf.ones_like(unique_flatten_indices, dtype=tf.int32) - - # New shape of the one hot tensor is batch_size * vocab_size - new_shape = tf.stack([x.dense_shape[0], vocab_size]) - new_shape.set_shape([2]) - - return tf.SparseTensor( - indices=unique_batch_vocab_indices, - values=one_hot_values, - dense_shape=new_shape) + return tf.SparseTensor( + indices=unique_batch_vocab_indices, values=one_hot_values, dense_shape=new_shape + ) def _to_global_document_frequency( - x: tf.SparseTensor, vocab_size: Union[int, tf.Tensor]) -> tf.Tensor: - """Summerizes term/doc one-hot tensor to get document frequency of each term. - - Args: - x: a SparseTensor of size (batch_size, vocab_size) and values 0/1 to - indicate to existence of each term in each document. x is expected to be - the output of _to_term_document_one_hot. - vocab_size: A scalar int64 Tensor - the count of vocab used to turn the - string into int64s including any OOV buckets. - - Returns: - a tensor with indices as (1, ) and values as the count of - documents in the entire dataset that contain each vocab term. - """ - # term_doc_freq is the one-hot encoding of term existence in each document. - # It is a (batch_size, vocab_size)-shaped, 0/1 valued sparse tensor. - term_doc_one_hot = _to_term_document_one_hot(x, vocab_size) - - # Reduce sum the one-hot tensor within each mini batch to get one - # (1, vocab_size)-shaped sparse tensor for each mini batch, with the value - # being the count of documents containing each term in that batch. - count_docs_with_term = tf.sparse.reduce_sum( - term_doc_one_hot, axis=0, keepdims=True) - - # Sum up all batches to get a (1, vocab_size)-shaped sparse tensor storing - # count of documents containing each term across the entire dataset. - return analyzers.sum(count_docs_with_term, reduce_instance_dims=False) + x: tf.SparseTensor, vocab_size: Union[int, tf.Tensor] +) -> tf.Tensor: + """Summerizes term/doc one-hot tensor to get document frequency of each term. + + Args: + ---- + x: a SparseTensor of size (batch_size, vocab_size) and values 0/1 to + indicate to existence of each term in each document. x is expected to be + the output of _to_term_document_one_hot. + vocab_size: A scalar int64 Tensor - the count of vocab used to turn the + string into int64s including any OOV buckets. + + Returns: + ------- + a tensor with indices as (1, ) and values as the count of + documents in the entire dataset that contain each vocab term. + """ + # term_doc_freq is the one-hot encoding of term existence in each document. + # It is a (batch_size, vocab_size)-shaped, 0/1 valued sparse tensor. + term_doc_one_hot = _to_term_document_one_hot(x, vocab_size) + + # Reduce sum the one-hot tensor within each mini batch to get one + # (1, vocab_size)-shaped sparse tensor for each mini batch, with the value + # being the count of documents containing each term in that batch. + count_docs_with_term = tf.sparse.reduce_sum(term_doc_one_hot, axis=0, keepdims=True) + + # Sum up all batches to get a (1, vocab_size)-shaped sparse tensor storing + # count of documents containing each term across the entire dataset. + return analyzers.sum(count_docs_with_term, reduce_instance_dims=False) diff --git a/tensorflow_transform/gaussianization.py b/tensorflow_transform/gaussianization.py index 320acb6..a07b3dc 100644 --- a/tensorflow_transform/gaussianization.py +++ b/tensorflow_transform/gaussianization.py @@ -24,330 +24,387 @@ def tukey_hh_l_mean_and_scale(h_params): - """Computes L-mean and L-scale for a Tukey HH distribution. - - Args: - h_params: An np.array with dimension 2 on the first axis. The slice - h_params[0, ...] contains the left parameter of the distribution and - h_params[1, ...] the right parameter. Each entry h must in 0 <= h < 1. - - Returns: - The tuple (L_mean, L_scale) containing the first two L-moments for the - given parameters. Each entry has the same shape as h_params, except for - the first axis, which is removed. - """ - one_div_sqrt2pi = 1.0 / np.sqrt(2.0 * np.pi) - hl = h_params[0, ...] - hr = h_params[1, ...] - dtype = h_params.dtype - l_1 = one_div_sqrt2pi * (1.0 / (hl - 1.0) + 1.0 / (1.0 - hr)) - l_2 = one_div_sqrt2pi * ( - (np.sqrt(2.0 - hl) + np.sqrt(2.0 - hr) - hl * np.sqrt(2.0 - hl) - - hr * np.sqrt(2 - hr)) / - ((hl - 1.0) * (hr - 1.0) * np.sqrt((hl - 2.0) * (hr - 2.0)))) - return (l_1.astype(dtype), l_2.astype(dtype)) + """Computes L-mean and L-scale for a Tukey HH distribution. + + Args: + ---- + h_params: An np.array with dimension 2 on the first axis. The slice + h_params[0, ...] contains the left parameter of the distribution and + h_params[1, ...] the right parameter. Each entry h must in 0 <= h < 1. + + Returns: + ------- + The tuple (L_mean, L_scale) containing the first two L-moments for the + given parameters. Each entry has the same shape as h_params, except for + the first axis, which is removed. + """ + one_div_sqrt2pi = 1.0 / np.sqrt(2.0 * np.pi) + hl = h_params[0, ...] + hr = h_params[1, ...] + dtype = h_params.dtype + l_1 = one_div_sqrt2pi * (1.0 / (hl - 1.0) + 1.0 / (1.0 - hr)) + l_2 = one_div_sqrt2pi * ( + ( + np.sqrt(2.0 - hl) + + np.sqrt(2.0 - hr) + - hl * np.sqrt(2.0 - hl) + - hr * np.sqrt(2 - hr) + ) + / ((hl - 1.0) * (hr - 1.0) * np.sqrt((hl - 2.0) * (hr - 2.0))) + ) + return (l_1.astype(dtype), l_2.astype(dtype)) def _tukey_hh_l_skewness_and_kurtosis(h_params): - """Computes L-skewness and L-kurtosis for a Tukey HH distribution. - - Args: - h_params: An np.array with dimension 2 on the first axis. The slice - h_params[0, ...] contains the left parameter of the distribution and - h_params[1, ...] the right parameter. - - Returns: - The tuple (L_skewness, L_kurtosis) for the given parameters. Each entry - has the same shape as h_params, except for the first axis, which is - removed. - """ - def skewness_num(h1, h2): - return (12 * np.sqrt(2.0 - h1) * (h2 - 2.0) * (h2 - 1.0) * - np.arctan(1.0 / np.sqrt(2.0 - h1))) - - def skewness_den(h): - return h * np.sqrt(2 - h) - np.sqrt(2 - h) - - def kurtosis_den_part(h): - return h * np.sqrt(2.0 - h) - np.sqrt(2.0 - h) - - hl = h_params[0, ...] - hr = h_params[1, ...] - dtype = h_params.dtype - skewness = (skewness_num(hl, hr) - - np.pi * (hl - hr) * (hl - 2.0) * (hr - 2.0) - - skewness_num(hr, hl)) / ( - 2 * np.pi * np.sqrt((hl - 2.0) * (hr - 2.0)) * - (skewness_den(hl) + skewness_den(hr))) - kurtosis_num_1 = ( - hr * np.sqrt((hl - 4.0) * (hl - 2.0) * (hl - 1.0) * (hr - 2.0)) - - 2.0 * np.sqrt((hl - 4.0) * (hl - 1.0))) - kurtosis_num_2 = (hl * (hl - 3.0) * np.sqrt((hl - 4.0) * (hl - 1.0)) + - np.sqrt((hl - 4.0) * (hl - 2.0) * (hl - 1.0) * (hr - 2.0))) - kurtosis_num_3 = (30.0 * (hl - 1.0) * - np.sqrt((hl - 4.0) * (hl - 2.0) * (hr - 2.0) / (hl - 1.0)) * - (hr - 1.0) * np.arctan(np.sqrt(1.0 + 2.0 / (hl - 4.0)))) - kurtosis_num_4 = (30.0 * (hl - 2) * - np.sqrt((hl - 4.0) * (hl - 1.0)) * (hl - 1.0) * - np.arctan(np.sqrt(1.0 + 2.0 / (hr - 4.0)))) - kurtosis_den = (np.pi * np.sqrt((4.0 - hl) * (2.0 - hl) * (1.0 - hl)) * - (kurtosis_den_part(hl) + kurtosis_den_part(hr))) - kurtosis = (6.0 * np.pi * (kurtosis_num_1 - kurtosis_num_2) + - kurtosis_num_3 + kurtosis_num_4) / kurtosis_den - return (skewness.astype(dtype), kurtosis.astype(dtype)) + """Computes L-skewness and L-kurtosis for a Tukey HH distribution. + + Args: + ---- + h_params: An np.array with dimension 2 on the first axis. The slice + h_params[0, ...] contains the left parameter of the distribution and + h_params[1, ...] the right parameter. + + Returns: + ------- + The tuple (L_skewness, L_kurtosis) for the given parameters. Each entry + has the same shape as h_params, except for the first axis, which is + removed. + """ + + def skewness_num(h1, h2): + return ( + 12 + * np.sqrt(2.0 - h1) + * (h2 - 2.0) + * (h2 - 1.0) + * np.arctan(1.0 / np.sqrt(2.0 - h1)) + ) + + def skewness_den(h): + return h * np.sqrt(2 - h) - np.sqrt(2 - h) + + def kurtosis_den_part(h): + return h * np.sqrt(2.0 - h) - np.sqrt(2.0 - h) + + hl = h_params[0, ...] + hr = h_params[1, ...] + dtype = h_params.dtype + skewness = ( + skewness_num(hl, hr) + - np.pi * (hl - hr) * (hl - 2.0) * (hr - 2.0) + - skewness_num(hr, hl) + ) / ( + 2 + * np.pi + * np.sqrt((hl - 2.0) * (hr - 2.0)) + * (skewness_den(hl) + skewness_den(hr)) + ) + kurtosis_num_1 = hr * np.sqrt( + (hl - 4.0) * (hl - 2.0) * (hl - 1.0) * (hr - 2.0) + ) - 2.0 * np.sqrt((hl - 4.0) * (hl - 1.0)) + kurtosis_num_2 = hl * (hl - 3.0) * np.sqrt((hl - 4.0) * (hl - 1.0)) + np.sqrt( + (hl - 4.0) * (hl - 2.0) * (hl - 1.0) * (hr - 2.0) + ) + kurtosis_num_3 = ( + 30.0 + * (hl - 1.0) + * np.sqrt((hl - 4.0) * (hl - 2.0) * (hr - 2.0) / (hl - 1.0)) + * (hr - 1.0) + * np.arctan(np.sqrt(1.0 + 2.0 / (hl - 4.0))) + ) + kurtosis_num_4 = ( + 30.0 + * (hl - 2) + * np.sqrt((hl - 4.0) * (hl - 1.0)) + * (hl - 1.0) + * np.arctan(np.sqrt(1.0 + 2.0 / (hr - 4.0))) + ) + kurtosis_den = ( + np.pi + * np.sqrt((4.0 - hl) * (2.0 - hl) * (1.0 - hl)) + * (kurtosis_den_part(hl) + kurtosis_den_part(hr)) + ) + kurtosis = ( + 6.0 * np.pi * (kurtosis_num_1 - kurtosis_num_2) + + kurtosis_num_3 + + kurtosis_num_4 + ) / kurtosis_den + return (skewness.astype(dtype), kurtosis.astype(dtype)) def _binary_search(error_fn, low_value, high_value): - """Binary search for a function given start and end interval. - - This is a simple binary search over the values of the function error_fn given - the interval [low_value, high_value]. We expect that the starting condition is - error_fn(low_value) < 0 and error_fn(high_value) > 0 and we bisect the - interval until the exit conditions are met. The result is the final interval - [low_value, high_value] that is normally much smaller than the initial one, - but still satisfying the starting condition. - - Args: - error_fn: Function mapping values to errors. - low_value: Lower interval endpoint. We expect f(low_value) < 0. - high_value: Higher interval endpoint. We expect f(high_value) > 0. - - Returns: - The final interval endpoints (low_value, high_value) after the sequence of - bisections. - """ - # Exit conditions. - stop_iter_step = 10 # Max number of iterations. - stop_error_step = 1e-6 # Minimum function variation. - stop_value_step = 1e-6 # Minimum variable variation. - - current_iter = 0 - while True: - current_value = (low_value + high_value) / 2.0 - current_error = error_fn(current_value) - if current_error < 0.0: - low_value = current_value - else: - high_value = current_value - current_iter += 1 - if (current_iter > stop_iter_step or - np.abs(current_error) < stop_error_step or - high_value - low_value < stop_value_step): - break - return low_value, high_value + """Binary search for a function given start and end interval. + + This is a simple binary search over the values of the function error_fn given + the interval [low_value, high_value]. We expect that the starting condition is + error_fn(low_value) < 0 and error_fn(high_value) > 0 and we bisect the + interval until the exit conditions are met. The result is the final interval + [low_value, high_value] that is normally much smaller than the initial one, + but still satisfying the starting condition. + + Args: + ---- + error_fn: Function mapping values to errors. + low_value: Lower interval endpoint. We expect f(low_value) < 0. + high_value: Higher interval endpoint. We expect f(high_value) > 0. + + Returns: + ------- + The final interval endpoints (low_value, high_value) after the sequence of + bisections. + """ + # Exit conditions. + stop_iter_step = 10 # Max number of iterations. + stop_error_step = 1e-6 # Minimum function variation. + stop_value_step = 1e-6 # Minimum variable variation. + + current_iter = 0 + while True: + current_value = (low_value + high_value) / 2.0 + current_error = error_fn(current_value) + if current_error < 0.0: + low_value = current_value + else: + high_value = current_value + current_iter += 1 + if ( + current_iter > stop_iter_step + or np.abs(current_error) < stop_error_step + or high_value - low_value < stop_value_step + ): + break + return low_value, high_value def _params_to_errors(h, delta_h, l_skewness_and_kurtosis): - """Maps parameters to errors on L-skewness and L-kurtosis. - - Args: - h: Value of right parameter of the Tukey HH distribution. - delta_h: Different between right and left parameter of the Tukey HH - distribution. - l_skewness_and_kurtosis: np.array containing the target values of - L-skewness and L-kurtosis. - - Returns: - An np.array containing the difference between the values of L-skewness and - L-kurtosis corresponding to the parameters hl = h - delta_h, hr =h and the - target values. - """ - dtype = l_skewness_and_kurtosis.dtype - h_params = np.array([h - delta_h, h], dtype=dtype) - current_l_skewness_and_kurtosis = np.array( - _tukey_hh_l_skewness_and_kurtosis(h_params), dtype=dtype) - return current_l_skewness_and_kurtosis - l_skewness_and_kurtosis + """Maps parameters to errors on L-skewness and L-kurtosis. + + Args: + ---- + h: Value of right parameter of the Tukey HH distribution. + delta_h: Different between right and left parameter of the Tukey HH + distribution. + l_skewness_and_kurtosis: np.array containing the target values of + L-skewness and L-kurtosis. + + Returns: + ------- + An np.array containing the difference between the values of L-skewness and + L-kurtosis corresponding to the parameters hl = h - delta_h, hr =h and the + target values. + """ + dtype = l_skewness_and_kurtosis.dtype + h_params = np.array([h - delta_h, h], dtype=dtype) + current_l_skewness_and_kurtosis = np.array( + _tukey_hh_l_skewness_and_kurtosis(h_params), dtype=dtype + ) + return current_l_skewness_and_kurtosis - l_skewness_and_kurtosis def compute_tukey_hh_params(l_skewness_and_kurtosis): - """Computes the H paramesters of a Tukey HH distribution. - - Given the L-skewness and L-kurtosis of a Tukey HH distribution we compute - the H parameters of the distribution. - - Args: - l_skewness_and_kurtosis: A np.array with shape (2,) containing L-skewness - and L-kurtosis. - - Returns: - An np.array with the same type and shape of the argument containing the - left and right H parameters of the distribution. - """ - - # Exit conditions for the search loop. - stop_iter_step = 20 # Max number of iteration for the search loop. - stop_error_step = 1e-6 # Minimum function variation. - stop_value_step = 1e-6 # Minimum variable variation. - - dtype = l_skewness_and_kurtosis.dtype - - # Returns zero parameters (i.e. treat as gaussian) if L-kurtosis is smaller - # than for a gaussian. - - result = np.zeros_like(l_skewness_and_kurtosis) - if l_skewness_and_kurtosis[1] < 0.1226017: + """Computes the H paramesters of a Tukey HH distribution. + + Given the L-skewness and L-kurtosis of a Tukey HH distribution we compute + the H parameters of the distribution. + + Args: + ---- + l_skewness_and_kurtosis: A np.array with shape (2,) containing L-skewness + and L-kurtosis. + + Returns: + ------- + An np.array with the same type and shape of the argument containing the + left and right H parameters of the distribution. + """ + # Exit conditions for the search loop. + stop_iter_step = 20 # Max number of iteration for the search loop. + stop_error_step = 1e-6 # Minimum function variation. + stop_value_step = 1e-6 # Minimum variable variation. + + dtype = l_skewness_and_kurtosis.dtype + + # Returns zero parameters (i.e. treat as gaussian) if L-kurtosis is smaller + # than for a gaussian. + + result = np.zeros_like(l_skewness_and_kurtosis) + if l_skewness_and_kurtosis[1] < 0.1226017: + return result + + # If L-skewness is negative, swap the parameters. + + swap_params = False + if l_skewness_and_kurtosis[0] < 0.0: + l_skewness_and_kurtosis[0] = -l_skewness_and_kurtosis[0] + swap_params = True + + l_skewness_and_kurtosis[1] = np.minimum(l_skewness_and_kurtosis[1], 1.0 - 1.0e-5) + + # If L-skewness is zero, left and right parameters are equal and there is a + # a closed form to compute them from L-kurtosis. We start from this value + # and then change them to match simultaneously L-skeweness and L-kurtosis. + # For that, we parametrize the search space with the array + # [h_rigth, h_right - h_left], i.e. the value of the right parameter and the + # difference right minus left paramerters. In the search iteration, we + # alternate between updates on the first and the second entry of the search + # parameters. + + initial_h = 3.0 - 1.0 / np.cos(np.pi / 15.0 * (l_skewness_and_kurtosis[1] - 6.0)) + search_params = np.array([initial_h, 0.0], dtype=dtype) + + # Current lower and upper bounds for the search parameters. + + min_search_params = np.array([initial_h, 0.0], dtype=dtype) + max_search_params = np.array([1.0 - 1.0e-7, initial_h], dtype=dtype) + + current_iter = 0 + previous_search_params = np.zeros_like(search_params) + while current_iter < stop_iter_step: + # Search for L-skewness at constant h. Increase delta_h. + error_skewness = lambda x: _params_to_errors( # pylint: disable=g-long-lambda + search_params[0], x, l_skewness_and_kurtosis + )[0] + if error_skewness(max_search_params[1]) > 0.0: + low_delta_h, high_delta_h = _binary_search( + error_skewness, min_search_params[1], max_search_params[1] + ) + search_params[1] = high_delta_h + max_search_params[1] = high_delta_h # The new delta is an upperbound. + upperbound_delta_found = True + else: + search_params[1] = max_search_params[1] + min_search_params[1] = max_search_params[1] # No solution: lowerbound. + upperbound_delta_found = False + + # Search for L-kurtosis at constant possibly overestimated delta. + error_kurtosis = lambda x: _params_to_errors( # pylint: disable=g-long-lambda + x, search_params[1], l_skewness_and_kurtosis + )[1] + low_h, high_h = _binary_search( + error_kurtosis, min_search_params[0], max_search_params[0] + ) + if upperbound_delta_found: + search_params[0] = high_h + max_search_params[0] = high_h # Delta overestimated: upperbound for h. + else: + search_params[0] = low_h + min_search_params[0] = low_h # Delta underestimated: lowerbound for h. + max_search_params[1] = low_h # Delta not found, search on full range. + + if upperbound_delta_found: # If not found, we repeat the first 2 steps. + # Otherwise, Search for delta at constant overestimated h. + error_skewness = lambda x: _params_to_errors( # pylint: disable=g-long-lambda + search_params[0], x, l_skewness_and_kurtosis + )[0] + low_delta_h, high_delta_h = _binary_search( + error_skewness, min_search_params[1], max_search_params[1] + ) + search_params[1] = low_delta_h + min_search_params[1] = low_delta_h + + # Search for h at constant delta. + error_kurtosis = lambda x: _params_to_errors( # pylint: disable=g-long-lambda + x, search_params[1], l_skewness_and_kurtosis + )[1] + low_h, high_h = _binary_search( + error_kurtosis, min_search_params[0], max_search_params[0] + ) + search_params[0] = low_h + min_search_params[0] = low_h + + current_error = _params_to_errors( + search_params[0], search_params[1], l_skewness_and_kurtosis + ) + delta_search_params = search_params - previous_search_params + current_iter += 1 + previous_search_params = search_params.copy() + if np.all(np.abs(current_error) < stop_error_step) or np.all( + np.abs(delta_search_params) < stop_value_step + ): + break + + result[0] = search_params[0] - search_params[1] + result[1] = search_params[0] + if swap_params: + result = result[::-1] return result - # If L-skewness is negative, swap the parameters. - - swap_params = False - if l_skewness_and_kurtosis[0] < 0.0: - l_skewness_and_kurtosis[0] = -l_skewness_and_kurtosis[0] - swap_params = True - - l_skewness_and_kurtosis[1] = np.minimum( - l_skewness_and_kurtosis[1], 1.0 - 1.0e-5) - - # If L-skewness is zero, left and right parameters are equal and there is a - # a closed form to compute them from L-kurtosis. We start from this value - # and then change them to match simultaneously L-skeweness and L-kurtosis. - # For that, we parametrize the search space with the array - # [h_rigth, h_right - h_left], i.e. the value of the right parameter and the - # difference right minus left paramerters. In the search iteration, we - # alternate between updates on the first and the second entry of the search - # parameters. - - initial_h = 3.0 - 1.0 / np.cos( - np.pi / 15.0 * (l_skewness_and_kurtosis[1] - 6.0)) - search_params = np.array([initial_h, 0.0], dtype=dtype) - - # Current lower and upper bounds for the search parameters. - - min_search_params = np.array([initial_h, 0.0], dtype=dtype) - max_search_params = np.array([1.0 - 1.0e-7, initial_h], dtype=dtype) - - current_iter = 0 - previous_search_params = np.zeros_like(search_params) - while current_iter < stop_iter_step: - # Search for L-skewness at constant h. Increase delta_h. - error_skewness = lambda x: _params_to_errors( # pylint: disable=g-long-lambda - search_params[0], x, l_skewness_and_kurtosis)[0] - if error_skewness(max_search_params[1]) > 0.0: - low_delta_h, high_delta_h = _binary_search( - error_skewness, min_search_params[1], max_search_params[1]) - search_params[1] = high_delta_h - max_search_params[1] = high_delta_h # The new delta is an upperbound. - upperbound_delta_found = True - else: - search_params[1] = max_search_params[1] - min_search_params[1] = max_search_params[1] # No solution: lowerbound. - upperbound_delta_found = False - - # Search for L-kurtosis at constant possibly overestimated delta. - error_kurtosis = lambda x: _params_to_errors( # pylint: disable=g-long-lambda - x, search_params[1], l_skewness_and_kurtosis)[1] - low_h, high_h = _binary_search( - error_kurtosis, min_search_params[0], max_search_params[0]) - if upperbound_delta_found: - search_params[0] = high_h - max_search_params[0] = high_h # Delta overestimated: upperbound for h. - else: - search_params[0] = low_h - min_search_params[0] = low_h # Delta underestimated: lowerbound for h. - max_search_params[1] = low_h # Delta not found, search on full range. - - if upperbound_delta_found: # If not found, we repeat the first 2 steps. - # Otherwise, Search for delta at constant overestimated h. - error_skewness = lambda x: _params_to_errors( # pylint: disable=g-long-lambda - search_params[0], x, l_skewness_and_kurtosis)[0] - low_delta_h, high_delta_h = _binary_search( - error_skewness, min_search_params[1], max_search_params[1]) - search_params[1] = low_delta_h - min_search_params[1] = low_delta_h - - # Search for h at constant delta. - error_kurtosis = lambda x: _params_to_errors( # pylint: disable=g-long-lambda - x, search_params[1], l_skewness_and_kurtosis)[1] - low_h, high_h = _binary_search( - error_kurtosis, min_search_params[0], max_search_params[0]) - search_params[0] = low_h - min_search_params[0] = low_h - - current_error = _params_to_errors( - search_params[0], search_params[1], l_skewness_and_kurtosis) - delta_search_params = search_params - previous_search_params - current_iter += 1 - previous_search_params = search_params.copy() - if (np.all(np.abs(current_error) < stop_error_step) or - np.all(np.abs(delta_search_params) < stop_value_step)): - break - - result[0] = search_params[0] - search_params[1] - result[1] = search_params[0] - if swap_params: - result = result[::-1] - return result - def lambert_w(x): - """Computes the Lambert W function of a `Tensor`. - - Computes the principal branch of the Lambert W function, i.e. the value w such - that w * exp(w) = x for a a given x. For the principal branch, x must be real - x >= -1 / e, and w >= -1. - - Args: - x: A `Tensor` containing the values for which the principal branch of - the Lambert W function is computed. - - Returns: - A `Tensor` with the same shape and dtype as x containing the value of the - Lambert W function. - """ - dtype = x.dtype - e = tf.constant(np.exp(1.0), dtype) - inv_e = tf.constant(np.exp(-1.0), dtype) - s = (np.exp(1) - 1.0) / (np.exp(2) - 1.0) - slope = tf.constant(s, dtype) - c = tf.constant(1 / np.exp(1) * (1 - s), dtype) - log_s = tf.math.log(x) - w_init = tf.where( - x < inv_e, - x, - tf.where(x < e, - slope * x + c, - (log_s + (1.0 / log_s - 1.0) * tf.math.log(log_s)))) - - def newton_update(count, w): - expw = tf.math.exp(w) - wexpw = w * expw - return count + 1, w - (wexpw - x) / (expw + wexpw) - - count = tf.constant(0, tf.int32) - num_iter = tf.constant(8) - (unused_final_count, w) = tf.while_loop( - lambda count, w: tf.less(count, num_iter), - newton_update, - [count, w_init]) - return w + """Computes the Lambert W function of a `Tensor`. + + Computes the principal branch of the Lambert W function, i.e. the value w such + that w * exp(w) = x for a a given x. For the principal branch, x must be real + x >= -1 / e, and w >= -1. + + Args: + ---- + x: A `Tensor` containing the values for which the principal branch of + the Lambert W function is computed. + + Returns: + ------- + A `Tensor` with the same shape and dtype as x containing the value of the + Lambert W function. + """ + dtype = x.dtype + e = tf.constant(np.exp(1.0), dtype) + inv_e = tf.constant(np.exp(-1.0), dtype) + s = (np.exp(1) - 1.0) / (np.exp(2) - 1.0) + slope = tf.constant(s, dtype) + c = tf.constant(1 / np.exp(1) * (1 - s), dtype) + log_s = tf.math.log(x) + w_init = tf.where( + x < inv_e, + x, + tf.where( + x < e, slope * x + c, (log_s + (1.0 / log_s - 1.0) * tf.math.log(log_s)) + ), + ) + + def newton_update(count, w): + expw = tf.math.exp(w) + wexpw = w * expw + return count + 1, w - (wexpw - x) / (expw + wexpw) + + count = tf.constant(0, tf.int32) + num_iter = tf.constant(8) + (unused_final_count, w) = tf.while_loop( + lambda count, w: tf.less(count, num_iter), newton_update, [count, w_init] + ) + return w def inverse_tukey_hh(x, hl, hr): - """Compute the inverse of the Tukey HH function. - - The Tukey HH function transforms a standard Gaussian distribution into the - Tukey HH distribution and it's defined as: - - x = u * exp(hl * u ^ 2) for u < 0 and x = u * exp(hr * u ^ 2) for u >= 0. - - Given the values of x, this function computes the corresponding values of u. - - Args: - x: The input `Tensor`. - hl: The "left" parameter of the distribution. It must have the same dtype - and shape of x (or a broadcastable shape) or be a scalar. - hr: The "right" parameter of the distribution. It must have the same dtype - and shape of x (or a broadcastable shape) or be a scalar. - - Returns: - The inverse of the Tukey HH function. - """ - def one_side(x, h): - h_x_square = tf.multiply(h, tf.square(x)) - return tf.where( - # Prevents the 0 / 0 form for small values of x.. - tf.less(h_x_square, 1.0e-7), - x, # The error is < 1e-14 for this case. - tf.sqrt(tf.divide(lambert_w(h_x_square), h))) - - return tf.where(tf.less(x, 0.0), -one_side(-x, hl), one_side(x, hr)) + """Compute the inverse of the Tukey HH function. + + The Tukey HH function transforms a standard Gaussian distribution into the + Tukey HH distribution and it's defined as: + + x = u * exp(hl * u ^ 2) for u < 0 and x = u * exp(hr * u ^ 2) for u >= 0. + + Given the values of x, this function computes the corresponding values of u. + + Args: + ---- + x: The input `Tensor`. + hl: The "left" parameter of the distribution. It must have the same dtype + and shape of x (or a broadcastable shape) or be a scalar. + hr: The "right" parameter of the distribution. It must have the same dtype + and shape of x (or a broadcastable shape) or be a scalar. + + Returns: + ------- + The inverse of the Tukey HH function. + """ + + def one_side(x, h): + h_x_square = tf.multiply(h, tf.square(x)) + return tf.where( + # Prevents the 0 / 0 form for small values of x.. + tf.less(h_x_square, 1.0e-7), + x, # The error is < 1e-14 for this case. + tf.sqrt(tf.divide(lambert_w(h_x_square), h)), + ) + + return tf.where(tf.less(x, 0.0), -one_side(-x, hl), one_side(x, hr)) diff --git a/tensorflow_transform/gaussianization_test.py b/tensorflow_transform/gaussianization_test.py index 2435386..f2405e4 100644 --- a/tensorflow_transform/gaussianization_test.py +++ b/tensorflow_transform/gaussianization_test.py @@ -14,243 +14,313 @@ """Tests for tensorflow_transform.gaussianization.""" import numpy as np -from tensorflow_transform import gaussianization -from tensorflow_transform import test_case + +from tensorflow_transform import gaussianization, test_case _MEAN_SCALE_SCALAR_TEST = dict( - testcase_name='_scalar', + testcase_name="_scalar", h_params=np.array([0.1, 0.2], np.float32), - expected_outputs=[ - np.float32(0.05540865005575452), - np.float32(0.6932738015273474) - ], + expected_outputs=[np.float32(0.05540865005575452), np.float32(0.6932738015273474)], ) _MEAN_SCALE_ND_TEST = dict( - testcase_name='_nd', + testcase_name="_nd", h_params=np.array( - [[[[0.0], [0.1], [0.5]], [[0.7], [0.8], [0.9]]], - [[[0.0], [0.7], [0.6]], [[0.3], [0.2], [0.0]]]], np.float32), + [ + [[[0.0], [0.1], [0.5]], [[0.7], [0.8], [0.9]]], + [[[0.0], [0.7], [0.6]], [[0.3], [0.2], [0.0]]], + ], + np.float32, + ), expected_outputs=[ - np.array([[[0.], [0.8865384], [0.19947124]], - [[-0.75989], [-1.4960338], [-3.5904799]]], np.float32), - np.array([[[0.5641896], [1.4878997], [1.4943897]], - [[1.6034254], [2.1926064], [4.085859]]], np.float32) + np.array( + [ + [[0.0], [0.8865384], [0.19947124]], + [[-0.75989], [-1.4960338], [-3.5904799]], + ], + np.float32, + ), + np.array( + [ + [[0.5641896], [1.4878997], [1.4943897]], + [[1.6034254], [2.1926064], [4.085859]], + ], + np.float32, + ), ], ) _L_SKEWNESS_KURTOSIS_SCALAR_TEST = dict( - testcase_name='_scalar', + testcase_name="_scalar", h_params=np.array([0.1, 0.2], np.float32), - expected_outputs=[ - np.float32(0.05989154619056726), - np.float32(0.21460719619685548) - ], + expected_outputs=[np.float32(0.05989154619056726), np.float32(0.21460719619685548)], ) _L_SKEWNESS_KURTOSIS_ND_TEST = dict( - testcase_name='_nd', + testcase_name="_nd", h_params=np.array( - [[[[0.0], [0.1], [0.5]], [[0.7], [0.8], [0.9]]], - [[[0.0], [0.7], [0.6]], [[0.3], [0.2], [0.0]]]], np.float32), + [ + [[[0.0], [0.1], [0.5]], [[0.7], [0.8], [0.9]]], + [[[0.0], [0.7], [0.6]], [[0.3], [0.2], [0.0]]], + ], + np.float32, + ), expected_outputs=[ - np.array([[[0.], [0.5209037], [0.11905935]], - [[-0.4226278], [-0.6249933], [-0.833552]]], np.float32), - np.array([[[0.12260159], [0.54675657], [0.5140212]], - [[0.55600286], [0.66664696], [0.81815743]]], np.float32), + np.array( + [ + [[0.0], [0.5209037], [0.11905935]], + [[-0.4226278], [-0.6249933], [-0.833552]], + ], + np.float32, + ), + np.array( + [ + [[0.12260159], [0.54675657], [0.5140212]], + [[0.55600286], [0.66664696], [0.81815743]], + ], + np.float32, + ), ], ) -_COMPUTE_TUKEY_H_PARAMS_REGULAR_TESTS = [dict( - testcase_name='_regular_1', - l_skewness_and_kurtosis=np.array( - [0.05989154619056726, 0.21460719619685548], np.float32), - expected_output=np.array([0.1, 0.2], np.float32), -), dict( - testcase_name='_regular_2', - l_skewness_and_kurtosis=np.array([0.1, 0.2], np.float32), - expected_output=np.array([0.03056329, 0.20497137], np.float32) -), dict( - testcase_name='_regular_3', - l_skewness_and_kurtosis=np.array([0.8, 0.99], np.float32), - expected_output=np.array([0.9635793, 0.99589026], np.float32) -), dict( - testcase_name='_regular_4', - l_skewness_and_kurtosis=np.array([0.6, 0.7], np.float32), - expected_output=np.array([0.3535486, 0.82437974], np.float32) -)] +_COMPUTE_TUKEY_H_PARAMS_REGULAR_TESTS = [ + dict( + testcase_name="_regular_1", + l_skewness_and_kurtosis=np.array( + [0.05989154619056726, 0.21460719619685548], np.float32 + ), + expected_output=np.array([0.1, 0.2], np.float32), + ), + dict( + testcase_name="_regular_2", + l_skewness_and_kurtosis=np.array([0.1, 0.2], np.float32), + expected_output=np.array([0.03056329, 0.20497137], np.float32), + ), + dict( + testcase_name="_regular_3", + l_skewness_and_kurtosis=np.array([0.8, 0.99], np.float32), + expected_output=np.array([0.9635793, 0.99589026], np.float32), + ), + dict( + testcase_name="_regular_4", + l_skewness_and_kurtosis=np.array([0.6, 0.7], np.float32), + expected_output=np.array([0.3535486, 0.82437974], np.float32), + ), +] _COMPUTE_TUKEY_H_PARAMS_NEG_SKEWNESS_TEST = dict( - testcase_name='_neg_skewness', + testcase_name="_neg_skewness", l_skewness_and_kurtosis=np.array( - [-0.05989154619056726, 0.21460719619685548], np.float32), - expected_output=np.array([0.2, 0.1], np.float32) + [-0.05989154619056726, 0.21460719619685548], np.float32 + ), + expected_output=np.array([0.2, 0.1], np.float32), ) -_COMPUTE_TUKEY_H_PARAMS_PATOLOGICAL_TESTS = [dict( - # For this test, the values of skewness and kurtosis are valid, but not - # achievable by a Tukey HH distribution. The solution is the closest - # possible. - testcase_name='_patological', - l_skewness_and_kurtosis=np.array( - [0.7, 0.5], np.float32), - expected_output=np.array([0.0, 0.65736556], np.float32) -), dict( - testcase_name='_pat_invalid_skewness', - l_skewness_and_kurtosis=np.array( - [1.0, 0.5], np.float32), - expected_output=np.array([0.0, 0.65736556], np.float32) -), dict( - testcase_name='_pat_invalid_kurtosis', - l_skewness_and_kurtosis=np.array( - [0.5, 1.5], np.float32), - expected_output=np.array( - [00.9999859847861059, 0.9999950120303265], np.float32) -)] +_COMPUTE_TUKEY_H_PARAMS_PATOLOGICAL_TESTS = [ + dict( + # For this test, the values of skewness and kurtosis are valid, but not + # achievable by a Tukey HH distribution. The solution is the closest + # possible. + testcase_name="_patological", + l_skewness_and_kurtosis=np.array([0.7, 0.5], np.float32), + expected_output=np.array([0.0, 0.65736556], np.float32), + ), + dict( + testcase_name="_pat_invalid_skewness", + l_skewness_and_kurtosis=np.array([1.0, 0.5], np.float32), + expected_output=np.array([0.0, 0.65736556], np.float32), + ), + dict( + testcase_name="_pat_invalid_kurtosis", + l_skewness_and_kurtosis=np.array([0.5, 1.5], np.float32), + expected_output=np.array([00.9999859847861059, 0.9999950120303265], np.float32), + ), +] -_LAMBERT_W_SCALAR_TESTS = [dict( - testcase_name='lambert_w_scalar_0', - samples=np.float32(0.0), - expected_output=np.float32(0.0) -), dict( - testcase_name='lambert_w_scalar_small', - samples=np.float32(1.0e-4), - expected_output=np.float32(9.999000e-05) -), dict( - testcase_name='lambert_w_scalar_e', - samples=np.float32(np.exp(1.0)), - expected_output=np.float32(1.0) -), dict( - testcase_name='lambert_w_scalar_large', - samples=np.float32(10.0 * np.exp(10.0)), - expected_output=np.float32(10.0) -)] +_LAMBERT_W_SCALAR_TESTS = [ + dict( + testcase_name="lambert_w_scalar_0", + samples=np.float32(0.0), + expected_output=np.float32(0.0), + ), + dict( + testcase_name="lambert_w_scalar_small", + samples=np.float32(1.0e-4), + expected_output=np.float32(9.999000e-05), + ), + dict( + testcase_name="lambert_w_scalar_e", + samples=np.float32(np.exp(1.0)), + expected_output=np.float32(1.0), + ), + dict( + testcase_name="lambert_w_scalar_large", + samples=np.float32(10.0 * np.exp(10.0)), + expected_output=np.float32(10.0), + ), +] -_LAMBERT_W_ND_TESTS = [dict( - testcase_name='lambert_w_1D', - samples=np.linspace(0.0, 10, 8, dtype=np.float32), - expected_output=np.array( - [0., 0.70550971, 1.02506557, 1.24009733, 1.40379211, 1.53656406, - 1.6485427, 1.745528], np.float32) -), dict( - testcase_name='lambert_w_3D', - samples=np.linspace(0.0, 10, 8, dtype=np.float32).reshape((2, 4, 1)), - expected_output=np.array( - [0., 0.70550971, 1.02506557, 1.24009733, 1.40379211, 1.53656406, - 1.6485427, 1.745528], np.float32).reshape((2, 4, 1)) -)] +_LAMBERT_W_ND_TESTS = [ + dict( + testcase_name="lambert_w_1D", + samples=np.linspace(0.0, 10, 8, dtype=np.float32), + expected_output=np.array( + [ + 0.0, + 0.70550971, + 1.02506557, + 1.24009733, + 1.40379211, + 1.53656406, + 1.6485427, + 1.745528, + ], + np.float32, + ), + ), + dict( + testcase_name="lambert_w_3D", + samples=np.linspace(0.0, 10, 8, dtype=np.float32).reshape((2, 4, 1)), + expected_output=np.array( + [ + 0.0, + 0.70550971, + 1.02506557, + 1.24009733, + 1.40379211, + 1.53656406, + 1.6485427, + 1.745528, + ], + np.float32, + ).reshape((2, 4, 1)), + ), +] -_INVERSE_TUKEY_HH_SCALAR_TESTS = [dict( - testcase_name='inverse_tukey_scalar_0', - samples=np.float32(0.0), - hl=np.float32(1.0), - hr=np.float32(2.0), - expected_output=np.float32(0.0) -), dict( - testcase_name='inverse_tukey_small_positive', - samples=np.float32(1.0e-4), - hl=np.float32(1.0), - hr=np.float32(2.0), - expected_output=np.float32(1.0e-4) -), dict( - testcase_name='inverse_tukey_small_negative', - samples=np.float32(-1.0e-4), - hl=np.float32(1.0), - hr=np.float32(2.0), - expected_output=np.float32(-1.0e-4) -), dict( - testcase_name='inverse_tukey_large_positive', - samples=np.float32(5.0 * np.exp(25.0)), - hl=np.float32(1.0), - hr=np.float32(2.0), - expected_output=np.float32(5.0) -), dict( - testcase_name='inverse_tukey_large_negative', - samples=np.float32(-5.0 * np.exp(0.5 * 25.0)), - hl=np.float32(1.0), - hr=np.float32(2.0), - expected_output=np.float32(-5.0) -)] +_INVERSE_TUKEY_HH_SCALAR_TESTS = [ + dict( + testcase_name="inverse_tukey_scalar_0", + samples=np.float32(0.0), + hl=np.float32(1.0), + hr=np.float32(2.0), + expected_output=np.float32(0.0), + ), + dict( + testcase_name="inverse_tukey_small_positive", + samples=np.float32(1.0e-4), + hl=np.float32(1.0), + hr=np.float32(2.0), + expected_output=np.float32(1.0e-4), + ), + dict( + testcase_name="inverse_tukey_small_negative", + samples=np.float32(-1.0e-4), + hl=np.float32(1.0), + hr=np.float32(2.0), + expected_output=np.float32(-1.0e-4), + ), + dict( + testcase_name="inverse_tukey_large_positive", + samples=np.float32(5.0 * np.exp(25.0)), + hl=np.float32(1.0), + hr=np.float32(2.0), + expected_output=np.float32(5.0), + ), + dict( + testcase_name="inverse_tukey_large_negative", + samples=np.float32(-5.0 * np.exp(0.5 * 25.0)), + hl=np.float32(1.0), + hr=np.float32(2.0), + expected_output=np.float32(-5.0), + ), +] def _tukey_hh(x, hl, hr): - return np.where( - x > 0.0, - x * np.exp(0.5 * hr * np.square(x)), - x * np.exp(0.5 * hl * np.square(x))) + return np.where( + x > 0.0, + x * np.exp(0.5 * hr * np.square(x)), + x * np.exp(0.5 * hl * np.square(x)), + ) -_INVERSE_TUKEY_HH_ND_TESTS = [dict( - testcase_name='inverse_tukey_1D', - samples=np.array( - _tukey_hh(np.linspace(-5.0, 5.0, 20), 1.0, 2.0), np.float32), - hl=np.float32(1.0), - hr=np.float32(2.0), - expected_output=np.linspace(-5.0, 5.0, 20, dtype=np.float32) -), dict( - testcase_name='inverse_tukey_3D', - samples=np.array( - _tukey_hh(np.linspace(-5.0, 5.0, 100).reshape((10, 5, 2)), - np.linspace(1.0, 1.5, 10).reshape((1, 5, 2)), - np.linspace(2.0, 2.5, 10).reshape((1, 5, 2))), np.float32), - hl=np.linspace(1.0, 1.5, 10, dtype=np.float32).reshape((1, 5, 2)), - hr=np.linspace(2.0, 2.5, 10, dtype=np.float32).reshape((1, 5, 2)), - expected_output=np.linspace( - -5.0, 5.0, 100, dtype=np.float32).reshape((10, 5, 2)) -)] +_INVERSE_TUKEY_HH_ND_TESTS = [ + dict( + testcase_name="inverse_tukey_1D", + samples=np.array(_tukey_hh(np.linspace(-5.0, 5.0, 20), 1.0, 2.0), np.float32), + hl=np.float32(1.0), + hr=np.float32(2.0), + expected_output=np.linspace(-5.0, 5.0, 20, dtype=np.float32), + ), + dict( + testcase_name="inverse_tukey_3D", + samples=np.array( + _tukey_hh( + np.linspace(-5.0, 5.0, 100).reshape((10, 5, 2)), + np.linspace(1.0, 1.5, 10).reshape((1, 5, 2)), + np.linspace(2.0, 2.5, 10).reshape((1, 5, 2)), + ), + np.float32, + ), + hl=np.linspace(1.0, 1.5, 10, dtype=np.float32).reshape((1, 5, 2)), + hr=np.linspace(2.0, 2.5, 10, dtype=np.float32).reshape((1, 5, 2)), + expected_output=np.linspace(-5.0, 5.0, 100, dtype=np.float32).reshape( + (10, 5, 2) + ), + ), +] -class GaussianizationTest(test_case.TransformTestCase): - @test_case.named_parameters( - _MEAN_SCALE_SCALAR_TEST, - _MEAN_SCALE_ND_TEST - ) - def test_tukey_hh_l_mean_and_scale(self, h_params, expected_outputs): - outputs = gaussianization.tukey_hh_l_mean_and_scale(h_params) - self.assertEqual(len(outputs), len(expected_outputs)) - for output, expected_output in zip(outputs, expected_outputs): - self.assertEqual(output.dtype, expected_output.dtype) - self.assertAllEqual(output.shape, expected_output.shape) - self.assertAllClose(output, expected_output) +class GaussianizationTest(test_case.TransformTestCase): + @test_case.named_parameters(_MEAN_SCALE_SCALAR_TEST, _MEAN_SCALE_ND_TEST) + def test_tukey_hh_l_mean_and_scale(self, h_params, expected_outputs): + outputs = gaussianization.tukey_hh_l_mean_and_scale(h_params) + self.assertEqual(len(outputs), len(expected_outputs)) + for output, expected_output in zip(outputs, expected_outputs): + self.assertEqual(output.dtype, expected_output.dtype) + self.assertAllEqual(output.shape, expected_output.shape) + self.assertAllClose(output, expected_output) - @test_case.named_parameters( - _L_SKEWNESS_KURTOSIS_SCALAR_TEST, - _L_SKEWNESS_KURTOSIS_ND_TEST - ) - def test_tukey_hh_l_skewness_and_kurtosis(self, h_params, expected_outputs): - outputs = gaussianization._tukey_hh_l_skewness_and_kurtosis(h_params) - self.assertEqual(len(outputs), len(expected_outputs)) - for output, expected_output in zip(outputs, expected_outputs): - self.assertEqual(output.dtype, expected_output.dtype) - self.assertAllEqual(output.shape, expected_output.shape) - self.assertAllClose(output, expected_output) + @test_case.named_parameters( + _L_SKEWNESS_KURTOSIS_SCALAR_TEST, _L_SKEWNESS_KURTOSIS_ND_TEST + ) + def test_tukey_hh_l_skewness_and_kurtosis(self, h_params, expected_outputs): + outputs = gaussianization._tukey_hh_l_skewness_and_kurtosis(h_params) + self.assertEqual(len(outputs), len(expected_outputs)) + for output, expected_output in zip(outputs, expected_outputs): + self.assertEqual(output.dtype, expected_output.dtype) + self.assertAllEqual(output.shape, expected_output.shape) + self.assertAllClose(output, expected_output) - @test_case.named_parameters(*( - [_COMPUTE_TUKEY_H_PARAMS_NEG_SKEWNESS_TEST] + - _COMPUTE_TUKEY_H_PARAMS_REGULAR_TESTS + - _COMPUTE_TUKEY_H_PARAMS_PATOLOGICAL_TESTS)) - def test_compute_tukey_hh_params( - self, l_skewness_and_kurtosis, expected_output): - output = gaussianization.compute_tukey_hh_params(l_skewness_and_kurtosis) - self.assertEqual(output.dtype, expected_output.dtype) - self.assertAllEqual(output.shape, expected_output.shape) - self.assertAllClose(output, expected_output, rtol=1e-5, atol=1e-5) + @test_case.named_parameters( + *( + [_COMPUTE_TUKEY_H_PARAMS_NEG_SKEWNESS_TEST] + + _COMPUTE_TUKEY_H_PARAMS_REGULAR_TESTS + + _COMPUTE_TUKEY_H_PARAMS_PATOLOGICAL_TESTS + ) + ) + def test_compute_tukey_hh_params(self, l_skewness_and_kurtosis, expected_output): + output = gaussianization.compute_tukey_hh_params(l_skewness_and_kurtosis) + self.assertEqual(output.dtype, expected_output.dtype) + self.assertAllEqual(output.shape, expected_output.shape) + self.assertAllClose(output, expected_output, rtol=1e-5, atol=1e-5) - @test_case.named_parameters(*_LAMBERT_W_SCALAR_TESTS + _LAMBERT_W_ND_TESTS) - def test_lambert_w(self, samples, expected_output): - output = gaussianization.lambert_w(samples) - self.assertEqual(output.dtype, expected_output.dtype) - self.assertAllEqual(output.shape, expected_output.shape) - self.assertAllClose(output, expected_output) + @test_case.named_parameters(*_LAMBERT_W_SCALAR_TESTS + _LAMBERT_W_ND_TESTS) + def test_lambert_w(self, samples, expected_output): + output = gaussianization.lambert_w(samples) + self.assertEqual(output.dtype, expected_output.dtype) + self.assertAllEqual(output.shape, expected_output.shape) + self.assertAllClose(output, expected_output) - @test_case.named_parameters( - *_INVERSE_TUKEY_HH_SCALAR_TESTS + _INVERSE_TUKEY_HH_ND_TESTS) - def test_inverse_tukey_hh(self, samples, hl, hr, expected_output): - output = gaussianization.inverse_tukey_hh(samples, hl, hr) - self.assertEqual(output.dtype, expected_output.dtype) - self.assertAllEqual(output.shape, expected_output.shape) - self.assertAllClose(output, expected_output) + @test_case.named_parameters( + *_INVERSE_TUKEY_HH_SCALAR_TESTS + _INVERSE_TUKEY_HH_ND_TESTS + ) + def test_inverse_tukey_hh(self, samples, hl, hr, expected_output): + output = gaussianization.inverse_tukey_hh(samples, hl, hr) + self.assertEqual(output.dtype, expected_output.dtype) + self.assertAllEqual(output.shape, expected_output.shape) + self.assertAllClose(output, expected_output) -if __name__ == '__main__': - test_case.main() +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/graph_context.py b/tensorflow_transform/graph_context.py index d93ad9a..674f6cf 100644 --- a/tensorflow_transform/graph_context.py +++ b/tensorflow_transform/graph_context.py @@ -18,111 +18,124 @@ from typing import Any, Dict, Optional import tensorflow as tf + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple class TFGraphContext: - """A context manager to pass global state to a TF graph when it is traced. - - All the attributes in this context are kept on a thread local state. - - Attributes: - module_to_export: A tf.Module object that can be exported to a SavedModel - and will be used to track objects created within this TF graph. - temp_dir: The base path of the directory to write out any temporary files - in this context block. If None, the TF graph in this context will be - traced with placeholders for asset filepaths and is not serializable to a - SavedModel. - evaluated_replacements: A subset of placeholders/temporary asset files in - `analyzer_nodes.TENSOR_REPLACEMENTS` that have been evaluated in - previous TFT phases. - - Note that the temp dir should be accessible to worker jobs, e.g. if running - with the Cloud Dataflow runner, the temp dir should be on GCS and should have - permissions that allow both launcher and workers to access it. - """ - - class _State( - tfx_namedtuple.namedtuple('_State', [ - 'module_to_export', - 'temp_dir', - 'evaluated_replacements', - ])): - """A named tuple storing state passed to this context manager.""" + """A context manager to pass global state to a TF graph when it is traced. + + All the attributes in this context are kept on a thread local state. + + Attributes + ---------- + module_to_export: A tf.Module object that can be exported to a SavedModel + and will be used to track objects created within this TF graph. + temp_dir: The base path of the directory to write out any temporary files + in this context block. If None, the TF graph in this context will be + traced with placeholders for asset filepaths and is not serializable to a + SavedModel. + evaluated_replacements: A subset of placeholders/temporary asset files in + `analyzer_nodes.TENSOR_REPLACEMENTS` that have been evaluated in + previous TFT phases. + + Note that the temp dir should be accessible to worker jobs, e.g. if running + with the Cloud Dataflow runner, the temp dir should be on GCS and should have + permissions that allow both launcher and workers to access it. + """ + + class _State( + tfx_namedtuple.namedtuple( + "_State", + [ + "module_to_export", + "temp_dir", + "evaluated_replacements", + ], + ) + ): + """A named tuple storing state passed to this context manager.""" + + @classmethod + def make_empty(cls): + """Return `_State` object with all fields set to `None`.""" + return cls(*(None,) * len(cls._fields)) + + _TEMP_SUBDIR = "analyzer_temporary_assets" + + _thread_local = threading.local() + + def __init__( + self, + module_to_export: tf.Module, + temp_dir: Optional[str] = None, + evaluated_replacements: Optional[Dict[str, Any]] = None, + ): + self._module_to_export = module_to_export + self._temp_dir = temp_dir + self._evaluated_replacements = evaluated_replacements + + def __enter__(self): + assert getattr(self._thread_local, "current_state", None) is None + self._thread_local.current_state = self._State( + module_to_export=self._module_to_export, + temp_dir=self._temp_dir, + evaluated_replacements=self._evaluated_replacements, + ) + + def __exit__(self, *exn_info): + self._thread_local.current_state = None + + @property + def module_to_export(self): + return self._module_to_export @classmethod - def make_empty(cls): - """Return `_State` object with all fields set to `None`.""" - return cls(*(None,) * len(cls._fields)) - - _TEMP_SUBDIR = 'analyzer_temporary_assets' - - _thread_local = threading.local() - - def __init__(self, - module_to_export: tf.Module, - temp_dir: Optional[str] = None, - evaluated_replacements: Optional[Dict[str, Any]] = None): - self._module_to_export = module_to_export - self._temp_dir = temp_dir - self._evaluated_replacements = evaluated_replacements - - def __enter__(self): - assert getattr(self._thread_local, 'current_state', None) is None - self._thread_local.current_state = self._State( - module_to_export=self._module_to_export, - temp_dir=self._temp_dir, - evaluated_replacements=self._evaluated_replacements) - - def __exit__(self, *exn_info): - self._thread_local.current_state = None - - @property - def module_to_export(self): - return self._module_to_export - - @classmethod - def _get_current_state(cls) -> 'TFGraphContext._State': - if hasattr(cls._thread_local, 'current_state'): - return cls._thread_local.current_state - return cls._State.make_empty() - - @classmethod - def get_or_create_temp_dir(cls) -> Optional[str]: - """Generate a temporary location.""" - current_state = cls._get_current_state() - if current_state.temp_dir is None: - return None - if not current_state.temp_dir: - raise ValueError('A temp dir was requested, but empty temp_dir was set. ' - 'Use the TFGraphContext context manager.') - result = os.path.join(current_state.temp_dir, cls._TEMP_SUBDIR) - tf.io.gfile.makedirs(result) - return result - - @classmethod - def get_evaluated_replacements(cls) -> Optional[Dict[str, Any]]: - """Retrieves the value of evaluated_replacements if set. - - None otherwise. - - Returns: - A dictionary from graph tensor names to evaluated values for these - tensors. The keys are a subset of placeholders/temporary asset files in - `analyzer_nodes.TENSOR_REPLACEMENTS` that have been evaluated in - previous TFT phases. - """ - return cls._get_current_state().evaluated_replacements + def _get_current_state(cls) -> "TFGraphContext._State": + if hasattr(cls._thread_local, "current_state"): + return cls._thread_local.current_state + return cls._State.make_empty() + + @classmethod + def get_or_create_temp_dir(cls) -> Optional[str]: + """Generate a temporary location.""" + current_state = cls._get_current_state() + if current_state.temp_dir is None: + return None + if not current_state.temp_dir: + raise ValueError( + "A temp dir was requested, but empty temp_dir was set. " + "Use the TFGraphContext context manager." + ) + result = os.path.join(current_state.temp_dir, cls._TEMP_SUBDIR) + tf.io.gfile.makedirs(result) + return result - @classmethod - def get_module_to_export(cls) -> Optional[tf.Module]: - """Retrieves the value of module_to_export. + @classmethod + def get_evaluated_replacements(cls) -> Optional[Dict[str, Any]]: + """Retrieves the value of evaluated_replacements if set. - None if called outside a TFGraphContext scope. + None otherwise. - Returns: - A tf.Module object - """ - return cls._get_current_state().module_to_export + Returns + ------- + A dictionary from graph tensor names to evaluated values for these + tensors. The keys are a subset of placeholders/temporary asset files in + `analyzer_nodes.TENSOR_REPLACEMENTS` that have been evaluated in + previous TFT phases. + """ + return cls._get_current_state().evaluated_replacements + + @classmethod + def get_module_to_export(cls) -> Optional[tf.Module]: + """Retrieves the value of module_to_export. + + None if called outside a TFGraphContext scope. + + Returns + ------- + A tf.Module object + """ + return cls._get_current_state().module_to_export diff --git a/tensorflow_transform/graph_tools.py b/tensorflow_transform/graph_tools.py index 9a3f552..093b9b0 100644 --- a/tensorflow_transform/graph_tools.py +++ b/tensorflow_transform/graph_tools.py @@ -27,928 +27,1005 @@ import copy import hashlib import itertools -from typing import Iterable, List, Mapping, Optional, Set, Union import uuid -from absl import logging +from typing import Iterable, List, Mapping, Optional, Set, Union import tensorflow as tf -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import common_types -from tensorflow_transform import nodes -from tensorflow_transform import tf_utils -# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` -# once the Spark issue is resolved. -from tfx_bsl.types import tfx_namedtuple +from absl import logging + # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.framework import composite_tensor +from tensorflow.python.framework import composite_tensor, function_def_to_graph from tensorflow.python.framework import func_graph as tf_func_graph -from tensorflow.python.framework import function_def_to_graph from tensorflow.python.util import object_identity + +# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` +# once the Spark issue is resolved. +from tfx_bsl.types import tfx_namedtuple + +from tensorflow_transform import analyzer_nodes, common_types, nodes, tf_utils + # pylint: enable=g-direct-tensorflow-import _INITIALIZABLE_TABLE_OP_TYPES = [ - 'CuckooTable', - 'CuckooTableV2', - 'HashTable', - 'HashTableV2', - 'IndexTable', - 'IndexTableV2', + "CuckooTable", + "CuckooTableV2", + "HashTable", + "HashTableV2", + "IndexTable", + "IndexTableV2", ] _TABLE_INIT_OP_TYPES = [ - 'InitializeTable', - 'InitializeTableV2', - 'InitializeTableFromTextFile', - 'InitializeTableFromTextFileV2', - 'InitializeTableFromDataset', - 'LookupTableImport', - 'LookupTableImportV2', + "InitializeTable", + "InitializeTableV2", + "InitializeTableFromTextFile", + "InitializeTableFromTextFileV2", + "InitializeTableFromDataset", + "LookupTableImport", + "LookupTableImportV2", # If a TF 2 SavedModel/Hub module with tables is loaded inside the # pre-processing fn, a StatefulPartitionedCall is added to the # TABLE_INITIALIZERS collection. - 'StatefulPartitionedCall', + "StatefulPartitionedCall", ] def _decompose_tensor_or_op(tensor_or_op): - """Yields the raw components of a `tf.CompositeTensor`. + """Yields the raw components of a `tf.CompositeTensor`. - If tensor_or_op is a `tf.Operation`, or `tf.Tensor`, then - _decompose_tensor_or_op will act as a pass through. + If tensor_or_op is a `tf.Operation`, or `tf.Tensor`, then + _decompose_tensor_or_op will act as a pass through. - Args: - tensor_or_op: `tf.Tensor`, `tf.CompositeTensor`, or `tf.Operation`. + Args: + ---- + tensor_or_op: `tf.Tensor`, `tf.CompositeTensor`, or `tf.Operation`. - Yields: - A tf.Tensor or tf.Operation, depending on what tensor_or_op is. - """ - if isinstance(tensor_or_op, composite_tensor.CompositeTensor): - for component in tf.nest.flatten(tensor_or_op, expand_composites=True): - yield component - else: - yield tensor_or_op + Yields: + ------ + A tf.Tensor or tf.Operation, depending on what tensor_or_op is. + """ + if isinstance(tensor_or_op, composite_tensor.CompositeTensor): + for component in tf.nest.flatten(tensor_or_op, expand_composites=True): + yield component + else: + yield tensor_or_op def retrieve_sources(sinks, ignore_control_dependencies=False): - """Captures subgraph between sources and sinks. - - Walk a Graph backwards from `sinks` and return any sources encountered in the - subgraph. This util is refactored from `_map_subgraph` in - tensorflow/.../ops/op_selector.py. - - Arguments: - sinks: An iterable of Operations where the subgraph terminates. - ignore_control_dependencies: (Optional) If `True`, ignore any - `control_inputs` for all ops while walking the graph. - - Returns: - The set of placeholders upon which `sinks` depend. This could also contain - placeholders representing `captures` in the graph. - """ - stop_at_tensors = object_identity.ObjectIdentitySet() - ops_to_visit = object_identity.ObjectIdentitySet(sinks) - visited_ops = object_identity.ObjectIdentitySet() - potential_extra_sources = object_identity.ObjectIdentitySet() - while ops_to_visit: - op = ops_to_visit.pop() - visited_ops.add(op) - - if op.type == 'Placeholder': - potential_extra_sources.update(op.outputs) - - input_ops = [t.op for t in op.inputs if t not in stop_at_tensors] - if not ignore_control_dependencies: - input_ops = itertools.chain(input_ops, op.control_inputs) - for input_op in input_ops: - if input_op not in visited_ops: - ops_to_visit.add(input_op) - - return potential_extra_sources + """Captures subgraph between sources and sinks. + + Walk a Graph backwards from `sinks` and return any sources encountered in the + subgraph. This util is refactored from `_map_subgraph` in + tensorflow/.../ops/op_selector.py. + + Arguments: + --------- + sinks: An iterable of Operations where the subgraph terminates. + ignore_control_dependencies: (Optional) If `True`, ignore any + `control_inputs` for all ops while walking the graph. + + Returns: + ------- + The set of placeholders upon which `sinks` depend. This could also contain + placeholders representing `captures` in the graph. + """ + stop_at_tensors = object_identity.ObjectIdentitySet() + ops_to_visit = object_identity.ObjectIdentitySet(sinks) + visited_ops = object_identity.ObjectIdentitySet() + potential_extra_sources = object_identity.ObjectIdentitySet() + while ops_to_visit: + op = ops_to_visit.pop() + visited_ops.add(op) + + if op.type == "Placeholder": + potential_extra_sources.update(op.outputs) + + input_ops = [t.op for t in op.inputs if t not in stop_at_tensors] + if not ignore_control_dependencies: + input_ops = itertools.chain(input_ops, op.control_inputs) + for input_op in input_ops: + if input_op not in visited_ops: + ops_to_visit.add(input_op) + + return potential_extra_sources def get_func_graph_for_name(graph, func_name): - """Returns the FuncGraph associated to the given func_name if possible.""" - outer_graph = graph - while graph is not None: - func = graph._get_function(str(func_name)) # pylint: disable=protected-access - if func is not None: - if hasattr(func, 'graph'): - return func.graph - # `outer_graph` may not be the same as `ops.get_default_graph()` e.g. - # in the case of nested if ops or when the gradient is being computed - # from inside a Defun. We build the `func_graph` with `outer_graph` as its - # outer graph. - with outer_graph.as_default(): - # This is a _DefinedFunction. - func_graph = ( - function_def_to_graph.function_def_to_graph(func.definition)) - if func_graph is not None: - return func_graph - if hasattr(graph, 'outer_graph'): - graph = graph.outer_graph - else: - raise ValueError( - 'Function {} does not exist in the graph.'.format(func_name)) + """Returns the FuncGraph associated to the given func_name if possible.""" + outer_graph = graph + while graph is not None: + func = graph._get_function(str(func_name)) # pylint: disable=protected-access + if func is not None: + if hasattr(func, "graph"): + return func.graph + # `outer_graph` may not be the same as `ops.get_default_graph()` e.g. + # in the case of nested if ops or when the gradient is being computed + # from inside a Defun. We build the `func_graph` with `outer_graph` as its + # outer graph. + with outer_graph.as_default(): + # This is a _DefinedFunction. + func_graph = function_def_to_graph.function_def_to_graph( + func.definition + ) + if func_graph is not None: + return func_graph + if hasattr(graph, "outer_graph"): + graph = graph.outer_graph + else: + raise ValueError(f"Function {func_name} does not exist in the graph.") class _UnexpectedPlaceholderError(Exception): - - def __init__(self, op, func_graph_name): - tensor = op.outputs[0] - msg = 'An unexpected placeholder was encountered ({})'.format(tensor) - super().__init__(msg) - self.tensor = tensor - self.func_graph_name = func_graph_name + def __init__(self, op, func_graph_name): + tensor = op.outputs[0] + msg = f"An unexpected placeholder was encountered ({tensor})" + super().__init__(msg) + self.tensor = tensor + self.func_graph_name = func_graph_name class _UnexpectedTableError(Exception): - - def __init__(self, op, func_graph_name): - msg = 'An unexpected initializable table was encountered ({})'.format(op) - super().__init__(msg) - self.op = op - self.func_graph_name = func_graph_name + def __init__(self, op, func_graph_name): + msg = f"An unexpected initializable table was encountered ({op})" + super().__init__(msg) + self.op = op + self.func_graph_name = func_graph_name def _reraise_unexpected_error(func): - """A decorator that reraises certain exceptions with modified msg and type.""" - - def wrapper(self, tensor_or_op): - """Wrapper when calling func to re-raise exceptions.""" - try: - return func(self, tensor_or_op) - except _UnexpectedPlaceholderError as e: - context = (f' tf.function name: `{e.func_graph_name}`' - if e.func_graph_name else '') - raise ValueError( - 'The tensor_or_op {} depended on a placeholder ({}) that was not ' - 'in the input_signature. This may have be caused by manually ' - 'adding a placeholder to the graph.{}'.format(tensor_or_op, e.tensor, - context)) from e - except _UnexpectedTableError as e: - if e.func_graph_name: - raise ValueError( - 'The tensor_or_op {} depended on an initializable table ({}) that ' - 'is part of a tf.function graph ({}), this is not supported. This' - ' may be a result of initializing a table in a tf.function' - ''.format(tensor_or_op, e.op, e.func_graph_name)) from e - else: - raise ValueError( - 'The tensor_or_op {} depended on an initializable table ({}) that ' - 'was not tracked by the graph analysis. This may be caused by ' - 'adding an initializable table without adding its initializer to ' - 'the collection tf.GraphKeys.TABLE_INITIALIZERS'.format( - tensor_or_op, e.op)) from e - - return wrapper + """A decorator that reraises certain exceptions with modified msg and type.""" + + def wrapper(self, tensor_or_op): + """Wrapper when calling func to re-raise exceptions.""" + try: + return func(self, tensor_or_op) + except _UnexpectedPlaceholderError as e: + context = ( + f" tf.function name: `{e.func_graph_name}`" if e.func_graph_name else "" + ) + raise ValueError( + f"The tensor_or_op {tensor_or_op} depended on a placeholder ({e.tensor}) that was not " + "in the input_signature. This may have be caused by manually " + f"adding a placeholder to the graph.{context}" + ) from e + except _UnexpectedTableError as e: + if e.func_graph_name: + raise ValueError( + f"The tensor_or_op {tensor_or_op} depended on an initializable table ({e.op}) that " + f"is part of a tf.function graph ({e.func_graph_name}), this is not supported. This" + " may be a result of initializing a table in a tf.function" + ) from e + else: + raise ValueError( + f"The tensor_or_op {tensor_or_op} depended on an initializable table ({e.op}) that " + "was not tracked by the graph analysis. This may be caused by " + "adding an initializable table without adding its initializer to " + "the collection tf.GraphKeys.TABLE_INITIALIZERS" + ) from e + + return wrapper _AnalysisResult = tfx_namedtuple.namedtuple( - '_AnalysisResult', ['is_ready_to_run', 'path', 'dependent_sources']) + "_AnalysisResult", ["is_ready_to_run", "path", "dependent_sources"] +) -_SourceInfo = tfx_namedtuple.namedtuple('_SourceInfo', - ['is_ready_to_run', 'name']) +_SourceInfo = tfx_namedtuple.namedtuple("_SourceInfo", ["is_ready_to_run", "name"]) class _GraphAnalyzer: - """Class that analyzes a graph to determine readiness of tensors.""" + """Class that analyzes a graph to determine readiness of tensors.""" + + def __init__(self, source_info_dict, translate_path_fn, graph): + """Init method for _GraphAnalyzer. + + Args: + ---- + source_info_dict: A dict from `Tensor Reference` or `Operation` to + `_SourceInfo`. + translate_path_fn: A function with the signature: (identifier, parents) -> + Any which will be used to construct a unique path for a given `Tensor`. + graph: A `tf.Graph` which the given tensors belong to. + """ + self._memoized_analyze_tensor_result = {} + self._source_info_dict = source_info_dict + self._translate_path_fn = translate_path_fn + self._graph = graph + + def _get_parents(self, tensor_or_op): + """Get the parents of the given `tensor_or_op`.""" + if tf_utils.hashable_tensor_or_op(tensor_or_op) in self._source_info_dict: + return [] + + # func_graph_name is not None only if the graph is a FuncGraph. + func_graph_name = getattr(self._graph, "name", None) + if isinstance(tensor_or_op, tf.Operation): + if tensor_or_op.type in _INITIALIZABLE_TABLE_OP_TYPES: + raise _UnexpectedTableError(tensor_or_op, func_graph_name) + if tensor_or_op.type == "Placeholder": + # If we're not in the context of a tf.function, this is an error. + if func_graph_name is None: + raise _UnexpectedPlaceholderError(tensor_or_op, func_graph_name) + # If we're in the context of a tf.function and this op is part of its + # inputs, that's expected. + if tensor_or_op not in [x.op for x in self._graph.inputs]: + raise _UnexpectedPlaceholderError(tensor_or_op, func_graph_name) + parents = list( + itertools.chain(tensor_or_op.inputs, tensor_or_op.control_inputs) + ) + elif isinstance(tensor_or_op, tf.Tensor): + parents = [tensor_or_op.op] + else: + raise TypeError( + f"Expected Tensor or Operation, got {tensor_or_op} of type {type(tensor_or_op)}" + ) + return parents + + def _compute_analysis_results_for_func_attributes( + self, tensor_or_op, parent_analysis_results + ): + """Analyzes `FuncGraph`s if tensor_or_op has them as attributes. + + This functionality is added to support `Operation`s such as PartitionedCall + (tf.function call) and control flow ops which use `func` attributes. + + These func attributes are references to `FuncGraph`s which can also be + analyzed, and the result of their analysis can be used as additional + information for the current node (`tensor_or_op`). + + Since `FuncGraph`s are completely different graphs than the one that this + _GraphAnalyzer is analyzing, their analysis wouldn't be taken into account + when analysing the current graph even though they will affect the runtime + results of running it. This is why we have to manually analyze those + sub-graphs as well as the main graph when computing graph information such + as dependent_inputs, unique_path, etc. + + Args: + ---- + tensor_or_op: A `Tensor` or `Operation` object. + parent_analysis_results: A list of `_AnalysisResult`s, results of analysis + of the parents of tensor_or_op. + + Returns: + ------- + A list of `_AnalysisResult`s, the results of analysis of `tensor_or_op`'s + func attributes. All `Tensor`s in dependent_sources belong to self._graph. + """ + if not isinstance(tensor_or_op, tf.Operation): + return [] + func_attributes = [ + attr.name for attr in tensor_or_op.op_def.attr if attr.type == "func" + ] + func_names = [tensor_or_op.get_attr(str(n)).name for n in func_attributes] + func_graphs = [get_func_graph_for_name(self._graph, n) for n in func_names] + + result = [] + for func_graph in func_graphs: + if not hasattr(func_graph, "inputs"): + # Since the body of the graph is not visible we insert a random string + # to the path in order to reflect that we don't know its full contents. + result.append( + _AnalysisResult( + is_ready_to_run=True, + path=self._translate_path_fn(uuid.uuid4().hex), + dependent_sources={}, + ) + ) + continue + op_inputs = list( + itertools.chain(tensor_or_op.inputs, tensor_or_op.control_inputs) + ) + assert len(op_inputs) == len(parent_analysis_results), ( + op_inputs, + parent_analysis_results, + ) + func_graph_inputs_ready = [ + (next_input, r.is_ready_to_run) + for (next_input, r) in zip(func_graph.inputs, parent_analysis_results) + ] + infos = { + tf_utils.hashable_tensor_or_op(t): _SourceInfo( + ready, f"FuncGraphInput[{idx}]" + ) + for idx, (t, ready) in enumerate(func_graph_inputs_ready) + } + func_graph_analyzer = _GraphAnalyzer( + infos, self._translate_path_fn, func_graph + ) + analyzed_list = [ + func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs + ] + + if len(tensor_or_op.inputs) == len(func_graph.inputs): + tensor_pairs = zip(tensor_or_op.inputs, func_graph.inputs) + else: + # Control flow ops such as while store this information in captures. + tensor_pairs = func_graph.captures + tensor_map = {tf_utils.hashable_tensor_or_op(b): a for a, b in tensor_pairs} + + # Make sure that the dependent sources Tensors are translated from the + # FuncGraph to the outer graph in order to align with the rest of the + # traversal. + for analysis in analyzed_list: + translated_dependent_sources = { + tf_utils.hashable_tensor_or_op(tensor_map[s]) + for s in analysis.dependent_sources + if s in tensor_map + } + result.append( + analysis._replace(dependent_sources=translated_dependent_sources) + ) + return result + + def _compute_analysis_result(self, tensor_or_op, parent_analysis_results): + """Compute analysis result for a tensor or op with its parent results.""" + hashable = tf_utils.hashable_tensor_or_op(tensor_or_op) + if hashable in self._source_info_dict: + source_info = self._source_info_dict[hashable] + # source_info.name may be None but that just means that it relies on an + # output of a previous analyzer, so that's ok. + return _AnalysisResult( + is_ready_to_run=source_info.is_ready_to_run, + path=self._translate_path_fn(source_info.name), + dependent_sources={hashable}, + ) + + func_graphs_analysis_results = ( + self._compute_analysis_results_for_func_attributes( + tensor_or_op, parent_analysis_results + ) + ) + + result = _AnalysisResult( + is_ready_to_run=all( + analysis_result.is_ready_to_run + for analysis_result in ( + parent_analysis_results + func_graphs_analysis_results + ) + ), + path=self._translate_path_fn( + tensor_or_op, + parents=[ + parent_analysis_result.path + for parent_analysis_result in parent_analysis_results + ] + + [func_result.path for func_result in func_graphs_analysis_results], + ), + dependent_sources=set(), + ) + for parent_analysis_result in parent_analysis_results: + result.dependent_sources.update(parent_analysis_result.dependent_sources) + for func_result in func_graphs_analysis_results: + result.dependent_sources.update(func_result.dependent_sources) + return result + + def analyze_tensor(self, tensor_or_op): + """Analyzes the `tensor_or_op` for its dependencies and readiness. + + Computes the transitive dependencies of a tensor or operation and decides + whether it is ready to run using iterative DFS. `source_info_dict` are used + as terminal nodes. An error is thrown if a table or placeholder is reached: + they must be set using source_info_dict. This function is memoized using the + _memoized_analyze_tensor_result cache. Cycles are ignored (so a cycle is + considered ready to run). + + Args: + ---- + tensor_or_op: A `Tensor` or `Operation`. + + Returns: + ------- + An _AnalysisResult which includes whether this op or tensor is ready to + run, a path from it to its sources and its dependent sources from + `source_info_dict`. + + Raises: + ------ + _UnexpectedTableError: If an initializable table op is encountered. + _UnexpectedPlaceholderError: If a placeholder is encountered. + """ + stack = collections.deque() + # Note that because tensors are no longer hashable, we need to convert to + # their reference in order to use them in sets or dicts. + stack.append(tf_utils.hashable_tensor_or_op(tensor_or_op)) + # Contains the nodes of the path starting from tensor_or_op to current + # visiting node, used for loop detection. We assume that any loop is a + # valid while loop and so it will be able to run as long as all the other + # parents are ready. + path = set() + while stack: + current = stack[-1] + if current in self._memoized_analyze_tensor_result: + stack.pop() + continue + path.add(current) + parents = self._get_parents(tf_utils.deref_tensor_or_op(current)) + parents = [ + parent + for parent in map(tf_utils.hashable_tensor_or_op, parents) + if parent not in path + ] + if all( + parent in self._memoized_analyze_tensor_result for parent in parents + ): + parent_results = [ + self._memoized_analyze_tensor_result[parent] for parent in parents + ] + current_result = self._compute_analysis_result( + tf_utils.deref_tensor_or_op(current), parent_results + ) + self._memoized_analyze_tensor_result[current] = current_result + path.discard(stack.pop()) + else: + stack.extend(parents) + return self._memoized_analyze_tensor_result[ + tf_utils.hashable_tensor_or_op(tensor_or_op) + ] - def __init__(self, source_info_dict, translate_path_fn, graph): - """Init method for _GraphAnalyzer. + def ready_to_run(self, tensor_or_op): + """Determine if a given tensor or op is ready to run. + + A tensor is ready to run if every tensor in all its transitive dependencies + are set to `True` in `known_ready`. + + Note that if a placeholder is encountered, this will result in an error as + it is assumed that all placeholders are keys in `known_ready`. This is + to avoid unexpected behavior when the user creates placeholders (as opposed + to placeholders created by the tf.Transform framework). + + Similarly encountering a Table op is an error because a table should be + a key in `known_ready` (in the case of analyzing the main session run) or + should not be encountered (in the case of analyzing the graph init run). + + Args: + ---- + tensor_or_op: A `Tensor`, `SparseTensor`, `RaggedTensor` or `Operation` + + Returns: + ------- + A bool indicating whether then tensor is ready to run. + + Raises: + ------ + ValueError: If a placeholder or table is encountered. + _UnexpectedTableError: If an initializable table op is encountered. + _UnexpectedPlaceholderError: If a placeholder is encountered. + """ + if not isinstance( + tensor_or_op, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor, tf.Operation) + ): + raise TypeError( + f"Expected Tensor, SparseTensor, RaggedTensor, or Operation got {tensor_or_op} of type {type(tensor_or_op)}" + ) + return all( + self.analyze_tensor(component).is_ready_to_run + for component in _decompose_tensor_or_op(tensor_or_op) + ) + + def get_unique_path(self, tensor): + """Gets the analyzed path from the tensor to its root(s). + + This path is defined recursively as: + Path(root) := translate_path_fn(root) + Path(x) := translate_path_fn( + x, + [translate_path_fn(p) for p in parents(x)]) + + When root is defined as a tensor that has no parents. + + Args: + ---- + tensor: A `Tensor` for which a path should be computed. + + Returns: + ------- + The result of translate_path_fn on the computed path as described above. + + Raises: + ------ + TypeError: if the given tensor is not of type `Tensor` + _UnexpectedTableError: If an initializable table op is encountered. + _UnexpectedPlaceholderError: If a placeholder is encountered. + """ + if not isinstance(tensor, tf.Tensor): + raise TypeError(f"Expected Tensor got {tensor} of type {type(tensor)}") + return self.analyze_tensor(tensor).path - Args: - source_info_dict: A dict from `Tensor Reference` or `Operation` to - `_SourceInfo`. - translate_path_fn: A function with the signature: (identifier, parents) -> - Any which will be used to construct a unique path for a given `Tensor`. - graph: A `tf.Graph` which the given tensors belong to. - """ - self._memoized_analyze_tensor_result = {} - self._source_info_dict = source_info_dict - self._translate_path_fn = translate_path_fn - self._graph = graph - - def _get_parents(self, tensor_or_op): - """Get the parents of the given `tensor_or_op`.""" - if tf_utils.hashable_tensor_or_op(tensor_or_op) in self._source_info_dict: - return [] - - # func_graph_name is not None only if the graph is a FuncGraph. - func_graph_name = getattr(self._graph, 'name', None) - if isinstance(tensor_or_op, tf.Operation): - if tensor_or_op.type in _INITIALIZABLE_TABLE_OP_TYPES: - raise _UnexpectedTableError(tensor_or_op, func_graph_name) - if tensor_or_op.type == 'Placeholder': - # If we're not in the context of a tf.function, this is an error. - if func_graph_name is None: - raise _UnexpectedPlaceholderError(tensor_or_op, func_graph_name) - # If we're in the context of a tf.function and this op is part of its - # inputs, that's expected. - if tensor_or_op not in [x.op for x in self._graph.inputs]: - raise _UnexpectedPlaceholderError(tensor_or_op, func_graph_name) - parents = list( - itertools.chain(tensor_or_op.inputs, tensor_or_op.control_inputs)) - elif isinstance(tensor_or_op, tf.Tensor): - parents = [tensor_or_op.op] - else: - raise TypeError('Expected Tensor or Operation, got {} of type {}'.format( - tensor_or_op, type(tensor_or_op))) - return parents - def _compute_analysis_results_for_func_attributes(self, tensor_or_op, - parent_analysis_results): - """Analyzes `FuncGraph`s if tensor_or_op has them as attributes. +def _set_unique_value_in_dict(input_dict, key, value): + assert value not in input_dict.values(), value + input_dict[tf_utils.hashable_tensor_or_op(key)] = value - This functionality is added to support `Operation`s such as PartitionedCall - (tf.function call) and control flow ops which use `func` attributes. - These func attributes are references to `FuncGraph`s which can also be - analyzed, and the result of their analysis can be used as additional - information for the current node (`tensor_or_op`). +class InitializableGraphAnalyzer: + """Determines which tensors will be ready when running the graph. - Since `FuncGraph`s are completely different graphs than the one that this - _GraphAnalyzer is analyzing, their analysis wouldn't be taken into account - when analysing the current graph even though they will affect the runtime - results of running it. This is why we have to manually analyze those - sub-graphs as well as the main graph when computing graph information such - as dependent_inputs, unique_path, etc. + Determines which tensors from `fetches` are ready to run, using following + algorithm. - Args: - tensor_or_op: A `Tensor` or `Operation` object. - parent_analysis_results: A list of `_AnalysisResult`s, results of analysis - of the parents of tensor_or_op. + 1. Determine which table initializers are ready to run. A table initializer + is an element of the TABLE_INITIALIZERS collection and it is ready to run + if all the tensors it depends on are set to ready in + `replaced_tensors_ready`. - Returns: - A list of `_AnalysisResult`s, the results of analysis of `tensor_or_op`'s - func attributes. All `Tensor`s in dependent_sources belong to self._graph. + 2. Determine which of `fetches` are ready to run. A fetch is ready to run if + it only depends on tensors in `feeds` and tensors that are set to ready in + `replaced_tensors_ready`. """ - if not isinstance(tensor_or_op, tf.Operation): - return [] - func_attributes = [ - attr.name for attr in tensor_or_op.op_def.attr if attr.type == 'func' - ] - func_names = [tensor_or_op.get_attr(str(n)).name for n in func_attributes] - func_graphs = [get_func_graph_for_name(self._graph, n) for n in func_names] - result = [] - for func_graph in func_graphs: - if not hasattr(func_graph, 'inputs'): - # Since the body of the graph is not visible we insert a random string - # to the path in order to reflect that we don't know its full contents. - result.append( - _AnalysisResult( - is_ready_to_run=True, - path=self._translate_path_fn(uuid.uuid4().hex), - dependent_sources={})) - continue - op_inputs = list( - itertools.chain(tensor_or_op.inputs, tensor_or_op.control_inputs)) - assert len(op_inputs) == len(parent_analysis_results), ( - op_inputs, parent_analysis_results) - func_graph_inputs_ready = [ - (next_input, r.is_ready_to_run) - for (next_input, r) in zip(func_graph.inputs, parent_analysis_results) - ] - infos = { - tf_utils.hashable_tensor_or_op(t): - _SourceInfo(ready, 'FuncGraphInput[{}]'.format(idx)) - for idx, (t, ready) in enumerate(func_graph_inputs_ready) - } - func_graph_analyzer = _GraphAnalyzer(infos, self._translate_path_fn, - func_graph) - analyzed_list = [ - func_graph_analyzer.analyze_tensor(t) for t in func_graph.outputs - ] - - if len(tensor_or_op.inputs) == len(func_graph.inputs): - tensor_pairs = zip(tensor_or_op.inputs, func_graph.inputs) - else: - # Control flow ops such as while store this information in captures. - tensor_pairs = func_graph.captures - tensor_map = { - tf_utils.hashable_tensor_or_op(b): a for a, b in tensor_pairs - } - - # Make sure that the dependent sources Tensors are translated from the - # FuncGraph to the outer graph in order to align with the rest of the - # traversal. - for analysis in analyzed_list: - translated_dependent_sources = { - tf_utils.hashable_tensor_or_op(tensor_map[s]) - for s in analysis.dependent_sources - if s in tensor_map + def __init__( + self, graph, input_signature, replaced_tensors_ready, translate_path_fn=None + ): + """Init method for InitializableGraphAnalyzer. + + Args: + ---- + graph: a `Graph`. + input_signature: A dict whose keys are strings and values are `Tensor`s, + `SparseTensor`s, or `RaggedTensor`s. + replaced_tensors_ready: a list of `Tensor`, `SparseTensor`s, or + `RaggedTensor`s, bool pairs indicating whether the `Tensor`, + `SparseTensor`s, or `RaggedTensor`s is ready in this phase. + translate_path_fn: (Optional) A function with the signature: (identifier, + optional(parents)) -> Any which will be used to construct a unique path + for a given `Tensor`. + + Raises: + ------ + ValueError: If unexpected placeholders or tables are encountered, or table + initializers do not have the expected structure in the graph. + """ + if translate_path_fn is None: + translate_path_fn = lambda x, parents=None: None + + self._ready_table_initializers = [] + self._input_signature = input_signature + replaced_tensors_ready = { + tf_utils.hashable_tensor_or_op(t): ready + for t, ready in replaced_tensors_ready } - result.append( - analysis._replace(dependent_sources=translated_dependent_sources)) - return result - - def _compute_analysis_result(self, tensor_or_op, parent_analysis_results): - """Compute analysis result for a tensor or op with its parent results.""" - hashable = tf_utils.hashable_tensor_or_op(tensor_or_op) - if hashable in self._source_info_dict: - source_info = self._source_info_dict[hashable] - # source_info.name may be None but that just means that it relies on an - # output of a previous analyzer, so that's ok. - return _AnalysisResult( - is_ready_to_run=source_info.is_ready_to_run, - path=self._translate_path_fn(source_info.name), - dependent_sources={hashable}) - - func_graphs_analysis_results = ( - self._compute_analysis_results_for_func_attributes( - tensor_or_op, parent_analysis_results)) - - result = _AnalysisResult( - is_ready_to_run=all( - analysis_result.is_ready_to_run - for analysis_result in (parent_analysis_results + - func_graphs_analysis_results)), - path=self._translate_path_fn( - tensor_or_op, - parents=[ - parent_analysis_result.path - for parent_analysis_result in parent_analysis_results - ] + - [func_result.path for func_result in func_graphs_analysis_results]), - dependent_sources=set()) - for parent_analysis_result in parent_analysis_results: - result.dependent_sources.update(parent_analysis_result.dependent_sources) - for func_result in func_graphs_analysis_results: - result.dependent_sources.update(func_result.dependent_sources) - return result - def analyze_tensor(self, tensor_or_op): - """Analyzes the `tensor_or_op` for its dependencies and readiness. + initial_source_infos_dict = self._make_source_infos_dict( + {}, replaced_tensors_ready + ) + + # Determine which table initializers are ready, based on the replaced + # tensors. Since no input tensors are fed during table initialization, we do + # not set the value of any tensors in `input_signature`. + graph_analyzer_for_table_init = _GraphAnalyzer( + initial_source_infos_dict, translate_path_fn, graph + ) + complete_source_info_dict = self._make_source_infos_dict( + input_signature, replaced_tensors_ready + ) + + for table_init_op_or_tensor in graph.get_collection( + tf.compat.v1.GraphKeys.TABLE_INITIALIZERS + ): + # Handle the case when an initializer was lifted out of the graph context. + if table_init_op_or_tensor is None: + continue + + if isinstance(graph, tf_func_graph.FuncGraph): + self._log_warning( + "Tables initialized inside a tf.function will be" + " re-initialized on every invocation of the function." + " This re-initialization can have significant impact" + " on performance. Consider lifting them out of the" + f" graph context using `tf.init_scope`.: {table_init_op_or_tensor.name}" + ) + + table_init_op, table_input_ops = self._get_table_init_op_and_inputs( + table_init_op_or_tensor + ) + source_info = self._get_table_init_op_source_info( + table_init_op, graph_analyzer_for_table_init, translate_path_fn + ) + + for key in table_input_ops: + complete_source_info_dict[tf_utils.hashable_tensor_or_op(key)] = ( + source_info + ) + if source_info.is_ready_to_run: + self._ready_table_initializers.append(table_init_op_or_tensor) + + # Now determine which tensors are ready to run once the table has been + # initialized. + self._graph_analyzer = _GraphAnalyzer( + complete_source_info_dict, translate_path_fn, graph + ) + + def _log_warning(self, message: str): + logging.warning(message) + + def _get_table_init_op_and_inputs(self, table_init_op_or_tensor): + """Get a tuple of table init op and keys for its input ops.""" + # If a TF2 exported SavedModel with a table is loaded inside the + # preprocessing_fn, the TABLE_INITIALIZERS collection of the outer graph + # contains a Tensor whose parent op is of type StatefulPartitionedCall. + # The nested func graph for this StatefulPartitionedCall contains the + # table initializer. + if ( + isinstance(table_init_op_or_tensor, tf.Tensor) + and table_init_op_or_tensor.op.type == "StatefulPartitionedCall" + ): + result = ( + table_init_op_or_tensor.op, + [input_t.op for input_t in table_init_op_or_tensor.op.inputs], + ) + else: + assert isinstance(table_init_op_or_tensor, tf.Operation) + # We are using the table init op information and the table op information, + # since that is a unique description of the table op. + table_ops = [] + for input_t in table_init_op_or_tensor.inputs: + # One of the inputs to the initializer op should be the table op. If + # no table op is found, (as in the case of a StatefulPartitionedCall) + # all inputs are added to the source dict. + if input_t.dtype == tf.resource: + table_ops.append(input_t.op) + assert len(table_ops) == 1 + result = (table_init_op_or_tensor, [table_ops[0]]) + return result + + def _make_source_infos_dict(self, input_signature, replaced_tensors_ready): + """Builds a dictionary from source tensors to _SourceInfos. + + This dictionary stores information about the sources of the graph. + Each tensor in replaced_tensors_ready is a source whose readiness is known + and has no name. Each tensor (or component of a tensor) in input_signature + is ready to run and has a name determined by the signature. + + Args: + ---- + input_signature: A dict whose keys are strings and values are `Tensor`s, + `SparseTensor`s, or `RaggedTensor`s. + replaced_tensors_ready: a dict from `Tensor`, `SparseTensor`s, or + `RaggedTensor`s to bool indicating whether the tensor is ready in this + phase. + + Returns: + ------- + a dictionary from source tensors to _SourceInfos. + """ + result = {} + for tensor_or_op, is_ready in replaced_tensors_ready.items(): + for component in _decompose_tensor_or_op( + tf_utils.deref_tensor_or_op(tensor_or_op) + ): + result[tf_utils.hashable_tensor_or_op(component)] = _SourceInfo( + is_ready, None + ) + + for name, tensor in input_signature.items(): + if isinstance(tensor, tf.Tensor): + _set_unique_value_in_dict( + result, tensor, _SourceInfo(True, f"{name}$tensor") + ) + elif isinstance(tensor, composite_tensor.CompositeTensor): + for idx, tensor_component in enumerate(_decompose_tensor_or_op(tensor)): + _set_unique_value_in_dict( + result, + tensor_component, + _SourceInfo(True, f"{name}$composite_tensor_{idx}"), + ) + else: + raise TypeError( + f"Expected Tensor, or CompositeTensor, got {tensor} of type {type(tensor)}" + ) + return result + + def _get_table_init_op_source_info( + self, table_init_op, graph_analyzer, translate_path_fn + ): + """Gets a _SourceInfo for a given table init op.""" + if table_init_op.type not in _TABLE_INIT_OP_TYPES: + raise ValueError( + f"Table initializer {table_init_op} did not have expected op type" + ) + if not table_init_op.inputs: + raise ValueError( + f"Table initializer {table_init_op} did not have expected number if inputs " + "(expected >= 1 inputs, got 0)" + ) + table_op = table_init_op.inputs[0].op + table_init_inputs = table_init_op.inputs[1:] + try: + ready = all(map(graph_analyzer.ready_to_run, table_init_inputs)) + path = translate_path_fn( + table_op, + parents=list(map(graph_analyzer.get_unique_path, table_init_inputs)), + ) + except _UnexpectedPlaceholderError as e: + if e.func_graph_name: + raise e + raise ValueError( + f"The table initializer {table_init_op} depended on a placeholder ({e.tensor}). Note " + "placeholders will not be fed during table initialization" + ) from e + except _UnexpectedTableError as e: + if e.func_graph_name: + raise e + raise ValueError( + f"The table initializer {table_init_op} depended on an initializable table ({e.op}). " + "Note tables are initialized in one pass so a table initializer " + "cannot depend on the output of an initializeable table" + ) from e + return _SourceInfo(ready, path) + + @property + def ready_table_initializers(self): + return self._ready_table_initializers + + @_reraise_unexpected_error + def ready_to_run(self, tensor_or_op): + """Determine if a given tensor or op is ready to run.""" + return self._graph_analyzer.ready_to_run(tensor_or_op) + + @_reraise_unexpected_error + def get_unique_path(self, tensor): + """Gets the analyzed path from the tensor to its root(s). + + This path is defined recursively as: + Path(root) := translate_path_fn(root) + Path(x) := translate_path_fn( + x, + [translate_path_fn(p) for p in parents(x)]) + + When root is defined as a tensor that has no parents. + + Args: + ---- + tensor: A `Tensor` for which a path should be computed. + + Returns: + ------- + The result of translate_path_fn on the computed path as described above. + """ + return self._graph_analyzer.get_unique_path(tensor) + + @_reraise_unexpected_error + def get_dependent_inputs(self, tensor_or_op): + """Gets the inputs that the given `tensor_or_op` transitively depends on. + + Args: + ---- + tensor_or_op: A `Tensor`, `SparseTensor`, `RaggedTensor` or `Operation`. + + Returns: + ------- + A dict of name to `Tensor`, `SparseTensor`, or `RaggedTensor` (sub-dict of + `input_signature`) that the given `tensor_or_op` depends on. + + Raises: + ------ + TypeError: If `tensor_or_op` is of an unsupported type. + """ + if not isinstance( + tensor_or_op, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor, tf.Operation) + ): + raise TypeError( + f"Expected Tensor, SparseTensor, RaggedTensor or Operation got {tensor_or_op} of " + f"type {type(tensor_or_op)}" + ) + + dependents = set() + for component in _decompose_tensor_or_op(tensor_or_op): + dependents.update( + self._graph_analyzer.analyze_tensor(component).dependent_sources + ) + + result = {} + for name, tensor in self._input_signature.items(): + if any( + tf_utils.hashable_tensor_or_op(component) in dependents + for component in _decompose_tensor_or_op(tensor) + ): + result[name] = tensor + return result - Computes the transitive dependencies of a tensor or operation and decides - whether it is ready to run using iterative DFS. `source_info_dict` are used - as terminal nodes. An error is thrown if a table or placeholder is reached: - they must be set using source_info_dict. This function is memoized using the - _memoized_analyze_tensor_result cache. Cycles are ignored (so a cycle is - considered ready to run). - Args: - tensor_or_op: A `Tensor` or `Operation`. +class _QuietInitializableGraphAnalyzer(InitializableGraphAnalyzer): + """A `InitializableGraphAnalyzer` which doesn't log any warnings.""" - Returns: - An _AnalysisResult which includes whether this op or tensor is ready to - run, a path from it to its sources and its dependent sources from - `source_info_dict`. + def _log_warning(self, message: str): + pass - Raises: - _UnexpectedTableError: If an initializable table op is encountered. - _UnexpectedPlaceholderError: If a placeholder is encountered. - """ - stack = collections.deque() - # Note that because tensors are no longer hashable, we need to convert to - # their reference in order to use them in sets or dicts. - stack.append(tf_utils.hashable_tensor_or_op(tensor_or_op)) - # Contains the nodes of the path starting from tensor_or_op to current - # visiting node, used for loop detection. We assume that any loop is a - # valid while loop and so it will be able to run as long as all the other - # parents are ready. - path = set() - while stack: - current = stack[-1] - if current in self._memoized_analyze_tensor_result: - stack.pop() - continue - path.add(current) - parents = self._get_parents(tf_utils.deref_tensor_or_op(current)) - parents = [parent for parent in map(tf_utils.hashable_tensor_or_op, - parents) if parent not in path] - if all( - parent in self._memoized_analyze_tensor_result for parent in parents): - parent_results = [ - self._memoized_analyze_tensor_result[parent] for parent in parents - ] - current_result = self._compute_analysis_result( - tf_utils.deref_tensor_or_op(current), parent_results) - self._memoized_analyze_tensor_result[current] = current_result - path.discard(stack.pop()) - else: - stack.extend(parents) - return self._memoized_analyze_tensor_result[tf_utils.hashable_tensor_or_op( - tensor_or_op)] - - def ready_to_run(self, tensor_or_op): - """Determine if a given tensor or op is ready to run. - - A tensor is ready to run if every tensor in all its transitive dependencies - are set to `True` in `known_ready`. - - Note that if a placeholder is encountered, this will result in an error as - it is assumed that all placeholders are keys in `known_ready`. This is - to avoid unexpected behavior when the user creates placeholders (as opposed - to placeholders created by the tf.Transform framework). - - Similarly encountering a Table op is an error because a table should be - a key in `known_ready` (in the case of analyzing the main session run) or - should not be encountered (in the case of analyzing the graph init run). - - Args: - tensor_or_op: A `Tensor`, `SparseTensor`, `RaggedTensor` or `Operation` - Returns: - A bool indicating whether then tensor is ready to run. - - Raises: - ValueError: If a placeholder or table is encountered. - _UnexpectedTableError: If an initializable table op is encountered. - _UnexpectedPlaceholderError: If a placeholder is encountered. - """ - if not isinstance( - tensor_or_op, - (tf.Tensor, tf.SparseTensor, tf.RaggedTensor, tf.Operation)): - raise TypeError( - 'Expected Tensor, SparseTensor, RaggedTensor, or Operation got {} of type {}' - .format(tensor_or_op, type(tensor_or_op))) - return all( - self.analyze_tensor(component).is_ready_to_run - for component in _decompose_tensor_or_op(tensor_or_op)) - - def get_unique_path(self, tensor): - """Gets the analyzed path from the tensor to its root(s). - - This path is defined recursively as: - Path(root) := translate_path_fn(root) - Path(x) := translate_path_fn( - x, - [translate_path_fn(p) for p in parents(x)]) - - When root is defined as a tensor that has no parents. +def get_dependent_inputs(graph, input_tensors, output_tensors): + """Returns tensors in input_tensors that (transitively) produce output_tensors. Args: - tensor: A `Tensor` for which a path should be computed. + ---- + graph: A `tf.Graph`. It could be the (intermediate) output tf graph in any + transform phase (including phase 0 where no tensor replacement has yet + happened). + input_tensors: A dict of logical name to `tf.Tensor`, `tf.SparseTensor`, or + `tf.RaggedTensor`. Logical name doesn't have any implications in this + method and can be anything. In some cases it is the feature name + corresponding to the input tensor. + output_tensors: A dict of logical name to `tf.Tensor`, `tf.SparseTensor`, or + `tf.RaggedTensor`, or a list of `tf.Tensor`, `tf.SparseTensor`, or + `tf.RaggedTensor`. Returns: - The result of translate_path_fn on the computed path as described above. - - Raises: - TypeError: if the given tensor is not of type `Tensor` - _UnexpectedTableError: If an initializable table op is encountered. - _UnexpectedPlaceholderError: If a placeholder is encountered. + ------- + A dict of logical name to `tf.Tensor`, `tf.SparseTensor`, or + `tf.RaggedTensor` that are filtered from input_tensors (transitively) + producing output_tensors """ - if not isinstance(tensor, tf.Tensor): - raise TypeError('Expected Tensor got {} of type {}'.format( - tensor, type(tensor))) - return self.analyze_tensor(tensor).path - - -def _set_unique_value_in_dict(input_dict, key, value): - assert value not in input_dict.values(), value - input_dict[tf_utils.hashable_tensor_or_op(key)] = value - - -class InitializableGraphAnalyzer: - """Determines which tensors will be ready when running the graph. - - Determines which tensors from `fetches` are ready to run, using following - algorithm. - - 1. Determine which table initializers are ready to run. A table initializer - is an element of the TABLE_INITIALIZERS collection and it is ready to run - if all the tensors it depends on are set to ready in - `replaced_tensors_ready`. - - 2. Determine which of `fetches` are ready to run. A fetch is ready to run if - it only depends on tensors in `feeds` and tensors that are set to ready in - `replaced_tensors_ready`. - """ - - def __init__(self, - graph, - input_signature, - replaced_tensors_ready, - translate_path_fn=None): - """Init method for InitializableGraphAnalyzer. - - Args: - graph: a `Graph`. - input_signature: A dict whose keys are strings and values are `Tensor`s, - `SparseTensor`s, or `RaggedTensor`s. - replaced_tensors_ready: a list of `Tensor`, `SparseTensor`s, or - `RaggedTensor`s, bool pairs indicating whether the `Tensor`, - `SparseTensor`s, or `RaggedTensor`s is ready in this phase. - translate_path_fn: (Optional) A function with the signature: (identifier, - optional(parents)) -> Any which will be used to construct a unique path - for a given `Tensor`. - - Raises: - ValueError: If unexpected placeholders or tables are encountered, or table - initializers do not have the expected structure in the graph. - """ - - if translate_path_fn is None: - translate_path_fn = lambda x, parents=None: None - - self._ready_table_initializers = [] - self._input_signature = input_signature - replaced_tensors_ready = {tf_utils.hashable_tensor_or_op(t): ready - for t, ready in replaced_tensors_ready} - - initial_source_infos_dict = self._make_source_infos_dict( - {}, replaced_tensors_ready) - - # Determine which table initializers are ready, based on the replaced - # tensors. Since no input tensors are fed during table initialization, we do - # not set the value of any tensors in `input_signature`. - graph_analyzer_for_table_init = _GraphAnalyzer(initial_source_infos_dict, - translate_path_fn, graph) - complete_source_info_dict = self._make_source_infos_dict( - input_signature, replaced_tensors_ready) - - for table_init_op_or_tensor in graph.get_collection( - tf.compat.v1.GraphKeys.TABLE_INITIALIZERS): - # Handle the case when an initializer was lifted out of the graph context. - if table_init_op_or_tensor is None: - continue - - if isinstance(graph, tf_func_graph.FuncGraph): - self._log_warning('Tables initialized inside a tf.function will be' - ' re-initialized on every invocation of the function.' - ' This re-initialization can have significant impact' - ' on performance. Consider lifting them out of the' - ' graph context using `tf.init_scope`.: {}'.format( - table_init_op_or_tensor.name)) - - table_init_op, table_input_ops = ( - self._get_table_init_op_and_inputs(table_init_op_or_tensor)) - source_info = self._get_table_init_op_source_info( - table_init_op, graph_analyzer_for_table_init, translate_path_fn) - - for key in table_input_ops: - complete_source_info_dict[tf_utils.hashable_tensor_or_op( - key)] = source_info - if source_info.is_ready_to_run: - self._ready_table_initializers.append(table_init_op_or_tensor) - - # Now determine which tensors are ready to run once the table has been - # initialized. - self._graph_analyzer = _GraphAnalyzer(complete_source_info_dict, - translate_path_fn, graph) - - def _log_warning(self, message: str): - logging.warning(message) - - def _get_table_init_op_and_inputs(self, table_init_op_or_tensor): - """Get a tuple of table init op and keys for its input ops.""" - # If a TF2 exported SavedModel with a table is loaded inside the - # preprocessing_fn, the TABLE_INITIALIZERS collection of the outer graph - # contains a Tensor whose parent op is of type StatefulPartitionedCall. - # The nested func graph for this StatefulPartitionedCall contains the - # table initializer. - if (isinstance(table_init_op_or_tensor, tf.Tensor) and - table_init_op_or_tensor.op.type == 'StatefulPartitionedCall'): - result = (table_init_op_or_tensor.op, - [input_t.op for input_t in table_init_op_or_tensor.op.inputs]) + if isinstance(output_tensors, list): + output_container = output_tensors else: - assert isinstance(table_init_op_or_tensor, tf.Operation) - # We are using the table init op information and the table op information, - # since that is a unique description of the table op. - table_ops = [] - for input_t in table_init_op_or_tensor.inputs: - # One of the inputs to the initializer op should be the table op. If - # no table op is found, (as in the case of a StatefulPartitionedCall) - # all inputs are added to the source dict. - if input_t.dtype == tf.resource: - table_ops.append(input_t.op) - assert len(table_ops) == 1 - result = (table_init_op_or_tensor, [table_ops[0]]) - return result + output_container = output_tensors.values() + + # Since this method may be called before all tensor replacements are ready, to + # fulfill the precondition of InitializableGraphAnalyzer, we fake the + # readiness of tensor replacements. Note that the readiness of replacement + # tensors doesn't affect the correctness of dependencies tracing. + tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) + sink_tensors_ready = [(sink.tensor, False) for sink in tensor_sinks] + graph_analyzer = _QuietInitializableGraphAnalyzer( + graph, input_tensors, sink_tensors_ready + ) + dependent_inputs = {} + for output_tensor in output_container: + dependent_inputs.update(graph_analyzer.get_dependent_inputs(output_tensor)) + return { + name: tensor + for name, tensor in input_tensors.items() + if name in dependent_inputs + } - def _make_source_infos_dict(self, input_signature, replaced_tensors_ready): - """Builds a dictionary from source tensors to _SourceInfos. - This dictionary stores information about the sources of the graph. - Each tensor in replaced_tensors_ready is a source whose readiness is known - and has no name. Each tensor (or component of a tensor) in input_signature - is ready to run and has a name determined by the signature. - - Args: - input_signature: A dict whose keys are strings and values are `Tensor`s, - `SparseTensor`s, or `RaggedTensor`s. - replaced_tensors_ready: a dict from `Tensor`, `SparseTensor`s, or - `RaggedTensor`s to bool indicating whether the tensor is ready in this - phase. - - Returns: - a dictionary from source tensors to _SourceInfos. - """ - result = {} - for tensor_or_op, is_ready in replaced_tensors_ready.items(): - for component in _decompose_tensor_or_op( - tf_utils.deref_tensor_or_op(tensor_or_op)): - result[tf_utils.hashable_tensor_or_op(component)] = _SourceInfo( - is_ready, None) - - for name, tensor in input_signature.items(): - if isinstance(tensor, tf.Tensor): - _set_unique_value_in_dict(result, tensor, - _SourceInfo(True, '{}$tensor'.format(name))) - elif isinstance(tensor, composite_tensor.CompositeTensor): - for idx, tensor_component in enumerate(_decompose_tensor_or_op(tensor)): - _set_unique_value_in_dict( - result, tensor_component, - _SourceInfo(True, '{}$composite_tensor_{}'.format(name, idx))) - else: - raise TypeError( - 'Expected Tensor, or CompositeTensor, got {} of type {}'.format( - tensor, type(tensor))) +def _serialize_op_attr(op_attr): + """Deterministicly serializes tf.Operation attrs since it is a map.""" + sorted_attributes = sorted(op_attr.items(), key=lambda kv: kv[0]) + if "f" in op_attr: + # This is a tf.Function node, and it includes attributes that are + # inconsistent across runs such as _gradient_op_type, config_proto, so we + # only keep input and output types since other information will arrive from + # the FuncGraph attributes. + sorted_attributes = [kv for kv in sorted_attributes if kv[0] in ("Tin", "Tout")] + result = [] + for key, attr_value in sorted_attributes: + result.append(key) + attr_value = copy.deepcopy(attr_value) + if attr_value.list.func: + raise ValueError( + "Unable to serialize op attributes that contain a `list.func` field" + ) + if attr_value.HasField("func"): + # There should be a separate call for the FuncGraph attributes. + attr_value.ClearField("func") + result.append(attr_value.SerializeToString()) return result - def _get_table_init_op_source_info(self, table_init_op, graph_analyzer, - translate_path_fn): - """Gets a _SourceInfo for a given table init op.""" - - if table_init_op.type not in _TABLE_INIT_OP_TYPES: - raise ValueError( - 'Table initializer {} did not have expected op type'.format( - table_init_op)) - if not table_init_op.inputs: - raise ValueError( - 'Table initializer {} did not have expected number if inputs ' - '(expected >= 1 inputs, got 0)'.format(table_init_op)) - table_op = table_init_op.inputs[0].op - table_init_inputs = table_init_op.inputs[1:] - try: - ready = all(map(graph_analyzer.ready_to_run, table_init_inputs)) - path = translate_path_fn( - table_op, - parents=list(map(graph_analyzer.get_unique_path, table_init_inputs))) - except _UnexpectedPlaceholderError as e: - if e.func_graph_name: - raise e - raise ValueError( - 'The table initializer {} depended on a placeholder ({}). Note ' - 'placeholders will not be fed during table initialization'.format( - table_init_op, e.tensor)) from e - except _UnexpectedTableError as e: - if e.func_graph_name: - raise e - raise ValueError( - 'The table initializer {} depended on an initializable table ({}). ' - 'Note tables are initialized in one pass so a table initializer ' - 'cannot depend on the output of an initializeable table'.format( - table_init_op, e.op)) from e - return _SourceInfo(ready, path) - - @property - def ready_table_initializers(self): - return self._ready_table_initializers - - @_reraise_unexpected_error - def ready_to_run(self, tensor_or_op): - """Determine if a given tensor or op is ready to run.""" - return self._graph_analyzer.ready_to_run(tensor_or_op) - - @_reraise_unexpected_error - def get_unique_path(self, tensor): - """Gets the analyzed path from the tensor to its root(s). - - This path is defined recursively as: - Path(root) := translate_path_fn(root) - Path(x) := translate_path_fn( - x, - [translate_path_fn(p) for p in parents(x)]) - - When root is defined as a tensor that has no parents. - - Args: - tensor: A `Tensor` for which a path should be computed. - Returns: - The result of translate_path_fn on the computed path as described above. - """ - return self._graph_analyzer.get_unique_path(tensor) +def describe_path_as_analyzer_cache_hash( + x: Optional[Union[tf.Operation, tf.Tensor, str]], + parents: Optional[List[bytes]] = None, +) -> Optional[bytes]: + """Constructs a hash to describe a unique TF graph path. - @_reraise_unexpected_error - def get_dependent_inputs(self, tensor_or_op): - """Gets the inputs that the given `tensor_or_op` transitively depends on. + Note: We do not rely on names for hashing since it can be fragile. Args: - tensor_or_op: A `Tensor`, `SparseTensor`, `RaggedTensor` or `Operation`. + ---- + x: The current TF graph node. + parents: Results of previous calls to this function, where x was an ancestor + to the current node x. Returns: - A dict of name to `Tensor`, `SparseTensor`, or `RaggedTensor` (sub-dict of - `input_signature`) that the given `tensor_or_op` depends on. - - Raises: - TypeError: If `tensor_or_op` is of an unsupported type. + ------- + A bytes hash of the path from x to its sources. None if x is None. """ - if not isinstance( - tensor_or_op, - (tf.Tensor, tf.SparseTensor, tf.RaggedTensor, tf.Operation)): - raise TypeError( - 'Expected Tensor, SparseTensor, RaggedTensor or Operation got {} of ' - 'type {}'.format(tensor_or_op, type(tensor_or_op))) - - dependents = set() - for component in _decompose_tensor_or_op(tensor_or_op): - dependents.update( - self._graph_analyzer.analyze_tensor(component).dependent_sources) - - result = {} - for name, tensor in self._input_signature.items(): - if any( - tf_utils.hashable_tensor_or_op(component) in dependents - for component in _decompose_tensor_or_op(tensor)): - result[name] = tensor - return result - - -class _QuietInitializableGraphAnalyzer(InitializableGraphAnalyzer): - """A `InitializableGraphAnalyzer` which doesn't log any warnings.""" - - def _log_warning(self, message: str): - pass - - -def get_dependent_inputs(graph, input_tensors, output_tensors): - """Returns tensors in input_tensors that (transitively) produce output_tensors. - - Args: - graph: A `tf.Graph`. It could be the (intermediate) output tf graph in any - transform phase (including phase 0 where no tensor replacement has yet - happened). - input_tensors: A dict of logical name to `tf.Tensor`, `tf.SparseTensor`, or - `tf.RaggedTensor`. Logical name doesn't have any implications in this - method and can be anything. In some cases it is the feature name - corresponding to the input tensor. - output_tensors: A dict of logical name to `tf.Tensor`, `tf.SparseTensor`, or - `tf.RaggedTensor`, or a list of `tf.Tensor`, `tf.SparseTensor`, or - `tf.RaggedTensor`. - - Returns: - A dict of logical name to `tf.Tensor`, `tf.SparseTensor`, or - `tf.RaggedTensor` that are filtered from input_tensors (transitively) - producing output_tensors - """ - if isinstance(output_tensors, list): - output_container = output_tensors - else: - output_container = output_tensors.values() - - # Since this method may be called before all tensor replacements are ready, to - # fulfill the precondition of InitializableGraphAnalyzer, we fake the - # readiness of tensor replacements. Note that the readiness of replacement - # tensors doesn't affect the correctness of dependencies tracing. - tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) - sink_tensors_ready = [(sink.tensor, False) for sink in tensor_sinks] - graph_analyzer = _QuietInitializableGraphAnalyzer(graph, input_tensors, - sink_tensors_ready) - dependent_inputs = {} - for output_tensor in output_container: - dependent_inputs.update(graph_analyzer.get_dependent_inputs(output_tensor)) - return { - name: tensor - for name, tensor in input_tensors.items() - if name in dependent_inputs - } - + # This may happen in cases where tensors are outputs of previous analyzers, + # we don't need to describe a path for those. + if x is None: + assert parents is None + return None + parents = parents or [] + if any(p is None for p in parents): + return None + + if isinstance(x, tf.Operation): + values = _serialize_op_attr(x.node_def.attr) + elif isinstance(x, tf.Tensor): + # No need to add x.op to the hash since that should be included in parents. + values = [tf.compat.as_str_any(x.value_index)] + else: + assert isinstance(x, (str, bytes)) + values = [x] -def _serialize_op_attr(op_attr): - """Deterministicly serializes tf.Operation attrs since it is a map.""" - sorted_attributes = sorted(op_attr.items(), key=lambda kv: kv[0]) - if 'f' in op_attr: - # This is a tf.Function node, and it includes attributes that are - # inconsistent across runs such as _gradient_op_type, config_proto, so we - # only keep input and output types since other information will arrive from - # the FuncGraph attributes. - sorted_attributes = [ - kv for kv in sorted_attributes if kv[0] in ('Tin', 'Tout') - ] - result = [] - for key, attr_value in sorted_attributes: - result.append(key) - attr_value = copy.deepcopy(attr_value) - if attr_value.list.func: - raise ValueError( - 'Unable to serialize op attributes that contain a `list.func` field') - if attr_value.HasField('func'): - # There should be a separate call for the FuncGraph attributes. - attr_value.ClearField('func') - result.append(attr_value.SerializeToString()) - return result + h = hashlib.sha1() + for value in values: + encoded = tf.compat.as_bytes(value) + h.update(encoded) + for p in parents: + h.update(p) -def describe_path_as_analyzer_cache_hash( - x: Optional[Union[tf.Operation, tf.Tensor, str]], - parents: Optional[List[bytes]] = None) -> Optional[bytes]: - """Constructs a hash to describe a unique TF graph path. - - Note: We do not rely on names for hashing since it can be fragile. - - Args: - x: The current TF graph node. - parents: Results of previous calls to this function, where x was an ancestor - to the current node x. - - Returns: - A bytes hash of the path from x to its sources. None if x is None. - """ - # This may happen in cases where tensors are outputs of previous analyzers, - # we don't need to describe a path for those. - if x is None: - assert parents is None - return None - parents = parents or [] - if any(p is None for p in parents): - return None - - if isinstance(x, tf.Operation): - values = _serialize_op_attr(x.node_def.attr) - elif isinstance(x, tf.Tensor): - # No need to add x.op to the hash since that should be included in parents. - values = [tf.compat.as_str_any(x.value_index)] - else: - assert isinstance(x, (str, bytes)) - values = [x] - - h = hashlib.sha1() - for value in values: - encoded = tf.compat.as_bytes(value) - h.update(encoded) - - for p in parents: - h.update(p) - - return h.digest() + return h.digest() class SourcedTensorsVisitor(nodes.Visitor): - """Visitor used to extract tensors that are inputs to `TensorSource` nodes.""" + """Visitor used to extract tensors that are inputs to `TensorSource` nodes.""" - def __init__(self): - self.sourced_tensors = [] + def __init__(self): + self.sourced_tensors = [] - def visit(self, operation_def, input_values): - if isinstance(operation_def, analyzer_nodes.TensorSource): - for tensor in operation_def.tensors: - self.sourced_tensors.append(tensor) - return nodes.OperationNode(operation_def, input_values).outputs + def visit(self, operation_def, input_values): + if isinstance(operation_def, analyzer_nodes.TensorSource): + for tensor in operation_def.tensors: + self.sourced_tensors.append(tensor) + return nodes.OperationNode(operation_def, input_values).outputs - def validate_value(self, value): - assert isinstance(value, nodes.ValueNode) + def validate_value(self, value): + assert isinstance(value, nodes.ValueNode) def _retrieve_source_keys( sourced_tensors: Iterable[tf.Tensor], - structured_inputs: Mapping[str, common_types.TensorType]) -> Set[str]: - """Retrieve input keys that sourced_tensors depend on.""" - result = set() - sinks = [t.op for t in sourced_tensors] - sources = retrieve_sources(sinks, ignore_control_dependencies=True) - hashable_sources = [tf_utils.hashable_tensor_or_op(s) for s in sources] - for key, value in structured_inputs.items(): - components = ([ - tf_utils.hashable_tensor_or_op(v) - for v in _decompose_tensor_or_op(value) - ]) - if any([s in components for s in hashable_sources]): - result.add(key) - return result + structured_inputs: Mapping[str, common_types.TensorType], +) -> Set[str]: + """Retrieve input keys that sourced_tensors depend on.""" + result = set() + sinks = [t.op for t in sourced_tensors] + sources = retrieve_sources(sinks, ignore_control_dependencies=True) + hashable_sources = [tf_utils.hashable_tensor_or_op(s) for s in sources] + for key, value in structured_inputs.items(): + components = [ + tf_utils.hashable_tensor_or_op(v) for v in _decompose_tensor_or_op(value) + ] + if any([s in components for s in hashable_sources]): + result.add(key) + return result AnalyzersFingerprint = tfx_namedtuple.TypedNamedTuple( - 'AnalyzersFingerprint', [('source_keys', Set[str]), - ('unique_path_hash', Set[bytes])]) + "AnalyzersFingerprint", + [("source_keys", Set[str]), ("unique_path_hash", Set[bytes])], +) def get_analyzers_fingerprint( graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType] ) -> Mapping[str, AnalyzersFingerprint]: - """Computes fingerprints for all analyzers in `graph`. - - Args: - graph: a TF Graph. - structured_inputs: a dict from keys to batches of placeholder graph tensors. - - Returns: - A mapping from analyzer name to a set of paths that define its fingerprint. - """ - result = {} - tensor_sinks = graph.get_collection(analyzer_nodes.ALL_REPLACEMENTS) - # The value for the keys in this dictionary are unused and can be arbitrary. - sink_tensors_ready = { - tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False - for tensor_sink in tensor_sinks - } - graph_analyzer = InitializableGraphAnalyzer( - graph, structured_inputs, list(sink_tensors_ready.items()), - describe_path_as_analyzer_cache_hash) - for tensor_sink in tensor_sinks: - # Retrieve tensors that are inputs to the analyzer's value node. - visitor = SourcedTensorsVisitor() - nodes.Traverser(visitor).visit_value_node(tensor_sink.future) - source_keys = _retrieve_source_keys(visitor.sourced_tensors, - structured_inputs) - paths = set() - for tensor in visitor.sourced_tensors: - # Obtain fingerprint for each tensor that is an input to the value node. - path = graph_analyzer.get_unique_path(tensor) - if path is not None: - paths.add(path) - result[str(tensor_sink.tensor.name)] = AnalyzersFingerprint( - source_keys, paths) - return result + """Computes fingerprints for all analyzers in `graph`. + + Args: + ---- + graph: a TF Graph. + structured_inputs: a dict from keys to batches of placeholder graph tensors. + + Returns: + ------- + A mapping from analyzer name to a set of paths that define its fingerprint. + """ + result = {} + tensor_sinks = graph.get_collection(analyzer_nodes.ALL_REPLACEMENTS) + # The value for the keys in this dictionary are unused and can be arbitrary. + sink_tensors_ready = { + tf_utils.hashable_tensor_or_op(tensor_sink.tensor): False + for tensor_sink in tensor_sinks + } + graph_analyzer = InitializableGraphAnalyzer( + graph, + structured_inputs, + list(sink_tensors_ready.items()), + describe_path_as_analyzer_cache_hash, + ) + for tensor_sink in tensor_sinks: + # Retrieve tensors that are inputs to the analyzer's value node. + visitor = SourcedTensorsVisitor() + nodes.Traverser(visitor).visit_value_node(tensor_sink.future) + source_keys = _retrieve_source_keys(visitor.sourced_tensors, structured_inputs) + paths = set() + for tensor in visitor.sourced_tensors: + # Obtain fingerprint for each tensor that is an input to the value node. + path = graph_analyzer.get_unique_path(tensor) + if path is not None: + paths.add(path) + result[str(tensor_sink.tensor.name)] = AnalyzersFingerprint(source_keys, paths) + return result diff --git a/tensorflow_transform/graph_tools_test.py b/tensorflow_transform/graph_tools_test.py index 93cc576..5785c8f 100644 --- a/tensorflow_transform/graph_tools_test.py +++ b/tensorflow_transform/graph_tools_test.py @@ -18,1251 +18,1357 @@ import tempfile import tensorflow as tf -from tensorflow_transform import graph_tools -from tensorflow_transform import test_case + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.python.ops import control_flow_ops + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple -# pylint: disable=g-direct-tensorflow-import -from tensorflow.python.ops import control_flow_ops +from tensorflow_transform import graph_tools, test_case + # pylint: disable=g-enable-tensorflow-import mock = tf.compat.v1.test.mock def _create_lookup_table_from_file(filename): - initializer = tf.lookup.TextFileInitializer( - filename, - key_dtype=tf.string, - key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - value_dtype=tf.int64, - value_index=tf.lookup.TextFileIndex.LINE_NUMBER) - return tf.lookup.StaticHashTable(initializer, default_value=-1) + initializer = tf.lookup.TextFileInitializer( + filename, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + ) + return tf.lookup.StaticHashTable(initializer, default_value=-1) def _create_graph_with_y_function_of_x(): - x = tf.compat.v1.placeholder(tf.int64) - y = x + 1 - return {'x': x, 'y': y} + x = tf.compat.v1.placeholder(tf.int64) + y = x + 1 + return {"x": x, "y": y} def _create_graph_with_tf_function(): - x = tf.compat.v1.placeholder(tf.int64) - y = tf.compat.v1.placeholder(tf.int64) + x = tf.compat.v1.placeholder(tf.int64) + y = tf.compat.v1.placeholder(tf.int64) - @tf.function - def foo(x, y): - return x * 2, x + y, y * 2 + @tf.function + def foo(x, y): + return x * 2, x + y, y * 2 - a, b, c = foo(x, y) - return {'x': x, 'y': y, 'z': a + b + c, 'r': a, 'q': foo(x, x)[0]} + a, b, c = foo(x, y) + return {"x": x, "y": y, "z": a + b + c, "r": a, "q": foo(x, x)[0]} def _create_graph_with_tf2_saved_model_with_table(): + def export_saved_model_v2(): + root = tf.Module() - def export_saved_model_v2(): - root = tf.Module() - - @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) - def lookup_fn(x): - table_keys = ['cat', 'dog', 'giraffe'] - root.initializer = tf.lookup.KeyValueTensorInitializer( - keys=table_keys, - values=tf.cast(tf.range(len(table_keys)), tf.int64), - key_dtype=tf.string, - value_dtype=tf.int64) - root.table = tf.lookup.StaticHashTable(root.initializer, default_value=-1) - return root.table.lookup(x) + @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) + def lookup_fn(x): + table_keys = ["cat", "dog", "giraffe"] + root.initializer = tf.lookup.KeyValueTensorInitializer( + keys=table_keys, + values=tf.cast(tf.range(len(table_keys)), tf.int64), + key_dtype=tf.string, + value_dtype=tf.int64, + ) + root.table = tf.lookup.StaticHashTable(root.initializer, default_value=-1) + return root.table.lookup(x) - result = os.path.join(tempfile.mkdtemp(), 'export_path') - root.lookup_fn = lookup_fn - tf.compat.v2.saved_model.save(root, result) - return result + result = os.path.join(tempfile.mkdtemp(), "export_path") + root.lookup_fn = lookup_fn + tf.compat.v2.saved_model.save(root, result) + return result - module_path = export_saved_model_v2() + module_path = export_saved_model_v2() - def bar(x): - imported = tf.compat.v2.saved_model.load(module_path) - return imported.lookup_fn(x) + def bar(x): + imported = tf.compat.v2.saved_model.load(module_path) + return imported.lookup_fn(x) - x = tf.compat.v1.placeholder(tf.string) - a = bar(x) - return {'x': x, 'a': a} + x = tf.compat.v1.placeholder(tf.string) + a = bar(x) + return {"x": x, "a": a} def _create_graph_with_placeholder_in_tf_function(): - x = tf.compat.v1.placeholder(tf.int64) + x = tf.compat.v1.placeholder(tf.int64) - @tf.function - def foo(x): - a = tf.compat.v1.placeholder(tf.int64) - return x * a, a + @tf.function + def foo(x): + a = tf.compat.v1.placeholder(tf.int64) + return x * a, a - y, a = foo(x + 1) - return {'x': x, 'y': y + 1, 'z': a} + y, a = foo(x + 1) + return {"x": x, "y": y + 1, "z": a} def _create_graph_with_mixed_dependencies(): - x = tf.compat.v1.placeholder(tf.int64) - y = tf.compat.v1.placeholder(tf.int64) + x = tf.compat.v1.placeholder(tf.int64) + y = tf.compat.v1.placeholder(tf.int64) - @tf.function - def foo(x): - return x * 2 + @tf.function + def foo(x): + return x * 2 - return {'x': x, 'y': y, 'z': foo(x) + y} + return {"x": x, "y": y, "z": foo(x) + y} def _create_graph_with_chained_tf_function(): - x = tf.compat.v1.placeholder(tf.int64) + x = tf.compat.v1.placeholder(tf.int64) - @tf.function - def goo(x): - return x + 1 + @tf.function + def goo(x): + return x + 1 - @tf.function - def foo(x): - return goo(x) * 2 + @tf.function + def foo(x): + return goo(x) * 2 - return {'x': x, 'y': foo(x) / 2} + return {"x": x, "y": foo(x) / 2} def _create_graph_with_y_function_of_x_with_unused_inputs(): - x = tf.compat.v1.placeholder(tf.int64) - x2 = tf.compat.v1.placeholder(tf.int64) - x_unused = tf.compat.v1.placeholder(tf.int64) - y = x + 1 - z = x2 + 2 - return {'x': x, 'x2': x2, 'x_unused': x_unused, 'y': y, 'z': z} + x = tf.compat.v1.placeholder(tf.int64) + x2 = tf.compat.v1.placeholder(tf.int64) + x_unused = tf.compat.v1.placeholder(tf.int64) + y = x + 1 + z = x2 + 2 + return {"x": x, "x2": x2, "x_unused": x_unused, "y": y, "z": z} def _create_graph_with_y_function_of_x_sparse(): - x = tf.compat.v1.sparse_placeholder(tf.int64) - y = tf.sparse.reduce_sum(x) + 1 - return {'x': x, 'y': y} + x = tf.compat.v1.sparse_placeholder(tf.int64) + y = tf.sparse.reduce_sum(x) + 1 + return {"x": x, "y": y} def _create_graph_with_z_function_of_x_ragged(): - x = tf.compat.v1.ragged.placeholder(tf.int64, 2) - y = x.to_sparse() - z = tf.sparse.reduce_sum(y) + 1 - return {'x': x, 'y': y, 'z': z} + x = tf.compat.v1.ragged.placeholder(tf.int64, 2) + y = x.to_sparse() + z = tf.sparse.reduce_sum(y) + 1 + return {"x": x, "y": y, "z": z} def _create_graph_with_ragged_tensor(): - x1 = tf.compat.v1.placeholder(tf.int64, (1, 3, 3)) - x2 = tf.compat.v1.sparse.placeholder(tf.int64, (4, 3)) - y1 = tf.RaggedTensor.from_tensor(x1, ragged_rank=2) - y2 = tf.RaggedTensor.from_sparse(x2) + 1 - return {'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2} + x1 = tf.compat.v1.placeholder(tf.int64, (1, 3, 3)) + x2 = tf.compat.v1.sparse.placeholder(tf.int64, (4, 3)) + y1 = tf.RaggedTensor.from_tensor(x1, ragged_rank=2) + y2 = tf.RaggedTensor.from_sparse(x2) + 1 + return {"x1": x1, "x2": x2, "y1": y1, "y2": y2} def _create_graph_with_y_sparse_function_of_x_sparse(): - x = tf.compat.v1.sparse_placeholder(tf.int64) - y = tf.SparseTensor( - indices=x.indices, - values=x.values + 1, - dense_shape=x.dense_shape) - return { - 'x': x, - 'y': y, - 'z': tf.compat.v1.sparse.add(y, tf.ones(y.dense_shape, tf.int64)) - } + x = tf.compat.v1.sparse_placeholder(tf.int64) + y = tf.SparseTensor( + indices=x.indices, values=x.values + 1, dense_shape=x.dense_shape + ) + return { + "x": x, + "y": y, + "z": tf.compat.v1.sparse.add(y, tf.ones(y.dense_shape, tf.int64)), + } def _create_graph_with_y_function_of_x_and_table(): - filename = tf.raw_ops.Placeholder(dtype=tf.string, shape=()) - table = _create_lookup_table_from_file(filename) - x = tf.raw_ops.Placeholder(dtype=tf.string, shape=(None,)) - y = table.lookup(x) - return {'filename': filename, 'x': x, 'y': y} + filename = tf.raw_ops.Placeholder(dtype=tf.string, shape=()) + table = _create_lookup_table_from_file(filename) + x = tf.raw_ops.Placeholder(dtype=tf.string, shape=(None,)) + y = table.lookup(x) + return {"filename": filename, "x": x, "y": y} def _create_graph_with_y_function_of_x_and_table_in_first_phase(): - table = _create_lookup_table_from_file(tf.constant('not_a_file_name_but_ok')) - x = tf.raw_ops.Placeholder(dtype=tf.string, shape=(None,)) - y = table.lookup(x) - return {'x': x, 'y': y} + table = _create_lookup_table_from_file(tf.constant("not_a_file_name_but_ok")) + x = tf.raw_ops.Placeholder(dtype=tf.string, shape=(None,)) + y = table.lookup(x) + return {"x": x, "y": y} def _create_graph_with_y_function_of_x_and_untracked_table(): - filename = tf.compat.v1.placeholder(tf.string, ()) - table = _create_lookup_table_from_file(filename) + filename = tf.compat.v1.placeholder(tf.string, ()) + table = _create_lookup_table_from_file(filename) - x = tf.compat.v1.placeholder(tf.string, (None,)) - y = table.lookup(x) - del tf.compat.v1.get_collection_ref( - tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)[:] - return {'filename': filename, 'x': x, 'y': y} + x = tf.compat.v1.placeholder(tf.string, (None,)) + y = table.lookup(x) + del tf.compat.v1.get_collection_ref(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)[:] + return {"filename": filename, "x": x, "y": y} def _create_graph_with_table_initialized_by_table_output(): - filename = tf.compat.v1.placeholder(tf.string, ()) - table1 = _create_lookup_table_from_file(filename) - - # Use output from the first table to initialize the second table. - keys = ['a', 'b', 'c'] - tensor_keys = tf.as_string( - table1.lookup(tf.constant(keys, tf.string))) - initializer2 = tf.lookup.KeyValueTensorInitializer( - keys=tensor_keys, - values=tf.range(len(keys), dtype=tf.int64), - key_dtype=tf.string, - value_dtype=tf.int64) - table2 = tf.lookup.StaticHashTable(initializer2, default_value=-1) - x = tf.compat.v1.placeholder(tf.string, (None,)) - y = table2.lookup(x) - return {'filename': filename, 'x': x, 'y': y} + filename = tf.compat.v1.placeholder(tf.string, ()) + table1 = _create_lookup_table_from_file(filename) + + # Use output from the first table to initialize the second table. + keys = ["a", "b", "c"] + tensor_keys = tf.as_string(table1.lookup(tf.constant(keys, tf.string))) + initializer2 = tf.lookup.KeyValueTensorInitializer( + keys=tensor_keys, + values=tf.range(len(keys), dtype=tf.int64), + key_dtype=tf.string, + value_dtype=tf.int64, + ) + table2 = tf.lookup.StaticHashTable(initializer2, default_value=-1) + x = tf.compat.v1.placeholder(tf.string, (None,)) + y = table2.lookup(x) + return {"filename": filename, "x": x, "y": y} def _create_graph_with_assert_equal(): - x = tf.raw_ops.Placeholder(dtype=tf.int64) - y = tf.raw_ops.Placeholder(dtype=tf.int64) - z = control_flow_ops.with_dependencies( - [tf.raw_ops.Assert(condition=tf.raw_ops.Equal(x=x, y=y), data=[x, y])], x) - return {'x': x, 'y': y, 'z': z} + x = tf.raw_ops.Placeholder(dtype=tf.int64) + y = tf.raw_ops.Placeholder(dtype=tf.int64) + z = control_flow_ops.with_dependencies( + [tf.raw_ops.Assert(condition=tf.raw_ops.Equal(x=x, y=y), data=[x, y])], x + ) + return {"x": x, "y": y, "z": z} def _create_graph_with_y_function_of_x_with_tf_while(): - x = tf.raw_ops.Placeholder(dtype=tf.int64, shape=()) - - # Subtract 10 from x using a tf.while_loop. - @tf.function(input_signature=[ - tf.TensorSpec([], tf.int32), - tf.TensorSpec([], tf.int64) - ]) - def stop_condition(counter, x_minus_counter): - del x_minus_counter # unused - return tf.less(counter, 10) - - @tf.function(input_signature=[ - tf.TensorSpec([], tf.int32), - tf.TensorSpec([], tf.int64) - ]) - def iteration(counter, x_minus_counter): - return tf.add(counter, 1), tf.add(x_minus_counter, -1) - initial_values = [tf.constant(0), x] - final_values = tf.raw_ops.While( - cond=stop_condition.get_concrete_function(), - body=iteration.get_concrete_function(), - input=initial_values) - - y = final_values[1] - return {'x': x, 'y': y} + x = tf.raw_ops.Placeholder(dtype=tf.int64, shape=()) + + # Subtract 10 from x using a tf.while_loop. + @tf.function( + input_signature=[tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.int64)] + ) + def stop_condition(counter, x_minus_counter): + del x_minus_counter # unused + return tf.less(counter, 10) + + @tf.function( + input_signature=[tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.int64)] + ) + def iteration(counter, x_minus_counter): + return tf.add(counter, 1), tf.add(x_minus_counter, -1) + + initial_values = [tf.constant(0), x] + final_values = tf.raw_ops.While( + cond=stop_condition.get_concrete_function(), + body=iteration.get_concrete_function(), + input=initial_values, + ) + + y = final_values[1] + return {"x": x, "y": y} def _create_graph_with_tf_function_while(): - x = tf.raw_ops.Placeholder(dtype=tf.float32, shape=()) + x = tf.raw_ops.Placeholder(dtype=tf.float32, shape=()) - @tf.function - def larger_than_100(x): - while x < 100: - x *= 2 - return x + @tf.function + def larger_than_100(x): + while x < 100: + x *= 2 + return x - return {'x': x, 'y': larger_than_100(x)} + return {"x": x, "y": larger_than_100(x)} class _Matcher(metaclass=abc.ABCMeta): + def _future_proof(self, value): + if isinstance(value, (str, bytes)): + new_to_old = {} + for new, old in new_to_old.items(): + value = value.replace(new, old) + return value - def _future_proof(self, value): - if isinstance(value, (str, bytes)): - new_to_old = {} - for new, old in new_to_old.items(): - value = value.replace(new, old) - return value - - @abc.abstractmethod - def expected_fields(self, other): - raise NotImplementedError + @abc.abstractmethod + def expected_fields(self, other): + raise NotImplementedError - @abc.abstractproperty - def expected_fields_values(self): - raise NotImplementedError + @abc.abstractproperty + def expected_fields_values(self): + raise NotImplementedError - @abc.abstractproperty - def expected_class(self): - raise NotImplementedError + @abc.abstractproperty + def expected_class(self): + raise NotImplementedError - def __eq__(self, other): - if not isinstance(other, self.expected_class): - tf.compat.v1.logging.error('Types do not match, got: %s, expected: %s', - type(other), self.expected_class) - return False + def __eq__(self, other): + if not isinstance(other, self.expected_class): + tf.compat.v1.logging.error( + "Types do not match, got: %s, expected: %s", + type(other), + self.expected_class, + ) + return False - future_expected_fields = tuple( - self._future_proof(f) for f in self.expected_fields_values) - if (self.expected_fields_values != self.expected_fields(other) and - future_expected_fields != self.expected_fields(other)): - tf.compat.v1.logging.error('Fields do not match: %s != %s', - self.expected_fields_values, - self.expected_fields(other)) - return False + future_expected_fields = tuple( + self._future_proof(f) for f in self.expected_fields_values + ) + if self.expected_fields_values != self.expected_fields( + other + ) and future_expected_fields != self.expected_fields(other): + tf.compat.v1.logging.error( + "Fields do not match: %s != %s", + self.expected_fields_values, + self.expected_fields(other), + ) + return False - return True + return True -class _TensorMatcher(_Matcher, - tfx_namedtuple.namedtuple('_TensorMatcher', ['name'])): - __slots__ = () +class _TensorMatcher(_Matcher, tfx_namedtuple.namedtuple("_TensorMatcher", ["name"])): + __slots__ = () - def expected_fields(self, other): - return (str(other.name),) + def expected_fields(self, other): + return (str(other.name),) - @property - def expected_fields_values(self): - return tuple(self) + @property + def expected_fields_values(self): + return tuple(self) - @property - def expected_class(self): - return tf.Tensor + @property + def expected_class(self): + return tf.Tensor -class _OpMatcher(_Matcher, tfx_namedtuple.namedtuple('_OpMatcher', ['name'])): - __slots__ = () +class _OpMatcher(_Matcher, tfx_namedtuple.namedtuple("_OpMatcher", ["name"])): + __slots__ = () - def expected_fields(self, other): - return (str(other.name),) + def expected_fields(self, other): + return (str(other.name),) - @property - def expected_fields_values(self): - return tuple(self) + @property + def expected_fields_values(self): + return tuple(self) - @property - def expected_class(self): - return tf.Operation + @property + def expected_class(self): + return tf.Operation class GraphToolsTest(test_case.TransformTestCase): - - @test_case.named_parameters( - dict( - testcase_name='_y_function_of_x_nothing_ready', - create_graph_fn=_create_graph_with_y_function_of_x, - feeds=[], - replaced_tensors_ready={'x': False}, - should_be_ready={'y': False}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_unused_input_ready', - create_graph_fn=_create_graph_with_y_function_of_x_with_unused_inputs, - feeds=[], - replaced_tensors_ready={ - 'x': False, - 'x2': True, - 'x_unused': True - }, - should_be_ready={ - 'y': False, - 'z': True - }, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_no_feeds_y_is_ready', - create_graph_fn=_create_graph_with_y_function_of_x, - feeds=[], - replaced_tensors_ready={'x': True}, - should_be_ready={'y': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_feeds_x_y_is_ready', - create_graph_fn=_create_graph_with_y_function_of_x, - feeds=['x'], - replaced_tensors_ready={}, - should_be_ready={'y': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_sparse_nothing_ready', - create_graph_fn=_create_graph_with_y_function_of_x_sparse, - feeds=[], - replaced_tensors_ready={'x': False}, - should_be_ready={'y': False}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_sparse_no_feeds_y_is_ready', - create_graph_fn=_create_graph_with_y_function_of_x_sparse, - feeds=[], - replaced_tensors_ready={'x': True}, - should_be_ready={'y': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_sparse_feeds_x_y_is_ready', - create_graph_fn=_create_graph_with_y_function_of_x_sparse, - feeds=['x'], - replaced_tensors_ready={}, - should_be_ready={'y': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_sparse_function_of_x_sparse_nothing_ready', - create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, - feeds=[], - replaced_tensors_ready={'x': False}, - should_be_ready={'y': False}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_sparse_function_of_x_sparse_no_feeds_y_is_ready', - create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, - feeds=[], - replaced_tensors_ready={'x': True}, - should_be_ready={'y': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_sparse_function_of_x_sparse_feeds_x_y_is_ready', - create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, - feeds=['x'], - replaced_tensors_ready={}, - should_be_ready={'y': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_with_tf_while_nothing_ready', - create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, - feeds=[], - replaced_tensors_ready={'x': False}, - should_be_ready={'y': False}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_with_tf_while_no_feeds_y_is_ready', - create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, - feeds=[], - replaced_tensors_ready={'x': True}, - should_be_ready={'y': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_with_tf_while_feeds_x_y_is_ready', - create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, - feeds=['x'], - replaced_tensors_ready={}, - should_be_ready={'y': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_and_table_nothing_ready', - create_graph_fn=_create_graph_with_y_function_of_x_and_table, - feeds=[], - replaced_tensors_ready={ - 'x': False, - 'filename': False - }, - should_be_ready={'y': False}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_and_table_filename_ready_y_is_not', - create_graph_fn=_create_graph_with_y_function_of_x_and_table, - feeds=[], - replaced_tensors_ready={ - 'x': False, - 'filename': True - }, - should_be_ready={'y': False}, - num_ready_table_initializers=1), - dict( - testcase_name='_y_function_of_x_and_table_x_ready_filename_is_not', - create_graph_fn=_create_graph_with_y_function_of_x_and_table, - feeds=[], - replaced_tensors_ready={ - 'x': True, - 'filename': False - }, - should_be_ready={'y': False}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_and_table_everything_is_ready', - create_graph_fn=_create_graph_with_y_function_of_x_and_table, - feeds=[], - replaced_tensors_ready={ - 'x': True, - 'filename': True, - }, - should_be_ready={'y': True}, - num_ready_table_initializers=1), - dict( - testcase_name='_y_function_of_x_and_table_feeds_x_nothing_ready', - create_graph_fn=_create_graph_with_y_function_of_x_and_table, - feeds=['x'], - replaced_tensors_ready={'filename': False}, - should_be_ready={'y': False}, - num_ready_table_initializers=0), - dict( - testcase_name='_y_function_of_x_and_table_feeds_x_everything_ready', - create_graph_fn=_create_graph_with_y_function_of_x_and_table, - feeds=['x'], - replaced_tensors_ready={'filename': True}, - should_be_ready={'y': True}, - num_ready_table_initializers=1), - dict( - testcase_name='_assert_equal', - create_graph_fn=_create_graph_with_assert_equal, - feeds=['x', 'y'], - replaced_tensors_ready={ - 'x': True, - 'y': True, - }, - should_be_ready={'z': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_tf_function', - create_graph_fn=_create_graph_with_tf_function, - feeds=['x', 'y'], - replaced_tensors_ready={}, - should_be_ready={'z': True}, - num_ready_table_initializers=0), - dict( - testcase_name='_tf_function_not_ready', - create_graph_fn=_create_graph_with_tf_function, - feeds=[], - replaced_tensors_ready={ - 'x': True, - 'y': False, - }, - should_be_ready={ - 'z': False, - 'q': True, - }, - num_ready_table_initializers=0), - dict( - testcase_name='_chained_tf_function', - create_graph_fn=_create_graph_with_chained_tf_function, - feeds=['x'], - replaced_tensors_ready={}, - should_be_ready={'y': True}, - num_ready_table_initializers=0), - ) - def testDetermineReadyTensorsAndTableInitializers( - self, create_graph_fn, feeds, replaced_tensors_ready, should_be_ready, - num_ready_table_initializers): - """Test determine_ready_tensors_and_table_initializers. - - Args: - create_graph_fn: A function that adds ops to a graph and returns a dict - mapping tensor names to `Tensor` or `SparseTensor`s. - feeds: A list of keys in the dict returned by create_graph_fn that are fed - in the main run (but not table initialization run). - replaced_tensors_ready: A dict whose keys are keys in the dict returned by - create_graph_fn and values are a bools indicating whether that tensor - is ready to be replaced in this phase. - should_be_ready: A dict whose keys are keys in the dict returned by - create_graph_fn and value are bools indicating whether a tensor can be - calculated in this phase. - num_ready_table_initializers: The number of table initializers that are - ready to run in the table initialization run of this phase. - """ - with tf.compat.v1.Graph().as_default() as graph: - tensors = create_graph_fn() - replaced_tensors_ready = [(tensors[name], ready) - for name, ready in replaced_tensors_ready.items()] - - graph_analyzer = graph_tools.InitializableGraphAnalyzer( - graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready) - self.assertEqual( - len(graph_analyzer.ready_table_initializers), - num_ready_table_initializers) - - for name, ready in should_be_ready.items(): - tensor = tensors[name] - self.assertEqual( - graph_analyzer.ready_to_run(tensor), - ready, - msg='Expected tensor {} to be ready={}'.format(name, ready)) - - @test_case.parameters( - (_create_graph_with_y_function_of_x_and_table, - [], {'x': False}, - 'placeholders will not be fed during table initialization'), - (_create_graph_with_y_function_of_x_and_table, - [], {'x': True}, - 'placeholders will not be fed during table initialization'), - (_create_graph_with_y_function_of_x_and_table, - ['filename'], {'x': False}, - 'placeholders will not be fed during table initialization'), - (_create_graph_with_y_function_of_x_and_table, - ['filename'], {'x': True}, - 'placeholders will not be fed during table initialization'), - (_create_graph_with_y_function_of_x_and_table, - ['filename', 'x'], {}, - 'placeholders will not be fed during table initialization'), - (_create_graph_with_table_initialized_by_table_output, - ['x'], {'filename': True}, - 'tables are initialized in one pass') - ) - def testInitializableGraphAnalyzerConstructorRaises( - self, create_graph_fn, feeds, replaced_tensors_ready, - error_msg_regex): - """Test determine_ready_tensors_and_table_initializers. - - Args: - create_graph_fn: A function that adds ops to a graph and returns a dict - mapping tensor names to `Tensor` or `SparseTensor`s. - feeds: A list of keys in the dict returned by create_graph_fn that are fed - in the main run (but not table initialization run). - replaced_tensors_ready: A dict whose keys are keys in the dict returned by - create_graph_fn and values are a bools indicating whether that tensor - is ready to be replaced in this phase. - error_msg_regex: The expected error message. - """ - with tf.compat.v1.Graph().as_default() as graph: - tensors = create_graph_fn() - replaced_tensors_ready = [(tensors[name], ready) - for name, ready in replaced_tensors_ready.items()] - with self.assertRaisesRegex(ValueError, error_msg_regex): - graph_tools.InitializableGraphAnalyzer(graph, - {x: tensors[x] for x in feeds}, - replaced_tensors_ready) - - @test_case.parameters( - (_create_graph_with_y_function_of_x, [], {}, 'y', - 'may have be caused by manually adding a placeholder to the graph'), - (_create_graph_with_placeholder_in_tf_function, ['x'], {}, 'z', - 'manually adding a placeholder to the graph. tf.function name: `foo`'), - (_create_graph_with_y_function_of_x_and_untracked_table, ['x'], { - 'filename': True - }, 'y', 'may be caused by adding an initializable table without'), - ) - def testInitializableGraphAnalyzerReadyToRunRaises( - self, create_graph_fn, feeds, replaced_tensors_ready, fetch, - error_msg_regex): - """Test determine_ready_tensors_and_table_initializers. - - Args: - create_graph_fn: A function that adds ops to a graph and returns a dict - mapping tensor names to `Tensor` or `SparseTensor`s. - feeds: A list of keys in the dict returned by create_graph_fn that are fed - in the main run (but not table initialization run). - replaced_tensors_ready: A dict whose keys are keys in the dict returned by - create_graph_fn and values are a bools indicating whether that tensor - is ready to be replaced in this phase. - fetch: The tensor to fetch. Should be a key in the dict returned by - create_graph_fn. - error_msg_regex: The expected error message. - """ - with tf.compat.v1.Graph().as_default() as graph: - tensors = create_graph_fn() - replaced_tensors_ready = [( - tensors[name], ready) for name, ready in replaced_tensors_ready.items()] - graph_analyzer = graph_tools.InitializableGraphAnalyzer( - graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready) - with self.assertRaisesRegex(ValueError, error_msg_regex): - tensor = tensors[fetch] - graph_analyzer.ready_to_run(tensor) - - @test_case.named_parameters( - dict( - testcase_name='_y_function_of_x', - create_graph_fn=_create_graph_with_y_function_of_x, - feeds=['x'], - fetches=['y'], - expected_dependent_inputs=['x']), - dict( - testcase_name='_tf_function', - create_graph_fn=_create_graph_with_tf_function, - feeds=['x', 'y'], - fetches=['z'], - expected_dependent_inputs=['x', 'y']), - dict( - testcase_name='_tf_function_signature_forces_dependencies', - create_graph_fn=_create_graph_with_tf_function, - feeds=['x', 'y'], - fetches=['r'], - expected_dependent_inputs=['x', 'y']), - dict( - testcase_name='_tf_function_mixed_dependencies', - create_graph_fn=_create_graph_with_mixed_dependencies, - feeds=['x', 'y'], - fetches=['z'], - expected_dependent_inputs=['x', 'y']), - dict( - testcase_name='_chained_tf_function', - create_graph_fn=_create_graph_with_chained_tf_function, - feeds=['x'], - fetches=['y'], - expected_dependent_inputs=['x']), - dict( - testcase_name='_y_function_of_x_with_unused_inputs', - create_graph_fn=_create_graph_with_y_function_of_x_with_unused_inputs, - feeds=['x', 'x2', 'x_unused'], - fetches=['y', 'z'], - expected_dependent_inputs=['x', 'x2']), - dict( - testcase_name='_y_function_of_sparse_x', - create_graph_fn=_create_graph_with_y_function_of_x_sparse, - feeds=['x'], - fetches=['y'], - expected_dependent_inputs=['x']), - dict( - testcase_name='_y_sparse_function_of_sparse_x', - create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, - feeds=['x'], - fetches=['y'], - expected_dependent_inputs=['x']), - dict( - testcase_name='_y_function_of_ragged_x', - create_graph_fn=_create_graph_with_ragged_tensor, - feeds=['x1', 'x2'], - fetches=['y1', 'y2'], - expected_dependent_inputs=['x1', 'x2']), - dict( - testcase_name='_z_function_of_x_ragged', - create_graph_fn=_create_graph_with_z_function_of_x_ragged, - feeds=['x'], - fetches=['y', 'z'], - expected_dependent_inputs=['x']), - dict( - testcase_name='z_function_of_x_y_with_control_dependencies', - create_graph_fn=_create_graph_with_assert_equal, - feeds=['x', 'y'], - fetches=['z'], - expected_dependent_inputs=['x', 'y']), - dict( - testcase_name='_y_function_of_x_with_tf_while', - create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, - feeds=['x'], - fetches=['y'], - expected_dependent_inputs=['x']), - dict( - testcase_name='_tf2_saved_model_with_table', - create_graph_fn=_create_graph_with_tf2_saved_model_with_table, - feeds=['x'], - fetches=['a'], - expected_dependent_inputs=['x']), - ) - def testGetDependentInputs(self, create_graph_fn, feeds, fetches, - expected_dependent_inputs): - with tf.compat.v1.Graph().as_default() as graph: - tensors = create_graph_fn() - got = graph_tools.get_dependent_inputs(graph, - {x: tensors[x] for x in feeds}, - {y: tensors[y] for y in fetches}) - self.assertCountEqual(expected_dependent_inputs, got.keys()) - for input_name in expected_dependent_inputs: - self.assertIs(tensors[input_name], got[input_name]) + @test_case.named_parameters( + dict( + testcase_name="_y_function_of_x_nothing_ready", + create_graph_fn=_create_graph_with_y_function_of_x, + feeds=[], + replaced_tensors_ready={"x": False}, + should_be_ready={"y": False}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_unused_input_ready", + create_graph_fn=_create_graph_with_y_function_of_x_with_unused_inputs, + feeds=[], + replaced_tensors_ready={"x": False, "x2": True, "x_unused": True}, + should_be_ready={"y": False, "z": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_no_feeds_y_is_ready", + create_graph_fn=_create_graph_with_y_function_of_x, + feeds=[], + replaced_tensors_ready={"x": True}, + should_be_ready={"y": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_feeds_x_y_is_ready", + create_graph_fn=_create_graph_with_y_function_of_x, + feeds=["x"], + replaced_tensors_ready={}, + should_be_ready={"y": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_sparse_nothing_ready", + create_graph_fn=_create_graph_with_y_function_of_x_sparse, + feeds=[], + replaced_tensors_ready={"x": False}, + should_be_ready={"y": False}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_sparse_no_feeds_y_is_ready", + create_graph_fn=_create_graph_with_y_function_of_x_sparse, + feeds=[], + replaced_tensors_ready={"x": True}, + should_be_ready={"y": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_sparse_feeds_x_y_is_ready", + create_graph_fn=_create_graph_with_y_function_of_x_sparse, + feeds=["x"], + replaced_tensors_ready={}, + should_be_ready={"y": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_sparse_function_of_x_sparse_nothing_ready", + create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, + feeds=[], + replaced_tensors_ready={"x": False}, + should_be_ready={"y": False}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_sparse_function_of_x_sparse_no_feeds_y_is_ready", + create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, + feeds=[], + replaced_tensors_ready={"x": True}, + should_be_ready={"y": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_sparse_function_of_x_sparse_feeds_x_y_is_ready", + create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, + feeds=["x"], + replaced_tensors_ready={}, + should_be_ready={"y": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_with_tf_while_nothing_ready", + create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, + feeds=[], + replaced_tensors_ready={"x": False}, + should_be_ready={"y": False}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_with_tf_while_no_feeds_y_is_ready", + create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, + feeds=[], + replaced_tensors_ready={"x": True}, + should_be_ready={"y": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_with_tf_while_feeds_x_y_is_ready", + create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, + feeds=["x"], + replaced_tensors_ready={}, + should_be_ready={"y": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_and_table_nothing_ready", + create_graph_fn=_create_graph_with_y_function_of_x_and_table, + feeds=[], + replaced_tensors_ready={"x": False, "filename": False}, + should_be_ready={"y": False}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_and_table_filename_ready_y_is_not", + create_graph_fn=_create_graph_with_y_function_of_x_and_table, + feeds=[], + replaced_tensors_ready={"x": False, "filename": True}, + should_be_ready={"y": False}, + num_ready_table_initializers=1, + ), + dict( + testcase_name="_y_function_of_x_and_table_x_ready_filename_is_not", + create_graph_fn=_create_graph_with_y_function_of_x_and_table, + feeds=[], + replaced_tensors_ready={"x": True, "filename": False}, + should_be_ready={"y": False}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_and_table_everything_is_ready", + create_graph_fn=_create_graph_with_y_function_of_x_and_table, + feeds=[], + replaced_tensors_ready={ + "x": True, + "filename": True, + }, + should_be_ready={"y": True}, + num_ready_table_initializers=1, + ), + dict( + testcase_name="_y_function_of_x_and_table_feeds_x_nothing_ready", + create_graph_fn=_create_graph_with_y_function_of_x_and_table, + feeds=["x"], + replaced_tensors_ready={"filename": False}, + should_be_ready={"y": False}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_y_function_of_x_and_table_feeds_x_everything_ready", + create_graph_fn=_create_graph_with_y_function_of_x_and_table, + feeds=["x"], + replaced_tensors_ready={"filename": True}, + should_be_ready={"y": True}, + num_ready_table_initializers=1, + ), + dict( + testcase_name="_assert_equal", + create_graph_fn=_create_graph_with_assert_equal, + feeds=["x", "y"], + replaced_tensors_ready={ + "x": True, + "y": True, + }, + should_be_ready={"z": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_tf_function", + create_graph_fn=_create_graph_with_tf_function, + feeds=["x", "y"], + replaced_tensors_ready={}, + should_be_ready={"z": True}, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_tf_function_not_ready", + create_graph_fn=_create_graph_with_tf_function, + feeds=[], + replaced_tensors_ready={ + "x": True, + "y": False, + }, + should_be_ready={ + "z": False, + "q": True, + }, + num_ready_table_initializers=0, + ), + dict( + testcase_name="_chained_tf_function", + create_graph_fn=_create_graph_with_chained_tf_function, + feeds=["x"], + replaced_tensors_ready={}, + should_be_ready={"y": True}, + num_ready_table_initializers=0, + ), + ) + def testDetermineReadyTensorsAndTableInitializers( + self, + create_graph_fn, + feeds, + replaced_tensors_ready, + should_be_ready, + num_ready_table_initializers, + ): + """Test determine_ready_tensors_and_table_initializers. + + Args: + ---- + create_graph_fn: A function that adds ops to a graph and returns a dict + mapping tensor names to `Tensor` or `SparseTensor`s. + feeds: A list of keys in the dict returned by create_graph_fn that are fed + in the main run (but not table initialization run). + replaced_tensors_ready: A dict whose keys are keys in the dict returned by + create_graph_fn and values are a bools indicating whether that tensor + is ready to be replaced in this phase. + should_be_ready: A dict whose keys are keys in the dict returned by + create_graph_fn and value are bools indicating whether a tensor can be + calculated in this phase. + num_ready_table_initializers: The number of table initializers that are + ready to run in the table initialization run of this phase. + """ + with tf.compat.v1.Graph().as_default() as graph: + tensors = create_graph_fn() + replaced_tensors_ready = [ + (tensors[name], ready) for name, ready in replaced_tensors_ready.items() + ] + + graph_analyzer = graph_tools.InitializableGraphAnalyzer( + graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready + ) + self.assertEqual( + len(graph_analyzer.ready_table_initializers), num_ready_table_initializers + ) + + for name, ready in should_be_ready.items(): + tensor = tensors[name] + self.assertEqual( + graph_analyzer.ready_to_run(tensor), + ready, + msg=f"Expected tensor {name} to be ready={ready}", + ) + + @test_case.parameters( + ( + _create_graph_with_y_function_of_x_and_table, + [], + {"x": False}, + "placeholders will not be fed during table initialization", + ), + ( + _create_graph_with_y_function_of_x_and_table, + [], + {"x": True}, + "placeholders will not be fed during table initialization", + ), + ( + _create_graph_with_y_function_of_x_and_table, + ["filename"], + {"x": False}, + "placeholders will not be fed during table initialization", + ), + ( + _create_graph_with_y_function_of_x_and_table, + ["filename"], + {"x": True}, + "placeholders will not be fed during table initialization", + ), + ( + _create_graph_with_y_function_of_x_and_table, + ["filename", "x"], + {}, + "placeholders will not be fed during table initialization", + ), + ( + _create_graph_with_table_initialized_by_table_output, + ["x"], + {"filename": True}, + "tables are initialized in one pass", + ), + ) + def testInitializableGraphAnalyzerConstructorRaises( + self, create_graph_fn, feeds, replaced_tensors_ready, error_msg_regex + ): + """Test determine_ready_tensors_and_table_initializers. + + Args: + ---- + create_graph_fn: A function that adds ops to a graph and returns a dict + mapping tensor names to `Tensor` or `SparseTensor`s. + feeds: A list of keys in the dict returned by create_graph_fn that are fed + in the main run (but not table initialization run). + replaced_tensors_ready: A dict whose keys are keys in the dict returned by + create_graph_fn and values are a bools indicating whether that tensor + is ready to be replaced in this phase. + error_msg_regex: The expected error message. + """ + with tf.compat.v1.Graph().as_default() as graph: + tensors = create_graph_fn() + replaced_tensors_ready = [ + (tensors[name], ready) for name, ready in replaced_tensors_ready.items() + ] + with self.assertRaisesRegex(ValueError, error_msg_regex): + graph_tools.InitializableGraphAnalyzer( + graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready + ) + + @test_case.parameters( + ( + _create_graph_with_y_function_of_x, + [], + {}, + "y", + "may have be caused by manually adding a placeholder to the graph", + ), + ( + _create_graph_with_placeholder_in_tf_function, + ["x"], + {}, + "z", + "manually adding a placeholder to the graph. tf.function name: `foo`", + ), + ( + _create_graph_with_y_function_of_x_and_untracked_table, + ["x"], + {"filename": True}, + "y", + "may be caused by adding an initializable table without", + ), + ) + def testInitializableGraphAnalyzerReadyToRunRaises( + self, create_graph_fn, feeds, replaced_tensors_ready, fetch, error_msg_regex + ): + """Test determine_ready_tensors_and_table_initializers. + + Args: + ---- + create_graph_fn: A function that adds ops to a graph and returns a dict + mapping tensor names to `Tensor` or `SparseTensor`s. + feeds: A list of keys in the dict returned by create_graph_fn that are fed + in the main run (but not table initialization run). + replaced_tensors_ready: A dict whose keys are keys in the dict returned by + create_graph_fn and values are a bools indicating whether that tensor + is ready to be replaced in this phase. + fetch: The tensor to fetch. Should be a key in the dict returned by + create_graph_fn. + error_msg_regex: The expected error message. + """ + with tf.compat.v1.Graph().as_default() as graph: + tensors = create_graph_fn() + replaced_tensors_ready = [ + (tensors[name], ready) for name, ready in replaced_tensors_ready.items() + ] + graph_analyzer = graph_tools.InitializableGraphAnalyzer( + graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready + ) + with self.assertRaisesRegex(ValueError, error_msg_regex): + tensor = tensors[fetch] + graph_analyzer.ready_to_run(tensor) + + @test_case.named_parameters( + dict( + testcase_name="_y_function_of_x", + create_graph_fn=_create_graph_with_y_function_of_x, + feeds=["x"], + fetches=["y"], + expected_dependent_inputs=["x"], + ), + dict( + testcase_name="_tf_function", + create_graph_fn=_create_graph_with_tf_function, + feeds=["x", "y"], + fetches=["z"], + expected_dependent_inputs=["x", "y"], + ), + dict( + testcase_name="_tf_function_signature_forces_dependencies", + create_graph_fn=_create_graph_with_tf_function, + feeds=["x", "y"], + fetches=["r"], + expected_dependent_inputs=["x", "y"], + ), + dict( + testcase_name="_tf_function_mixed_dependencies", + create_graph_fn=_create_graph_with_mixed_dependencies, + feeds=["x", "y"], + fetches=["z"], + expected_dependent_inputs=["x", "y"], + ), + dict( + testcase_name="_chained_tf_function", + create_graph_fn=_create_graph_with_chained_tf_function, + feeds=["x"], + fetches=["y"], + expected_dependent_inputs=["x"], + ), + dict( + testcase_name="_y_function_of_x_with_unused_inputs", + create_graph_fn=_create_graph_with_y_function_of_x_with_unused_inputs, + feeds=["x", "x2", "x_unused"], + fetches=["y", "z"], + expected_dependent_inputs=["x", "x2"], + ), + dict( + testcase_name="_y_function_of_sparse_x", + create_graph_fn=_create_graph_with_y_function_of_x_sparse, + feeds=["x"], + fetches=["y"], + expected_dependent_inputs=["x"], + ), + dict( + testcase_name="_y_sparse_function_of_sparse_x", + create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, + feeds=["x"], + fetches=["y"], + expected_dependent_inputs=["x"], + ), + dict( + testcase_name="_y_function_of_ragged_x", + create_graph_fn=_create_graph_with_ragged_tensor, + feeds=["x1", "x2"], + fetches=["y1", "y2"], + expected_dependent_inputs=["x1", "x2"], + ), + dict( + testcase_name="_z_function_of_x_ragged", + create_graph_fn=_create_graph_with_z_function_of_x_ragged, + feeds=["x"], + fetches=["y", "z"], + expected_dependent_inputs=["x"], + ), + dict( + testcase_name="z_function_of_x_y_with_control_dependencies", + create_graph_fn=_create_graph_with_assert_equal, + feeds=["x", "y"], + fetches=["z"], + expected_dependent_inputs=["x", "y"], + ), + dict( + testcase_name="_y_function_of_x_with_tf_while", + create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, + feeds=["x"], + fetches=["y"], + expected_dependent_inputs=["x"], + ), + dict( + testcase_name="_tf2_saved_model_with_table", + create_graph_fn=_create_graph_with_tf2_saved_model_with_table, + feeds=["x"], + fetches=["a"], + expected_dependent_inputs=["x"], + ), + ) + def testGetDependentInputs( + self, create_graph_fn, feeds, fetches, expected_dependent_inputs + ): + with tf.compat.v1.Graph().as_default() as graph: + tensors = create_graph_fn() + got = graph_tools.get_dependent_inputs( + graph, {x: tensors[x] for x in feeds}, {y: tensors[y] for y in fetches} + ) + self.assertCountEqual(expected_dependent_inputs, got.keys()) + for input_name in expected_dependent_inputs: + self.assertIs(tensors[input_name], got[input_name]) class GraphToolsTestUniquePath(test_case.TransformTestCase): + @test_case.named_parameters( + dict( + testcase_name="_y_function_of_x", + create_graph_fn=_create_graph_with_y_function_of_x, + feeds=["x"], + replaced_tensors_ready={"x": False}, + expected_calls_dict={ + "x": [ + mock.call("x$tensor"), + ], + "y": [ + mock.call(_OpMatcher("add/y"), parents=[]), + mock.call(_TensorMatcher("add/y:0"), parents=["add/y"]), + mock.call("x$tensor"), + mock.call(_OpMatcher("add"), parents=["x$tensor", "add/y:0"]), + mock.call(_TensorMatcher("add:0"), parents=["add"]), + ], + }, + ), + dict( + testcase_name="_y_function_of_x_and_tf_function", + create_graph_fn=_create_graph_with_tf_function, + feeds=["x", "y"], + replaced_tensors_ready={"x": False, "y": False}, + expected_calls_dict={ + "x": [ + mock.call("x$tensor"), + ], + "y": [ + mock.call("y$tensor"), + ], + "z": [ + mock.call("y$tensor"), + mock.call("x$tensor"), + mock.call(_OpMatcher("mul/y"), parents=[]), + mock.call(_TensorMatcher("mul/y:0"), parents=["mul/y"]), + mock.call("FuncGraphInput[0]"), + mock.call( + _OpMatcher("mul"), parents=["FuncGraphInput[0]", "mul/y:0"] + ), + mock.call(_TensorMatcher("mul:0"), parents=["mul"]), + mock.call(_OpMatcher("Identity"), parents=["mul:0"]), + mock.call(_TensorMatcher("Identity:0"), parents=["Identity"]), + mock.call("FuncGraphInput[1]"), + mock.call( + _OpMatcher("add"), + parents=["FuncGraphInput[0]", "FuncGraphInput[1]"], + ), + mock.call(_TensorMatcher("add:0"), parents=["add"]), + mock.call(_OpMatcher("Identity_1"), parents=["add:0"]), + mock.call(_TensorMatcher("Identity_1:0"), parents=["Identity_1"]), + mock.call(_OpMatcher("mul_1/y"), parents=[]), + mock.call(_TensorMatcher("mul_1/y:0"), parents=["mul_1/y"]), + mock.call( + _OpMatcher("mul_1"), parents=["FuncGraphInput[1]", "mul_1/y:0"] + ), + mock.call(_TensorMatcher("mul_1:0"), parents=["mul_1"]), + mock.call(_OpMatcher("Identity_2"), parents=["mul_1:0"]), + mock.call(_TensorMatcher("Identity_2:0"), parents=["Identity_2"]), + mock.call( + _OpMatcher("PartitionedCall"), + parents=[ + "x$tensor", + "y$tensor", + "Identity:0", + "Identity_1:0", + "Identity_2:0", + ], + ), + mock.call( + _TensorMatcher("PartitionedCall:2"), parents=["PartitionedCall"] + ), + mock.call( + _TensorMatcher("PartitionedCall:1"), parents=["PartitionedCall"] + ), + mock.call( + _TensorMatcher("PartitionedCall:0"), parents=["PartitionedCall"] + ), + mock.call( + _OpMatcher("add"), + parents=["PartitionedCall:0", "PartitionedCall:1"], + ), + mock.call(_TensorMatcher("add:0"), parents=["add"]), + mock.call( + _OpMatcher("add_1"), parents=["add:0", "PartitionedCall:2"] + ), + mock.call(_TensorMatcher("add_1:0"), parents=["add_1"]), + ], + }, + ), + dict( + testcase_name="_y_function_of_x_and_chained_tf_function", + create_graph_fn=_create_graph_with_chained_tf_function, + feeds=["x"], + replaced_tensors_ready={"x": False}, + expected_calls_dict={ + "x": [ + mock.call("x$tensor"), + ], + "y": [ + mock.call(_OpMatcher("truediv/y"), parents=[]), + mock.call(_TensorMatcher("truediv/y:0"), parents=["truediv/y"]), + mock.call(_OpMatcher("truediv/Cast_1"), parents=["truediv/y:0"]), + mock.call( + _TensorMatcher("truediv/Cast_1:0"), parents=["truediv/Cast_1"] + ), + mock.call("x$tensor"), + mock.call(_OpMatcher("mul/y"), parents=[]), + mock.call(_TensorMatcher("mul/y:0"), parents=["mul/y"]), + mock.call("FuncGraphInput[0]"), + mock.call(_OpMatcher("add/y"), parents=[]), + mock.call(_TensorMatcher("add/y:0"), parents=["add/y"]), + mock.call("FuncGraphInput[0]"), + mock.call( + _OpMatcher("add"), parents=["FuncGraphInput[0]", "add/y:0"] + ), + mock.call(_TensorMatcher("add:0"), parents=["add"]), + mock.call(_OpMatcher("Identity"), parents=["add:0"]), + mock.call(_TensorMatcher("Identity:0"), parents=["Identity"]), + mock.call( + _OpMatcher("PartitionedCall"), + parents=["FuncGraphInput[0]", "Identity:0"], + ), + mock.call( + _TensorMatcher("PartitionedCall:0"), parents=["PartitionedCall"] + ), + mock.call( + _OpMatcher("mul"), parents=["PartitionedCall:0", "mul/y:0"] + ), + mock.call(_TensorMatcher("mul:0"), parents=["mul"]), + mock.call(_OpMatcher("Identity"), parents=["mul:0"]), + mock.call(_TensorMatcher("Identity:0"), parents=["Identity"]), + mock.call( + _OpMatcher("PartitionedCall"), + parents=["x$tensor", "Identity:0"], + ), + mock.call( + _TensorMatcher("PartitionedCall:0"), parents=["PartitionedCall"] + ), + mock.call( + _OpMatcher("truediv/Cast"), parents=["PartitionedCall:0"] + ), + mock.call( + _TensorMatcher("truediv/Cast:0"), parents=["truediv/Cast"] + ), + mock.call( + _OpMatcher("truediv"), + parents=["truediv/Cast:0", "truediv/Cast_1:0"], + ), + mock.call(_TensorMatcher("truediv:0"), parents=["truediv"]), + ], + }, + ), + dict( + testcase_name="_y_function_of_x_sparse", + create_graph_fn=_create_graph_with_y_function_of_x_sparse, + feeds=["x"], + replaced_tensors_ready={"x": False}, + expected_calls_dict={ + "y": [ + mock.call(_OpMatcher("add/y"), parents=[]), + mock.call(_TensorMatcher("add/y:0"), parents=["add/y"]), + mock.call(_OpMatcher("range/delta"), parents=[]), + mock.call(_TensorMatcher("range/delta:0"), parents=["range/delta"]), + mock.call("x$composite_tensor_2"), + mock.call(_OpMatcher("Rank"), parents=["x$composite_tensor_2"]), + mock.call(_TensorMatcher("Rank:0"), parents=["Rank"]), + mock.call(_OpMatcher("range/start"), parents=[]), + mock.call(_TensorMatcher("range/start:0"), parents=["range/start"]), + mock.call( + _OpMatcher("range"), + parents=["range/start:0", "Rank:0", "range/delta:0"], + ), + mock.call(_TensorMatcher("range:0"), parents=["range"]), + mock.call("x$composite_tensor_1"), + mock.call("x$composite_tensor_0"), + mock.call( + _OpMatcher("SparseReduceSum"), + parents=[ + "x$composite_tensor_0", + "x$composite_tensor_1", + "x$composite_tensor_2", + "range:0", + ], + ), + mock.call( + _TensorMatcher("SparseReduceSum:0"), parents=["SparseReduceSum"] + ), + mock.call( + _OpMatcher("add"), parents=["SparseReduceSum:0", "add/y:0"] + ), + mock.call(_TensorMatcher("add:0"), parents=["add"]), + ] + }, + ), + dict( + testcase_name="_y_sparse_function_of_x_sparse", + create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, + feeds=["x"], + replaced_tensors_ready={"x": False}, + expected_calls_dict={ + "z": [ + mock.call(_OpMatcher("ones/Const"), parents=[]), + mock.call(_TensorMatcher("ones/Const:0"), parents=["ones/Const"]), + mock.call("x$composite_tensor_2"), + mock.call( + _OpMatcher("ones"), + parents=["x$composite_tensor_2", "ones/Const:0"], + ), + mock.call(_TensorMatcher("ones:0"), parents=["ones"]), + mock.call(_OpMatcher("add/y"), parents=[]), + mock.call(_TensorMatcher("add/y:0"), parents=["add/y"]), + mock.call("x$composite_tensor_1"), + mock.call( + _OpMatcher("add"), parents=["x$composite_tensor_1", "add/y:0"] + ), + mock.call(_TensorMatcher("add:0"), parents=["add"]), + mock.call("x$composite_tensor_0"), + mock.call( + _OpMatcher("SparseTensorDenseAdd"), + parents=[ + "x$composite_tensor_0", + "add:0", + "x$composite_tensor_2", + "ones:0", + ], + ), + mock.call( + _TensorMatcher("SparseTensorDenseAdd:0"), + parents=["SparseTensorDenseAdd"], + ), + ], + }, + ), + dict( + testcase_name="_z_function_of_x_ragged", + create_graph_fn=_create_graph_with_z_function_of_x_ragged, + feeds=["x"], + replaced_tensors_ready={"x": False}, + expected_calls_dict={ + "z": [ + mock.call(_OpMatcher("add/y"), parents=[]), + mock.call(_TensorMatcher("add/y:0"), parents=["add/y"]), + mock.call(_OpMatcher("range/delta"), parents=[]), + mock.call(_TensorMatcher("range/delta:0"), parents=["range/delta"]), + mock.call("x$composite_tensor_0"), + mock.call("x$composite_tensor_2"), + mock.call("x$composite_tensor_1"), + mock.call( + _OpMatcher("RaggedToSparse/RaggedTensorToSparse"), + parents=[ + "x$composite_tensor_1", + "x$composite_tensor_2", + "x$composite_tensor_0", + ], + ), + mock.call( + _TensorMatcher("RaggedToSparse/RaggedTensorToSparse:2"), + parents=["RaggedToSparse/RaggedTensorToSparse"], + ), + mock.call( + _OpMatcher("Rank"), + parents=["RaggedToSparse/RaggedTensorToSparse:2"], + ), + mock.call(_TensorMatcher("Rank:0"), parents=["Rank"]), + mock.call(_OpMatcher("range/start"), parents=[]), + mock.call(_TensorMatcher("range/start:0"), parents=["range/start"]), + mock.call( + _OpMatcher("range"), + parents=["range/start:0", "Rank:0", "range/delta:0"], + ), + mock.call(_TensorMatcher("range:0"), parents=["range"]), + mock.call( + _TensorMatcher("RaggedToSparse/RaggedTensorToSparse:1"), + parents=["RaggedToSparse/RaggedTensorToSparse"], + ), + mock.call( + _TensorMatcher("RaggedToSparse/RaggedTensorToSparse:0"), + parents=["RaggedToSparse/RaggedTensorToSparse"], + ), + mock.call( + _OpMatcher("SparseReduceSum"), + parents=[ + "RaggedToSparse/RaggedTensorToSparse:0", + "RaggedToSparse/RaggedTensorToSparse:1", + "RaggedToSparse/RaggedTensorToSparse:2", + "range:0", + ], + ), + mock.call( + _TensorMatcher("SparseReduceSum:0"), parents=["SparseReduceSum"] + ), + mock.call( + _OpMatcher("add"), parents=["SparseReduceSum:0", "add/y:0"] + ), + mock.call(_TensorMatcher("add:0"), parents=["add"]), + ], + }, + ), + dict( + testcase_name="_y_function_of_x_with_raw_ops_while", + create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, + feeds=["x"], + replaced_tensors_ready={"x": False}, + expected_calls_dict={ + "y": [ + mock.call("x$tensor"), + mock.call(_OpMatcher("Const"), parents=[]), + mock.call(_TensorMatcher("Const:0"), parents=["Const"]), + mock.call(_OpMatcher("Less/y"), parents=[]), + mock.call(_TensorMatcher("Less/y:0"), parents=["Less/y"]), + mock.call("FuncGraphInput[0]"), + mock.call( + _OpMatcher("Less"), parents=["FuncGraphInput[0]", "Less/y:0"] + ), + mock.call(_TensorMatcher("Less:0"), parents=["Less"]), + mock.call(_OpMatcher("Identity"), parents=["Less:0"]), + mock.call(_TensorMatcher("Identity:0"), parents=["Identity"]), + mock.call(_OpMatcher("Add/y"), parents=[]), + mock.call(_TensorMatcher("Add/y:0"), parents=["Add/y"]), + mock.call("FuncGraphInput[0]"), + mock.call( + _OpMatcher("Add"), parents=["FuncGraphInput[0]", "Add/y:0"] + ), + mock.call(_TensorMatcher("Add:0"), parents=["Add"]), + mock.call(_OpMatcher("Identity"), parents=["Add:0"]), + mock.call(_TensorMatcher("Identity:0"), parents=["Identity"]), + mock.call(_OpMatcher("Add_1/y"), parents=[]), + mock.call(_TensorMatcher("Add_1/y:0"), parents=["Add_1/y"]), + mock.call("FuncGraphInput[1]"), + mock.call( + _OpMatcher("Add_1"), parents=["FuncGraphInput[1]", "Add_1/y:0"] + ), + mock.call(_TensorMatcher("Add_1:0"), parents=["Add_1"]), + mock.call(_OpMatcher("Identity_1"), parents=["Add_1:0"]), + mock.call(_TensorMatcher("Identity_1:0"), parents=["Identity_1"]), + mock.call( + _OpMatcher("While"), + parents=[ + "Const:0", + "x$tensor", + "Identity:0", + "Identity:0", + "Identity_1:0", + ], + ), + mock.call(_TensorMatcher("While:1"), parents=["While"]), + ], + }, + ), + dict( + testcase_name="_y_function_of_x_with_tf_while", + create_graph_fn=_create_graph_with_tf_function_while, + feeds=["x"], + replaced_tensors_ready={"x": False}, + expected_calls_dict={ + "y": [ + mock.call("x$tensor"), + mock.call("FuncGraphInput[0]"), + mock.call(_OpMatcher("while/maximum_iterations"), parents=[]), + mock.call( + _TensorMatcher("while/maximum_iterations:0"), + parents=["while/maximum_iterations"], + ), + mock.call(_OpMatcher("while/loop_counter"), parents=[]), + mock.call( + _TensorMatcher("while/loop_counter:0"), + parents=["while/loop_counter"], + ), + mock.call(_OpMatcher("Less/y"), parents=[]), + mock.call(_TensorMatcher("Less/y:0"), parents=["Less/y"]), + mock.call("FuncGraphInput[2]"), + mock.call( + _OpMatcher("Less"), parents=["FuncGraphInput[2]", "Less/y:0"] + ), + mock.call(_TensorMatcher("Less:0"), parents=["Less"]), + mock.call(_OpMatcher("Identity"), parents=["Less:0"]), + mock.call(_TensorMatcher("Identity:0"), parents=["Identity"]), + mock.call(_OpMatcher("add/y"), parents=[]), + mock.call(_TensorMatcher("add/y:0"), parents=["add/y"]), + mock.call("FuncGraphInput[0]"), + mock.call( + _OpMatcher("add"), parents=["FuncGraphInput[0]", "add/y:0"] + ), + mock.call(_TensorMatcher("add:0"), parents=["add"]), + mock.call(_OpMatcher("Identity"), parents=["add:0"]), + mock.call(_TensorMatcher("Identity:0"), parents=["Identity"]), + mock.call("FuncGraphInput[1]"), + mock.call(_OpMatcher("Identity_1"), parents=["FuncGraphInput[1]"]), + mock.call(_TensorMatcher("Identity_1:0"), parents=["Identity_1"]), + mock.call(_OpMatcher("mul/y"), parents=[]), + mock.call(_TensorMatcher("mul/y:0"), parents=["mul/y"]), + mock.call("FuncGraphInput[2]"), + mock.call( + _OpMatcher("mul"), parents=["FuncGraphInput[2]", "mul/y:0"] + ), + mock.call(_TensorMatcher("mul:0"), parents=["mul"]), + mock.call(_OpMatcher("Identity_2"), parents=["mul:0"]), + mock.call(_TensorMatcher("Identity_2:0"), parents=["Identity_2"]), + mock.call( + _OpMatcher("while"), + parents=[ + "while/loop_counter:0", + "while/maximum_iterations:0", + "FuncGraphInput[0]", + "Identity:0", + "Identity:0", + "Identity_1:0", + "Identity_2:0", + ], + ), + mock.call(_TensorMatcher("while:2"), parents=["while"]), + mock.call(_OpMatcher("Identity"), parents=["while:2"]), + mock.call(_TensorMatcher("Identity:0"), parents=["Identity"]), + mock.call( + _OpMatcher("PartitionedCall"), + parents=["x$tensor", "Identity:0"], + ), + mock.call( + _TensorMatcher("PartitionedCall:0"), parents=["PartitionedCall"] + ), + ], + }, + ), + dict( + testcase_name="_y_function_of_x_and_table", + create_graph_fn=_create_graph_with_y_function_of_x_and_table_in_first_phase, + feeds=["x"], + replaced_tensors_ready={"x": False}, + expected_calls_dict={ + "x": [ + mock.call(_OpMatcher("Const"), parents=[]), + mock.call(_TensorMatcher("Const:0"), parents=["Const"]), + mock.call(_OpMatcher("hash_table"), parents=["Const:0"]), + mock.call("x$tensor"), + ], + "y": [ + mock.call(_OpMatcher("Const"), parents=[]), + mock.call(_TensorMatcher("Const:0"), parents=["Const"]), + mock.call(_OpMatcher("hash_table"), parents=["Const:0"]), + mock.call(_OpMatcher("Const_1"), parents=[]), + mock.call(_TensorMatcher("Const_1:0"), parents=["Const_1"]), + mock.call("x$tensor"), + mock.call("hash_table"), + mock.call(_TensorMatcher("hash_table:0"), parents=["hash_table"]), + mock.call( + _OpMatcher("hash_table_Lookup/LookupTableFindV2"), + parents=["hash_table:0", "x$tensor", "Const_1:0"], + ), + mock.call( + _TensorMatcher("hash_table_Lookup/LookupTableFindV2:0"), + parents=["hash_table_Lookup/LookupTableFindV2"], + ), + ], + }, + ), + dict( + testcase_name="_with_assert_equal", + create_graph_fn=_create_graph_with_assert_equal, + feeds=["x", "y"], + replaced_tensors_ready={"x": False, "y": False}, + expected_calls_dict={ + "x": [ + mock.call("x$tensor"), + ], + "y": [ + mock.call("y$tensor"), + ], + "z": [ + mock.call("y$tensor"), + mock.call("x$tensor"), + mock.call(_OpMatcher("Equal"), parents=["x$tensor", "y$tensor"]), + mock.call(_TensorMatcher("Equal:0"), parents=["Equal"]), + mock.call( + _OpMatcher("Assert"), + parents=["Equal:0", "x$tensor", "y$tensor"], + ), + mock.call( + _OpMatcher("control_dependency"), parents=["x$tensor", "Assert"] + ), + mock.call( + _TensorMatcher("control_dependency:0"), + parents=["control_dependency"], + ), + ], + }, + ), + ) + def testGetUniquePath( + self, create_graph_fn, feeds, replaced_tensors_ready, expected_calls_dict + ): + with tf.compat.v1.Graph().as_default() as graph: + tensors = create_graph_fn() + replaced_tensors_ready = [ + (tensors[name], ready) for name, ready in replaced_tensors_ready.items() + ] + for name in expected_calls_dict: + # This is used to construct the debugging string below. + actual_needed_matchers_to_pass = [] + + def describe_path_fn(x, parents=None): + if parents is None: + parents_str = "" + else: + parents_str = f", parents={list(map(_value_to_matcher, parents))}" + actual_needed_matchers_to_pass.append( + "({}{}),".format( # pylint: disable=cell-var-from-loop + _value_to_matcher(x, True), parents_str + ) + ) + + if isinstance(x, tf.Operation): + return x.node_def.name + if isinstance(x, tf.Tensor): + self.assertLessEqual(len(parents), 1) + return x.name + if isinstance(x, (str, bytes)): + return x + raise ValueError(f"Unexpected type: {x}") + + path_cb_mock = mock.MagicMock(side_effect=describe_path_fn) + + graph_analyzer = graph_tools.InitializableGraphAnalyzer( + graph, + {x: tensors[x] for x in feeds}, + replaced_tensors_ready, + path_cb_mock, + ) + + graph_analyzer.get_unique_path(tensors[name]) + + try: + path_cb_mock.assert_has_calls(expected_calls_dict[name]) + self.assertEqual( + path_cb_mock.call_count, + len(expected_calls_dict[name]), + f"Number of expected calls != number of actual calls for {name}: {path_cb_mock.call_args_list}", + ) + except AssertionError: + tf.compat.v1.logging.error( + "The following is a list of matchers for {}:\n{}".format( + name, "\n".join(actual_needed_matchers_to_pass) + ) + ) + raise - @test_case.named_parameters( - dict( - testcase_name='_y_function_of_x', - create_graph_fn=_create_graph_with_y_function_of_x, - feeds=['x'], - replaced_tensors_ready={'x': False}, - expected_calls_dict={ - 'x': [mock.call('x$tensor'),], - 'y': [ - mock.call(_OpMatcher('add/y'), parents=[]), - mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']), - mock.call('x$tensor'), - mock.call( - _OpMatcher('add'), parents=['x$tensor', u'add/y:0']), - mock.call(_TensorMatcher('add:0'), parents=[u'add']), - ] - }), - dict( - testcase_name='_y_function_of_x_and_tf_function', - create_graph_fn=_create_graph_with_tf_function, - feeds=['x', 'y'], - replaced_tensors_ready={ - 'x': False, - 'y': False - }, - expected_calls_dict={ - 'x': [mock.call('x$tensor'),], - 'y': [mock.call('y$tensor'),], - 'z': [ - mock.call('y$tensor'), - mock.call('x$tensor'), - mock.call(_OpMatcher('mul/y'), parents=[]), - mock.call(_TensorMatcher('mul/y:0'), parents=[u'mul/y']), - mock.call('FuncGraphInput[0]'), - mock.call( - _OpMatcher('mul'), - parents=['FuncGraphInput[0]', u'mul/y:0']), - mock.call(_TensorMatcher('mul:0'), parents=[u'mul']), - mock.call(_OpMatcher('Identity'), parents=[u'mul:0']), - mock.call( - _TensorMatcher('Identity:0'), parents=[u'Identity']), - mock.call('FuncGraphInput[1]'), - mock.call( - _OpMatcher('add'), - parents=['FuncGraphInput[0]', 'FuncGraphInput[1]']), - mock.call(_TensorMatcher('add:0'), parents=[u'add']), - mock.call(_OpMatcher('Identity_1'), parents=[u'add:0']), - mock.call( - _TensorMatcher('Identity_1:0'), parents=[u'Identity_1']), - mock.call(_OpMatcher('mul_1/y'), parents=[]), - mock.call(_TensorMatcher('mul_1/y:0'), parents=[u'mul_1/y']), - mock.call( - _OpMatcher('mul_1'), - parents=['FuncGraphInput[1]', u'mul_1/y:0']), - mock.call(_TensorMatcher('mul_1:0'), parents=[u'mul_1']), - mock.call(_OpMatcher('Identity_2'), parents=[u'mul_1:0']), - mock.call( - _TensorMatcher('Identity_2:0'), parents=[u'Identity_2']), - mock.call( - _OpMatcher('PartitionedCall'), - parents=[ - 'x$tensor', 'y$tensor', u'Identity:0', - u'Identity_1:0', u'Identity_2:0' - ]), - mock.call( - _TensorMatcher('PartitionedCall:2'), - parents=[u'PartitionedCall']), - mock.call( - _TensorMatcher('PartitionedCall:1'), - parents=[u'PartitionedCall']), - mock.call( - _TensorMatcher('PartitionedCall:0'), - parents=[u'PartitionedCall']), - mock.call( - _OpMatcher('add'), - parents=[u'PartitionedCall:0', u'PartitionedCall:1']), - mock.call(_TensorMatcher('add:0'), parents=[u'add']), - mock.call( - _OpMatcher('add_1'), - parents=[u'add:0', u'PartitionedCall:2']), - mock.call(_TensorMatcher('add_1:0'), parents=[u'add_1']), - ] - }), - dict( - testcase_name='_y_function_of_x_and_chained_tf_function', - create_graph_fn=_create_graph_with_chained_tf_function, - feeds=['x'], - replaced_tensors_ready={'x': False}, - expected_calls_dict={ - 'x': [mock.call('x$tensor'),], - 'y': [ - mock.call(_OpMatcher('truediv/y'), parents=[]), - mock.call( - _TensorMatcher('truediv/y:0'), parents=[u'truediv/y']), - mock.call( - _OpMatcher('truediv/Cast_1'), parents=[u'truediv/y:0']), - mock.call( - _TensorMatcher('truediv/Cast_1:0'), - parents=[u'truediv/Cast_1']), - mock.call('x$tensor'), - mock.call(_OpMatcher('mul/y'), parents=[]), - mock.call(_TensorMatcher('mul/y:0'), parents=[u'mul/y']), - mock.call('FuncGraphInput[0]'), - mock.call(_OpMatcher('add/y'), parents=[]), - mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']), - mock.call('FuncGraphInput[0]'), - mock.call( - _OpMatcher('add'), - parents=['FuncGraphInput[0]', u'add/y:0']), - mock.call(_TensorMatcher('add:0'), parents=[u'add']), - mock.call(_OpMatcher('Identity'), parents=[u'add:0']), - mock.call( - _TensorMatcher('Identity:0'), parents=[u'Identity']), - mock.call( - _OpMatcher('PartitionedCall'), - parents=['FuncGraphInput[0]', u'Identity:0']), - mock.call( - _TensorMatcher('PartitionedCall:0'), - parents=[u'PartitionedCall']), - mock.call( - _OpMatcher('mul'), - parents=[u'PartitionedCall:0', u'mul/y:0']), - mock.call(_TensorMatcher('mul:0'), parents=[u'mul']), - mock.call(_OpMatcher('Identity'), parents=[u'mul:0']), - mock.call( - _TensorMatcher('Identity:0'), parents=[u'Identity']), - mock.call( - _OpMatcher('PartitionedCall'), - parents=['x$tensor', u'Identity:0']), - mock.call( - _TensorMatcher('PartitionedCall:0'), - parents=[u'PartitionedCall']), - mock.call( - _OpMatcher('truediv/Cast'), - parents=[u'PartitionedCall:0']), - mock.call( - _TensorMatcher('truediv/Cast:0'), - parents=[u'truediv/Cast']), - mock.call( - _OpMatcher('truediv'), - parents=[u'truediv/Cast:0', u'truediv/Cast_1:0']), - mock.call(_TensorMatcher('truediv:0'), parents=[u'truediv']), - ], - }), - dict( - testcase_name='_y_function_of_x_sparse', - create_graph_fn=_create_graph_with_y_function_of_x_sparse, - feeds=['x'], - replaced_tensors_ready={'x': False}, - expected_calls_dict={ - 'y': [ - mock.call(_OpMatcher('add/y'), parents=[]), - mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']), - mock.call(_OpMatcher('range/delta'), parents=[]), - mock.call( - _TensorMatcher('range/delta:0'), - parents=[u'range/delta']), - mock.call('x$composite_tensor_2'), - mock.call( - _OpMatcher('Rank'), parents=['x$composite_tensor_2']), - mock.call(_TensorMatcher('Rank:0'), parents=[u'Rank']), - mock.call(_OpMatcher('range/start'), parents=[]), - mock.call( - _TensorMatcher('range/start:0'), - parents=[u'range/start']), - mock.call( - _OpMatcher('range'), - parents=[u'range/start:0', u'Rank:0', u'range/delta:0']), - mock.call(_TensorMatcher('range:0'), parents=[u'range']), - mock.call('x$composite_tensor_1'), - mock.call('x$composite_tensor_0'), - mock.call( - _OpMatcher('SparseReduceSum'), - parents=[ - 'x$composite_tensor_0', 'x$composite_tensor_1', - 'x$composite_tensor_2', u'range:0' - ]), - mock.call( - _TensorMatcher('SparseReduceSum:0'), - parents=[u'SparseReduceSum']), - mock.call( - _OpMatcher('add'), - parents=[u'SparseReduceSum:0', u'add/y:0']), - mock.call(_TensorMatcher('add:0'), parents=[u'add']), - ] - }), - dict( - testcase_name='_y_sparse_function_of_x_sparse', - create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse, - feeds=['x'], - replaced_tensors_ready={'x': False}, - expected_calls_dict={ - 'z': [ - mock.call(_OpMatcher('ones/Const'), parents=[]), - mock.call( - _TensorMatcher('ones/Const:0'), parents=[u'ones/Const']), - mock.call('x$composite_tensor_2'), - mock.call( - _OpMatcher('ones'), - parents=['x$composite_tensor_2', u'ones/Const:0']), - mock.call(_TensorMatcher('ones:0'), parents=[u'ones']), - mock.call(_OpMatcher('add/y'), parents=[]), - mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']), - mock.call('x$composite_tensor_1'), - mock.call( - _OpMatcher('add'), - parents=['x$composite_tensor_1', u'add/y:0']), - mock.call(_TensorMatcher('add:0'), parents=[u'add']), - mock.call('x$composite_tensor_0'), - mock.call( - _OpMatcher('SparseTensorDenseAdd'), - parents=[ - 'x$composite_tensor_0', u'add:0', - 'x$composite_tensor_2', u'ones:0' - ]), - mock.call( - _TensorMatcher('SparseTensorDenseAdd:0'), - parents=[u'SparseTensorDenseAdd']), - ], - }), - dict( - testcase_name='_z_function_of_x_ragged', - create_graph_fn=_create_graph_with_z_function_of_x_ragged, - feeds=['x'], - replaced_tensors_ready={'x': False}, - expected_calls_dict={ - 'z': [ - mock.call(_OpMatcher('add/y'), parents=[]), - mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']), - mock.call(_OpMatcher('range/delta'), parents=[]), - mock.call( - _TensorMatcher('range/delta:0'), - parents=[u'range/delta']), - mock.call('x$composite_tensor_0'), - mock.call('x$composite_tensor_2'), - mock.call('x$composite_tensor_1'), - mock.call( - _OpMatcher('RaggedToSparse/RaggedTensorToSparse'), - parents=[ - 'x$composite_tensor_1', 'x$composite_tensor_2', - 'x$composite_tensor_0' - ]), - mock.call( - _TensorMatcher('RaggedToSparse/RaggedTensorToSparse:2'), - parents=['RaggedToSparse/RaggedTensorToSparse']), - mock.call( - _OpMatcher('Rank'), - parents=['RaggedToSparse/RaggedTensorToSparse:2']), - mock.call(_TensorMatcher('Rank:0'), parents=[u'Rank']), - mock.call(_OpMatcher('range/start'), parents=[]), - mock.call( - _TensorMatcher('range/start:0'), - parents=[u'range/start']), - mock.call( - _OpMatcher('range'), - parents=[u'range/start:0', u'Rank:0', u'range/delta:0']), - mock.call(_TensorMatcher('range:0'), parents=[u'range']), - mock.call( - _TensorMatcher('RaggedToSparse/RaggedTensorToSparse:1'), - parents=['RaggedToSparse/RaggedTensorToSparse']), - mock.call( - _TensorMatcher('RaggedToSparse/RaggedTensorToSparse:0'), - parents=['RaggedToSparse/RaggedTensorToSparse']), - mock.call( - _OpMatcher('SparseReduceSum'), - parents=[ - 'RaggedToSparse/RaggedTensorToSparse:0', - 'RaggedToSparse/RaggedTensorToSparse:1', - 'RaggedToSparse/RaggedTensorToSparse:2', u'range:0' - ]), - mock.call( - _TensorMatcher('SparseReduceSum:0'), - parents=[u'SparseReduceSum']), - mock.call( - _OpMatcher('add'), - parents=[u'SparseReduceSum:0', u'add/y:0']), - mock.call(_TensorMatcher('add:0'), parents=[u'add']), - ], - }), - dict( - testcase_name='_y_function_of_x_with_raw_ops_while', - create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while, - feeds=['x'], - replaced_tensors_ready={'x': False}, - expected_calls_dict={ - 'y': [ - mock.call('x$tensor'), - mock.call(_OpMatcher('Const'), parents=[]), - mock.call(_TensorMatcher('Const:0'), parents=[u'Const']), - mock.call(_OpMatcher('Less/y'), parents=[]), - mock.call(_TensorMatcher('Less/y:0'), parents=[u'Less/y']), - mock.call('FuncGraphInput[0]'), - mock.call( - _OpMatcher('Less'), - parents=[u'FuncGraphInput[0]', 'Less/y:0']), - mock.call(_TensorMatcher('Less:0'), parents=[u'Less']), - mock.call(_OpMatcher('Identity'), parents=[u'Less:0']), - mock.call( - _TensorMatcher('Identity:0'), parents=[u'Identity']), - mock.call(_OpMatcher('Add/y'), parents=[]), - mock.call(_TensorMatcher('Add/y:0'), parents=[u'Add/y']), - mock.call('FuncGraphInput[0]'), - mock.call( - _OpMatcher('Add'), - parents=[u'FuncGraphInput[0]', 'Add/y:0']), - mock.call(_TensorMatcher('Add:0'), parents=[u'Add']), - mock.call(_OpMatcher('Identity'), parents=[u'Add:0']), - mock.call( - _TensorMatcher('Identity:0'), parents=[u'Identity']), - mock.call(_OpMatcher('Add_1/y'), parents=[]), - mock.call(_TensorMatcher('Add_1/y:0'), parents=[u'Add_1/y']), - mock.call('FuncGraphInput[1]'), - mock.call( - _OpMatcher('Add_1'), - parents=[u'FuncGraphInput[1]', 'Add_1/y:0']), - mock.call(_TensorMatcher('Add_1:0'), parents=[u'Add_1']), - mock.call(_OpMatcher('Identity_1'), parents=[u'Add_1:0']), - mock.call( - _TensorMatcher('Identity_1:0'), parents=[u'Identity_1']), - mock.call( - _OpMatcher('While'), - parents=[ - u'Const:0', 'x$tensor', 'Identity:0', 'Identity:0', - 'Identity_1:0' - ]), - mock.call(_TensorMatcher('While:1'), parents=[u'While']), - ], - }), - dict( - testcase_name='_y_function_of_x_with_tf_while', - create_graph_fn=_create_graph_with_tf_function_while, - feeds=['x'], - replaced_tensors_ready={'x': False}, - expected_calls_dict={ - 'y': [ - mock.call('x$tensor'), - mock.call('FuncGraphInput[0]'), - mock.call(_OpMatcher('while/maximum_iterations'), parents=[]), - mock.call( - _TensorMatcher('while/maximum_iterations:0'), - parents=[u'while/maximum_iterations']), - mock.call(_OpMatcher('while/loop_counter'), parents=[]), - mock.call( - _TensorMatcher('while/loop_counter:0'), - parents=[u'while/loop_counter']), - mock.call(_OpMatcher('Less/y'), parents=[]), - mock.call(_TensorMatcher('Less/y:0'), parents=[u'Less/y']), - mock.call('FuncGraphInput[2]'), - mock.call( - _OpMatcher('Less'), - parents=['FuncGraphInput[2]', u'Less/y:0']), - mock.call(_TensorMatcher('Less:0'), parents=[u'Less']), - mock.call(_OpMatcher('Identity'), parents=[u'Less:0']), - mock.call( - _TensorMatcher('Identity:0'), parents=[u'Identity']), - mock.call(_OpMatcher('add/y'), parents=[]), - mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']), - mock.call('FuncGraphInput[0]'), - mock.call( - _OpMatcher('add'), - parents=['FuncGraphInput[0]', u'add/y:0']), - mock.call(_TensorMatcher('add:0'), parents=[u'add']), - mock.call(_OpMatcher('Identity'), parents=[u'add:0']), - mock.call( - _TensorMatcher('Identity:0'), parents=[u'Identity']), - mock.call('FuncGraphInput[1]'), - mock.call( - _OpMatcher('Identity_1'), parents=['FuncGraphInput[1]']), - mock.call( - _TensorMatcher('Identity_1:0'), parents=[u'Identity_1']), - mock.call(_OpMatcher('mul/y'), parents=[]), - mock.call(_TensorMatcher('mul/y:0'), parents=[u'mul/y']), - mock.call('FuncGraphInput[2]'), - mock.call( - _OpMatcher('mul'), - parents=['FuncGraphInput[2]', u'mul/y:0']), - mock.call(_TensorMatcher('mul:0'), parents=[u'mul']), - mock.call(_OpMatcher('Identity_2'), parents=[u'mul:0']), - mock.call( - _TensorMatcher('Identity_2:0'), parents=[u'Identity_2']), - mock.call( - _OpMatcher('while'), - parents=[ - u'while/loop_counter:0', - u'while/maximum_iterations:0', 'FuncGraphInput[0]', - u'Identity:0', u'Identity:0', u'Identity_1:0', - u'Identity_2:0' - ]), - mock.call(_TensorMatcher('while:2'), parents=[u'while']), - mock.call(_OpMatcher('Identity'), parents=[u'while:2']), - mock.call( - _TensorMatcher('Identity:0'), parents=[u'Identity']), - mock.call( - _OpMatcher('PartitionedCall'), - parents=['x$tensor', u'Identity:0']), - mock.call( - _TensorMatcher('PartitionedCall:0'), - parents=[u'PartitionedCall']), - ], - }), - dict( - testcase_name='_y_function_of_x_and_table', - create_graph_fn=_create_graph_with_y_function_of_x_and_table_in_first_phase, - feeds=['x'], - replaced_tensors_ready={'x': False}, - expected_calls_dict={ - 'x': [ - mock.call(_OpMatcher('Const'), parents=[]), - mock.call(_TensorMatcher('Const:0'), parents=['Const']), - mock.call(_OpMatcher('hash_table'), parents=['Const:0']), - mock.call('x$tensor'), - ], - 'y': [ - mock.call(_OpMatcher('Const'), parents=[]), - mock.call(_TensorMatcher('Const:0'), parents=['Const']), - mock.call(_OpMatcher('hash_table'), parents=['Const:0']), - mock.call(_OpMatcher('Const_1'), parents=[]), - mock.call(_TensorMatcher('Const_1:0'), parents=['Const_1']), - mock.call('x$tensor'), - mock.call('hash_table'), - mock.call( - _TensorMatcher('hash_table:0'), parents=['hash_table']), - mock.call( - _OpMatcher('hash_table_Lookup/LookupTableFindV2'), - parents=['hash_table:0', 'x$tensor', 'Const_1:0']), - mock.call( - _TensorMatcher('hash_table_Lookup/LookupTableFindV2:0'), - parents=['hash_table_Lookup/LookupTableFindV2']), - ], - }), - dict( - testcase_name='_with_assert_equal', - create_graph_fn=_create_graph_with_assert_equal, - feeds=['x', 'y'], - replaced_tensors_ready={ - 'x': False, - 'y': False - }, - expected_calls_dict={ - 'x': [mock.call('x$tensor'),], - 'y': [mock.call('y$tensor'),], - 'z': [ - mock.call('y$tensor'), - mock.call('x$tensor'), - mock.call( - _OpMatcher('Equal'), parents=['x$tensor', 'y$tensor']), - mock.call(_TensorMatcher('Equal:0'), parents=[u'Equal']), - mock.call( - _OpMatcher('Assert'), - parents=[u'Equal:0', 'x$tensor', 'y$tensor']), - mock.call( - _OpMatcher('control_dependency'), - parents=['x$tensor', u'Assert']), - mock.call( - _TensorMatcher('control_dependency:0'), - parents=[u'control_dependency']), - ] - }), - ) - def testGetUniquePath(self, - create_graph_fn, - feeds, - replaced_tensors_ready, - expected_calls_dict): - with tf.compat.v1.Graph().as_default() as graph: - tensors = create_graph_fn() - replaced_tensors_ready = [(tensors[name], ready) - for name, ready in replaced_tensors_ready.items()] - for name in expected_calls_dict: - - # This is used to construct the debugging string below. - actual_needed_matchers_to_pass = [] - - def describe_path_fn(x, parents=None): - if parents is None: - parents_str = '' + +def _value_to_matcher(value, add_quotes=False): + """Returns a matcher for the value - used for debugging failures.""" + if isinstance(value, tf.Operation): + return _OpMatcher(str(value.node_def.name)) + if isinstance(value, tf.Tensor): + return _TensorMatcher(str(value.name)) + if isinstance(value, (str, bytes)): + if add_quotes: + return f"'{value}'" else: - parents_str = ', parents={}'.format( - list(map(_value_to_matcher, parents))) - actual_needed_matchers_to_pass.append('({}{}),'.format( # pylint: disable=cell-var-from-loop - _value_to_matcher(x, True), parents_str)) - - if isinstance(x, tf.Operation): - return x.node_def.name - if isinstance(x, tf.Tensor): - self.assertLessEqual(len(parents), 1) - return x.name - if isinstance(x, (str, bytes)): - return x - raise ValueError('Unexpected type: {}'.format(x)) - - path_cb_mock = mock.MagicMock(side_effect=describe_path_fn) - - graph_analyzer = graph_tools.InitializableGraphAnalyzer( - graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready, - path_cb_mock) - - graph_analyzer.get_unique_path(tensors[name]) - - try: - path_cb_mock.assert_has_calls(expected_calls_dict[name]) - self.assertEqual( - path_cb_mock.call_count, len(expected_calls_dict[name]), - 'Number of expected calls != number of actual calls for {}: {}' - .format(name, path_cb_mock.call_args_list)) - except AssertionError: - tf.compat.v1.logging.error( - 'The following is a list of matchers for {}:\n{}'.format( - name, '\n'.join(actual_needed_matchers_to_pass))) - raise + return value + raise ValueError(f"Cannot get a matcher for: {type(value)}, {value}") -def _value_to_matcher(value, add_quotes=False): - """Returns a matcher for the value - used for debugging failures.""" - if isinstance(value, tf.Operation): - return _OpMatcher(str(value.node_def.name)) - if isinstance(value, tf.Tensor): - return _TensorMatcher(str(value.name)) - if isinstance(value, (str, bytes)): - if add_quotes: - return '\'{}\''.format(value) - else: - return value - raise ValueError('Cannot get a matcher for: {}, {}'.format( - type(value), value)) - - -if __name__ == '__main__': - test_case.main() +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/impl_helper.py b/tensorflow_transform/impl_helper.py index 874e282..8a80a8e 100644 --- a/tensorflow_transform/impl_helper.py +++ b/tensorflow_transform/impl_helper.py @@ -16,302 +16,339 @@ import functools import os import re -from typing import Callable, Dict, List, Mapping, Optional, FrozenSet +from typing import Callable, Dict, FrozenSet, List, Mapping, Optional -from absl import logging import numpy as np import pyarrow as pa import tensorflow as tf -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import annotators -from tensorflow_transform import common_types -from tensorflow_transform import graph_context -from tensorflow_transform import graph_tools -from tensorflow_transform import schema_inference -from tensorflow_transform import tf2_utils -from tensorflow_transform import tf_utils -from tensorflow_transform.output_wrapper import TFTransformOutput -from tensorflow_transform.saved import saved_transform_io -from tensorflow_transform.saved import saved_transform_io_v2 -from tensorflow_transform.tf_metadata import dataset_metadata -from tensorflow_transform.tf_metadata import metadata_io -from tensorflow_transform.tf_metadata import schema_utils -from tfx_bsl.coders import example_coder -from tfx_bsl.tfxio import tensor_to_arrow +from absl import logging + # pylint: disable=g-direct-tensorflow-import from tensorflow.python.eager import function from tensorflow.python.framework import ops + # pylint: enable=g-direct-tensorflow-import from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.coders import example_coder +from tfx_bsl.tfxio import tensor_to_arrow +from tensorflow_transform import ( + analyzer_nodes, + annotators, + common_types, + graph_context, + graph_tools, + schema_inference, + tf2_utils, + tf_utils, +) +from tensorflow_transform.output_wrapper import TFTransformOutput +from tensorflow_transform.saved import saved_transform_io, saved_transform_io_v2 +from tensorflow_transform.tf_metadata import dataset_metadata, metadata_io, schema_utils -_VALID_SCOPE_REGEX = re.compile('^[A-Za-z0-9]*$') -_INVALID_SCOPE_CHAR = re.compile('[^A-Za-z0-9_.\\-/>]') +_VALID_SCOPE_REGEX = re.compile("^[A-Za-z0-9]*$") +_INVALID_SCOPE_CHAR = re.compile("[^A-Za-z0-9_.\\-/>]") -METADATA_DIR_NAME = '.tft_metadata' +METADATA_DIR_NAME = ".tft_metadata" _FEATURE_VALUE_KIND_TO_NP_DTYPE = { - 'float_list': np.float32, - 'int64_list': np.int64, - 'bytes_list': object, + "float_list": np.float32, + "int64_list": np.int64, + "bytes_list": object, } def batched_placeholders_from_specs(specs): - """Returns placeholders for the given tf.TypeSpecs or feature specs. - - Args: - specs: a Dict[Text, Union[tf.TypeSpec, FeatureSpec]]. Note that the values - in this dict must be of the same type. Mixing is not allowed. - - Returns: - A dictionary from strings to `Tensor`, `SparseTensor`s, or `RaggedTensor`s. - - Raises: - ValueError: when the TypeSpec or feature spec has an unsupported dtype. - """ - if not (all([_is_feature_spec(s) for s in specs.values()]) or - all([isinstance(s, tf.TypeSpec) for s in specs.values()])): - raise TypeError('Specs must be all tf.TypeSpecs or feature specs. ' - 'Mixing is not allowed. Got: {}'.format(specs)) - - result = {} - for name, spec in specs.items(): - if isinstance(spec, tf.RaggedTensorSpec): - # TODO(b/159717195): clean up protected-access - spec_dtype = spec._dtype # pylint: disable=protected-access - else: - spec_dtype = spec.dtype - if spec_dtype not in (tf.int64, tf.float32, tf.string): - raise ValueError('Feature {} ({}, {}) had invalid dtype'.format( - name, spec, type(spec))) - if isinstance(spec, tf.TypeSpec): - result[name] = _batched_placeholder_from_typespec(name, spec) - else: - result[name] = _batched_placeholder_from_feature_spec(name, spec) + """Returns placeholders for the given tf.TypeSpecs or feature specs. + + Args: + ---- + specs: a Dict[Text, Union[tf.TypeSpec, FeatureSpec]]. Note that the values + in this dict must be of the same type. Mixing is not allowed. + + Returns: + ------- + A dictionary from strings to `Tensor`, `SparseTensor`s, or `RaggedTensor`s. + + Raises: + ------ + ValueError: when the TypeSpec or feature spec has an unsupported dtype. + """ + if not ( + all([_is_feature_spec(s) for s in specs.values()]) + or all([isinstance(s, tf.TypeSpec) for s in specs.values()]) + ): + raise TypeError( + "Specs must be all tf.TypeSpecs or feature specs. " + f"Mixing is not allowed. Got: {specs}" + ) + + result = {} + for name, spec in specs.items(): + if isinstance(spec, tf.RaggedTensorSpec): + # TODO(b/159717195): clean up protected-access + spec_dtype = spec._dtype # pylint: disable=protected-access + else: + spec_dtype = spec.dtype + if spec_dtype not in (tf.int64, tf.float32, tf.string): + raise ValueError(f"Feature {name} ({spec}, {type(spec)}) had invalid dtype") + if isinstance(spec, tf.TypeSpec): + result[name] = _batched_placeholder_from_typespec(name, spec) + else: + result[name] = _batched_placeholder_from_feature_spec(name, spec) - return result + return result def _is_feature_spec(spec): - return isinstance(spec, (tf.io.VarLenFeature, tf.io.SparseFeature, - tf.io.FixedLenFeature, tf.io.RaggedFeature)) + return isinstance( + spec, + ( + tf.io.VarLenFeature, + tf.io.SparseFeature, + tf.io.FixedLenFeature, + tf.io.RaggedFeature, + ), + ) def _sanitize_scope_name(name): - scope_name = _INVALID_SCOPE_CHAR.sub('_', name) - if not _VALID_SCOPE_REGEX.match(scope_name): - scope_name = 'F_{}'.format(scope_name) - return scope_name + scope_name = _INVALID_SCOPE_CHAR.sub("_", name) + if not _VALID_SCOPE_REGEX.match(scope_name): + scope_name = f"F_{scope_name}" + return scope_name def _batched_placeholder_from_typespec(name, typespec): - """Creates a batched placeholder from a tf.TypeSpec.""" - if isinstance(typespec, - (tf.TensorSpec, tf.SparseTensorSpec, tf.RaggedTensorSpec)): - sanitized_name = _sanitize_scope_name(name) - with tf.name_scope(sanitized_name): - return tf.nest.map_structure( - lambda tspec: tf.raw_ops.Placeholder( # pylint: disable=g-long-lambda - dtype=tspec.dtype, - shape=tspec.shape, - name=sanitized_name), - typespec, - expand_composites=True) - - raise ValueError('Unsupported typespec: {}({}) for feature {}'.format( - typespec, type(typespec), name)) + """Creates a batched placeholder from a tf.TypeSpec.""" + if isinstance(typespec, (tf.TensorSpec, tf.SparseTensorSpec, tf.RaggedTensorSpec)): + sanitized_name = _sanitize_scope_name(name) + with tf.name_scope(sanitized_name): + return tf.nest.map_structure( + lambda tspec: tf.raw_ops.Placeholder( # pylint: disable=g-long-lambda + dtype=tspec.dtype, shape=tspec.shape, name=sanitized_name + ), + typespec, + expand_composites=True, + ) + + raise ValueError( + f"Unsupported typespec: {typespec}({type(typespec)}) for feature {name}" + ) def _batched_placeholder_from_feature_spec(name, feature_spec): - """Creates a batched placeholder from a feature spec.""" - scope_name = _sanitize_scope_name(name) - if isinstance(feature_spec, tf.io.FixedLenFeature): - return tf.compat.v1.placeholder( - feature_spec.dtype, [None] + feature_spec.shape, name=scope_name) - elif isinstance(feature_spec, tf.io.VarLenFeature): - return tf.compat.v1.sparse_placeholder( - feature_spec.dtype, [None, None], name=scope_name) - elif isinstance(feature_spec, tf.io.SparseFeature): - shape = [None] + feature_spec.size if isinstance( - feature_spec.size, list) else [None, feature_spec.size] - return tf.compat.v1.sparse_placeholder( - feature_spec.dtype, shape, name=scope_name) - - raise ValueError('Unsupported feature spec: {}({}) for feature {}'.format( - feature_spec, type(feature_spec), name)) + """Creates a batched placeholder from a feature spec.""" + scope_name = _sanitize_scope_name(name) + if isinstance(feature_spec, tf.io.FixedLenFeature): + return tf.compat.v1.placeholder( + feature_spec.dtype, [None] + feature_spec.shape, name=scope_name + ) + elif isinstance(feature_spec, tf.io.VarLenFeature): + return tf.compat.v1.sparse_placeholder( + feature_spec.dtype, [None, None], name=scope_name + ) + elif isinstance(feature_spec, tf.io.SparseFeature): + shape = ( + [None] + feature_spec.size + if isinstance(feature_spec.size, list) + else [None, feature_spec.size] + ) + return tf.compat.v1.sparse_placeholder( + feature_spec.dtype, shape, name=scope_name + ) + + raise ValueError( + f"Unsupported feature spec: {feature_spec}({type(feature_spec)}) for feature {name}" + ) def _example_to_dict(example: bytes) -> Dict[str, Optional[np.ndarray]]: - """Converts serialized tf.Example to Python dictionary.""" - example = tf.train.Example.FromString(example) - result = {} - # Sort the produced dict by keys to make the order deterministic. - for name, feature in sorted(example.features.feature.items()): - kind = feature.WhichOneof('kind') - # Use None if the value kind is not set (can occur for passthrough values). - result[name] = (None if kind is None else np.array( - getattr(feature, kind).value, - dtype=_FEATURE_VALUE_KIND_TO_NP_DTYPE[kind])) - return result + """Converts serialized tf.Example to Python dictionary.""" + example = tf.train.Example.FromString(example) + result = {} + # Sort the produced dict by keys to make the order deterministic. + for name, feature in sorted(example.features.feature.items()): + kind = feature.WhichOneof("kind") + # Use None if the value kind is not set (can occur for passthrough values). + result[name] = ( + None + if kind is None + else np.array( + getattr(feature, kind).value, + dtype=_FEATURE_VALUE_KIND_TO_NP_DTYPE[kind], + ) + ) + return result def record_batch_to_instance_dicts( record_batch: pa.RecordBatch, schema: schema_pb2.Schema ) -> List[common_types.InstanceDictType]: - """Converts pa.RecordBatch to list of Python dictionaries. - - Args: - record_batch: the batch to be converted. - schema: A `Schema` proto. - - Returns: - A list of dicts where each dict is an in-memory representation of an - instance. - """ - # Alternatively, we could've used `record_batch.to_pylist()`, but - # RaggedTensors would be represented as nested lists (as opposed to array of - # values + row lengths), so we make a trip through flat examples first. - coder = example_coder.RecordBatchToExamplesEncoder(schema) - examples = coder.encode(record_batch) - # Dense tensor instances must be reshaped according to their spec shape. - # Scalars are represented as Python scalars (as opposed to singleton arrays). - feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec - dense_reshape_fns = {} - def _extract_singleton_item( - singleton: np.ndarray, - ) -> common_types.PrimitiveType: - return singleton.item() - for name, spec in feature_spec.items(): - if isinstance(spec, tf.io.FixedLenFeature): - if spec.shape: - dense_reshape_fns[name] = functools.partial( - np.reshape, newshape=spec.shape - ) - else: - dense_reshape_fns[name] = _extract_singleton_item - result = [] - for example in examples: - instance_dict = _example_to_dict(example) - for name, reshape_fn in dense_reshape_fns.items(): - instance_dict[name] = reshape_fn(instance_dict[name]) - result.append(instance_dict) - return result + """Converts pa.RecordBatch to list of Python dictionaries. + + Args: + ---- + record_batch: the batch to be converted. + schema: A `Schema` proto. + + Returns: + ------- + A list of dicts where each dict is an in-memory representation of an + instance. + """ + # Alternatively, we could've used `record_batch.to_pylist()`, but + # RaggedTensors would be represented as nested lists (as opposed to array of + # values + row lengths), so we make a trip through flat examples first. + coder = example_coder.RecordBatchToExamplesEncoder(schema) + examples = coder.encode(record_batch) + # Dense tensor instances must be reshaped according to their spec shape. + # Scalars are represented as Python scalars (as opposed to singleton arrays). + feature_spec = schema_utils.schema_as_feature_spec(schema).feature_spec + dense_reshape_fns = {} + + def _extract_singleton_item( + singleton: np.ndarray, + ) -> common_types.PrimitiveType: + return singleton.item() + + for name, spec in feature_spec.items(): + if isinstance(spec, tf.io.FixedLenFeature): + if spec.shape: + dense_reshape_fns[name] = functools.partial( + np.reshape, newshape=spec.shape + ) + else: + dense_reshape_fns[name] = _extract_singleton_item + result = [] + for example in examples: + instance_dict = _example_to_dict(example) + for name, reshape_fn in dense_reshape_fns.items(): + instance_dict[name] = reshape_fn(instance_dict[name]) + result.append(instance_dict) + return result def validate_varlen_sparse_value( name: str, batched_value: common_types.SparseTensorValueType ): - """Checks that the given SparseTensor is 2-D ragged and left-aligned.""" - indices = np.asarray(batched_value.indices) - if indices.shape[1] != 2: - raise ValueError(f'Encountered non 2-D varlen sparse feature {name}') - if indices.shape[0] == 0: - return - indices_diff = np.diff(indices, axis=0) - instance_index_diff, value_index_diff = indices_diff[:, 0], indices_diff[:, 1] - if np.any(instance_index_diff < 0): - raise ValueError( - f'Encountered decreasing instance indices for feature {name}: {indices}' - ) - if np.any(np.logical_and(instance_index_diff == 0, value_index_diff != 1)): - raise ValueError( - f'Encountered non-consecutive value indices for feature {name}:' - f' {indices}' - ) - (instance_boundaries,) = np.where(instance_index_diff != 0) - if np.any(indices[np.append(instance_boundaries + 1, 0), 1] != 0): - raise ValueError( - f'Encountered non-zero starting value indices for feature {name}:' - f' {indices}' - ) + """Checks that the given SparseTensor is 2-D ragged and left-aligned.""" + indices = np.asarray(batched_value.indices) + if indices.shape[1] != 2: + raise ValueError(f"Encountered non 2-D varlen sparse feature {name}") + if indices.shape[0] == 0: + return + indices_diff = np.diff(indices, axis=0) + instance_index_diff, value_index_diff = indices_diff[:, 0], indices_diff[:, 1] + if np.any(instance_index_diff < 0): + raise ValueError( + f"Encountered decreasing instance indices for feature {name}: {indices}" + ) + if np.any(np.logical_and(instance_index_diff == 0, value_index_diff != 1)): + raise ValueError( + f"Encountered non-consecutive value indices for feature {name}:" + f" {indices}" + ) + (instance_boundaries,) = np.where(instance_index_diff != 0) + if np.any(indices[np.append(instance_boundaries + 1, 0), 1] != 0): + raise ValueError( + f"Encountered non-zero starting value indices for feature {name}:" + f" {indices}" + ) def get_type_specs_from_feature_specs( feature_specs: Dict[str, common_types.FeatureSpecType], ragged_sequence_features: FrozenSet[str] = frozenset(), ) -> Dict[str, tf.TypeSpec]: - """Returns `tf.TypeSpec`s for the given feature specs. - - Returns a dictionary of type_spec with the same type and shape as defined by - `feature_specs`. - - Args: - feature_specs: A TensorFlow feature spec. - ragged_sequence_features: Set of names of features representing ragged - sequence tensors. - - Returns: - A dictionary from strings to `tf.TensorSpec`, `tf.SparseTensorSpec` or - `tf.RaggedTensorSpec`s. - - Raises: - ValueError: If the feature spec contains feature types not supported. - """ - result = {} - for name, feature_spec in feature_specs.items(): - if isinstance(feature_spec, tf.io.FixedLenFeature): - result[name] = tf.TensorSpec([None] + list(feature_spec.shape), - feature_spec.dtype) - elif isinstance(feature_spec, tf.io.VarLenFeature): - result[name] = tf.SparseTensorSpec([None, None], feature_spec.dtype) - elif isinstance(feature_spec, tf.io.SparseFeature): - shape = [None] + [None if dim == -1 else dim for dim in feature_spec.size] - result[name] = tf.SparseTensorSpec(shape, feature_spec.dtype) - elif isinstance(feature_spec, tf.io.RaggedFeature): - # Number of dimensions is number of partitions + 1 + 1 batch dimension. - shape = [None, None] - ragged_rank = 1 - # Ragged sequence tensors will have additional sequence dimension. - if name in ragged_sequence_features: - shape.append(None) - ragged_rank += 1 - for partition in feature_spec.partitions: - if isinstance(partition, tf.io.RaggedFeature.UniformRowLength): # pytype: disable=attribute-error - shape.append(partition.length) + """Returns `tf.TypeSpec`s for the given feature specs. + + Returns a dictionary of type_spec with the same type and shape as defined by + `feature_specs`. + + Args: + ---- + feature_specs: A TensorFlow feature spec. + ragged_sequence_features: Set of names of features representing ragged + sequence tensors. + + Returns: + ------- + A dictionary from strings to `tf.TensorSpec`, `tf.SparseTensorSpec` or + `tf.RaggedTensorSpec`s. + + Raises: + ------ + ValueError: If the feature spec contains feature types not supported. + """ + result = {} + for name, feature_spec in feature_specs.items(): + if isinstance(feature_spec, tf.io.FixedLenFeature): + result[name] = tf.TensorSpec( + [None] + list(feature_spec.shape), feature_spec.dtype + ) + elif isinstance(feature_spec, tf.io.VarLenFeature): + result[name] = tf.SparseTensorSpec([None, None], feature_spec.dtype) + elif isinstance(feature_spec, tf.io.SparseFeature): + shape = [None] + [None if dim == -1 else dim for dim in feature_spec.size] + result[name] = tf.SparseTensorSpec(shape, feature_spec.dtype) + elif isinstance(feature_spec, tf.io.RaggedFeature): + # Number of dimensions is number of partitions + 1 + 1 batch dimension. + shape = [None, None] + ragged_rank = 1 + # Ragged sequence tensors will have additional sequence dimension. + if name in ragged_sequence_features: + shape.append(None) + ragged_rank += 1 + for partition in feature_spec.partitions: + if isinstance( + partition, tf.io.RaggedFeature.UniformRowLength + ): # pytype: disable=attribute-error + shape.append(partition.length) + else: + shape.append(None) + ragged_rank += 1 + result[name] = tf.RaggedTensorSpec( + shape=shape, + dtype=feature_spec.dtype, + ragged_rank=ragged_rank, + row_splits_dtype=feature_spec.row_splits_dtype, + ) else: - shape.append(None) - ragged_rank += 1 - result[name] = tf.RaggedTensorSpec( - shape=shape, - dtype=feature_spec.dtype, - ragged_rank=ragged_rank, - row_splits_dtype=feature_spec.row_splits_dtype) - else: - raise ValueError('Invalid feature spec {}.'.format(feature_spec)) - return result + raise ValueError(f"Invalid feature spec {feature_spec}.") + return result def make_tensor_to_arrow_converter( - schema: schema_pb2.Schema) -> tensor_to_arrow.TensorsToRecordBatchConverter: - """Constructs a `tf.Tensor` to `pa.RecordBatch` converter.""" - # Ragged sequence features will have an additional (sequence) dimension that - # doesn't come from feature partition. Hence, we need to generate type spec - # accordingly. - ragged_sequence_features = set() - feature_specs = schema_utils.schema_as_feature_spec(schema).feature_spec - for feature in schema.feature: - if feature.type == schema_pb2.FeatureType.STRUCT: - for child_feature in feature.struct_domain.feature: - ragged_sequence_features.add(child_feature.name) - type_specs = get_type_specs_from_feature_specs( - feature_specs, frozenset(ragged_sequence_features) - ) - - # Make sure that SparseFeatures are handled as generic SparseTensors as - # opposed to VarLenSparse. Note that at this point only sparse outputs with - # rank >2 are inferred as SparseFeatures, but this is likely to change. - sparse_tensor_names = set() - for name, spec in feature_specs.items(): - if isinstance(spec, tf.io.SparseFeature): - sparse_tensor_names.add(name) - options = tensor_to_arrow.TensorsToRecordBatchConverter.Options( - sparse_tensor_value_column_name_template=schema_inference - .SPARSE_VALUES_NAME_TEMPLATE, - sparse_tensor_index_column_name_template=schema_inference - .SPARSE_INDICES_NAME_TEMPLATE, - generic_sparse_tensor_names=frozenset(sparse_tensor_names)) - return tensor_to_arrow.TensorsToRecordBatchConverter(type_specs, options) + schema: schema_pb2.Schema, +) -> tensor_to_arrow.TensorsToRecordBatchConverter: + """Constructs a `tf.Tensor` to `pa.RecordBatch` converter.""" + # Ragged sequence features will have an additional (sequence) dimension that + # doesn't come from feature partition. Hence, we need to generate type spec + # accordingly. + ragged_sequence_features = set() + feature_specs = schema_utils.schema_as_feature_spec(schema).feature_spec + for feature in schema.feature: + if feature.type == schema_pb2.FeatureType.STRUCT: + for child_feature in feature.struct_domain.feature: + ragged_sequence_features.add(child_feature.name) + type_specs = get_type_specs_from_feature_specs( + feature_specs, frozenset(ragged_sequence_features) + ) + + # Make sure that SparseFeatures are handled as generic SparseTensors as + # opposed to VarLenSparse. Note that at this point only sparse outputs with + # rank >2 are inferred as SparseFeatures, but this is likely to change. + sparse_tensor_names = set() + for name, spec in feature_specs.items(): + if isinstance(spec, tf.io.SparseFeature): + sparse_tensor_names.add(name) + options = tensor_to_arrow.TensorsToRecordBatchConverter.Options( + sparse_tensor_value_column_name_template=schema_inference.SPARSE_VALUES_NAME_TEMPLATE, + sparse_tensor_index_column_name_template=schema_inference.SPARSE_INDICES_NAME_TEMPLATE, + generic_sparse_tensor_names=frozenset(sparse_tensor_names), + ) + return tensor_to_arrow.TensorsToRecordBatchConverter(type_specs, options) # TODO(b/149997088): Split into two APIs one that will just trace the @@ -326,123 +363,128 @@ def get_traced_transform_fn( tf_graph_context: graph_context.TFGraphContext, output_keys_to_name_map: Optional[Dict[str, str]] = None, ) -> tf.types.experimental.GenericFunction: - """Get preprocessing_fn traced using tf.function. - - Args: - preprocessing_fn: A user defined python function to be traced. - input_signature: `tf.TypeSpec`s describing the inputs to the - `preprocessing_fn`. - tf_graph_context: A `TFGraphContext` context manager to invoke the - `preprocessing_fn` in. - output_keys_to_name_map: (Optional) A map from output dictionary keys to the - names of the tensors that they represent. - - Returns: - A tf.function object representing a function with the same input signature - as `preprocessing_fn`. - If `output_keys_to_name_map` is None or there are no more TFT analyzers to - evaluate in the `preprocessing_fn`, the output signature of this - tf.function - is the same as the `preprocessing_fn`. - Otherwise, its output signature contains the keys in - `output_keys_to_name_map` and the tensor represented by the corresponding - dictionary values. - """ - - assert all([isinstance(s, tf.TypeSpec) for s in input_signature.values()]) - - # TODO(b/177672051): Investigate performance impact of enabling autograph. - @tf.function(input_signature=[input_signature], autograph=False) - def transform_fn(inputs): - graph = ops.get_default_graph() - # If any analyzers have already been evaluated, pass them using the - # `graph_context.TFGraphContext`. This will be used in place of the analyzer - # nodes. - # The user defined `preprocessing_fn` may directly modify its inputs which - # is not allowed in a tf.function. Hence, we make a copy here. - inputs_copy = tf_utils.copy_tensors(inputs) - with tf_graph_context: - transformed_features = preprocessing_fn(inputs_copy) - # An empty `TENSOR_REPLACEMENTS` collection symbolizes that there is no - # analyzer left for Transform to evaluate. Either if this collection is - # empty or if no specific outputs have been requested, return - # the same output as `preprocessing_fn` (i.e, transformed_features). - if (output_keys_to_name_map is None or - not graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)): - return transformed_features - else: - return { - key: graph.get_tensor_by_name(value) - for key, value in output_keys_to_name_map.items() - } + """Get preprocessing_fn traced using tf.function. + + Args: + ---- + preprocessing_fn: A user defined python function to be traced. + input_signature: `tf.TypeSpec`s describing the inputs to the + `preprocessing_fn`. + tf_graph_context: A `TFGraphContext` context manager to invoke the + `preprocessing_fn` in. + output_keys_to_name_map: (Optional) A map from output dictionary keys to the + names of the tensors that they represent. + + Returns: + ------- + A tf.function object representing a function with the same input signature + as `preprocessing_fn`. + If `output_keys_to_name_map` is None or there are no more TFT analyzers to + evaluate in the `preprocessing_fn`, the output signature of this + tf.function + is the same as the `preprocessing_fn`. + Otherwise, its output signature contains the keys in + `output_keys_to_name_map` and the tensor represented by the corresponding + dictionary values. + """ + assert all([isinstance(s, tf.TypeSpec) for s in input_signature.values()]) + + # TODO(b/177672051): Investigate performance impact of enabling autograph. + @tf.function(input_signature=[input_signature], autograph=False) + def transform_fn(inputs): + graph = ops.get_default_graph() + # If any analyzers have already been evaluated, pass them using the + # `graph_context.TFGraphContext`. This will be used in place of the analyzer + # nodes. + # The user defined `preprocessing_fn` may directly modify its inputs which + # is not allowed in a tf.function. Hence, we make a copy here. + inputs_copy = tf_utils.copy_tensors(inputs) + with tf_graph_context: + transformed_features = preprocessing_fn(inputs_copy) + # An empty `TENSOR_REPLACEMENTS` collection symbolizes that there is no + # analyzer left for Transform to evaluate. Either if this collection is + # empty or if no specific outputs have been requested, return + # the same output as `preprocessing_fn` (i.e, transformed_features). + if output_keys_to_name_map is None or not graph.get_collection( + analyzer_nodes.TENSOR_REPLACEMENTS + ): + return transformed_features + else: + return { + key: graph.get_tensor_by_name(value) + for key, value in output_keys_to_name_map.items() + } - return transform_fn + return transform_fn def _trace_preprocessing_fn_v1(preprocessing_fn, specs): - """Trace TF1 graph for `preprocessing_fn`.""" - with tf.compat.v1.Graph().as_default() as graph: - with tf.compat.v1.name_scope('inputs'): - structured_inputs = batched_placeholders_from_specs(specs) - # In order to avoid a bug where import_graph_def fails when the - # input_map and return_elements of an imported graph are the same - # (b/34288791), we avoid using the placeholder of an input column as an - # output of a graph. We do this by applying tf.identity to all inputs of - # the preprocessing_fn. Note this applies at the level of raw tensors. - # TODO(b/34288791): Remove this workaround and use a shallow copy of - # inputs instead. A shallow copy is needed in case - # self._preprocessing_fn mutates its input. - copied_inputs = tf_utils.copy_tensors(structured_inputs) - - structured_outputs = preprocessing_fn(copied_inputs) - return graph, structured_inputs, structured_outputs + """Trace TF1 graph for `preprocessing_fn`.""" + with tf.compat.v1.Graph().as_default() as graph: + with tf.compat.v1.name_scope("inputs"): + structured_inputs = batched_placeholders_from_specs(specs) + # In order to avoid a bug where import_graph_def fails when the + # input_map and return_elements of an imported graph are the same + # (b/34288791), we avoid using the placeholder of an input column as an + # output of a graph. We do this by applying tf.identity to all inputs of + # the preprocessing_fn. Note this applies at the level of raw tensors. + # TODO(b/34288791): Remove this workaround and use a shallow copy of + # inputs instead. A shallow copy is needed in case + # self._preprocessing_fn mutates its input. + copied_inputs = tf_utils.copy_tensors(structured_inputs) + + structured_outputs = preprocessing_fn(copied_inputs) + return graph, structured_inputs, structured_outputs def _trace_preprocessing_fn_v2(preprocessing_fn, specs, base_temp_dir): - """Trace TF2 graph for `preprocessing_fn`.""" - tf_graph_context = graph_context.TFGraphContext( - module_to_export=tf.Module(), - temp_dir=base_temp_dir, - evaluated_replacements=None) - with annotators.object_tracker_scope(annotators.ObjectTracker()): - concrete_fn = get_traced_transform_fn( - preprocessing_fn, specs, tf_graph_context).get_concrete_function() - return (concrete_fn.graph, - tf2_utils.get_structured_inputs_from_func_graph(concrete_fn.graph), - concrete_fn.structured_outputs) - - -def trace_preprocessing_function(preprocessing_fn, - input_specs, - use_tf_compat_v1, - base_temp_dir=None): - """Trace graph for `preprocessing_fn`. - - Args: - preprocessing_fn: A user defined python function to be traced. - input_specs: A dictionary from input feature name to its FeatureSpec or - TypeSpec. If use_tf_compat_v1 is `False`, input_specs must be a dictionary - of TypeSpecs. - use_tf_compat_v1: (Optional) If `True`, the `preprocessing_fn` is traced as - a TF 1.x graph. Else, it is traced using tf.function. - base_temp_dir: (Optional) Base path to write any dummy assets to during - tracing. Required when `use_tf_compat_v1` is `False`. - - Returns: - A tuple of: - - 0. the graph representing the traced `preprocessing_fn` - 1. the graph's structured inputs - 2. the graph's structured outputs - """ - if use_tf_compat_v1: - return _trace_preprocessing_fn_v1(preprocessing_fn, input_specs) - else: - return _trace_preprocessing_fn_v2( - preprocessing_fn, input_specs, base_temp_dir + """Trace TF2 graph for `preprocessing_fn`.""" + tf_graph_context = graph_context.TFGraphContext( + module_to_export=tf.Module(), + temp_dir=base_temp_dir, + evaluated_replacements=None, + ) + with annotators.object_tracker_scope(annotators.ObjectTracker()): + concrete_fn = get_traced_transform_fn( + preprocessing_fn, specs, tf_graph_context + ).get_concrete_function() + return ( + concrete_fn.graph, + tf2_utils.get_structured_inputs_from_func_graph(concrete_fn.graph), + concrete_fn.structured_outputs, ) +def trace_preprocessing_function( + preprocessing_fn, input_specs, use_tf_compat_v1, base_temp_dir=None +): + """Trace graph for `preprocessing_fn`. + + Args: + ---- + preprocessing_fn: A user defined python function to be traced. + input_specs: A dictionary from input feature name to its FeatureSpec or + TypeSpec. If use_tf_compat_v1 is `False`, input_specs must be a dictionary + of TypeSpecs. + use_tf_compat_v1: (Optional) If `True`, the `preprocessing_fn` is traced as + a TF 1.x graph. Else, it is traced using tf.function. + base_temp_dir: (Optional) Base path to write any dummy assets to during + tracing. Required when `use_tf_compat_v1` is `False`. + + Returns: + ------- + A tuple of: + + 0. the graph representing the traced `preprocessing_fn` + 1. the graph's structured inputs + 2. the graph's structured outputs + """ + if use_tf_compat_v1: + return _trace_preprocessing_fn_v1(preprocessing_fn, input_specs) + else: + return _trace_preprocessing_fn_v2(preprocessing_fn, input_specs, base_temp_dir) + + def _trace_and_write_transform_fn( saved_model_dir: str, preprocessing_fn: Callable[ @@ -455,25 +497,25 @@ def _trace_and_write_transform_fn( output_keys_to_name_map: Optional[Dict[str, str]], save_options: Optional[tf.saved_model.SaveOptions], ) -> function.ConcreteFunction: - """Trace `preprocessing_fn` and serialize to a SavedModel.""" - tf_graph_context = graph_context.TFGraphContext( - module_to_export=tf.Module(), - temp_dir=base_temp_dir, - evaluated_replacements=tensor_replacement_map, - ) - transform_fn = get_traced_transform_fn( - preprocessing_fn, - input_signature, - tf_graph_context, - output_keys_to_name_map=output_keys_to_name_map, - ) - return saved_transform_io_v2.write_v2_saved_model( - tf_graph_context.module_to_export, - transform_fn, - 'transform_fn', - saved_model_dir, - save_options, - ) + """Trace `preprocessing_fn` and serialize to a SavedModel.""" + tf_graph_context = graph_context.TFGraphContext( + module_to_export=tf.Module(), + temp_dir=base_temp_dir, + evaluated_replacements=tensor_replacement_map, + ) + transform_fn = get_traced_transform_fn( + preprocessing_fn, + input_signature, + tf_graph_context, + output_keys_to_name_map=output_keys_to_name_map, + ) + return saved_transform_io_v2.write_v2_saved_model( + tf_graph_context.module_to_export, + transform_fn, + "transform_fn", + saved_model_dir, + save_options, + ) def _trace_and_get_metadata( @@ -486,69 +528,67 @@ def _trace_and_get_metadata( base_temp_dir: Optional[str], tensor_replacement_map: Optional[Dict[str, tf.Tensor]], ) -> dataset_metadata.DatasetMetadata: - """Compute and return metadata for the outputs of `concrete_transform_fn`.""" - tf_graph_context = graph_context.TFGraphContext( - module_to_export=tf.Module(), - temp_dir=base_temp_dir, - evaluated_replacements=tensor_replacement_map, - ) - concrete_metadata_fn = schema_inference.get_traced_metadata_fn( - preprocessing_fn, - structured_inputs, - tf_graph_context, - evaluate_schema_overrides=True, - ) - return dataset_metadata.DatasetMetadata( - schema=schema_inference.infer_feature_schema_v2( - concrete_transform_fn.structured_outputs, - concrete_metadata_fn, - evaluate_schema_overrides=True, - ) - ) + """Compute and return metadata for the outputs of `concrete_transform_fn`.""" + tf_graph_context = graph_context.TFGraphContext( + module_to_export=tf.Module(), + temp_dir=base_temp_dir, + evaluated_replacements=tensor_replacement_map, + ) + concrete_metadata_fn = schema_inference.get_traced_metadata_fn( + preprocessing_fn, + structured_inputs, + tf_graph_context, + evaluate_schema_overrides=True, + ) + return dataset_metadata.DatasetMetadata( + schema=schema_inference.infer_feature_schema_v2( + concrete_transform_fn.structured_outputs, + concrete_metadata_fn, + evaluate_schema_overrides=True, + ) + ) def _validate_analyzers_fingerprint( - baseline_analyzers_fingerprint: Mapping[ - str, graph_tools.AnalyzersFingerprint - ], + baseline_analyzers_fingerprint: Mapping[str, graph_tools.AnalyzersFingerprint], graph: tf.Graph, structured_inputs: Mapping[str, common_types.TensorType], ): - """Validates analyzers fingerprint in `graph` is same as baseline.""" - analyzers_fingerprint = graph_tools.get_analyzers_fingerprint( - graph, structured_inputs - ) - error_msg = ( - 'The order of analyzers in your `preprocessing_fn` appears to be ' - 'non-deterministic. This can be fixed either by changing your ' - '`preprocessing_fn` such that tf.Transform analyzers are encountered ' - 'in a deterministic order or by passing a unique name to each ' - 'analyzer API call.' - ) - for analyzer in analyzers_fingerprint: - if analyzer not in baseline_analyzers_fingerprint: - prefix_msg = ( - f'Analyzer node ({analyzer}) not found in ' - f'{baseline_analyzers_fingerprint.keys()}. ' - ) - raise RuntimeError(prefix_msg + error_msg) - if ( - baseline_analyzers_fingerprint[analyzer].source_keys - != analyzers_fingerprint[analyzer].source_keys - ): - raise RuntimeError(error_msg) - - if ( - baseline_analyzers_fingerprint[analyzer].unique_path_hash - != analyzers_fingerprint[analyzer].unique_path_hash - ): - logging.warning( - "Analyzer (%s) node's cache key varies on repeated tracing." - ' This warning is safe to ignore if you either specify `name` for all' - ' analyzers or if the order in which they are invoked is' - ' deterministic. If not, please file a bug with details.', - analyzer, - ) + """Validates analyzers fingerprint in `graph` is same as baseline.""" + analyzers_fingerprint = graph_tools.get_analyzers_fingerprint( + graph, structured_inputs + ) + error_msg = ( + "The order of analyzers in your `preprocessing_fn` appears to be " + "non-deterministic. This can be fixed either by changing your " + "`preprocessing_fn` such that tf.Transform analyzers are encountered " + "in a deterministic order or by passing a unique name to each " + "analyzer API call." + ) + for analyzer in analyzers_fingerprint: + if analyzer not in baseline_analyzers_fingerprint: + prefix_msg = ( + f"Analyzer node ({analyzer}) not found in " + f"{baseline_analyzers_fingerprint.keys()}. " + ) + raise RuntimeError(prefix_msg + error_msg) + if ( + baseline_analyzers_fingerprint[analyzer].source_keys + != analyzers_fingerprint[analyzer].source_keys + ): + raise RuntimeError(error_msg) + + if ( + baseline_analyzers_fingerprint[analyzer].unique_path_hash + != analyzers_fingerprint[analyzer].unique_path_hash + ): + logging.warning( + "Analyzer (%s) node's cache key varies on repeated tracing." + " This warning is safe to ignore if you either specify `name` for all" + " analyzers or if the order in which they are invoked is" + " deterministic. If not, please file a bug with details.", + analyzer, + ) def trace_and_write_v2_saved_model( @@ -559,126 +599,157 @@ def trace_and_write_v2_saved_model( ], input_signature: Mapping[str, tf.TypeSpec], base_temp_dir: Optional[str], - baseline_analyzers_fingerprint: Mapping[ - str, graph_tools.AnalyzersFingerprint - ], + baseline_analyzers_fingerprint: Mapping[str, graph_tools.AnalyzersFingerprint], tensor_replacement_map: Optional[Dict[str, tf.Tensor]], output_keys_to_name_map: Optional[Dict[str, str]], save_options: Optional[tf.saved_model.SaveOptions], ): - """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function. - - The SavedModel written contains a method called `transform_fn` that - represents the traced `preprocessing_fn`. Additionally, if this is the final - SavedModel being written out, it will contain a method called `metadata_fn` - that provides deferred schema annotations. - - Args: - saved_model_dir: Path to write SavedModel to. - preprocessing_fn: A user defined python function to be traced. - input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`. - base_temp_dir: Base path to write temporary artifacts to. - baseline_analyzers_fingerprint: A mapping from analyzer name to a set of - paths that define its fingerprint. - tensor_replacement_map: A map from placeholder tensor names to their - evaluated replacement tensors. - output_keys_to_name_map: A map from output dictionary keys to the names of - the tensors that they represent. - save_options: The options to use when saving the saved_model. - - Returns: - A tuple containing a pair of `tf.ConcreteFunction`s: - 1. The traced preprocessing_fn. - 2. A metadata_fn that returns a dictionary containing the deferred - annotations added to the graph when invoked with any valid input. - - Raises: - RuntimeError: if analyzers in `preprocessing_fn` are encountered in a - non-deterministic order. - """ - concrete_transform_fn = _trace_and_write_transform_fn( - saved_model_dir, preprocessing_fn, input_signature, base_temp_dir, - tensor_replacement_map, output_keys_to_name_map, save_options) - structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( - concrete_transform_fn.graph) - _validate_analyzers_fingerprint(baseline_analyzers_fingerprint, - concrete_transform_fn.graph, - structured_inputs) - - # If the `TENSOR_REPLACEMENTS` graph collection is empty, all TFT analyzers - # in the `preprocessing_fn` have already been evaluated. - if not concrete_transform_fn.graph.get_collection( - analyzer_nodes.TENSOR_REPLACEMENTS): - metadata = _trace_and_get_metadata(concrete_transform_fn, structured_inputs, - preprocessing_fn, base_temp_dir, - tensor_replacement_map) - metadata_io.write_metadata(metadata, - os.path.join(saved_model_dir, METADATA_DIR_NAME)) + """Writes out a SavedModelV2 with preprocessing_fn traced using tf.function. + + The SavedModel written contains a method called `transform_fn` that + represents the traced `preprocessing_fn`. Additionally, if this is the final + SavedModel being written out, it will contain a method called `metadata_fn` + that provides deferred schema annotations. + + Args: + ---- + saved_model_dir: Path to write SavedModel to. + preprocessing_fn: A user defined python function to be traced. + input_signature: TypeSpecs describing the inputs to the `preprocessing_fn`. + base_temp_dir: Base path to write temporary artifacts to. + baseline_analyzers_fingerprint: A mapping from analyzer name to a set of + paths that define its fingerprint. + tensor_replacement_map: A map from placeholder tensor names to their + evaluated replacement tensors. + output_keys_to_name_map: A map from output dictionary keys to the names of + the tensors that they represent. + save_options: The options to use when saving the saved_model. + + Returns: + ------- + A tuple containing a pair of `tf.ConcreteFunction`s: + 1. The traced preprocessing_fn. + 2. A metadata_fn that returns a dictionary containing the deferred + annotations added to the graph when invoked with any valid input. + + Raises: + ------ + RuntimeError: if analyzers in `preprocessing_fn` are encountered in a + non-deterministic order. + """ + concrete_transform_fn = _trace_and_write_transform_fn( + saved_model_dir, + preprocessing_fn, + input_signature, + base_temp_dir, + tensor_replacement_map, + output_keys_to_name_map, + save_options, + ) + structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( + concrete_transform_fn.graph + ) + _validate_analyzers_fingerprint( + baseline_analyzers_fingerprint, concrete_transform_fn.graph, structured_inputs + ) + + # If the `TENSOR_REPLACEMENTS` graph collection is empty, all TFT analyzers + # in the `preprocessing_fn` have already been evaluated. + if not concrete_transform_fn.graph.get_collection( + analyzer_nodes.TENSOR_REPLACEMENTS + ): + metadata = _trace_and_get_metadata( + concrete_transform_fn, + structured_inputs, + preprocessing_fn, + base_temp_dir, + tensor_replacement_map, + ) + metadata_io.write_metadata( + metadata, os.path.join(saved_model_dir, METADATA_DIR_NAME) + ) def _assert_no_analyzers_in_graph(graph): - if graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS): - raise RuntimeError('TFT analyzers found when tracing the given ' - '`preprocessing_fn`. Please use ' - '`tft.beam.AnalyzeDataset` to analyze this function.') - - -def analyze_in_place(preprocessing_fn, force_tf_compat_v1, feature_specs, - type_specs, transform_output_path): - """Analyzes the `preprocessing_fn` in-place without looking at the data. - - This should only be used if the `preprocessing_fn` contains no TFT - analyzers or TFT mappers that use analyzers. - - Writes out a transform function and transformed metadata to subdirs under - `transform_output_path`. - - Args: - preprocessing_fn: The tf.Transform preprocessing_fn. - force_tf_compat_v1: If True, call Transform's API to use Tensorflow in - tf.compat.v1 mode. - feature_specs: a Dict from input feature key to its feature spec. - type_specs: a Dict from input feature key to its type spec. - transform_output_path: An absolute path to write the output to. - - Raises: - RuntimeError if `preprocessing_fn` contains TFT analyzers. - """ - use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) - transform_fn_path = os.path.join(transform_output_path, - TFTransformOutput.TRANSFORM_FN_DIR) - if use_tf_compat_v1: - graph, structured_inputs, structured_outputs = ( - trace_preprocessing_function( - preprocessing_fn, feature_specs, use_tf_compat_v1=use_tf_compat_v1)) - _assert_no_analyzers_in_graph(graph) - with tf.compat.v1.Session(graph=graph) as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - sess.run(tf.compat.v1.tables_initializer()) - saved_transform_io.write_saved_transform_from_session( - sess, structured_inputs, structured_outputs, transform_fn_path) - - transformed_metadata = dataset_metadata.DatasetMetadata( - schema=schema_inference.infer_feature_schema(structured_outputs, - graph, sess)) - else: - concrete_transform_fn = _trace_and_write_transform_fn( - saved_model_dir=transform_fn_path, - preprocessing_fn=preprocessing_fn, - input_signature=type_specs, - base_temp_dir=None, - tensor_replacement_map=None, - output_keys_to_name_map=None, - save_options=None) - _assert_no_analyzers_in_graph(concrete_transform_fn.graph) - structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( - concrete_transform_fn.graph) - transformed_metadata = _trace_and_get_metadata( - concrete_transform_fn=concrete_transform_fn, - structured_inputs=structured_inputs, - preprocessing_fn=preprocessing_fn, - base_temp_dir=None, - tensor_replacement_map=None) - transformed_metadata_dir = os.path.join( - transform_output_path, TFTransformOutput.TRANSFORMED_METADATA_DIR) - metadata_io.write_metadata(transformed_metadata, transformed_metadata_dir) + if graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS): + raise RuntimeError( + "TFT analyzers found when tracing the given " + "`preprocessing_fn`. Please use " + "`tft.beam.AnalyzeDataset` to analyze this function." + ) + + +def analyze_in_place( + preprocessing_fn, + force_tf_compat_v1, + feature_specs, + type_specs, + transform_output_path, +): + """Analyzes the `preprocessing_fn` in-place without looking at the data. + + This should only be used if the `preprocessing_fn` contains no TFT + analyzers or TFT mappers that use analyzers. + + Writes out a transform function and transformed metadata to subdirs under + `transform_output_path`. + + Args: + ---- + preprocessing_fn: The tf.Transform preprocessing_fn. + force_tf_compat_v1: If True, call Transform's API to use Tensorflow in + tf.compat.v1 mode. + feature_specs: a Dict from input feature key to its feature spec. + type_specs: a Dict from input feature key to its type spec. + transform_output_path: An absolute path to write the output to. + + Raises: + ------ + RuntimeError if `preprocessing_fn` contains TFT analyzers. + """ + use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) + transform_fn_path = os.path.join( + transform_output_path, TFTransformOutput.TRANSFORM_FN_DIR + ) + if use_tf_compat_v1: + graph, structured_inputs, structured_outputs = trace_preprocessing_function( + preprocessing_fn, feature_specs, use_tf_compat_v1=use_tf_compat_v1 + ) + _assert_no_analyzers_in_graph(graph) + with tf.compat.v1.Session(graph=graph) as sess: + sess.run(tf.compat.v1.global_variables_initializer()) + sess.run(tf.compat.v1.tables_initializer()) + saved_transform_io.write_saved_transform_from_session( + sess, structured_inputs, structured_outputs, transform_fn_path + ) + + transformed_metadata = dataset_metadata.DatasetMetadata( + schema=schema_inference.infer_feature_schema( + structured_outputs, graph, sess + ) + ) + else: + concrete_transform_fn = _trace_and_write_transform_fn( + saved_model_dir=transform_fn_path, + preprocessing_fn=preprocessing_fn, + input_signature=type_specs, + base_temp_dir=None, + tensor_replacement_map=None, + output_keys_to_name_map=None, + save_options=None, + ) + _assert_no_analyzers_in_graph(concrete_transform_fn.graph) + structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( + concrete_transform_fn.graph + ) + transformed_metadata = _trace_and_get_metadata( + concrete_transform_fn=concrete_transform_fn, + structured_inputs=structured_inputs, + preprocessing_fn=preprocessing_fn, + base_temp_dir=None, + tensor_replacement_map=None, + ) + transformed_metadata_dir = os.path.join( + transform_output_path, TFTransformOutput.TRANSFORMED_METADATA_DIR + ) + metadata_io.write_metadata(transformed_metadata, transformed_metadata_dir) diff --git a/tensorflow_transform/impl_helper_test.py b/tensorflow_transform/impl_helper_test.py index ff3a0c9..a4fe632 100644 --- a/tensorflow_transform/impl_helper_test.py +++ b/tensorflow_transform/impl_helper_test.py @@ -17,962 +17,979 @@ import os import numpy as np -from packaging import version import pyarrow as pa import tensorflow as tf -from tensorflow_transform import analyzers -from tensorflow_transform import impl_helper -from tensorflow_transform import schema_inference -from tensorflow_transform import test_case +from packaging import version + +from tensorflow_transform import analyzers, impl_helper, schema_inference, test_case from tensorflow_transform.output_wrapper import TFTransformOutput from tensorflow_transform.tf_metadata import schema_utils -def _sparse_index_name(index, tensor_name='sparse'): - return schema_inference.SPARSE_INDICES_NAME_TEMPLATE.format( - tensor_name=tensor_name, index=index) +def _sparse_index_name(index, tensor_name="sparse"): + return schema_inference.SPARSE_INDICES_NAME_TEMPLATE.format( + tensor_name=tensor_name, index=index + ) -def _sparse_value_name(tensor_name='sparse'): - return schema_inference.SPARSE_VALUES_NAME_TEMPLATE.format( - tensor_name=tensor_name) +def _sparse_value_name(tensor_name="sparse"): + return schema_inference.SPARSE_VALUES_NAME_TEMPLATE.format(tensor_name=tensor_name) _FEATURE_SPEC = { - 'a': - tf.io.FixedLenFeature([], tf.int64), - 'b': - tf.io.FixedLenFeature([], tf.float32), - 'c': - tf.io.FixedLenFeature([1], tf.float32), - 'd': - tf.io.FixedLenFeature([2, 2], tf.float32), - 'e': - tf.io.VarLenFeature(tf.string), - 'f': - tf.io.SparseFeature( - _sparse_index_name(0, 'f'), _sparse_value_name('f'), tf.float32, - 10), - 'g': - tf.io.SparseFeature([_sparse_index_name(idx, 'g') for idx in range(2)], - _sparse_value_name('g'), tf.float32, [2, 10]), - 'h': - tf.io.RaggedFeature(tf.float32, value_key='h_val'), - 'i': - tf.io.RaggedFeature( - tf.float32, - value_key='i_val', - partitions=[tf.io.RaggedFeature.RowLengths('i_row_lengths1')]), # pytype: disable=attribute-error - 'j': - tf.io.RaggedFeature( - tf.float32, - value_key='j_val', - partitions=[ - tf.io.RaggedFeature.RowLengths('j_row_lengths1'), # pytype: disable=attribute-error - tf.io.RaggedFeature.RowLengths('j_row_lengths2'), # pytype: disable=attribute-error - ]), - 'k': - tf.io.RaggedFeature( - tf.int64, - value_key='k_val', - partitions=[ - tf.io.RaggedFeature.UniformRowLength(3), # pytype: disable=attribute-error - ]), - 'l': - tf.io.RaggedFeature( - tf.int64, - value_key='l_val', - partitions=[ - tf.io.RaggedFeature.RowLengths('l_row_lengths1'), # pytype: disable=attribute-error - tf.io.RaggedFeature.UniformRowLength(2), # pytype: disable=attribute-error - ]), + "a": tf.io.FixedLenFeature([], tf.int64), + "b": tf.io.FixedLenFeature([], tf.float32), + "c": tf.io.FixedLenFeature([1], tf.float32), + "d": tf.io.FixedLenFeature([2, 2], tf.float32), + "e": tf.io.VarLenFeature(tf.string), + "f": tf.io.SparseFeature( + _sparse_index_name(0, "f"), _sparse_value_name("f"), tf.float32, 10 + ), + "g": tf.io.SparseFeature( + [_sparse_index_name(idx, "g") for idx in range(2)], + _sparse_value_name("g"), + tf.float32, + [2, 10], + ), + "h": tf.io.RaggedFeature(tf.float32, value_key="h_val"), + "i": tf.io.RaggedFeature( + tf.float32, + value_key="i_val", + partitions=[tf.io.RaggedFeature.RowLengths("i_row_lengths1")], + ), # pytype: disable=attribute-error + "j": tf.io.RaggedFeature( + tf.float32, + value_key="j_val", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "j_row_lengths1" + ), # pytype: disable=attribute-error + tf.io.RaggedFeature.RowLengths( + "j_row_lengths2" + ), # pytype: disable=attribute-error + ], + ), + "k": tf.io.RaggedFeature( + tf.int64, + value_key="k_val", + partitions=[ + tf.io.RaggedFeature.UniformRowLength(3), # pytype: disable=attribute-error + ], + ), + "l": tf.io.RaggedFeature( + tf.int64, + value_key="l_val", + partitions=[ + tf.io.RaggedFeature.RowLengths( + "l_row_lengths1" + ), # pytype: disable=attribute-error + tf.io.RaggedFeature.UniformRowLength(2), # pytype: disable=attribute-error + ], + ), } _FEED_DICT = { - 'a': - np.array([100, 100]), - 'b': - np.array([1.0, 2.0], np.float32), - 'c': - np.array([[2.0], [4.0]], np.float32), - 'd': - np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], - np.float32), - 'e': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]), - values=np.array([b'doe', b'a', b'deer', b'a', b'female', b'deer'], - dtype=object), - dense_shape=(2, 3)), - 'f': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 2), (0, 4), (0, 8)]), - values=np.array([10.0, 20.0, 30.0], np.float32), - dense_shape=(2, 10)), - 'g': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0, 3), (0, 1, 5), (0, 1, 9)]), - values=np.array([110.0, 210.0, 310.0], np.float32), - dense_shape=(2, 2, 10)), - 'h': - tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5.], dtype=np.float32), - row_splits=np.array([0, 3, 5])), - 'i': - tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 3., 3., 1.], np.float32), - row_splits=np.array([0, 0, 3, 6])), - row_splits=np.array([0, 2, 3])), - 'j': - tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 3, 4, 5])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - 'k': - tf.compat.v1.ragged.RaggedTensorValue( - values=np.reshape(np.arange(12, dtype=np.int64), (4, 3)), - row_splits=np.array([0, 3, 4])), - 'l': - tf.compat.v1.ragged.RaggedTensorValue( + "a": np.array([100, 100]), + "b": np.array([1.0, 2.0], np.float32), + "c": np.array([[2.0], [4.0]], np.float32), + "d": np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], np.float32), + "e": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]), + values=np.array( + [b"doe", b"a", b"deer", b"a", b"female", b"deer"], dtype=object + ), + dense_shape=(2, 3), + ), + "f": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 2), (0, 4), (0, 8)]), + values=np.array([10.0, 20.0, 30.0], np.float32), + dense_shape=(2, 10), + ), + "g": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0, 3), (0, 1, 5), (0, 1, 9)]), + values=np.array([110.0, 210.0, 310.0], np.float32), + dense_shape=(2, 2, 10), + ), + "h": tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32), + row_splits=np.array([0, 3, 5]), + ), + "i": tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.0, 2.0, 3.0, 3.0, 3.0, 1.0], np.float32), + row_splits=np.array([0, 0, 3, 6]), + ), + row_splits=np.array([0, 2, 3]), + ), + "j": tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.reshape(np.arange(8, dtype=np.int64), (4, 2)), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), + values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + "k": tf.compat.v1.ragged.RaggedTensorValue( + values=np.reshape(np.arange(12, dtype=np.int64), (4, 3)), + row_splits=np.array([0, 3, 4]), + ), + "l": tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.reshape(np.arange(8, dtype=np.int64), (4, 2)), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), } _MULTIPLE_FEATURES_CASE_RECORD_BATCH = { - 'a': - pa.array([[100], [100]], type=pa.large_list(pa.int64())), - 'b': - pa.array([[1.0], [2.0]], type=pa.large_list(pa.float32())), - 'c': - pa.array([[2.0], [4.0]], type=pa.large_list(pa.float32())), - 'd': - pa.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], - type=pa.large_list(pa.float32())), - 'e': - pa.array([[b'doe', b'a', b'deer'], [b'a', b'female', b'deer']], - type=pa.large_list(pa.large_binary())), - _sparse_index_name(0, 'f'): - pa.array([[2, 4, 8], []], type=pa.large_list(pa.int64())), - _sparse_value_name('f'): - pa.array([[10.0, 20.0, 30.0], []], type=pa.large_list(pa.float32())), - _sparse_index_name(0, 'g'): - pa.array([[0, 1, 1], []], type=pa.large_list(pa.int64())), - _sparse_index_name(1, 'g'): - pa.array([[3, 5, 9], []], type=pa.large_list(pa.int64())), - _sparse_value_name('g'): - pa.array([[110.0, 210.0, 310.0], []], type=pa.large_list(pa.float32())), - 'h': - pa.array([[1., 2., 3.], [4., 5.]], type=pa.large_list(pa.float32())), - 'i': - pa.array([[[], [1., 2., 3.]], [[3., 3., 1.]]], - type=pa.large_list(pa.large_list(pa.float32()))), - 'j': - pa.array([[[[1., 2.], [3.]], [[4.]]], [[[5.]]]], - type=pa.large_list(pa.large_list(pa.large_list( - pa.float32())))), - 'k': - pa.array([[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11]], - type=pa.large_list(pa.int64())), - 'l': - pa.array([[[0, 1, 2, 3], [4, 5]], [[6, 7]]], - type=pa.large_list(pa.large_list(pa.int64()))), + "a": pa.array([[100], [100]], type=pa.large_list(pa.int64())), + "b": pa.array([[1.0], [2.0]], type=pa.large_list(pa.float32())), + "c": pa.array([[2.0], [4.0]], type=pa.large_list(pa.float32())), + "d": pa.array( + [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]], type=pa.large_list(pa.float32()) + ), + "e": pa.array( + [[b"doe", b"a", b"deer"], [b"a", b"female", b"deer"]], + type=pa.large_list(pa.large_binary()), + ), + _sparse_index_name(0, "f"): pa.array( + [[2, 4, 8], []], type=pa.large_list(pa.int64()) + ), + _sparse_value_name("f"): pa.array( + [[10.0, 20.0, 30.0], []], type=pa.large_list(pa.float32()) + ), + _sparse_index_name(0, "g"): pa.array( + [[0, 1, 1], []], type=pa.large_list(pa.int64()) + ), + _sparse_index_name(1, "g"): pa.array( + [[3, 5, 9], []], type=pa.large_list(pa.int64()) + ), + _sparse_value_name("g"): pa.array( + [[110.0, 210.0, 310.0], []], type=pa.large_list(pa.float32()) + ), + "h": pa.array([[1.0, 2.0, 3.0], [4.0, 5.0]], type=pa.large_list(pa.float32())), + "i": pa.array( + [[[], [1.0, 2.0, 3.0]], [[3.0, 3.0, 1.0]]], + type=pa.large_list(pa.large_list(pa.float32())), + ), + "j": pa.array( + [[[[1.0, 2.0], [3.0]], [[4.0]]], [[[5.0]]]], + type=pa.large_list(pa.large_list(pa.large_list(pa.float32()))), + ), + "k": pa.array( + [[0, 1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11]], type=pa.large_list(pa.int64()) + ), + "l": pa.array( + [[[0, 1, 2, 3], [4, 5]], [[6, 7]]], + type=pa.large_list(pa.large_list(pa.int64())), + ), } _ROUNDTRIP_CASES = [ dict( - testcase_name='multiple_features', + testcase_name="multiple_features", feature_spec=_FEATURE_SPEC, - instances=[{ - 'a': 100, - 'b': 1.0, - 'c': [2.0], - 'd': [[1.0, 2.0], [3.0, 4.0]], - 'e': [b'doe', b'a', b'deer'], - _sparse_index_name(0, 'f'): [2, 4, 8], - _sparse_value_name('f'): [10.0, 20.0, 30.0], - _sparse_index_name(0, 'g'): [0, 1, 1], - _sparse_index_name(1, 'g'): [3, 5, 9], - _sparse_value_name('g'): [110.0, 210.0, 310.0], - 'h_val': [1., 2., 3.], - 'i_val': [1., 2., 3.], - 'i_row_lengths1': [0, 3], - 'j_val': [1., 2., 3., 4.], - 'j_row_lengths1': [2, 1], - 'j_row_lengths2': [2, 1, 1], - 'k_val': [0, 1, 2, 3, 4, 5, 6, 7, 8], - 'l_val': [0, 1, 2, 3, 4, 5], - 'l_row_lengths1': [2, 1], - }, { - 'a': 100, - 'b': 2.0, - 'c': [4.0], - 'd': [[5.0, 6.0], [7.0, 8.0]], - 'e': [b'a', b'female', b'deer'], - _sparse_index_name(0, 'f'): [], - _sparse_value_name('f'): [], - _sparse_index_name(0, 'g'): [], - _sparse_index_name(1, 'g'): [], - _sparse_value_name('g'): [], - 'h_val': [4., 5.], - 'i_val': [3., 3., 1.], - 'i_row_lengths1': [3], - 'j_val': [5.], - 'j_row_lengths1': [1], - 'j_row_lengths2': [1], - 'k_val': [9, 10, 11], - 'l_val': [6, 7], - 'l_row_lengths1': [1], - }], + instances=[ + { + "a": 100, + "b": 1.0, + "c": [2.0], + "d": [[1.0, 2.0], [3.0, 4.0]], + "e": [b"doe", b"a", b"deer"], + _sparse_index_name(0, "f"): [2, 4, 8], + _sparse_value_name("f"): [10.0, 20.0, 30.0], + _sparse_index_name(0, "g"): [0, 1, 1], + _sparse_index_name(1, "g"): [3, 5, 9], + _sparse_value_name("g"): [110.0, 210.0, 310.0], + "h_val": [1.0, 2.0, 3.0], + "i_val": [1.0, 2.0, 3.0], + "i_row_lengths1": [0, 3], + "j_val": [1.0, 2.0, 3.0, 4.0], + "j_row_lengths1": [2, 1], + "j_row_lengths2": [2, 1, 1], + "k_val": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "l_val": [0, 1, 2, 3, 4, 5], + "l_row_lengths1": [2, 1], + }, + { + "a": 100, + "b": 2.0, + "c": [4.0], + "d": [[5.0, 6.0], [7.0, 8.0]], + "e": [b"a", b"female", b"deer"], + _sparse_index_name(0, "f"): [], + _sparse_value_name("f"): [], + _sparse_index_name(0, "g"): [], + _sparse_index_name(1, "g"): [], + _sparse_value_name("g"): [], + "h_val": [4.0, 5.0], + "i_val": [3.0, 3.0, 1.0], + "i_row_lengths1": [3], + "j_val": [5.0], + "j_row_lengths1": [1], + "j_row_lengths2": [1], + "k_val": [9, 10, 11], + "l_val": [6, 7], + "l_row_lengths1": [1], + }, + ], record_batch=_MULTIPLE_FEATURES_CASE_RECORD_BATCH, - feed_dict=_FEED_DICT), + feed_dict=_FEED_DICT, + ), dict( - testcase_name='multiple_features_ndarrays', + testcase_name="multiple_features_ndarrays", feature_spec=_FEATURE_SPEC, - instances=[{ - 'a': - np.int64(100), - 'b': - np.array(1.0, np.float32), - 'c': - np.array([2.0], np.float32), - 'd': - np.array([[1.0, 2.0], [3.0, 4.0]], np.float32), - 'e': [b'doe', b'a', b'deer'], - _sparse_index_name(0, 'f'): - np.array([2, 4, 8]), - _sparse_value_name('f'): - np.array([10.0, 20.0, 30.0], np.float32), - _sparse_index_name(0, 'g'): - np.array([0, 1, 1]), - _sparse_index_name(1, 'g'): - np.array([3, 5, 9]), - _sparse_value_name('g'): - np.array([110.0, 210.0, 310.0], np.float32), - 'h_val': - np.array([1., 2., 3.], np.float32), - 'i_val': - np.array([1., 2., 3.], np.float32), - 'i_row_lengths1': - np.array([0, 3]), - 'j_val': - np.array([1., 2., 3., 4], np.float32), - 'j_row_lengths1': - np.array([2, 1]), - 'j_row_lengths2': - np.array([2, 1, 1]), - 'k_val': - np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]), - 'l_val': - np.array([0, 1, 2, 3, 4, 5]), - 'l_row_lengths1': - np.array([2, 1]), - }, { - 'a': np.int64(100), - 'b': np.array(2.0, np.float32), - 'c': np.array([4.0], np.float32), - 'd': np.array([[5.0, 6.0], [7.0, 8.0]], np.float32), - 'e': [b'a', b'female', b'deer'], - _sparse_index_name(0, 'f'): np.array([], np.int32), - _sparse_value_name('f'): np.array([], np.float32), - _sparse_index_name(0, 'g'): np.array([], np.float32), - _sparse_index_name(1, 'g'): np.array([], np.float32), - _sparse_value_name('g'): np.array([], np.float32), - 'h_val': np.array([4., 5.], np.float32), - 'i_val': np.array([3., 3., 1.], np.float32), - 'i_row_lengths1': np.array([3]), - 'j_val': np.array([5.], np.float32), - 'j_row_lengths1': np.array([1]), - 'j_row_lengths2': np.array([1]), - 'k_val': np.array([9, 10, 11]), - 'l_val': np.array([6, 7]), - 'l_row_lengths1': np.array([1]), - }], + instances=[ + { + "a": np.int64(100), + "b": np.array(1.0, np.float32), + "c": np.array([2.0], np.float32), + "d": np.array([[1.0, 2.0], [3.0, 4.0]], np.float32), + "e": [b"doe", b"a", b"deer"], + _sparse_index_name(0, "f"): np.array([2, 4, 8]), + _sparse_value_name("f"): np.array([10.0, 20.0, 30.0], np.float32), + _sparse_index_name(0, "g"): np.array([0, 1, 1]), + _sparse_index_name(1, "g"): np.array([3, 5, 9]), + _sparse_value_name("g"): np.array([110.0, 210.0, 310.0], np.float32), + "h_val": np.array([1.0, 2.0, 3.0], np.float32), + "i_val": np.array([1.0, 2.0, 3.0], np.float32), + "i_row_lengths1": np.array([0, 3]), + "j_val": np.array([1.0, 2.0, 3.0, 4], np.float32), + "j_row_lengths1": np.array([2, 1]), + "j_row_lengths2": np.array([2, 1, 1]), + "k_val": np.array([0, 1, 2, 3, 4, 5, 6, 7, 8]), + "l_val": np.array([0, 1, 2, 3, 4, 5]), + "l_row_lengths1": np.array([2, 1]), + }, + { + "a": np.int64(100), + "b": np.array(2.0, np.float32), + "c": np.array([4.0], np.float32), + "d": np.array([[5.0, 6.0], [7.0, 8.0]], np.float32), + "e": [b"a", b"female", b"deer"], + _sparse_index_name(0, "f"): np.array([], np.int32), + _sparse_value_name("f"): np.array([], np.float32), + _sparse_index_name(0, "g"): np.array([], np.float32), + _sparse_index_name(1, "g"): np.array([], np.float32), + _sparse_value_name("g"): np.array([], np.float32), + "h_val": np.array([4.0, 5.0], np.float32), + "i_val": np.array([3.0, 3.0, 1.0], np.float32), + "i_row_lengths1": np.array([3]), + "j_val": np.array([5.0], np.float32), + "j_row_lengths1": np.array([1]), + "j_row_lengths2": np.array([1]), + "k_val": np.array([9, 10, 11]), + "l_val": np.array([6, 7]), + "l_row_lengths1": np.array([1]), + }, + ], record_batch=_MULTIPLE_FEATURES_CASE_RECORD_BATCH, - feed_dict=_FEED_DICT), + feed_dict=_FEED_DICT, + ), dict( - testcase_name='empty_var_len_feature', - feature_spec={'varlen': tf.io.VarLenFeature(tf.string)}, - instances=[{ - 'varlen': [] - }], + testcase_name="empty_var_len_feature", + feature_spec={"varlen": tf.io.VarLenFeature(tf.string)}, + instances=[{"varlen": []}], record_batch={ - 'varlen': pa.array([[]], type=pa.large_list(pa.large_binary())), + "varlen": pa.array([[]], type=pa.large_list(pa.large_binary())), }, feed_dict={ - 'varlen': - tf.compat.v1.SparseTensorValue( - indices=np.empty([0, 2]), - values=np.array([], dtype=object), - dense_shape=[1, 0]) - }), + "varlen": tf.compat.v1.SparseTensorValue( + indices=np.empty([0, 2]), + values=np.array([], dtype=object), + dense_shape=[1, 0], + ) + }, + ), # Mainly to test the empty-ndarray optimization though this is also # exercised by empty_var_len_feature dict( - testcase_name='some_empty_int_var_len_feature', - feature_spec={'varlen': tf.io.VarLenFeature(tf.int64)}, - instances=[{ - 'varlen': [0] - }, { - 'varlen': [] - }, { - 'varlen': [1] - }, { - 'varlen': [] - }], + testcase_name="some_empty_int_var_len_feature", + feature_spec={"varlen": tf.io.VarLenFeature(tf.int64)}, + instances=[{"varlen": [0]}, {"varlen": []}, {"varlen": [1]}, {"varlen": []}], record_batch={ - 'varlen': - pa.array([[0], [], [1], []], type=pa.large_list(pa.int64())), + "varlen": pa.array([[0], [], [1], []], type=pa.large_list(pa.int64())), }, feed_dict={ - 'varlen': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0), (2, 0)]), - values=np.array([0, 1], np.int64), - dense_shape=(4, 1)), - }), + "varlen": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0), (2, 0)]), + values=np.array([0, 1], np.int64), + dense_shape=(4, 1), + ), + }, + ), dict( - testcase_name='some_empty_float_var_len_feature', - feature_spec={'varlen': tf.io.VarLenFeature(tf.float32)}, - instances=[{ - 'varlen': [0.5] - }, { - 'varlen': [] - }, { - 'varlen': [1.5] - }, { - 'varlen': [] - }], + testcase_name="some_empty_float_var_len_feature", + feature_spec={"varlen": tf.io.VarLenFeature(tf.float32)}, + instances=[ + {"varlen": [0.5]}, + {"varlen": []}, + {"varlen": [1.5]}, + {"varlen": []}, + ], record_batch={ - 'varlen': - pa.array([[0.5], [], [1.5], []], - type=pa.large_list(pa.float32())), + "varlen": pa.array( + [[0.5], [], [1.5], []], type=pa.large_list(pa.float32()) + ), }, feed_dict={ - 'varlen': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0), (2, 0)]), - values=np.array([0.5, 1.5], np.float32), - dense_shape=(4, 1)), - }), + "varlen": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0), (2, 0)]), + values=np.array([0.5, 1.5], np.float32), + dense_shape=(4, 1), + ), + }, + ), dict( - testcase_name='some_empty_string_var_len_feature', - feature_spec={'varlen': tf.io.VarLenFeature(tf.string)}, - instances=[{ - 'varlen': [b'a'] - }, { - 'varlen': [] - }, { - 'varlen': [b'b'] - }, { - 'varlen': [] - }], + testcase_name="some_empty_string_var_len_feature", + feature_spec={"varlen": tf.io.VarLenFeature(tf.string)}, + instances=[ + {"varlen": [b"a"]}, + {"varlen": []}, + {"varlen": [b"b"]}, + {"varlen": []}, + ], record_batch={ - 'varlen': - pa.array([[b'a'], [], [b'b'], []], - type=pa.large_list(pa.large_binary())), + "varlen": pa.array( + [[b"a"], [], [b"b"], []], type=pa.large_list(pa.large_binary()) + ), }, feed_dict={ - 'varlen': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0), (2, 0)]), - values=np.array([b'a', b'b'], object), - dense_shape=(4, 1)), - }), + "varlen": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0), (2, 0)]), + values=np.array([b"a", b"b"], object), + dense_shape=(4, 1), + ), + }, + ), dict( - testcase_name='empty_sparse_feature', + testcase_name="empty_sparse_feature", feature_spec={ - 'sparse': - tf.io.SparseFeature( - _sparse_index_name(0), _sparse_value_name(), tf.string, 10) + "sparse": tf.io.SparseFeature( + _sparse_index_name(0), _sparse_value_name(), tf.string, 10 + ) }, - instances=[{ - _sparse_index_name(0): [], - _sparse_value_name(): np.array([], object) - }], + instances=[ + {_sparse_index_name(0): [], _sparse_value_name(): np.array([], object)} + ], record_batch={ - _sparse_index_name(0): - pa.array([[]], type=pa.large_list(pa.int64())), - _sparse_value_name(): - pa.array([[]], type=pa.large_list(pa.large_binary())), + _sparse_index_name(0): pa.array([[]], type=pa.large_list(pa.int64())), + _sparse_value_name(): pa.array([[]], type=pa.large_list(pa.large_binary())), }, feed_dict={ - 'sparse': - tf.compat.v1.SparseTensorValue( - indices=np.empty([0, 2]), - values=np.array([], object), - dense_shape=[1, 10]) - }), + "sparse": tf.compat.v1.SparseTensorValue( + indices=np.empty([0, 2]), + values=np.array([], object), + dense_shape=[1, 10], + ) + }, + ), dict( - testcase_name='non_ragged_sparse_feature', + testcase_name="non_ragged_sparse_feature", feature_spec={ - 'sparse': - tf.io.SparseFeature( - _sparse_index_name(0), _sparse_value_name(), tf.float32, 10) + "sparse": tf.io.SparseFeature( + _sparse_index_name(0), _sparse_value_name(), tf.float32, 10 + ) }, - instances=[{ - _sparse_index_name(0): [], - _sparse_value_name(): np.array([], np.float32) - }, { - _sparse_index_name(0): [9], - _sparse_value_name(): np.array([0.3], np.float32) - }], + instances=[ + {_sparse_index_name(0): [], _sparse_value_name(): np.array([], np.float32)}, + { + _sparse_index_name(0): [9], + _sparse_value_name(): np.array([0.3], np.float32), + }, + ], record_batch={ - _sparse_index_name(0): - pa.array([[], [9]], type=pa.large_list(pa.int64())), - _sparse_value_name(): - pa.array([[], [0.3]], type=pa.large_list(pa.float32())), + _sparse_index_name(0): pa.array([[], [9]], type=pa.large_list(pa.int64())), + _sparse_value_name(): pa.array( + [[], [0.3]], type=pa.large_list(pa.float32()) + ), }, feed_dict={ - 'sparse': - tf.compat.v1.SparseTensorValue( - indices=np.array([[1, 9]]), - values=np.array([0.3], np.float32), - dense_shape=[2, 10]) - }), + "sparse": tf.compat.v1.SparseTensorValue( + indices=np.array([[1, 9]]), + values=np.array([0.3], np.float32), + dense_shape=[2, 10], + ) + }, + ), dict( - testcase_name='2d_sparse_feature', + testcase_name="2d_sparse_feature", feature_spec={ - 'sparse': - tf.io.SparseFeature( - [_sparse_index_name(idx) for idx in range(2)], - _sparse_value_name(), tf.float32, [10, 11]) + "sparse": tf.io.SparseFeature( + [_sparse_index_name(idx) for idx in range(2)], + _sparse_value_name(), + tf.float32, + [10, 11], + ) }, - instances=[{ - _sparse_index_name(0): [], - _sparse_index_name(1): [], - _sparse_value_name(): np.array([], np.float32) - }, { - _sparse_index_name(0): [9], - _sparse_index_name(1): [7], - _sparse_value_name(): np.array([0.3], np.float32) - }], + instances=[ + { + _sparse_index_name(0): [], + _sparse_index_name(1): [], + _sparse_value_name(): np.array([], np.float32), + }, + { + _sparse_index_name(0): [9], + _sparse_index_name(1): [7], + _sparse_value_name(): np.array([0.3], np.float32), + }, + ], record_batch={ - _sparse_index_name(0): - pa.array([[], [9]], type=pa.large_list(pa.int64())), - _sparse_index_name(1): - pa.array([[], [7]], type=pa.large_list(pa.int64())), - _sparse_value_name(): - pa.array([[], [0.3]], type=pa.large_list(pa.float32())), + _sparse_index_name(0): pa.array([[], [9]], type=pa.large_list(pa.int64())), + _sparse_index_name(1): pa.array([[], [7]], type=pa.large_list(pa.int64())), + _sparse_value_name(): pa.array( + [[], [0.3]], type=pa.large_list(pa.float32()) + ), }, feed_dict={ - 'sparse': - tf.compat.v1.SparseTensorValue( - indices=np.array([[1, 9, 7]]), - values=np.array([0.3], np.float32), - dense_shape=[2, 10, 11]) - }), + "sparse": tf.compat.v1.SparseTensorValue( + indices=np.array([[1, 9, 7]]), + values=np.array([0.3], np.float32), + dense_shape=[2, 10, 11], + ) + }, + ), ] # Non-canonical inputs that will not be the output of to_instance_dicts but # are valid inputs to make_feed_dict. _MAKE_FEED_DICT_CASES = [ dict( - testcase_name='none_feature', + testcase_name="none_feature", feature_spec={ - 'varlen_feature': tf.io.VarLenFeature(tf.int64), + "varlen_feature": tf.io.VarLenFeature(tf.int64), }, - instances=[{ - 'varlen_feature': [] - }, { - 'varlen_feature': None - }, { - 'varlen_feature': [1, 2] - }], + instances=[ + {"varlen_feature": []}, + {"varlen_feature": None}, + {"varlen_feature": [1, 2]}, + ], feed_dict={ - 'varlen_feature': - tf.compat.v1.SparseTensorValue( - indices=np.array([(2, 0), (2, 1)]), - values=np.array([1, 2]), - dense_shape=[3, 2]) - }), + "varlen_feature": tf.compat.v1.SparseTensorValue( + indices=np.array([(2, 0), (2, 1)]), + values=np.array([1, 2]), + dense_shape=[3, 2], + ) + }, + ), ] _TO_INSTANCE_DICT_ERROR_CASES = [ dict( - testcase_name='var_len_with_rank_not_2', - feature_spec={'a': tf.io.VarLenFeature(tf.float32)}, + testcase_name="var_len_with_rank_not_2", + feature_spec={"a": tf.io.VarLenFeature(tf.float32)}, feed_dict={ - 'a': tf.compat.v1.SparseTensorValue( + "a": tf.compat.v1.SparseTensorValue( indices=np.array([(0, 0, 1), (0, 0, 2), (0, 0, 3)]), values=np.array([10.0, 20.0, 30.0], np.float32), dense_shape=(1, 10, 10), ) }, error_msg=( - r'Expected SparseTensorSpec\(TensorShape\(' - r'\[(None|Dimension\(None\)), (None|Dimension\(None\))\]\)' + r"Expected SparseTensorSpec\(TensorShape\(" + r"\[(None|Dimension\(None\)), (None|Dimension\(None\))\]\)" ), error_type=TypeError, ), dict( - testcase_name='var_len_with_out_of_order_indices', - feature_spec={'a': tf.io.VarLenFeature(tf.float32)}, + testcase_name="var_len_with_out_of_order_indices", + feature_spec={"a": tf.io.VarLenFeature(tf.float32)}, feed_dict={ - 'a': tf.compat.v1.SparseTensorValue( + "a": tf.compat.v1.SparseTensorValue( indices=np.array([(0, 2), (2, 4), (1, 8)]), values=np.array([10.0, 20.0, 30.0], np.float32), dense_shape=(3, 20), ) }, - error_msg='The sparse indices must be sorted', + error_msg="The sparse indices must be sorted", error_type=AssertionError, ), dict( - testcase_name='var_len_with_different_batch_dim_sizes', + testcase_name="var_len_with_different_batch_dim_sizes", feature_spec={ - 'a': tf.io.VarLenFeature(tf.float32), - 'b': tf.io.VarLenFeature(tf.float32), + "a": tf.io.VarLenFeature(tf.float32), + "b": tf.io.VarLenFeature(tf.float32), }, feed_dict={ - 'a': tf.compat.v1.SparseTensorValue( + "a": tf.compat.v1.SparseTensorValue( indices=np.array([(0, 0)]), values=np.array([10.0], np.float32), dense_shape=(1, 20), ), - 'b': tf.compat.v1.SparseTensorValue( + "b": tf.compat.v1.SparseTensorValue( indices=np.array([(0, 0)]), values=np.array([10.0], np.float32), dense_shape=(2, 20), ), }, - error_msg=r'Arrays were not all the same length: \d vs \d', + error_msg=r"Arrays were not all the same length: \d vs \d", ), dict( - testcase_name='fixed_len_with_different_batch_dim_sizes', + testcase_name="fixed_len_with_different_batch_dim_sizes", feature_spec={ - 'a': tf.io.FixedLenFeature([], tf.float32), - 'b': tf.io.FixedLenFeature([], tf.float32), + "a": tf.io.FixedLenFeature([], tf.float32), + "b": tf.io.FixedLenFeature([], tf.float32), }, feed_dict={ - 'a': np.array([1], np.float32), - 'b': np.array([1, 2], np.float32), + "a": np.array([1], np.float32), + "b": np.array([1, 2], np.float32), }, - error_msg=r'Arrays were not all the same length: \d vs \d', + error_msg=r"Arrays were not all the same length: \d vs \d", ), ] _CONVERT_TO_ARROW_ERROR_CASES = [ dict( - testcase_name='var_len_with_rank_not_2', - feature_spec={'a': tf.io.VarLenFeature(tf.float32)}, + testcase_name="var_len_with_rank_not_2", + feature_spec={"a": tf.io.VarLenFeature(tf.float32)}, feed_dict={ - 'a': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0, 1), (0, 0, 2), (0, 0, 3)]), - values=np.array([10.0, 20.0, 30.0], np.float32), - dense_shape=(1, 10, 10)) + "a": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0, 1), (0, 0, 2), (0, 0, 3)]), + values=np.array([10.0, 20.0, 30.0], np.float32), + dense_shape=(1, 10, 10), + ) }, - error_msg=(r'Expected SparseTensorSpec\(TensorShape\(' - r'\[(None|Dimension\(None\)), (None|Dimension\(None\))\]\)'), - error_type=TypeError), + error_msg=( + r"Expected SparseTensorSpec\(TensorShape\(" + r"\[(None|Dimension\(None\)), (None|Dimension\(None\))\]\)" + ), + error_type=TypeError, + ), dict( - testcase_name='var_len_with_out_of_order_indices', - feature_spec={'a': tf.io.VarLenFeature(tf.float32)}, + testcase_name="var_len_with_out_of_order_indices", + feature_spec={"a": tf.io.VarLenFeature(tf.float32)}, feed_dict={ - 'a': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 2), (2, 4), (1, 8)]), - values=np.array([10.0, 20.0, 30.0], np.float32), - dense_shape=(3, 20)) + "a": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 2), (2, 4), (1, 8)]), + values=np.array([10.0, 20.0, 30.0], np.float32), + dense_shape=(3, 20), + ) }, - error_msg='The sparse indices must be sorted', - error_type=AssertionError), + error_msg="The sparse indices must be sorted", + error_type=AssertionError, + ), dict( - testcase_name='var_len_with_different_batch_dim_sizes', + testcase_name="var_len_with_different_batch_dim_sizes", feature_spec={ - 'a': tf.io.VarLenFeature(tf.float32), - 'b': tf.io.VarLenFeature(tf.float32), + "a": tf.io.VarLenFeature(tf.float32), + "b": tf.io.VarLenFeature(tf.float32), }, feed_dict={ - 'a': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0)]), - values=np.array([10.0], np.float32), - dense_shape=(1, 20)), - 'b': - tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0)]), - values=np.array([10.0], np.float32), - dense_shape=(2, 20)), + "a": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0)]), + values=np.array([10.0], np.float32), + dense_shape=(1, 20), + ), + "b": tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0)]), + values=np.array([10.0], np.float32), + dense_shape=(2, 20), + ), }, - error_msg='Arrays were not all the same length'), + error_msg="Arrays were not all the same length", + ), dict( - testcase_name='fixed_len_with_different_batch_dim_sizes', + testcase_name="fixed_len_with_different_batch_dim_sizes", feature_spec={ - 'a': tf.io.FixedLenFeature([], tf.float32), - 'b': tf.io.FixedLenFeature([], tf.float32), + "a": tf.io.FixedLenFeature([], tf.float32), + "b": tf.io.FixedLenFeature([], tf.float32), }, feed_dict={ - 'a': np.array([1], dtype=np.float32), - 'b': np.array([1, 2], dtype=np.float32) + "a": np.array([1], dtype=np.float32), + "b": np.array([1, 2], dtype=np.float32), }, - error_msg=('Arrays were not all the same length')), + error_msg=("Arrays were not all the same length"), + ), ] def _ragged_tensor_from_value(value): - if isinstance(value, tf.compat.v1.ragged.RaggedTensorValue): - return tf.RaggedTensor.from_row_splits( - values=_ragged_tensor_from_value(value.values), - row_splits=value.row_splits) - else: - # Recursion base case, value here is a numpy array. - return tf.constant(value) + if isinstance(value, tf.compat.v1.ragged.RaggedTensorValue): + return tf.RaggedTensor.from_row_splits( + values=_ragged_tensor_from_value(value.values), row_splits=value.row_splits + ) + else: + # Recursion base case, value here is a numpy array. + return tf.constant(value) def _eager_tensor_from_values(values): - result = {} - for key, value in values.items(): - if isinstance(value, tf.compat.v1.SparseTensorValue): - result[key] = tf.sparse.SparseTensor.from_value(value) - elif isinstance(value, tf.compat.v1.ragged.RaggedTensorValue): - result[key] = _ragged_tensor_from_value(value) - else: - result[key] = tf.constant(value) - return result + result = {} + for key, value in values.items(): + if isinstance(value, tf.compat.v1.SparseTensorValue): + result[key] = tf.sparse.SparseTensor.from_value(value) + elif isinstance(value, tf.compat.v1.ragged.RaggedTensorValue): + result[key] = _ragged_tensor_from_value(value) + else: + result[key] = tf.constant(value) + return result class ImplHelperTest(test_case.TransformTestCase): + def test_batched_placeholders_from_feature_spec(self): + feature_spec = { + "fixed_len_float": tf.io.FixedLenFeature([2, 3], tf.float32), + "fixed_len_string": tf.io.FixedLenFeature([], tf.string), + "_var_len_underscored": tf.io.VarLenFeature(tf.string), + "var_len_int": tf.io.VarLenFeature(tf.int64), + "sparse_1d": tf.io.SparseFeature("1d_idx", "1d_value", tf.int64, 7), + "sparse_2d": tf.io.SparseFeature( + ["2d_idx0", "2d_idx1"], "2d_value", tf.int64, [2, 17] + ), + } + with tf.compat.v1.Graph().as_default(): + features = impl_helper.batched_placeholders_from_specs(feature_spec) + self.assertCountEqual( + features.keys(), + [ + "fixed_len_float", + "fixed_len_string", + "var_len_int", + "_var_len_underscored", + "sparse_1d", + "sparse_2d", + ], + ) + self.assertIsInstance(features["fixed_len_float"], tf.Tensor) + self.assertEqual( + features["fixed_len_float"].get_shape().as_list(), [None, 2, 3] + ) + self.assertIsInstance(features["fixed_len_string"], tf.Tensor) + self.assertEqual(features["fixed_len_string"].get_shape().as_list(), [None]) + self.assertEqual(type(features["var_len_int"]), tf.SparseTensor) + self.assertEqual(features["var_len_int"].get_shape().as_list(), [None, None]) + self.assertEqual(type(features["_var_len_underscored"]), tf.SparseTensor) + self.assertEqual( + features["_var_len_underscored"].get_shape().as_list(), [None, None] + ) + self.assertEqual(type(features["sparse_1d"]), tf.SparseTensor) + self.assertEqual(type(features["sparse_2d"]), tf.SparseTensor) + if version.parse(tf.__version__) >= version.parse("2"): + self.assertEqual(features["sparse_1d"].get_shape().as_list(), [None, 7]) + self.assertEqual(features["sparse_2d"].get_shape().as_list(), [None, 2, 17]) + else: + self.assertEqual(features["sparse_1d"].get_shape().as_list(), [None, None]) + self.assertEqual( + features["sparse_2d"].get_shape().as_list(), [None, None, None] + ) - def test_batched_placeholders_from_feature_spec(self): - feature_spec = { - 'fixed_len_float': - tf.io.FixedLenFeature([2, 3], tf.float32), - 'fixed_len_string': - tf.io.FixedLenFeature([], tf.string), - '_var_len_underscored': - tf.io.VarLenFeature(tf.string), - 'var_len_int': - tf.io.VarLenFeature(tf.int64), - 'sparse_1d': - tf.io.SparseFeature('1d_idx', '1d_value', tf.int64, 7), - 'sparse_2d': - tf.io.SparseFeature(['2d_idx0', '2d_idx1'], '2d_value', tf.int64, - [2, 17]), - } - with tf.compat.v1.Graph().as_default(): - features = impl_helper.batched_placeholders_from_specs(feature_spec) - self.assertCountEqual(features.keys(), [ - 'fixed_len_float', - 'fixed_len_string', - 'var_len_int', - '_var_len_underscored', - 'sparse_1d', - 'sparse_2d', - ]) - self.assertIsInstance(features['fixed_len_float'], tf.Tensor) - self.assertEqual(features['fixed_len_float'].get_shape().as_list(), - [None, 2, 3]) - self.assertIsInstance(features['fixed_len_string'], tf.Tensor) - self.assertEqual(features['fixed_len_string'].get_shape().as_list(), [None]) - self.assertEqual(type(features['var_len_int']), tf.SparseTensor) - self.assertEqual(features['var_len_int'].get_shape().as_list(), - [None, None]) - self.assertEqual(type(features['_var_len_underscored']), tf.SparseTensor) - self.assertEqual(features['_var_len_underscored'].get_shape().as_list(), - [None, None]) - self.assertEqual(type(features['sparse_1d']), tf.SparseTensor) - self.assertEqual(type(features['sparse_2d']), tf.SparseTensor) - if version.parse(tf.__version__) >= version.parse('2'): - self.assertEqual(features['sparse_1d'].get_shape().as_list(), [None, 7]) - self.assertEqual(features['sparse_2d'].get_shape().as_list(), - [None, 2, 17]) - else: - self.assertEqual(features['sparse_1d'].get_shape().as_list(), - [None, None]) - self.assertEqual(features['sparse_2d'].get_shape().as_list(), - [None, None, None]) - - def test_batched_placeholders_from_typespecs(self): - typespecs = { - 'dense_float': - tf.TensorSpec(dtype=tf.float32, shape=[None, 2, 3]), - 'dense_string': - tf.TensorSpec(shape=[None], dtype=tf.string), - '_sparse_underscored': - tf.SparseTensorSpec(dtype=tf.string, shape=[None, None, 17]), - 'ragged_string': - tf.RaggedTensorSpec( - dtype=tf.string, ragged_rank=1, shape=[None, None]), - 'ragged_multi_dimension': - tf.RaggedTensorSpec( + def test_batched_placeholders_from_typespecs(self): + typespecs = { + "dense_float": tf.TensorSpec(dtype=tf.float32, shape=[None, 2, 3]), + "dense_string": tf.TensorSpec(shape=[None], dtype=tf.string), + "_sparse_underscored": tf.SparseTensorSpec( + dtype=tf.string, shape=[None, None, 17] + ), + "ragged_string": tf.RaggedTensorSpec( + dtype=tf.string, ragged_rank=1, shape=[None, None] + ), + "ragged_multi_dimension": tf.RaggedTensorSpec( + dtype=tf.int64, ragged_rank=3, shape=[None, None, None, None, 5] + ), + } + with tf.compat.v1.Graph().as_default(): + features = impl_helper.batched_placeholders_from_specs(typespecs) + self.assertCountEqual( + features.keys(), + [ + "dense_float", + "dense_string", + "_sparse_underscored", + "ragged_string", + "ragged_multi_dimension", + ], + ) + self.assertIsInstance(features["dense_float"], tf.Tensor) + self.assertEqual(features["dense_float"].get_shape().as_list(), [None, 2, 3]) + self.assertEqual(features["dense_float"].dtype, tf.float32) + + self.assertIsInstance(features["dense_string"], tf.Tensor) + self.assertEqual(features["dense_string"].get_shape().as_list(), [None]) + self.assertEqual(features["dense_string"].dtype, tf.string) + + self.assertEqual(type(features["_sparse_underscored"]), tf.SparseTensor) + self.assertEqual( + features["_sparse_underscored"].get_shape().as_list(), [None, None, 17] + ) + self.assertEqual(features["_sparse_underscored"].dtype, tf.string) + + self.assertEqual(type(features["ragged_string"]), tf.RaggedTensor) + self.assertEqual(features["ragged_string"].shape.as_list(), [None, None]) + self.assertEqual(features["ragged_string"].ragged_rank, 1) + self.assertEqual(features["ragged_string"].dtype, tf.string) + + self.assertEqual(type(features["ragged_multi_dimension"]), tf.RaggedTensor) + self.assertEqual( + features["ragged_multi_dimension"].shape.as_list(), + [None, None, None, None, 5], + ) + self.assertEqual(features["ragged_multi_dimension"].ragged_rank, 3) + self.assertEqual(features["ragged_multi_dimension"].dtype, tf.int64) + + def test_batched_placeholders_from_specs_invalid_dtype(self): + with self.assertRaisesRegex(ValueError, "had invalid dtype"): + impl_helper.batched_placeholders_from_specs( + {"f": tf.TensorSpec(dtype=tf.int32, shape=[None])} + ) + with self.assertRaisesRegex(ValueError, "had invalid dtype"): + impl_helper.batched_placeholders_from_specs( + {"f": tf.io.FixedLenFeature(dtype=tf.int32, shape=[None])} + ) + + def test_batched_placeholders_from_specs_invalid_mixing(self): + with self.assertRaisesRegex(TypeError, "Specs must be all"): + impl_helper.batched_placeholders_from_specs( + { + "f1": tf.TensorSpec(dtype=tf.int64, shape=[None]), + "f2": tf.io.FixedLenFeature(dtype=tf.int64, shape=[None]), + } + ) + + @test_case.named_parameters( + *test_case.cross_named_parameters( + _ROUNDTRIP_CASES, + [ + dict(testcase_name="eager_tensors", feed_eager_tensors=True), + dict(testcase_name="session_run_values", feed_eager_tensors=False), + ], + ) + ) + def test_to_instance_dicts( + self, feature_spec, instances, record_batch, feed_dict, feed_eager_tensors + ): + del record_batch + if feed_eager_tensors: + test_case.skip_if_not_tf2("Tensorflow 2.x required") + schema = schema_utils.schema_from_feature_spec(feature_spec) + feed_dict_local = ( + _eager_tensor_from_values(feed_dict) + if feed_eager_tensors + else copy.copy(feed_dict) + ) + arrow_converter = impl_helper.make_tensor_to_arrow_converter(schema) + record_batch = arrow_converter.convert(feed_dict_local) + result = impl_helper.record_batch_to_instance_dicts(record_batch, schema) + np.testing.assert_equal(instances, result) + + @test_case.named_parameters(*_TO_INSTANCE_DICT_ERROR_CASES) + def test_to_instance_dicts_error( + self, feature_spec, feed_dict, error_msg, error_type=ValueError + ): + schema = schema_utils.schema_from_feature_spec(feature_spec) + arrow_converter = impl_helper.make_tensor_to_arrow_converter(schema) + with self.assertRaisesRegex(error_type, error_msg): + record_batch = arrow_converter.convert(feed_dict) + _ = impl_helper.record_batch_to_instance_dicts(record_batch, schema) + + @test_case.named_parameters( + *test_case.cross_named_parameters( + _ROUNDTRIP_CASES, + [ + dict(testcase_name="eager_tensors", feed_eager_tensors=True), + dict(testcase_name="session_run_values", feed_eager_tensors=False), + ], + ) + ) + def test_convert_to_arrow( + self, feature_spec, instances, record_batch, feed_dict, feed_eager_tensors + ): + del instances + if feed_eager_tensors: + test_case.skip_if_not_tf2("Tensorflow 2.x required") + schema = schema_utils.schema_from_feature_spec(feature_spec) + converter = impl_helper.make_tensor_to_arrow_converter(schema) + feed_dict_local = ( + _eager_tensor_from_values(feed_dict) + if feed_eager_tensors + else copy.copy(feed_dict) + ) + actual = converter.convert(feed_dict_local) + expected = pa.RecordBatch.from_arrays( + list(record_batch.values()), names=list(record_batch.keys()) + ) + np.testing.assert_equal(actual.to_pydict(), expected.to_pydict()) + + @test_case.named_parameters(*_CONVERT_TO_ARROW_ERROR_CASES) + def test_convert_to_arrow_error( + self, feature_spec, feed_dict, error_msg, error_type=ValueError + ): + schema = schema_utils.schema_from_feature_spec(feature_spec) + converter = impl_helper.make_tensor_to_arrow_converter(schema) + with self.assertRaisesRegex(error_type, error_msg): + converter.convert(feed_dict) + + @test_case.named_parameters( + dict(testcase_name="tf_compat_v1", force_tf_compat_v1=True), + dict(testcase_name="native_tf2", force_tf_compat_v1=False), + ) + def test_analyze_in_place(self, force_tf_compat_v1): + if not force_tf_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") + + def preprocessing_fn(inputs): + return {"x_add_1": inputs["x"] + 1} + + feature_spec = {"x": tf.io.FixedLenFeature([], tf.int64)} + type_spec = { + "x": tf.TensorSpec( dtype=tf.int64, - ragged_rank=3, - shape=[None, None, None, None, 5]), - } - with tf.compat.v1.Graph().as_default(): - features = impl_helper.batched_placeholders_from_specs(typespecs) - self.assertCountEqual(features.keys(), [ - 'dense_float', - 'dense_string', - '_sparse_underscored', - 'ragged_string', - 'ragged_multi_dimension', - ]) - self.assertIsInstance(features['dense_float'], tf.Tensor) - self.assertEqual(features['dense_float'].get_shape().as_list(), - [None, 2, 3]) - self.assertEqual(features['dense_float'].dtype, tf.float32) - - self.assertIsInstance(features['dense_string'], tf.Tensor) - self.assertEqual(features['dense_string'].get_shape().as_list(), [None]) - self.assertEqual(features['dense_string'].dtype, tf.string) - - self.assertEqual(type(features['_sparse_underscored']), tf.SparseTensor) - self.assertEqual( - features['_sparse_underscored'].get_shape().as_list(), [None, None, 17] + shape=[ + None, + ], + ) + } + output_path = os.path.join(self.get_temp_dir(), self._testMethodName) + impl_helper.analyze_in_place( + preprocessing_fn, force_tf_compat_v1, feature_spec, type_spec, output_path + ) + + tft_output = TFTransformOutput(output_path) + expected_value = np.array([2], dtype=np.int64) + if force_tf_compat_v1: + with tf.Graph().as_default() as graph: + with tf.compat.v1.Session(graph=graph).as_default(): + transformed_features = tft_output.transform_raw_features( + {"x": tf.constant([1], dtype=tf.int64)} + ) + transformed_value = transformed_features["x_add_1"].eval() + else: + transformed_features = tft_output.transform_raw_features( + {"x": tf.constant([1], dtype=tf.int64)} + ) + transformed_value = transformed_features["x_add_1"].numpy() + self.assertEqual(transformed_value, expected_value) + + transformed_feature_spec = tft_output.transformed_feature_spec() + expected_feature_spec = feature_spec = { + "x_add_1": tf.io.FixedLenFeature([], tf.int64) + } + self.assertEqual(transformed_feature_spec, expected_feature_spec) + + @test_case.named_parameters( + dict(testcase_name="tf_compat_v1", force_tf_compat_v1=True), + dict(testcase_name="native_tf2", force_tf_compat_v1=False), ) - self.assertEqual(features['_sparse_underscored'].dtype, tf.string) - - self.assertEqual(type(features['ragged_string']), tf.RaggedTensor) - self.assertEqual(features['ragged_string'].shape.as_list(), [None, None]) - self.assertEqual(features['ragged_string'].ragged_rank, 1) - self.assertEqual(features['ragged_string'].dtype, tf.string) - - self.assertEqual(type(features['ragged_multi_dimension']), tf.RaggedTensor) - self.assertEqual(features['ragged_multi_dimension'].shape.as_list(), - [None, None, None, None, 5]) - self.assertEqual(features['ragged_multi_dimension'].ragged_rank, 3) - self.assertEqual(features['ragged_multi_dimension'].dtype, tf.int64) - - def test_batched_placeholders_from_specs_invalid_dtype(self): - with self.assertRaisesRegex(ValueError, 'had invalid dtype'): - impl_helper.batched_placeholders_from_specs( - {'f': tf.TensorSpec(dtype=tf.int32, shape=[None])}) - with self.assertRaisesRegex(ValueError, 'had invalid dtype'): - impl_helper.batched_placeholders_from_specs( - {'f': tf.io.FixedLenFeature(dtype=tf.int32, shape=[None])}) - - def test_batched_placeholders_from_specs_invalid_mixing(self): - with self.assertRaisesRegex(TypeError, 'Specs must be all'): - impl_helper.batched_placeholders_from_specs({ - 'f1': tf.TensorSpec(dtype=tf.int64, shape=[None]), - 'f2': tf.io.FixedLenFeature(dtype=tf.int64, shape=[None]), - }) - - @test_case.named_parameters(*test_case.cross_named_parameters( - _ROUNDTRIP_CASES, [ - dict(testcase_name='eager_tensors', feed_eager_tensors=True), - dict(testcase_name='session_run_values', feed_eager_tensors=False) - ])) - def test_to_instance_dicts(self, feature_spec, instances, record_batch, - feed_dict, feed_eager_tensors): - del record_batch - if feed_eager_tensors: - test_case.skip_if_not_tf2('Tensorflow 2.x required') - schema = schema_utils.schema_from_feature_spec(feature_spec) - feed_dict_local = ( - _eager_tensor_from_values(feed_dict) - if feed_eager_tensors else copy.copy(feed_dict)) - arrow_converter = impl_helper.make_tensor_to_arrow_converter(schema) - record_batch = arrow_converter.convert(feed_dict_local) - result = impl_helper.record_batch_to_instance_dicts(record_batch, schema) - np.testing.assert_equal(instances, result) - - @test_case.named_parameters(*_TO_INSTANCE_DICT_ERROR_CASES) - def test_to_instance_dicts_error(self, - feature_spec, - feed_dict, - error_msg, - error_type=ValueError): - schema = schema_utils.schema_from_feature_spec(feature_spec) - arrow_converter = impl_helper.make_tensor_to_arrow_converter(schema) - with self.assertRaisesRegex(error_type, error_msg): - record_batch = arrow_converter.convert(feed_dict) - _ = impl_helper.record_batch_to_instance_dicts(record_batch, schema) - - @test_case.named_parameters(*test_case.cross_named_parameters( - _ROUNDTRIP_CASES, [ - dict(testcase_name='eager_tensors', feed_eager_tensors=True), - dict(testcase_name='session_run_values', feed_eager_tensors=False) - ])) - def test_convert_to_arrow(self, feature_spec, instances, record_batch, - feed_dict, feed_eager_tensors): - del instances - if feed_eager_tensors: - test_case.skip_if_not_tf2('Tensorflow 2.x required') - schema = schema_utils.schema_from_feature_spec(feature_spec) - converter = impl_helper.make_tensor_to_arrow_converter(schema) - feed_dict_local = ( - _eager_tensor_from_values(feed_dict) - if feed_eager_tensors else copy.copy(feed_dict)) - actual = converter.convert(feed_dict_local) - expected = pa.RecordBatch.from_arrays( - list(record_batch.values()), names=list(record_batch.keys())) - np.testing.assert_equal(actual.to_pydict(), expected.to_pydict()) - - @test_case.named_parameters(*_CONVERT_TO_ARROW_ERROR_CASES) - def test_convert_to_arrow_error(self, - feature_spec, - feed_dict, - error_msg, - error_type=ValueError): - schema = schema_utils.schema_from_feature_spec(feature_spec) - converter = impl_helper.make_tensor_to_arrow_converter(schema) - with self.assertRaisesRegex(error_type, error_msg): - converter.convert(feed_dict) - - @test_case.named_parameters( - dict(testcase_name='tf_compat_v1', force_tf_compat_v1=True), - dict(testcase_name='native_tf2', force_tf_compat_v1=False)) - def test_analyze_in_place(self, force_tf_compat_v1): - if not force_tf_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') - - def preprocessing_fn(inputs): - return {'x_add_1': inputs['x'] + 1} - - feature_spec = {'x': tf.io.FixedLenFeature([], tf.int64)} - type_spec = { - 'x': tf.TensorSpec(dtype=tf.int64, shape=[ - None, - ]) - } - output_path = os.path.join(self.get_temp_dir(), self._testMethodName) - impl_helper.analyze_in_place(preprocessing_fn, force_tf_compat_v1, - feature_spec, type_spec, output_path) - - tft_output = TFTransformOutput(output_path) - expected_value = np.array([2], dtype=np.int64) - if force_tf_compat_v1: - with tf.Graph().as_default() as graph: - with tf.compat.v1.Session(graph=graph).as_default(): - transformed_features = tft_output.transform_raw_features( - {'x': tf.constant([1], dtype=tf.int64)}) - transformed_value = transformed_features['x_add_1'].eval() - else: - transformed_features = tft_output.transform_raw_features( - {'x': tf.constant([1], dtype=tf.int64)}) - transformed_value = transformed_features['x_add_1'].numpy() - self.assertEqual(transformed_value, expected_value) - - transformed_feature_spec = tft_output.transformed_feature_spec() - expected_feature_spec = feature_spec = { - 'x_add_1': tf.io.FixedLenFeature([], tf.int64) - } - self.assertEqual(transformed_feature_spec, expected_feature_spec) - - @test_case.named_parameters( - dict(testcase_name='tf_compat_v1', force_tf_compat_v1=True), - dict(testcase_name='native_tf2', force_tf_compat_v1=False)) - def test_analyze_in_place_with_analyzers_raises_error(self, - force_tf_compat_v1): - if not force_tf_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') - - def preprocessing_fn(inputs): - return {'x_add_1': analyzers.mean(inputs['x'])} - - feature_spec = {'x': tf.io.FixedLenFeature([], tf.int64)} - type_spec = { - 'x': tf.TensorSpec(dtype=tf.int64, shape=[ - None, - ]) - } - output_path = os.path.join(self.get_temp_dir(), self._testMethodName) - with self.assertRaisesRegex(RuntimeError, 'analyzers found when tracing'): - impl_helper.analyze_in_place(preprocessing_fn, force_tf_compat_v1, - feature_spec, type_spec, output_path) - - @test_case.named_parameters( - dict( - testcase_name='valid', - value=tf.compat.v1.SparseTensorValue( - indices=np.array( - [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)] - ), - values=np.array( - [b'doe', b'a', b'deer', b'a', b'female', b'deer'], - dtype=object, - ), - dense_shape=(2, 3), - ), - ), - dict( - testcase_name='empty', - value=tf.compat.v1.SparseTensorValue( - indices=np.empty((0, 2)), - values=np.empty((0)), - dense_shape=(2, 3), - ), - ), - dict( - testcase_name='3d', - value=tf.compat.v1.SparseTensorValue( - indices=np.empty((0, 3)), - values=np.empty((0)), - dense_shape=(2, 3), - ), - error='Encountered non 2-D varlen sparse', - ), - dict( - testcase_name='instance_index_not_sorted', - value=tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0), (1, 1), (0, 2)]), - values=np.array([b'a', b'female', b'deer'], dtype=object), - dense_shape=(2, 3), - ), - error='Encountered decreasing instance indices', - ), - dict( - testcase_name='value_index_not_sorted', - value=tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0), (1, 1), (1, 0)]), - values=np.array([b'a', b'female', b'deer'], dtype=object), - dense_shape=(2, 3), - ), - error='Encountered non-consecutive value indices', - ), - dict( - testcase_name='instance_start_index_nonzero', - value=tf.compat.v1.SparseTensorValue( - indices=np.array([(0, 0), (1, 1), (1, 2)]), - values=np.array([b'a', b'female', b'deer'], dtype=object), - dense_shape=(2, 3), - ), - error='Encountered non-zero starting value indices', - ), - ) - def test_validate_varlen_sparse_value(self, value, error=None): - if error is None: - self.assertIsNone( - impl_helper.validate_varlen_sparse_value('varlen', value) - ) - else: - with self.assertRaisesRegex(ValueError, error): - impl_helper.validate_varlen_sparse_value('varlen', value) + def test_analyze_in_place_with_analyzers_raises_error(self, force_tf_compat_v1): + if not force_tf_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") + + def preprocessing_fn(inputs): + return {"x_add_1": analyzers.mean(inputs["x"])} + + feature_spec = {"x": tf.io.FixedLenFeature([], tf.int64)} + type_spec = { + "x": tf.TensorSpec( + dtype=tf.int64, + shape=[ + None, + ], + ) + } + output_path = os.path.join(self.get_temp_dir(), self._testMethodName) + with self.assertRaisesRegex(RuntimeError, "analyzers found when tracing"): + impl_helper.analyze_in_place( + preprocessing_fn, + force_tf_compat_v1, + feature_spec, + type_spec, + output_path, + ) + + @test_case.named_parameters( + dict( + testcase_name="valid", + value=tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)]), + values=np.array( + [b"doe", b"a", b"deer", b"a", b"female", b"deer"], + dtype=object, + ), + dense_shape=(2, 3), + ), + ), + dict( + testcase_name="empty", + value=tf.compat.v1.SparseTensorValue( + indices=np.empty((0, 2)), + values=np.empty(0), + dense_shape=(2, 3), + ), + ), + dict( + testcase_name="3d", + value=tf.compat.v1.SparseTensorValue( + indices=np.empty((0, 3)), + values=np.empty(0), + dense_shape=(2, 3), + ), + error="Encountered non 2-D varlen sparse", + ), + dict( + testcase_name="instance_index_not_sorted", + value=tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0), (1, 1), (0, 2)]), + values=np.array([b"a", b"female", b"deer"], dtype=object), + dense_shape=(2, 3), + ), + error="Encountered decreasing instance indices", + ), + dict( + testcase_name="value_index_not_sorted", + value=tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0), (1, 1), (1, 0)]), + values=np.array([b"a", b"female", b"deer"], dtype=object), + dense_shape=(2, 3), + ), + error="Encountered non-consecutive value indices", + ), + dict( + testcase_name="instance_start_index_nonzero", + value=tf.compat.v1.SparseTensorValue( + indices=np.array([(0, 0), (1, 1), (1, 2)]), + values=np.array([b"a", b"female", b"deer"], dtype=object), + dense_shape=(2, 3), + ), + error="Encountered non-zero starting value indices", + ), + ) + def test_validate_varlen_sparse_value(self, value, error=None): + if error is None: + self.assertIsNone(impl_helper.validate_varlen_sparse_value("varlen", value)) + else: + with self.assertRaisesRegex(ValueError, error): + impl_helper.validate_varlen_sparse_value("varlen", value) def _subtract_ten_with_tf_while(x): - """Subtracts 10 from x using control flow ops. + """Subtracts 10 from x using control flow ops. - This function is equivalent to "x - 10" but uses a tf.while_loop, in order - to test the use of functions that involve control flow ops. + This function is equivalent to "x - 10" but uses a tf.while_loop, in order + to test the use of functions that involve control flow ops. - Args: - x: A tensor of integral type. + Args: + ---- + x: A tensor of integral type. - Returns: - A tensor representing x - 10. - """ + Returns: + ------- + A tensor representing x - 10. + """ - def stop_condition(counter, x_minus_counter): - del x_minus_counter # unused - return tf.less(counter, 10) + def stop_condition(counter, x_minus_counter): + del x_minus_counter # unused + return tf.less(counter, 10) - def iteration(counter, x_minus_counter): - return tf.add(counter, 1), tf.add(x_minus_counter, -1) + def iteration(counter, x_minus_counter): + return tf.add(counter, 1), tf.add(x_minus_counter, -1) - initial_values = [tf.constant(0), x] - return tf.while_loop( - cond=stop_condition, body=iteration, loop_vars=initial_values)[1] + initial_values = [tf.constant(0), x] + return tf.while_loop(cond=stop_condition, body=iteration, loop_vars=initial_values)[ + 1 + ] -if __name__ == '__main__': - test_case.main() +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/info_theory.py b/tensorflow_transform/info_theory.py index f536d10..eed3024 100644 --- a/tensorflow_transform/info_theory.py +++ b/tensorflow_transform/info_theory.py @@ -16,111 +16,126 @@ import math # math.log2 was added in Python 3.3 -log2 = getattr(math, 'log2', lambda x: math.log(x, 2)) +log2 = getattr(math, "log2", lambda x: math.log(x, 2)) # TODO(b/157302701): Evaluate optimizations or approximations for this function, # in particular the _hypergeometric_pmf. def calculate_partial_expected_mutual_information(n, x_i, y_j): - """Calculates the partial expected mutual information (EMI) of two variables. - - EMI reflects the MI expected by chance, and is used to compute adjusted - mutual information. See www.wikipedia.org/wiki/Adjusted_mutual_information. - - The EMI for two variables x and y, is the sum of the expected mutual info - for each value of x with each value of y. This function computes the EMI - for a single value of each variable (x_i, y_j) and is thus considered a - partial EMI calculation. - - Specifically: - EMI(x, y) = sum_{n_ij = max(0, x_i + y_j - n) to min(x_i, y_j)} ( - n_ij / n * log2((n * n_ij / (x_i * y_j)) - * ((x_i! * y_j! * (n - x_i)! * (n - y_j)!) / - (n! * n_ij! * (x_i - n_ij)! * (y_j - n_ij)! * (n - x_i - y_j + n_ij)!))) - where n_ij is the joint count of x taking on value i and y taking on - value j, x_i is the count for x taking on value i, y_j is the count for y - taking on value j, and n represents total count. - - Args: - n: The sum of weights for all values. - x_i: The sum of weights for the first variable taking on value i - y_j: The sum of weights for the second variable taking on value j - - Returns: - Calculated expected mutual information for x_i, y_j. - """ - if x_i == 0 or y_j == 0: - return 0 - coefficient = (-log2(x_i) - log2(y_j) + log2(n)) - sum_probability = 0.0 - partial_result = 0.0 - for n_j, p_j in _hypergeometric_pmf(n, x_i, y_j): - if n_j != 0: - partial_result += n_j * (coefficient + log2(n_j)) * p_j - sum_probability += p_j - # The values of p_j should sum to 1, but given approximate calculations for - # log2(x) and exp2(x) with large x, the full pmf might not sum to exactly 1. - # We correct for this by dividing by the sum of the probabilities. - return partial_result / sum_probability + """Calculates the partial expected mutual information (EMI) of two variables. + + EMI reflects the MI expected by chance, and is used to compute adjusted + mutual information. See www.wikipedia.org/wiki/Adjusted_mutual_information. + + The EMI for two variables x and y, is the sum of the expected mutual info + for each value of x with each value of y. This function computes the EMI + for a single value of each variable (x_i, y_j) and is thus considered a + partial EMI calculation. + + Specifically: + EMI(x, y) = sum_{n_ij = max(0, x_i + y_j - n) to min(x_i, y_j)} ( + n_ij / n * log2((n * n_ij / (x_i * y_j)) + * ((x_i! * y_j! * (n - x_i)! * (n - y_j)!) / + (n! * n_ij! * (x_i - n_ij)! * (y_j - n_ij)! * (n - x_i - y_j + n_ij)!))) + where n_ij is the joint count of x taking on value i and y taking on + value j, x_i is the count for x taking on value i, y_j is the count for y + taking on value j, and n represents total count. + + Args: + ---- + n: The sum of weights for all values. + x_i: The sum of weights for the first variable taking on value i + y_j: The sum of weights for the second variable taking on value j + + Returns: + ------- + Calculated expected mutual information for x_i, y_j. + """ + if x_i == 0 or y_j == 0: + return 0 + coefficient = -log2(x_i) - log2(y_j) + log2(n) + sum_probability = 0.0 + partial_result = 0.0 + for n_j, p_j in _hypergeometric_pmf(n, x_i, y_j): + if n_j != 0: + partial_result += n_j * (coefficient + log2(n_j)) * p_j + sum_probability += p_j + # The values of p_j should sum to 1, but given approximate calculations for + # log2(x) and exp2(x) with large x, the full pmf might not sum to exactly 1. + # We correct for this by dividing by the sum of the probabilities. + return partial_result / sum_probability def calculate_partial_mutual_information(n_ij, x_i, y_j, n): - """Calculates Mutual Information for x=i, y=j from sample counts. - - The standard formulation of mutual information is: - MI(X,Y) = Sum_i,j {p_ij * log2(p_ij / p_i * p_j)} - We are operating over counts (p_ij = n_ij / n), so this is transformed into - MI(X,Y) = Sum_i,j {n_ij * (log2(n_ij) + log2(n) - log2(x_i) - log2(y_j))} / n - This function returns the argument to the summation, the mutual information - for a particular pair of values x_i, y_j (the caller is expected to divide - the summation by n to compute the final mutual information result). - - Args: - n_ij: The co-occurrence of x=i and y=j - x_i: The frequency of x=i. - y_j: The frequency of y=j. - n: The total # observations - - Returns: - Mutual information for the cell x=i, y=j. - """ - if n_ij == 0 or x_i == 0 or y_j == 0: - return 0 - return n_ij * ((log2(n_ij) + log2(n)) - - (log2(x_i) + log2(y_j))) + """Calculates Mutual Information for x=i, y=j from sample counts. + + The standard formulation of mutual information is: + MI(X,Y) = Sum_i,j {p_ij * log2(p_ij / p_i * p_j)} + We are operating over counts (p_ij = n_ij / n), so this is transformed into + MI(X,Y) = Sum_i,j {n_ij * (log2(n_ij) + log2(n) - log2(x_i) - log2(y_j))} / n + This function returns the argument to the summation, the mutual information + for a particular pair of values x_i, y_j (the caller is expected to divide + the summation by n to compute the final mutual information result). + + Args: + ---- + n_ij: The co-occurrence of x=i and y=j + x_i: The frequency of x=i. + y_j: The frequency of y=j. + n: The total # observations + + Returns: + ------- + Mutual information for the cell x=i, y=j. + """ + if n_ij == 0 or x_i == 0 or y_j == 0: + return 0 + return n_ij * ((log2(n_ij) + log2(n)) - (log2(x_i) + log2(y_j))) def _hypergeometric_pmf(n, x_i, y_j): - """Probablity for expectation computation under hypergeometric distribution. - - Args: - n: The sum of weights for all values. - x_i: The sum of weights for the first variable taking on value i - y_j: The sum of weights for the second variable taking on value j - - Yields: - The probability p_j at point n_j in the hypergeometric distribution. - """ - start = int(round(max(0, x_i + y_j - n))) - end = int(round(min(x_i, y_j))) - # Use log factorial to preserve calculation precision. - # Note: because the factorials are expensive to compute, we compute the - # denominator incrementally, at the cost of some readability. - numerator = ( - _logfactorial(x_i) + _logfactorial(y_j) + _logfactorial(n - x_i) + - _logfactorial(n - y_j)) - denominator = ( - _logfactorial(n) + _logfactorial(start) + _logfactorial(x_i - start) + - _logfactorial(y_j - start) + _logfactorial(n - x_i - y_j + start)) - for n_j in range(start, end + 1): - p_j = math.exp(numerator - denominator) - if n_j != end: - denominator += ( - math.log(n_j + 1) - math.log(x_i - n_j) - math.log(y_j - n_j) + - math.log(n - x_i - y_j + n_j + 1)) - yield n_j, p_j + """Probablity for expectation computation under hypergeometric distribution. + + Args: + ---- + n: The sum of weights for all values. + x_i: The sum of weights for the first variable taking on value i + y_j: The sum of weights for the second variable taking on value j + + Yields: + ------ + The probability p_j at point n_j in the hypergeometric distribution. + """ + start = int(round(max(0, x_i + y_j - n))) + end = int(round(min(x_i, y_j))) + # Use log factorial to preserve calculation precision. + # Note: because the factorials are expensive to compute, we compute the + # denominator incrementally, at the cost of some readability. + numerator = ( + _logfactorial(x_i) + + _logfactorial(y_j) + + _logfactorial(n - x_i) + + _logfactorial(n - y_j) + ) + denominator = ( + _logfactorial(n) + + _logfactorial(start) + + _logfactorial(x_i - start) + + _logfactorial(y_j - start) + + _logfactorial(n - x_i - y_j + start) + ) + for n_j in range(start, end + 1): + p_j = math.exp(numerator - denominator) + if n_j != end: + denominator += ( + math.log(n_j + 1) + - math.log(x_i - n_j) + - math.log(y_j - n_j) + + math.log(n - x_i - y_j + n_j + 1) + ) + yield n_j, p_j def _logfactorial(n): - """Calculate natural logarithm of n!.""" - return math.lgamma(n + 1) + """Calculate natural logarithm of n!.""" + return math.lgamma(n + 1) diff --git a/tensorflow_transform/info_theory_test.py b/tensorflow_transform/info_theory_test.py index b8136b4..173f5dd 100644 --- a/tensorflow_transform/info_theory_test.py +++ b/tensorflow_transform/info_theory_test.py @@ -13,158 +13,167 @@ # limitations under the License. """Tests for tensorflow_transform.info_theory.""" -from tensorflow_transform import info_theory -from tensorflow_transform import test_case - - import unittest +from tensorflow_transform import info_theory, test_case EPSILON = 1e-4 def _make_hypergeometric_pmf_sum_up_to_one_parameters(): - start = 1000 - end = 10000 - range_length = end - start - num_chunks = 15 - assert range_length % num_chunks == 0 - chunk_size = int(range_length / num_chunks) - sub_ranges = [(x, x + chunk_size) for x in range(start, end, chunk_size)] - return [ # pylint: disable=g-complex-comprehension - dict( - testcase_name='{}_to_{}'.format(a, b), - test_range=range(a, b), - n=end, - y_j=start) for a, b in sub_ranges - ] + start = 1000 + end = 10000 + range_length = end - start + num_chunks = 15 + assert range_length % num_chunks == 0 + chunk_size = int(range_length / num_chunks) + sub_ranges = [(x, x + chunk_size) for x in range(start, end, chunk_size)] + return [ # pylint: disable=g-complex-comprehension + dict(testcase_name=f"{a}_to_{b}", test_range=range(a, b), n=end, y_j=start) + for a, b in sub_ranges + ] class InfoTheoryTest(test_case.TransformTestCase): - - def testHypergeometricPmf(self): - expected_results = [(0, 0.75), (1, 0.25)] - results = list(info_theory._hypergeometric_pmf(4, 1, 1)) - for expected_result, result in zip(expected_results, results): - self.assertEqual(expected_result[0], result[0]) - self.assertNear(expected_result[1], result[1], EPSILON) - - def testHypergeometricPmf_LargeN(self): - expected_results = [(0, 0.9508937), (1, 0.0482198), (2, 0.0008794), - (3, 7.1e-06), (4, 2.5e-08), (5, 0.0)] - results = list(info_theory._hypergeometric_pmf(1000, 5, 10)) - for expected_result, result in zip(expected_results, results): - self.assertEqual(expected_result[0], result[0]) - self.assertNear(expected_result[1], result[1], EPSILON) - - @test_case.named_parameters( - *_make_hypergeometric_pmf_sum_up_to_one_parameters()) - def test_hypergeometric_pmf_sum_up_to_one(self, test_range, n, y_j): - for x in test_range: - probs = [prob for _, prob in info_theory._hypergeometric_pmf(n, x, y_j)] - sum_prob = sum(probs) - self.assertNear(sum_prob, 1.0, EPSILON) - - @test_case.named_parameters( - dict( - testcase_name='all_co_occur', - n=10, - x_i=10, - y_j=10, - expected=0, - ), - dict( - testcase_name='2_co_occur_no_observations', - n=10, - x_i=0, - y_j=0, - expected=0, - ), - dict( - testcase_name='2_values_appear_half_the_time', - n=10, - x_i=5, - y_j=5, - expected=0.215411, - ), - dict( - testcase_name='2_values_differing_frequencies', - n=10, - x_i=2, - y_j=4, - expected=0.524209, - ), - ) - def test_calculate_partial_expected_mutual_information( - self, n, x_i, y_j, expected): - self.assertNear( - info_theory.calculate_partial_expected_mutual_information(n, x_i, y_j), - expected, EPSILON) - - @test_case.named_parameters( - dict( - testcase_name='strongly_positive_mi', - cell_count=2, - row_count=10, - col_count=2, - total_count=14, - expected_mi=0.970854), - dict( - testcase_name='weakly_positive_mi', - cell_count=4, - row_count=15, - col_count=6, - total_count=25, - expected_mi=0.608012), - dict( - testcase_name='strongly_negative_mi', - cell_count=2, - row_count=10, - col_count=6, - total_count=25, - expected_mi=-0.526069), - dict( - testcase_name='weakly_negative_mi', - cell_count=3, - row_count=31, - col_count=4, - total_count=41, - expected_mi=-0.0350454), - dict( - testcase_name='zero_mi', - cell_count=4, - row_count=8, - col_count=8, - total_count=16, - expected_mi=0), - dict( - testcase_name='invalid_input_zero_cell_count', - cell_count=4, - row_count=0, - col_count=8, - total_count=8, - expected_mi=0), - dict( - testcase_name='invalid_input_zero_row_count', - cell_count=4, - row_count=0, - col_count=8, - total_count=8, - expected_mi=0), - dict( - testcase_name='invalid_input_zero_col_count', - cell_count=4, - row_count=8, - col_count=0, - total_count=8, - expected_mi=0), - ) - def test_mutual_information(self, cell_count, row_count, col_count, - total_count, expected_mi): - per_cell_mi = info_theory.calculate_partial_mutual_information( - cell_count, row_count, col_count, total_count) - self.assertNear(per_cell_mi, expected_mi, EPSILON) - - -if __name__ == '__main__': - unittest.main() + def testHypergeometricPmf(self): + expected_results = [(0, 0.75), (1, 0.25)] + results = list(info_theory._hypergeometric_pmf(4, 1, 1)) + for expected_result, result in zip(expected_results, results): + self.assertEqual(expected_result[0], result[0]) + self.assertNear(expected_result[1], result[1], EPSILON) + + def testHypergeometricPmf_LargeN(self): + expected_results = [ + (0, 0.9508937), + (1, 0.0482198), + (2, 0.0008794), + (3, 7.1e-06), + (4, 2.5e-08), + (5, 0.0), + ] + results = list(info_theory._hypergeometric_pmf(1000, 5, 10)) + for expected_result, result in zip(expected_results, results): + self.assertEqual(expected_result[0], result[0]) + self.assertNear(expected_result[1], result[1], EPSILON) + + @test_case.named_parameters(*_make_hypergeometric_pmf_sum_up_to_one_parameters()) + def test_hypergeometric_pmf_sum_up_to_one(self, test_range, n, y_j): + for x in test_range: + probs = [prob for _, prob in info_theory._hypergeometric_pmf(n, x, y_j)] + sum_prob = sum(probs) + self.assertNear(sum_prob, 1.0, EPSILON) + + @test_case.named_parameters( + dict( + testcase_name="all_co_occur", + n=10, + x_i=10, + y_j=10, + expected=0, + ), + dict( + testcase_name="2_co_occur_no_observations", + n=10, + x_i=0, + y_j=0, + expected=0, + ), + dict( + testcase_name="2_values_appear_half_the_time", + n=10, + x_i=5, + y_j=5, + expected=0.215411, + ), + dict( + testcase_name="2_values_differing_frequencies", + n=10, + x_i=2, + y_j=4, + expected=0.524209, + ), + ) + def test_calculate_partial_expected_mutual_information(self, n, x_i, y_j, expected): + self.assertNear( + info_theory.calculate_partial_expected_mutual_information(n, x_i, y_j), + expected, + EPSILON, + ) + + @test_case.named_parameters( + dict( + testcase_name="strongly_positive_mi", + cell_count=2, + row_count=10, + col_count=2, + total_count=14, + expected_mi=0.970854, + ), + dict( + testcase_name="weakly_positive_mi", + cell_count=4, + row_count=15, + col_count=6, + total_count=25, + expected_mi=0.608012, + ), + dict( + testcase_name="strongly_negative_mi", + cell_count=2, + row_count=10, + col_count=6, + total_count=25, + expected_mi=-0.526069, + ), + dict( + testcase_name="weakly_negative_mi", + cell_count=3, + row_count=31, + col_count=4, + total_count=41, + expected_mi=-0.0350454, + ), + dict( + testcase_name="zero_mi", + cell_count=4, + row_count=8, + col_count=8, + total_count=16, + expected_mi=0, + ), + dict( + testcase_name="invalid_input_zero_cell_count", + cell_count=4, + row_count=0, + col_count=8, + total_count=8, + expected_mi=0, + ), + dict( + testcase_name="invalid_input_zero_row_count", + cell_count=4, + row_count=0, + col_count=8, + total_count=8, + expected_mi=0, + ), + dict( + testcase_name="invalid_input_zero_col_count", + cell_count=4, + row_count=8, + col_count=0, + total_count=8, + expected_mi=0, + ), + ) + def test_mutual_information( + self, cell_count, row_count, col_count, total_count, expected_mi + ): + per_cell_mi = info_theory.calculate_partial_mutual_information( + cell_count, row_count, col_count, total_count + ) + self.assertNear(per_cell_mi, expected_mi, EPSILON) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow_transform/inspect_preprocessing_fn.py b/tensorflow_transform/inspect_preprocessing_fn.py index b92826a..5d0aad8 100644 --- a/tensorflow_transform/inspect_preprocessing_fn.py +++ b/tensorflow_transform/inspect_preprocessing_fn.py @@ -17,85 +17,105 @@ from typing import Callable, List, Mapping, Union import tensorflow as tf -from tensorflow_transform import analyzer_nodes -from tensorflow_transform import common_types -from tensorflow_transform import graph_tools -from tensorflow_transform import impl_helper -from tensorflow_transform import nodes -from tensorflow_transform import tf2_utils + +from tensorflow_transform import ( + analyzer_nodes, + common_types, + graph_tools, + impl_helper, + nodes, + tf2_utils, +) def get_analyze_input_columns( - preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], - Mapping[str, common_types.TensorType]], + preprocessing_fn: Callable[ + [Mapping[str, common_types.TensorType]], Mapping[str, common_types.TensorType] + ], specs: Mapping[str, Union[common_types.FeatureSpecType, tf.TypeSpec]], - force_tf_compat_v1: bool = False) -> List[str]: - """Return columns that are required inputs of `AnalyzeDataset`. + force_tf_compat_v1: bool = False, +) -> List[str]: + """Return columns that are required inputs of `AnalyzeDataset`. - Args: - preprocessing_fn: A tf.transform preprocessing_fn. - specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is - True, this can also be feature specifications. - force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode. - Defaults to `False`. + Args: + ---- + preprocessing_fn: A tf.transform preprocessing_fn. + specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is + True, this can also be feature specifications. + force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode. + Defaults to `False`. - Returns: - A list of columns that are required inputs of analyzers. - """ - use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) - if not use_tf_compat_v1: - assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs - graph, structured_inputs, structured_outputs = ( - impl_helper.trace_preprocessing_function( - preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1)) + Returns: + ------- + A list of columns that are required inputs of analyzers. + """ + use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) + if not use_tf_compat_v1: + assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs + graph, structured_inputs, structured_outputs = ( + impl_helper.trace_preprocessing_function( + preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1 + ) + ) - tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) - visitor = graph_tools.SourcedTensorsVisitor() - for tensor_sink in tensor_sinks: - nodes.Traverser(visitor).visit_value_node(tensor_sink.future) + tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS) + visitor = graph_tools.SourcedTensorsVisitor() + for tensor_sink in tensor_sinks: + nodes.Traverser(visitor).visit_value_node(tensor_sink.future) - if use_tf_compat_v1: - control_dependency_ops = [] - else: - # If traced in TF2 as a tf.function, inputs that end up in control - # dependencies are required for the function to execute. Return such inputs - # as required inputs of analyzers as well. - _, control_dependency_ops = ( - tf2_utils.strip_and_get_tensors_and_control_dependencies( - tf.nest.flatten(structured_outputs, expand_composites=True))) + if use_tf_compat_v1: + control_dependency_ops = [] + else: + # If traced in TF2 as a tf.function, inputs that end up in control + # dependencies are required for the function to execute. Return such inputs + # as required inputs of analyzers as well. + _, control_dependency_ops = ( + tf2_utils.strip_and_get_tensors_and_control_dependencies( + tf.nest.flatten(structured_outputs, expand_composites=True) + ) + ) - output_tensors = list( - itertools.chain(visitor.sourced_tensors, control_dependency_ops)) - analyze_input_tensors = graph_tools.get_dependent_inputs( - graph, structured_inputs, output_tensors) - return list(analyze_input_tensors.keys()) + output_tensors = list( + itertools.chain(visitor.sourced_tensors, control_dependency_ops) + ) + analyze_input_tensors = graph_tools.get_dependent_inputs( + graph, structured_inputs, output_tensors + ) + return list(analyze_input_tensors.keys()) def get_transform_input_columns( - preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], - Mapping[str, common_types.TensorType]], + preprocessing_fn: Callable[ + [Mapping[str, common_types.TensorType]], Mapping[str, common_types.TensorType] + ], specs: Mapping[str, Union[common_types.FeatureSpecType, tf.TypeSpec]], - force_tf_compat_v1: bool = False) -> List[str]: - """Return columns that are required inputs of `TransformDataset`. + force_tf_compat_v1: bool = False, +) -> List[str]: + """Return columns that are required inputs of `TransformDataset`. - Args: - preprocessing_fn: A tf.transform preprocessing_fn. - specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is - True, this can also be feature specifications. - force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode. - Defaults to `False`. + Args: + ---- + preprocessing_fn: A tf.transform preprocessing_fn. + specs: A dict of feature name to tf.TypeSpecs. If `force_tf_compat_v1` is + True, this can also be feature specifications. + force_tf_compat_v1: (Optional) If `True`, use Tensorflow in compat.v1 mode. + Defaults to `False`. - Returns: - A list of columns that are required inputs of the transform `tf.Graph` - defined by `preprocessing_fn`. - """ - use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) - if not use_tf_compat_v1: - assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs - graph, structured_inputs, structured_outputs = ( - impl_helper.trace_preprocessing_function( - preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1)) + Returns: + ------- + A list of columns that are required inputs of the transform `tf.Graph` + defined by `preprocessing_fn`. + """ + use_tf_compat_v1 = tf2_utils.use_tf_compat_v1(force_tf_compat_v1) + if not use_tf_compat_v1: + assert all([isinstance(s, tf.TypeSpec) for s in specs.values()]), specs + graph, structured_inputs, structured_outputs = ( + impl_helper.trace_preprocessing_function( + preprocessing_fn, specs, use_tf_compat_v1=use_tf_compat_v1 + ) + ) - transform_input_tensors = graph_tools.get_dependent_inputs( - graph, structured_inputs, structured_outputs) - return list(transform_input_tensors.keys()) + transform_input_tensors = graph_tools.get_dependent_inputs( + graph, structured_inputs, structured_outputs + ) + return list(transform_input_tensors.keys()) diff --git a/tensorflow_transform/inspect_preprocessing_fn_test.py b/tensorflow_transform/inspect_preprocessing_fn_test.py index 7e37f61..f4cc779 100644 --- a/tensorflow_transform/inspect_preprocessing_fn_test.py +++ b/tensorflow_transform/inspect_preprocessing_fn_test.py @@ -14,157 +14,173 @@ """Tests for inspect_preprocessing_fn.""" import tensorflow as tf -from tensorflow_transform import analyzers -from tensorflow_transform import inspect_preprocessing_fn -from tensorflow_transform import mappers -from tensorflow_transform import test_case + +from tensorflow_transform import analyzers, inspect_preprocessing_fn, mappers, test_case _FEATURE_SPEC = { - 'x': tf.io.FixedLenFeature([], tf.float32), - 'y': tf.io.VarLenFeature(tf.int64), - 's': tf.io.FixedLenFeature([], tf.string), + "x": tf.io.FixedLenFeature([], tf.float32), + "y": tf.io.VarLenFeature(tf.int64), + "s": tf.io.FixedLenFeature([], tf.string), } _TYPE_SPEC = { - 'x': tf.TensorSpec([None], tf.float32), - 'y': tf.SparseTensorSpec(shape=[None, None], dtype=tf.int64), - 's': tf.TensorSpec([None], tf.string), + "x": tf.TensorSpec([None], tf.float32), + "y": tf.SparseTensorSpec(shape=[None, None], dtype=tf.int64), + "s": tf.TensorSpec([None], tf.string), } def _identity_preprocessing_fn(inputs): - return inputs.copy() + return inputs.copy() def _side_affect_preprocessing_fn(inputs): - _ = analyzers.vocabulary(inputs['s']) - return {} + _ = analyzers.vocabulary(inputs["s"]) + return {} def _non_identity_ops_preprocessing_fn(inputs): - outputs = inputs.copy() - outputs['new_feature'] = tf.constant(1) - return outputs + outputs = inputs.copy() + outputs["new_feature"] = tf.constant(1) + return outputs def _renaming_preprocessing_fn(inputs): - return {'id_{}'.format(key): value for key, value in inputs.items()} + return {f"id_{key}": value for key, value in inputs.items()} @tf.function def _plus_one(x): - return x + 1 + return x + 1 def _one_phase_preprocessing_fn(inputs): - x_plus_one = _plus_one(inputs['x']) - subtracted = tf.sparse.add( - tf.cast(inputs['y'], tf.float32), -analyzers.mean(x_plus_one)) - _ = analyzers.vocabulary(inputs['s']) - return {'subtracted': subtracted} + x_plus_one = _plus_one(inputs["x"]) + subtracted = tf.sparse.add( + tf.cast(inputs["y"], tf.float32), -analyzers.mean(x_plus_one) + ) + _ = analyzers.vocabulary(inputs["s"]) + return {"subtracted": subtracted} def _two_phases_preprocessing_fn(inputs): - x = inputs['x'] - x_mean = analyzers.mean(x) - x_square_deviations = tf.square(x - x_mean) - x_var = analyzers.mean(x_square_deviations + analyzers.mean(inputs['y'])) - x_normalized = (x - x_mean) / tf.sqrt(x_var) - return { - 'x_normalized': x_normalized, - 's_id': mappers.compute_and_apply_vocabulary(inputs['s']) - } + x = inputs["x"] + x_mean = analyzers.mean(x) + x_square_deviations = tf.square(x - x_mean) + x_var = analyzers.mean(x_square_deviations + analyzers.mean(inputs["y"])) + x_normalized = (x - x_mean) / tf.sqrt(x_var) + return { + "x_normalized": x_normalized, + "s_id": mappers.compute_and_apply_vocabulary(inputs["s"]), + } def _preprocessing_fn_with_control_dependency(inputs): - with tf.init_scope(): - initializer = tf.lookup.KeyValueTensorInitializer(['foo', 'bar'], [0, 1]) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - # The table created here will add an automatic control dependency. - s_int = table.lookup(inputs['s']) + 1 + with tf.init_scope(): + initializer = tf.lookup.KeyValueTensorInitializer(["foo", "bar"], [0, 1]) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + # The table created here will add an automatic control dependency. + s_int = table.lookup(inputs["s"]) + 1 - # Perform some TF Ops to ensure x is part of the graph of dependencies for the - # outputs. - x_abs = tf.math.abs(inputs['x']) - y_centered = ( - tf.sparse.add( - tf.cast(inputs['y'], tf.float32), -analyzers.mean(inputs['y']))) - return {'s_int': s_int, 'x_abs': x_abs, 'y_centered': y_centered} + # Perform some TF Ops to ensure x is part of the graph of dependencies for the + # outputs. + x_abs = tf.math.abs(inputs["x"]) + y_centered = tf.sparse.add( + tf.cast(inputs["y"], tf.float32), -analyzers.mean(inputs["y"]) + ) + return {"s_int": s_int, "x_abs": x_abs, "y_centered": y_centered} class InspectPreprocessingFnTest(test_case.TransformTestCase): - - @test_case.named_parameters( - *test_case.cross_named_parameters([ - dict( - testcase_name='identity', - preprocessing_fn=_identity_preprocessing_fn, - expected_analyze_input_columns=[], - expected_transform_input_columns=['x', 'y', 's']), - dict( - testcase_name='side_affect', - preprocessing_fn=_side_affect_preprocessing_fn, - expected_analyze_input_columns=['s'], - expected_transform_input_columns=[]), - dict( - testcase_name='non_identity_ops', - preprocessing_fn=_non_identity_ops_preprocessing_fn, - expected_analyze_input_columns=[], - expected_transform_input_columns=['x', 'y', 's']), - dict( - testcase_name='feature_renaming', - preprocessing_fn=_renaming_preprocessing_fn, - expected_analyze_input_columns=[], - expected_transform_input_columns=['x', 'y', 's']), - dict( - testcase_name='one_phase', - preprocessing_fn=_one_phase_preprocessing_fn, - expected_analyze_input_columns=['x', 's'], - expected_transform_input_columns=['y']), - dict( - testcase_name='two_phases', - preprocessing_fn=_two_phases_preprocessing_fn, - expected_analyze_input_columns=['x', 'y', 's'], - expected_transform_input_columns=['x', 's']) - ], [ - dict(testcase_name='tf_compat_v1', force_tf_compat_v1=True), - dict(testcase_name='tf2', force_tf_compat_v1=False) - ]), - *test_case.cross_named_parameters([ - dict( - testcase_name='control_dependencies', - preprocessing_fn=_preprocessing_fn_with_control_dependency, - expected_transform_input_columns=['x', 'y', 's']) - ], [ - dict( - testcase_name='tf_compat_v1', - force_tf_compat_v1=True, - expected_analyze_input_columns=['y']), - dict( - testcase_name='tf2', - force_tf_compat_v1=False, - expected_analyze_input_columns=['s', 'y']) - ])) - def test_column_inference(self, preprocessing_fn, - expected_analyze_input_columns, - expected_transform_input_columns, - force_tf_compat_v1): - if not force_tf_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') - specs = _TYPE_SPEC - else: - specs = _FEATURE_SPEC - - analyze_input_columns = ( - inspect_preprocessing_fn.get_analyze_input_columns( - preprocessing_fn, specs, force_tf_compat_v1)) - transform_input_columns = ( - inspect_preprocessing_fn.get_transform_input_columns( - preprocessing_fn, specs, force_tf_compat_v1)) - self.assertCountEqual(analyze_input_columns, expected_analyze_input_columns) - self.assertCountEqual(transform_input_columns, - expected_transform_input_columns) - - -if __name__ == '__main__': - test_case.main() + @test_case.named_parameters( + *test_case.cross_named_parameters( + [ + dict( + testcase_name="identity", + preprocessing_fn=_identity_preprocessing_fn, + expected_analyze_input_columns=[], + expected_transform_input_columns=["x", "y", "s"], + ), + dict( + testcase_name="side_affect", + preprocessing_fn=_side_affect_preprocessing_fn, + expected_analyze_input_columns=["s"], + expected_transform_input_columns=[], + ), + dict( + testcase_name="non_identity_ops", + preprocessing_fn=_non_identity_ops_preprocessing_fn, + expected_analyze_input_columns=[], + expected_transform_input_columns=["x", "y", "s"], + ), + dict( + testcase_name="feature_renaming", + preprocessing_fn=_renaming_preprocessing_fn, + expected_analyze_input_columns=[], + expected_transform_input_columns=["x", "y", "s"], + ), + dict( + testcase_name="one_phase", + preprocessing_fn=_one_phase_preprocessing_fn, + expected_analyze_input_columns=["x", "s"], + expected_transform_input_columns=["y"], + ), + dict( + testcase_name="two_phases", + preprocessing_fn=_two_phases_preprocessing_fn, + expected_analyze_input_columns=["x", "y", "s"], + expected_transform_input_columns=["x", "s"], + ), + ], + [ + dict(testcase_name="tf_compat_v1", force_tf_compat_v1=True), + dict(testcase_name="tf2", force_tf_compat_v1=False), + ], + ), + *test_case.cross_named_parameters( + [ + dict( + testcase_name="control_dependencies", + preprocessing_fn=_preprocessing_fn_with_control_dependency, + expected_transform_input_columns=["x", "y", "s"], + ) + ], + [ + dict( + testcase_name="tf_compat_v1", + force_tf_compat_v1=True, + expected_analyze_input_columns=["y"], + ), + dict( + testcase_name="tf2", + force_tf_compat_v1=False, + expected_analyze_input_columns=["s", "y"], + ), + ], + ), + ) + def test_column_inference( + self, + preprocessing_fn, + expected_analyze_input_columns, + expected_transform_input_columns, + force_tf_compat_v1, + ): + if not force_tf_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") + specs = _TYPE_SPEC + else: + specs = _FEATURE_SPEC + + analyze_input_columns = inspect_preprocessing_fn.get_analyze_input_columns( + preprocessing_fn, specs, force_tf_compat_v1 + ) + transform_input_columns = inspect_preprocessing_fn.get_transform_input_columns( + preprocessing_fn, specs, force_tf_compat_v1 + ) + self.assertCountEqual(analyze_input_columns, expected_analyze_input_columns) + self.assertCountEqual(transform_input_columns, expected_transform_input_columns) + + +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/keras_lib.py b/tensorflow_transform/keras_lib.py index 668b16b..3ebb187 100644 --- a/tensorflow_transform/keras_lib.py +++ b/tensorflow_transform/keras_lib.py @@ -12,31 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. """Imports keras 2.""" + import os -from absl import logging import tensorflow as tf +from absl import logging -if 'TF_USE_LEGACY_KERAS' not in os.environ: - # Make sure we are using Keras 2. - os.environ['TF_USE_LEGACY_KERAS'] = '1' -elif os.environ['TF_USE_LEGACY_KERAS'] not in ('true', 'True', '1'): - logging.warning( - 'TF_USE_LEGACY_KERAS is set to %s, which will not use Keras 2. Tensorflow' - ' Transform is only compatible with Keras 2. Please set' - ' TF_USE_LEGACY_KERAS=1.', - os.environ['TF_USE_LEGACY_KERAS'], - ) +if "TF_USE_LEGACY_KERAS" not in os.environ: + # Make sure we are using Keras 2. + os.environ["TF_USE_LEGACY_KERAS"] = "1" +elif os.environ["TF_USE_LEGACY_KERAS"] not in ("true", "True", "1"): + logging.warning( + "TF_USE_LEGACY_KERAS is set to %s, which will not use Keras 2. Tensorflow" + " Transform is only compatible with Keras 2. Please set" + " TF_USE_LEGACY_KERAS=1.", + os.environ["TF_USE_LEGACY_KERAS"], + ) -version_fn = getattr(tf.keras, 'version', None) -if version_fn and version_fn().startswith('3.'): - # `tf.keras` points to `keras 3`, so use `tf_keras` package - try: - import tf_keras # pylint: disable=g-import-not-at-top,unused-import - except ImportError: - raise ImportError( # pylint: disable=raise-missing-from - 'Keras 2 requires the `tf_keras` package.' - 'Please install it with `pip install tf_keras`.' - ) from None +version_fn = getattr(tf.keras, "version", None) +if version_fn and version_fn().startswith("3."): + # `tf.keras` points to `keras 3`, so use `tf_keras` package + try: + import tf_keras # pylint: disable=g-import-not-at-top,unused-import + except ImportError: + raise ImportError( # pylint: disable=raise-missing-from + "Keras 2 requires the `tf_keras` package." + "Please install it with `pip install tf_keras`." + ) from None else: - tf_keras = tf.keras # Keras 2 + tf_keras = tf.keras # Keras 2 diff --git a/tensorflow_transform/mappers.py b/tensorflow_transform/mappers.py index 5806849..324217b 100644 --- a/tensorflow_transform/mappers.py +++ b/tensorflow_transform/mappers.py @@ -54,189 +54,203 @@ def preprocessing_fn(inputs): import os from typing import Any, Callable, Iterable, Optional, Tuple, Union - import tensorflow as tf -from tensorflow_transform import analyzers -from tensorflow_transform import common -from tensorflow_transform import common_types -from tensorflow_transform import gaussianization -from tensorflow_transform import schema_inference -from tensorflow_transform import tf_utils + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple +from tensorflow_transform import ( + analyzers, + common, + common_types, + gaussianization, + schema_inference, + tf_utils, +) + @common.log_api_use(common.MAPPER_COLLECTION) def scale_to_gaussian( x: common_types.ConsistentTensorType, elementwise: bool = False, name: Optional[str] = None, - output_dtype: Optional[tf.DType] = None + output_dtype: Optional[tf.DType] = None, ) -> common_types.ConsistentTensorType: - """Returns an (approximately) normal column with mean to 0 and variance 1. - - We transform the column to values that are approximately distributed - according to a standard normal distribution. - The transformation is obtained by applying the moments method to estimate - the parameters of a Tukey HH distribution and applying the inverse of the - estimated function to the column values. - The method is partially described in - - Georg M. Georgm "The Lambert Way to Gaussianize Heavy-Tailed Data with the - Inverse of Tukey's h Transformation as a Special Case," The Scientific World - Journal, Vol. 2015, Hindawi Publishing Corporation. - - We use the L-moments instead of conventional moments to be able to deal with - long-tailed distributions. The expressions of the L-moments for the Tukey HH - distribution is in - - Todd C. Headrick, and Mohan D. Pant. "Characterizing Tukey H and - HH-Distributions through L-Moments and the L-Correlation," ISRN Applied - Mathematics, vol. 2012, 2012. doi:10.5402/2012/980153 - - Note that the transformation to Gaussian is applied only if the column has - long-tails. If this is not the case, for instance if values are uniformly - distributed, the values are only normalized using the z score. This applies - also to the cases where only one of the tails is long; the other tail is only - rescaled but not non linearly transformed. - Also, if the analysis set is empty, the transformation is set to to leave the - input vaules unchanged. - - Args: - x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. - elementwise: If true, scales each element of the tensor independently; - otherwise uses the parameters of the whole tensor. - name: (Optional) A name for this operation. - output_dtype: (Optional) If not None, casts the output tensor to this type. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column - transformed to be approximately standard distributed (i.e. a Gaussian with - mean 0 and variance 1). If `x` is floating point, the mean will have the - same type as `x`. If `x` is integral, the output is cast to tf.float32. - - Note that TFLearn generally permits only tf.int64 and tf.float32, so casting - this scaler's output may be necessary. - """ - with tf.compat.v1.name_scope(name, 'scale_to_gaussian'): - return _scale_to_gaussian_internal( - x=x, - elementwise=elementwise, - output_dtype=output_dtype) + """Returns an (approximately) normal column with mean to 0 and variance 1. + + We transform the column to values that are approximately distributed + according to a standard normal distribution. + The transformation is obtained by applying the moments method to estimate + the parameters of a Tukey HH distribution and applying the inverse of the + estimated function to the column values. + The method is partially described in + + Georg M. Georgm "The Lambert Way to Gaussianize Heavy-Tailed Data with the + Inverse of Tukey's h Transformation as a Special Case," The Scientific World + Journal, Vol. 2015, Hindawi Publishing Corporation. + + We use the L-moments instead of conventional moments to be able to deal with + long-tailed distributions. The expressions of the L-moments for the Tukey HH + distribution is in + + Todd C. Headrick, and Mohan D. Pant. "Characterizing Tukey H and + HH-Distributions through L-Moments and the L-Correlation," ISRN Applied + Mathematics, vol. 2012, 2012. doi:10.5402/2012/980153 + + Note that the transformation to Gaussian is applied only if the column has + long-tails. If this is not the case, for instance if values are uniformly + distributed, the values are only normalized using the z score. This applies + also to the cases where only one of the tails is long; the other tail is only + rescaled but not non linearly transformed. + Also, if the analysis set is empty, the transformation is set to to leave the + input vaules unchanged. + + Args: + ---- + x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. + elementwise: If true, scales each element of the tensor independently; + otherwise uses the parameters of the whole tensor. + name: (Optional) A name for this operation. + output_dtype: (Optional) If not None, casts the output tensor to this type. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column + transformed to be approximately standard distributed (i.e. a Gaussian with + mean 0 and variance 1). If `x` is floating point, the mean will have the + same type as `x`. If `x` is integral, the output is cast to tf.float32. + + Note that TFLearn generally permits only tf.int64 and tf.float32, so casting + this scaler's output may be necessary. + """ + with tf.compat.v1.name_scope(name, "scale_to_gaussian"): + return _scale_to_gaussian_internal( + x=x, elementwise=elementwise, output_dtype=output_dtype + ) def _scale_to_gaussian_internal( x: common_types.ConsistentTensorType, elementwise: bool = False, - output_dtype: Optional[tf.DType] = None + output_dtype: Optional[tf.DType] = None, ) -> common_types.ConsistentTensorType: - """Implementation for scale_to_gaussian.""" - # x_mean will be float16, float32, or float64, depending on type of x. - x_loc, x_scale, hl, hr = analyzers._tukey_parameters( # pylint: disable=protected-access - x, reduce_instance_dims=not elementwise, output_dtype=output_dtype) - - compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) - x_values = tf_utils.get_values(x) - - x_var = analyzers.var(x, reduce_instance_dims=not elementwise, - output_dtype=output_dtype) - - if isinstance(x, tf.SparseTensor): - if elementwise: - x_loc = tf.gather_nd(x_loc, x.indices[:, 1:]) - x_scale = tf.gather_nd(x_scale, x.indices[:, 1:]) - hl = tf.gather_nd(hl, x.indices[:, 1:]) - hr = tf.gather_nd(hr, x.indices[:, 1:]) - x_var = tf.gather_nd(x_var, x.indices[:, 1:]) - elif isinstance(x, tf.RaggedTensor): - if elementwise: - raise NotImplementedError( - 'Elementwise scale_to_gaussian does not support RaggedTensors.') - - numerator = tf.cast(x_values, x_loc.dtype) - x_loc - is_long_tailed = tf.math.logical_or(hl > 0.0, hr > 0.0) - - # If the distribution is long-tailed, we apply the robust scale computed - # with L-moments; otherwise, we scale using the standard deviation so that - # we obtain the same result of scale_to_z_score. - denominator = tf.where(is_long_tailed, x_scale, tf.sqrt(x_var)) - cond = tf.not_equal(denominator, 0) - - if cond.shape.as_list() != x_values.shape.as_list(): - # Repeats cond when necessary across the batch dimension for it to be - # compatible with the shape of numerator. - cond = tf.cast( - tf.zeros_like(numerator) + tf.cast(cond, numerator.dtype), - dtype=tf.bool) - - scaled_values = tf.where(cond, tf.divide(numerator, denominator), - numerator) - gaussianized_values = gaussianization.inverse_tukey_hh(scaled_values, hl, hr) - return compose_result_fn(gaussianized_values) + """Implementation for scale_to_gaussian.""" + # x_mean will be float16, float32, or float64, depending on type of x. + x_loc, x_scale, hl, hr = analyzers._tukey_parameters( # pylint: disable=protected-access + x, reduce_instance_dims=not elementwise, output_dtype=output_dtype + ) + + compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) + x_values = tf_utils.get_values(x) + + x_var = analyzers.var( + x, reduce_instance_dims=not elementwise, output_dtype=output_dtype + ) + + if isinstance(x, tf.SparseTensor): + if elementwise: + x_loc = tf.gather_nd(x_loc, x.indices[:, 1:]) + x_scale = tf.gather_nd(x_scale, x.indices[:, 1:]) + hl = tf.gather_nd(hl, x.indices[:, 1:]) + hr = tf.gather_nd(hr, x.indices[:, 1:]) + x_var = tf.gather_nd(x_var, x.indices[:, 1:]) + elif isinstance(x, tf.RaggedTensor): + if elementwise: + raise NotImplementedError( + "Elementwise scale_to_gaussian does not support RaggedTensors." + ) + + numerator = tf.cast(x_values, x_loc.dtype) - x_loc + is_long_tailed = tf.math.logical_or(hl > 0.0, hr > 0.0) + + # If the distribution is long-tailed, we apply the robust scale computed + # with L-moments; otherwise, we scale using the standard deviation so that + # we obtain the same result of scale_to_z_score. + denominator = tf.where(is_long_tailed, x_scale, tf.sqrt(x_var)) + cond = tf.not_equal(denominator, 0) + + if cond.shape.as_list() != x_values.shape.as_list(): + # Repeats cond when necessary across the batch dimension for it to be + # compatible with the shape of numerator. + cond = tf.cast( + tf.zeros_like(numerator) + tf.cast(cond, numerator.dtype), dtype=tf.bool + ) + + scaled_values = tf.where(cond, tf.divide(numerator, denominator), numerator) + gaussianized_values = gaussianization.inverse_tukey_hh(scaled_values, hl, hr) + return compose_result_fn(gaussianized_values) @common.log_api_use(common.MAPPER_COLLECTION) def sparse_tensor_to_dense_with_shape( x: tf.SparseTensor, shape: Union[tf.TensorShape, Iterable[int]], - default_value: Union[tf.Tensor, int, float, str] = 0) -> tf.Tensor: - """Converts a `SparseTensor` into a dense tensor and sets its shape. - - Args: - x: A `SparseTensor`. - shape: The desired shape of the densified `Tensor`. - default_value: (Optional) Value to set for indices not specified. Defaults - to zero. - - Returns: - A `Tensor` with the desired shape. - - Raises: - ValueError: If input is not a `SparseTensor`. - """ - if not isinstance(x, tf.SparseTensor): - raise ValueError('input must be a SparseTensor') - new_dense_shape = [ - x.dense_shape[i] if size is None else size - for i, size in enumerate(shape) - ] - dense = tf.raw_ops.SparseToDense( - sparse_indices=x.indices, - output_shape=new_dense_shape, - sparse_values=x.values, - default_value=default_value) - dense.set_shape(shape) - return dense + default_value: Union[tf.Tensor, int, float, str] = 0, +) -> tf.Tensor: + """Converts a `SparseTensor` into a dense tensor and sets its shape. + + Args: + ---- + x: A `SparseTensor`. + shape: The desired shape of the densified `Tensor`. + default_value: (Optional) Value to set for indices not specified. Defaults + to zero. + + Returns: + ------- + A `Tensor` with the desired shape. + + Raises: + ------ + ValueError: If input is not a `SparseTensor`. + """ + if not isinstance(x, tf.SparseTensor): + raise ValueError("input must be a SparseTensor") + new_dense_shape = [ + x.dense_shape[i] if size is None else size for i, size in enumerate(shape) + ] + dense = tf.raw_ops.SparseToDense( + sparse_indices=x.indices, + output_shape=new_dense_shape, + sparse_values=x.values, + default_value=default_value, + ) + dense.set_shape(shape) + return dense @common.log_api_use(common.MAPPER_COLLECTION) def sparse_tensor_left_align(sparse_tensor: tf.SparseTensor) -> tf.SparseTensor: - """Re-arranges a `tf.SparseTensor` and returns a left-aligned version of it. + """Re-arranges a `tf.SparseTensor` and returns a left-aligned version of it. - This mapper can be useful when returning a sparse tensor that may not be - left-aligned from a preprocessing_fn. + This mapper can be useful when returning a sparse tensor that may not be + left-aligned from a preprocessing_fn. - Args: - sparse_tensor: A 2D `tf.SparseTensor`. + Args: + ---- + sparse_tensor: A 2D `tf.SparseTensor`. - Raises: - ValueError if `sparse_tensor` is not 2D. + Raises: + ------ + ValueError if `sparse_tensor` is not 2D. - Returns: - A left-aligned version of sparse_tensor as a `tf.SparseTensor`. - """ - if sparse_tensor.get_shape().ndims != 2: - raise ValueError('sparse_tensor_left_align requires a 2D input') - reordered_tensor = tf.sparse.reorder(sparse_tensor) - transposed_indices = tf.transpose(reordered_tensor.indices) - row_indices = transposed_indices[0] - row_counts = tf.unique_with_counts(row_indices, out_idx=tf.int64).count - column_indices = tf.ragged.range(row_counts).flat_values - return tf.SparseTensor( - indices=tf.transpose(tf.stack([row_indices, column_indices])), - values=reordered_tensor.values, - dense_shape=reordered_tensor.dense_shape) + Returns: + ------- + A left-aligned version of sparse_tensor as a `tf.SparseTensor`. + """ + if sparse_tensor.get_shape().ndims != 2: + raise ValueError("sparse_tensor_left_align requires a 2D input") + reordered_tensor = tf.sparse.reorder(sparse_tensor) + transposed_indices = tf.transpose(reordered_tensor.indices) + row_indices = transposed_indices[0] + row_counts = tf.unique_with_counts(row_indices, out_idx=tf.int64).count + column_indices = tf.ragged.range(row_counts).flat_values + return tf.SparseTensor( + indices=tf.transpose(tf.stack([row_indices, column_indices])), + values=reordered_tensor.values, + dense_shape=reordered_tensor.dense_shape, + ) @common.log_api_use(common.MAPPER_COLLECTION) @@ -245,32 +259,37 @@ def scale_by_min_max( output_min: float = 0.0, output_max: float = 1.0, elementwise: bool = False, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - """Scale a numerical column into the range [output_min, output_max]. - - Args: - x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. - output_min: The minimum of the range of output values. - output_max: The maximum of the range of output values. - elementwise: If true, scale each element of the tensor independently. - name: (Optional) A name for this operation. - - Returns: - A `Tensor` containing the input column scaled to [output_min, output_max]. - If the analysis dataset is empty or contains a singe distinct value, then - `x` is scaled using a sigmoid function. - - Raises: - ValueError: If output_min, output_max have the wrong order. - """ - with tf.compat.v1.name_scope(name, 'scale_by_min_max'): - return _scale_by_min_max_internal( - x, - key=None, - output_min=output_min, - output_max=output_max, - elementwise=elementwise, - key_vocabulary_filename=None) + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + """Scale a numerical column into the range [output_min, output_max]. + + Args: + ---- + x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. + output_min: The minimum of the range of output values. + output_max: The maximum of the range of output values. + elementwise: If true, scale each element of the tensor independently. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor` containing the input column scaled to [output_min, output_max]. + If the analysis dataset is empty or contains a singe distinct value, then + `x` is scaled using a sigmoid function. + + Raises: + ------ + ValueError: If output_min, output_max have the wrong order. + """ + with tf.compat.v1.name_scope(name, "scale_by_min_max"): + return _scale_by_min_max_internal( + x, + key=None, + output_min=output_min, + output_max=output_max, + elementwise=elementwise, + key_vocabulary_filename=None, + ) @common.log_api_use(common.MAPPER_COLLECTION) @@ -281,71 +300,76 @@ def scale_by_min_max_per_key( output_max: float = 1.0, elementwise: bool = False, key_vocabulary_filename: Optional[str] = None, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - # pyformat: disable - """Scale a numerical column into a predefined range on a per-key basis. - - Args: - x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. - key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string. - Must meet one of the following conditions: - 0. key is None - 1. Both x and key are dense, - 2. Both x and key are composite and `key` must exactly match `x` in - everything except values, - 3. The axis=1 index of each x matches its index of dense key. - output_min: The minimum of the range of output values. - output_max: The maximum of the range of output values. - elementwise: If true, scale each element of the tensor independently. - key_vocabulary_filename: (Optional) The file name for the per-key file. - If None, this combiner will assume the keys fit in memory and will not - store the analyzer result in a file. If '', a file name will be chosen - based on the current TensorFlow scope. If not '', it should be unique - within a given preprocessing function. - name: (Optional) A name for this operation. - - Example: - - >>> def preprocessing_fn(inputs): - ... return { - ... 'scaled': tft.scale_by_min_max_per_key(inputs['x'], inputs['s']) - ... } - >>> raw_data = [dict(x=1, s='a'), dict(x=0, s='b'), dict(x=3, s='a')] - >>> feature_spec = dict( - ... x=tf.io.FixedLenFeature([], tf.float32), - ... s=tf.io.FixedLenFeature([], tf.string)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... transformed_dataset, transform_fn = ( - ... (raw_data, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - >>> transformed_data, transformed_metadata = transformed_dataset - >>> transformed_data - [{'scaled': 0.0}, {'scaled': 0.5}, {'scaled': 1.0}] - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column scaled to - [output_min, output_max] on a per-key basis if a key is provided. If the - analysis dataset is empty, a certain key contains a single distinct value or - the computed key vocabulary doesn't have an entry for `key`, then `x` is - scaled using a sigmoid function. - - Raises: - ValueError: If output_min, output_max have the wrong order. - NotImplementedError: If elementwise is True and key is not None. - InvalidArgumentError: If indices of sparse x and key do not match. - """ - # pyformat: enable - with tf.compat.v1.name_scope(name, 'scale_by_min_max_per_key'): - if key is None: - raise ValueError('key is None, call `tft.scale_by_min_max` instead') - return _scale_by_min_max_internal( - x, - key=key, - output_min=output_min, - output_max=output_max, - elementwise=elementwise, - key_vocabulary_filename=key_vocabulary_filename) + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + # pyformat: disable + """Scale a numerical column into a predefined range on a per-key basis. + + Args: + ---- + x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. + key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string. + Must meet one of the following conditions: + 0. key is None + 1. Both x and key are dense, + 2. Both x and key are composite and `key` must exactly match `x` in + everything except values, + 3. The axis=1 index of each x matches its index of dense key. + output_min: The minimum of the range of output values. + output_max: The maximum of the range of output values. + elementwise: If true, scale each element of the tensor independently. + key_vocabulary_filename: (Optional) The file name for the per-key file. + If None, this combiner will assume the keys fit in memory and will not + store the analyzer result in a file. If '', a file name will be chosen + based on the current TensorFlow scope. If not '', it should be unique + within a given preprocessing function. + name: (Optional) A name for this operation. + + Example: + ------- + >>> def preprocessing_fn(inputs): + ... return { + ... 'scaled': tft.scale_by_min_max_per_key(inputs['x'], inputs['s']) + ... } + >>> raw_data = [dict(x=1, s='a'), dict(x=0, s='b'), dict(x=3, s='a')] + >>> feature_spec = dict( + ... x=tf.io.FixedLenFeature([], tf.float32), + ... s=tf.io.FixedLenFeature([], tf.string)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... transformed_dataset, transform_fn = ( + ... (raw_data, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + >>> transformed_data, transformed_metadata = transformed_dataset + >>> transformed_data + [{'scaled': 0.0}, {'scaled': 0.5}, {'scaled': 1.0}] + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column scaled to + [output_min, output_max] on a per-key basis if a key is provided. If the + analysis dataset is empty, a certain key contains a single distinct value or + the computed key vocabulary doesn't have an entry for `key`, then `x` is + scaled using a sigmoid function. + + Raises: + ------ + ValueError: If output_min, output_max have the wrong order. + NotImplementedError: If elementwise is True and key is not None. + InvalidArgumentError: If indices of sparse x and key do not match. + """ + # pyformat: enable + with tf.compat.v1.name_scope(name, "scale_by_min_max_per_key"): + if key is None: + raise ValueError("key is None, call `tft.scale_by_min_max` instead") + return _scale_by_min_max_internal( + x, + key=key, + output_min=output_min, + output_max=output_max, + elementwise=elementwise, + key_vocabulary_filename=key_vocabulary_filename, + ) def _scale_by_min_max_internal( @@ -354,98 +378,112 @@ def _scale_by_min_max_internal( output_min: float, output_max: float, elementwise: bool, - key_vocabulary_filename: Optional[str] = None + key_vocabulary_filename: Optional[str] = None, ) -> common_types.ConsistentTensorType: - """Implementation for scale_by_min_max.""" - if output_min >= output_max: - raise ValueError('output_min must be less than output_max') + """Implementation for scale_by_min_max.""" + if output_min >= output_max: + raise ValueError("output_min must be less than output_max") - x = tf.cast(x, tf.float32) - if key is None: - min_x_value, max_x_value = analyzers._min_and_max( # pylint: disable=protected-access - x, - reduce_instance_dims=not elementwise) - else: - if elementwise and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): - raise NotImplementedError( - 'Per-key elementwise reduction of Composite Tensors not supported') - key_values = analyzers._min_and_max_per_key( # pylint: disable=protected-access - x, - key, - reduce_instance_dims=not elementwise, - key_vocabulary_filename=key_vocabulary_filename) - if key_vocabulary_filename is None: - key_vocab, min_x_value, max_x_value = key_values - # Missing keys will translate to 0 for both min and max which will be - # ignored below in the tf.where. - min_x_value, max_x_value = tf_utils.map_per_key_reductions( - (min_x_value, max_x_value), key, key_vocab, x, not elementwise) + x = tf.cast(x, tf.float32) + if key is None: + min_x_value, max_x_value = analyzers._min_and_max( # pylint: disable=protected-access + x, reduce_instance_dims=not elementwise + ) else: - if elementwise: - raise NotImplementedError( - 'Elementwise scaling does not support key_vocabulary_filename') - minus_min_max_for_key = tf_utils.apply_per_key_vocabulary( - key_values, key, target_ndims=x.get_shape().ndims) - min_x_value, max_x_value = (-minus_min_max_for_key[:, 0], - minus_min_max_for_key[:, 1]) - - compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) - x_values = tf_utils.get_values(x) - if isinstance(x, tf.SparseTensor): - if elementwise: - min_x_value = tf.gather_nd( - tf.broadcast_to(min_x_value, x.dense_shape), x.indices) - max_x_value = tf.gather_nd( - tf.broadcast_to(max_x_value, x.dense_shape), x.indices) - elif isinstance(x, tf.RaggedTensor): - if elementwise: - raise NotImplementedError( - 'Elementwise min_and_max does not support RaggedTensors.') - - # If min>=max, then the corresponding input to the min_and_max analyzer either - # was empty and the analyzer returned default values, or contained only one - # distinct value. In this case we scale x by applying a sigmoid function which - # is continuous, increasing and maps (-inf, inf) -> (0, 1). Its output is - # then projected on the requested range. Note that both the options of - # tf.where are computed, which means that this will compute unused NaNs. - numerator = tf.cast(x_values, min_x_value.dtype) - min_x_value - where_cond = min_x_value < max_x_value - where_cond = tf.cast( - tf.zeros_like(numerator) + tf.cast(where_cond, numerator.dtype), - dtype=tf.bool) - scaled_result = tf.where(where_cond, numerator / (max_x_value - min_x_value), - tf.math.sigmoid(x_values)) - - return compose_result_fn((scaled_result * (output_max - output_min)) + - output_min) + if elementwise and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): + raise NotImplementedError( + "Per-key elementwise reduction of Composite Tensors not supported" + ) + key_values = analyzers._min_and_max_per_key( # pylint: disable=protected-access + x, + key, + reduce_instance_dims=not elementwise, + key_vocabulary_filename=key_vocabulary_filename, + ) + if key_vocabulary_filename is None: + key_vocab, min_x_value, max_x_value = key_values + # Missing keys will translate to 0 for both min and max which will be + # ignored below in the tf.where. + min_x_value, max_x_value = tf_utils.map_per_key_reductions( + (min_x_value, max_x_value), key, key_vocab, x, not elementwise + ) + else: + if elementwise: + raise NotImplementedError( + "Elementwise scaling does not support key_vocabulary_filename" + ) + minus_min_max_for_key = tf_utils.apply_per_key_vocabulary( + key_values, key, target_ndims=x.get_shape().ndims + ) + min_x_value, max_x_value = ( + -minus_min_max_for_key[:, 0], + minus_min_max_for_key[:, 1], + ) + + compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) + x_values = tf_utils.get_values(x) + if isinstance(x, tf.SparseTensor): + if elementwise: + min_x_value = tf.gather_nd( + tf.broadcast_to(min_x_value, x.dense_shape), x.indices + ) + max_x_value = tf.gather_nd( + tf.broadcast_to(max_x_value, x.dense_shape), x.indices + ) + elif isinstance(x, tf.RaggedTensor): + if elementwise: + raise NotImplementedError( + "Elementwise min_and_max does not support RaggedTensors." + ) + + # If min>=max, then the corresponding input to the min_and_max analyzer either + # was empty and the analyzer returned default values, or contained only one + # distinct value. In this case we scale x by applying a sigmoid function which + # is continuous, increasing and maps (-inf, inf) -> (0, 1). Its output is + # then projected on the requested range. Note that both the options of + # tf.where are computed, which means that this will compute unused NaNs. + numerator = tf.cast(x_values, min_x_value.dtype) - min_x_value + where_cond = min_x_value < max_x_value + where_cond = tf.cast( + tf.zeros_like(numerator) + tf.cast(where_cond, numerator.dtype), dtype=tf.bool + ) + scaled_result = tf.where( + where_cond, numerator / (max_x_value - min_x_value), tf.math.sigmoid(x_values) + ) + + return compose_result_fn((scaled_result * (output_max - output_min)) + output_min) @common.log_api_use(common.MAPPER_COLLECTION) def scale_to_0_1( x: common_types.ConsistentTensorType, elementwise: bool = False, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - """Returns a column which is the input column scaled to have range [0,1]. - - Args: - x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. - elementwise: If true, scale each element of the tensor independently. - name: (Optional) A name for this operation. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column - scaled to - [0, 1]. If the analysis dataset is empty or contains a single distinct - value, then `x` is scaled using a sigmoid function. - """ - with tf.compat.v1.name_scope(name, 'scale_to_0_1'): - return _scale_by_min_max_internal( - x, - key=None, - output_min=0, - output_max=1, - elementwise=elementwise, - key_vocabulary_filename=None) + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + """Returns a column which is the input column scaled to have range [0,1]. + + Args: + ---- + x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. + elementwise: If true, scale each element of the tensor independently. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column + scaled to + [0, 1]. If the analysis dataset is empty or contains a single distinct + value, then `x` is scaled using a sigmoid function. + """ + with tf.compat.v1.name_scope(name, "scale_to_0_1"): + return _scale_by_min_max_internal( + x, + key=None, + output_min=0, + output_max=1, + elementwise=elementwise, + key_vocabulary_filename=None, + ) @common.log_api_use(common.MAPPER_COLLECTION) @@ -454,57 +492,61 @@ def scale_to_0_1_per_key( key: common_types.TensorType, elementwise: bool = False, key_vocabulary_filename: Optional[str] = None, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - # pyformat: disable - """Returns a column which is the input column scaled to have range [0,1]. - - Args: - x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. - key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of type string. - elementwise: If true, scale each element of the tensor independently. - key_vocabulary_filename: (Optional) The file name for the per-key file. If - None, this combiner will assume the keys fit in memory and will not store - the analyzer result in a file. If '', a file name will be chosen based on - the current TensorFlow scope. If not '', it should be unique within a - given preprocessing function. - name: (Optional) A name for this operation. - - Example: - - >>> def preprocessing_fn(inputs): - ... return { - ... 'scaled': tft.scale_to_0_1_per_key(inputs['x'], inputs['s']) - ... } - >>> raw_data = [dict(x=1, s='a'), dict(x=0, s='b'), dict(x=3, s='a')] - >>> feature_spec = dict( - ... x=tf.io.FixedLenFeature([], tf.float32), - ... s=tf.io.FixedLenFeature([], tf.string)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... transformed_dataset, transform_fn = ( - ... (raw_data, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - >>> transformed_data, transformed_metadata = transformed_dataset - >>> transformed_data - [{'scaled': 0.0}, {'scaled': 0.5}, {'scaled': 1.0}] - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column scaled to [0, 1], - per key. If the analysis dataset is empty, contains a single distinct value - or the computed key vocabulary doesn't have an entry for `key`, then `x` is - scaled using a sigmoid function. - """ - # pyformat: enable - with tf.compat.v1.name_scope(name, 'scale_to_0_1_per_key'): - if key is None: - raise ValueError('key is None, call `tft.scale_to_0_1` instead') - return _scale_by_min_max_internal( - x, - key=key, - output_min=0, - output_max=1, - elementwise=elementwise, - key_vocabulary_filename=key_vocabulary_filename) + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + # pyformat: disable + """Returns a column which is the input column scaled to have range [0,1]. + + Args: + ---- + x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. + key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of type string. + elementwise: If true, scale each element of the tensor independently. + key_vocabulary_filename: (Optional) The file name for the per-key file. If + None, this combiner will assume the keys fit in memory and will not store + the analyzer result in a file. If '', a file name will be chosen based on + the current TensorFlow scope. If not '', it should be unique within a + given preprocessing function. + name: (Optional) A name for this operation. + + Example: + ------- + >>> def preprocessing_fn(inputs): + ... return { + ... 'scaled': tft.scale_to_0_1_per_key(inputs['x'], inputs['s']) + ... } + >>> raw_data = [dict(x=1, s='a'), dict(x=0, s='b'), dict(x=3, s='a')] + >>> feature_spec = dict( + ... x=tf.io.FixedLenFeature([], tf.float32), + ... s=tf.io.FixedLenFeature([], tf.string)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... transformed_dataset, transform_fn = ( + ... (raw_data, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + >>> transformed_data, transformed_metadata = transformed_dataset + >>> transformed_data + [{'scaled': 0.0}, {'scaled': 0.5}, {'scaled': 1.0}] + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column scaled to [0, 1], + per key. If the analysis dataset is empty, contains a single distinct value + or the computed key vocabulary doesn't have an entry for `key`, then `x` is + scaled using a sigmoid function. + """ + # pyformat: enable + with tf.compat.v1.name_scope(name, "scale_to_0_1_per_key"): + if key is None: + raise ValueError("key is None, call `tft.scale_to_0_1` instead") + return _scale_by_min_max_internal( + x, + key=key, + output_min=0, + output_max=1, + elementwise=elementwise, + key_vocabulary_filename=key_vocabulary_filename, + ) @common.log_api_use(common.MAPPER_COLLECTION) @@ -512,40 +554,43 @@ def scale_to_z_score( x: common_types.ConsistentTensorType, elementwise: bool = False, name: Optional[str] = None, - output_dtype: Optional[tf.DType] = None + output_dtype: Optional[tf.DType] = None, ) -> common_types.ConsistentTensorType: - """Returns a standardized column with mean 0 and variance 1. - - Scaling to z-score subtracts out the mean and divides by standard deviation. - Note that the standard deviation computed here is based on the biased variance - (0 delta degrees of freedom), as computed by analyzers.var. - - Args: - x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. - elementwise: If true, scales each element of the tensor independently; - otherwise uses the mean and variance of the whole tensor. - name: (Optional) A name for this operation. - output_dtype: (Optional) If not None, casts the output tensor to this type. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column - scaled to mean 0 - and variance 1 (standard deviation 1), given by: (x - mean(x)) / std_dev(x). - If `x` is floating point, the mean will have the same type as `x`. If `x` is - integral, the output is cast to tf.float32. If the analysis dataset is empty - or contains a single distinct value, then the input is returned without - scaling. - - Note that TFLearn generally permits only tf.int64 and tf.float32, so casting - this scaler's output may be necessary. - """ - with tf.compat.v1.name_scope(name, 'scale_to_z_score'): - return _scale_to_z_score_internal( - x=x, - key=None, - elementwise=elementwise, - key_vocabulary_filename=None, - output_dtype=output_dtype) + """Returns a standardized column with mean 0 and variance 1. + + Scaling to z-score subtracts out the mean and divides by standard deviation. + Note that the standard deviation computed here is based on the biased variance + (0 delta degrees of freedom), as computed by analyzers.var. + + Args: + ---- + x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. + elementwise: If true, scales each element of the tensor independently; + otherwise uses the mean and variance of the whole tensor. + name: (Optional) A name for this operation. + output_dtype: (Optional) If not None, casts the output tensor to this type. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column + scaled to mean 0 + and variance 1 (standard deviation 1), given by: (x - mean(x)) / std_dev(x). + If `x` is floating point, the mean will have the same type as `x`. If `x` is + integral, the output is cast to tf.float32. If the analysis dataset is empty + or contains a single distinct value, then the input is returned without + scaling. + + Note that TFLearn generally permits only tf.int64 and tf.float32, so casting + this scaler's output may be necessary. + """ + with tf.compat.v1.name_scope(name, "scale_to_z_score"): + return _scale_to_z_score_internal( + x=x, + key=None, + elementwise=elementwise, + key_vocabulary_filename=None, + output_dtype=output_dtype, + ) @common.log_api_use(common.MAPPER_COLLECTION) @@ -555,368 +600,403 @@ def scale_to_z_score_per_key( elementwise: bool = False, key_vocabulary_filename: Optional[str] = None, name: Optional[str] = None, - output_dtype: Optional[tf.DType] = None + output_dtype: Optional[tf.DType] = None, ) -> common_types.ConsistentTensorType: - """Returns a standardized column with mean 0 and variance 1, grouped per key. - - Scaling to z-score subtracts out the mean and divides by standard deviation. - Note that the standard deviation computed here is based on the biased variance - (0 delta degrees of freedom), as computed by analyzers.var. - - Args: - x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. - key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string. Must - meet one of the following conditions: - 0. key is None, - 1. Both x and key are dense, - 2. Both x and key are sparse and `key` must exactly match `x` in - everything except values, - 3. The axis=1 index of each x matches its index of dense key. - elementwise: If true, scales each element of the tensor independently; - otherwise uses the mean and variance of the whole tensor. Currently, not - supported for per-key operations. - key_vocabulary_filename: (Optional) The file name for the per-key file. If - None, this combiner will assume the keys fit in memory and will not store - the analyzer result in a file. If '', a file name will be chosen based on - the current TensorFlow scope. If not '', it should be unique within a - given preprocessing function. - name: (Optional) A name for this operation. - output_dtype: (Optional) If not None, casts the output tensor to this type. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column - scaled to mean 0 - and variance 1 (standard deviation 1), grouped per key if a key is provided. - - That is, for all keys k: (x - mean(x)) / std_dev(x) for all x with key k. - If `x` is floating point, the mean will have the same type as `x`. If `x` is - integral, the output is cast to tf.float32. If the analysis dataset is - empty, contains a single distinct value or the computed key vocabulary - doesn't have an entry for `key`, then the input is returned without scaling. - - Note that TFLearn generally permits only tf.int64 and tf.float32, so casting - this scaler's output may be necessary. - """ - with tf.compat.v1.name_scope(name, 'scale_to_z_score_per_key'): - if key is None: - raise ValueError('key is None, call `tft.scale_to_z_score` instead') - return _scale_to_z_score_internal( - x=x, - key=key, - elementwise=elementwise, - key_vocabulary_filename=key_vocabulary_filename, - output_dtype=output_dtype) + """Returns a standardized column with mean 0 and variance 1, grouped per key. + + Scaling to z-score subtracts out the mean and divides by standard deviation. + Note that the standard deviation computed here is based on the biased variance + (0 delta degrees of freedom), as computed by analyzers.var. + + Args: + ---- + x: A numeric `Tensor`, `SparseTensor`, or `RaggedTensor`. + key: A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype tf.string. Must + meet one of the following conditions: + 0. key is None, + 1. Both x and key are dense, + 2. Both x and key are sparse and `key` must exactly match `x` in + everything except values, + 3. The axis=1 index of each x matches its index of dense key. + elementwise: If true, scales each element of the tensor independently; + otherwise uses the mean and variance of the whole tensor. Currently, not + supported for per-key operations. + key_vocabulary_filename: (Optional) The file name for the per-key file. If + None, this combiner will assume the keys fit in memory and will not store + the analyzer result in a file. If '', a file name will be chosen based on + the current TensorFlow scope. If not '', it should be unique within a + given preprocessing function. + name: (Optional) A name for this operation. + output_dtype: (Optional) If not None, casts the output tensor to this type. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` containing the input column + scaled to mean 0 + and variance 1 (standard deviation 1), grouped per key if a key is provided. + + That is, for all keys k: (x - mean(x)) / std_dev(x) for all x with key k. + If `x` is floating point, the mean will have the same type as `x`. If `x` is + integral, the output is cast to tf.float32. If the analysis dataset is + empty, contains a single distinct value or the computed key vocabulary + doesn't have an entry for `key`, then the input is returned without scaling. + + Note that TFLearn generally permits only tf.int64 and tf.float32, so casting + this scaler's output may be necessary. + """ + with tf.compat.v1.name_scope(name, "scale_to_z_score_per_key"): + if key is None: + raise ValueError("key is None, call `tft.scale_to_z_score` instead") + return _scale_to_z_score_internal( + x=x, + key=key, + elementwise=elementwise, + key_vocabulary_filename=key_vocabulary_filename, + output_dtype=output_dtype, + ) def _scale_to_z_score_internal( x: common_types.ConsistentTensorType, - key: Optional[common_types.TensorType], elementwise: bool, + key: Optional[common_types.TensorType], + elementwise: bool, key_vocabulary_filename: Optional[str], - output_dtype: Optional[tf.DType]) -> common_types.ConsistentTensorType: - """Implementation for scale_to_z_score.""" - # x_mean will be float16, float32, or float64, depending on type of x - if key is None: - x_mean, x_var = analyzers._mean_and_var( # pylint: disable=protected-access - x, - reduce_instance_dims=not elementwise, - output_dtype=output_dtype) - else: - if elementwise and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): - raise NotImplementedError( - 'Per-key elementwise reduction of Composite Tensors not supported') - - mean_and_var_per_key_result = analyzers._mean_and_var_per_key( # pylint: disable=protected-access - x, - key, - reduce_instance_dims=not elementwise, - key_vocabulary_filename=key_vocabulary_filename, - output_dtype=output_dtype) - - if key_vocabulary_filename is None: - # Missing keys will translate to 0 for both mean and var which will be - # ignored below in the tf.where. - key_vocab, key_means, key_vars = mean_and_var_per_key_result - x_mean, x_var = tf_utils.map_per_key_reductions( - (key_means, key_vars), key, key_vocab, x, not elementwise) + output_dtype: Optional[tf.DType], +) -> common_types.ConsistentTensorType: + """Implementation for scale_to_z_score.""" + # x_mean will be float16, float32, or float64, depending on type of x + if key is None: + x_mean, x_var = analyzers._mean_and_var( # pylint: disable=protected-access + x, reduce_instance_dims=not elementwise, output_dtype=output_dtype + ) else: - if elementwise: - raise NotImplementedError( - 'Elementwise scaling does not support key_vocabulary_filename') - mean_var_for_key = tf_utils.apply_per_key_vocabulary( - mean_and_var_per_key_result, key, target_ndims=x.get_shape().ndims) - x_mean, x_var = (mean_var_for_key[:, 0], mean_var_for_key[:, 1]) - - compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) - x_values = tf_utils.get_values(x) - - if isinstance(x, tf.SparseTensor): - if elementwise: - x_mean = tf.gather_nd(tf.broadcast_to(x_mean, x.dense_shape), x.indices) - x_var = tf.gather_nd(tf.broadcast_to(x_var, x.dense_shape), x.indices) - elif isinstance(x, tf.RaggedTensor): - if elementwise: - raise NotImplementedError( - 'Elementwise scale_to_z_score does not support RaggedTensors') - - numerator = tf.cast(x_values, x_mean.dtype) - x_mean - denominator = tf.sqrt(x_var) - cond = tf.not_equal(denominator, 0) - - if cond.shape.as_list() != x_values.shape.as_list(): - # Repeats cond when necessary across the batch dimension for it to be - # compatible with the shape of numerator. - cond = tf.cast( - tf.zeros_like(numerator) + tf.cast(cond, numerator.dtype), - dtype=tf.bool) - - deviation_values = tf.where(cond, tf.divide(numerator, denominator), - numerator) - return compose_result_fn(deviation_values) + if elementwise and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): + raise NotImplementedError( + "Per-key elementwise reduction of Composite Tensors not supported" + ) + + mean_and_var_per_key_result = analyzers._mean_and_var_per_key( # pylint: disable=protected-access + x, + key, + reduce_instance_dims=not elementwise, + key_vocabulary_filename=key_vocabulary_filename, + output_dtype=output_dtype, + ) + + if key_vocabulary_filename is None: + # Missing keys will translate to 0 for both mean and var which will be + # ignored below in the tf.where. + key_vocab, key_means, key_vars = mean_and_var_per_key_result + x_mean, x_var = tf_utils.map_per_key_reductions( + (key_means, key_vars), key, key_vocab, x, not elementwise + ) + else: + if elementwise: + raise NotImplementedError( + "Elementwise scaling does not support key_vocabulary_filename" + ) + mean_var_for_key = tf_utils.apply_per_key_vocabulary( + mean_and_var_per_key_result, key, target_ndims=x.get_shape().ndims + ) + x_mean, x_var = (mean_var_for_key[:, 0], mean_var_for_key[:, 1]) + + compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) + x_values = tf_utils.get_values(x) + + if isinstance(x, tf.SparseTensor): + if elementwise: + x_mean = tf.gather_nd(tf.broadcast_to(x_mean, x.dense_shape), x.indices) + x_var = tf.gather_nd(tf.broadcast_to(x_var, x.dense_shape), x.indices) + elif isinstance(x, tf.RaggedTensor): + if elementwise: + raise NotImplementedError( + "Elementwise scale_to_z_score does not support RaggedTensors" + ) + + numerator = tf.cast(x_values, x_mean.dtype) - x_mean + denominator = tf.sqrt(x_var) + cond = tf.not_equal(denominator, 0) + + if cond.shape.as_list() != x_values.shape.as_list(): + # Repeats cond when necessary across the batch dimension for it to be + # compatible with the shape of numerator. + cond = tf.cast( + tf.zeros_like(numerator) + tf.cast(cond, numerator.dtype), dtype=tf.bool + ) + + deviation_values = tf.where(cond, tf.divide(numerator, denominator), numerator) + return compose_result_fn(deviation_values) @common.log_api_use(common.MAPPER_COLLECTION) def tfidf( - x: tf.SparseTensor, - vocab_size: int, - smooth: bool = True, - name: Optional[str] = None) -> Tuple[tf.SparseTensor, tf.SparseTensor]: - # pyformat: disable - """Maps the terms in x to their term frequency * inverse document frequency. - - The term frequency of a term in a document is calculated as - (count of term in document) / (document size) - - The inverse document frequency of a term is, by default, calculated as - 1 + log((corpus size + 1) / (count of documents containing term + 1)). - - - Example usage: - - >>> def preprocessing_fn(inputs): - ... integerized = tft.compute_and_apply_vocabulary(inputs['x']) - ... vocab_size = tft.get_num_buckets_for_transformed_feature(integerized) - ... vocab_index, tfidf_weight = tft.tfidf(integerized, vocab_size) - ... return { - ... 'index': vocab_index, - ... 'tf_idf': tfidf_weight, - ... 'integerized': integerized, - ... } - >>> raw_data = [dict(x=["I", "like", "pie", "pie", "pie"]), - ... dict(x=["yum", "yum", "pie"])] - >>> feature_spec = dict(x=tf.io.VarLenFeature(tf.string)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... transformed_dataset, transform_fn = ( - ... (raw_data, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - >>> transformed_data, transformed_metadata = transformed_dataset - >>> transformed_data - [{'index': array([0, 2, 3]), 'integerized': array([3, 2, 0, 0, 0]), - 'tf_idf': array([0.6, 0.28109303, 0.28109303], dtype=float32)}, - {'index': array([0, 1]), 'integerized': array([1, 1, 0]), - 'tf_idf': array([0.33333334, 0.9369768 ], dtype=float32)}] - - ``` - example strings: [["I", "like", "pie", "pie", "pie"], ["yum", "yum", "pie]] - in: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], - [1, 0], [1, 1], [1, 2]], - values=[1, 2, 0, 0, 0, 3, 3, 0]) - out: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], - values=[1, 2, 0, 3, 0]) - SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], - values=[(1/5)*(log(3/2)+1), (1/5)*(log(3/2)+1), (3/5), - (2/3)*(log(3/2)+1), (1/3)] - ``` - - NOTE: the first doc's duplicate "pie" strings have been combined to - one output, as have the second doc's duplicate "yum" strings. - - Args: - x: A 2D `SparseTensor` representing int64 values (most likely that are the - result of calling `compute_and_apply_vocabulary` on a tokenized string). - vocab_size: An int - the count of vocab used to turn the string into int64s - including any OOV buckets. - smooth: A bool indicating if the inverse document frequency should be - smoothed. If True, which is the default, then the idf is calculated as - 1 + log((corpus size + 1) / (document frequency of term + 1)). - Otherwise, the idf is - 1 +log((corpus size) / (document frequency of term)), which could - result in a division by zero error. - name: (Optional) A name for this operation. - - Returns: - Two `SparseTensor`s with indices [index_in_batch, index_in_bag_of_words]. - The first has values vocab_index, which is taken from input `x`. - The second has values tfidf_weight. - - Raises: - ValueError if `x` does not have 2 dimensions. - """ - # pyformat: enable - if x.get_shape().ndims != 2: - raise ValueError('tft.tfidf requires a 2D SparseTensor input. ' - 'Input had {} dimensions.'.format(x.get_shape().ndims)) - - with tf.compat.v1.name_scope(name, 'tfidf'): - cleaned_input = tf_utils.to_vocab_range(x, vocab_size) - - term_frequencies = _to_term_frequency(cleaned_input, vocab_size) - - count_docs_with_term_column = _count_docs_with_term(term_frequencies) - # Expand dims to get around the min_tensor_rank checks - sizes = tf.expand_dims(tf.shape(input=cleaned_input)[0], 0) - # [batch, vocab] - tfidf - tfidfs = _to_tfidf(term_frequencies, - analyzers.sum(count_docs_with_term_column, - reduce_instance_dims=False), - analyzers.sum(sizes), - smooth) - return _split_tfidfs_to_outputs(tfidfs) + x: tf.SparseTensor, vocab_size: int, smooth: bool = True, name: Optional[str] = None +) -> Tuple[tf.SparseTensor, tf.SparseTensor]: + # pyformat: disable + """Maps the terms in x to their term frequency * inverse document frequency. + + The term frequency of a term in a document is calculated as + (count of term in document) / (document size) + + The inverse document frequency of a term is, by default, calculated as + 1 + log((corpus size + 1) / (count of documents containing term + 1)). + + + Example usage: + + >>> def preprocessing_fn(inputs): + ... integerized = tft.compute_and_apply_vocabulary(inputs['x']) + ... vocab_size = tft.get_num_buckets_for_transformed_feature(integerized) + ... vocab_index, tfidf_weight = tft.tfidf(integerized, vocab_size) + ... return { + ... 'index': vocab_index, + ... 'tf_idf': tfidf_weight, + ... 'integerized': integerized, + ... } + >>> raw_data = [dict(x=["I", "like", "pie", "pie", "pie"]), + ... dict(x=["yum", "yum", "pie"])] + >>> feature_spec = dict(x=tf.io.VarLenFeature(tf.string)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... transformed_dataset, transform_fn = ( + ... (raw_data, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + >>> transformed_data, transformed_metadata = transformed_dataset + >>> transformed_data + [{'index': array([0, 2, 3]), 'integerized': array([3, 2, 0, 0, 0]), + 'tf_idf': array([0.6, 0.28109303, 0.28109303], dtype=float32)}, + {'index': array([0, 1]), 'integerized': array([1, 1, 0]), + 'tf_idf': array([0.33333334, 0.9369768 ], dtype=float32)}] + + ``` + example strings: [["I", "like", "pie", "pie", "pie"], ["yum", "yum", "pie]] + in: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], + [1, 0], [1, 1], [1, 2]], + values=[1, 2, 0, 0, 0, 3, 3, 0]) + out: SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], + values=[1, 2, 0, 3, 0]) + SparseTensor(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], + values=[(1/5)*(log(3/2)+1), (1/5)*(log(3/2)+1), (3/5), + (2/3)*(log(3/2)+1), (1/3)] + ``` + + NOTE: the first doc's duplicate "pie" strings have been combined to + one output, as have the second doc's duplicate "yum" strings. + + Args: + ---- + x: A 2D `SparseTensor` representing int64 values (most likely that are the + result of calling `compute_and_apply_vocabulary` on a tokenized string). + vocab_size: An int - the count of vocab used to turn the string into int64s + including any OOV buckets. + smooth: A bool indicating if the inverse document frequency should be + smoothed. If True, which is the default, then the idf is calculated as + 1 + log((corpus size + 1) / (document frequency of term + 1)). + Otherwise, the idf is + 1 +log((corpus size) / (document frequency of term)), which could + result in a division by zero error. + name: (Optional) A name for this operation. + + Returns: + ------- + Two `SparseTensor`s with indices [index_in_batch, index_in_bag_of_words]. + The first has values vocab_index, which is taken from input `x`. + The second has values tfidf_weight. + + Raises: + ------ + ValueError if `x` does not have 2 dimensions. + """ + # pyformat: enable + if x.get_shape().ndims != 2: + raise ValueError( + "tft.tfidf requires a 2D SparseTensor input. " + f"Input had {x.get_shape().ndims} dimensions." + ) + + with tf.compat.v1.name_scope(name, "tfidf"): + cleaned_input = tf_utils.to_vocab_range(x, vocab_size) + + term_frequencies = _to_term_frequency(cleaned_input, vocab_size) + + count_docs_with_term_column = _count_docs_with_term(term_frequencies) + # Expand dims to get around the min_tensor_rank checks + sizes = tf.expand_dims(tf.shape(input=cleaned_input)[0], 0) + # [batch, vocab] - tfidf + tfidfs = _to_tfidf( + term_frequencies, + analyzers.sum(count_docs_with_term_column, reduce_instance_dims=False), + analyzers.sum(sizes), + smooth, + ) + return _split_tfidfs_to_outputs(tfidfs) def _split_tfidfs_to_outputs( - tfidfs: tf.SparseTensor) -> Tuple[tf.SparseTensor, tf.SparseTensor]: - """Splits [batch, vocab]-weight into [batch, bow]-vocab & [batch, bow]-tfidf. - - Args: - tfidfs: the `SparseTensor` output of _to_tfidf - - Returns: - Two `SparseTensor`s with indices [index_in_batch, index_in_bag_of_words]. - The first has values vocab_index, which is taken from input `x`. - The second has values tfidf_weight. - """ - # Split tfidfs tensor into [batch, dummy] -> vocab & [batch, dummy] -> tfidf - # The "dummy" index counts from 0 to the number of unique tokens in the doc. - # So example doc ["I", "like", "pie", "pie", "pie"], with 3 unique tokens, - # will have "dummy" indices [0, 1, 2]. The particular dummy index that any - # token receives is not important, only that the tfidf value and vocab index - # have the *same* dummy index, so that feature_column can apply the weight to - # the correct vocab item. - dummy_index = segment_indices(tfidfs.indices[:, 0]) - out_index = tf.concat( - [tf.expand_dims(tfidfs.indices[:, 0], 1), - tf.expand_dims(dummy_index, 1)], 1) - - out_shape_second_dim = tf.maximum( - tf.reduce_max(input_tensor=dummy_index), -1) + 1 - out_shape = tf.stack([tfidfs.dense_shape[0], out_shape_second_dim]) - out_shape.set_shape([2]) - - de_duped_indicies_out = tf.SparseTensor( # NOTYPO ('indices') - indices=out_index, - values=tfidfs.indices[:, 1], - dense_shape=out_shape) - de_duped_tfidf_out = tf.SparseTensor( - indices=out_index, - values=tfidfs.values, - dense_shape=out_shape) - return de_duped_indicies_out, de_duped_tfidf_out # NOTYPO ('indices') - - -def _to_term_frequency(x: tf.SparseTensor, - vocab_size: Union[int, tf.Tensor]) -> tf.SparseTensor: - """Creates a SparseTensor of term frequency for every doc/term pair. - - Args: - x : a SparseTensor of int64 representing string indices in vocab. - vocab_size: A scalar int64 Tensor - the count of vocab used to turn the - string into int64s including any OOV buckets. - - Returns: - a SparseTensor with the count of times a term appears in a document at - indices , , - with size (num_docs_in_batch, vocab_size). - """ - # Construct intermediary sparse tensor with indices - # [, , ] and tf.ones values. - vocab_size = tf.convert_to_tensor(value=vocab_size, dtype=tf.int64) - split_indices = tf.cast( - tf.split(x.indices, axis=1, num_or_size_splits=2), dtype=tf.int64) - expanded_values = tf.cast(tf.expand_dims(x.values, 1), dtype=tf.int64) - next_index = tf.concat( - [split_indices[0], split_indices[1], expanded_values], axis=1) - - next_values = tf.ones_like(x.values) - expanded_vocab_size = tf.expand_dims(vocab_size, 0) - next_shape = tf.concat( - [x.dense_shape, expanded_vocab_size], 0) - - next_tensor = tf.SparseTensor( - indices=tf.cast(next_index, dtype=tf.int64), - values=next_values, - dense_shape=next_shape) - - # Take the intermediary tensor and reduce over the term_index_in_doc - # dimension. This produces a tensor with indices [, ] - # and values [count_of_term_in_doc] and shape batch x vocab_size - term_count_per_doc = tf.compat.v1.sparse_reduce_sum_sparse(next_tensor, 1) - - dense_doc_sizes = tf.cast( - tf.sparse.reduce_sum( - tf.SparseTensor( - indices=x.indices, - values=tf.ones_like(x.values), - dense_shape=x.dense_shape), 1), - dtype=tf.float64) - - gather_indices = term_count_per_doc.indices[:, 0] - gathered_doc_sizes = tf.gather(dense_doc_sizes, gather_indices) - - term_frequency = ( - tf.cast(term_count_per_doc.values, dtype=tf.float64) / - tf.cast(gathered_doc_sizes, dtype=tf.float64)) - return tf.SparseTensor( - indices=term_count_per_doc.indices, - values=term_frequency, - dense_shape=term_count_per_doc.dense_shape) - - -def _to_tfidf(term_frequency: tf.SparseTensor, reduced_term_freq: tf.Tensor, - corpus_size: tf.Tensor, smooth: bool) -> tf.SparseTensor: - """Calculates the inverse document frequency of terms in the corpus. - - Args: - term_frequency: The `SparseTensor` output of _to_term_frequency. - reduced_term_freq: A `Tensor` of shape (vocabSize,) that represents the - count of the number of documents with each term. - corpus_size: A scalar count of the number of documents in the corpus. - smooth: A bool indicating if the idf value should be smoothed. See - tfidf_weights documentation for details. - - Returns: - A `SparseTensor` with indices=, , - values=term frequency * inverse document frequency, - and shape=(batch, vocab_size) - """ - # The idf tensor has shape (vocab_size,) - idf = tf_utils.document_frequency_to_idf( - reduced_term_freq, corpus_size, smooth=smooth, add_baseline=True) - gathered_idfs = tf.gather(tf.squeeze(idf), term_frequency.indices[:, 1]) - tfidf_values = (tf.cast(term_frequency.values, tf.float32) - * tf.cast(gathered_idfs, tf.float32)) - - return tf.SparseTensor( - indices=term_frequency.indices, - values=tfidf_values, - dense_shape=term_frequency.dense_shape) + tfidfs: tf.SparseTensor, +) -> Tuple[tf.SparseTensor, tf.SparseTensor]: + """Splits [batch, vocab]-weight into [batch, bow]-vocab & [batch, bow]-tfidf. + + Args: + ---- + tfidfs: the `SparseTensor` output of _to_tfidf + + Returns: + ------- + Two `SparseTensor`s with indices [index_in_batch, index_in_bag_of_words]. + The first has values vocab_index, which is taken from input `x`. + The second has values tfidf_weight. + """ + # Split tfidfs tensor into [batch, dummy] -> vocab & [batch, dummy] -> tfidf + # The "dummy" index counts from 0 to the number of unique tokens in the doc. + # So example doc ["I", "like", "pie", "pie", "pie"], with 3 unique tokens, + # will have "dummy" indices [0, 1, 2]. The particular dummy index that any + # token receives is not important, only that the tfidf value and vocab index + # have the *same* dummy index, so that feature_column can apply the weight to + # the correct vocab item. + dummy_index = segment_indices(tfidfs.indices[:, 0]) + out_index = tf.concat( + [tf.expand_dims(tfidfs.indices[:, 0], 1), tf.expand_dims(dummy_index, 1)], 1 + ) + + out_shape_second_dim = tf.maximum(tf.reduce_max(input_tensor=dummy_index), -1) + 1 + out_shape = tf.stack([tfidfs.dense_shape[0], out_shape_second_dim]) + out_shape.set_shape([2]) + + de_duped_indicies_out = tf.SparseTensor( # NOTYPO ('indices') + indices=out_index, values=tfidfs.indices[:, 1], dense_shape=out_shape + ) + de_duped_tfidf_out = tf.SparseTensor( + indices=out_index, values=tfidfs.values, dense_shape=out_shape + ) + return de_duped_indicies_out, de_duped_tfidf_out # NOTYPO ('indices') + + +def _to_term_frequency( + x: tf.SparseTensor, vocab_size: Union[int, tf.Tensor] +) -> tf.SparseTensor: + """Creates a SparseTensor of term frequency for every doc/term pair. + + Args: + ---- + x : a SparseTensor of int64 representing string indices in vocab. + vocab_size: A scalar int64 Tensor - the count of vocab used to turn the + string into int64s including any OOV buckets. + + Returns: + ------- + a SparseTensor with the count of times a term appears in a document at + indices , , + with size (num_docs_in_batch, vocab_size). + """ + # Construct intermediary sparse tensor with indices + # [, , ] and tf.ones values. + vocab_size = tf.convert_to_tensor(value=vocab_size, dtype=tf.int64) + split_indices = tf.cast( + tf.split(x.indices, axis=1, num_or_size_splits=2), dtype=tf.int64 + ) + expanded_values = tf.cast(tf.expand_dims(x.values, 1), dtype=tf.int64) + next_index = tf.concat( + [split_indices[0], split_indices[1], expanded_values], axis=1 + ) + + next_values = tf.ones_like(x.values) + expanded_vocab_size = tf.expand_dims(vocab_size, 0) + next_shape = tf.concat([x.dense_shape, expanded_vocab_size], 0) + + next_tensor = tf.SparseTensor( + indices=tf.cast(next_index, dtype=tf.int64), + values=next_values, + dense_shape=next_shape, + ) + + # Take the intermediary tensor and reduce over the term_index_in_doc + # dimension. This produces a tensor with indices [, ] + # and values [count_of_term_in_doc] and shape batch x vocab_size + term_count_per_doc = tf.compat.v1.sparse_reduce_sum_sparse(next_tensor, 1) + + dense_doc_sizes = tf.cast( + tf.sparse.reduce_sum( + tf.SparseTensor( + indices=x.indices, + values=tf.ones_like(x.values), + dense_shape=x.dense_shape, + ), + 1, + ), + dtype=tf.float64, + ) + + gather_indices = term_count_per_doc.indices[:, 0] + gathered_doc_sizes = tf.gather(dense_doc_sizes, gather_indices) + + term_frequency = tf.cast(term_count_per_doc.values, dtype=tf.float64) / tf.cast( + gathered_doc_sizes, dtype=tf.float64 + ) + return tf.SparseTensor( + indices=term_count_per_doc.indices, + values=term_frequency, + dense_shape=term_count_per_doc.dense_shape, + ) + + +def _to_tfidf( + term_frequency: tf.SparseTensor, + reduced_term_freq: tf.Tensor, + corpus_size: tf.Tensor, + smooth: bool, +) -> tf.SparseTensor: + """Calculates the inverse document frequency of terms in the corpus. + + Args: + ---- + term_frequency: The `SparseTensor` output of _to_term_frequency. + reduced_term_freq: A `Tensor` of shape (vocabSize,) that represents the + count of the number of documents with each term. + corpus_size: A scalar count of the number of documents in the corpus. + smooth: A bool indicating if the idf value should be smoothed. See + tfidf_weights documentation for details. + + Returns: + ------- + A `SparseTensor` with indices=, , + values=term frequency * inverse document frequency, + and shape=(batch, vocab_size) + """ + # The idf tensor has shape (vocab_size,) + idf = tf_utils.document_frequency_to_idf( + reduced_term_freq, corpus_size, smooth=smooth, add_baseline=True + ) + gathered_idfs = tf.gather(tf.squeeze(idf), term_frequency.indices[:, 1]) + tfidf_values = tf.cast(term_frequency.values, tf.float32) * tf.cast( + gathered_idfs, tf.float32 + ) + + return tf.SparseTensor( + indices=term_frequency.indices, + values=tfidf_values, + dense_shape=term_frequency.dense_shape, + ) def _count_docs_with_term(term_frequency: tf.SparseTensor) -> tf.Tensor: - """Computes the number of documents in a batch that contain each term. + """Computes the number of documents in a batch that contain each term. - Args: - term_frequency: The `SparseTensor` output of _to_term_frequency. + Args: + ---- + term_frequency: The `SparseTensor` output of _to_term_frequency. - Returns: - A `Tensor` of shape (vocab_size,) that contains the number of documents in - the batch that contain each term. - """ - count_of_doc_inter = tf.SparseTensor( - indices=term_frequency.indices, - values=tf.ones_like(term_frequency.values), - dense_shape=term_frequency.dense_shape) - return tf.sparse.reduce_sum(count_of_doc_inter, axis=0, keepdims=True) + Returns: + ------- + A `Tensor` of shape (vocab_size,) that contains the number of documents in + the batch that contain each term. + """ + count_of_doc_inter = tf.SparseTensor( + indices=term_frequency.indices, + values=tf.ones_like(term_frequency.values), + dense_shape=term_frequency.dense_shape, + ) + return tf.sparse.reduce_sum(count_of_doc_inter, axis=0, keepdims=True) @common.log_api_use(common.MAPPER_COLLECTION) @@ -941,128 +1021,131 @@ def compute_and_apply_vocabulary( reserved_tokens: Optional[Union[Iterable[str], tf.Tensor]] = None, name: Optional[str] = None, ) -> common_types.ConsistentTensorType: - r"""Generates a vocabulary for `x` and maps it to an integer with this vocab. - - In case one of the tokens contains the '\n' or '\r' characters or is empty it - will be discarded since we are currently writing the vocabularies as text - files. This behavior will likely be fixed/improved in the future. - - Note that this function will cause a vocabulary to be computed. For large - datasets it is highly recommended to either set frequency_threshold or top_k - to control the size of the vocabulary, and also the run time of this - operation. - - Args: - x: A `Tensor`, `SparseTensor`, or `RaggedTensor` of type tf.string or - tf.int[8|16|32|64]. - default_value: The value to use for out-of-vocabulary values, unless - 'num_oov_buckets' is greater than zero. - top_k: Limit the generated vocabulary to the first `top_k` elements. If set - to None, the full vocabulary is generated. - frequency_threshold: Limit the generated vocabulary only to elements whose - absolute frequency is >= to the supplied threshold. If set to None, the - full vocabulary is generated. Absolute frequency means the number of - occurences of the element in the dataset, as opposed to the proportion of - instances that contain that element. If labels are provided and the vocab - is computed using mutual information, tokens are filtered if their mutual - information with the label is < the supplied threshold. - num_oov_buckets: Any lookup of an out-of-vocabulary token will return a - bucket ID based on its hash if `num_oov_buckets` is greater than zero. - Otherwise it is assigned the `default_value`. - vocab_filename: The file name for the vocabulary file. If None, a name based - on the scope name in the context of this graph will be used as the file - name. If not None, should be unique within a given preprocessing function. - NOTE in order to make your pipelines resilient to implementation details - please set `vocab_filename` when you are using the vocab_filename on a - downstream component. - weights: (Optional) Weights `Tensor` for the vocabulary. It must have the - same shape as x. - labels: (Optional) A `Tensor` of labels for the vocabulary. If provided, the - vocabulary is calculated based on mutual information with the label, - rather than frequency. The labels must have the same batch dimension as x. - If x is sparse, labels should be a 1D tensor reflecting row-wise labels. - If x is dense, labels can either be a 1D tensor of row-wise labels, or a - dense tensor of the identical shape as x (i.e. element-wise labels). - Labels should be a discrete integerized tensor (If the label is numeric, - it should first be bucketized; If the label is a string, an integer - vocabulary should first be applied). Note: `CompositeTensor` labels are - not yet supported (b/134931826). WARNING: when labels are provided, the - frequency_threshold argument functions as a mutual information threshold, - which is a float. TODO(b/116308354): Fix confusing naming. - use_adjusted_mutual_info: If true, use adjusted mutual information. - min_diff_from_avg: Mutual information of a feature will be adjusted to zero - whenever the difference between count of the feature with any label and - its expected count is lower than min_diff_from_average. - coverage_top_k: (Optional), (Experimental) The minimum number of elements - per key to be included in the vocabulary. - coverage_frequency_threshold: (Optional), (Experimental) Limit the coverage - arm of the vocabulary only to elements whose absolute frequency is >= this - threshold for a given key. - key_fn: (Optional), (Experimental) A fn that takes in a single entry of `x` - and returns the corresponding key for coverage calculation. If this is - `None`, no coverage arm is added to the vocabulary. - fingerprint_shuffle: (Optional), (Experimental) Whether to sort the - vocabularies by fingerprint instead of counts. This is useful for load - balancing on the training parameter servers. Shuffle only happens while - writing the files, so all the filters above will still take effect. - file_format: (Optional) A str. The format of the resulting vocabulary file. - Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires - tensorflow>=2.4. The default value is 'text'. - store_frequency: If True, frequency of the words is stored in the vocabulary - file. In the case labels are provided, the mutual information is stored in - the file instead. Each line in the file will be of the form 'frequency - word'. NOTE: if True and text_format is 'text' then spaces will be - replaced to avoid information loss. - reserved_tokens: (Optional) A list of tokens that should appear in the - vocabulary regardless of their appearance in the input. These tokens would - maintain their order, and have a reserved spot at the beginning of the - vocabulary. Note: this field has no affect on cache. - name: (Optional) A name for this operation. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` where each string value is - mapped to an integer. Each unique string value that appears in the - vocabulary is mapped to a different integer and integers are consecutive - starting from zero. String value not in the vocabulary is assigned - `default_value`. Alternatively, if `num_oov_buckets` is specified, out of - vocabulary strings are hashed to values in - [vocab_size, vocab_size + num_oov_buckets) for an overall range of - [0, vocab_size + num_oov_buckets). - - Raises: - ValueError: If `top_k` or `frequency_threshold` is negative. - If `coverage_top_k` or `coverage_frequency_threshold` is negative. - """ - with tf.compat.v1.name_scope(name, 'compute_and_apply_vocabulary'): - if store_frequency and file_format == 'text': - x = tf_utils.maybe_format_vocabulary_input(x) - deferred_vocab_and_filename = analyzers.vocabulary( - x=x, - top_k=top_k, - frequency_threshold=frequency_threshold, - vocab_filename=vocab_filename, - store_frequency=store_frequency, - weights=weights, - labels=labels, - use_adjusted_mutual_info=use_adjusted_mutual_info, - min_diff_from_avg=min_diff_from_avg, - coverage_top_k=coverage_top_k, - coverage_frequency_threshold=coverage_frequency_threshold, - key_fn=key_fn, - fingerprint_shuffle=fingerprint_shuffle, - file_format=file_format, - reserved_tokens=reserved_tokens, - ) - return _apply_vocabulary_internal( - x, - deferred_vocab_and_filename, - default_value, - num_oov_buckets, - lookup_fn=None, - store_frequency=store_frequency, - file_format=file_format, - name=None, - ) + r"""Generates a vocabulary for `x` and maps it to an integer with this vocab. + + In case one of the tokens contains the '\n' or '\r' characters or is empty it + will be discarded since we are currently writing the vocabularies as text + files. This behavior will likely be fixed/improved in the future. + + Note that this function will cause a vocabulary to be computed. For large + datasets it is highly recommended to either set frequency_threshold or top_k + to control the size of the vocabulary, and also the run time of this + operation. + + Args: + ---- + x: A `Tensor`, `SparseTensor`, or `RaggedTensor` of type tf.string or + tf.int[8|16|32|64]. + default_value: The value to use for out-of-vocabulary values, unless + 'num_oov_buckets' is greater than zero. + top_k: Limit the generated vocabulary to the first `top_k` elements. If set + to None, the full vocabulary is generated. + frequency_threshold: Limit the generated vocabulary only to elements whose + absolute frequency is >= to the supplied threshold. If set to None, the + full vocabulary is generated. Absolute frequency means the number of + occurences of the element in the dataset, as opposed to the proportion of + instances that contain that element. If labels are provided and the vocab + is computed using mutual information, tokens are filtered if their mutual + information with the label is < the supplied threshold. + num_oov_buckets: Any lookup of an out-of-vocabulary token will return a + bucket ID based on its hash if `num_oov_buckets` is greater than zero. + Otherwise it is assigned the `default_value`. + vocab_filename: The file name for the vocabulary file. If None, a name based + on the scope name in the context of this graph will be used as the file + name. If not None, should be unique within a given preprocessing function. + NOTE in order to make your pipelines resilient to implementation details + please set `vocab_filename` when you are using the vocab_filename on a + downstream component. + weights: (Optional) Weights `Tensor` for the vocabulary. It must have the + same shape as x. + labels: (Optional) A `Tensor` of labels for the vocabulary. If provided, the + vocabulary is calculated based on mutual information with the label, + rather than frequency. The labels must have the same batch dimension as x. + If x is sparse, labels should be a 1D tensor reflecting row-wise labels. + If x is dense, labels can either be a 1D tensor of row-wise labels, or a + dense tensor of the identical shape as x (i.e. element-wise labels). + Labels should be a discrete integerized tensor (If the label is numeric, + it should first be bucketized; If the label is a string, an integer + vocabulary should first be applied). Note: `CompositeTensor` labels are + not yet supported (b/134931826). WARNING: when labels are provided, the + frequency_threshold argument functions as a mutual information threshold, + which is a float. TODO(b/116308354): Fix confusing naming. + use_adjusted_mutual_info: If true, use adjusted mutual information. + min_diff_from_avg: Mutual information of a feature will be adjusted to zero + whenever the difference between count of the feature with any label and + its expected count is lower than min_diff_from_average. + coverage_top_k: (Optional), (Experimental) The minimum number of elements + per key to be included in the vocabulary. + coverage_frequency_threshold: (Optional), (Experimental) Limit the coverage + arm of the vocabulary only to elements whose absolute frequency is >= this + threshold for a given key. + key_fn: (Optional), (Experimental) A fn that takes in a single entry of `x` + and returns the corresponding key for coverage calculation. If this is + `None`, no coverage arm is added to the vocabulary. + fingerprint_shuffle: (Optional), (Experimental) Whether to sort the + vocabularies by fingerprint instead of counts. This is useful for load + balancing on the training parameter servers. Shuffle only happens while + writing the files, so all the filters above will still take effect. + file_format: (Optional) A str. The format of the resulting vocabulary file. + Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires + tensorflow>=2.4. The default value is 'text'. + store_frequency: If True, frequency of the words is stored in the vocabulary + file. In the case labels are provided, the mutual information is stored in + the file instead. Each line in the file will be of the form 'frequency + word'. NOTE: if True and text_format is 'text' then spaces will be + replaced to avoid information loss. + reserved_tokens: (Optional) A list of tokens that should appear in the + vocabulary regardless of their appearance in the input. These tokens would + maintain their order, and have a reserved spot at the beginning of the + vocabulary. Note: this field has no affect on cache. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` where each string value is + mapped to an integer. Each unique string value that appears in the + vocabulary is mapped to a different integer and integers are consecutive + starting from zero. String value not in the vocabulary is assigned + `default_value`. Alternatively, if `num_oov_buckets` is specified, out of + vocabulary strings are hashed to values in + [vocab_size, vocab_size + num_oov_buckets) for an overall range of + [0, vocab_size + num_oov_buckets). + + Raises: + ------ + ValueError: If `top_k` or `frequency_threshold` is negative. + If `coverage_top_k` or `coverage_frequency_threshold` is negative. + """ + with tf.compat.v1.name_scope(name, "compute_and_apply_vocabulary"): + if store_frequency and file_format == "text": + x = tf_utils.maybe_format_vocabulary_input(x) + deferred_vocab_and_filename = analyzers.vocabulary( + x=x, + top_k=top_k, + frequency_threshold=frequency_threshold, + vocab_filename=vocab_filename, + store_frequency=store_frequency, + weights=weights, + labels=labels, + use_adjusted_mutual_info=use_adjusted_mutual_info, + min_diff_from_avg=min_diff_from_avg, + coverage_top_k=coverage_top_k, + coverage_frequency_threshold=coverage_frequency_threshold, + key_fn=key_fn, + fingerprint_shuffle=fingerprint_shuffle, + file_format=file_format, + reserved_tokens=reserved_tokens, + ) + return _apply_vocabulary_internal( + x, + deferred_vocab_and_filename, + default_value, + num_oov_buckets, + lookup_fn=None, + store_frequency=store_frequency, + file_format=file_format, + name=None, + ) @common.log_api_use(common.MAPPER_COLLECTION) @@ -1072,97 +1155,97 @@ def apply_vocabulary( *, # Force passing optional parameters by keys. default_value: Any = -1, num_oov_buckets: int = 0, - lookup_fn: Optional[Callable[[common_types.TensorType, tf.Tensor], - Tuple[tf.Tensor, tf.Tensor]]] = None, - file_format: common_types.VocabularyFileFormatType = analyzers - .DEFAULT_VOCABULARY_FILE_FORMAT, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - r"""Maps `x` to a vocabulary specified by the deferred tensor. - - This function also writes domain statistics about the vocabulary min and max - values. Note that the min and max are inclusive, and depend on the vocab size, - num_oov_buckets and default_value. - - Args: - x: A categorical `Tensor`, `SparseTensor`, or `RaggedTensor` of type - tf.string or tf.int[8|16|32|64] to which the vocabulary transformation - should be applied. The column names are those intended for the transformed - tensors. - deferred_vocab_filename_tensor: The deferred vocab filename tensor as - returned by `tft.vocabulary`, as long as the frequencies were not stored. - default_value: The value to use for out-of-vocabulary values, unless - 'num_oov_buckets' is greater than zero. - num_oov_buckets: Any lookup of an out-of-vocabulary token will return a - bucket ID based on its hash if `num_oov_buckets` is greater than zero. - Otherwise it is assigned the `default_value`. - lookup_fn: Optional lookup function, if specified it should take a tensor - and a deferred vocab filename as an input and return a lookup `op` along - with the table size, by default `apply_vocabulary` constructs a - StaticHashTable for the table lookup. - file_format: (Optional) A str. The format of the given vocabulary. Accepted - formats are: 'tfrecord_gzip', 'text'. The default value is 'text'. - name: (Optional) A name for this operation. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` where each string value is - mapped to an integer. Each unique string value that appears in the - vocabulary is mapped to a different integer and integers are consecutive - starting from zero, and string value not in the vocabulary is - assigned default_value. - """ - return _apply_vocabulary_internal( - x, - deferred_vocab_filename_tensor, - default_value, - num_oov_buckets, - lookup_fn, - file_format, - False, - name, - ) + lookup_fn: Optional[ + Callable[[common_types.TensorType, tf.Tensor], Tuple[tf.Tensor, tf.Tensor]] + ] = None, + file_format: common_types.VocabularyFileFormatType = analyzers.DEFAULT_VOCABULARY_FILE_FORMAT, + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + r"""Maps `x` to a vocabulary specified by the deferred tensor. + + This function also writes domain statistics about the vocabulary min and max + values. Note that the min and max are inclusive, and depend on the vocab size, + num_oov_buckets and default_value. + + Args: + ---- + x: A categorical `Tensor`, `SparseTensor`, or `RaggedTensor` of type + tf.string or tf.int[8|16|32|64] to which the vocabulary transformation + should be applied. The column names are those intended for the transformed + tensors. + deferred_vocab_filename_tensor: The deferred vocab filename tensor as + returned by `tft.vocabulary`, as long as the frequencies were not stored. + default_value: The value to use for out-of-vocabulary values, unless + 'num_oov_buckets' is greater than zero. + num_oov_buckets: Any lookup of an out-of-vocabulary token will return a + bucket ID based on its hash if `num_oov_buckets` is greater than zero. + Otherwise it is assigned the `default_value`. + lookup_fn: Optional lookup function, if specified it should take a tensor + and a deferred vocab filename as an input and return a lookup `op` along + with the table size, by default `apply_vocabulary` constructs a + StaticHashTable for the table lookup. + file_format: (Optional) A str. The format of the given vocabulary. Accepted + formats are: 'tfrecord_gzip', 'text'. The default value is 'text'. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` where each string value is + mapped to an integer. Each unique string value that appears in the + vocabulary is mapped to a different integer and integers are consecutive + starting from zero, and string value not in the vocabulary is + assigned default_value. + """ + return _apply_vocabulary_internal( + x, + deferred_vocab_filename_tensor, + default_value, + num_oov_buckets, + lookup_fn, + file_format, + False, + name, + ) def _make_construct_vocabulary_table_function( x, file_format, num_oov_buckets, default_value, store_frequency ): - """Defines a function to construct a vocabulary lookup table.""" - def construct_table(asset_filepath): - if file_format == 'tfrecord_gzip': - initializer = tf_utils.make_tfrecord_vocabulary_lookup_initializer( - asset_filepath, - x.dtype, - return_indicator_as_value=False, - has_indicator=store_frequency, - ) - elif file_format == 'text': - key_index = 1 if store_frequency else tf.lookup.TextFileIndex.WHOLE_LINE - kwargs = {'delimiter': ' '} if store_frequency else {} - initializer = tf.lookup.TextFileInitializer( - asset_filepath, - key_dtype=x.dtype, - key_index=key_index, - value_dtype=tf.int64, - value_index=tf.lookup.TextFileIndex.LINE_NUMBER, - **kwargs, - ) - else: - raise ValueError( - '"{}" is not an accepted file_format. It should be one of: {}'.format( - file_format, analyzers.ALLOWED_VOCABULARY_FILE_FORMATS - ) - ) - - if num_oov_buckets > 0: - table = tf.lookup.StaticVocabularyTable( - initializer, num_oov_buckets=num_oov_buckets, lookup_key_dtype=x.dtype - ) - else: - table = tf.lookup.StaticHashTable( - initializer, default_value=default_value - ) - return table - - return construct_table + """Defines a function to construct a vocabulary lookup table.""" + + def construct_table(asset_filepath): + if file_format == "tfrecord_gzip": + initializer = tf_utils.make_tfrecord_vocabulary_lookup_initializer( + asset_filepath, + x.dtype, + return_indicator_as_value=False, + has_indicator=store_frequency, + ) + elif file_format == "text": + key_index = 1 if store_frequency else tf.lookup.TextFileIndex.WHOLE_LINE + kwargs = {"delimiter": " "} if store_frequency else {} + initializer = tf.lookup.TextFileInitializer( + asset_filepath, + key_dtype=x.dtype, + key_index=key_index, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + **kwargs, + ) + else: + raise ValueError( + f'"{file_format}" is not an accepted file_format. It should be one of: {analyzers.ALLOWED_VOCABULARY_FILE_FORMATS}' + ) + + if num_oov_buckets > 0: + table = tf.lookup.StaticVocabularyTable( + initializer, num_oov_buckets=num_oov_buckets, lookup_key_dtype=x.dtype + ) + else: + table = tf.lookup.StaticHashTable(initializer, default_value=default_value) + return table + + return construct_table def _apply_vocabulary_internal( @@ -1171,533 +1254,582 @@ def _apply_vocabulary_internal( default_value: Any, num_oov_buckets: int, lookup_fn: Optional[ - Callable[ - [common_types.TensorType, tf.Tensor], Tuple[tf.Tensor, tf.Tensor] - ] + Callable[[common_types.TensorType, tf.Tensor], Tuple[tf.Tensor, tf.Tensor]] ], file_format: common_types.VocabularyFileFormatType, store_frequency: bool, name: Optional[str], ) -> common_types.ConsistentTensorType: - """See apply_vocabulary doc.""" - with tf.compat.v1.name_scope(name, 'apply_vocab'): - if x.dtype != tf.string and not x.dtype.is_integer: - raise ValueError('expected tf.string or tf.int[8|16|32|64] but got %r' % - x.dtype) - - if lookup_fn: - result, table_size = tf_utils.lookup_table( - lookup_fn, deferred_vocab_filename_tensor, x) - else: - if (deferred_vocab_filename_tensor is None or - (isinstance(deferred_vocab_filename_tensor, - (bytes, str)) and not deferred_vocab_filename_tensor)): - raise ValueError('`deferred_vocab_filename_tensor` must not be empty.') - construct_table = _make_construct_vocabulary_table_function( - x, file_format, num_oov_buckets, default_value, store_frequency - ) - x_values = tf_utils.get_values(x) - result, table_size = tf_utils.construct_and_lookup_table( - construct_table, deferred_vocab_filename_tensor, x_values - ) - compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) - result = compose_result_fn(result) - - # Specify schema overrides which will override the values in the schema - # with the min and max values, which are deferred as they are only known - # once the analyzer has run. - # - # `table_size` includes the num oov buckets. The default value is only used - # if num_oov_buckets <= 0. - min_value = tf.constant(0, tf.int64) - max_value = table_size - 1 - if num_oov_buckets <= 0: - min_value = tf.minimum(min_value, default_value) - max_value = tf.maximum(max_value, default_value) - schema_inference.set_tensor_schema_override( - tf_utils.get_values(result), min_value, max_value) - return result + """See apply_vocabulary doc.""" + with tf.compat.v1.name_scope(name, "apply_vocab"): + if x.dtype != tf.string and not x.dtype.is_integer: + raise ValueError( + "expected tf.string or tf.int[8|16|32|64] but got %r" % x.dtype + ) + + if lookup_fn: + result, table_size = tf_utils.lookup_table( + lookup_fn, deferred_vocab_filename_tensor, x + ) + else: + if deferred_vocab_filename_tensor is None or ( + isinstance(deferred_vocab_filename_tensor, (bytes, str)) + and not deferred_vocab_filename_tensor + ): + raise ValueError("`deferred_vocab_filename_tensor` must not be empty.") + construct_table = _make_construct_vocabulary_table_function( + x, file_format, num_oov_buckets, default_value, store_frequency + ) + x_values = tf_utils.get_values(x) + result, table_size = tf_utils.construct_and_lookup_table( + construct_table, deferred_vocab_filename_tensor, x_values + ) + compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) + result = compose_result_fn(result) + + # Specify schema overrides which will override the values in the schema + # with the min and max values, which are deferred as they are only known + # once the analyzer has run. + # + # `table_size` includes the num oov buckets. The default value is only used + # if num_oov_buckets <= 0. + min_value = tf.constant(0, tf.int64) + max_value = table_size - 1 + if num_oov_buckets <= 0: + min_value = tf.minimum(min_value, default_value) + max_value = tf.maximum(max_value, default_value) + schema_inference.set_tensor_schema_override( + tf_utils.get_values(result), min_value, max_value + ) + return result @common.log_api_use(common.MAPPER_COLLECTION) def get_num_buckets_for_transformed_feature( - transformed_feature: common_types.TensorType) -> tf.Tensor: - # pyformat: disable - """Provides the number of buckets for a transformed feature if annotated. - - This for example can be used for the direct output of `tft.bucketize`, - `tft.apply_buckets`, `tft.compute_and_apply_vocabulary`, - `tft.apply_vocabulary`. - These methods annotate the transformed feature with additional information. - If the given `transformed_feature` isn't annotated, this method will fail. - - Example: - - >>> def preprocessing_fn(inputs): - ... bucketized = tft.bucketize(inputs['x'], num_buckets=3) - ... integerized = tft.compute_and_apply_vocabulary(inputs['x']) - ... zeros = tf.zeros_like(inputs['x'], tf.int64) - ... return { - ... 'bucketized': bucketized, - ... 'bucketized_num_buckets': ( - ... zeros + tft.get_num_buckets_for_transformed_feature(bucketized)), - ... 'integerized': integerized, - ... 'integerized_num_buckets': ( - ... zeros + tft.get_num_buckets_for_transformed_feature(integerized)), - ... } - >>> raw_data = [dict(x=3),dict(x=23)] - >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64)) - >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) - >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): - ... transformed_dataset, transform_fn = ( - ... (raw_data, raw_data_metadata) - ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) - >>> transformed_data, transformed_metadata = transformed_dataset - >>> transformed_data - [{'bucketized': 1, 'bucketized_num_buckets': 3, - 'integerized': 0, 'integerized_num_buckets': 2}, - {'bucketized': 2, 'bucketized_num_buckets': 3, - 'integerized': 1, 'integerized_num_buckets': 2}] - - Args: - transformed_feature: A `Tensor` or `SparseTensor` which is the direct output - of `tft.bucketize`, `tft.apply_buckets`, - `tft.compute_and_apply_vocabulary` or `tft.apply_vocabulary`. - - Raises: - ValueError: If the given tensor has not been annotated a the number of - buckets. - - Returns: - A `Tensor` with the number of buckets for the given `transformed_feature`. - """ - # pyformat: enable - # Adding 1 to the 2nd Tensor of the returned pair in order to compute max + 1. - return tf.cast( - schema_inference.get_tensor_schema_override(transformed_feature)[1] + 1, - tf.int64) + transformed_feature: common_types.TensorType, +) -> tf.Tensor: + # pyformat: disable + """Provides the number of buckets for a transformed feature if annotated. + + This for example can be used for the direct output of `tft.bucketize`, + `tft.apply_buckets`, `tft.compute_and_apply_vocabulary`, + `tft.apply_vocabulary`. + These methods annotate the transformed feature with additional information. + If the given `transformed_feature` isn't annotated, this method will fail. + + Example: + ------- + >>> def preprocessing_fn(inputs): + ... bucketized = tft.bucketize(inputs['x'], num_buckets=3) + ... integerized = tft.compute_and_apply_vocabulary(inputs['x']) + ... zeros = tf.zeros_like(inputs['x'], tf.int64) + ... return { + ... 'bucketized': bucketized, + ... 'bucketized_num_buckets': ( + ... zeros + tft.get_num_buckets_for_transformed_feature(bucketized)), + ... 'integerized': integerized, + ... 'integerized_num_buckets': ( + ... zeros + tft.get_num_buckets_for_transformed_feature(integerized)), + ... } + >>> raw_data = [dict(x=3),dict(x=23)] + >>> feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64)) + >>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec) + >>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()): + ... transformed_dataset, transform_fn = ( + ... (raw_data, raw_data_metadata) + ... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn)) + >>> transformed_data, transformed_metadata = transformed_dataset + >>> transformed_data + [{'bucketized': 1, 'bucketized_num_buckets': 3, + 'integerized': 0, 'integerized_num_buckets': 2}, + {'bucketized': 2, 'bucketized_num_buckets': 3, + 'integerized': 1, 'integerized_num_buckets': 2}] + + Args: + ---- + transformed_feature: A `Tensor` or `SparseTensor` which is the direct output + of `tft.bucketize`, `tft.apply_buckets`, + `tft.compute_and_apply_vocabulary` or `tft.apply_vocabulary`. + + Raises: + ------ + ValueError: If the given tensor has not been annotated a the number of + buckets. + + Returns: + ------- + A `Tensor` with the number of buckets for the given `transformed_feature`. + """ + # pyformat: enable + # Adding 1 to the 2nd Tensor of the returned pair in order to compute max + 1. + return tf.cast( + schema_inference.get_tensor_schema_override(transformed_feature)[1] + 1, + tf.int64, + ) @common.log_api_use(common.MAPPER_COLLECTION) -def segment_indices(segment_ids: tf.Tensor, - name: Optional[str] = None) -> tf.Tensor: - """Returns a `Tensor` of indices within each segment. - - segment_ids should be a sequence of non-decreasing non-negative integers that - define a set of segments, e.g. [0, 0, 1, 2, 2, 2] defines 3 segments of length - 2, 1 and 3. The return value is a `Tensor` containing the indices within each - segment. - - Example: - - >>> result = tft.segment_indices(tf.constant([0, 0, 1, 2, 2, 2])) - >>> print(result) - tf.Tensor([0 1 0 0 1 2], shape=(6,), dtype=int32) - - Args: - segment_ids: A 1-d `Tensor` containing an non-decreasing sequence of - non-negative integers with type `tf.int32` or `tf.int64`. - name: (Optional) A name for this operation. - - Returns: - A `Tensor` containing the indices within each segment. - """ - ndims = segment_ids.get_shape().ndims - if ndims != 1 and ndims is not None: - raise ValueError( - 'segment_indices requires a 1-dimensional input. ' - 'segment_indices has {} dimensions.'.format(ndims)) - with tf.compat.v1.name_scope(name, 'segment_indices'): - # TODO(KesterTong): This is a fundamental operation for segments, write a C++ - # op to do this. - # TODO(KesterTong): Add a check that segment_ids are increasing. - segment_lengths = tf.math.segment_sum( - tf.ones_like(segment_ids), segment_ids) - segment_starts = tf.gather(tf.concat([[0], tf.cumsum(segment_lengths)], 0), - segment_ids) - return (tf.range(tf.size(input=segment_ids, out_type=segment_ids.dtype)) - - segment_starts) +def segment_indices(segment_ids: tf.Tensor, name: Optional[str] = None) -> tf.Tensor: + """Returns a `Tensor` of indices within each segment. + + segment_ids should be a sequence of non-decreasing non-negative integers that + define a set of segments, e.g. [0, 0, 1, 2, 2, 2] defines 3 segments of length + 2, 1 and 3. The return value is a `Tensor` containing the indices within each + segment. + + Example: + ------- + >>> result = tft.segment_indices(tf.constant([0, 0, 1, 2, 2, 2])) + >>> print(result) + tf.Tensor([0 1 0 0 1 2], shape=(6,), dtype=int32) + + Args: + ---- + segment_ids: A 1-d `Tensor` containing an non-decreasing sequence of + non-negative integers with type `tf.int32` or `tf.int64`. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor` containing the indices within each segment. + """ + ndims = segment_ids.get_shape().ndims + if ndims != 1 and ndims is not None: + raise ValueError( + "segment_indices requires a 1-dimensional input. " + f"segment_indices has {ndims} dimensions." + ) + with tf.compat.v1.name_scope(name, "segment_indices"): + # TODO(KesterTong): This is a fundamental operation for segments, write a C++ + # op to do this. + # TODO(KesterTong): Add a check that segment_ids are increasing. + segment_lengths = tf.math.segment_sum(tf.ones_like(segment_ids), segment_ids) + segment_starts = tf.gather( + tf.concat([[0], tf.cumsum(segment_lengths)], 0), segment_ids + ) + return ( + tf.range(tf.size(input=segment_ids, out_type=segment_ids.dtype)) + - segment_starts + ) @common.log_api_use(common.MAPPER_COLLECTION) def deduplicate_tensor_per_row(input_tensor, name=None): - """Deduplicates each row (0-th dimension) of the provided tensor. - - Args: - input_tensor: A two-dimensional `Tensor` or `SparseTensor`. The first - dimension is assumed to be the batch or "row" dimension, and deduplication - is done on the 2nd dimension. If the Tensor is 1D it is returned as the - equivalent `SparseTensor` since the "row" is a scalar can't be further - deduplicated. - name: Optional name for the operation. - - Returns: - A `SparseTensor` containing the unique set of values from each - row of the input. Note: the original order of the input may not be - preserved. - """ - with tf.compat.v1.name_scope(name, 'deduplicate_per_row'): - - if isinstance(input_tensor, tf.SparseTensor): - batch_dim = tf.cast(input_tensor.dense_shape[0], tf.int32) - rank = input_tensor.dense_shape.shape[0] - else: - batch_dim = tf.cast(tf.shape(input_tensor)[0], tf.int32) - rank = input_tensor.shape.rank - - def _univalent_dense_to_sparse(batch_dim, input_tensor): - """Helper to convert a 1D dense `Tensor` to a `SparseTensor`.""" - indices = tf.cast( - tf.stack([ - tf.range(batch_dim, dtype=tf.int32), - tf.zeros(batch_dim, dtype=tf.int32) - ], - axis=1), - dtype=tf.int64) - - return tf.SparseTensor( - indices=indices, values=input_tensor, dense_shape=(batch_dim, 1)) - - if rank is not None: - # If the rank is known at graph construction time, and it's rank 1, there - # is no deduplication to be done so we can return early. - if rank <= 1: + """Deduplicates each row (0-th dimension) of the provided tensor. + + Args: + ---- + input_tensor: A two-dimensional `Tensor` or `SparseTensor`. The first + dimension is assumed to be the batch or "row" dimension, and deduplication + is done on the 2nd dimension. If the Tensor is 1D it is returned as the + equivalent `SparseTensor` since the "row" is a scalar can't be further + deduplicated. + name: Optional name for the operation. + + Returns: + ------- + A `SparseTensor` containing the unique set of values from each + row of the input. Note: the original order of the input may not be + preserved. + """ + with tf.compat.v1.name_scope(name, "deduplicate_per_row"): if isinstance(input_tensor, tf.SparseTensor): - return input_tensor - # Even though we are just returning as is, we convert to a SparseTensor - # to ensure consistent output type. - return _univalent_dense_to_sparse(batch_dim, input_tensor) - if rank > 2: - raise ValueError( - 'Deduplication assumes a rank 2 tensor, got {}.'.format(rank)) - return _deduplicate_tensor_per_row(input_tensor, batch_dim) + batch_dim = tf.cast(input_tensor.dense_shape[0], tf.int32) + rank = input_tensor.dense_shape.shape[0] + else: + batch_dim = tf.cast(tf.shape(input_tensor)[0], tf.int32) + rank = input_tensor.shape.rank + + def _univalent_dense_to_sparse(batch_dim, input_tensor): + """Helper to convert a 1D dense `Tensor` to a `SparseTensor`.""" + indices = tf.cast( + tf.stack( + [ + tf.range(batch_dim, dtype=tf.int32), + tf.zeros(batch_dim, dtype=tf.int32), + ], + axis=1, + ), + dtype=tf.int64, + ) + + return tf.SparseTensor( + indices=indices, values=input_tensor, dense_shape=(batch_dim, 1) + ) + + if rank is not None: + # If the rank is known at graph construction time, and it's rank 1, there + # is no deduplication to be done so we can return early. + if rank <= 1: + if isinstance(input_tensor, tf.SparseTensor): + return input_tensor + # Even though we are just returning as is, we convert to a SparseTensor + # to ensure consistent output type. + return _univalent_dense_to_sparse(batch_dim, input_tensor) + if rank > 2: + raise ValueError(f"Deduplication assumes a rank 2 tensor, got {rank}.") + return _deduplicate_tensor_per_row(input_tensor, batch_dim) - if isinstance(input_tensor, tf.SparseTensor): - return _deduplicate_tensor_per_row(input_tensor, batch_dim) - else: - # Again check for rank 1 tensor (that doesn't need deduplication), this - # time handling inputs where rank isn't known until execution time. - dynamic_rank = tf.rank(input_tensor) - return tf.cond( - tf.equal(dynamic_rank, 1), - lambda: _univalent_dense_to_sparse(batch_dim, input_tensor), - lambda: _deduplicate_tensor_per_row(input_tensor, batch_dim), - ) + if isinstance(input_tensor, tf.SparseTensor): + return _deduplicate_tensor_per_row(input_tensor, batch_dim) + else: + # Again check for rank 1 tensor (that doesn't need deduplication), this + # time handling inputs where rank isn't known until execution time. + dynamic_rank = tf.rank(input_tensor) + return tf.cond( + tf.equal(dynamic_rank, 1), + lambda: _univalent_dense_to_sparse(batch_dim, input_tensor), + lambda: _deduplicate_tensor_per_row(input_tensor, batch_dim), + ) _DedupRowLoopArgs = tfx_namedtuple.namedtuple( - 'DedupRowLoopArgs', + "DedupRowLoopArgs", [ - 'index', # Index representing the row of input_tensor to be processed. - 'input_tensor', # `Tensor` or `SparseTensor` to be deuplicated per row. - 'indices', # `TensorArray` containing indices of each deduplicated row. - 'values', # `TensorArray` containing values of each deduplicated row. - 'max_unique', # Tracks the maximum size of any row. - ]) + "index", # Index representing the row of input_tensor to be processed. + "input_tensor", # `Tensor` or `SparseTensor` to be deuplicated per row. + "indices", # `TensorArray` containing indices of each deduplicated row. + "values", # `TensorArray` containing values of each deduplicated row. + "max_unique", # Tracks the maximum size of any row. + ], +) class _DedupRowLoopVars(_DedupRowLoopArgs): - """Loop variables for _deduplicate_per_row.""" - pass + """Loop variables for _deduplicate_per_row.""" + + pass def _deduplicate_tensor_per_row(input_tensor, batch_dim): - """Helper function for deduplicating each row of the provided tensor. - - For each input row, computes the unique values and set them in positions 0 - through num_unique - 1 within the row. - - Args: - input_tensor: A `Tensor` or `SparseTensor` to be deuplicated per row. - batch_dim: The batch dimension or number of "rows" in the batch. - - Returns: - A `SparseTensor` containing the unique set of values from each - row of the input. Note: the original order of the input may not be - preserved. - """ - max_unique = tf.constant(0, dtype=tf.int64) - values = tf.TensorArray( - size=batch_dim, - dtype=input_tensor.dtype, - element_shape=[None], - infer_shape=False) - indices = tf.TensorArray( - size=batch_dim, - dtype=tf.int64, - element_shape=[None, 2], - infer_shape=False) - - def _deduplicate_row(dedup_row_loop_vars): - """Deduplicates the values in the i-th row of the input. + """Helper function for deduplicating each row of the provided tensor. + + For each input row, computes the unique values and set them in positions 0 + through num_unique - 1 within the row. Args: - dedup_row_loop_vars: A _DedupRowLoopVars NamedTuple. + ---- + input_tensor: A `Tensor` or `SparseTensor` to be deuplicated per row. + batch_dim: The batch dimension or number of "rows" in the batch. Returns: - Updated version of the _DedupRowLoopVars for the loop iteration. + ------- + A `SparseTensor` containing the unique set of values from each + row of the input. Note: the original order of the input may not be + preserved. """ - index, input_tensor, indices, values, max_unique = dedup_row_loop_vars - if isinstance(input_tensor, tf.SparseTensor): + max_unique = tf.constant(0, dtype=tf.int64) + values = tf.TensorArray( + size=batch_dim, + dtype=input_tensor.dtype, + element_shape=[None], + infer_shape=False, + ) + indices = tf.TensorArray( + size=batch_dim, dtype=tf.int64, element_shape=[None, 2], infer_shape=False + ) - row = tf.sparse.slice(input_tensor, [index, 0], - [1, input_tensor.dense_shape[1]]) - row_values, _ = tf.unique(row.values) - else: - row = input_tensor[index] - row_values, _ = tf.unique(row) - - # Keep track of the maximum number of unique elements in a row, as this - # will determine the resulting dense shape. - num_unique_values = tf.shape(row_values)[0] - max_unique = tf.cast( - tf.maximum(tf.cast(num_unique_values, tf.int64), max_unique), - tf.int64) - column_indices = tf.cast( - tf.expand_dims(tf.range(num_unique_values), axis=1), tf.int64) - row_indices = tf.fill(tf.shape(column_indices), tf.cast(index, tf.int64)) - values = values.write(index, row_values) - indices = indices.write(index, tf.concat([row_indices, column_indices], 1)) - return [ - _DedupRowLoopVars(index + 1, input_tensor, indices, values, max_unique) - ] + def _deduplicate_row(dedup_row_loop_vars): + """Deduplicates the values in the i-th row of the input. + + Args: + ---- + dedup_row_loop_vars: A _DedupRowLoopVars NamedTuple. - index = tf.constant(0, tf.int32) - (loop_output,) = tf.while_loop( - lambda loop_args: loop_args.index < batch_dim, - _deduplicate_row, - [_DedupRowLoopVars(index, input_tensor, indices, values, max_unique)], - back_prop=False) + Returns: + ------- + Updated version of the _DedupRowLoopVars for the loop iteration. + """ + index, input_tensor, indices, values, max_unique = dedup_row_loop_vars + if isinstance(input_tensor, tf.SparseTensor): + row = tf.sparse.slice( + input_tensor, [index, 0], [1, input_tensor.dense_shape[1]] + ) + row_values, _ = tf.unique(row.values) + else: + row = input_tensor[index] + row_values, _ = tf.unique(row) + + # Keep track of the maximum number of unique elements in a row, as this + # will determine the resulting dense shape. + num_unique_values = tf.shape(row_values)[0] + max_unique = tf.cast( + tf.maximum(tf.cast(num_unique_values, tf.int64), max_unique), tf.int64 + ) + column_indices = tf.cast( + tf.expand_dims(tf.range(num_unique_values), axis=1), tf.int64 + ) + row_indices = tf.fill(tf.shape(column_indices), tf.cast(index, tf.int64)) + values = values.write(index, row_values) + indices = indices.write(index, tf.concat([row_indices, column_indices], 1)) + return [_DedupRowLoopVars(index + 1, input_tensor, indices, values, max_unique)] + + index = tf.constant(0, tf.int32) + (loop_output,) = tf.while_loop( + lambda loop_args: loop_args.index < batch_dim, + _deduplicate_row, + [_DedupRowLoopVars(index, input_tensor, indices, values, max_unique)], + back_prop=False, + ) - dense_shape = tf.convert_to_tensor( - [tf.cast(batch_dim, tf.int64), - tf.cast(loop_output.max_unique, tf.int64)], - dtype=tf.int64) - return tf.SparseTensor( - indices=tf.cast(loop_output.indices.concat(), tf.int64), - values=loop_output.values.concat(), - dense_shape=dense_shape) + dense_shape = tf.convert_to_tensor( + [tf.cast(batch_dim, tf.int64), tf.cast(loop_output.max_unique, tf.int64)], + dtype=tf.int64, + ) + return tf.SparseTensor( + indices=tf.cast(loop_output.indices.concat(), tf.int64), + values=loop_output.values.concat(), + dense_shape=dense_shape, + ) @common.log_api_use(common.MAPPER_COLLECTION) -def bag_of_words(tokens: tf.SparseTensor, - ngram_range: Tuple[int, int], - separator: str, - name: Optional[str] = None) -> tf.SparseTensor: - """Computes a bag of "words" based on the specified ngram configuration. - - A light wrapper around tft.ngrams. First computes ngrams, then transforms the - ngram representation (list semantics) into a Bag of Words (set semantics) per - row. Each row reflects the set of *unique* ngrams present in an input record. - - See tft.ngrams for more information. - - Args: - tokens: a two-dimensional `SparseTensor` of dtype `tf.string` containing - tokens that will be used to construct a bag of words. - ngram_range: A pair with the range (inclusive) of ngram sizes to compute. - separator: a string that will be inserted between tokens when ngrams are - constructed. - name: (Optional) A name for this operation. - - Returns: - A `SparseTensor` containing the unique set of ngrams from each row of the - input. Note: the original order of the ngrams may not be preserved. - """ - if tokens.get_shape().ndims != 2: - raise ValueError('bag_of_words requires `tokens` to be 2-dimensional') - with tf.compat.v1.name_scope(name, 'bag_of_words'): - # First compute the ngram representation, which will contain ordered and - # possibly duplicated ngrams per row. - all_ngrams = ngrams(tokens, ngram_range, separator) - # Then deduplicate the ngrams in each row. - return deduplicate_tensor_per_row(all_ngrams) +def bag_of_words( + tokens: tf.SparseTensor, + ngram_range: Tuple[int, int], + separator: str, + name: Optional[str] = None, +) -> tf.SparseTensor: + """Computes a bag of "words" based on the specified ngram configuration. + + A light wrapper around tft.ngrams. First computes ngrams, then transforms the + ngram representation (list semantics) into a Bag of Words (set semantics) per + row. Each row reflects the set of *unique* ngrams present in an input record. + + See tft.ngrams for more information. + + Args: + ---- + tokens: a two-dimensional `SparseTensor` of dtype `tf.string` containing + tokens that will be used to construct a bag of words. + ngram_range: A pair with the range (inclusive) of ngram sizes to compute. + separator: a string that will be inserted between tokens when ngrams are + constructed. + name: (Optional) A name for this operation. + + Returns: + ------- + A `SparseTensor` containing the unique set of ngrams from each row of the + input. Note: the original order of the ngrams may not be preserved. + """ + if tokens.get_shape().ndims != 2: + raise ValueError("bag_of_words requires `tokens` to be 2-dimensional") + with tf.compat.v1.name_scope(name, "bag_of_words"): + # First compute the ngram representation, which will contain ordered and + # possibly duplicated ngrams per row. + all_ngrams = ngrams(tokens, ngram_range, separator) + # Then deduplicate the ngrams in each row. + return deduplicate_tensor_per_row(all_ngrams) @common.log_api_use(common.MAPPER_COLLECTION) -def ngrams(tokens: tf.SparseTensor, - ngram_range: Tuple[int, int], - separator: str, - name: Optional[str] = None) -> tf.SparseTensor: - """Create a `SparseTensor` of n-grams. - - Given a `SparseTensor` of tokens, returns a `SparseTensor` containing the - ngrams that can be constructed from each row. - - `separator` is inserted between each pair of tokens, so " " would be an - appropriate choice if the tokens are words, while "" would be an appropriate - choice if they are characters. - - Example: - - >>> tokens = tf.SparseTensor( - ... indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]], - ... values=['One', 'was', 'Johnny', 'Two', 'was', 'a', 'rat'], - ... dense_shape=[2, 4]) - >>> print(tft.ngrams(tokens, ngram_range=(1, 3), separator=' ')) - SparseTensor(indices=tf.Tensor( - [[0 0] [0 1] [0 2] [0 3] [0 4] [0 5] - [1 0] [1 1] [1 2] [1 3] [1 4] [1 5] [1 6] [1 7] [1 8]], - shape=(15, 2), dtype=int64), - values=tf.Tensor( - [b'One' b'One was' b'One was Johnny' b'was' b'was Johnny' b'Johnny' b'Two' - b'Two was' b'Two was a' b'was' b'was a' b'was a rat' b'a' b'a rat' - b'rat'], shape=(15,), dtype=string), - dense_shape=tf.Tensor([2 9], shape=(2,), dtype=int64)) - - Args: - tokens: a two-dimensional`SparseTensor` of dtype `tf.string` containing - tokens that will be used to construct ngrams. - ngram_range: A pair with the range (inclusive) of ngram sizes to return. - separator: a string that will be inserted between tokens when ngrams are - constructed. - name: (Optional) A name for this operation. - - Returns: - A `SparseTensor` containing all ngrams from each row of the input. Note: - if an ngram appears multiple times in the input row, it will be present the - same number of times in the output. For unique ngrams, see tft.bag_of_words. - - Raises: - ValueError: if `tokens` is not 2D. - ValueError: if ngram_range[0] < 1 or ngram_range[1] < ngram_range[0] - """ - # This function is implemented as follows. Assume we start with the following - # `SparseTensor`: - # - # indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [2, 0], [2, 1], [2, 2]] - # values=['a', 'b', 'c', 'd', 'q', 'x', 'y', 'z'] - # dense_shape=[3, 4] - # - # First we then create shifts of the values and first column of indices, - # buffering to avoid overrunning the end of the array, so the shifted values - # (if we are ngrams up to size 3) are - # - # shifted_batch_indices[0]=[0, 0, 0, 0, 1, 2, 2, 2] - # shifted_tokens[0]=['a', 'b', 'c', 'd', 'q', 'x', 'y', 'z'] - # - # shifted_batch_indices[1]=[0, 0, 0, 1, 2, 2, 2, -1] - # shifted_tokens[1]=['b', 'c', 'd', 'q', 'x', 'y', 'z', ''] - # - # shifted_batch_indices[2]=[0, 0, 1, 2, 2, 2, -1, -1] - # shifted_tokens[2]=['c', 'd', 'q', 'x', 'y', 'z', '', ''] - # - # These shifted ngrams are used to create the ngrams as follows. We use - # tf.string_join to join shifted_tokens[:k] to create k-grams. The `separator` - # string is inserted between each pair of tokens in the k-gram. - # The batch that the first of these belonged to is given by - # shifted_batch_indices[0]. However some of these will cross the boundaries - # between 'batches' and so we we create a boolean mask which is True when - # shifted_indices[:k] are all equal. - # - # This results in tensors of ngrams, their batch indices and a boolean mask, - # which we then use to construct the output SparseTensor. - if tokens.get_shape().ndims != 2: - raise ValueError('ngrams requires `tokens` to be 2-dimensional') - with tf.compat.v1.name_scope(name, 'ngrams'): - if ngram_range[0] < 1 or ngram_range[1] < ngram_range[0]: - raise ValueError('Invalid ngram_range: %r' % (ngram_range,)) - - def _sliding_windows(values, num_shifts, fill_value): - buffered_values = tf.concat( - [values, tf.fill([num_shifts - 1], fill_value)], 0) - return [ - tf.slice(buffered_values, [i], tf.shape(input=values)) - for i in range(num_shifts) - ] - - shifted_batch_indices = _sliding_windows( - tokens.indices[:, 0], ngram_range[1] + 1, - tf.constant(-1, dtype=tf.int64)) - shifted_tokens = _sliding_windows(tokens.values, ngram_range[1] + 1, '') - - # Construct a tensor of the form - # [['a', 'ab, 'abc'], ['b', 'bcd', cde'], ...] - def _string_join(tensors): - if tensors: - return tf.strings.join(tensors, separator=separator) - else: - return +def ngrams( + tokens: tf.SparseTensor, + ngram_range: Tuple[int, int], + separator: str, + name: Optional[str] = None, +) -> tf.SparseTensor: + """Create a `SparseTensor` of n-grams. + + Given a `SparseTensor` of tokens, returns a `SparseTensor` containing the + ngrams that can be constructed from each row. + + `separator` is inserted between each pair of tokens, so " " would be an + appropriate choice if the tokens are words, while "" would be an appropriate + choice if they are characters. + + Example: + ------- + >>> tokens = tf.SparseTensor( + ... indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [1, 3]], + ... values=['One', 'was', 'Johnny', 'Two', 'was', 'a', 'rat'], + ... dense_shape=[2, 4]) + >>> print(tft.ngrams(tokens, ngram_range=(1, 3), separator=' ')) + SparseTensor(indices=tf.Tensor( + [[0 0] [0 1] [0 2] [0 3] [0 4] [0 5] + [1 0] [1 1] [1 2] [1 3] [1 4] [1 5] [1 6] [1 7] [1 8]], + shape=(15, 2), dtype=int64), + values=tf.Tensor( + [b'One' b'One was' b'One was Johnny' b'was' b'was Johnny' b'Johnny' b'Two' + b'Two was' b'Two was a' b'was' b'was a' b'was a rat' b'a' b'a rat' + b'rat'], shape=(15,), dtype=string), + dense_shape=tf.Tensor([2 9], shape=(2,), dtype=int64)) - ngrams_array = [_string_join(shifted_tokens[:k]) - for k in range(ngram_range[0], ngram_range[1] + 1)] - ngrams_tensor = tf.stack(ngrams_array, 1) - - # Construct a boolean mask for whether each ngram in ngram_tensor is valid, - # in that each character came from the same batch. - valid_ngram = tf.equal( - tf.math.cumprod( - tf.cast( - tf.equal( - tf.stack(shifted_batch_indices, 1), - tf.expand_dims(shifted_batch_indices[0], 1)), - dtype=tf.int32), - axis=1), 1) - valid_ngram = valid_ngram[:, (ngram_range[0] - 1):ngram_range[1]] - - # Construct a tensor with the batch that each ngram in ngram_tensor belongs - # to. - batch_indices = tf.tile(tf.expand_dims(tokens.indices[:, 0], 1), - [1, ngram_range[1] + 1 - ngram_range[0]]) - - # Apply the boolean mask and construct a SparseTensor with the given indices - # and values, where another index is added to give the position within a - # batch. - batch_indices = tf.boolean_mask(tensor=batch_indices, mask=valid_ngram) - ngrams_tensor = tf.boolean_mask(tensor=ngrams_tensor, mask=valid_ngram) - instance_indices = segment_indices(batch_indices) - dense_shape_second_dim = tf.maximum( - tf.reduce_max(input_tensor=instance_indices), -1) + 1 - return tf.SparseTensor( - indices=tf.stack([batch_indices, instance_indices], 1), - values=ngrams_tensor, - dense_shape=tf.stack( - [tokens.dense_shape[0], dense_shape_second_dim])) + Args: + ---- + tokens: a two-dimensional`SparseTensor` of dtype `tf.string` containing + tokens that will be used to construct ngrams. + ngram_range: A pair with the range (inclusive) of ngram sizes to return. + separator: a string that will be inserted between tokens when ngrams are + constructed. + name: (Optional) A name for this operation. + + Returns: + ------- + A `SparseTensor` containing all ngrams from each row of the input. Note: + if an ngram appears multiple times in the input row, it will be present the + same number of times in the output. For unique ngrams, see tft.bag_of_words. + + Raises: + ------ + ValueError: if `tokens` is not 2D. + ValueError: if ngram_range[0] < 1 or ngram_range[1] < ngram_range[0] + """ + # This function is implemented as follows. Assume we start with the following + # `SparseTensor`: + # + # indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [2, 0], [2, 1], [2, 2]] + # values=['a', 'b', 'c', 'd', 'q', 'x', 'y', 'z'] + # dense_shape=[3, 4] + # + # First we then create shifts of the values and first column of indices, + # buffering to avoid overrunning the end of the array, so the shifted values + # (if we are ngrams up to size 3) are + # + # shifted_batch_indices[0]=[0, 0, 0, 0, 1, 2, 2, 2] + # shifted_tokens[0]=['a', 'b', 'c', 'd', 'q', 'x', 'y', 'z'] + # + # shifted_batch_indices[1]=[0, 0, 0, 1, 2, 2, 2, -1] + # shifted_tokens[1]=['b', 'c', 'd', 'q', 'x', 'y', 'z', ''] + # + # shifted_batch_indices[2]=[0, 0, 1, 2, 2, 2, -1, -1] + # shifted_tokens[2]=['c', 'd', 'q', 'x', 'y', 'z', '', ''] + # + # These shifted ngrams are used to create the ngrams as follows. We use + # tf.string_join to join shifted_tokens[:k] to create k-grams. The `separator` + # string is inserted between each pair of tokens in the k-gram. + # The batch that the first of these belonged to is given by + # shifted_batch_indices[0]. However some of these will cross the boundaries + # between 'batches' and so we we create a boolean mask which is True when + # shifted_indices[:k] are all equal. + # + # This results in tensors of ngrams, their batch indices and a boolean mask, + # which we then use to construct the output SparseTensor. + if tokens.get_shape().ndims != 2: + raise ValueError("ngrams requires `tokens` to be 2-dimensional") + with tf.compat.v1.name_scope(name, "ngrams"): + if ngram_range[0] < 1 or ngram_range[1] < ngram_range[0]: + raise ValueError("Invalid ngram_range: %r" % (ngram_range,)) + + def _sliding_windows(values, num_shifts, fill_value): + buffered_values = tf.concat( + [values, tf.fill([num_shifts - 1], fill_value)], 0 + ) + return [ + tf.slice(buffered_values, [i], tf.shape(input=values)) + for i in range(num_shifts) + ] + + shifted_batch_indices = _sliding_windows( + tokens.indices[:, 0], ngram_range[1] + 1, tf.constant(-1, dtype=tf.int64) + ) + shifted_tokens = _sliding_windows(tokens.values, ngram_range[1] + 1, "") + + # Construct a tensor of the form + # [['a', 'ab, 'abc'], ['b', 'bcd', cde'], ...] + def _string_join(tensors): + if tensors: + return tf.strings.join(tensors, separator=separator) + else: + return None + + ngrams_array = [ + _string_join(shifted_tokens[:k]) + for k in range(ngram_range[0], ngram_range[1] + 1) + ] + ngrams_tensor = tf.stack(ngrams_array, 1) + + # Construct a boolean mask for whether each ngram in ngram_tensor is valid, + # in that each character came from the same batch. + valid_ngram = tf.equal( + tf.math.cumprod( + tf.cast( + tf.equal( + tf.stack(shifted_batch_indices, 1), + tf.expand_dims(shifted_batch_indices[0], 1), + ), + dtype=tf.int32, + ), + axis=1, + ), + 1, + ) + valid_ngram = valid_ngram[:, (ngram_range[0] - 1) : ngram_range[1]] + + # Construct a tensor with the batch that each ngram in ngram_tensor belongs + # to. + batch_indices = tf.tile( + tf.expand_dims(tokens.indices[:, 0], 1), + [1, ngram_range[1] + 1 - ngram_range[0]], + ) + + # Apply the boolean mask and construct a SparseTensor with the given indices + # and values, where another index is added to give the position within a + # batch. + batch_indices = tf.boolean_mask(tensor=batch_indices, mask=valid_ngram) + ngrams_tensor = tf.boolean_mask(tensor=ngrams_tensor, mask=valid_ngram) + instance_indices = segment_indices(batch_indices) + dense_shape_second_dim = ( + tf.maximum(tf.reduce_max(input_tensor=instance_indices), -1) + 1 + ) + return tf.SparseTensor( + indices=tf.stack([batch_indices, instance_indices], 1), + values=ngrams_tensor, + dense_shape=tf.stack([tokens.dense_shape[0], dense_shape_second_dim]), + ) @common.log_api_use(common.MAPPER_COLLECTION) -def word_count(tokens: Union[tf.SparseTensor, tf.RaggedTensor], - name: Optional[str] = None) -> tf.Tensor: - # pyformat: disable - """Find the token count of each document/row. - - `tokens` is either a `RaggedTensor` or `SparseTensor`, representing tokenized - strings. This function simply returns size of each row, so the dtype is not - constrained to string. - - Example: - >>> sparse = tf.SparseTensor(indices=[[0, 0], [0, 1], [2, 2]], - ... values=['a', 'b', 'c'], dense_shape=(4, 4)) - >>> tft.word_count(sparse) - - - Args: - tokens: either - (1) a `SparseTensor`, or - (2) a `RaggedTensor` with ragged rank of 1, non-ragged rank of 1 - of dtype `tf.string` containing tokens to be counted - name: (Optional) A name for this operation. - - Returns: - A one-dimensional `Tensor` the token counts of each row. - - Raises: - ValueError: if tokens is neither sparse nor ragged - """ - # pyformat: enable - with tf.compat.v1.name_scope(name, 'word_count'): - if isinstance(tokens, tf.RaggedTensor): - return tokens.row_lengths() - elif isinstance(tokens, tf.SparseTensor): - result = tf.sparse.reduce_sum( - tf.SparseTensor(indices=tokens.indices, - values=tf.ones_like(tokens.values, dtype=tf.int64), - dense_shape=tokens.dense_shape), - axis=list(range(1, tokens.get_shape().ndims))) - result.set_shape([tokens.shape[0]]) - return result - else: - raise ValueError('Invalid token tensor') +def word_count( + tokens: Union[tf.SparseTensor, tf.RaggedTensor], name: Optional[str] = None +) -> tf.Tensor: + # pyformat: disable + """Find the token count of each document/row. + + `tokens` is either a `RaggedTensor` or `SparseTensor`, representing tokenized + strings. This function simply returns size of each row, so the dtype is not + constrained to string. + + Example: + ------- + >>> sparse = tf.SparseTensor(indices=[[0, 0], [0, 1], [2, 2]], + ... values=['a', 'b', 'c'], dense_shape=(4, 4)) + >>> tft.word_count(sparse) + + + Args: + ---- + tokens: either + (1) a `SparseTensor`, or + (2) a `RaggedTensor` with ragged rank of 1, non-ragged rank of 1 + of dtype `tf.string` containing tokens to be counted + name: (Optional) A name for this operation. + + Returns: + ------- + A one-dimensional `Tensor` the token counts of each row. + + Raises: + ------ + ValueError: if tokens is neither sparse nor ragged + """ + # pyformat: enable + with tf.compat.v1.name_scope(name, "word_count"): + if isinstance(tokens, tf.RaggedTensor): + return tokens.row_lengths() + elif isinstance(tokens, tf.SparseTensor): + result = tf.sparse.reduce_sum( + tf.SparseTensor( + indices=tokens.indices, + values=tf.ones_like(tokens.values, dtype=tf.int64), + dense_shape=tokens.dense_shape, + ), + axis=list(range(1, tokens.get_shape().ndims)), + ) + result.set_shape([tokens.shape[0]]) + return result + else: + raise ValueError("Invalid token tensor") @common.log_api_use(common.MAPPER_COLLECTION) @@ -1705,125 +1837,138 @@ def hash_strings( strings: common_types.ConsistentTensorType, hash_buckets: int, key: Optional[Iterable[int]] = None, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - """Hash strings into buckets. - - Args: - strings: a `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype `tf.string`. - hash_buckets: the number of hash buckets. - key: optional. An array of two Python `uint64`. If passed, output will be a - deterministic function of `strings` and `key`. Note that hashing will be - slower if this value is specified. - name: (Optional) A name for this operation. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype `tf.int64` with the - same shape as - the input `strings`. - - Raises: - TypeError: if `strings` is not a `Tensor`, `SparseTensor`, or `RaggedTensor` - of dtype `tf.string`. - """ - if (not isinstance(strings, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)) or - strings.dtype != tf.string): - raise TypeError( - 'Input to hash_strings must be a `Tensor`, `SparseTensor`, or ' - f'`RaggedTensor` of dtype string; got {strings.dtype}') - if isinstance(strings, tf.Tensor): - if name is None: - name = 'hash_strings' - if key is None: - return tf.strings.to_hash_bucket_fast(strings, hash_buckets, name=name) - return tf.strings.to_hash_bucket_strong( - strings, hash_buckets, key, name=name) - else: - compose_result_fn = _make_composite_tensor_wrapper_if_composite(strings) - values = tf_utils.get_values(strings) - return compose_result_fn(hash_strings(values, hash_buckets, key)) + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + """Hash strings into buckets. + + Args: + ---- + strings: a `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype `tf.string`. + hash_buckets: the number of hash buckets. + key: optional. An array of two Python `uint64`. If passed, output will be a + deterministic function of `strings` and `key`. Note that hashing will be + slower if this value is specified. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` of dtype `tf.int64` with the + same shape as + the input `strings`. + + Raises: + ------ + TypeError: if `strings` is not a `Tensor`, `SparseTensor`, or `RaggedTensor` + of dtype `tf.string`. + """ + if ( + not isinstance(strings, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)) + or strings.dtype != tf.string + ): + raise TypeError( + "Input to hash_strings must be a `Tensor`, `SparseTensor`, or " + f"`RaggedTensor` of dtype string; got {strings.dtype}" + ) + if isinstance(strings, tf.Tensor): + if name is None: + name = "hash_strings" + if key is None: + return tf.strings.to_hash_bucket_fast(strings, hash_buckets, name=name) + return tf.strings.to_hash_bucket_strong(strings, hash_buckets, key, name=name) + else: + compose_result_fn = _make_composite_tensor_wrapper_if_composite(strings) + values = tf_utils.get_values(strings) + return compose_result_fn(hash_strings(values, hash_buckets, key)) @common.log_api_use(common.MAPPER_COLLECTION) -def bucketize(x: common_types.ConsistentTensorType, - num_buckets: int, - epsilon: Optional[float] = None, - weights: Optional[tf.Tensor] = None, - elementwise: bool = False, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - """Returns a bucketized column, with a bucket index assigned to each input. - - Args: - x: A numeric input `Tensor`, `SparseTensor`, or `RaggedTensor` whose values - should be mapped to buckets. For a `CompositeTensor` only non-missing - values will be included in the quantiles computation, and the result of - `bucketize` will be a `CompositeTensor` with non-missing values mapped to - buckets. If elementwise=True then `x` must be dense. - num_buckets: Values in the input `x` are divided into approximately - equal-sized buckets, where the number of buckets is `num_buckets`. - epsilon: (Optional) Error tolerance, typically a small fraction close to - zero. If a value is not specified by the caller, a suitable value is - computed based on experimental results. For `num_buckets` less than 100, - the value of 0.01 is chosen to handle a dataset of up to ~1 trillion input - data values. If `num_buckets` is larger, then epsilon is set to - (1/`num_buckets`) to enforce a stricter error tolerance, because more - buckets will result in smaller range for each bucket, and so we want the - boundaries to be less fuzzy. See analyzers.quantiles() for details. - weights: (Optional) Weights tensor for the quantiles. Tensor must have the - same shape as x. - elementwise: (Optional) If true, bucketize each element of the tensor - independently. - name: (Optional) A name for this operation. - - Returns: - A `Tensor` of the same shape as `x`, with each element in the - returned tensor representing the bucketized value. Bucketized value is - in the range [0, actual_num_buckets). Sometimes the actual number of buckets - can be different than num_buckets hint, for example in case the number of - distinct values is smaller than num_buckets, or in cases where the - input values are not uniformly distributed. - NaN values are mapped to the last bucket. Values with NaN weights are - ignored in bucket boundaries calculation. - - Raises: - TypeError: If num_buckets is not an int. - ValueError: If value of num_buckets is not > 1. - ValueError: If elementwise=True and x is a `CompositeTensor`. - """ - with tf.compat.v1.name_scope(name, 'bucketize'): - if not isinstance(num_buckets, int): - raise TypeError('num_buckets must be an int, got %s' % type(num_buckets)) - - if num_buckets < 1: - raise ValueError('Invalid num_buckets %d' % num_buckets) - - if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)) and elementwise: - raise ValueError( - 'bucketize requires `x` to be dense if `elementwise=True`') - - if epsilon is None: - # See explanation in args documentation for epsilon. - epsilon = min(1.0 / num_buckets, 0.01) +def bucketize( + x: common_types.ConsistentTensorType, + num_buckets: int, + epsilon: Optional[float] = None, + weights: Optional[tf.Tensor] = None, + elementwise: bool = False, + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + """Returns a bucketized column, with a bucket index assigned to each input. - x_values = tf_utils.get_values(x) - bucket_boundaries = analyzers.quantiles( - x_values, - num_buckets, - epsilon, - weights, - reduce_instance_dims=not elementwise) - - if not elementwise: - return apply_buckets(x, bucket_boundaries) - - num_features = tf.math.reduce_prod(x.get_shape()[1:]) - bucket_boundaries = tf.reshape(bucket_boundaries, [num_features, -1]) - x_reshaped = tf.reshape(x, [-1, num_features]) - bucketized = [] - for idx, boundaries in enumerate(tf.unstack(bucket_boundaries, axis=0)): - bucketized.append(apply_buckets(x_reshaped[:, idx], - tf.expand_dims(boundaries, axis=0))) - return tf.reshape(tf.stack(bucketized, axis=1), - [-1] + x.get_shape().as_list()[1:]) + Args: + ---- + x: A numeric input `Tensor`, `SparseTensor`, or `RaggedTensor` whose values + should be mapped to buckets. For a `CompositeTensor` only non-missing + values will be included in the quantiles computation, and the result of + `bucketize` will be a `CompositeTensor` with non-missing values mapped to + buckets. If elementwise=True then `x` must be dense. + num_buckets: Values in the input `x` are divided into approximately + equal-sized buckets, where the number of buckets is `num_buckets`. + epsilon: (Optional) Error tolerance, typically a small fraction close to + zero. If a value is not specified by the caller, a suitable value is + computed based on experimental results. For `num_buckets` less than 100, + the value of 0.01 is chosen to handle a dataset of up to ~1 trillion input + data values. If `num_buckets` is larger, then epsilon is set to + (1/`num_buckets`) to enforce a stricter error tolerance, because more + buckets will result in smaller range for each bucket, and so we want the + boundaries to be less fuzzy. See analyzers.quantiles() for details. + weights: (Optional) Weights tensor for the quantiles. Tensor must have the + same shape as x. + elementwise: (Optional) If true, bucketize each element of the tensor + independently. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor` of the same shape as `x`, with each element in the + returned tensor representing the bucketized value. Bucketized value is + in the range [0, actual_num_buckets). Sometimes the actual number of buckets + can be different than num_buckets hint, for example in case the number of + distinct values is smaller than num_buckets, or in cases where the + input values are not uniformly distributed. + NaN values are mapped to the last bucket. Values with NaN weights are + ignored in bucket boundaries calculation. + + Raises: + ------ + TypeError: If num_buckets is not an int. + ValueError: If value of num_buckets is not > 1. + ValueError: If elementwise=True and x is a `CompositeTensor`. + """ + with tf.compat.v1.name_scope(name, "bucketize"): + if not isinstance(num_buckets, int): + raise TypeError("num_buckets must be an int, got %s" % type(num_buckets)) + + if num_buckets < 1: + raise ValueError("Invalid num_buckets %d" % num_buckets) + + if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)) and elementwise: + raise ValueError("bucketize requires `x` to be dense if `elementwise=True`") + + if epsilon is None: + # See explanation in args documentation for epsilon. + epsilon = min(1.0 / num_buckets, 0.01) + + x_values = tf_utils.get_values(x) + bucket_boundaries = analyzers.quantiles( + x_values, + num_buckets, + epsilon, + weights, + reduce_instance_dims=not elementwise, + ) + + if not elementwise: + return apply_buckets(x, bucket_boundaries) + + num_features = tf.math.reduce_prod(x.get_shape()[1:]) + bucket_boundaries = tf.reshape(bucket_boundaries, [num_features, -1]) + x_reshaped = tf.reshape(x, [-1, num_features]) + bucketized = [] + for idx, boundaries in enumerate(tf.unstack(bucket_boundaries, axis=0)): + bucketized.append( + apply_buckets(x_reshaped[:, idx], tf.expand_dims(boundaries, axis=0)) + ) + return tf.reshape( + tf.stack(bucketized, axis=1), [-1] + x.get_shape().as_list()[1:] + ) # TODO(b/179891014): Implement key_vocabulary_filename for bucketize_per_key. @@ -1834,80 +1979,95 @@ def bucketize_per_key( num_buckets: int, epsilon: Optional[float] = None, weights: Optional[common_types.ConsistentTensorType] = None, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - """Returns a bucketized column, with a bucket index assigned to each input. - - Args: - x: A numeric input `Tensor`, `SparseTensor`, or `RaggedTensor` with rank 1, - whose values should be mapped to buckets. `CompositeTensor`s will have - their non-missing values mapped and missing values left as missing. - key: A `Tensor`, `SparseTensor`, or `RaggedTensor` with the same shape as - `x` and dtype tf.string. If `x` is a `CompositeTensor`, `key` must - exactly match `x` in everything except values, i.e. indices and - dense_shape or nested row splits must be identical. - num_buckets: Values in the input `x` are divided into approximately - equal-sized buckets, where the number of buckets is num_buckets. - epsilon: (Optional) see `bucketize`. - weights: (Optional) A `Tensor`, `SparseTensor`, or `RaggedTensor` with the - same shape as `x` and dtype tf.float32. Used as weights for quantiles - calculation. If `x` is a `CompositeTensor`, `weights` must exactly match - `x` in everything except values. - name: (Optional) A name for this operation. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` of the same shape as `x`, with - each element in the returned tensor representing the bucketized value. - Bucketized value is in the range [0, actual_num_buckets). If the computed - key vocabulary doesn't have an entry for `key` then the resulting bucket is - -1. - - Raises: - ValueError: If value of num_buckets is not > 1. - """ - with tf.compat.v1.name_scope(name, 'bucketize_per_key'): - if not isinstance(num_buckets, int): - raise TypeError( - 'num_buckets must be an int, got {}'.format(type(num_buckets))) - - if num_buckets < 1: - raise ValueError('Invalid num_buckets {}'.format(num_buckets)) - - if epsilon is None: - # See explanation in args documentation for epsilon. - epsilon = min(1.0 / num_buckets, 0.01) - - (key_vocab, bucket_boundaries, scale_factor_per_key, shift_per_key, - actual_num_buckets) = ( - analyzers._quantiles_per_key( # pylint: disable=protected-access - tf_utils.get_values(x), - tf_utils.get_values(key), - num_buckets, - epsilon, - weights=tf_utils.get_values(weights))) - return _apply_buckets_with_keys(x, key, key_vocab, bucket_boundaries, - scale_factor_per_key, shift_per_key, - actual_num_buckets) + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + """Returns a bucketized column, with a bucket index assigned to each input. + + Args: + ---- + x: A numeric input `Tensor`, `SparseTensor`, or `RaggedTensor` with rank 1, + whose values should be mapped to buckets. `CompositeTensor`s will have + their non-missing values mapped and missing values left as missing. + key: A `Tensor`, `SparseTensor`, or `RaggedTensor` with the same shape as + `x` and dtype tf.string. If `x` is a `CompositeTensor`, `key` must + exactly match `x` in everything except values, i.e. indices and + dense_shape or nested row splits must be identical. + num_buckets: Values in the input `x` are divided into approximately + equal-sized buckets, where the number of buckets is num_buckets. + epsilon: (Optional) see `bucketize`. + weights: (Optional) A `Tensor`, `SparseTensor`, or `RaggedTensor` with the + same shape as `x` and dtype tf.float32. Used as weights for quantiles + calculation. If `x` is a `CompositeTensor`, `weights` must exactly match + `x` in everything except values. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` of the same shape as `x`, with + each element in the returned tensor representing the bucketized value. + Bucketized value is in the range [0, actual_num_buckets). If the computed + key vocabulary doesn't have an entry for `key` then the resulting bucket is + -1. + + Raises: + ------ + ValueError: If value of num_buckets is not > 1. + """ + with tf.compat.v1.name_scope(name, "bucketize_per_key"): + if not isinstance(num_buckets, int): + raise TypeError(f"num_buckets must be an int, got {type(num_buckets)}") + + if num_buckets < 1: + raise ValueError(f"Invalid num_buckets {num_buckets}") + + if epsilon is None: + # See explanation in args documentation for epsilon. + epsilon = min(1.0 / num_buckets, 0.01) + + ( + key_vocab, + bucket_boundaries, + scale_factor_per_key, + shift_per_key, + actual_num_buckets, + ) = analyzers._quantiles_per_key( # pylint: disable=protected-access + tf_utils.get_values(x), + tf_utils.get_values(key), + num_buckets, + epsilon, + weights=tf_utils.get_values(weights), + ) + return _apply_buckets_with_keys( + x, + key, + key_vocab, + bucket_boundaries, + scale_factor_per_key, + shift_per_key, + actual_num_buckets, + ) def _make_composite_tensor_wrapper_if_composite( - x: common_types.ConsistentTensorType + x: common_types.ConsistentTensorType, ) -> Callable[[tf.Tensor], common_types.ConsistentTensorType]: - """Produces a function to wrap values in the composite structure of x.""" - if isinstance(x, tf.SparseTensor): - return lambda values: tf.SparseTensor(x.indices, values, x.dense_shape) - elif isinstance(x, tf.RaggedTensor): + """Produces a function to wrap values in the composite structure of x.""" + if isinstance(x, tf.SparseTensor): + return lambda values: tf.SparseTensor(x.indices, values, x.dense_shape) + elif isinstance(x, tf.RaggedTensor): - def from_nested_row_splits(values): - return tf.RaggedTensor.from_nested_row_splits( - values, x.nested_row_splits, validate=False) + def from_nested_row_splits(values): + return tf.RaggedTensor.from_nested_row_splits( + values, x.nested_row_splits, validate=False + ) - return from_nested_row_splits - else: - return lambda values: values + return from_nested_row_splits + else: + return lambda values: values def _fill_shape(value, shape, dtype): - return tf.cast(tf.fill(shape, value), dtype) + return tf.cast(tf.fill(shape, value), dtype) def _apply_buckets_with_keys( @@ -1918,362 +2078,415 @@ def _apply_buckets_with_keys( scale_factor_per_key: tf.Tensor, shift_per_key: tf.Tensor, num_buckets: int, - name: Optional[int] = None) -> common_types.ConsistentTensorType: - """Bucketize an input where boundaries depend on the index. - - Args: - x: A 1-d `Tensor`, `SparseTensor`, or `RaggedTensor`. - key: A 1-d `Tensor`, `SparseTensor`, or `RaggedTensor` with the same size as - `x`. - key_vocab: A vocab containing all keys. Must be exhaustive, an out-of-vocab - entry in `key` will cause a crash. - bucket_boundaries: A rank-1 Tensor. - scale_factor_per_key: A rank-1 Tensor of shape (key_size,). - shift_per_key: A rank-1 Tensor of shape (key_size,). - num_buckets: A scalar. - name: (Optional) A name for this operation. - - Returns: - A tensor with the same shape as `x` and dtype tf.int64. If any value in - `key` is not present in `key_vocab` then the resulting bucket will be -1. - """ - with tf.compat.v1.name_scope(name, 'apply_buckets_with_keys'): - x_values = tf.cast(tf_utils.get_values(x), tf.float32) - compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) - key_values = tf_utils.get_values(key) + name: Optional[int] = None, +) -> common_types.ConsistentTensorType: + """Bucketize an input where boundaries depend on the index. - # Convert `key_values` to indices in key_vocab. - key_indices = tf_utils.lookup_key(key_values, key_vocab) + Args: + ---- + x: A 1-d `Tensor`, `SparseTensor`, or `RaggedTensor`. + key: A 1-d `Tensor`, `SparseTensor`, or `RaggedTensor` with the same size as + `x`. + key_vocab: A vocab containing all keys. Must be exhaustive, an out-of-vocab + entry in `key` will cause a crash. + bucket_boundaries: A rank-1 Tensor. + scale_factor_per_key: A rank-1 Tensor of shape (key_size,). + shift_per_key: A rank-1 Tensor of shape (key_size,). + num_buckets: A scalar. + name: (Optional) A name for this operation. - adjusted_key_indices = tf.where( - key_indices < 0, _fill_shape(0, tf.shape(key_indices), tf.int64), - key_indices) + Returns: + ------- + A tensor with the same shape as `x` and dtype tf.int64. If any value in + `key` is not present in `key_vocab` then the resulting bucket will be -1. + """ + with tf.compat.v1.name_scope(name, "apply_buckets_with_keys"): + x_values = tf.cast(tf_utils.get_values(x), tf.float32) + compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) + key_values = tf_utils.get_values(key) - # Apply the per-key offsets to x, which produces offset buckets (where the - # bucket offset is an integer offset). Then remove this offset to get the - # actual per-key buckets for x. - scale_factors = tf.gather(scale_factor_per_key, adjusted_key_indices) - shifts = tf.gather(shift_per_key, adjusted_key_indices) + # Convert `key_values` to indices in key_vocab. + key_indices = tf_utils.lookup_key(key_values, key_vocab) - transformed_x = x_values * scale_factors + shifts + adjusted_key_indices = tf.where( + key_indices < 0, + _fill_shape(0, tf.shape(key_indices), tf.int64), + key_indices, + ) - offset_buckets = tf_utils.assign_buckets( - transformed_x, bucket_boundaries, side=tf_utils.Side.RIGHT) + # Apply the per-key offsets to x, which produces offset buckets (where the + # bucket offset is an integer offset). Then remove this offset to get the + # actual per-key buckets for x. + scale_factors = tf.gather(scale_factor_per_key, adjusted_key_indices) + shifts = tf.gather(shift_per_key, adjusted_key_indices) - max_bucket = num_buckets - 1 + transformed_x = x_values * scale_factors + shifts - # Shift the bucket numbers back to the correct range [0, num_buckets]. - # We use max_bucket-1 due to different keys sharing 1 boundary. - corrected_buckets = offset_buckets - ( - (max_bucket - 1) * adjusted_key_indices) - bucketized_values = tf.clip_by_value(corrected_buckets, 0, max_bucket) + offset_buckets = tf_utils.assign_buckets( + transformed_x, bucket_boundaries, side=tf_utils.Side.RIGHT + ) - # Set values with missing keys as -1. - bucketized_values = tf.where(key_indices < 0, key_indices, - bucketized_values) + max_bucket = num_buckets - 1 - # Attach the relevant metadata to result, so that the corresponding - # output feature will have this metadata set. - min_value = tf.constant(0, tf.int64) - schema_inference.set_tensor_schema_override( - bucketized_values, min_value, max_bucket) + # Shift the bucket numbers back to the correct range [0, num_buckets]. + # We use max_bucket-1 due to different keys sharing 1 boundary. + corrected_buckets = offset_buckets - ((max_bucket - 1) * adjusted_key_indices) + bucketized_values = tf.clip_by_value(corrected_buckets, 0, max_bucket) - return compose_result_fn(bucketized_values) + # Set values with missing keys as -1. + bucketized_values = tf.where(key_indices < 0, key_indices, bucketized_values) + + # Attach the relevant metadata to result, so that the corresponding + # output feature will have this metadata set. + min_value = tf.constant(0, tf.int64) + schema_inference.set_tensor_schema_override( + bucketized_values, min_value, max_bucket + ) + + return compose_result_fn(bucketized_values) @common.log_api_use(common.MAPPER_COLLECTION) def apply_buckets_with_interpolation( x: common_types.ConsistentTensorType, bucket_boundaries: common_types.BucketBoundariesType, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - """Interpolates within the provided buckets and then normalizes to 0 to 1. - - A method for normalizing continuous numeric data to the range [0, 1]. - Numeric values are first bucketized according to the provided boundaries, then - linearly interpolated within their respective bucket ranges. Finally, the - interpolated values are normalized to the range [0, 1]. Values that are - less than or equal to the lowest boundary, or greater than or equal to the - highest boundary, will be mapped to 0 and 1 respectively. NaN values will be - mapped to the middle of the range (.5). - - This is a non-linear approach to normalization that is less sensitive to - outliers than min-max or z-score scaling. When outliers are present, standard - forms of normalization can leave the majority of the data compressed into a - very small segment of the output range, whereas this approach tends to spread - out the more frequent values (if quantile buckets are used). Note that - distance relationships in the raw data are not necessarily preserved (data - points that close to each other in the raw feature space may not be equally - close in the transformed feature space). This means that unlike linear - normalization methods, correlations between features may be distorted by the - transformation. This scaling method may help with stability and minimize - exploding gradients in neural networks. - - Args: - x: A numeric input `Tensor`, `SparseTensor`, or `RaggedTensor` - (tf.float[32|64], tf.int[32|64]). - bucket_boundaries: Sorted bucket boundaries as a rank-2 `Tensor` or list. - name: (Optional) A name for this operation. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` of the same shape as `x`, - normalized to the range [0, 1]. If the input x is tf.float64, the returned - values will be tf.float64. Otherwise, returned values are tf.float32. - """ - with tf.compat.v1.name_scope(name, 'buckets_with_interpolation'): - bucket_boundaries = tf.convert_to_tensor(bucket_boundaries) - tf.compat.v1.assert_rank(bucket_boundaries, 2) - x_values = tf_utils.get_values(x) - compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) - if not (x_values.dtype.is_floating or x_values.dtype.is_integer): - raise ValueError( - 'Input tensor to be normalized must be numeric, got {}.'.format( - x_values.dtype)) - # Remove any non-finite boundaries. - if bucket_boundaries.dtype in (tf.float64, tf.float32): - bucket_boundaries = tf.expand_dims( - tf.gather_nd(bucket_boundaries, - tf.where(tf.math.is_finite(bucket_boundaries))), - axis=0) - return_type = tf.float64 if x.dtype == tf.float64 else tf.float32 - num_boundaries = tf.cast( - tf.shape(bucket_boundaries)[1], dtype=tf.int64, name='num_boundaries') - assert_some_finite_boundaries = tf.compat.v1.assert_greater( - num_boundaries, - tf.constant(0, tf.int64), - name='assert_1_or_more_finite_boundaries') - with tf.control_dependencies([assert_some_finite_boundaries]): - bucket_indices = tf_utils.assign_buckets( - x_values, bucket_boundaries, side=tf_utils.Side.RIGHT) - # Get max, min, and width of the corresponding bucket for each element. - bucket_max = tf.cast( - tf.gather( - tf.concat([bucket_boundaries[0], bucket_boundaries[:, -1]], - axis=0), bucket_indices), return_type) - bucket_min = tf.cast( - tf.gather( - tf.concat([bucket_boundaries[:, 0], bucket_boundaries[0]], - axis=0), bucket_indices), return_type) - bucket_width = bucket_max - bucket_min - zeros = tf.zeros_like(x_values, dtype=return_type) - ones = tf.ones_like(x_values, dtype=return_type) - - # Linearly interpolate each value within its respective bucket range. - interpolation_value = ( - (tf.cast(x_values, return_type) - bucket_min) / bucket_width) - bucket_interpolation = tf.compat.v1.verify_tensor_all_finite( - tf.where( - # If bucket index is first or last, which represents "less than - # min" and "greater than max" respectively, the bucket logically - # has an infinite width and we can't meaningfully interpolate. - tf.logical_or( - tf.equal(bucket_indices, 0), - tf.equal(bucket_indices, num_boundaries)), - zeros, + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + """Interpolates within the provided buckets and then normalizes to 0 to 1. + + A method for normalizing continuous numeric data to the range [0, 1]. + Numeric values are first bucketized according to the provided boundaries, then + linearly interpolated within their respective bucket ranges. Finally, the + interpolated values are normalized to the range [0, 1]. Values that are + less than or equal to the lowest boundary, or greater than or equal to the + highest boundary, will be mapped to 0 and 1 respectively. NaN values will be + mapped to the middle of the range (.5). + + This is a non-linear approach to normalization that is less sensitive to + outliers than min-max or z-score scaling. When outliers are present, standard + forms of normalization can leave the majority of the data compressed into a + very small segment of the output range, whereas this approach tends to spread + out the more frequent values (if quantile buckets are used). Note that + distance relationships in the raw data are not necessarily preserved (data + points that close to each other in the raw feature space may not be equally + close in the transformed feature space). This means that unlike linear + normalization methods, correlations between features may be distorted by the + transformation. This scaling method may help with stability and minimize + exploding gradients in neural networks. + + Args: + ---- + x: A numeric input `Tensor`, `SparseTensor`, or `RaggedTensor` + (tf.float[32|64], tf.int[32|64]). + bucket_boundaries: Sorted bucket boundaries as a rank-2 `Tensor` or list. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` of the same shape as `x`, + normalized to the range [0, 1]. If the input x is tf.float64, the returned + values will be tf.float64. Otherwise, returned values are tf.float32. + """ + with tf.compat.v1.name_scope(name, "buckets_with_interpolation"): + bucket_boundaries = tf.convert_to_tensor(bucket_boundaries) + tf.compat.v1.assert_rank(bucket_boundaries, 2) + x_values = tf_utils.get_values(x) + compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) + if not (x_values.dtype.is_floating or x_values.dtype.is_integer): + raise ValueError( + f"Input tensor to be normalized must be numeric, got {x_values.dtype}." + ) + # Remove any non-finite boundaries. + if bucket_boundaries.dtype in (tf.float64, tf.float32): + bucket_boundaries = tf.expand_dims( + tf.gather_nd( + bucket_boundaries, tf.where(tf.math.is_finite(bucket_boundaries)) + ), + axis=0, + ) + return_type = tf.float64 if x.dtype == tf.float64 else tf.float32 + num_boundaries = tf.cast( + tf.shape(bucket_boundaries)[1], dtype=tf.int64, name="num_boundaries" + ) + assert_some_finite_boundaries = tf.compat.v1.assert_greater( + num_boundaries, + tf.constant(0, tf.int64), + name="assert_1_or_more_finite_boundaries", + ) + with tf.control_dependencies([assert_some_finite_boundaries]): + bucket_indices = tf_utils.assign_buckets( + x_values, bucket_boundaries, side=tf_utils.Side.RIGHT + ) + # Get max, min, and width of the corresponding bucket for each element. + bucket_max = tf.cast( + tf.gather( + tf.concat([bucket_boundaries[0], bucket_boundaries[:, -1]], axis=0), + bucket_indices, + ), + return_type, + ) + bucket_min = tf.cast( + tf.gather( + tf.concat([bucket_boundaries[:, 0], bucket_boundaries[0]], axis=0), + bucket_indices, + ), + return_type, + ) + bucket_width = bucket_max - bucket_min + zeros = tf.zeros_like(x_values, dtype=return_type) + ones = tf.ones_like(x_values, dtype=return_type) + + # Linearly interpolate each value within its respective bucket range. + interpolation_value = ( + tf.cast(x_values, return_type) - bucket_min + ) / bucket_width + bucket_interpolation = tf.compat.v1.verify_tensor_all_finite( tf.where( - # If the bucket width is zero due to numerical imprecision, - # there is no point in interpolating - tf.equal(bucket_width, 0.0), - ones / 2.0, - # Finally, for a bucket with a valid width, we can interpolate. - interpolation_value)), - 'bucket_interpolation') - bucket_indices_with_interpolation = tf.cast( - tf.maximum(bucket_indices - 1, 0), return_type) + bucket_interpolation - - # Normalize the interpolated values to the range [0, 1]. - denominator = tf.cast(tf.maximum(num_boundaries - 1, 1), return_type) - normalized_values = bucket_indices_with_interpolation / denominator - if x_values.dtype.is_floating: - # Impute NaNs with .5, the middle value of the normalized output range. - imputed_values = tf.ones_like(x_values, dtype=return_type) / 2.0 - normalized_values = tf.where( - tf.math.is_nan(x_values), imputed_values, normalized_values) - # If there is only one boundary, all values < the boundary are 0, all values - # >= the boundary are 1. - single_boundary_values = lambda: tf.where( # pylint: disable=g-long-lambda - tf.equal(bucket_indices, 0), zeros, ones) - normalized_result = tf.cond( - tf.equal(num_boundaries, 1), - single_boundary_values, lambda: normalized_values) - return compose_result_fn(normalized_result) + # If bucket index is first or last, which represents "less than + # min" and "greater than max" respectively, the bucket logically + # has an infinite width and we can't meaningfully interpolate. + tf.logical_or( + tf.equal(bucket_indices, 0), + tf.equal(bucket_indices, num_boundaries), + ), + zeros, + tf.where( + # If the bucket width is zero due to numerical imprecision, + # there is no point in interpolating + tf.equal(bucket_width, 0.0), + ones / 2.0, + # Finally, for a bucket with a valid width, we can interpolate. + interpolation_value, + ), + ), + "bucket_interpolation", + ) + bucket_indices_with_interpolation = ( + tf.cast(tf.maximum(bucket_indices - 1, 0), return_type) + + bucket_interpolation + ) + + # Normalize the interpolated values to the range [0, 1]. + denominator = tf.cast(tf.maximum(num_boundaries - 1, 1), return_type) + normalized_values = bucket_indices_with_interpolation / denominator + if x_values.dtype.is_floating: + # Impute NaNs with .5, the middle value of the normalized output range. + imputed_values = tf.ones_like(x_values, dtype=return_type) / 2.0 + normalized_values = tf.where( + tf.math.is_nan(x_values), imputed_values, normalized_values + ) + # If there is only one boundary, all values < the boundary are 0, all values + # >= the boundary are 1. + single_boundary_values = lambda: tf.where( # pylint: disable=g-long-lambda + tf.equal(bucket_indices, 0), zeros, ones + ) + normalized_result = tf.cond( + tf.equal(num_boundaries, 1), + single_boundary_values, + lambda: normalized_values, + ) + return compose_result_fn(normalized_result) @common.log_api_use(common.MAPPER_COLLECTION) def apply_buckets( x: common_types.ConsistentTensorType, bucket_boundaries: common_types.BucketBoundariesType, - name: Optional[str] = None) -> common_types.ConsistentTensorType: - """Returns a bucketized column, with a bucket index assigned to each input. - - Each element `e` in `x` is mapped to a positive index `i` for which - `bucket_boundaries[i-1] <= e < bucket_boundaries[i]`, if it exists. - If `e < bucket_boundaries[0]`, then `e` is mapped to `0`. If - `e >= bucket_boundaries[-1]`, then `e` is mapped to `len(bucket_boundaries)`. - NaNs are mapped to `len(bucket_boundaries)`. - - Example: - - >>> x = tf.constant([[4.0, float('nan'), 1.0], [float('-inf'), 7.5, 10.0]]) - >>> bucket_boundaries = tf.constant([[2.0, 5.0, 10.0]]) - >>> tft.apply_buckets(x, bucket_boundaries) - - - Args: - x: A numeric input `Tensor`, `SparseTensor`, or `RaggedTensor` whose values - should be mapped to buckets. For `CompositeTensor`s, the non-missing - values will be mapped to buckets and missing value left missing. - bucket_boundaries: A rank 2 `Tensor` or list representing the bucket - boundaries sorted in ascending order. - name: (Optional) A name for this operation. - - Returns: - A `Tensor`, `SparseTensor`, or `RaggedTensor` of the same shape as `x`, with - each element in the returned tensor representing the bucketized value. - Bucketized value is in the range [0, len(bucket_boundaries)]. - """ - with tf.compat.v1.name_scope(name, 'apply_buckets'): - bucket_boundaries = tf.convert_to_tensor(bucket_boundaries) - tf.compat.v1.assert_rank(bucket_boundaries, 2) - - bucketized_values = tf_utils.assign_buckets( - tf_utils.get_values(x), bucket_boundaries, side=tf_utils.Side.RIGHT) - - # Attach the relevant metadata to result, so that the corresponding - # output feature will have this metadata set. - min_value = tf.constant(0, tf.int64) - max_value = tf.shape(input=bucket_boundaries)[1] - schema_inference.set_tensor_schema_override( - bucketized_values, min_value, max_value) - _annotate_buckets(bucketized_values, bucket_boundaries) - compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) - return compose_result_fn(bucketized_values) + name: Optional[str] = None, +) -> common_types.ConsistentTensorType: + """Returns a bucketized column, with a bucket index assigned to each input. + + Each element `e` in `x` is mapped to a positive index `i` for which + `bucket_boundaries[i-1] <= e < bucket_boundaries[i]`, if it exists. + If `e < bucket_boundaries[0]`, then `e` is mapped to `0`. If + `e >= bucket_boundaries[-1]`, then `e` is mapped to `len(bucket_boundaries)`. + NaNs are mapped to `len(bucket_boundaries)`. + + Example: + ------- + >>> x = tf.constant([[4.0, float('nan'), 1.0], [float('-inf'), 7.5, 10.0]]) + >>> bucket_boundaries = tf.constant([[2.0, 5.0, 10.0]]) + >>> tft.apply_buckets(x, bucket_boundaries) + + + Args: + ---- + x: A numeric input `Tensor`, `SparseTensor`, or `RaggedTensor` whose values + should be mapped to buckets. For `CompositeTensor`s, the non-missing + values will be mapped to buckets and missing value left missing. + bucket_boundaries: A rank 2 `Tensor` or list representing the bucket + boundaries sorted in ascending order. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor`, `SparseTensor`, or `RaggedTensor` of the same shape as `x`, with + each element in the returned tensor representing the bucketized value. + Bucketized value is in the range [0, len(bucket_boundaries)]. + """ + with tf.compat.v1.name_scope(name, "apply_buckets"): + bucket_boundaries = tf.convert_to_tensor(bucket_boundaries) + tf.compat.v1.assert_rank(bucket_boundaries, 2) + + bucketized_values = tf_utils.assign_buckets( + tf_utils.get_values(x), bucket_boundaries, side=tf_utils.Side.RIGHT + ) + + # Attach the relevant metadata to result, so that the corresponding + # output feature will have this metadata set. + min_value = tf.constant(0, tf.int64) + max_value = tf.shape(input=bucket_boundaries)[1] + schema_inference.set_tensor_schema_override( + bucketized_values, min_value, max_value + ) + _annotate_buckets(bucketized_values, bucket_boundaries) + compose_result_fn = _make_composite_tensor_wrapper_if_composite(x) + return compose_result_fn(bucketized_values) def _annotate_buckets(x: tf.Tensor, bucket_boundaries: tf.Tensor) -> None: - """Annotates a bucketized tensor with the boundaries that were applied. - - Creates a deferred annotation for the specified tensor. - - Args: - x: The tensor to annotate. - bucket_boundaries: A tensor of boundaries that were used to bucketize x. - """ - # The annotations proto currently isn't available in OSS builds, so schema - # annotations are not supported. - if not common.IS_ANNOTATIONS_PB_AVAILABLE: - return - from tensorflow_transform import annotations_pb2 # pylint: disable=g-import-not-at-top - message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name - - # The BucketBoundaries annotation expects a float field. - bucket_boundaries = tf.cast(bucket_boundaries, tf.float32) - # Some callers provide rank 2 boundaries like [[.25], [.5], [.75], [1.]], - # whereas we expect rank 2 boundaries like [[.25, .5, .75, 1.]] - bucket_boundaries = tf.reshape(bucket_boundaries, [-1]) - bucket_boundaries = tf.expand_dims(bucket_boundaries, 0) - size = (tf.shape(bucket_boundaries)[1],) - message_proto = tf.raw_ops.EncodeProto(sizes=[size], - values=[bucket_boundaries], - field_names=['boundaries'], - message_type=message_type) - assert message_proto.shape == [1] - message_proto = message_proto[0] - - type_url = os.path.join(common.ANNOTATION_PREFIX_URL, message_type) - schema_inference.annotate(type_url, message_proto, tensor=x) + """Annotates a bucketized tensor with the boundaries that were applied. + + Creates a deferred annotation for the specified tensor. + + Args: + ---- + x: The tensor to annotate. + bucket_boundaries: A tensor of boundaries that were used to bucketize x. + """ + # The annotations proto currently isn't available in OSS builds, so schema + # annotations are not supported. + if not common.IS_ANNOTATIONS_PB_AVAILABLE: + return + from tensorflow_transform import ( + annotations_pb2, # pylint: disable=g-import-not-at-top + ) + + message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name + + # The BucketBoundaries annotation expects a float field. + bucket_boundaries = tf.cast(bucket_boundaries, tf.float32) + # Some callers provide rank 2 boundaries like [[.25], [.5], [.75], [1.]], + # whereas we expect rank 2 boundaries like [[.25, .5, .75, 1.]] + bucket_boundaries = tf.reshape(bucket_boundaries, [-1]) + bucket_boundaries = tf.expand_dims(bucket_boundaries, 0) + size = (tf.shape(bucket_boundaries)[1],) + message_proto = tf.raw_ops.EncodeProto( + sizes=[size], + values=[bucket_boundaries], + field_names=["boundaries"], + message_type=message_type, + ) + assert message_proto.shape == [1] + message_proto = message_proto[0] + + type_url = os.path.join(common.ANNOTATION_PREFIX_URL, message_type) + schema_inference.annotate(type_url, message_proto, tensor=x) @common.log_api_use(common.MAPPER_COLLECTION) -def estimated_probability_density(x: tf.Tensor, - boundaries: Optional[Union[tf.Tensor, - int]] = None, - categorical: bool = False, - name: Optional[str] = None) -> tf.Tensor: - """Computes an approximate probability density at each x, given the bins. - - Using this type of fixed-interval method has several benefits compared to - bucketization, although may not always be preferred. - 1. Quantiles does not work on categorical data. - 2. The quantiles algorithm does not currently operate on multiple features - jointly, only independently. - - Ex: Outlier detection in a multi-modal or arbitrary distribution. - Imagine a value x where a simple model is highly predictive of a target y - within certain densely populated ranges. Outside these ranges, we may want - to treat the data differently, but there are too few samples for the model - to detect them by case-by-case treatment. - One option would be to use the density estimate for this purpose: - - outputs['x_density'] = tft.estimated_prob(inputs['x'], bins=100) - outputs['outlier_x'] = tf.where(outputs['x_density'] < OUTLIER_THRESHOLD, - tf.constant([1]), tf.constant([0])) - - This exercise uses a single variable for illustration, but a direct density - metric would become more useful with higher dimensions. - - Note that we normalize by average bin_width to arrive at a probability density - estimate. The result resembles a pdf, not the probability that a value falls - in the bucket (except in the categorical case). - - Args: - x: A `Tensor`. - boundaries: (Optional) A `Tensor` or int used to approximate the density. - If possible provide boundaries as a Tensor of multiple sorted values. - Will default to 10 intervals over the 0-1 range, or find the min/max - if an int is provided (not recommended because multi-phase analysis is - inefficient). If the boundaries are known as potentially arbitrary - interval boundaries, sizes are assumed to be equal. If the sizes are - unequal, density may be inaccurate. Ignored if `categorical` is true. - categorical: (Optional) A `bool` that will treat x as categorical if true. - name: (Optional) A name for this operation. - - Returns: - A `Tensor` the same shape as x, the probability density estimate at x (or - probability mass estimate if `categorical` is True). - - Raises: - NotImplementedError: If `x` is CompositeTensor. - """ - with tf.compat.v1.name_scope(name, 'estimated_probability_density'): - if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): - raise NotImplementedError( - 'estimated probability density does not support Composite Tensors') - if x.get_shape().ndims > 1 and x.shape[-1] > 1: - raise NotImplementedError( - 'estimated probability density does not support multiple dimensions') - - counts, boundaries = analyzers.histogram(x, boundaries=boundaries, - categorical=categorical) - - xdims = x.get_shape().ndims - counts = tf.cast(counts, tf.float32) - probabilities = counts / tf.reduce_sum(counts) - - x = tf.reshape(x, [-1]) - - if categorical: - bucket_indices = tf_utils.lookup_key(x, boundaries) - bucket_densities = probabilities - else: - # We need to compute the bin width so that density does not depend on - # number of intervals. - bin_width = tf.cast(boundaries[0, -1] - boundaries[0, 0], tf.float32) / ( - tf.cast(tf.size(probabilities), tf.float32)) - bucket_densities = probabilities / bin_width - - bucket_indices = tf_utils.assign_buckets( - tf.cast(x, tf.float32), - analyzers.remove_leftmost_boundary(boundaries)) - bucket_indices = tf_utils._align_dims(bucket_indices, xdims) # pylint: disable=protected-access - - # In the categorical case, when keys are missing, the indices may be -1, - # therefore we replace those with 0 in order to use tf.gather. - adjusted_bucket_indices = tf.where( - bucket_indices < 0, _fill_shape(0, tf.shape(bucket_indices), tf.int64), - bucket_indices) - bucket_densities = tf.gather(bucket_densities, adjusted_bucket_indices) - return tf.where(bucket_indices < 0, - _fill_shape(0, tf.shape(bucket_indices), tf.float32), - bucket_densities) +def estimated_probability_density( + x: tf.Tensor, + boundaries: Optional[Union[tf.Tensor, int]] = None, + categorical: bool = False, + name: Optional[str] = None, +) -> tf.Tensor: + """Computes an approximate probability density at each x, given the bins. + + Using this type of fixed-interval method has several benefits compared to + bucketization, although may not always be preferred. + 1. Quantiles does not work on categorical data. + 2. The quantiles algorithm does not currently operate on multiple features + jointly, only independently. + + Ex: Outlier detection in a multi-modal or arbitrary distribution. + Imagine a value x where a simple model is highly predictive of a target y + within certain densely populated ranges. Outside these ranges, we may want + to treat the data differently, but there are too few samples for the model + to detect them by case-by-case treatment. + One option would be to use the density estimate for this purpose: + + outputs['x_density'] = tft.estimated_prob(inputs['x'], bins=100) + outputs['outlier_x'] = tf.where(outputs['x_density'] < OUTLIER_THRESHOLD, + tf.constant([1]), tf.constant([0])) + + This exercise uses a single variable for illustration, but a direct density + metric would become more useful with higher dimensions. + + Note that we normalize by average bin_width to arrive at a probability density + estimate. The result resembles a pdf, not the probability that a value falls + in the bucket (except in the categorical case). + + Args: + ---- + x: A `Tensor`. + boundaries: (Optional) A `Tensor` or int used to approximate the density. + If possible provide boundaries as a Tensor of multiple sorted values. + Will default to 10 intervals over the 0-1 range, or find the min/max + if an int is provided (not recommended because multi-phase analysis is + inefficient). If the boundaries are known as potentially arbitrary + interval boundaries, sizes are assumed to be equal. If the sizes are + unequal, density may be inaccurate. Ignored if `categorical` is true. + categorical: (Optional) A `bool` that will treat x as categorical if true. + name: (Optional) A name for this operation. + + Returns: + ------- + A `Tensor` the same shape as x, the probability density estimate at x (or + probability mass estimate if `categorical` is True). + + Raises: + ------ + NotImplementedError: If `x` is CompositeTensor. + """ + with tf.compat.v1.name_scope(name, "estimated_probability_density"): + if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): + raise NotImplementedError( + "estimated probability density does not support Composite Tensors" + ) + if x.get_shape().ndims > 1 and x.shape[-1] > 1: + raise NotImplementedError( + "estimated probability density does not support multiple dimensions" + ) + + counts, boundaries = analyzers.histogram( + x, boundaries=boundaries, categorical=categorical + ) + + xdims = x.get_shape().ndims + counts = tf.cast(counts, tf.float32) + probabilities = counts / tf.reduce_sum(counts) + + x = tf.reshape(x, [-1]) + + if categorical: + bucket_indices = tf_utils.lookup_key(x, boundaries) + bucket_densities = probabilities + else: + # We need to compute the bin width so that density does not depend on + # number of intervals. + bin_width = tf.cast(boundaries[0, -1] - boundaries[0, 0], tf.float32) / ( + tf.cast(tf.size(probabilities), tf.float32) + ) + bucket_densities = probabilities / bin_width + + bucket_indices = tf_utils.assign_buckets( + tf.cast(x, tf.float32), analyzers.remove_leftmost_boundary(boundaries) + ) + bucket_indices = tf_utils._align_dims(bucket_indices, xdims) # pylint: disable=protected-access + + # In the categorical case, when keys are missing, the indices may be -1, + # therefore we replace those with 0 in order to use tf.gather. + adjusted_bucket_indices = tf.where( + bucket_indices < 0, + _fill_shape(0, tf.shape(bucket_indices), tf.int64), + bucket_indices, + ) + bucket_densities = tf.gather(bucket_densities, adjusted_bucket_indices) + return tf.where( + bucket_indices < 0, + _fill_shape(0, tf.shape(bucket_indices), tf.float32), + bucket_densities, + ) diff --git a/tensorflow_transform/mappers_test.py b/tensorflow_transform/mappers_test.py index 9ff41b9..e6a6004 100644 --- a/tensorflow_transform/mappers_test.py +++ b/tensorflow_transform/mappers_test.py @@ -14,923 +14,1220 @@ """Tests for tensorflow_transform.mappers.""" import numpy as np - import tensorflow as tf -from tensorflow_transform import mappers -from tensorflow_transform import test_case + +from tensorflow_transform import mappers, test_case mock = tf.compat.v1.test.mock class MappersTest(test_case.TransformTestCase): + def assertSparseOutput( + self, + expected_indices, + expected_values, + expected_shape, + actual_sparse_tensor, + close_values, + ): + actual = self.evaluate(actual_sparse_tensor) + self.assertAllEqual(expected_indices, actual.indices) + self.assertAllEqual(expected_shape, actual.dense_shape) + if close_values: + self.assertAllClose(expected_values, actual.values) + else: + self.assertAllEqual(expected_values, actual.values) + + def testSegmentIndices(self): + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session(): + self.assertAllEqual( + mappers.segment_indices( + tf.constant([0, 0, 1, 2, 2, 2], tf.int64), name="test_name" + ).eval(), + [0, 1, 0, 0, 1, 2], + ) + self.assertAllEqual( + mappers.segment_indices(tf.constant([], tf.int64)).eval(), [] + ) + + def testSegmentIndicesSkipOne(self): + with tf.compat.v1.Graph().as_default(): + input_tensor = tf.constant([0, 0, 2, 2]) + with tf.compat.v1.Session(): + self.assertAllEqual( + [0, 1, 0, 1], mappers.segment_indices(input_tensor).eval() + ) + + def testNGramsEmpty(self): + with tf.compat.v1.Graph().as_default(): + output_tensor = mappers.ngrams( + tf.compat.v1.strings.split(tf.constant([""])), (1, 5), "" + ) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertEqual((0, 2), output.indices.shape) + self.assertAllEqual([1, 0], output.dense_shape) + self.assertEqual(0, len(output.values)) + + def testNGrams(self): + with tf.compat.v1.Graph().as_default(): + string_tensor = tf.constant(["abc", "def", "fghijklm", "z", ""]) + tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter="") + output_tensor = mappers.ngrams( + tokens=tokenized_tensor, ngram_range=(1, 5), separator="" + ) + self.assertSparseOutput( + expected_indices=[ + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [0, 4], + [0, 5], + [1, 0], + [1, 1], + [1, 2], + [1, 3], + [1, 4], + [1, 5], + [2, 0], + [2, 1], + [2, 2], + [2, 3], + [2, 4], + [2, 5], + [2, 6], + [2, 7], + [2, 8], + [2, 9], + [2, 10], + [2, 11], + [2, 12], + [2, 13], + [2, 14], + [2, 15], + [2, 16], + [2, 17], + [2, 18], + [2, 19], + [2, 20], + [2, 21], + [2, 22], + [2, 23], + [2, 24], + [2, 25], + [2, 26], + [2, 27], + [2, 28], + [2, 29], + [3, 0], + ], + expected_values=[ + b"a", + b"ab", + b"abc", + b"b", + b"bc", + b"c", + b"d", + b"de", + b"def", + b"e", + b"ef", + b"f", + b"f", + b"fg", + b"fgh", + b"fghi", + b"fghij", + b"g", + b"gh", + b"ghi", + b"ghij", + b"ghijk", + b"h", + b"hi", + b"hij", + b"hijk", + b"hijkl", + b"i", + b"ij", + b"ijk", + b"ijkl", + b"ijklm", + b"j", + b"jk", + b"jkl", + b"jklm", + b"k", + b"kl", + b"klm", + b"l", + b"lm", + b"m", + b"z", + ], + expected_shape=[5, 30], + actual_sparse_tensor=output_tensor, + close_values=False, + ) + + def testNGramsMinSizeNotOne(self): + with tf.compat.v1.Graph().as_default(): + string_tensor = tf.constant(["abc", "def", "fghijklm", "z", ""]) + tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter="") + output_tensor = mappers.ngrams( + tokens=tokenized_tensor, ngram_range=(2, 5), separator="" + ) + self.assertSparseOutput( + expected_indices=[ + [0, 0], + [0, 1], + [0, 2], + [1, 0], + [1, 1], + [1, 2], + [2, 0], + [2, 1], + [2, 2], + [2, 3], + [2, 4], + [2, 5], + [2, 6], + [2, 7], + [2, 8], + [2, 9], + [2, 10], + [2, 11], + [2, 12], + [2, 13], + [2, 14], + [2, 15], + [2, 16], + [2, 17], + [2, 18], + [2, 19], + [2, 20], + [2, 21], + ], + expected_values=[ + b"ab", + b"abc", + b"bc", + b"de", + b"def", + b"ef", + b"fg", + b"fgh", + b"fghi", + b"fghij", + b"gh", + b"ghi", + b"ghij", + b"ghijk", + b"hi", + b"hij", + b"hijk", + b"hijkl", + b"ij", + b"ijk", + b"ijkl", + b"ijklm", + b"jk", + b"jkl", + b"jklm", + b"kl", + b"klm", + b"lm", + ], + expected_shape=[5, 22], + actual_sparse_tensor=output_tensor, + close_values=False, + ) + + def testNGramsWithSpaceSeparator(self): + with tf.compat.v1.Graph().as_default(): + string_tensor = tf.constant(["One was Johnny", "Two was a rat"]) + tokenized_tensor = tf.compat.v1.strings.split(string_tensor, sep=" ") + output_tensor = mappers.ngrams( + tokens=tokenized_tensor, ngram_range=(1, 2), separator=" " + ) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertAllEqual( + output.indices, + [ + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [0, 4], + [1, 0], + [1, 1], + [1, 2], + [1, 3], + [1, 4], + [1, 5], + [1, 6], + ], + ) + self.assertAllEqual( + output.values, + [ + b"One", + b"One was", + b"was", + b"was Johnny", + b"Johnny", + b"Two", + b"Two was", + b"was", + b"was a", + b"a", + b"a rat", + b"rat", + ], + ) + self.assertAllEqual(output.dense_shape, [2, 7]) + + def testNGramsWithRepeatedTokensPerRow(self): + with tf.compat.v1.Graph().as_default(): + string_tensor = tf.constant(["Cats or dogs or bunnies", "Cats not rats"]) + tokenized_tensor = tf.compat.v1.strings.split(string_tensor, sep=" ") + output_tensor = mappers.ngrams( + tokens=tokenized_tensor, ngram_range=(1, 1), separator=" " + ) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertAllEqual( + output.indices, + [ + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [0, 4], + [1, 0], + [1, 1], + [1, 2], + ], + ) + # Note: the ngram "or" is represented twice for the first document. + self.assertAllEqual( + output.values, + [ + b"Cats", + b"or", + b"dogs", + b"or", + b"bunnies", + b"Cats", + b"not", + b"rats", + ], + ) + self.assertAllEqual(output.dense_shape, [2, 5]) + + def testNGramsBadSizes(self): + string_tensor = tf.constant(["abc", "def", "fghijklm", "z", ""]) + tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter="") + with self.assertRaisesRegex(ValueError, "Invalid ngram_range"): + mappers.ngrams(tokenized_tensor, (0, 5), separator="") + with self.assertRaisesRegex(ValueError, "Invalid ngram_range"): + mappers.ngrams(tokenized_tensor, (6, 5), separator="") + + def testNGramsBagOfWordsEmpty(self): + with tf.compat.v1.Graph().as_default(): + string_tensor = tf.constant([], dtype=tf.string) + tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter="") + ngrams = mappers.ngrams(tokenized_tensor, (1, 2), separator="") + bow = mappers.bag_of_words(tokenized_tensor, (1, 2), separator="") + with tf.compat.v1.Session(): + ngrams_output = ngrams.eval() + bow_output = bow.eval() + self.assertAllEqual(ngrams_output.values, []) + self.assertAllEqual(bow_output.values, []) + self.assertAllEqual(ngrams_output.dense_shape, [0, 0]) + self.assertAllEqual(bow_output.dense_shape, [0, 0]) + + @test_case.named_parameters( + dict( + testcase_name="bag_of_words", + strings=["snakes or dogs and bunnies", "cats not rats"], + expected_output_indices=[ + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [0, 4], + [1, 0], + [1, 1], + [1, 2], + ], + expected_output_values=[ + b"snakes", + b"or", + b"dogs", + b"and", + b"bunnies", + b"cats", + b"not", + b"rats", + ], + ), + dict( + testcase_name="bag_of_words_duplicates_within_rows", + strings=["Cats or dogs or bunnies", "Cats not rats"], + expected_output_indices=[ + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [1, 0], + [1, 1], + [1, 2], + ], + expected_output_values=[ + b"Cats", + b"or", + b"dogs", + b"bunnies", + b"Cats", + b"not", + b"rats", + ], + ), + dict( + testcase_name="bag_of_words_duplicates_across_rows", + strings=["cats or dogs or cats", "cats or dogs"], + expected_output_indices=[ + [0, 0], + [0, 1], + [0, 2], + [1, 0], + [1, 1], + [1, 2], + ], + expected_output_values=[b"cats", b"or", b"dogs", b"cats", b"or", b"dogs"], + ), + dict( + testcase_name="bag_of_words_some_empty", + strings=["boots and cats and boots and cats", "", "cats or dogs", ""], + expected_output_indices=[ + [0, 0], + [0, 1], + [0, 2], + [2, 0], + [2, 1], + [2, 2], + ], + expected_output_values=[b"boots", b"and", b"cats", b"cats", b"or", b"dogs"], + ), + dict( + testcase_name="bag_of_words_bigrams", + strings=["i like cats and i like cats to pet", "i like cats"], + expected_output_indices=[ + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [0, 4], + [0, 5], + [1, 0], + [1, 1], + ], + # bigrams 'i like' and 'like cats' appear twice in the input but only + # once in the output for that row. + expected_output_values=[ + b"i like", + b"like cats", + b"cats and", + b"and i", + b"cats to", + b"to pet", + b"i like", + b"like cats", + ], + ngram_range=[2, 2], + ), + ) + def testBagOfWords( + self, + strings, + expected_output_indices, + expected_output_values, + ngram_range=(1, 1), + separator=" ", + ): + with tf.compat.v1.Graph().as_default(): + string_tensor = tf.constant(strings, dtype=tf.string) + tokenized_tensor = tf.compat.v1.string_split( + string_tensor, delimiter=separator + ) + output_tensor = mappers.bag_of_words( + tokens=tokenized_tensor, ngram_range=ngram_range, separator=separator + ) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertAllEqual(output.indices, expected_output_indices) + self.assertAllEqual(output.values, expected_output_values) + + @test_case.named_parameters( + dict( + testcase_name="deduplicate_no_op", + indices=[ + [0, 0], + [1, 0], + [1, 1], + [1, 2], + ], + values=[b"foo", b"bar", b"biz", b"buzz"], + dense_shape=[2, 3], + expected_output_indices=[ + [0, 0], + [1, 0], + [1, 1], + [1, 2], + ], + expected_output_values=[b"foo", b"bar", b"biz", b"buzz"], + expected_output_shape=[2, 3], + ), + dict( + testcase_name="deduplicate_integers", + indices=[ + [1, 0], + [3, 1], + [3, 2], + [4, 4], + [4, 1], + ], + values=[1, 1, 1, 0, 0], + dense_shape=[5, 5], + expected_output_indices=[ + [1, 0], + [3, 0], + [4, 0], + ], + expected_output_values=[1, 1, 0], + expected_output_shape=[5, 1], + ), + dict( + testcase_name="deduplicate_empty_rows", + indices=[ + [0, 0], + [2, 1], + [2, 2], + [2, 4], + [4, 1], + ], + values=[b"foo", b"bar", b"biz", b"bar", b"foo"], + dense_shape=[5, 5], + expected_output_indices=[ + [0, 0], + [2, 0], + [2, 1], + [4, 0], + ], + expected_output_values=[b"foo", b"bar", b"biz", b"foo"], + expected_output_shape=[5, 2], + ), + dict( + testcase_name="deduplicate_shape_change", + indices=[ + [0, 0], + [0, 3], + [1, 0], + [1, 1], + [1, 2], + ], + values=[b"foo", b"foo", b"bar", b"buzz", b"bar"], + dense_shape=[2, 4], + expected_output_indices=[ + [0, 0], + [1, 0], + [1, 1], + ], + expected_output_values=[b"foo", b"bar", b"buzz"], + expected_output_shape=[2, 2], + ), + ) + def testDedupeSparseTensorPerRow( + self, + indices, + values, + dense_shape, + expected_output_indices, + expected_output_values, + expected_output_shape, + ): + with tf.compat.v1.Graph().as_default(): + sp_input = tf.SparseTensor( + indices=indices, values=values, dense_shape=dense_shape + ) + output_tensor = mappers.deduplicate_tensor_per_row(sp_input) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertAllEqual(output.indices, expected_output_indices) + self.assertAllEqual(output.values, expected_output_values) + self.assertAllEqual(output.dense_shape, expected_output_shape) + + @test_case.named_parameters( + dict( + testcase_name="deduplicate_no_op", + values=[[b"a", b"b"], [b"c", b"d"]], + expected_indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + expected_output=[b"a", b"b", b"c", b"d"], + ), + # Note: because the first dimension is the batch/row dimension, a 1D + # tensor is always returned as is (since there's only 1 value per row). + dict( + testcase_name="deduplicate_1D", + values=[b"a", b"b", b"a", b"d"], + expected_indices=[[0, 0], [1, 0], [2, 0], [3, 0]], + expected_output=[b"a", b"b", b"a", b"d"], + ), + dict( + testcase_name="deduplicate", + values=[[b"a", b"b", b"a", b"b"], [b"c", b"c", b"d", b"d"]], + expected_indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + expected_output=[b"a", b"b", b"c", b"d"], + ), + dict( + testcase_name="deduplicate_different_sizes", + # 2 uniques in the first row, 3 in the second row. + values=[[b"a", b"b", b"a", b"b"], [b"c", b"a", b"d", b"d"]], + expected_indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]], + expected_output=[b"a", b"b", b"c", b"a", b"d"], + ), + dict( + testcase_name="deduplicate_keeps_dups_across_rows", + values=[[b"a", b"b", b"a", b"b"], [b"b", b"a", b"b", b"b"]], + expected_indices=[[0, 0], [0, 1], [1, 0], [1, 1]], + expected_output=[b"a", b"b", b"b", b"a"], + ), + ) + def testDedupeDenseTensorPerRow(self, values, expected_indices, expected_output): + with tf.compat.v1.Graph().as_default(): + dense_input = tf.constant(values) + output_tensor = mappers.deduplicate_tensor_per_row(dense_input) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertAllEqual(output.indices, expected_indices) + self.assertAllEqual(output.values, expected_output) + + def testDedup3dInputRaises(self): + dense_input = tf.constant( + [[[b"a", b"a"], [b"b", b"b"]], [[b"a", b"a"], [b"d", b"d"]]] + ) + with self.assertRaises(ValueError): + mappers.deduplicate_tensor_per_row(dense_input) + + def testWordCountEmpty(self): + with tf.compat.v1.Graph().as_default(): + output_tensor = mappers.word_count( + tf.compat.v1.string_split(tf.constant([""])) + ) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertEqual(1, len(output)) + self.assertEqual(0, sum(output)) + + def testWordCount(self): + with tf.compat.v1.Graph().as_default(): + string_tensor = tf.constant(["abc", "def", "fghijklm", "z", ""]) + tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter="") + output_tensor = mappers.word_count(tokenized_tensor) + output_3d_tensor = mappers.word_count( + tf.sparse.expand_dims( + tf.sparse.expand_dims(tokenized_tensor, axis=1), axis=1 + ) + ) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertEqual(5, len(output)) + self.assertEqual(15, sum(output)) + self.assertAllEqual(output, [3, 3, 8, 1, 0]) + self.assertAllEqual(output, output_3d_tensor.eval()) + + def testWordCountRagged(self): + with tf.compat.v1.Graph().as_default(): + string_tensor = tf.constant(["abc", "def", "fghijklm", "z", ""]) + tokenized_tensor = tf.RaggedTensor.from_sparse( + tf.compat.v1.string_split(string_tensor, delimiter="") + ) + output_tensor = mappers.word_count(tokenized_tensor) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertEqual(5, len(output)) + self.assertEqual(15, sum(output)) + self.assertAllEqual(output, [3, 3, 8, 1, 0]) + + def testTermFrequency(self): + input_tensor = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1]], + [1, 2, 0, 0, 0, 3, 0], + [2, 5], + ) + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + expected_values=[(3 / 5), (1 / 5), (1 / 5), (1 / 2), (1 / 2)], + expected_shape=[2, 4], + actual_sparse_tensor=mappers._to_term_frequency(input_tensor, 4), + close_values=True, + ) + + def testTermFrequencyUnusedTerm(self): + input_tensor = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1]], + [4, 2, 0, 0, 0, 3, 0], + [2, 5], + ) + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 2], [0, 4], [1, 0], [1, 3]], + expected_values=[(3 / 5), (1 / 5), (1 / 5), (1 / 2), (1 / 2)], + expected_shape=[2, 5], + actual_sparse_tensor=mappers._to_term_frequency(input_tensor, 5), + close_values=True, + ) + + def testCountDocsWithTerm(self): + with tf.compat.v1.Graph().as_default(): + input_tensor = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + [(3 / 5), (1 / 5), (1 / 5), (1 / 2), (1 / 2)], + [2, 4], + ) + output_tensor = mappers._count_docs_with_term(input_tensor) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertAllEqual([[2, 1, 1, 1]], output) + + def testCountDocsWithTermUnusedTerm(self): + with tf.compat.v1.Graph().as_default(): + input_tensor = tf.SparseTensor( + [[0, 0], [0, 2], [1, 0], [1, 3]], + [(3 / 5), (1 / 5), (1 / 2), (1 / 2)], + [2, 4], + ) + output_tensor = mappers._count_docs_with_term(input_tensor) + with tf.compat.v1.Session(): + output = output_tensor.eval() + self.assertAllEqual([[2, 0, 1, 1]], output) + + def testToTFIDF(self): + term_freq = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + [(3 / 5), (1 / 5), (1 / 5), (1 / 2), (1 / 2)], + [2, 4], + ) + reduced_term_freq = tf.constant([[2, 1, 1, 1]]) + output_tensor = mappers._to_tfidf( + term_freq, reduced_term_freq, tf.constant(2), True + ) + log_3_over_2 = 1.4054651 + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + expected_values=[ + (3 / 5), + (1 / 5) * log_3_over_2, + (1 / 5) * log_3_over_2, + (1 / 2), + (1 / 2) * log_3_over_2, + ], + expected_shape=[2, 4], + actual_sparse_tensor=output_tensor, + close_values=True, + ) + + def testToTFIDFNotSmooth(self): + term_freq = tf.SparseTensor( + [[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + [(3 / 5), (1 / 5), (1 / 5), (1 / 2), (1 / 2)], + [2, 4], + ) + reduced_term_freq = tf.constant([[2, 1, 1, 1]]) + output_tensor = mappers._to_tfidf( + term_freq, reduced_term_freq, tf.constant(2), False + ) + log_2_over_1 = 1.6931471 + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], + expected_values=[ + (3 / 5), + (1 / 5) * log_2_over_1, + (1 / 5) * log_2_over_1, + (1 / 2), + (1 / 2) * log_2_over_1, + ], + expected_shape=[2, 4], + actual_sparse_tensor=output_tensor, + close_values=True, + ) + + def testSplitTFIDF(self): + tfidfs = tf.SparseTensor( + [[0, 0], [0, 1], [2, 1], [2, 2]], + [0.23104906, 0.19178806, 0.14384104, 0.34657359], + [3, 4], + ) + + out_index, out_weight = mappers._split_tfidfs_to_outputs(tfidfs) + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [2, 0], [2, 1]], + expected_values=[0, 1, 1, 2], + expected_shape=[3, 2], + actual_sparse_tensor=out_index, + close_values=False, + ) + self.assertSparseOutput( + expected_indices=[[0, 0], [0, 1], [2, 0], [2, 1]], + expected_values=[0.23104906, 0.19178806, 0.14384104, 0.34657359], + expected_shape=[3, 2], + actual_sparse_tensor=out_weight, + close_values=True, + ) + + def testSplitTFIDFWithEmptyInput(self): + # TODO(b/123242111): rewrite this test using public functions. + with tf.compat.v1.Graph().as_default(): + tfidf = tf.SparseTensor( + values=tf.constant([], shape=[0], dtype=tf.float32), + indices=tf.constant([], shape=[0, 2], dtype=tf.int64), + dense_shape=[2, 0], + ) - def assertSparseOutput(self, expected_indices, expected_values, - expected_shape, actual_sparse_tensor, close_values): - actual = self.evaluate(actual_sparse_tensor) - self.assertAllEqual(expected_indices, actual.indices) - self.assertAllEqual(expected_shape, actual.dense_shape) - if close_values: - self.assertAllClose(expected_values, actual.values) - else: - self.assertAllEqual(expected_values, actual.values) - - def testSegmentIndices(self): - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session(): - self.assertAllEqual( - mappers.segment_indices(tf.constant([0, 0, 1, 2, 2, 2], tf.int64), - name='test_name').eval(), - [0, 1, 0, 0, 1, 2]) - self.assertAllEqual( - mappers.segment_indices(tf.constant([], tf.int64)).eval(), - []) - - def testSegmentIndicesSkipOne(self): - with tf.compat.v1.Graph().as_default(): - input_tensor = tf.constant([0, 0, 2, 2]) - with tf.compat.v1.Session(): - self.assertAllEqual([0, 1, 0, 1], - mappers.segment_indices(input_tensor).eval()) - - def testNGramsEmpty(self): - with tf.compat.v1.Graph().as_default(): - output_tensor = mappers.ngrams( - tf.compat.v1.strings.split(tf.constant([''])), (1, 5), '') - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertEqual((0, 2), output.indices.shape) - self.assertAllEqual([1, 0], output.dense_shape) - self.assertEqual(0, len(output.values)) - - def testNGrams(self): - with tf.compat.v1.Graph().as_default(): - string_tensor = tf.constant(['abc', 'def', 'fghijklm', 'z', '']) - tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter='') - output_tensor = mappers.ngrams( - tokens=tokenized_tensor, - ngram_range=(1, 5), - separator='') - self.assertSparseOutput( - expected_indices=[ - [0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], - [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], - [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], - [2, 8], [2, 9], [2, 10], [2, 11], [2, 12], [2, 13], [2, 14], - [2, 15], [2, 16], [2, 17], [2, 18], [2, 19], [2, 20], [2, 21], - [2, 22], [2, 23], [2, 24], [2, 25], [2, 26], [2, 27], [2, 28], - [2, 29], [3, 0]], - expected_values=[ - b'a', b'ab', b'abc', b'b', b'bc', b'c', b'd', b'de', b'def', b'e', - b'ef', b'f', b'f', b'fg', b'fgh', b'fghi', b'fghij', b'g', b'gh', - b'ghi', b'ghij', b'ghijk', b'h', b'hi', b'hij', b'hijk', b'hijkl', - b'i', b'ij', b'ijk', b'ijkl', b'ijklm', b'j', b'jk', b'jkl', - b'jklm', b'k', b'kl', b'klm', b'l', b'lm', b'm', b'z' - ], - expected_shape=[5, 30], - actual_sparse_tensor=output_tensor, - close_values=False) - - def testNGramsMinSizeNotOne(self): - with tf.compat.v1.Graph().as_default(): - string_tensor = tf.constant(['abc', 'def', 'fghijklm', 'z', '']) - tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter='') - output_tensor = mappers.ngrams( - tokens=tokenized_tensor, - ngram_range=(2, 5), - separator='') - self.assertSparseOutput( - expected_indices=[ - [0, 0], [0, 1], [0, 2], - [1, 0], [1, 1], [1, 2], - [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5], [2, 6], [2, 7], - [2, 8], [2, 9], [2, 10], [2, 11], [2, 12], [2, 13], [2, 14], - [2, 15], [2, 16], [2, 17], [2, 18], [2, 19], [2, 20], [2, 21]], - expected_values=[ - b'ab', b'abc', b'bc', b'de', b'def', b'ef', b'fg', b'fgh', - b'fghi', b'fghij', b'gh', b'ghi', b'ghij', b'ghijk', b'hi', - b'hij', b'hijk', b'hijkl', b'ij', b'ijk', b'ijkl', b'ijklm', - b'jk', b'jkl', b'jklm', b'kl', b'klm', b'lm' - ], - expected_shape=[5, 22], - actual_sparse_tensor=output_tensor, - close_values=False) - - def testNGramsWithSpaceSeparator(self): - with tf.compat.v1.Graph().as_default(): - string_tensor = tf.constant(['One was Johnny', 'Two was a rat']) - tokenized_tensor = tf.compat.v1.strings.split(string_tensor, sep=' ') - output_tensor = mappers.ngrams( - tokens=tokenized_tensor, - ngram_range=(1, 2), - separator=' ') - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertAllEqual( - output.indices, - [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], - [1, 0], [1, 1], [1, 2], [1, 3], [1, 4], [1, 5], [1, 6]]) - self.assertAllEqual(output.values, [ - b'One', b'One was', b'was', b'was Johnny', b'Johnny', b'Two', - b'Two was', b'was', b'was a', b'a', b'a rat', b'rat' - ]) - self.assertAllEqual(output.dense_shape, [2, 7]) - - def testNGramsWithRepeatedTokensPerRow(self): - with tf.compat.v1.Graph().as_default(): - string_tensor = tf.constant(['Cats or dogs or bunnies', 'Cats not rats']) - tokenized_tensor = tf.compat.v1.strings.split(string_tensor, sep=' ') - output_tensor = mappers.ngrams( - tokens=tokenized_tensor, ngram_range=(1, 1), separator=' ') - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertAllEqual(output.indices, [ - [0, 0], - [0, 1], - [0, 2], - [0, 3], - [0, 4], - [1, 0], - [1, 1], - [1, 2], - ]) - # Note: the ngram "or" is represented twice for the first document. - self.assertAllEqual(output.values, [ - b'Cats', b'or', b'dogs', b'or', b'bunnies', b'Cats', b'not', b'rats' - ]) - self.assertAllEqual(output.dense_shape, [2, 5]) - - def testNGramsBadSizes(self): - string_tensor = tf.constant(['abc', 'def', 'fghijklm', 'z', '']) - tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter='') - with self.assertRaisesRegex(ValueError, 'Invalid ngram_range'): - mappers.ngrams(tokenized_tensor, (0, 5), separator='') - with self.assertRaisesRegex(ValueError, 'Invalid ngram_range'): - mappers.ngrams(tokenized_tensor, (6, 5), separator='') - - def testNGramsBagOfWordsEmpty(self): - with tf.compat.v1.Graph().as_default(): - string_tensor = tf.constant([], dtype=tf.string) - tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter='') - ngrams = mappers.ngrams(tokenized_tensor, (1, 2), separator='') - bow = mappers.bag_of_words(tokenized_tensor, (1, 2), separator='') - with tf.compat.v1.Session(): - ngrams_output = ngrams.eval() - bow_output = bow.eval() - self.assertAllEqual(ngrams_output.values, []) - self.assertAllEqual(bow_output.values, []) - self.assertAllEqual(ngrams_output.dense_shape, [0, 0]) - self.assertAllEqual(bow_output.dense_shape, [0, 0]) - - @test_case.named_parameters( - dict( - testcase_name='bag_of_words', - strings=['snakes or dogs and bunnies', 'cats not rats'], - expected_output_indices=[ - [0, 0], - [0, 1], - [0, 2], - [0, 3], - [0, 4], - [1, 0], - [1, 1], - [1, 2], - ], - expected_output_values=[ - b'snakes', b'or', b'dogs', b'and', b'bunnies', b'cats', b'not', - b'rats' - ]), - dict( - testcase_name='bag_of_words_duplicates_within_rows', - strings=['Cats or dogs or bunnies', 'Cats not rats'], - expected_output_indices=[ - [0, 0], - [0, 1], - [0, 2], - [0, 3], - [1, 0], - [1, 1], - [1, 2], - ], - expected_output_values=[ - b'Cats', b'or', b'dogs', b'bunnies', b'Cats', b'not', b'rats' - ]), - dict( - testcase_name='bag_of_words_duplicates_across_rows', - strings=['cats or dogs or cats', 'cats or dogs'], - expected_output_indices=[ - [0, 0], - [0, 1], - [0, 2], - [1, 0], - [1, 1], - [1, 2], - ], - expected_output_values=[ - b'cats', b'or', b'dogs', b'cats', b'or', b'dogs' - ]), - dict( - testcase_name='bag_of_words_some_empty', - strings=['boots and cats and boots and cats', '', 'cats or dogs', ''], - expected_output_indices=[ - [0, 0], - [0, 1], - [0, 2], - [2, 0], - [2, 1], - [2, 2], - ], - expected_output_values=[ - b'boots', b'and', b'cats', b'cats', b'or', b'dogs' - ]), - dict( - testcase_name='bag_of_words_bigrams', - strings=['i like cats and i like cats to pet', 'i like cats'], - expected_output_indices=[ - [0, 0], - [0, 1], - [0, 2], - [0, 3], - [0, 4], - [0, 5], - [1, 0], - [1, 1], - ], - # bigrams 'i like' and 'like cats' appear twice in the input but only - # once in the output for that row. - expected_output_values=[ - b'i like', - b'like cats', - b'cats and', - b'and i', - b'cats to', - b'to pet', - b'i like', - b'like cats', - ], - ngram_range=[2, 2]), - ) - def testBagOfWords(self, - strings, - expected_output_indices, - expected_output_values, - ngram_range=(1, 1), - separator=' '): - with tf.compat.v1.Graph().as_default(): - string_tensor = tf.constant(strings, dtype=tf.string) - tokenized_tensor = tf.compat.v1.string_split( - string_tensor, delimiter=separator) - output_tensor = mappers.bag_of_words( - tokens=tokenized_tensor, ngram_range=ngram_range, separator=separator) - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertAllEqual(output.indices, expected_output_indices) - self.assertAllEqual(output.values, expected_output_values) - - @test_case.named_parameters( - dict( - testcase_name='deduplicate_no_op', - indices=[ - [0, 0], - [1, 0], - [1, 1], - [1, 2], - ], - values=[b'foo', b'bar', b'biz', b'buzz'], - dense_shape=[2, 3], - expected_output_indices=[ - [0, 0], - [1, 0], - [1, 1], - [1, 2], - ], - expected_output_values=[b'foo', b'bar', b'biz', b'buzz'], - expected_output_shape=[2, 3], - ), - dict( - testcase_name='deduplicate_integers', - indices=[ - [1, 0], - [3, 1], - [3, 2], - [4, 4], - [4, 1], - ], - values=[1, 1, 1, 0, 0], - dense_shape=[5, 5], - expected_output_indices=[ - [1, 0], - [3, 0], - [4, 0], - ], - expected_output_values=[1, 1, 0], - expected_output_shape=[5, 1], - ), - dict( - testcase_name='deduplicate_empty_rows', - indices=[ - [0, 0], - [2, 1], - [2, 2], - [2, 4], - [4, 1], - ], - values=[b'foo', b'bar', b'biz', b'bar', b'foo'], - dense_shape=[5, 5], - expected_output_indices=[ - [0, 0], - [2, 0], - [2, 1], - [4, 0], - ], - expected_output_values=[b'foo', b'bar', b'biz', b'foo'], - expected_output_shape=[5, 2], - ), - dict( - testcase_name='deduplicate_shape_change', - indices=[ - [0, 0], - [0, 3], - [1, 0], - [1, 1], - [1, 2], - ], - values=[b'foo', b'foo', b'bar', b'buzz', b'bar'], - dense_shape=[2, 4], - expected_output_indices=[ - [0, 0], - [1, 0], - [1, 1], - ], - expected_output_values=[b'foo', b'bar', b'buzz'], - expected_output_shape=[2, 2], - )) - def testDedupeSparseTensorPerRow(self, indices, values, dense_shape, - expected_output_indices, - expected_output_values, - expected_output_shape): - with tf.compat.v1.Graph().as_default(): - sp_input = tf.SparseTensor( - indices=indices, values=values, dense_shape=dense_shape) - output_tensor = mappers.deduplicate_tensor_per_row(sp_input) - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertAllEqual(output.indices, expected_output_indices) - self.assertAllEqual(output.values, expected_output_values) - self.assertAllEqual(output.dense_shape, expected_output_shape) - - @test_case.named_parameters( - dict( - testcase_name='deduplicate_no_op', - values=[[b'a', b'b'], [b'c', b'd']], - expected_indices=[[0, 0], [0, 1], [1, 0], [1, 1]], - expected_output=[b'a', b'b', b'c', b'd'], - ), - # Note: because the first dimension is the batch/row dimension, a 1D - # tensor is always returned as is (since there's only 1 value per row). - dict( - testcase_name='deduplicate_1D', - values=[b'a', b'b', b'a', b'd'], - expected_indices=[[0, 0], [1, 0], [2, 0], [3, 0]], - expected_output=[b'a', b'b', b'a', b'd'], - ), - dict( - testcase_name='deduplicate', - values=[[b'a', b'b', b'a', b'b'], [b'c', b'c', b'd', b'd']], - expected_indices=[[0, 0], [0, 1], [1, 0], [1, 1]], - expected_output=[b'a', b'b', b'c', b'd'], - ), - dict( - testcase_name='deduplicate_different_sizes', - # 2 uniques in the first row, 3 in the second row. - values=[[b'a', b'b', b'a', b'b'], [b'c', b'a', b'd', b'd']], - expected_indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]], - expected_output=[b'a', b'b', b'c', b'a', b'd'], - ), - dict( - testcase_name='deduplicate_keeps_dups_across_rows', - values=[[b'a', b'b', b'a', b'b'], [b'b', b'a', b'b', b'b']], - expected_indices=[[0, 0], [0, 1], [1, 0], [1, 1]], - expected_output=[b'a', b'b', b'b', b'a'], - ), - ) - def testDedupeDenseTensorPerRow(self, values, expected_indices, - expected_output): - with tf.compat.v1.Graph().as_default(): - dense_input = tf.constant(values) - output_tensor = mappers.deduplicate_tensor_per_row(dense_input) - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertAllEqual(output.indices, expected_indices) - self.assertAllEqual(output.values, expected_output) - - def testDedup3dInputRaises(self): - dense_input = tf.constant([[[b'a', b'a'], [b'b', b'b']], - [[b'a', b'a'], [b'd', b'd']]]) - with self.assertRaises(ValueError): - mappers.deduplicate_tensor_per_row(dense_input) - - def testWordCountEmpty(self): - with tf.compat.v1.Graph().as_default(): - output_tensor = mappers.word_count( - tf.compat.v1.string_split(tf.constant(['']))) - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertEqual(1, len(output)) - self.assertEqual(0, sum(output)) - - def testWordCount(self): - with tf.compat.v1.Graph().as_default(): - string_tensor = tf.constant(['abc', 'def', 'fghijklm', 'z', '']) - tokenized_tensor = tf.compat.v1.string_split(string_tensor, delimiter='') - output_tensor = mappers.word_count(tokenized_tensor) - output_3d_tensor = mappers.word_count( - tf.sparse.expand_dims( - tf.sparse.expand_dims(tokenized_tensor, axis=1), axis=1)) - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertEqual(5, len(output)) - self.assertEqual(15, sum(output)) - self.assertAllEqual(output, [3, 3, 8, 1, 0]) - self.assertAllEqual(output, output_3d_tensor.eval()) - - def testWordCountRagged(self): - with tf.compat.v1.Graph().as_default(): - string_tensor = tf.constant(['abc', 'def', 'fghijklm', 'z', '']) - tokenized_tensor = tf.RaggedTensor.from_sparse( - tf.compat.v1.string_split(string_tensor, delimiter='')) - output_tensor = mappers.word_count(tokenized_tensor) - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertEqual(5, len(output)) - self.assertEqual(15, sum(output)) - self.assertAllEqual(output, [3, 3, 8, 1, 0]) - - def testTermFrequency(self): - input_tensor = tf.SparseTensor( - [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1]], - [1, 2, 0, 0, 0, 3, 0], - [2, 5]) - self.assertSparseOutput( - expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], - expected_values=[(3/5), (1/5), (1/5), (1/2), (1/2)], - expected_shape=[2, 4], - actual_sparse_tensor=mappers._to_term_frequency(input_tensor, 4), - close_values=True) - - def testTermFrequencyUnusedTerm(self): - input_tensor = tf.SparseTensor( - [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 0], [1, 1]], - [4, 2, 0, 0, 0, 3, 0], - [2, 5]) - self.assertSparseOutput( - expected_indices=[[0, 0], [0, 2], [0, 4], [1, 0], [1, 3]], - expected_values=[(3/5), (1/5), (1/5), (1/2), (1/2)], - expected_shape=[2, 5], - actual_sparse_tensor=mappers._to_term_frequency(input_tensor, 5), - close_values=True) - - def testCountDocsWithTerm(self): - with tf.compat.v1.Graph().as_default(): - input_tensor = tf.SparseTensor( - [[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], - [(3/5), (1/5), (1/5), (1/2), (1/2)], - [2, 4]) - output_tensor = mappers._count_docs_with_term(input_tensor) - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertAllEqual([[2, 1, 1, 1]], output) - - def testCountDocsWithTermUnusedTerm(self): - with tf.compat.v1.Graph().as_default(): - input_tensor = tf.SparseTensor( - [[0, 0], [0, 2], [1, 0], [1, 3]], - [(3/5), (1/5), (1/2), (1/2)], - [2, 4]) - output_tensor = mappers._count_docs_with_term(input_tensor) - with tf.compat.v1.Session(): - output = output_tensor.eval() - self.assertAllEqual([[2, 0, 1, 1]], output) - - def testToTFIDF(self): - term_freq = tf.SparseTensor( - [[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], - [(3/5), (1/5), (1/5), (1/2), (1/2)], - [2, 4]) - reduced_term_freq = tf.constant([[2, 1, 1, 1]]) - output_tensor = mappers._to_tfidf(term_freq, reduced_term_freq, - tf.constant(2), True) - log_3_over_2 = 1.4054651 - self.assertSparseOutput( - expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], - expected_values=[(3/5), (1/5)*log_3_over_2, (1/5)*log_3_over_2, - (1/2), (1/2)*log_3_over_2], - expected_shape=[2, 4], - actual_sparse_tensor=output_tensor, - close_values=True) - - def testToTFIDFNotSmooth(self): - term_freq = tf.SparseTensor( - [[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], - [(3/5), (1/5), (1/5), (1/2), (1/2)], - [2, 4]) - reduced_term_freq = tf.constant([[2, 1, 1, 1]]) - output_tensor = mappers._to_tfidf(term_freq, reduced_term_freq, - tf.constant(2), False) - log_2_over_1 = 1.6931471 - self.assertSparseOutput( - expected_indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 3]], - expected_values=[(3/5), (1/5)*log_2_over_1, (1/5)*log_2_over_1, - (1/2), (1/2)*log_2_over_1], - expected_shape=[2, 4], - actual_sparse_tensor=output_tensor, - close_values=True) - - def testSplitTFIDF(self): - tfidfs = tf.SparseTensor( - [[0, 0], [0, 1], [2, 1], [2, 2]], - [0.23104906, 0.19178806, 0.14384104, 0.34657359], - [3, 4]) - - out_index, out_weight = mappers._split_tfidfs_to_outputs(tfidfs) - self.assertSparseOutput( - expected_indices=[[0, 0], [0, 1], [2, 0], [2, 1]], - expected_values=[0, 1, 1, 2], - expected_shape=[3, 2], - actual_sparse_tensor=out_index, - close_values=False) - self.assertSparseOutput( - expected_indices=[[0, 0], [0, 1], [2, 0], [2, 1]], - expected_values=[0.23104906, 0.19178806, 0.14384104, 0.34657359], - expected_shape=[3, 2], - actual_sparse_tensor=out_weight, - close_values=True) - - def testSplitTFIDFWithEmptyInput(self): - # TODO(b/123242111): rewrite this test using public functions. - with tf.compat.v1.Graph().as_default(): - tfidf = tf.SparseTensor( - values=tf.constant([], shape=[0], dtype=tf.float32), - indices=tf.constant([], shape=[0, 2], dtype=tf.int64), - dense_shape=[2, 0]) - - _, weights = mappers._split_tfidfs_to_outputs(tfidf) - - with self.test_session() as sess: - weights_shape = sess.run(weights.dense_shape) - self.assertAllEqual(weights_shape, [2, 0]) - - def testHashStringsNoKeyDenseInput(self): - with tf.compat.v1.Graph().as_default(): - strings = tf.constant(['Car', 'Bus', 'Tree']) - expected_output = [8, 4, 5] - - hash_buckets = 11 - hashed_strings = mappers.hash_strings(strings, hash_buckets) - with self.test_session() as sess: - output = sess.run(hashed_strings) - self.assertAllEqual(expected_output, output) - - def testHashStringsNoKeySparseInput(self): - strings = tf.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]], - values=['Dog', 'Cat', ''], - dense_shape=[2, 2]) - hash_buckets = 17 - expected_indices = [[0, 0], [0, 1], [1, 0]] - expected_values = [12, 4, 11] - expected_shape = [2, 2] - hashed_strings = mappers.hash_strings(strings, hash_buckets) - self.assertSparseOutput( - expected_indices=expected_indices, - expected_values=expected_values, - expected_shape=expected_shape, - actual_sparse_tensor=hashed_strings, - close_values=False) - - def testHashStringsNoKeyRaggedInput(self): - strings = tf.RaggedTensor.from_row_splits( - values=['Dog', 'Cat', ''], row_splits=[0, 1, 1, 1, 1, 3]) - hash_buckets = 17 - expected_hashed_strings = tf.RaggedTensor.from_row_splits( - values=[12, 4, 11], row_splits=[0, 1, 1, 1, 1, 3]) - hashed_strings = mappers.hash_strings(strings, hash_buckets) - self.assertAllEqual(expected_hashed_strings, hashed_strings) - - def testHashStringsWithKeyDenseInput(self): - with tf.compat.v1.Graph().as_default(): - strings = tf.constant(['Cake', 'Pie', 'Sundae']) - expected_output = [6, 5, 6] - hash_buckets = 11 - hashed_strings = mappers.hash_strings( - strings, hash_buckets, key=[123, 456]) - with self.test_session() as sess: - output = sess.run(hashed_strings) - self.assertAllEqual(expected_output, output) - - def testHashStringsWithKeySparseInput(self): - strings = tf.SparseTensor( - indices=[[0, 0, 0], [0, 1, 1], [1, 1, 0], [2, 1, 0]], - values=['$$$', '%^#', '&$!#@', '$$$'], - dense_shape=[3, 3, 2]) - hash_buckets = 173 - expected_indices = strings.indices - expected_values = [16, 156, 9, 16] - expected_shape = strings.dense_shape - hashed_strings = mappers.hash_strings(strings, hash_buckets, key=[321, 555]) - self.assertSparseOutput( - expected_indices=expected_indices, - expected_values=expected_values, - expected_shape=expected_shape, - actual_sparse_tensor=hashed_strings, - close_values=False) - - def testHashStringsWithKeyRaggedInput(self): - strings = tf.RaggedTensor.from_row_splits( - values=['$$$', '%^#', '&$!#@', '$$$'], row_splits=[0, 1, 1, 2, 2, 4]) - hash_buckets = 173 - expected_hashed_strings = tf.RaggedTensor.from_row_splits( - values=[16, 156, 9, 16], row_splits=[0, 1, 1, 2, 2, 4]) - hashed_strings = mappers.hash_strings(strings, hash_buckets, key=[321, 555]) - self.assertAllEqual(expected_hashed_strings, hashed_strings) - - @test_case.named_parameters( - dict( - testcase_name='few_buckets', - x=4, - bucket_boundaries=[[5]], - expected_buckets=0), - dict( - testcase_name='large_buckets', - x=50_000_000, - bucket_boundaries=[[0, 50_000_001, 100_000_001]], - expected_buckets=1), - dict( - testcase_name='with_nans', - x=[4.0, float('nan'), float('-inf'), 7.5, 10.0], - bucket_boundaries=[[2, 5, 8]], - expected_buckets=[1, 3, 0, 2, 3]), - dict( - testcase_name='with_inf_boundary', - x=[4.0, float('-inf'), .8, 7.5, 10.0], - bucket_boundaries=[[float('-inf'), 2, 5, 8]], - expected_buckets=[2, 1, 1, 3, 4]), - ) - def testApplyBuckets(self, x, bucket_boundaries, expected_buckets): - x = tf.constant(x) - bucket_boundaries = tf.constant(bucket_boundaries) - expected_buckets = tf.constant(expected_buckets, dtype=tf.int64) - buckets = mappers.apply_buckets(x, bucket_boundaries) - self.assertAllEqual(buckets, expected_buckets) - - def testApplybucketsToSparseTensor(self): - inputs = tf.SparseTensor( - indices=[[0, 0, 0], [0, 1, 1], [2, 2, 2]], - values=[10, 20, -1], - dense_shape=[3, 3, 4]) - quantiles = [-10, 0, 13] - bucketized = mappers.apply_buckets(inputs, [quantiles]) - self.assertSparseOutput( - inputs.indices, - tf.constant([2, 3, 1]), - inputs.dense_shape, - bucketized, - close_values=False) - - def testApplybucketsToRaggedTensor(self): - inputs = tf.RaggedTensor.from_row_splits( - values=tf.RaggedTensor.from_row_splits( - values=[10, 20, -1], row_splits=[0, 1, 1, 2, 2, 3]), - row_splits=[0, 1, 1, 2, 3, 5]) - quantiles = [-10, 0, 13] - expected_bucketized = tf.RaggedTensor.from_row_splits( - values=tf.RaggedTensor.from_row_splits( - values=[2, 3, 1], row_splits=[0, 1, 1, 2, 2, 3]), - row_splits=[0, 1, 1, 2, 3, 5]) - bucketized = mappers.apply_buckets(inputs, [quantiles]) - self.assertAllEqual(expected_bucketized, bucketized) - - def testApplyBucketsWithKeys(self): - with tf.compat.v1.Graph().as_default(): - values = tf.constant([ - -100, -0.05, 0.05, 0.25, 0.15, 100, -100, 0, 4.3, 4.5, 4.4, 4.6, 100 - ], - dtype=tf.float32) - keys = tf.constant([ - 'a', 'a', 'a', 'a', 'a', 'a', 'b', 'missing', 'b', 'b', 'b', 'b', 'b' - ]) - key_vocab = tf.constant(['a', 'b']) - # Pre-normalization boundaries: [[0, 0.1, 0.2], [4.33, 4.43, 4.53]] - bucket_boundaries = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0], - dtype=tf.float32) - scales = 1.0 / ( - tf.constant([0.2, 4.53], dtype=tf.float32) - - tf.constant([0, 4.33], dtype=tf.float32)) - shifts = tf.constant([0, 1.0 - (4.33 * 5)], dtype=tf.float32) - num_buckets = tf.constant(4, dtype=tf.int64) - buckets = mappers._apply_buckets_with_keys(values, keys, key_vocab, - bucket_boundaries, scales, - shifts, num_buckets) - with self.test_session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - output = sess.run(buckets) - self.assertAllEqual([0, 0, 1, 3, 2, 3, 0, -1, 0, 2, 1, 3, 3], output) - - @test_case.named_parameters( - dict( - testcase_name='single_input_value', - x=1, - boundaries=[0, 2], - expected_results=.5), - dict( - testcase_name='single_boundary', - x=[-1, 9, 10, 11], - boundaries=[10], - expected_results=[0, 0, 1, 1]), - dict( - testcase_name='out_of_bounds', - x=[-1111, 0, 5, 9, 10, 11, 15, 19, 20, 21, 1111], - boundaries=[10, 20], - expected_results=[0, 0, 0, 0, 0, .1, 0.5, .9, 1, 1, 1]), - dict( - testcase_name='2d_input', - x=[[15, 10], [20, 17], [-1111, 21]], - boundaries=[10, 20], - expected_results=[[0.5, 0], [1, .7], [0, 1]]), - dict( - testcase_name='integer_input', - x=[15, 20, 25], - boundaries=[10, 20], - expected_results=[.5, 1, 1], - input_dtype=tf.int64), - dict( - testcase_name='float_input', - x=[-10, 0, 0.1, 2.3, 4.5, 6.7, 8.9, 10, 100], - boundaries=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - expected_results=[0, 0, 0.01, 0.23, 0.45, 0.67, 0.89, 1, 1]), - dict( - testcase_name='float_input_with_nans', - x=[ - float('-inf'), -10, 0, 0.1, 2.3, - float('nan'), 4.5, 6.7, 8.9, 10, 100, - float('inf') - ], - boundaries=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], - expected_results=[0, 0, 0, 0.01, 0.23, .5, 0.45, 0.67, 0.89, 1, 1, - 1]), - dict( - testcase_name='float_input_with_inf_boundaries', - x=[ - float('-inf'), - float('-inf'), - float(0), - float('-inf'), - ], - boundaries=[float('-inf'), 0], - expected_results=[0, 0, 1, 0]), - dict( - testcase_name='float_input_with_nan_boundaries', - x=[ - float('-inf'), - float('nan'), - float(0), - float(1), - ], - boundaries=[float('nan'), 0, 1], - expected_results=[0, .5, 0, 1]), - dict( - testcase_name='integer_boundaries', - x=[15, 20, 25], - boundaries=[10, 20], - expected_results=[.5, 1, 1], - boundaries_dtype=tf.int64), - dict( - testcase_name='negative_boundaries', - x=[-10, -5, -3, 0, 2, 4, 8, 12, 18], - boundaries=[-20, -4, 1, 4, 20], - expected_results=[ - 0.15625, 0.234375, .3, .45, 0.583333, .75, 0.8125, .875, 0.96875 - ]), - dict( - testcase_name='interpolates_properly', - x=[-1111, 10, 50, 100, 1000, 9000, 10000, 1293817391], - boundaries=[10, 100, 1000, 10000], - expected_results=[ - 0, 0, (4.0 / 9 / 3), (1.0 / 3), (2.0 / 3), ((2 + 8.0 / 9) / 3), 1, - 1 - ], - boundaries_dtype=tf.int64), - ) - def testApplyBucketsWithInterpolation(self, - x, - boundaries, - expected_results, - input_dtype=tf.float32, - boundaries_dtype=tf.float32): - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: - x = tf.constant(x, dtype=input_dtype) - boundaries = tf.constant([boundaries], dtype=boundaries_dtype) - output = mappers.apply_buckets_with_interpolation(x, boundaries) - self.assertAllClose(sess.run(output), expected_results, 1e-6) - - def testApplyBucketsWithInterpolationAllNanBoundariesRaises(self): - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: - x = tf.constant([float('-inf'), float('nan'), 0.0, 1.0]) - boundaries = tf.constant([[float('nan'), float('nan'), float('nan')]]) - with self.assertRaisesRegex( - tf.errors.InvalidArgumentError, 'num_boundaries' - ): - sess.run(mappers.apply_buckets_with_interpolation(x, boundaries)) - - def testApplyBucketsWithInterpolationRaises(self): - # We should raise an exception if you try to scale a non-numeric tensor. - with self.test_session(): - x = tf.constant(['a', 'b', 'c'], dtype=tf.string) - boundaries = tf.constant([.2, .4], dtype=tf.float32) - with self.assertRaises(ValueError): - mappers.apply_buckets_with_interpolation(x, boundaries) - - def testApplyBucketsWithInterpolationSparseTensor(self): - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: - x = tf.SparseTensor( - indices=[[0, 0, 0], [1, 1, 2], [3, 1, 4], [1, 1, 4], [6, 1, 1], - [3, 1, 2]], - values=[15, 10, 20, 17, -1111, 21], - dense_shape=[7, 3, 5]) + _, weights = mappers._split_tfidfs_to_outputs(tfidf) + + with self.test_session() as sess: + weights_shape = sess.run(weights.dense_shape) + self.assertAllEqual(weights_shape, [2, 0]) + + def testHashStringsNoKeyDenseInput(self): + with tf.compat.v1.Graph().as_default(): + strings = tf.constant(["Car", "Bus", "Tree"]) + expected_output = [8, 4, 5] + + hash_buckets = 11 + hashed_strings = mappers.hash_strings(strings, hash_buckets) + with self.test_session() as sess: + output = sess.run(hashed_strings) + self.assertAllEqual(expected_output, output) + + def testHashStringsNoKeySparseInput(self): + strings = tf.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0]], + values=["Dog", "Cat", ""], + dense_shape=[2, 2], + ) + hash_buckets = 17 + expected_indices = [[0, 0], [0, 1], [1, 0]] + expected_values = [12, 4, 11] + expected_shape = [2, 2] + hashed_strings = mappers.hash_strings(strings, hash_buckets) + self.assertSparseOutput( + expected_indices=expected_indices, + expected_values=expected_values, + expected_shape=expected_shape, + actual_sparse_tensor=hashed_strings, + close_values=False, + ) + + def testHashStringsNoKeyRaggedInput(self): + strings = tf.RaggedTensor.from_row_splits( + values=["Dog", "Cat", ""], row_splits=[0, 1, 1, 1, 1, 3] + ) + hash_buckets = 17 + expected_hashed_strings = tf.RaggedTensor.from_row_splits( + values=[12, 4, 11], row_splits=[0, 1, 1, 1, 1, 3] + ) + hashed_strings = mappers.hash_strings(strings, hash_buckets) + self.assertAllEqual(expected_hashed_strings, hashed_strings) + + def testHashStringsWithKeyDenseInput(self): + with tf.compat.v1.Graph().as_default(): + strings = tf.constant(["Cake", "Pie", "Sundae"]) + expected_output = [6, 5, 6] + hash_buckets = 11 + hashed_strings = mappers.hash_strings(strings, hash_buckets, key=[123, 456]) + with self.test_session() as sess: + output = sess.run(hashed_strings) + self.assertAllEqual(expected_output, output) + + def testHashStringsWithKeySparseInput(self): + strings = tf.SparseTensor( + indices=[[0, 0, 0], [0, 1, 1], [1, 1, 0], [2, 1, 0]], + values=["$$$", "%^#", "&$!#@", "$$$"], + dense_shape=[3, 3, 2], + ) + hash_buckets = 173 + expected_indices = strings.indices + expected_values = [16, 156, 9, 16] + expected_shape = strings.dense_shape + hashed_strings = mappers.hash_strings(strings, hash_buckets, key=[321, 555]) + self.assertSparseOutput( + expected_indices=expected_indices, + expected_values=expected_values, + expected_shape=expected_shape, + actual_sparse_tensor=hashed_strings, + close_values=False, + ) + + def testHashStringsWithKeyRaggedInput(self): + strings = tf.RaggedTensor.from_row_splits( + values=["$$$", "%^#", "&$!#@", "$$$"], row_splits=[0, 1, 1, 2, 2, 4] + ) + hash_buckets = 173 + expected_hashed_strings = tf.RaggedTensor.from_row_splits( + values=[16, 156, 9, 16], row_splits=[0, 1, 1, 2, 2, 4] + ) + hashed_strings = mappers.hash_strings(strings, hash_buckets, key=[321, 555]) + self.assertAllEqual(expected_hashed_strings, hashed_strings) + + @test_case.named_parameters( + dict( + testcase_name="few_buckets", + x=4, + bucket_boundaries=[[5]], + expected_buckets=0, + ), + dict( + testcase_name="large_buckets", + x=50_000_000, + bucket_boundaries=[[0, 50_000_001, 100_000_001]], + expected_buckets=1, + ), + dict( + testcase_name="with_nans", + x=[4.0, float("nan"), float("-inf"), 7.5, 10.0], + bucket_boundaries=[[2, 5, 8]], + expected_buckets=[1, 3, 0, 2, 3], + ), + dict( + testcase_name="with_inf_boundary", + x=[4.0, float("-inf"), 0.8, 7.5, 10.0], + bucket_boundaries=[[float("-inf"), 2, 5, 8]], + expected_buckets=[2, 1, 1, 3, 4], + ), + ) + def testApplyBuckets(self, x, bucket_boundaries, expected_buckets): + x = tf.constant(x) + bucket_boundaries = tf.constant(bucket_boundaries) + expected_buckets = tf.constant(expected_buckets, dtype=tf.int64) + buckets = mappers.apply_buckets(x, bucket_boundaries) + self.assertAllEqual(buckets, expected_buckets) + + def testApplybucketsToSparseTensor(self): + inputs = tf.SparseTensor( + indices=[[0, 0, 0], [0, 1, 1], [2, 2, 2]], + values=[10, 20, -1], + dense_shape=[3, 3, 4], + ) + quantiles = [-10, 0, 13] + bucketized = mappers.apply_buckets(inputs, [quantiles]) + self.assertSparseOutput( + inputs.indices, + tf.constant([2, 3, 1]), + inputs.dense_shape, + bucketized, + close_values=False, + ) + + def testApplybucketsToRaggedTensor(self): + inputs = tf.RaggedTensor.from_row_splits( + values=tf.RaggedTensor.from_row_splits( + values=[10, 20, -1], row_splits=[0, 1, 1, 2, 2, 3] + ), + row_splits=[0, 1, 1, 2, 3, 5], + ) + quantiles = [-10, 0, 13] + expected_bucketized = tf.RaggedTensor.from_row_splits( + values=tf.RaggedTensor.from_row_splits( + values=[2, 3, 1], row_splits=[0, 1, 1, 2, 2, 3] + ), + row_splits=[0, 1, 1, 2, 3, 5], + ) + bucketized = mappers.apply_buckets(inputs, [quantiles]) + self.assertAllEqual(expected_bucketized, bucketized) + + def testApplyBucketsWithKeys(self): + with tf.compat.v1.Graph().as_default(): + values = tf.constant( + [-100, -0.05, 0.05, 0.25, 0.15, 100, -100, 0, 4.3, 4.5, 4.4, 4.6, 100], + dtype=tf.float32, + ) + keys = tf.constant( + ["a", "a", "a", "a", "a", "a", "b", "missing", "b", "b", "b", "b", "b"] + ) + key_vocab = tf.constant(["a", "b"]) + # Pre-normalization boundaries: [[0, 0.1, 0.2], [4.33, 4.43, 4.53]] + bucket_boundaries = tf.constant([0.0, 0.5, 1.0, 1.5, 2.0], dtype=tf.float32) + scales = 1.0 / ( + tf.constant([0.2, 4.53], dtype=tf.float32) + - tf.constant([0, 4.33], dtype=tf.float32) + ) + shifts = tf.constant([0, 1.0 - (4.33 * 5)], dtype=tf.float32) + num_buckets = tf.constant(4, dtype=tf.int64) + buckets = mappers._apply_buckets_with_keys( + values, keys, key_vocab, bucket_boundaries, scales, shifts, num_buckets + ) + with self.test_session() as sess: + sess.run(tf.compat.v1.tables_initializer()) + output = sess.run(buckets) + self.assertAllEqual([0, 0, 1, 3, 2, 3, 0, -1, 0, 2, 1, 3, 3], output) + + @test_case.named_parameters( + dict( + testcase_name="single_input_value", + x=1, + boundaries=[0, 2], + expected_results=0.5, + ), + dict( + testcase_name="single_boundary", + x=[-1, 9, 10, 11], + boundaries=[10], + expected_results=[0, 0, 1, 1], + ), + dict( + testcase_name="out_of_bounds", + x=[-1111, 0, 5, 9, 10, 11, 15, 19, 20, 21, 1111], + boundaries=[10, 20], + expected_results=[0, 0, 0, 0, 0, 0.1, 0.5, 0.9, 1, 1, 1], + ), + dict( + testcase_name="2d_input", + x=[[15, 10], [20, 17], [-1111, 21]], + boundaries=[10, 20], + expected_results=[[0.5, 0], [1, 0.7], [0, 1]], + ), + dict( + testcase_name="integer_input", + x=[15, 20, 25], + boundaries=[10, 20], + expected_results=[0.5, 1, 1], + input_dtype=tf.int64, + ), + dict( + testcase_name="float_input", + x=[-10, 0, 0.1, 2.3, 4.5, 6.7, 8.9, 10, 100], + boundaries=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + expected_results=[0, 0, 0.01, 0.23, 0.45, 0.67, 0.89, 1, 1], + ), + dict( + testcase_name="float_input_with_nans", + x=[ + float("-inf"), + -10, + 0, + 0.1, + 2.3, + float("nan"), + 4.5, + 6.7, + 8.9, + 10, + 100, + float("inf"), + ], + boundaries=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + expected_results=[0, 0, 0, 0.01, 0.23, 0.5, 0.45, 0.67, 0.89, 1, 1, 1], + ), + dict( + testcase_name="float_input_with_inf_boundaries", + x=[ + float("-inf"), + float("-inf"), + float(0), + float("-inf"), + ], + boundaries=[float("-inf"), 0], + expected_results=[0, 0, 1, 0], + ), + dict( + testcase_name="float_input_with_nan_boundaries", + x=[ + float("-inf"), + float("nan"), + float(0), + float(1), + ], + boundaries=[float("nan"), 0, 1], + expected_results=[0, 0.5, 0, 1], + ), + dict( + testcase_name="integer_boundaries", + x=[15, 20, 25], + boundaries=[10, 20], + expected_results=[0.5, 1, 1], + boundaries_dtype=tf.int64, + ), + dict( + testcase_name="negative_boundaries", + x=[-10, -5, -3, 0, 2, 4, 8, 12, 18], + boundaries=[-20, -4, 1, 4, 20], + expected_results=[ + 0.15625, + 0.234375, + 0.3, + 0.45, + 0.583333, + 0.75, + 0.8125, + 0.875, + 0.96875, + ], + ), + dict( + testcase_name="interpolates_properly", + x=[-1111, 10, 50, 100, 1000, 9000, 10000, 1293817391], + boundaries=[10, 100, 1000, 10000], + expected_results=[ + 0, + 0, + (4.0 / 9 / 3), + (1.0 / 3), + (2.0 / 3), + ((2 + 8.0 / 9) / 3), + 1, + 1, + ], + boundaries_dtype=tf.int64, + ), + ) + def testApplyBucketsWithInterpolation( + self, + x, + boundaries, + expected_results, + input_dtype=tf.float32, + boundaries_dtype=tf.float32, + ): + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + x = tf.constant(x, dtype=input_dtype) + boundaries = tf.constant([boundaries], dtype=boundaries_dtype) + output = mappers.apply_buckets_with_interpolation(x, boundaries) + self.assertAllClose(sess.run(output), expected_results, 1e-6) + + def testApplyBucketsWithInterpolationAllNanBoundariesRaises(self): + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + x = tf.constant([float("-inf"), float("nan"), 0.0, 1.0]) + boundaries = tf.constant([[float("nan"), float("nan"), float("nan")]]) + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, "num_boundaries" + ): + sess.run(mappers.apply_buckets_with_interpolation(x, boundaries)) + + def testApplyBucketsWithInterpolationRaises(self): + # We should raise an exception if you try to scale a non-numeric tensor. + with self.test_session(): + x = tf.constant(["a", "b", "c"], dtype=tf.string) + boundaries = tf.constant([0.2, 0.4], dtype=tf.float32) + with self.assertRaises(ValueError): + mappers.apply_buckets_with_interpolation(x, boundaries) + + def testApplyBucketsWithInterpolationSparseTensor(self): + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + x = tf.SparseTensor( + indices=[ + [0, 0, 0], + [1, 1, 2], + [3, 1, 4], + [1, 1, 4], + [6, 1, 1], + [3, 1, 2], + ], + values=[15, 10, 20, 17, -1111, 21], + dense_shape=[7, 3, 5], + ) + boundaries = [[10, 20]] + output = mappers.apply_buckets_with_interpolation(x, boundaries) + expected_results = tf.SparseTensor( + indices=x.indices, + values=[0.5, 0, 1, 0.7, 0, 1], + dense_shape=x.dense_shape, + ) + actual_results = sess.run(output) + self.assertAllClose( + actual_results.values, expected_results.values, 1e-6 + ) + self.assertAllEqual(actual_results.indices, expected_results.indices) + self.assertAllEqual( + actual_results.dense_shape, expected_results.dense_shape + ) + + def testApplyBucketsWithInterpolationRaggedTensor(self): + inputs = tf.RaggedTensor.from_row_splits( + values=[15, 10, 20, 17, -1111, 21], row_splits=[0, 1, 1, 2, 4, 5, 6] + ) boundaries = [[10, 20]] - output = mappers.apply_buckets_with_interpolation(x, boundaries) - expected_results = tf.SparseTensor( - indices=x.indices, - values=[.5, 0, 1, .7, 0, 1], - dense_shape=x.dense_shape) - actual_results = sess.run(output) - self.assertAllClose(actual_results.values, - expected_results.values, - 1e-6) - self.assertAllEqual(actual_results.indices, expected_results.indices) - self.assertAllEqual(actual_results.dense_shape, - expected_results.dense_shape) - - def testApplyBucketsWithInterpolationRaggedTensor(self): - inputs = tf.RaggedTensor.from_row_splits( - values=[15, 10, 20, 17, -1111, 21], row_splits=[0, 1, 1, 2, 4, 5, 6]) - boundaries = [[10, 20]] - expected_bucketized = tf.RaggedTensor.from_row_splits( - values=[.5, 0, 1, .7, 0, 1], row_splits=[0, 1, 1, 2, 4, 5, 6]) - bucketized = mappers.apply_buckets_with_interpolation(inputs, boundaries) - self.assertAllEqual(expected_bucketized, bucketized) - - def testBucketsWithInterpolationUnknownShapeBoundary(self): - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: - x = tf.constant([0, 1, 5, 12], dtype=tf.float32) - # The shape used to generate the boundaries is random, and therefore - # the size of the boundaries tensor is not known. - num_boundaries = tf.random.uniform([1], 1, 2, dtype=tf.int64)[0] - boundaries = tf.random.uniform([1, num_boundaries], 0, 10) - # We don't assert anything about the outcome because we're intentionally - # using randomized boundaries, but we ensure the operations succeed. - _ = sess.run(mappers.apply_buckets_with_interpolation(x, boundaries)) - - def testSparseTensorToDenseWithShape(self): - with tf.compat.v1.Graph().as_default(): - sparse = tf.compat.v1.sparse_placeholder( - tf.int64, shape=[None, None, None]) - dense = mappers.sparse_tensor_to_dense_with_shape(sparse, [None, 5, 6]) - self.assertAllEqual(dense.get_shape().as_list(), [None, 5, 6]) - - def testSparseTensorLeftAlign(self): - with tf.compat.v1.Graph().as_default(): - with self.test_session() as sess: - x = tf.SparseTensor( - indices=[[0, 3], [1, 2], [1, 4], [3, 2], [3, 4], [5, 0], [6, 1]], - values=[15, 10, 20, 17, -1111, 13, 21], - dense_shape=[7, 5]) - y = mappers.sparse_tensor_left_align(x) - expected_indices = [[0, 0], [1, 0], [1, 1], [3, 0], [3, 1], [5, 0], - [6, 0]] - self.assertAllEqual(sess.run(y.indices), expected_indices) - - def testEstimatedProbabilityDensityMissingKey(self): - input_size = 5 - - with tf.compat.v1.Graph().as_default(): - input_data = tf.constant([[str(x + 1)] for x in range(input_size)]) - - count = tf.constant([3] * input_size, tf.int64) - boundaries = tf.as_string(tf.range(input_size)) - with mock.patch.object( - mappers.analyzers, 'histogram', side_effect=[(count, boundaries)]): - - result = mappers.estimated_probability_density( - input_data, categorical=True) - - expected = np.array([[0.2], [0.2], [0.2], [0.2], [0.]], np.float32) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - self.assertAllEqual(expected, sess.run(result)) - - -if __name__ == '__main__': - test_case.main() + expected_bucketized = tf.RaggedTensor.from_row_splits( + values=[0.5, 0, 1, 0.7, 0, 1], row_splits=[0, 1, 1, 2, 4, 5, 6] + ) + bucketized = mappers.apply_buckets_with_interpolation(inputs, boundaries) + self.assertAllEqual(expected_bucketized, bucketized) + + def testBucketsWithInterpolationUnknownShapeBoundary(self): + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + x = tf.constant([0, 1, 5, 12], dtype=tf.float32) + # The shape used to generate the boundaries is random, and therefore + # the size of the boundaries tensor is not known. + num_boundaries = tf.random.uniform([1], 1, 2, dtype=tf.int64)[0] + boundaries = tf.random.uniform([1, num_boundaries], 0, 10) + # We don't assert anything about the outcome because we're intentionally + # using randomized boundaries, but we ensure the operations succeed. + _ = sess.run(mappers.apply_buckets_with_interpolation(x, boundaries)) + + def testSparseTensorToDenseWithShape(self): + with tf.compat.v1.Graph().as_default(): + sparse = tf.compat.v1.sparse_placeholder(tf.int64, shape=[None, None, None]) + dense = mappers.sparse_tensor_to_dense_with_shape(sparse, [None, 5, 6]) + self.assertAllEqual(dense.get_shape().as_list(), [None, 5, 6]) + + def testSparseTensorLeftAlign(self): + with tf.compat.v1.Graph().as_default(): + with self.test_session() as sess: + x = tf.SparseTensor( + indices=[[0, 3], [1, 2], [1, 4], [3, 2], [3, 4], [5, 0], [6, 1]], + values=[15, 10, 20, 17, -1111, 13, 21], + dense_shape=[7, 5], + ) + y = mappers.sparse_tensor_left_align(x) + expected_indices = [ + [0, 0], + [1, 0], + [1, 1], + [3, 0], + [3, 1], + [5, 0], + [6, 0], + ] + self.assertAllEqual(sess.run(y.indices), expected_indices) + + def testEstimatedProbabilityDensityMissingKey(self): + input_size = 5 + + with tf.compat.v1.Graph().as_default(): + input_data = tf.constant([[str(x + 1)] for x in range(input_size)]) + + count = tf.constant([3] * input_size, tf.int64) + boundaries = tf.as_string(tf.range(input_size)) + with mock.patch.object( + mappers.analyzers, "histogram", side_effect=[(count, boundaries)] + ): + result = mappers.estimated_probability_density( + input_data, categorical=True + ) + + expected = np.array([[0.2], [0.2], [0.2], [0.2], [0.0]], np.float32) + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.tables_initializer()) + self.assertAllEqual(expected, sess.run(result)) + + +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/nodes.py b/tensorflow_transform/nodes.py index 57ea1f8..5fbde3f 100644 --- a/tensorflow_transform/nodes.py +++ b/tensorflow_transform/nodes.py @@ -33,367 +33,385 @@ class OperationDef(metaclass=abc.ABCMeta): - """The definition of an operation. + """The definition of an operation. - This class contains all the information needed to run an operation, except - the number of inputs and their values. A subclass should document + This class contains all the information needed to run an operation, except + the number of inputs and their values. A subclass should document - - How many inputs it expects, and what they should contain. - - What it outputs, as a function of its inputs. + - How many inputs it expects, and what they should contain. + - What it outputs, as a function of its inputs. - An OperationDef is just a specification and does not contain the actual - computation. - """ + An OperationDef is just a specification and does not contain the actual + computation. + """ - @property - def num_outputs(self) -> int: - """The number of outputs returned by this operation.""" - return 1 + @property + def num_outputs(self) -> int: + """The number of outputs returned by this operation.""" + return 1 - @abc.abstractproperty - def label(self) -> str: - """A unique label for this operation in the graph.""" - pass + @abc.abstractproperty + def label(self) -> str: + """A unique label for this operation in the graph.""" + pass - def get_field_str(self, field_name: str) -> str: - """Returns a str representation of the requested field.""" - return getattr(self, field_name) + def get_field_str(self, field_name: str) -> str: + """Returns a str representation of the requested field.""" + return getattr(self, field_name) - @property - def is_partitionable(self) -> bool: - """If True, means that this operation can be applied on partitioned data. + @property + def is_partitionable(self) -> bool: + """If True, means that this operation can be applied on partitioned data. - Being able to be applied on partitioned data means that partitioning the - data, running this operation on each of the data subsets independently, and - then having the next operation get the flattened results as inputs would be - equivalent to running this operation on the entire data and passing the - result to the next operation. + Being able to be applied on partitioned data means that partitioning the + data, running this operation on each of the data subsets independently, and + then having the next operation get the flattened results as inputs would be + equivalent to running this operation on the entire data and passing the + result to the next operation. - Returns: - A bool indicating whether or not this operation is partitionable. - """ - return False + Returns + ------- + A bool indicating whether or not this operation is partitionable. + """ + return False - @property - def cache_coder(self) -> Optional[object]: - """A CacheCoder object used to cache outputs returned by this operation. + @property + def cache_coder(self) -> Optional[object]: + """A CacheCoder object used to cache outputs returned by this operation. - If this doesn't return None, then: - * num_outputs has to be 1 - * is_partitionable has to be True. - """ - return None + If this doesn't return None, then: + * num_outputs has to be 1 + * is_partitionable has to be True. + """ + return None @dataclasses.dataclass(frozen=True) class ValueNode: - """A placeholder that will ultimately be translated to a PCollection. + """A placeholder that will ultimately be translated to a PCollection. - Attributes: - parent_operation: The `OperationNode` that produces this value. - value_index: The index of this value in the outputs of `parent_operation`. - """ + Attributes + ---------- + parent_operation: The `OperationNode` that produces this value. + value_index: The index of this value in the outputs of `parent_operation`. + """ - parent_operation: 'OperationNode' - value_index: int + parent_operation: "OperationNode" + value_index: int - def __post_init__(self): - num_outputs = self.parent_operation.operation_def.num_outputs - if not (0 <= self.value_index and self.value_index < num_outputs): - raise ValueError( - 'value_index was {} but parent_operation had {} outputs'.format( - self.value_index, num_outputs - ) - ) + def __post_init__(self): + num_outputs = self.parent_operation.operation_def.num_outputs + if not (self.value_index >= 0 and self.value_index < num_outputs): + raise ValueError( + f"value_index was {self.value_index} but parent_operation had {num_outputs} outputs" + ) class OperationNode: - """A placeholder that will ultimately be translated to a PTransform. - - Attributes: - operation_def: An `OperationDef`. - inputs: A tuple of `ValueNode`s. - """ - - def __init__(self, operation_def, inputs): - self._operation_def = operation_def - self._inputs = inputs - if not isinstance(operation_def, OperationDef): - raise TypeError( - 'operation_def must be an OperationDef, got {} of type {}'.format( - operation_def, type(operation_def))) - if not isinstance(inputs, tuple): - raise TypeError( - 'inputs must be a tuple, got {} of type {}'.format( - inputs, type(inputs))) - for value_node in inputs: - if not isinstance(value_node, ValueNode): - raise TypeError( - 'Inputs to Operation must be a ValueNode, got {} of type {}'.format( - value_node, type(value_node))) - - def __repr__(self): - return '{}(operation_def={}, inputs={})'.format( - self.__class__.__name__, self.operation_def, self.inputs) - - @property - def operation_def(self): - return self._operation_def - - @property - def inputs(self): - return self._inputs - - @property - def outputs(self): - """A tuple of `ValueNode`s representing outputs of this operation.""" - return tuple(ValueNode(self, value_index) - for value_index in range(self.operation_def.num_outputs)) + """A placeholder that will ultimately be translated to a PTransform. + + Attributes + ---------- + operation_def: An `OperationDef`. + inputs: A tuple of `ValueNode`s. + """ + + def __init__(self, operation_def, inputs): + self._operation_def = operation_def + self._inputs = inputs + if not isinstance(operation_def, OperationDef): + raise TypeError( + f"operation_def must be an OperationDef, got {operation_def} of type {type(operation_def)}" + ) + if not isinstance(inputs, tuple): + raise TypeError( + f"inputs must be a tuple, got {inputs} of type {type(inputs)}" + ) + for value_node in inputs: + if not isinstance(value_node, ValueNode): + raise TypeError( + f"Inputs to Operation must be a ValueNode, got {value_node} of type {type(value_node)}" + ) + + def __repr__(self): + return f"{self.__class__.__name__}(operation_def={self.operation_def}, inputs={self.inputs})" + + @property + def operation_def(self): + return self._operation_def + + @property + def inputs(self): + return self._inputs + + @property + def outputs(self): + """A tuple of `ValueNode`s representing outputs of this operation.""" + return tuple( + ValueNode(self, value_index) + for value_index in range(self.operation_def.num_outputs) + ) def apply_operation(operation_def_cls, *args, **kwargs): - """Applies an operation to some inputs and returns its output. + """Applies an operation to some inputs and returns its output. - This function is syntactic sugar on top of the constructor for OperationNode. - The operation must return a single output. + This function is syntactic sugar on top of the constructor for OperationNode. + The operation must return a single output. - Args: - operation_def_cls: A class that is a subclass of `OperationDef`. - *args: The inputs to the `OperationNode`. - **kwargs: Constructor args for `operation_def_cls`. + Args: + ---- + operation_def_cls: A class that is a subclass of `OperationDef`. + *args: The inputs to the `OperationNode`. + **kwargs: Constructor args for `operation_def_cls`. - Returns: - The output of the `OperationNode` that was constructed. - """ - (result,) = apply_multi_output_operation(operation_def_cls, *args, **kwargs) - return result + Returns: + ------- + The output of the `OperationNode` that was constructed. + """ + (result,) = apply_multi_output_operation(operation_def_cls, *args, **kwargs) + return result def apply_multi_output_operation(operation_def_cls, *args, **kwargs): - """Like `apply_operation` but returns a tuple of outputs.""" - try: - return OperationNode(operation_def_cls(**kwargs), args).outputs - except TypeError as e: - raise RuntimeError('Failed to apply Operation {}, with error: {}'.format( - operation_def_cls, str(e))) + """Like `apply_operation` but returns a tuple of outputs.""" + try: + return OperationNode(operation_def_cls(**kwargs), args).outputs + except TypeError as e: + raise RuntimeError( + f"Failed to apply Operation {operation_def_cls}, with error: {str(e)}" + ) class Visitor(metaclass=abc.ABCMeta): - """Class to visit nodes in the graph.""" + """Class to visit nodes in the graph.""" - @abc.abstractmethod - def validate_value(self, value): - """Validate the value of a ValueNode. + @abc.abstractmethod + def validate_value(self, value): + """Validate the value of a ValueNode. - Should raise an error if `value` is invalid. + Should raise an error if `value` is invalid. - Args: - value: An element of the tuple returned by visit. - """ - pass + Args: + ---- + value: An element of the tuple returned by visit. + """ + pass - @abc.abstractmethod - def visit(self, operation_def, input_values): - """Visits an `OperationNode` in the graph. + @abc.abstractmethod + def visit(self, operation_def, input_values): + """Visits an `OperationNode` in the graph. - Called once for each `OperationNode` in the graph that is visited. Will - be called with the `operation_def` of that `OperationNode`, and values - determined by cached recursive calls to the `OperationNode`s that produce - each input `ValueNode` of the current `OperationNode`. + Called once for each `OperationNode` in the graph that is visited. Will + be called with the `operation_def` of that `OperationNode`, and values + determined by cached recursive calls to the `OperationNode`s that produce + each input `ValueNode` of the current `OperationNode`. - Args: - operation_def: The `OperationDef` of the current `OperationNode`. - input_values: Values corresponding to each input of the current - `OperationNode`. + Args: + ---- + operation_def: The `OperationDef` of the current `OperationNode`. + input_values: Values corresponding to each input of the current + `OperationNode`. - Returns: - A tuple of values corresponding to the outputs of the current - `OperationNode`. - """ - pass + Returns: + ------- + A tuple of values corresponding to the outputs of the current + `OperationNode`. + """ + pass class Traverser: - """Class to traverse the DAG of nodes.""" - - def __init__(self, visitor: Visitor): - """Init method for Traverser. - - Args: - visitor: A `Visitor` object. - """ - self._cached_value_nodes_values: Dict[ValueNode, Any] = {} - self._stack: List[OperationNode] = [] - self._visitor = visitor - - def visit_value_node(self, value_node: ValueNode): - """Visit a value node, and return a corresponding value. - - Args: - value_node: A `ValueNode`. - - Returns: - A value corresponding to `value_node` determined by the implementation of - the abstract `visit` method. - """ - return self._maybe_visit_value_node(value_node) - - def _maybe_visit_value_node(self, value_node: ValueNode): - """Visit a value node if not cached, and return a corresponding value. - - Args: - value_node: A `ValueNode`. - - Returns: - A value corresponding to `value_node` determined by the implementation of - the abstract `visit` method. - """ - if value_node not in self._cached_value_nodes_values: - self._visit_operation(value_node.parent_operation) - return self._cached_value_nodes_values[value_node] - - def _visit_operation(self, operation: OperationNode): - """Visit an `OperationNode`.""" - if operation in self._stack: - cycle = self._stack[self._stack.index(operation):] + [operation] - # For readability, just print the label of `operation_def`s - cycle = ', '.join(operation.operation_def.label for operation in cycle) - raise AssertionError('Cycle detected: [{}]'.format(cycle)) - self._stack.append(operation) - input_values = tuple(map(self._maybe_visit_value_node, operation.inputs)) - assert operation is self._stack.pop() - output_values = self._visitor.visit(operation.operation_def, input_values) - outputs = operation.outputs - - # Expect a tuple of outputs. Since ValueNode and OperationDef are both - # subclasses of tuple, we also explicitly disallow them, since returning - # a single ValueNode or OperationDef is almost certainly an error. - try: - _ = iter(output_values) - output_iterable = not isinstance(output_values, str) - except TypeError: - output_iterable = False - if (not output_iterable or isinstance(output_values, - (ValueNode, OperationDef))): - raise ValueError( - 'When running operation {} expected visitor to return a tuple, got ' - '{} of type {}'.format(operation.operation_def.label, output_values, - type(output_values))) - # DoOutputsTuple doesn't work with len(). - if hasattr(output_values, '__len__') and len(output_values) != len(outputs): - raise ValueError( - 'Operation {} has {} outputs but visitor returned {} values: ' - '{}'.format(operation.operation_def, len(outputs), - len(output_values), output_values)) - - for output, value in zip(outputs, output_values): - self._visitor.validate_value(value) - self._cached_value_nodes_values[output] = value + """Class to traverse the DAG of nodes.""" + + def __init__(self, visitor: Visitor): + """Init method for Traverser. + + Args: + ---- + visitor: A `Visitor` object. + """ + self._cached_value_nodes_values: Dict[ValueNode, Any] = {} + self._stack: List[OperationNode] = [] + self._visitor = visitor + + def visit_value_node(self, value_node: ValueNode): + """Visit a value node, and return a corresponding value. + + Args: + ---- + value_node: A `ValueNode`. + + Returns: + ------- + A value corresponding to `value_node` determined by the implementation of + the abstract `visit` method. + """ + return self._maybe_visit_value_node(value_node) + + def _maybe_visit_value_node(self, value_node: ValueNode): + """Visit a value node if not cached, and return a corresponding value. + + Args: + ---- + value_node: A `ValueNode`. + + Returns: + ------- + A value corresponding to `value_node` determined by the implementation of + the abstract `visit` method. + """ + if value_node not in self._cached_value_nodes_values: + self._visit_operation(value_node.parent_operation) + return self._cached_value_nodes_values[value_node] + + def _visit_operation(self, operation: OperationNode): + """Visit an `OperationNode`.""" + if operation in self._stack: + cycle = self._stack[self._stack.index(operation) :] + [operation] + # For readability, just print the label of `operation_def`s + cycle = ", ".join(operation.operation_def.label for operation in cycle) + raise AssertionError(f"Cycle detected: [{cycle}]") + self._stack.append(operation) + input_values = tuple(map(self._maybe_visit_value_node, operation.inputs)) + assert operation is self._stack.pop() + output_values = self._visitor.visit(operation.operation_def, input_values) + outputs = operation.outputs + + # Expect a tuple of outputs. Since ValueNode and OperationDef are both + # subclasses of tuple, we also explicitly disallow them, since returning + # a single ValueNode or OperationDef is almost certainly an error. + try: + _ = iter(output_values) + output_iterable = not isinstance(output_values, str) + except TypeError: + output_iterable = False + if not output_iterable or isinstance(output_values, (ValueNode, OperationDef)): + raise ValueError( + f"When running operation {operation.operation_def.label} expected visitor to return a tuple, got " + f"{output_values} of type {type(output_values)}" + ) + # DoOutputsTuple doesn't work with len(). + if hasattr(output_values, "__len__") and len(output_values) != len(outputs): + raise ValueError( + f"Operation {operation.operation_def} has {len(outputs)} outputs but visitor returned {len(output_values)} values: " + f"{output_values}" + ) + + for output, value in zip(outputs, output_values): + self._visitor.validate_value(value) + self._cached_value_nodes_values[output] = value def _escape(line: str) -> str: - for char in '<>{}': - line = line.replace(char, '\\%s' % char) - return line + for char in "<>{}": + line = line.replace(char, "\\%s" % char) + return line class _PrintGraphVisitor(Visitor): - """Visitor to produce a human readable string for a graph.""" + """Visitor to produce a human readable string for a graph.""" - def __init__(self): - self._print_result = '' - self._dot_graph = pydot.Dot(directed=True) - self._dot_graph.obj_dict = collections.OrderedDict( - sorted(self._dot_graph.obj_dict.items(), key=lambda t: t[0])) - self._dot_graph.set_node_defaults(shape='Mrecord') - super().__init__() + def __init__(self): + self._print_result = "" + self._dot_graph = pydot.Dot(directed=True) + self._dot_graph.obj_dict = collections.OrderedDict( + sorted(self._dot_graph.obj_dict.items(), key=lambda t: t[0]) + ) + self._dot_graph.set_node_defaults(shape="Mrecord") + super().__init__() - def get_dot_graph(self) -> pydot.Dot: - return self._dot_graph + def get_dot_graph(self) -> pydot.Dot: + return self._dot_graph - def visit(self, operation_def, input_nodes) -> Tuple[pydot.Node, ...]: - num_outputs = operation_def.num_outputs - node_name = operation_def.label + def visit(self, operation_def, input_nodes) -> Tuple[pydot.Node, ...]: + num_outputs = operation_def.num_outputs + node_name = operation_def.label - display_label_rows = ([operation_def.__class__.__name__] + [ - _escape('%s: %s' % (field, operation_def.get_field_str(field))) - for field in operation_def._fields - ]) + display_label_rows = [operation_def.__class__.__name__] + [ + _escape("%s: %s" % (field, operation_def.get_field_str(field))) + for field in operation_def._fields + ] - if operation_def.is_partitionable: - display_label_rows.append('partitionable: %s' % True) + if operation_def.is_partitionable: + display_label_rows.append("partitionable: %s" % True) - if num_outputs != 1: - ports = '|'.join('<{0}>{0}'.format(idx) for idx in range(num_outputs)) - display_label_rows.append('{%s}' % ports) - display_label = '{%s}' % '|'.join(display_label_rows) + if num_outputs != 1: + ports = "|".join(f"<{idx}>{idx}" for idx in range(num_outputs)) + display_label_rows.append("{%s}" % ports) + display_label = "{%s}" % "|".join(display_label_rows) - node = pydot.Node(node_name, label=display_label) + node = pydot.Node(node_name, label=display_label) - self._dot_graph.add_node(node) + self._dot_graph.add_node(node) - for input_node in input_nodes: - self._dot_graph.add_edge(pydot.Edge(input_node, node)) + for input_node in input_nodes: + self._dot_graph.add_edge(pydot.Edge(input_node, node)) - if num_outputs == 1: - return (node,) - else: - return tuple( - pydot.Node(obj_dict={'name': '"{}":{}'.format(node_name, idx)}) - for idx in range(num_outputs)) + if num_outputs == 1: + return (node,) + else: + return tuple( + pydot.Node(obj_dict={"name": f'"{node_name}":{idx}'}) + for idx in range(num_outputs) + ) - def validate_value(self, value: pydot.Node): - assert isinstance(value, pydot.Node) + def validate_value(self, value: pydot.Node): + assert isinstance(value, pydot.Node) def get_dot_graph(leaf_nodes: Collection[ValueNode]) -> pydot.Dot: - """Utility to print a graph in a human readable manner. + """Utility to print a graph in a human readable manner. - The format resembles a sequence of calls to apply_operation or - apply_multi_output_operation. + The format resembles a sequence of calls to apply_operation or + apply_multi_output_operation. - Args: - leaf_nodes: A list of leaf `ValueNode`s to define the graph. The graph will - be the transitive parents of the leaf nodes. + Args: + ---- + leaf_nodes: A list of leaf `ValueNode`s to define the graph. The graph will + be the transitive parents of the leaf nodes. - Returns: - A human readable summary of the graph. - """ - visitor = _PrintGraphVisitor() - traverser = Traverser(visitor) - for value_node in leaf_nodes: - traverser.visit_value_node(value_node) - return visitor.get_dot_graph() + Returns: + ------- + A human readable summary of the graph. + """ + visitor = _PrintGraphVisitor() + traverser = Traverser(visitor) + for value_node in leaf_nodes: + traverser.visit_value_node(value_node) + return visitor.get_dot_graph() class _CountGraphNodes(Visitor): - """Visitor which counts the graph nodes.""" + """Visitor which counts the graph nodes.""" - num_nodes = 0 + num_nodes = 0 - def visit(self, operation_def: OperationDef, _) -> Tuple[int]: - self.num_nodes += 1 - return tuple(1 for _ in range(operation_def.num_outputs)) + def visit(self, operation_def: OperationDef, _) -> Tuple[int]: + self.num_nodes += 1 + return tuple(1 for _ in range(operation_def.num_outputs)) - def validate_value(self, value: int): - pass + def validate_value(self, value: int): + pass def count_graph_nodes(leaf_nodes: Collection[ValueNode]) -> int: - """Counts the number of graph nodes. - - Note: these nodes only include the TFT graph nodes, it doesn't count beam - nodes constructed directly. - - Args: - leaf_nodes: A list of leaf `ValueNode`s to define the graph. The graph will - be the transitive parents of the leaf nodes. - - Returns: - The count of TFT graph nodes. - """ - visitor = _CountGraphNodes() - traverser = Traverser(visitor) - for value_node in leaf_nodes: - traverser.visit_value_node(value_node) - return visitor.num_nodes + """Counts the number of graph nodes. + + Note: these nodes only include the TFT graph nodes, it doesn't count beam + nodes constructed directly. + + Args: + ---- + leaf_nodes: A list of leaf `ValueNode`s to define the graph. The graph will + be the transitive parents of the leaf nodes. + + Returns: + ------- + The count of TFT graph nodes. + """ + visitor = _CountGraphNodes() + traverser = Traverser(visitor) + for value_node in leaf_nodes: + traverser.visit_value_node(value_node) + return visitor.num_nodes diff --git a/tensorflow_transform/nodes_test.py b/tensorflow_transform/nodes_test.py index 9699a8f..b22cab7 100644 --- a/tensorflow_transform/nodes_test.py +++ b/tensorflow_transform/nodes_test.py @@ -14,195 +14,215 @@ """Tests for tensorflow_transform.nodes.""" import tensorflow as tf -from tensorflow_transform import nodes -from tensorflow_transform import test_case + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple +from tensorflow_transform import nodes, test_case + mock = tf.compat.v1.test.mock -class _Concat( - tfx_namedtuple.namedtuple('_Concat', ['label']), nodes.OperationDef): - __slots__ = () +class _Concat(tfx_namedtuple.namedtuple("_Concat", ["label"]), nodes.OperationDef): + __slots__ = () -class _Swap(tfx_namedtuple.namedtuple('_Swap', ['label']), nodes.OperationDef): - __slots__ = () +class _Swap(tfx_namedtuple.namedtuple("_Swap", ["label"]), nodes.OperationDef): + __slots__ = () - @property - def num_outputs(self): - return 2 + @property + def num_outputs(self): + return 2 class _Constant( - tfx_namedtuple.namedtuple('_Constant', ['value', 'label']), - nodes.OperationDef): - __slots__ = () + tfx_namedtuple.namedtuple("_Constant", ["value", "label"]), nodes.OperationDef +): + __slots__ = () -class _Identity( - tfx_namedtuple.namedtuple('_Identity', ['label']), nodes.OperationDef): - __slots__ = () +class _Identity(tfx_namedtuple.namedtuple("_Identity", ["label"]), nodes.OperationDef): + __slots__ = () class NodesTest(test_case.TransformTestCase): - - def testApplyOperationWithKwarg(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - op = a.parent_operation - self.assertEqual(a.value_index, 0) - self.assertEqual(op.operation_def, _Constant('a', 'Constant[a]')) - self.assertEqual(op.inputs, ()) - self.assertEqual(op.outputs, (a,)) - - def testApplyOperationWithTupleOutput(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - b = nodes.apply_operation(_Constant, value='b', label='Constant[b]') - b_copy, a_copy = nodes.apply_multi_output_operation( - _Swap, a, b, label='Swap') - op = b_copy.parent_operation - self.assertEqual(b_copy.value_index, 0) - self.assertEqual(a_copy.parent_operation, op) - self.assertEqual(a_copy.value_index, 1) - self.assertEqual(op.operation_def, _Swap('Swap')) - self.assertEqual(op.inputs, (a, b)) - self.assertEqual(op.outputs, (b_copy, a_copy)) - - def testValueNodeWithNegativeValueIndex(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - with self.assertRaisesWithLiteralMatch( - ValueError, 'value_index was -1 but parent_operation had 1 outputs'): - nodes.ValueNode(a.parent_operation, -1) - - def testValueNodeWithTooHighValueIndex(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - with self.assertRaisesWithLiteralMatch( - ValueError, 'value_index was 2 but parent_operation had 1 outputs'): - nodes.ValueNode(a.parent_operation, 2) - - def testTraverserSimpleGraph(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - mock_visitor = mock.MagicMock() - mock_visitor.visit.side_effect = [('a',)] - nodes.Traverser(mock_visitor).visit_value_node(a) - mock_visitor.assert_has_calls([ - mock.call.visit(_Constant('a', 'Constant[a]'), ()), - mock.call.validate_value('a'), - ]) - - def testTraverserComplexGraph(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - b = nodes.apply_operation(_Constant, value='b', label='Constant[b]') - c = nodes.apply_operation(_Constant, value='c', label='Constant[c]') - b_copy, a_copy = nodes.apply_multi_output_operation( - _Swap, a, b, label='Swap') - b_a = nodes.apply_operation(_Concat, b_copy, a_copy, label='Concat[0]') - b_a_c = nodes.apply_operation(_Concat, b_a, c, label='Concat[1]') - - mock_visitor = mock.MagicMock() - mock_visitor.visit.side_effect = [ - ('a',), ('b',), ('b', 'a'), ('ba',), ('c',), ('bac',)] - - nodes.Traverser(mock_visitor).visit_value_node(b_a_c) - - mock_visitor.assert_has_calls([ - mock.call.visit(_Constant('a', 'Constant[a]'), ()), - mock.call.validate_value('a'), - mock.call.visit(_Constant('b', 'Constant[b]'), ()), - mock.call.validate_value('b'), - mock.call.visit(_Swap('Swap'), ('a', 'b')), - mock.call.validate_value('b'), - mock.call.validate_value('a'), - mock.call.visit(_Concat('Concat[0]'), ('b', 'a')), - mock.call.validate_value('ba'), - mock.call.visit(_Constant('c', 'Constant[c]'), ()), - mock.call.validate_value('c'), - mock.call.visit(_Concat('Concat[1]'), ('ba', 'c')), - mock.call.validate_value('bac'), - ]) - - def testTraverserComplexGraphMultipleCalls(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - b = nodes.apply_operation(_Constant, value='b', label='Constant[b]') - c = nodes.apply_operation(_Constant, value='c', label='Constant[c]') - b_copy, a_copy = nodes.apply_multi_output_operation( - _Swap, a, b, label='Swap') - b_a = nodes.apply_operation(_Concat, b_copy, a_copy, label='Concat[0]') - b_a_c = nodes.apply_operation(_Concat, b_a, c, label='Concat[1]') - - mock_visitor = mock.MagicMock() - mock_visitor.visit.side_effect = [ - ('a',), ('b',), ('b', 'a'), ('ba',), ('c',), ('bac',)] - - traverser = nodes.Traverser(mock_visitor) - traverser.visit_value_node(b_a) - traverser.visit_value_node(b_a_c) - - mock_visitor.assert_has_calls([ - mock.call.visit(_Constant('a', 'Constant[a]'), ()), - mock.call.validate_value('a'), - mock.call.visit(_Constant('b', 'Constant[b]'), ()), - mock.call.validate_value('b'), - mock.call.visit(_Swap('Swap'), ('a', 'b')), - mock.call.validate_value('b'), - mock.call.validate_value('a'), - mock.call.visit(_Concat('Concat[0]'), ('b', 'a')), - mock.call.validate_value('ba'), - mock.call.visit(_Constant('c', 'Constant[c]'), ()), - mock.call.validate_value('c'), - mock.call.visit(_Concat('Concat[1]'), ('ba', 'c')), - mock.call.validate_value('bac'), - ]) - - def testTraverserOutputsNotATuple(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - - mock_visitor = mock.MagicMock() - mock_visitor.visit.side_effect = [42] - - with self.assertRaisesRegex( - ValueError, r'expected visitor to return a tuple, got'): - nodes.Traverser(mock_visitor).visit_value_node(a) - - def testTraverserBadNumOutputs(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - mock_visitor = mock.MagicMock() - mock_visitor.visit.side_effect = [('a', 'b')] - - with self.assertRaisesRegex( - ValueError, 'has 1 outputs but visitor returned 2 values: '): - nodes.Traverser(mock_visitor).visit_value_node(a) - - def testTraverserCycle(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - x_0 = nodes.apply_operation(_Identity, a, label='Identity[0]') - x_1 = nodes.apply_operation(_Identity, x_0, label='Identity[1]') - x_2 = nodes.apply_operation(_Identity, x_1, label='Identity[2]') - x_0.parent_operation._inputs = (x_2,) - - mock_visitor = mock.MagicMock() - mock_visitor.visit.return_value = ('x',) - - with self.assertRaisesWithLiteralMatch( - AssertionError, - 'Cycle detected: [Identity[2], Identity[1], Identity[0], Identity[2]]'): - nodes.Traverser(mock_visitor).visit_value_node(x_2) - - def testGetDotGraph(self): - a = nodes.apply_operation(_Constant, value='a', label='Constant[a]') - b = nodes.apply_operation(_Constant, value='b', label='Constant[b]') - b_copy, a_copy = nodes.apply_multi_output_operation( - _Swap, a, b, label='Swap[0]') - b_copy2, unused_a_copy2 = nodes.apply_multi_output_operation( - _Swap, a_copy, b_copy, label='Swap[1]') - dot_string = nodes.get_dot_graph([b_copy2]).to_string() - self.WriteRenderedDotFile(dot_string) - - self.assertMultiLineEqual( - dot_string, - """\ + def testApplyOperationWithKwarg(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + op = a.parent_operation + self.assertEqual(a.value_index, 0) + self.assertEqual(op.operation_def, _Constant("a", "Constant[a]")) + self.assertEqual(op.inputs, ()) + self.assertEqual(op.outputs, (a,)) + + def testApplyOperationWithTupleOutput(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + b = nodes.apply_operation(_Constant, value="b", label="Constant[b]") + b_copy, a_copy = nodes.apply_multi_output_operation(_Swap, a, b, label="Swap") + op = b_copy.parent_operation + self.assertEqual(b_copy.value_index, 0) + self.assertEqual(a_copy.parent_operation, op) + self.assertEqual(a_copy.value_index, 1) + self.assertEqual(op.operation_def, _Swap("Swap")) + self.assertEqual(op.inputs, (a, b)) + self.assertEqual(op.outputs, (b_copy, a_copy)) + + def testValueNodeWithNegativeValueIndex(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + with self.assertRaisesWithLiteralMatch( + ValueError, "value_index was -1 but parent_operation had 1 outputs" + ): + nodes.ValueNode(a.parent_operation, -1) + + def testValueNodeWithTooHighValueIndex(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + with self.assertRaisesWithLiteralMatch( + ValueError, "value_index was 2 but parent_operation had 1 outputs" + ): + nodes.ValueNode(a.parent_operation, 2) + + def testTraverserSimpleGraph(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + mock_visitor = mock.MagicMock() + mock_visitor.visit.side_effect = [("a",)] + nodes.Traverser(mock_visitor).visit_value_node(a) + mock_visitor.assert_has_calls( + [ + mock.call.visit(_Constant("a", "Constant[a]"), ()), + mock.call.validate_value("a"), + ] + ) + + def testTraverserComplexGraph(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + b = nodes.apply_operation(_Constant, value="b", label="Constant[b]") + c = nodes.apply_operation(_Constant, value="c", label="Constant[c]") + b_copy, a_copy = nodes.apply_multi_output_operation(_Swap, a, b, label="Swap") + b_a = nodes.apply_operation(_Concat, b_copy, a_copy, label="Concat[0]") + b_a_c = nodes.apply_operation(_Concat, b_a, c, label="Concat[1]") + + mock_visitor = mock.MagicMock() + mock_visitor.visit.side_effect = [ + ("a",), + ("b",), + ("b", "a"), + ("ba",), + ("c",), + ("bac",), + ] + + nodes.Traverser(mock_visitor).visit_value_node(b_a_c) + + mock_visitor.assert_has_calls( + [ + mock.call.visit(_Constant("a", "Constant[a]"), ()), + mock.call.validate_value("a"), + mock.call.visit(_Constant("b", "Constant[b]"), ()), + mock.call.validate_value("b"), + mock.call.visit(_Swap("Swap"), ("a", "b")), + mock.call.validate_value("b"), + mock.call.validate_value("a"), + mock.call.visit(_Concat("Concat[0]"), ("b", "a")), + mock.call.validate_value("ba"), + mock.call.visit(_Constant("c", "Constant[c]"), ()), + mock.call.validate_value("c"), + mock.call.visit(_Concat("Concat[1]"), ("ba", "c")), + mock.call.validate_value("bac"), + ] + ) + + def testTraverserComplexGraphMultipleCalls(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + b = nodes.apply_operation(_Constant, value="b", label="Constant[b]") + c = nodes.apply_operation(_Constant, value="c", label="Constant[c]") + b_copy, a_copy = nodes.apply_multi_output_operation(_Swap, a, b, label="Swap") + b_a = nodes.apply_operation(_Concat, b_copy, a_copy, label="Concat[0]") + b_a_c = nodes.apply_operation(_Concat, b_a, c, label="Concat[1]") + + mock_visitor = mock.MagicMock() + mock_visitor.visit.side_effect = [ + ("a",), + ("b",), + ("b", "a"), + ("ba",), + ("c",), + ("bac",), + ] + + traverser = nodes.Traverser(mock_visitor) + traverser.visit_value_node(b_a) + traverser.visit_value_node(b_a_c) + + mock_visitor.assert_has_calls( + [ + mock.call.visit(_Constant("a", "Constant[a]"), ()), + mock.call.validate_value("a"), + mock.call.visit(_Constant("b", "Constant[b]"), ()), + mock.call.validate_value("b"), + mock.call.visit(_Swap("Swap"), ("a", "b")), + mock.call.validate_value("b"), + mock.call.validate_value("a"), + mock.call.visit(_Concat("Concat[0]"), ("b", "a")), + mock.call.validate_value("ba"), + mock.call.visit(_Constant("c", "Constant[c]"), ()), + mock.call.validate_value("c"), + mock.call.visit(_Concat("Concat[1]"), ("ba", "c")), + mock.call.validate_value("bac"), + ] + ) + + def testTraverserOutputsNotATuple(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + + mock_visitor = mock.MagicMock() + mock_visitor.visit.side_effect = [42] + + with self.assertRaisesRegex( + ValueError, r"expected visitor to return a tuple, got" + ): + nodes.Traverser(mock_visitor).visit_value_node(a) + + def testTraverserBadNumOutputs(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + mock_visitor = mock.MagicMock() + mock_visitor.visit.side_effect = [("a", "b")] + + with self.assertRaisesRegex( + ValueError, "has 1 outputs but visitor returned 2 values: " + ): + nodes.Traverser(mock_visitor).visit_value_node(a) + + def testTraverserCycle(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + x_0 = nodes.apply_operation(_Identity, a, label="Identity[0]") + x_1 = nodes.apply_operation(_Identity, x_0, label="Identity[1]") + x_2 = nodes.apply_operation(_Identity, x_1, label="Identity[2]") + x_0.parent_operation._inputs = (x_2,) + + mock_visitor = mock.MagicMock() + mock_visitor.visit.return_value = ("x",) + + with self.assertRaisesWithLiteralMatch( + AssertionError, + "Cycle detected: [Identity[2], Identity[1], Identity[0], Identity[2]]", + ): + nodes.Traverser(mock_visitor).visit_value_node(x_2) + + def testGetDotGraph(self): + a = nodes.apply_operation(_Constant, value="a", label="Constant[a]") + b = nodes.apply_operation(_Constant, value="b", label="Constant[b]") + b_copy, a_copy = nodes.apply_multi_output_operation( + _Swap, a, b, label="Swap[0]" + ) + b_copy2, unused_a_copy2 = nodes.apply_multi_output_operation( + _Swap, a_copy, b_copy, label="Swap[1]" + ) + dot_string = nodes.get_dot_graph([b_copy2]).to_string() + self.WriteRenderedDotFile(dot_string) + + self.assertMultiLineEqual( + dot_string, + """\ digraph G { directed=True; node [shape=Mrecord]; @@ -216,8 +236,9 @@ def testGetDotGraph(self): "Swap[0]":0 -> "Swap[1]"; } """, - msg='Result dot graph is:\n{}'.format(dot_string)) + msg=f"Result dot graph is:\n{dot_string}", + ) -if __name__ == '__main__': - test_case.main() +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/output_wrapper.py b/tensorflow_transform/output_wrapper.py index b7d8926..977387c 100644 --- a/tensorflow_transform/output_wrapper.py +++ b/tensorflow_transform/output_wrapper.py @@ -19,518 +19,565 @@ import numpy as np import tensorflow as tf -from tensorflow_transform import common -from tensorflow_transform import common_types -from tensorflow_transform import graph_tools -from tensorflow_transform.analyzers import sanitized_vocab_filename -from tensorflow_transform.keras_lib import tf_keras -from tensorflow_transform.saved import saved_transform_io -from tensorflow_transform.saved import saved_transform_io_v2 -from tensorflow_transform.tf_metadata import dataset_metadata -from tensorflow_transform.tf_metadata import metadata_io -from tensorflow_transform.tf_metadata import schema_utils # pylint: disable=g-direct-tensorflow-import from tensorflow.python.framework import ops from tensorflow.tools.docs import doc_controls + # pylint: enable=g-direct-tensorflow-import from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_transform import common, common_types, graph_tools +from tensorflow_transform.analyzers import sanitized_vocab_filename +from tensorflow_transform.keras_lib import tf_keras +from tensorflow_transform.saved import saved_transform_io, saved_transform_io_v2 +from tensorflow_transform.tf_metadata import dataset_metadata, metadata_io, schema_utils -def _get_tensor_value(tensor_or_eager_tensor: tf.Tensor) -> Any: - if ops.executing_eagerly_outside_functions(): - return np.asarray(tensor_or_eager_tensor) - else: - with tf.compat.v1.Session(): - return tensor_or_eager_tensor.eval() - - -class _TransformedFeaturesDict(dict): - """A wrapper around dict. - - Overrides pop to return None instead of throwing a KeyError when invoked with - a key that is not found in the dictionary. - - NOTE: Do not use directly. - """ - - def pop(self, key, default=None): # pylint: disable=useless-super-delegation - return super().pop(key, default) - - -class TFTransformOutput: - """A wrapper around the output of the tf.Transform.""" - - # Locations relative to the base output directory, where outputs of - # tf.Transform should be written in order to be read by TFTransformOutput. - # WriteTransformFn will follow these conventions. - TRANSFORMED_METADATA_DIR = 'transformed_metadata' - TRANSFORM_FN_DIR = 'transform_fn' - ASSET_MAP = 'asset_map' - - def __init__(self, transform_output_dir: str): - """Init method for TFTransformOutput. - - Args: - transform_output_dir: The directory containig tf.Transform output. - """ - self._transform_output_dir = transform_output_dir - - # Lazily constructed properties. - self._transformed_metadata = None - self._raw_metadata = None - self._transform_features_layer = None - self._exported_as_v1_value = None - self._transformed_domains = None - - @property - def transformed_metadata(self) -> dataset_metadata.DatasetMetadata: - """A DatasetMetadata.""" - if self._transformed_metadata is None: - self._transformed_metadata = metadata_io.read_metadata( - self._transformed_metadata_dir) - return self._transformed_metadata - - @property - def transform_savedmodel_dir(self) -> str: - """A python str.""" - return os.path.join(self._transform_output_dir, self.TRANSFORM_FN_DIR) - - @property - def _exported_as_v1(self) -> bool: - """A boolean. - - Indicates whether the SavedModel was exported using TF 1.x or TF 2.x APIs. - """ - if self._exported_as_v1_value is None: - self._exported_as_v1_value = saved_transform_io.exported_as_v1( - self.transform_savedmodel_dir) - return self._exported_as_v1_value - - @property - def _transformed_metadata_dir(self) -> str: - return os.path.join(self._transform_output_dir, - self.TRANSFORMED_METADATA_DIR) - - def transformed_feature_spec(self) -> Dict[str, common_types.FeatureSpecType]: - """Returns a feature_spec for the transformed features. - - Returns: - A dict from feature names to FixedLenFeature/SparseFeature/VarLenFeature. - """ - return schema_utils.schema_as_feature_spec( - self.transformed_metadata.schema).feature_spec - - def transformed_domains(self) -> Dict[str, common_types.DomainType]: - """Returns domains for the transformed features. - - Returns: - A dict from feature names to one of schema_pb2.IntDomain, - schema_pb2.StringDomain or schema_pb2.FloatDomain. - """ - if self._transformed_domains is None: - self._transformed_domains = schema_utils.schema_as_feature_spec( - self.transformed_metadata.schema).domains - return self._transformed_domains - - def vocabulary_file_by_name(self, vocab_filename: str) -> Optional[str]: - """Returns the vocabulary file path created in the preprocessing function. - - `vocab_filename` must either be (i) the name used as the vocab_filename - argument to tft.compute_and_apply_vocabulary / tft.vocabulary or (ii) the - key used in tft.annotate_asset. - - When a mapping has been specified by calls to tft.annotate_asset, it will be - checked first for the provided filename. If present, this filename will be - used directly to construct a path. - - If the mapping does not exist or `vocab_filename` is not present within it, - we will default to sanitizing `vocab_filename` and searching for files - matching it within the assets directory. - - In either case, if the constructed path does not point to an existing file - within the assets subdirectory, we will return a None. - Args: - vocab_filename: The vocabulary name to lookup. - """ - mapping_path = os.path.join(self._transformed_metadata_dir, self.ASSET_MAP) - - mapping = {} - if tf.io.gfile.exists(mapping_path): - with tf.io.gfile.GFile(mapping_path) as f: - mapping = json.loads(f.read()) - if vocab_filename in mapping: - vocab_path = os.path.join(self.transform_savedmodel_dir, - tf.saved_model.ASSETS_DIRECTORY, - mapping[vocab_filename]) - if tf.io.gfile.exists(vocab_path): - return vocab_path - - prefix = os.path.join(self.transform_savedmodel_dir, - tf.saved_model.ASSETS_DIRECTORY, - sanitized_vocab_filename(filename=vocab_filename)) - files = tf.io.gfile.glob(prefix) + tf.io.gfile.glob( - '{}.tfrecord.gz'.format(prefix)) - if not files: - return None - if len(files) != 1: - raise ValueError('Found too many vocabulary files: {}'.format(files)) - return files[0] - - def _vocabulary_size_from_annotations(self, - vocab_filename: str) -> Optional[int]: - """If vocabulary size is present in annotations return it, else None.""" - if not common.IS_ANNOTATIONS_PB_AVAILABLE: - return None - - try: - schema = self.transformed_metadata.schema - except IOError: - return None - - from tensorflow_transform import annotations_pb2 # pylint: disable=g-import-not-at-top - for annotation in schema.annotation.extra_metadata: - message = annotations_pb2.VocabularyMetadata() - annotation.Unpack(message) - # Check message.filtered_vocabulary_size is not 0 for backwards - # compatibility. - if (message.file_name == vocab_filename and - message.filtered_vocabulary_size != 0): - return message.filtered_vocabulary_size - - return None - - def vocabulary_size_by_name(self, vocab_filename: str) -> int: - """Like vocabulary_file_by_name, but returns the size of vocabulary.""" - vocab_size_from_annotations = self._vocabulary_size_from_annotations( - vocab_filename) - if vocab_size_from_annotations is not None: - return vocab_size_from_annotations - - vocab_path = self.vocabulary_file_by_name(vocab_filename) - if not vocab_path: - raise ValueError( - 'Could not compute vocabulary size for {}, does not exist'.format( - vocab_filename)) - elif vocab_path.endswith('tfrecord.gz'): - dataset = tf.data.TFRecordDataset(vocab_path, compression_type='GZIP') - - def reduce_fn(accum, elem): - return tf.size(elem, out_type=tf.int64, name='vocabulary_size') + accum - - return _get_tensor_value( - dataset.batch(tf.int32.max).reduce( - tf.constant(0, tf.int64), reduce_fn)) - else: - with tf.io.gfile.GFile(vocab_path, 'rb') as f: - return sum(1 for _ in f) - - def vocabulary_by_name(self, vocab_filename: str) -> List[bytes]: - """Like vocabulary_file_by_name but returns a list.""" - vocab_path = self.vocabulary_file_by_name(vocab_filename) - if not vocab_path: - raise ValueError('Could not read vocabulary: {}, does not exist'.format( - vocab_filename)) - elif vocab_path.endswith('tfrecord.gz'): - dataset = tf.data.TFRecordDataset(vocab_path, compression_type='GZIP') - vocab_tensor = dataset.batch(tf.int32.max).reduce( - tf.constant([], dtype=tf.string), - lambda state, elem: tf.concat([state, elem], axis=-1)) - # Using as_numpy_iterator only works when executing eagerly. - return _get_tensor_value(vocab_tensor).tolist() - else: - with tf.io.gfile.GFile(vocab_path, 'rb') as f: - return [l.rstrip(os.linesep.encode('utf-8')) for l in f] - - # TODO(KesterTong): Add test for this in output_wrapper_test.py - def num_buckets_for_transformed_feature(self, name: str) -> int: - """Returns the number of buckets for an integerized transformed feature.""" - # Do checks that this tensor can be wrapped in - # sparse_column_with_integerized_feature - try: - domain = self.transformed_domains()[name] - except KeyError: - raise ValueError('Column {} did not have a domain provided.'.format(name)) - if not isinstance(domain, schema_pb2.IntDomain): - raise ValueError('Column {} has domain {}, expected an IntDomain'.format( - name, domain)) - if domain.min != 0: - raise ValueError('Column {} has min value {}, should be 0'.format( - name, domain.min)) - return domain.max + 1 - - def transform_features_layer(self) -> tf_keras.Model: - """Creates a `TransformFeaturesLayer` from this transform output. - - If a `TransformFeaturesLayer` has already been created for self, the same - one will be returned. - - Returns: - A `TransformFeaturesLayer` instance. - """ - if self._transform_features_layer is None: - self._transform_features_layer = TransformFeaturesLayer( - self, exported_as_v1=self._exported_as_v1) - return self._transform_features_layer - - def transform_raw_features( - self, - raw_features: Mapping[str, common_types.TensorType], - drop_unused_features: bool = True # LEGACY_VALUE=False - ) -> Dict[str, common_types.TensorType]: - """Takes a dict of tensors representing raw features and transforms them. - - Takes a dictionary of `Tensor`, `SparseTensor`, or `RaggedTensor`s that - represent the raw features, and applies the transformation defined by - tf.Transform. - - If False it returns all transformed features defined by tf.Transform. To - only return features transformed from the given 'raw_features', set - `drop_unused_features` to True. - - Note: If eager execution is enabled and this API is invoked inside a - tf.function or an API that uses tf.function such as dataset.map, please use - `transform_features_layer` instead. It separates out loading of the - transform graph and hence resources will not be initialized on each - invocation. This can have significant performance improvement if the - transform graph was exported as a TF1 SavedModel and guarantees correctness - if it was exported as a TF2 SavedModel. - - Args: - raw_features: A dict whose keys are feature names and values are - `Tensor`s, `SparseTensor`s, or `RaggedTensor`s. - drop_unused_features: If True, the result will be filtered. Only the - features that are transformed from 'raw_features' will be included in - the returned result. If a feature is transformed from multiple raw - features (e.g, feature cross), it will only be included if all its base - raw features are present in `raw_features`. - - Returns: - A dict whose keys are feature names and values are `Tensor`s, - `SparseTensor`s, or `RaggedTensor`s representing transformed features. - """ - if self._exported_as_v1: - transformed_features = self._transform_raw_features_compat_v1( - raw_features, drop_unused_features) - else: - tft_layer = self.transform_features_layer() - if not drop_unused_features: - tf.compat.v1.logging.warning( - 'Unused features are always dropped in the TF 2.x ' - 'implementation. Ignoring value of drop_unused_features.') - - transformed_features = tft_layer(raw_features) - return _TransformedFeaturesDict(transformed_features) - - def _transform_raw_features_compat_v1( - self, raw_features: Mapping[str, common_types.TensorType], - drop_unused_features: bool) -> Dict[str, common_types.TensorType]: - """Takes a dict of tensors representing raw features and transforms them.""" - unbounded_raw_features, transformed_features = ( - saved_transform_io.partially_apply_saved_transform_internal( - self.transform_savedmodel_dir, raw_features)) - if drop_unused_features: - graph = tf.compat.v1.get_default_graph() - graph_analyzer = graph_tools.InitializableGraphAnalyzer( - graph, raw_features, - [(t, False) for t in unbounded_raw_features.values()]) - return { - name: feature - for name, feature in transformed_features.items() - if graph_analyzer.ready_to_run(feature) - } +def _get_tensor_value(tensor_or_eager_tensor: tf.Tensor) -> Any: + if ops.executing_eagerly_outside_functions(): + return np.asarray(tensor_or_eager_tensor) else: - return transformed_features + with tf.compat.v1.Session(): + return tensor_or_eager_tensor.eval() - def load_transform_graph(self): - """Load the transform graph without replacing any placeholders. - - This is necessary to ensure that variables in the transform graph are - included in the training checkpoint when using tf.Estimator. This should - be called in the training input_fn. - """ - if self._exported_as_v1 is None: - self._exported_as_v1 = saved_transform_io.exported_as_v1( - self.transform_savedmodel_dir) - - if self._exported_as_v1: - saved_transform_io.partially_apply_saved_transform_internal( - self.transform_savedmodel_dir, {}) - else: - # Note: This should use the same mechanism as `transform_raw_features` to - # load the SavedModel into the current graph context. - _ = self.transform_features_layer()({}) - - RAW_METADATA_DIR = 'metadata' - _FEATURE_STATS_PB = 'FeatureStats.pb' - PRE_TRANSFORM_FEATURE_STATS_PATH = os.path.join( - 'pre_transform_feature_stats', _FEATURE_STATS_PB) - POST_TRANSFORM_FEATURE_STATS_PATH = os.path.join( - 'post_transform_feature_stats', _FEATURE_STATS_PB) - - @property - def raw_metadata(self) -> dataset_metadata.DatasetMetadata: - """A DatasetMetadata. - - Note: raw_metadata is not guaranteed to exist in the output of tf.transform - and hence using this could fail, if raw_metadata is not present in - TFTransformOutput. - - Returns: - A DatasetMetadata - """ - if self._raw_metadata is None: - self._raw_metadata = metadata_io.read_metadata( - os.path.join(self._transform_output_dir, self.RAW_METADATA_DIR)) - return self._raw_metadata - def raw_feature_spec(self) -> Dict[str, common_types.FeatureSpecType]: - """Returns a feature_spec for the raw features. - - Returns: - A dict from feature names to FixedLenFeature/SparseFeature/VarLenFeature. - """ - return schema_utils.schema_as_feature_spec( - self.raw_metadata.schema).feature_spec +class _TransformedFeaturesDict(dict): + """A wrapper around dict. - def raw_domains(self) -> Dict[str, common_types.DomainType]: - """Returns domains for the raw features. + Overrides pop to return None instead of throwing a KeyError when invoked with + a key that is not found in the dictionary. - Returns: - A dict from feature names to one of schema_pb2.IntDomain, - schema_pb2.StringDomain or schema_pb2.FloatDomain. + NOTE: Do not use directly. """ - return schema_utils.schema_as_feature_spec( - self.raw_metadata.schema).domains - @property - def pre_transform_statistics_path(self) -> str: - """Returns the path to the pre-transform datum statistics. + def pop(self, key, default=None): # pylint: disable=useless-super-delegation + return super().pop(key, default) - Note: pre_transform_statistics is not guaranteed to exist in the output of - tf.transform and hence using this could fail, if pre_transform statistics is - not present in TFTransformOutput. - """ - return os.path.join( - self._transform_output_dir, self.PRE_TRANSFORM_FEATURE_STATS_PATH) - @property - def post_transform_statistics_path(self) -> str: - """Returns the path to the post-transform datum statistics. - - Note: post_transform_statistics is not guaranteed to exist in the output of - tf.transform and hence using this could fail, if post_transform statistics - is not present in TFTransformOutput. - """ - return os.path.join( - self._transform_output_dir, self.POST_TRANSFORM_FEATURE_STATS_PATH) +class TFTransformOutput: + """A wrapper around the output of the tf.Transform.""" + + # Locations relative to the base output directory, where outputs of + # tf.Transform should be written in order to be read by TFTransformOutput. + # WriteTransformFn will follow these conventions. + TRANSFORMED_METADATA_DIR = "transformed_metadata" + TRANSFORM_FN_DIR = "transform_fn" + ASSET_MAP = "asset_map" + + def __init__(self, transform_output_dir: str): + """Init method for TFTransformOutput. + + Args: + ---- + transform_output_dir: The directory containig tf.Transform output. + """ + self._transform_output_dir = transform_output_dir + + # Lazily constructed properties. + self._transformed_metadata = None + self._raw_metadata = None + self._transform_features_layer = None + self._exported_as_v1_value = None + self._transformed_domains = None + + @property + def transformed_metadata(self) -> dataset_metadata.DatasetMetadata: + """A DatasetMetadata.""" + if self._transformed_metadata is None: + self._transformed_metadata = metadata_io.read_metadata( + self._transformed_metadata_dir + ) + return self._transformed_metadata + + @property + def transform_savedmodel_dir(self) -> str: + """A python str.""" + return os.path.join(self._transform_output_dir, self.TRANSFORM_FN_DIR) + + @property + def _exported_as_v1(self) -> bool: + """A boolean. + + Indicates whether the SavedModel was exported using TF 1.x or TF 2.x APIs. + """ + if self._exported_as_v1_value is None: + self._exported_as_v1_value = saved_transform_io.exported_as_v1( + self.transform_savedmodel_dir + ) + return self._exported_as_v1_value + + @property + def _transformed_metadata_dir(self) -> str: + return os.path.join(self._transform_output_dir, self.TRANSFORMED_METADATA_DIR) + + def transformed_feature_spec(self) -> Dict[str, common_types.FeatureSpecType]: + """Returns a feature_spec for the transformed features. + + Returns + ------- + A dict from feature names to FixedLenFeature/SparseFeature/VarLenFeature. + """ + return schema_utils.schema_as_feature_spec( + self.transformed_metadata.schema + ).feature_spec + + def transformed_domains(self) -> Dict[str, common_types.DomainType]: + """Returns domains for the transformed features. + + Returns + ------- + A dict from feature names to one of schema_pb2.IntDomain, + schema_pb2.StringDomain or schema_pb2.FloatDomain. + """ + if self._transformed_domains is None: + self._transformed_domains = schema_utils.schema_as_feature_spec( + self.transformed_metadata.schema + ).domains + return self._transformed_domains + + def vocabulary_file_by_name(self, vocab_filename: str) -> Optional[str]: + """Returns the vocabulary file path created in the preprocessing function. + + `vocab_filename` must either be (i) the name used as the vocab_filename + argument to tft.compute_and_apply_vocabulary / tft.vocabulary or (ii) the + key used in tft.annotate_asset. + + When a mapping has been specified by calls to tft.annotate_asset, it will be + checked first for the provided filename. If present, this filename will be + used directly to construct a path. + + If the mapping does not exist or `vocab_filename` is not present within it, + we will default to sanitizing `vocab_filename` and searching for files + matching it within the assets directory. + + In either case, if the constructed path does not point to an existing file + within the assets subdirectory, we will return a None. + + Args: + ---- + vocab_filename: The vocabulary name to lookup. + """ + mapping_path = os.path.join(self._transformed_metadata_dir, self.ASSET_MAP) + + mapping = {} + if tf.io.gfile.exists(mapping_path): + with tf.io.gfile.GFile(mapping_path) as f: + mapping = json.loads(f.read()) + if vocab_filename in mapping: + vocab_path = os.path.join( + self.transform_savedmodel_dir, + tf.saved_model.ASSETS_DIRECTORY, + mapping[vocab_filename], + ) + if tf.io.gfile.exists(vocab_path): + return vocab_path + + prefix = os.path.join( + self.transform_savedmodel_dir, + tf.saved_model.ASSETS_DIRECTORY, + sanitized_vocab_filename(filename=vocab_filename), + ) + files = tf.io.gfile.glob(prefix) + tf.io.gfile.glob(f"{prefix}.tfrecord.gz") + if not files: + return None + if len(files) != 1: + raise ValueError(f"Found too many vocabulary files: {files}") + return files[0] + + def _vocabulary_size_from_annotations(self, vocab_filename: str) -> Optional[int]: + """If vocabulary size is present in annotations return it, else None.""" + if not common.IS_ANNOTATIONS_PB_AVAILABLE: + return None + + try: + schema = self.transformed_metadata.schema + except OSError: + return None + + from tensorflow_transform import ( + annotations_pb2, # pylint: disable=g-import-not-at-top + ) + + for annotation in schema.annotation.extra_metadata: + message = annotations_pb2.VocabularyMetadata() + annotation.Unpack(message) + # Check message.filtered_vocabulary_size is not 0 for backwards + # compatibility. + if ( + message.file_name == vocab_filename + and message.filtered_vocabulary_size != 0 + ): + return message.filtered_vocabulary_size + + return None + + def vocabulary_size_by_name(self, vocab_filename: str) -> int: + """Like vocabulary_file_by_name, but returns the size of vocabulary.""" + vocab_size_from_annotations = self._vocabulary_size_from_annotations( + vocab_filename + ) + if vocab_size_from_annotations is not None: + return vocab_size_from_annotations + + vocab_path = self.vocabulary_file_by_name(vocab_filename) + if not vocab_path: + raise ValueError( + f"Could not compute vocabulary size for {vocab_filename}, does not exist" + ) + elif vocab_path.endswith("tfrecord.gz"): + dataset = tf.data.TFRecordDataset(vocab_path, compression_type="GZIP") + + def reduce_fn(accum, elem): + return tf.size(elem, out_type=tf.int64, name="vocabulary_size") + accum + + return _get_tensor_value( + dataset.batch(tf.int32.max).reduce(tf.constant(0, tf.int64), reduce_fn) + ) + else: + with tf.io.gfile.GFile(vocab_path, "rb") as f: + return sum(1 for _ in f) + + def vocabulary_by_name(self, vocab_filename: str) -> List[bytes]: + """Like vocabulary_file_by_name but returns a list.""" + vocab_path = self.vocabulary_file_by_name(vocab_filename) + if not vocab_path: + raise ValueError( + f"Could not read vocabulary: {vocab_filename}, does not exist" + ) + elif vocab_path.endswith("tfrecord.gz"): + dataset = tf.data.TFRecordDataset(vocab_path, compression_type="GZIP") + vocab_tensor = dataset.batch(tf.int32.max).reduce( + tf.constant([], dtype=tf.string), + lambda state, elem: tf.concat([state, elem], axis=-1), + ) + # Using as_numpy_iterator only works when executing eagerly. + return _get_tensor_value(vocab_tensor).tolist() + else: + with tf.io.gfile.GFile(vocab_path, "rb") as f: + return [l.rstrip(os.linesep.encode("utf-8")) for l in f] + + # TODO(KesterTong): Add test for this in output_wrapper_test.py + def num_buckets_for_transformed_feature(self, name: str) -> int: + """Returns the number of buckets for an integerized transformed feature.""" + # Do checks that this tensor can be wrapped in + # sparse_column_with_integerized_feature + try: + domain = self.transformed_domains()[name] + except KeyError: + raise ValueError(f"Column {name} did not have a domain provided.") + if not isinstance(domain, schema_pb2.IntDomain): + raise ValueError( + f"Column {name} has domain {domain}, expected an IntDomain" + ) + if domain.min != 0: + raise ValueError(f"Column {name} has min value {domain.min}, should be 0") + return domain.max + 1 + + def transform_features_layer(self) -> tf_keras.Model: + """Creates a `TransformFeaturesLayer` from this transform output. + + If a `TransformFeaturesLayer` has already been created for self, the same + one will be returned. + + Returns + ------- + A `TransformFeaturesLayer` instance. + """ + if self._transform_features_layer is None: + self._transform_features_layer = TransformFeaturesLayer( + self, exported_as_v1=self._exported_as_v1 + ) + return self._transform_features_layer + + def transform_raw_features( + self, + raw_features: Mapping[str, common_types.TensorType], + drop_unused_features: bool = True, # LEGACY_VALUE=False + ) -> Dict[str, common_types.TensorType]: + """Takes a dict of tensors representing raw features and transforms them. + + Takes a dictionary of `Tensor`, `SparseTensor`, or `RaggedTensor`s that + represent the raw features, and applies the transformation defined by + tf.Transform. + + If False it returns all transformed features defined by tf.Transform. To + only return features transformed from the given 'raw_features', set + `drop_unused_features` to True. + + Note: If eager execution is enabled and this API is invoked inside a + tf.function or an API that uses tf.function such as dataset.map, please use + `transform_features_layer` instead. It separates out loading of the + transform graph and hence resources will not be initialized on each + invocation. This can have significant performance improvement if the + transform graph was exported as a TF1 SavedModel and guarantees correctness + if it was exported as a TF2 SavedModel. + + Args: + ---- + raw_features: A dict whose keys are feature names and values are + `Tensor`s, `SparseTensor`s, or `RaggedTensor`s. + drop_unused_features: If True, the result will be filtered. Only the + features that are transformed from 'raw_features' will be included in + the returned result. If a feature is transformed from multiple raw + features (e.g, feature cross), it will only be included if all its base + raw features are present in `raw_features`. + + Returns: + ------- + A dict whose keys are feature names and values are `Tensor`s, + `SparseTensor`s, or `RaggedTensor`s representing transformed features. + """ + if self._exported_as_v1: + transformed_features = self._transform_raw_features_compat_v1( + raw_features, drop_unused_features + ) + else: + tft_layer = self.transform_features_layer() + if not drop_unused_features: + tf.compat.v1.logging.warning( + "Unused features are always dropped in the TF 2.x " + "implementation. Ignoring value of drop_unused_features." + ) + + transformed_features = tft_layer(raw_features) + return _TransformedFeaturesDict(transformed_features) + + def _transform_raw_features_compat_v1( + self, + raw_features: Mapping[str, common_types.TensorType], + drop_unused_features: bool, + ) -> Dict[str, common_types.TensorType]: + """Takes a dict of tensors representing raw features and transforms them.""" + unbounded_raw_features, transformed_features = ( + saved_transform_io.partially_apply_saved_transform_internal( + self.transform_savedmodel_dir, raw_features + ) + ) + if drop_unused_features: + graph = tf.compat.v1.get_default_graph() + graph_analyzer = graph_tools.InitializableGraphAnalyzer( + graph, + raw_features, + [(t, False) for t in unbounded_raw_features.values()], + ) + return { + name: feature + for name, feature in transformed_features.items() + if graph_analyzer.ready_to_run(feature) + } + else: + return transformed_features + + def load_transform_graph(self): + """Load the transform graph without replacing any placeholders. + + This is necessary to ensure that variables in the transform graph are + included in the training checkpoint when using tf.Estimator. This should + be called in the training input_fn. + """ + if self._exported_as_v1 is None: + self._exported_as_v1 = saved_transform_io.exported_as_v1( + self.transform_savedmodel_dir + ) + + if self._exported_as_v1: + saved_transform_io.partially_apply_saved_transform_internal( + self.transform_savedmodel_dir, {} + ) + else: + # Note: This should use the same mechanism as `transform_raw_features` to + # load the SavedModel into the current graph context. + _ = self.transform_features_layer()({}) + + RAW_METADATA_DIR = "metadata" + _FEATURE_STATS_PB = "FeatureStats.pb" + PRE_TRANSFORM_FEATURE_STATS_PATH = os.path.join( + "pre_transform_feature_stats", _FEATURE_STATS_PB + ) + POST_TRANSFORM_FEATURE_STATS_PATH = os.path.join( + "post_transform_feature_stats", _FEATURE_STATS_PB + ) + + @property + def raw_metadata(self) -> dataset_metadata.DatasetMetadata: + """A DatasetMetadata. + + Note: raw_metadata is not guaranteed to exist in the output of tf.transform + and hence using this could fail, if raw_metadata is not present in + TFTransformOutput. + + Returns + ------- + A DatasetMetadata + """ + if self._raw_metadata is None: + self._raw_metadata = metadata_io.read_metadata( + os.path.join(self._transform_output_dir, self.RAW_METADATA_DIR) + ) + return self._raw_metadata + + def raw_feature_spec(self) -> Dict[str, common_types.FeatureSpecType]: + """Returns a feature_spec for the raw features. + + Returns + ------- + A dict from feature names to FixedLenFeature/SparseFeature/VarLenFeature. + """ + return schema_utils.schema_as_feature_spec( + self.raw_metadata.schema + ).feature_spec + + def raw_domains(self) -> Dict[str, common_types.DomainType]: + """Returns domains for the raw features. + + Returns + ------- + A dict from feature names to one of schema_pb2.IntDomain, + schema_pb2.StringDomain or schema_pb2.FloatDomain. + """ + return schema_utils.schema_as_feature_spec(self.raw_metadata.schema).domains + + @property + def pre_transform_statistics_path(self) -> str: + """Returns the path to the pre-transform datum statistics. + + Note: pre_transform_statistics is not guaranteed to exist in the output of + tf.transform and hence using this could fail, if pre_transform statistics is + not present in TFTransformOutput. + """ + return os.path.join( + self._transform_output_dir, self.PRE_TRANSFORM_FEATURE_STATS_PATH + ) + + @property + def post_transform_statistics_path(self) -> str: + """Returns the path to the post-transform datum statistics. + + Note: post_transform_statistics is not guaranteed to exist in the output of + tf.transform and hence using this could fail, if post_transform statistics + is not present in TFTransformOutput. + """ + return os.path.join( + self._transform_output_dir, self.POST_TRANSFORM_FEATURE_STATS_PATH + ) # TODO(b/162055065): Possibly switch back to inherit from Layer when possible. -@tf_keras.utils.register_keras_serializable(package='TensorFlowTransform') +@tf_keras.utils.register_keras_serializable(package="TensorFlowTransform") class TransformFeaturesLayer(tf_keras.Model): - """A Keras layer for applying a tf.Transform output to input layers.""" - - def __init__(self, - tft_output: TFTransformOutput, - exported_as_v1: Optional[bool] = None): - super().__init__(trainable=False) - self._tft_output = tft_output - if exported_as_v1 is None: - self._exported_as_v1 = saved_transform_io.exported_as_v1( - tft_output.transform_savedmodel_dir) - else: - self._exported_as_v1 = exported_as_v1 - self._saved_model_loader_value = None - self._loaded_saved_model_graph = None - if tf.compat.v1.executing_eagerly_outside_functions(): - # The model must be tracked by assigning to an attribute of the Keras - # layer. Hence, we track the attributes of _saved_model_loader here as - # well. - self._saved_model_loader_tracked_dict = self._saved_model_loader.__dict__ - - # TODO(b/162055065): This is needed because otherwise we'd get an error in - # some cases: - # ValueError: Your Layer or Model is in an invalid state. This can happen - # if you are interleaving estimator/non-estimator models or interleaving - # models/layers made in tf.compat.v1.Graph.as_default() with models/layers - # created outside of it. Converting a model to an estimator (via - # model_to_estimator) invalidates all models/layers made before the - # conversion (even if they were not the model converted to an estimator). - # Similarly, making a layer or a model inside a a tf.compat.v1.Graph - # invalidates all layers/models you previously made outside of the graph. - self._originally_built_as_v1 = True - - @property - def _saved_model_loader(self) -> saved_transform_io_v2.SavedModelLoader: - """A `saved_transform_io_v2.SavedModelLoader`.""" - if self._saved_model_loader_value is None: - self._saved_model_loader_value = saved_transform_io_v2.SavedModelLoader( - self._tft_output.transform_savedmodel_dir) - self._loaded_saved_model_graph = ops.get_default_graph() - - if tf.compat.v1.executing_eagerly_outside_functions(): - return self._saved_model_loader_value - else: - assert not self._exported_as_v1 - # TODO(b/149997088): Raise an exception once we no longer support using - # the Keras layer with estimator based Trainer. - tf.compat.v1.logging.warning('Loading a TF2 SavedModel but eager mode ' - 'seems disabled.') - # If exported as TF2 SavedModel but not invoked in eager mode, - # re-initialize the saved_model_loader_value as __init__ could have been - # called in a different graph context. - default_graph = ops.get_default_graph() - if (self._loaded_saved_model_graph is None or - self._loaded_saved_model_graph is not default_graph): - self._saved_model_loader_value = saved_transform_io_v2.SavedModelLoader( - self._tft_output.transform_savedmodel_dir) - self._loaded_saved_model_graph = default_graph - return self._saved_model_loader_value - - def _init_batch_counters(self, *args, **kwargs): # pylint: disable=g-doc-args - """Overriding this method because Model's implementation creates variables. - - These Variables are not needed for TransformFeaturesLayer. - """ - pass - - def call( # pytype: disable=signature-mismatch # overriding-parameter-count-checks - self, inputs: Mapping[str, common_types.TensorType] - ) -> Dict[str, common_types.TensorType]: - - if self._exported_as_v1 and not ops.executing_eagerly_outside_functions(): - tf.compat.v1.logging.warning('Falling back to transform_raw_features...') - return self._tft_output._transform_raw_features_compat_v1( # pylint: disable=protected-access - inputs, - drop_unused_features=True) - else: - return self._saved_model_loader.apply_transform_model(inputs) + """A Keras layer for applying a tf.Transform output to input layers.""" + + def __init__( + self, tft_output: TFTransformOutput, exported_as_v1: Optional[bool] = None + ): + super().__init__(trainable=False) + self._tft_output = tft_output + if exported_as_v1 is None: + self._exported_as_v1 = saved_transform_io.exported_as_v1( + tft_output.transform_savedmodel_dir + ) + else: + self._exported_as_v1 = exported_as_v1 + self._saved_model_loader_value = None + self._loaded_saved_model_graph = None + if tf.compat.v1.executing_eagerly_outside_functions(): + # The model must be tracked by assigning to an attribute of the Keras + # layer. Hence, we track the attributes of _saved_model_loader here as + # well. + self._saved_model_loader_tracked_dict = self._saved_model_loader.__dict__ + + # TODO(b/162055065): This is needed because otherwise we'd get an error in + # some cases: + # ValueError: Your Layer or Model is in an invalid state. This can happen + # if you are interleaving estimator/non-estimator models or interleaving + # models/layers made in tf.compat.v1.Graph.as_default() with models/layers + # created outside of it. Converting a model to an estimator (via + # model_to_estimator) invalidates all models/layers made before the + # conversion (even if they were not the model converted to an estimator). + # Similarly, making a layer or a model inside a a tf.compat.v1.Graph + # invalidates all layers/models you previously made outside of the graph. + self._originally_built_as_v1 = True + + @property + def _saved_model_loader(self) -> saved_transform_io_v2.SavedModelLoader: + """A `saved_transform_io_v2.SavedModelLoader`.""" + if self._saved_model_loader_value is None: + self._saved_model_loader_value = saved_transform_io_v2.SavedModelLoader( + self._tft_output.transform_savedmodel_dir + ) + self._loaded_saved_model_graph = ops.get_default_graph() + + if tf.compat.v1.executing_eagerly_outside_functions(): + return self._saved_model_loader_value + else: + assert not self._exported_as_v1 + # TODO(b/149997088): Raise an exception once we no longer support using + # the Keras layer with estimator based Trainer. + tf.compat.v1.logging.warning( + "Loading a TF2 SavedModel but eager mode " "seems disabled." + ) + # If exported as TF2 SavedModel but not invoked in eager mode, + # re-initialize the saved_model_loader_value as __init__ could have been + # called in a different graph context. + default_graph = ops.get_default_graph() + if ( + self._loaded_saved_model_graph is None + or self._loaded_saved_model_graph is not default_graph + ): + self._saved_model_loader_value = saved_transform_io_v2.SavedModelLoader( + self._tft_output.transform_savedmodel_dir + ) + self._loaded_saved_model_graph = default_graph + return self._saved_model_loader_value + + def _init_batch_counters(self, *args, **kwargs): # pylint: disable=g-doc-args + """Overriding this method because Model's implementation creates variables. + + These Variables are not needed for TransformFeaturesLayer. + """ + pass + + def call( # pytype: disable=signature-mismatch # overriding-parameter-count-checks + self, inputs: Mapping[str, common_types.TensorType] + ) -> Dict[str, common_types.TensorType]: + if self._exported_as_v1 and not ops.executing_eagerly_outside_functions(): + tf.compat.v1.logging.warning("Falling back to transform_raw_features...") + return self._tft_output._transform_raw_features_compat_v1( # pylint: disable=protected-access + inputs, drop_unused_features=True + ) + else: + return self._saved_model_loader.apply_transform_model(inputs) def _make_method_override(name): + @doc_controls.do_not_generate_docs + def method_override(*args, **kwargs): + raise NotImplementedError(name) - @doc_controls.do_not_generate_docs - def method_override(*args, **kwargs): - raise NotImplementedError(name) - - return method_override + return method_override # TODO(zoyahav): Get rid of property attributes docs as well. def _override_parent_methods(keep_items): - """Makes inheritted attributes of the TFT layer unusable and undocumented.""" - for name in dir(tf_keras.Model): - if name.startswith('_') or name in keep_items: - continue - if callable(getattr(tf_keras.Model, name)): - setattr(TransformFeaturesLayer, name, _make_method_override(name)) - elif not isinstance(getattr(TransformFeaturesLayer, name), property): - doc_controls.do_not_generate_docs(getattr(TransformFeaturesLayer, name)) - - -_override_parent_methods(keep_items=[ - 'call', 'build', 'compute_mask', 'add_loss', 'count_params', - 'finalize_state', 'save_spec' -]) + """Makes inheritted attributes of the TFT layer unusable and undocumented.""" + for name in dir(tf_keras.Model): + if name.startswith("_") or name in keep_items: + continue + if callable(getattr(tf_keras.Model, name)): + setattr(TransformFeaturesLayer, name, _make_method_override(name)) + elif not isinstance(getattr(TransformFeaturesLayer, name), property): + doc_controls.do_not_generate_docs(getattr(TransformFeaturesLayer, name)) + + +_override_parent_methods( + keep_items=[ + "call", + "build", + "compute_mask", + "add_loss", + "count_params", + "finalize_state", + "save_spec", + ] +) diff --git a/tensorflow_transform/pickle_helper.py b/tensorflow_transform/pickle_helper.py index f47050d..0c0e6e1 100644 --- a/tensorflow_transform/pickle_helper.py +++ b/tensorflow_transform/pickle_helper.py @@ -14,17 +14,22 @@ """Functions to fix pickling of certain objects (see b/121323638).""" import copyreg + import tensorflow as tf +from tensorflow_metadata.proto.v0 import schema_pb2, statistics_pb2 + from tensorflow_transform import common -from tensorflow_metadata.proto.v0 import schema_pb2 -from tensorflow_metadata.proto.v0 import statistics_pb2 if common.IS_ANNOTATIONS_PB_AVAILABLE: - from tensorflow_transform import annotations_pb2 # pylint: disable=g-import-not-at-top + from tensorflow_transform import ( + annotations_pb2, # pylint: disable=g-import-not-at-top + ) -_ANNOTATION_CLASSES = [ - annotations_pb2.VocabularyMetadata, annotations_pb2.BucketBoundaries -] if common.IS_ANNOTATIONS_PB_AVAILABLE else [] +_ANNOTATION_CLASSES = ( + [annotations_pb2.VocabularyMetadata, annotations_pb2.BucketBoundaries] + if common.IS_ANNOTATIONS_PB_AVAILABLE + else [] +) _PROTO_CLASSES = [ tf.compat.v1.ConfigProto, @@ -34,30 +39,33 @@ ] + _ANNOTATION_CLASSES -_PROTO_CLS_BY_NAME = {proto_cls.DESCRIPTOR.name: proto_cls - for proto_cls in _PROTO_CLASSES} +_PROTO_CLS_BY_NAME = { + proto_cls.DESCRIPTOR.name: proto_cls for proto_cls in _PROTO_CLASSES +} def _pickle_proto(proto): - return _unpickle_proto, (proto.DESCRIPTOR.name, proto.SerializeToString()) + return _unpickle_proto, (proto.DESCRIPTOR.name, proto.SerializeToString()) def _unpickle_proto(name, serialized_proto): - return _PROTO_CLS_BY_NAME[name].FromString(serialized_proto) + return _PROTO_CLS_BY_NAME[name].FromString(serialized_proto) def _pickle_tensor_spec(tensor_spec): - return _unpickle_tensor_spec, (tensor_spec.shape.as_list(), - tensor_spec.dtype.as_numpy_dtype) + return _unpickle_tensor_spec, ( + tensor_spec.shape.as_list(), + tensor_spec.dtype.as_numpy_dtype, + ) def _unpickle_tensor_spec(shape, numpy_dtype): - return tf.TensorSpec(shape, tf.as_dtype(numpy_dtype)) + return tf.TensorSpec(shape, tf.as_dtype(numpy_dtype)) def fix_internal_object_pickling(): - """Fix pickling issues (see b/121323638).""" - for proto_cls in _PROTO_CLASSES: - copyreg.pickle(proto_cls, _pickle_proto) + """Fix pickling issues (see b/121323638).""" + for proto_cls in _PROTO_CLASSES: + copyreg.pickle(proto_cls, _pickle_proto) - copyreg.pickle(tf.TensorSpec, _pickle_tensor_spec) + copyreg.pickle(tf.TensorSpec, _pickle_tensor_spec) diff --git a/tensorflow_transform/pretrained_models.py b/tensorflow_transform/pretrained_models.py index 4653695..c8a5ee0 100644 --- a/tensorflow_transform/pretrained_models.py +++ b/tensorflow_transform/pretrained_models.py @@ -23,261 +23,284 @@ # TODO(b/141936246) Replace this function with a V2-safe way to load models. -def _get_variables(scope=None, - suffix=None, - collection=tf.compat.v1.GraphKeys.GLOBAL_VARIABLES): - """Gets the list of variables, filtered by scope and/or suffix. - - Taken from tensorflow/contrib/framework/python/ops/variables.py. - - Args: - scope: an optional scope for filtering the variables to return. Can be a - variable scope or a string. - suffix: an optional suffix for filtering the variables to return. - collection: in which collection search for. Defaults to - `GraphKeys.GLOBAL_VARIABLES`. - - Returns: - a list of variables in collection with scope and suffix. - """ - if scope is not None and isinstance(scope, tf.compat.v1.VariableScope): - scope = scope.name - if suffix is not None: - if ':' not in suffix: - suffix += ':' - scope = (scope or '') + '.*' + suffix - return tf.compat.v1.get_collection(collection, scope) +def _get_variables( + scope=None, suffix=None, collection=tf.compat.v1.GraphKeys.GLOBAL_VARIABLES +): + """Gets the list of variables, filtered by scope and/or suffix. + + Taken from tensorflow/contrib/framework/python/ops/variables.py. + + Args: + ---- + scope: an optional scope for filtering the variables to return. Can be a + variable scope or a string. + suffix: an optional suffix for filtering the variables to return. + collection: in which collection search for. Defaults to + `GraphKeys.GLOBAL_VARIABLES`. + + Returns: + ------- + a list of variables in collection with scope and suffix. + """ + if scope is not None and isinstance(scope, tf.compat.v1.VariableScope): + scope = scope.name + if suffix is not None: + if ":" not in suffix: + suffix += ":" + scope = (scope or "") + ".*" + suffix + return tf.compat.v1.get_collection(collection, scope) # TODO(b/141936246) Replace this function with a V2-safe way to load models. def _get_variables_to_restore(include=None, exclude=None): - """Gets the list of the variables to restore. - - Taken from tensorflow/contrib/framework/python/ops/variables.py. - - Args: - include: an optional list/tuple of scope strings for filtering which - variables from the VARIABLES collection to include. None would include all - the variables. - exclude: an optional list/tuple of scope strings for filtering which - variables from the VARIABLES collection to exclude. None it would not - exclude any. - - Returns: - a list of variables to restore. - - Raises: - TypeError: include or exclude is provided but is not a list or a tuple. - """ - if include is None: - # Include all variables. - vars_to_include = _get_variables() - else: - if not isinstance(include, (list, tuple)): - raise TypeError('include is provided but is not a list or a tuple.') - vars_to_include = [] - for scope in include: - vars_to_include += _get_variables(scope) - vars_to_exclude = set() - if exclude is not None: - if not isinstance(exclude, (list, tuple)): - raise TypeError('exclude is provided but is not a list or a tuple.') - for scope in exclude: - vars_to_exclude |= set(_get_variables(scope)) - # Exclude the variables in vars_to_exclude - return [v for v in vars_to_include if v not in vars_to_exclude] - - -def apply_saved_model(model_dir, inputs, tags, signature_name=None, - output_keys_in_signature=None): - """Applies a SavedModel to some `Tensor`s. - - Applies a SavedModel to `inputs`. The SavedModel is specified with - `model_dir`, `tags` and `signature_name`. Note that the SavedModel will be - converted to an all-constants graph. - - Note: This API can only be used when TF2 is disabled or - `tft_beam.Context.force_tf_compat_v1=True`. - - Args: - model_dir: A path containing a SavedModel. - inputs: A dict whose keys are the names from the input signature and whose - values are `Tensor`s. If there is only one input in the model's input - signature then `inputs` can be a single `Tensor`. - tags: The tags specifying which metagraph to load from the SavedModel. - signature_name: Specify signature of the loaded model. The default value - None can be used if there is only one signature in the MetaGraphDef. - output_keys_in_signature: A list of strings which should be a subset of - the outputs in the signature of the SavedModel. The returned `Tensor`s - will correspond to specified output `Tensor`s, in the same order. The - default value None can be used if there is only one output from - signature. - - Returns: - A `Tensor` or list of `Tensor`s representing the application of the - SavedModel. - - Raises: - ValueError: if - `inputs` is invalid type, or - `signature_name` is None but the SavedModel contains multiple signature, or - `inputs` do not match the signature inputs, or - `output_keys_in_signature` is not a subset of the signature outputs. - """ - # Load model, get graph, inputs and outputs. - loaded_graph = tf.compat.v1.Graph() - loaded_initializer_op_names = [] - - with loaded_graph.as_default(): - sess = tf.compat.v1.Session() - meta_graph = tf.compat.v1.saved_model.load(sess, - export_dir=model_dir, - tags=tags) - loaded_initializer_op_names = [ - op.name for op in tf.compat.v1.get_collection( - tf.compat.v1.GraphKeys.TABLE_INITIALIZERS) - ] - - if signature_name: - signature = meta_graph.signature_def[signature_name] - elif len(meta_graph.signature_def) > 1: - raise ValueError( - 'The SavedModel contains multiple signatures (%r) but signature_name ' - 'was not specified.' % (meta_graph.signature_def.keys(),)) + """Gets the list of the variables to restore. + + Taken from tensorflow/contrib/framework/python/ops/variables.py. + + Args: + ---- + include: an optional list/tuple of scope strings for filtering which + variables from the VARIABLES collection to include. None would include all + the variables. + exclude: an optional list/tuple of scope strings for filtering which + variables from the VARIABLES collection to exclude. None it would not + exclude any. + + Returns: + ------- + a list of variables to restore. + + Raises: + ------ + TypeError: include or exclude is provided but is not a list or a tuple. + """ + if include is None: + # Include all variables. + vars_to_include = _get_variables() else: - signature = next(iter(meta_graph.signature_def.values())) - - # Generate mapping from tensors in the graph to the input tensors. - if isinstance(inputs, dict): - if set(signature.inputs.keys()) != set(inputs.keys()): - raise ValueError( - 'The keys in `inputs` (%r) do not match inputs of the SavedModel ' - '(%r).' % (inputs.keys(), signature.inputs.keys())) - input_name_to_tensor_map = { - signature.inputs[key].name: inputs[key] - for key in inputs.keys()} - elif len(signature.inputs) != 1: - raise ValueError( - 'The SavedModel does not have exactly one input (had inputs %r) but ' - '`inputs` was not a dict.' % (signature.inputs.keys(),)) - else: - input_name_to_tensor_map = { - next(iter(signature.inputs.values())).name: inputs - } - - # Get output tensor names. - if output_keys_in_signature: - if not set(output_keys_in_signature) <= set(signature.outputs.keys()): - raise ValueError( - 'output_keys_in_signature (%r) is not a subset of outputs of the ' - 'SavedModel (%r).' - % (output_keys_in_signature, signature.outputs.keys())) - - output_tensor_names = [ - signature.outputs[key].name for key in output_keys_in_signature - ] - output_single_tensor = False - elif len(signature.outputs) != 1: - raise ValueError( - 'The SavedModel does not have exactly one output (had outputs %r) but ' - 'output_keys_in_signature was not specified.' - % (signature.outputs.keys(),)) - else: - output_tensor_names = [next(iter(signature.outputs.values())).name] - output_single_tensor = True - - # Convert_variables_to_constants() requires op name. - output_op_names = [loaded_graph.get_tensor_by_name(tensor_name).op.name - for tensor_name in output_tensor_names] - constant_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( - sess, loaded_graph.as_graph_def(), - output_op_names + loaded_initializer_op_names) - sess.close() - - returned_elements = tf.import_graph_def( - constant_graph_def, - input_map=input_name_to_tensor_map, - return_elements=output_tensor_names + loaded_initializer_op_names) - returned_output_tensors = returned_elements[:len(output_tensor_names)] - returned_initializer_ops = returned_elements[len(output_tensor_names):] - - for initializer_op in returned_initializer_ops: - tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS, - initializer_op) - - if output_single_tensor: - assert len(output_tensor_names) == 1 - return returned_output_tensors[0] - else: - return returned_output_tensors - - -def apply_function_with_checkpoint(fn, inputs, checkpoint, include=None, - exclude=None): - """Applies a tensor-in-tensor-out function with variables to some `Tensor`s. - - Variable values are loaded from the given checkpoint path. Note that the - input_tensor_func, together with the checkpoint, will be converted to an - all-constants graph, so ops requiring graph collections, such as table lookup - (which requires a table init op being added to TABLE_INITIALIZERS collection), - are not supported. - - Note: This API can only be used when TF2 is disabled or - `tft_beam.Context.force_tf_compat_v1=True`. - - Args: - fn: A tensor-in-tensor-out function that may contain variables. - inputs: A list of `Tensor`s to apply `fn` to. - checkpoint: The checkpoint path to load variables from. - include: An optional list/tuple of scope strings for filtering which - variables from the VARIABLES collection to include. If None, all - variables will be included. - exclude: An optional list/tuple of scope strings for filtering which - variables from the VARIABLES collection to exclude. If None, no - variables will be excluded. - - Returns: - A `Tensor` or list of `Tensor`s representing the application of `fn`. - - Raises: - ValueError: if the input tensor-in-tensor-out function adds to - TABLE_INITIALIZERS collections. - """ - loaded_graph = tf.compat.v1.Graph() - with loaded_graph.as_default(): - input_placeholders = [ - tf.compat.v1.placeholder( - dtype=tensor.dtype, shape=tensor.shape, name=tensor.op.name) - for tensor in inputs + if not isinstance(include, (list, tuple)): + raise TypeError("include is provided but is not a list or a tuple.") + vars_to_include = [] + for scope in include: + vars_to_include += _get_variables(scope) + vars_to_exclude = set() + if exclude is not None: + if not isinstance(exclude, (list, tuple)): + raise TypeError("exclude is provided but is not a list or a tuple.") + for scope in exclude: + vars_to_exclude |= set(_get_variables(scope)) + # Exclude the variables in vars_to_exclude + return [v for v in vars_to_include if v not in vars_to_exclude] + + +def apply_saved_model( + model_dir, inputs, tags, signature_name=None, output_keys_in_signature=None +): + """Applies a SavedModel to some `Tensor`s. + + Applies a SavedModel to `inputs`. The SavedModel is specified with + `model_dir`, `tags` and `signature_name`. Note that the SavedModel will be + converted to an all-constants graph. + + Note: This API can only be used when TF2 is disabled or + `tft_beam.Context.force_tf_compat_v1=True`. + + Args: + ---- + model_dir: A path containing a SavedModel. + inputs: A dict whose keys are the names from the input signature and whose + values are `Tensor`s. If there is only one input in the model's input + signature then `inputs` can be a single `Tensor`. + tags: The tags specifying which metagraph to load from the SavedModel. + signature_name: Specify signature of the loaded model. The default value + None can be used if there is only one signature in the MetaGraphDef. + output_keys_in_signature: A list of strings which should be a subset of + the outputs in the signature of the SavedModel. The returned `Tensor`s + will correspond to specified output `Tensor`s, in the same order. The + default value None can be used if there is only one output from + signature. + + Returns: + ------- + A `Tensor` or list of `Tensor`s representing the application of the + SavedModel. + + Raises: + ------ + ValueError: if + `inputs` is invalid type, or + `signature_name` is None but the SavedModel contains multiple signature, or + `inputs` do not match the signature inputs, or + `output_keys_in_signature` is not a subset of the signature outputs. + """ + # Load model, get graph, inputs and outputs. + loaded_graph = tf.compat.v1.Graph() + loaded_initializer_op_names = [] + + with loaded_graph.as_default(): + sess = tf.compat.v1.Session() + meta_graph = tf.compat.v1.saved_model.load( + sess, export_dir=model_dir, tags=tags + ) + loaded_initializer_op_names = [ + op.name + for op in tf.compat.v1.get_collection( + tf.compat.v1.GraphKeys.TABLE_INITIALIZERS + ) + ] + + if signature_name: + signature = meta_graph.signature_def[signature_name] + elif len(meta_graph.signature_def) > 1: + raise ValueError( + "The SavedModel contains multiple signatures (%r) but signature_name " + "was not specified." % (meta_graph.signature_def.keys(),) + ) + else: + signature = next(iter(meta_graph.signature_def.values())) + + # Generate mapping from tensors in the graph to the input tensors. + if isinstance(inputs, dict): + if set(signature.inputs.keys()) != set(inputs.keys()): + raise ValueError( + "The keys in `inputs` (%r) do not match inputs of the SavedModel " + "(%r)." % (inputs.keys(), signature.inputs.keys()) + ) + input_name_to_tensor_map = { + signature.inputs[key].name: inputs[key] for key in inputs.keys() + } + elif len(signature.inputs) != 1: + raise ValueError( + "The SavedModel does not have exactly one input (had inputs %r) but " + "`inputs` was not a dict." % (signature.inputs.keys(),) + ) + else: + input_name_to_tensor_map = {next(iter(signature.inputs.values())).name: inputs} + + # Get output tensor names. + if output_keys_in_signature: + if not set(output_keys_in_signature) <= set(signature.outputs.keys()): + raise ValueError( + "output_keys_in_signature (%r) is not a subset of outputs of the " + "SavedModel (%r)." + % (output_keys_in_signature, signature.outputs.keys()) + ) + + output_tensor_names = [ + signature.outputs[key].name for key in output_keys_in_signature + ] + output_single_tensor = False + elif len(signature.outputs) != 1: + raise ValueError( + "The SavedModel does not have exactly one output (had outputs %r) but " + "output_keys_in_signature was not specified." % (signature.outputs.keys(),) + ) + else: + output_tensor_names = [next(iter(signature.outputs.values())).name] + output_single_tensor = True + + # Convert_variables_to_constants() requires op name. + output_op_names = [ + loaded_graph.get_tensor_by_name(tensor_name).op.name + for tensor_name in output_tensor_names ] - output = fn(*input_placeholders) - if isinstance(output, tf.Tensor): - output_tensors = [output] - output_single_tensor = True + constant_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( + sess, loaded_graph.as_graph_def(), output_op_names + loaded_initializer_op_names + ) + sess.close() + + returned_elements = tf.import_graph_def( + constant_graph_def, + input_map=input_name_to_tensor_map, + return_elements=output_tensor_names + loaded_initializer_op_names, + ) + returned_output_tensors = returned_elements[: len(output_tensor_names)] + returned_initializer_ops = returned_elements[len(output_tensor_names) :] + + for initializer_op in returned_initializer_ops: + tf.compat.v1.add_to_collection( + tf.compat.v1.GraphKeys.TABLE_INITIALIZERS, initializer_op + ) + + if output_single_tensor: + assert len(output_tensor_names) == 1 + return returned_output_tensors[0] + else: + return returned_output_tensors + + +def apply_function_with_checkpoint(fn, inputs, checkpoint, include=None, exclude=None): + """Applies a tensor-in-tensor-out function with variables to some `Tensor`s. + + Variable values are loaded from the given checkpoint path. Note that the + input_tensor_func, together with the checkpoint, will be converted to an + all-constants graph, so ops requiring graph collections, such as table lookup + (which requires a table init op being added to TABLE_INITIALIZERS collection), + are not supported. + + Note: This API can only be used when TF2 is disabled or + `tft_beam.Context.force_tf_compat_v1=True`. + + Args: + ---- + fn: A tensor-in-tensor-out function that may contain variables. + inputs: A list of `Tensor`s to apply `fn` to. + checkpoint: The checkpoint path to load variables from. + include: An optional list/tuple of scope strings for filtering which + variables from the VARIABLES collection to include. If None, all + variables will be included. + exclude: An optional list/tuple of scope strings for filtering which + variables from the VARIABLES collection to exclude. If None, no + variables will be excluded. + + Returns: + ------- + A `Tensor` or list of `Tensor`s representing the application of `fn`. + + Raises: + ------ + ValueError: if the input tensor-in-tensor-out function adds to + TABLE_INITIALIZERS collections. + """ + loaded_graph = tf.compat.v1.Graph() + with loaded_graph.as_default(): + input_placeholders = [ + tf.compat.v1.placeholder( + dtype=tensor.dtype, shape=tensor.shape, name=tensor.op.name + ) + for tensor in inputs + ] + output = fn(*input_placeholders) + if isinstance(output, tf.Tensor): + output_tensors = [output] + output_single_tensor = True + else: + output_tensors = output + output_single_tensor = False + + # TODO(qimingj/kestert): Copy table initializers to the composed graph. + if tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS): + raise ValueError("Models with table init ops are not supported.") + + vars_to_restore = _get_variables_to_restore(include=include, exclude=exclude) + saver = tf.compat.v1.train.Saver(vars_to_restore) + with tf.compat.v1.Session() as sess: + saver.restore(sess, checkpoint) + output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( + sess, + loaded_graph.as_graph_def(), + [tensor.op.name for tensor in output_tensors], + ) + + input_map = {tensor.name: tensor for tensor in inputs} + output_tensors = tf.import_graph_def( + output_graph_def, + input_map=input_map, + return_elements=[tensor.name for tensor in output_tensors], + ) + + if output_single_tensor: + assert len(output_tensors) == 1 + return output_tensors[0] else: - output_tensors = output - output_single_tensor = False - - # TODO(qimingj/kestert): Copy table initializers to the composed graph. - if tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS): - raise ValueError('Models with table init ops are not supported.') - - vars_to_restore = _get_variables_to_restore(include=include, - exclude=exclude) - saver = tf.compat.v1.train.Saver(vars_to_restore) - with tf.compat.v1.Session() as sess: - saver.restore(sess, checkpoint) - output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( - sess, loaded_graph.as_graph_def(), - [tensor.op.name for tensor in output_tensors]) - - input_map = {tensor.name: tensor for tensor in inputs} - output_tensors = tf.import_graph_def( - output_graph_def, input_map=input_map, - return_elements=[tensor.name for tensor in output_tensors]) - - if output_single_tensor: - assert len(output_tensors) == 1 - return output_tensors[0] - else: - return output_tensors + return output_tensors diff --git a/tensorflow_transform/pretrained_models_test.py b/tensorflow_transform/pretrained_models_test.py index 9d729f6..4ded297 100644 --- a/tensorflow_transform/pretrained_models_test.py +++ b/tensorflow_transform/pretrained_models_test.py @@ -16,143 +16,170 @@ import os import tensorflow as tf + from tensorflow_transform import pretrained_models class PretrainedModelsTest(tf.test.TestCase): + def save_model_with_single_input(self, export_dir): + builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) + with tf.compat.v1.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + input1 = tf.compat.v1.placeholder( + dtype=tf.int32, shape=[5], name="myinput" + ) + initializer = tf.compat.v1.initializers.constant([1, 2, 3, 4, 5]) + with tf.compat.v1.variable_scope( + "Model", reuse=None, initializer=initializer + ): + v1 = tf.compat.v1.get_variable("v1", [5], dtype=tf.int32) + output1 = tf.add(v1, input1, name="myadd") + inputs = {"single_input": input1} + outputs = {"single_output": output1} + signature_def_map = { + "my_signature_single_input": tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( + inputs, outputs + ) + } + sess.run(tf.compat.v1.global_variables_initializer()) + builder.add_meta_graph_and_variables( + sess, [tf.saved_model.SERVING], signature_def_map=signature_def_map + ) + builder.save(False) - def save_model_with_single_input(self, export_dir): - builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) - with tf.compat.v1.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: - input1 = tf.compat.v1.placeholder( - dtype=tf.int32, shape=[5], name='myinput') - initializer = tf.compat.v1.initializers.constant([1, 2, 3, 4, 5]) - with tf.compat.v1.variable_scope( - 'Model', reuse=None, initializer=initializer): - v1 = tf.compat.v1.get_variable('v1', [5], dtype=tf.int32) - output1 = tf.add(v1, input1, name='myadd') - inputs = {'single_input': input1} - outputs = {'single_output': output1} - signature_def_map = { - 'my_signature_single_input': - tf.compat.v1.saved_model.signature_def_utils - .predict_signature_def(inputs, outputs) - } - sess.run(tf.compat.v1.global_variables_initializer()) - builder.add_meta_graph_and_variables( - sess, [tf.saved_model.SERVING], signature_def_map=signature_def_map) - builder.save(False) + def save_model_with_multi_inputs(self, export_dir): + builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) + with tf.compat.v1.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + input1 = tf.compat.v1.placeholder( + dtype=tf.int32, shape=[5], name="myinput1" + ) + input2 = tf.compat.v1.placeholder( + dtype=tf.int32, shape=[5], name="myinput2" + ) + input3 = tf.compat.v1.placeholder( + dtype=tf.int32, shape=[5], name="myinput3" + ) + initializer = tf.compat.v1.initializers.constant([1, 2, 3, 4, 5]) + with tf.compat.v1.variable_scope( + "Model", reuse=None, initializer=initializer + ): + v1 = tf.compat.v1.get_variable("v1", [5], dtype=tf.int32) + o1 = tf.add(v1, input1, name="myadd1") + o2 = tf.add(o1, input2, name="myadd2") + output1 = tf.add(o2, input3, name="myadd3") + inputs = { + "input_name1": input1, + "input_name2": input2, + "input_name3": input3, + } + outputs = {"single_output": output1} + signature_def_map = { + "my_signature_multi_input": tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( + inputs, outputs + ) + } + sess.run(tf.compat.v1.global_variables_initializer()) + builder.add_meta_graph_and_variables( + sess, [tf.saved_model.SERVING], signature_def_map=signature_def_map + ) + builder.save(False) - def save_model_with_multi_inputs(self, export_dir): - builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_dir) - with tf.compat.v1.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: - input1 = tf.compat.v1.placeholder( - dtype=tf.int32, shape=[5], name='myinput1') - input2 = tf.compat.v1.placeholder( - dtype=tf.int32, shape=[5], name='myinput2') - input3 = tf.compat.v1.placeholder( - dtype=tf.int32, shape=[5], name='myinput3') - initializer = tf.compat.v1.initializers.constant([1, 2, 3, 4, 5]) - with tf.compat.v1.variable_scope( - 'Model', reuse=None, initializer=initializer): - v1 = tf.compat.v1.get_variable('v1', [5], dtype=tf.int32) - o1 = tf.add(v1, input1, name='myadd1') - o2 = tf.add(o1, input2, name='myadd2') - output1 = tf.add(o2, input3, name='myadd3') - inputs = {'input_name1': input1, 'input_name2': input2, - 'input_name3': input3} - outputs = {'single_output': output1} - signature_def_map = { - 'my_signature_multi_input': - tf.compat.v1.saved_model.signature_def_utils - .predict_signature_def(inputs, outputs) - } - sess.run(tf.compat.v1.global_variables_initializer()) - builder.add_meta_graph_and_variables( - sess, [tf.saved_model.SERVING], signature_def_map=signature_def_map) - builder.save(False) + def make_tensor_fn_two_inputs(self): + def tensor_fn(input1, input2): + initializer = tf.compat.v1.initializers.constant([1, 2, 3]) + with tf.compat.v1.variable_scope( + "Model", reuse=None, initializer=initializer + ): + v1 = tf.compat.v1.get_variable("v1", [3], dtype=tf.int64) + o1 = tf.add(v1, input1, name="myadda1") + o = tf.subtract(o1, input2, name="myadda2") + return o - def make_tensor_fn_two_inputs(self): - def tensor_fn(input1, input2): - initializer = tf.compat.v1.initializers.constant([1, 2, 3]) - with tf.compat.v1.variable_scope( - 'Model', reuse=None, initializer=initializer): - v1 = tf.compat.v1.get_variable('v1', [3], dtype=tf.int64) - o1 = tf.add(v1, input1, name='myadda1') - o = tf.subtract(o1, input2, name='myadda2') - return o - return tensor_fn + return tensor_fn - def save_checkpoint_with_two_inputs(self, checkpoint_path): - test_tensor_fn = self.make_tensor_fn_two_inputs() - with tf.compat.v1.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: - input1 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='myinputa') - input2 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='myinputb') - test_tensor_fn(input1, input2) - saver = tf.compat.v1.train.Saver() - sess.run(tf.compat.v1.global_variables_initializer()) - saver.save(sess, checkpoint_path) + def save_checkpoint_with_two_inputs(self, checkpoint_path): + test_tensor_fn = self.make_tensor_fn_two_inputs() + with tf.compat.v1.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + input1 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="myinputa" + ) + input2 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="myinputb" + ) + test_tensor_fn(input1, input2) + saver = tf.compat.v1.train.Saver() + sess.run(tf.compat.v1.global_variables_initializer()) + saver.save(sess, checkpoint_path) - def testApplySavedModelSingleInput(self): - export_dir = os.path.join(self.get_temp_dir(), 'single_input') - self.save_model_with_single_input(export_dir) - with tf.compat.v1.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: - input_tensor = tf.compat.v1.placeholder( - dtype=tf.int32, shape=[5], name='input_tensor') - output_tensor = pretrained_models.apply_saved_model( - export_dir, input_tensor, [tf.saved_model.SERVING]) - feed_dict = {input_tensor: [2, 2, 2, 2, 2]} - output_value = sess.run(output_tensor, feed_dict=feed_dict) - self.assertAllEqual(output_value, [3, 4, 5, 6, 7]) + def testApplySavedModelSingleInput(self): + export_dir = os.path.join(self.get_temp_dir(), "single_input") + self.save_model_with_single_input(export_dir) + with tf.compat.v1.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + input_tensor = tf.compat.v1.placeholder( + dtype=tf.int32, shape=[5], name="input_tensor" + ) + output_tensor = pretrained_models.apply_saved_model( + export_dir, input_tensor, [tf.saved_model.SERVING] + ) + feed_dict = {input_tensor: [2, 2, 2, 2, 2]} + output_value = sess.run(output_tensor, feed_dict=feed_dict) + self.assertAllEqual(output_value, [3, 4, 5, 6, 7]) - def testApplySavedModelMultiInputs(self): - export_dir = os.path.join(self.get_temp_dir(), 'multi_inputs') - self.save_model_with_multi_inputs(export_dir) - with tf.compat.v1.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: - input_tensor_1 = tf.compat.v1.placeholder( - dtype=tf.int32, shape=[5], name='input_tensor_1') - input_tensor_2 = tf.compat.v1.placeholder( - dtype=tf.int32, shape=[5], name='input_tensor_2') - input_tensor_3 = tf.compat.v1.placeholder( - dtype=tf.int32, shape=[5], name='input_tensor_3') - inputs = { - 'input_name1': input_tensor_1, - 'input_name2': input_tensor_2, - 'input_name3': input_tensor_3 - } - output_tensor = pretrained_models.apply_saved_model( - export_dir, - inputs, [tf.saved_model.SERVING], - signature_name='my_signature_multi_input') - feed_dict = {input_tensor_1: [2, 3, 4, 5, 6], - input_tensor_2: [1, 1, 1, 1, 1], - input_tensor_3: [1, 1, 1, 1, -1]} - output_value = sess.run(output_tensor, feed_dict=feed_dict) - self.assertAllEqual(output_value, [5, 7, 9, 11, 11]) + def testApplySavedModelMultiInputs(self): + export_dir = os.path.join(self.get_temp_dir(), "multi_inputs") + self.save_model_with_multi_inputs(export_dir) + with tf.compat.v1.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + input_tensor_1 = tf.compat.v1.placeholder( + dtype=tf.int32, shape=[5], name="input_tensor_1" + ) + input_tensor_2 = tf.compat.v1.placeholder( + dtype=tf.int32, shape=[5], name="input_tensor_2" + ) + input_tensor_3 = tf.compat.v1.placeholder( + dtype=tf.int32, shape=[5], name="input_tensor_3" + ) + inputs = { + "input_name1": input_tensor_1, + "input_name2": input_tensor_2, + "input_name3": input_tensor_3, + } + output_tensor = pretrained_models.apply_saved_model( + export_dir, + inputs, + [tf.saved_model.SERVING], + signature_name="my_signature_multi_input", + ) + feed_dict = { + input_tensor_1: [2, 3, 4, 5, 6], + input_tensor_2: [1, 1, 1, 1, 1], + input_tensor_3: [1, 1, 1, 1, -1], + } + output_value = sess.run(output_tensor, feed_dict=feed_dict) + self.assertAllEqual(output_value, [5, 7, 9, 11, 11]) - def testApplyFunctionWithCheckpointTwoInputs(self): - checkpoint = os.path.join(self.get_temp_dir(), 'checkpoint_two') - self.save_checkpoint_with_two_inputs(checkpoint) - with tf.compat.v1.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: - input1 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='input1') - input2 = tf.compat.v1.placeholder( - dtype=tf.int64, shape=[3], name='input2') - output_tensor = pretrained_models.apply_function_with_checkpoint( - self.make_tensor_fn_two_inputs(), [input1, input2], checkpoint) - feed_dict = {input1: [1, 2, 3], input2: [3, 2, 1]} - output_value = sess.run(output_tensor, feed_dict=feed_dict) - # [1, 2, 3] + [1, 2, 3] - [3, 2, 1] = [-1, 2, 5] - self.assertAllEqual(output_value, [-1, 2, 5]) + def testApplyFunctionWithCheckpointTwoInputs(self): + checkpoint = os.path.join(self.get_temp_dir(), "checkpoint_two") + self.save_checkpoint_with_two_inputs(checkpoint) + with tf.compat.v1.Graph().as_default() as graph: + with self.test_session(graph=graph) as sess: + input1 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="input1" + ) + input2 = tf.compat.v1.placeholder( + dtype=tf.int64, shape=[3], name="input2" + ) + output_tensor = pretrained_models.apply_function_with_checkpoint( + self.make_tensor_fn_two_inputs(), [input1, input2], checkpoint + ) + feed_dict = {input1: [1, 2, 3], input2: [3, 2, 1]} + output_value = sess.run(output_tensor, feed_dict=feed_dict) + # [1, 2, 3] + [1, 2, 3] - [3, 2, 1] = [-1, 2, 5] + self.assertAllEqual(output_value, [-1, 2, 5]) -if __name__ == '__main__': - tf.test.main() +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_transform/py.typed b/tensorflow_transform/py.typed index 73dc702..8688373 100644 --- a/tensorflow_transform/py.typed +++ b/tensorflow_transform/py.typed @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/tensorflow_transform/py_func/__init__.py b/tensorflow_transform/py_func/__init__.py index c88a163..f93770d 100644 --- a/tensorflow_transform/py_func/__init__.py +++ b/tensorflow_transform/py_func/__init__.py @@ -14,4 +14,6 @@ """Module level imports for tensorflow_transform.py_func.""" from tensorflow_transform.py_func.api import apply_pyfunc -from tensorflow_transform.py_func.pyfunc_helper import register_pyfuncs_from_saved_transform +from tensorflow_transform.py_func.pyfunc_helper import ( + register_pyfuncs_from_saved_transform, +) diff --git a/tensorflow_transform/py_func/api.py b/tensorflow_transform/py_func/api.py index 2fadc81..0111a48 100644 --- a/tensorflow_transform/py_func/api.py +++ b/tensorflow_transform/py_func/api.py @@ -18,46 +18,49 @@ # TODO(b/178867088): Figure out the TF2 compatibility plan for this API. def apply_pyfunc(func, Tout, stateful=True, name=None, *args): # pylint: disable=invalid-name - """Applies a python function to some `Tensor`s. - - Applies a python function to some `Tensor`s given by the argument list. The - number of arguments should match the number of inputs to the function. - - This function is for using inside a preprocessing_fn. It is a wrapper around - `tf.py_func`. A function added this way can run in Transform, and during - training when the graph is imported using the `transform_raw_features` method - of the `TFTransformOutput` class. However if the resulting training graph is - serialized and deserialized, then the `tf.py_func` op will not work and will - cause an error. This means that TensorFlow Serving will not be able to serve - this graph. - - The underlying reason for this limited support is that `tf.py_func` ops were - not designed to be serialized since they contain a reference to arbitrary - Python functions. This function pickles those functions and including them in - the graph, and `transform_raw_features` similarly unpickles the functions. - But unpickling requires a Python environment, so there it's not possible to - provide support in non-Python languages for loading such ops. Therefore - loading these ops in libraries such as TensorFlow Serving is not supported. - - Note: This API can only be used when TF2 is disabled or - `tft_beam.Context.force_tf_compat_v1=True`. - - Args: - func: A Python function, which accepts a list of NumPy `ndarray` objects - having element types that match the corresponding `tf.Tensor` objects - in `*args`, and returns a list of `ndarray` objects (or a single - `ndarray`) having element types that match the corresponding values - in `Tout`. - Tout: A list or tuple of tensorflow data types or a single tensorflow data - type if there is only one, indicating what `func` returns. - stateful: (Boolean.) If True, the function should be considered stateful. - If a function is stateless, when given the same input it will return the - same output and have no observable side effects. Optimizations such as - common subexpression elimination are only performed on stateless - operations. - name: A name for the operation (optional). - *args: The list of `Tensor`s to apply the arguments to. - Returns: - A `Tensor` representing the application of the function. - """ - return pyfunc_helper.insert_pyfunc(func, Tout, stateful, name, *args) + """Applies a python function to some `Tensor`s. + + Applies a python function to some `Tensor`s given by the argument list. The + number of arguments should match the number of inputs to the function. + + This function is for using inside a preprocessing_fn. It is a wrapper around + `tf.py_func`. A function added this way can run in Transform, and during + training when the graph is imported using the `transform_raw_features` method + of the `TFTransformOutput` class. However if the resulting training graph is + serialized and deserialized, then the `tf.py_func` op will not work and will + cause an error. This means that TensorFlow Serving will not be able to serve + this graph. + + The underlying reason for this limited support is that `tf.py_func` ops were + not designed to be serialized since they contain a reference to arbitrary + Python functions. This function pickles those functions and including them in + the graph, and `transform_raw_features` similarly unpickles the functions. + But unpickling requires a Python environment, so there it's not possible to + provide support in non-Python languages for loading such ops. Therefore + loading these ops in libraries such as TensorFlow Serving is not supported. + + Note: This API can only be used when TF2 is disabled or + `tft_beam.Context.force_tf_compat_v1=True`. + + Args: + ---- + func: A Python function, which accepts a list of NumPy `ndarray` objects + having element types that match the corresponding `tf.Tensor` objects + in `*args`, and returns a list of `ndarray` objects (or a single + `ndarray`) having element types that match the corresponding values + in `Tout`. + Tout: A list or tuple of tensorflow data types or a single tensorflow data + type if there is only one, indicating what `func` returns. + stateful: (Boolean.) If True, the function should be considered stateful. + If a function is stateless, when given the same input it will return the + same output and have no observable side effects. Optimizations such as + common subexpression elimination are only performed on stateless + operations. + name: A name for the operation (optional). + *args: The list of `Tensor`s to apply the arguments to. + + Returns: + ------- + A `Tensor` representing the application of the function. + """ + return pyfunc_helper.insert_pyfunc(func, Tout, stateful, name, *args) diff --git a/tensorflow_transform/py_func/pyfunc_helper.py b/tensorflow_transform/py_func/pyfunc_helper.py index c596d48..444ffd6 100644 --- a/tensorflow_transform/py_func/pyfunc_helper.py +++ b/tensorflow_transform/py_func/pyfunc_helper.py @@ -15,50 +15,54 @@ import dill import tensorflow as tf + +# pylint: disable=g-direct-tensorflow-import +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import ops from tfx_bsl import beam as tfx_bsl_beam + # TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` # once the Spark issue is resolved. from tfx_bsl.types import tfx_namedtuple -# pylint: disable=g-direct-tensorflow-import -from tensorflow.core.framework import attr_value_pb2 -from tensorflow.python.framework import ops # pylint: enable=g-direct-tensorflow-import -_PYFUNC_COLLECTION_KEY = 'pyfuncs' +_PYFUNC_COLLECTION_KEY = "pyfuncs" tfx_bsl_beam.fix_code_type_pickling() -class _PyFuncDef(tfx_namedtuple.namedtuple('_PyFuncDef', ['token', 'func'])): - """An internal wrapper around tuple(token, func). +class _PyFuncDef(tfx_namedtuple.namedtuple("_PyFuncDef", ["token", "func"])): + """An internal wrapper around tuple(token, func). - `token` can be either a single token (if the py_func returns a tensor), or a - list of tokens (if the py_func returns a list of tensors). + `token` can be either a single token (if the py_func returns a tensor), or a + list of tokens (if the py_func returns a list of tensors). - The main purpose of this class is to provides the two methods: - `from_proto` and `to_proto` that enable storing tuple objects in the graph's - collections as proto objects. - """ - __slots__ = () + The main purpose of this class is to provides the two methods: + `from_proto` and `to_proto` that enable storing tuple objects in the graph's + collections as proto objects. + """ - @staticmethod - def from_proto(attr_value, import_scope=None): - del import_scope # Unused - return dill.loads(attr_value.s) + __slots__ = () - @staticmethod - def from_proto_string(proto_str, import_scope=None): - del import_scope # Unused - attr_value = attr_value_pb2.AttrValue() - attr_value.ParseFromString(proto_str) - return _PyFuncDef.from_proto(attr_value) + @staticmethod + def from_proto(attr_value, import_scope=None): + del import_scope # Unused + return dill.loads(attr_value.s) + + @staticmethod + def from_proto_string(proto_str, import_scope=None): + del import_scope # Unused + attr_value = attr_value_pb2.AttrValue() + attr_value.ParseFromString(proto_str) + return _PyFuncDef.from_proto(attr_value) + + def to_proto(self, export_scope=None): + del export_scope # Unused + result = attr_value_pb2.AttrValue() + result.s = dill.dumps(self) + return result - def to_proto(self, export_scope=None): - del export_scope # Unused - result = attr_value_pb2.AttrValue() - result.s = dill.dumps(self) - return result # Register the pyfuncs collection to use `AttrValue` proto type. # The proto object stored in the graph collection will contain the pickled value @@ -66,102 +70,108 @@ def to_proto(self, export_scope=None): # Note that `AttrValue` is used here only as a convenient placeholder for a # string, and does not represent the actual attributes of an `op` as in the # usual case. -ops.register_proto_function(_PYFUNC_COLLECTION_KEY, - proto_type=attr_value_pb2.AttrValue, - to_proto=_PyFuncDef.to_proto, - from_proto=_PyFuncDef.from_proto) +ops.register_proto_function( + _PYFUNC_COLLECTION_KEY, + proto_type=attr_value_pb2.AttrValue, + to_proto=_PyFuncDef.to_proto, + from_proto=_PyFuncDef.from_proto, +) def insert_pyfunc(func, Tout, stateful, name, *args): # pylint: disable=invalid-name - """Calls tf.py_func and inserts the `func` in the internal registry.""" - result = tf.compat.v1.py_func( - func, inp=list(args), Tout=Tout, stateful=stateful, name=name) - # A py_func can either return a tensor or a list. Since we care only about the - # op, it doesn't matter which result we take. - if isinstance(result, list): - first_result = result[0] if result else None - else: - first_result = result - if first_result is None: - raise ValueError('func must return a tensor or list of tensors') - token = first_result.op.node_def.attr['token'].s - tf.compat.v1.add_to_collection(_PYFUNC_COLLECTION_KEY, - _PyFuncDef(token, func)) - return result + """Calls tf.py_func and inserts the `func` in the internal registry.""" + result = tf.compat.v1.py_func( + func, inp=list(args), Tout=Tout, stateful=stateful, name=name + ) + # A py_func can either return a tensor or a list. Since we care only about the + # op, it doesn't matter which result we take. + if isinstance(result, list): + first_result = result[0] if result else None + else: + first_result = result + if first_result is None: + raise ValueError("func must return a tensor or list of tensors") + token = first_result.op.node_def.attr["token"].s + tf.compat.v1.add_to_collection(_PYFUNC_COLLECTION_KEY, _PyFuncDef(token, func)) + return result def register_pyfuncs_from_saved_transform(graph, meta_graph, loaded_in_tf2): - """Registers `py_func`s in the MetaGraphDef. - - Takes the picked `py_func`s stored in the MetaGraphDef and adds them to the - graph. Registered `py_func`s are referred to internally by the token - attribute of the `py_func` op. We first create some arbitrary ops which - are not used, but which result in the pickled functions stored in the - MetaGraphDef being registered. We then take the tokens of these newly - registered functions, and remap the tokens in the MetaGraphDef to contain - the new tokens for each function (this remapping is required since we cannot - specify what token should be used to register a function). - - Args: - graph: The tf.Graph into which the meta_graph_def will be imported. - meta_graph: The MetaGraphDef containing the `py_func`s. All the `py_func` - ops in the graph will be modified in-place to have their token point to - the newly regsitered function. - loaded_in_tf2: A boolean indicating whether the saved transform is being - re-loaded in TF1 or TF2. - - Returns: - Modified graph_def if pyfuncs were found, else None. - - Raises: - ValueError if an unregistered pyfunc is encountered in `graph`. - """ - if _PYFUNC_COLLECTION_KEY not in meta_graph.collection_def: - return None - - # TODO(b/35929054) to enable it in TF itself. Once supported, - # we should refactor this code to remove extra work for pickling and - # re-registering of the py_funcs. - pyfuncs_collection = meta_graph.collection_def[_PYFUNC_COLLECTION_KEY] - - new_tokens_by_old_token = {} - with graph.as_default(): - for func_def_str in pyfuncs_collection.bytes_list.value: - func_def = _PyFuncDef.from_proto_string(func_def_str) - # Re-insert the original python function into the default graph. - # The operation itself in the graph does not matter (hence the dummy - # values for name, Tout, and stateful). This is done only to reinsert - # the function body in the internal TF's function registry. - # TODO(b/123241062): We should even remove this op from the graph if - # possible. - func_temp_name = func_def.token + b'_temp' - output_tensor = insert_pyfunc( - func_def.func, tf.float32, False, func_temp_name) - # Store the token associated with the function associated with the call - # to tf.py_func. - token = output_tensor.op.get_attr('token') - new_tokens_by_old_token[func_def.token] = token - - if loaded_in_tf2: - graph_def = graph.as_graph_def() - # Since we are updating the GraphDef of the graph in whose context pyfuncs - # were re-inserted, new tokens will also be present. - expected_tokens_in_graph_def = ( - list(new_tokens_by_old_token.keys()) + - list(new_tokens_by_old_token.values())) - else: - graph_def = meta_graph.graph_def - expected_tokens_in_graph_def = new_tokens_by_old_token.keys() - # Swap the old token stored for the function with the new one, if there are - # any tokens to change. - if new_tokens_by_old_token: - for node in graph_def.node: - if node.op == 'PyFunc' or node.op == 'PyFuncStateless': - token = node.attr['token'] - new_token = new_tokens_by_old_token.get(token.s, None) - if new_token is not None: - token.s = new_token - else: - if token.s not in expected_tokens_in_graph_def: - raise ValueError(f'Function: {node.name} was not registered') - return graph_def + """Registers `py_func`s in the MetaGraphDef. + + Takes the picked `py_func`s stored in the MetaGraphDef and adds them to the + graph. Registered `py_func`s are referred to internally by the token + attribute of the `py_func` op. We first create some arbitrary ops which + are not used, but which result in the pickled functions stored in the + MetaGraphDef being registered. We then take the tokens of these newly + registered functions, and remap the tokens in the MetaGraphDef to contain + the new tokens for each function (this remapping is required since we cannot + specify what token should be used to register a function). + + Args: + ---- + graph: The tf.Graph into which the meta_graph_def will be imported. + meta_graph: The MetaGraphDef containing the `py_func`s. All the `py_func` + ops in the graph will be modified in-place to have their token point to + the newly regsitered function. + loaded_in_tf2: A boolean indicating whether the saved transform is being + re-loaded in TF1 or TF2. + + Returns: + ------- + Modified graph_def if pyfuncs were found, else None. + + Raises: + ------ + ValueError if an unregistered pyfunc is encountered in `graph`. + """ + if _PYFUNC_COLLECTION_KEY not in meta_graph.collection_def: + return None + + # TODO(b/35929054) to enable it in TF itself. Once supported, + # we should refactor this code to remove extra work for pickling and + # re-registering of the py_funcs. + pyfuncs_collection = meta_graph.collection_def[_PYFUNC_COLLECTION_KEY] + + new_tokens_by_old_token = {} + with graph.as_default(): + for func_def_str in pyfuncs_collection.bytes_list.value: + func_def = _PyFuncDef.from_proto_string(func_def_str) + # Re-insert the original python function into the default graph. + # The operation itself in the graph does not matter (hence the dummy + # values for name, Tout, and stateful). This is done only to reinsert + # the function body in the internal TF's function registry. + # TODO(b/123241062): We should even remove this op from the graph if + # possible. + func_temp_name = func_def.token + b"_temp" + output_tensor = insert_pyfunc( + func_def.func, tf.float32, False, func_temp_name + ) + # Store the token associated with the function associated with the call + # to tf.py_func. + token = output_tensor.op.get_attr("token") + new_tokens_by_old_token[func_def.token] = token + + if loaded_in_tf2: + graph_def = graph.as_graph_def() + # Since we are updating the GraphDef of the graph in whose context pyfuncs + # were re-inserted, new tokens will also be present. + expected_tokens_in_graph_def = list(new_tokens_by_old_token.keys()) + list( + new_tokens_by_old_token.values() + ) + else: + graph_def = meta_graph.graph_def + expected_tokens_in_graph_def = new_tokens_by_old_token.keys() + # Swap the old token stored for the function with the new one, if there are + # any tokens to change. + if new_tokens_by_old_token: + for node in graph_def.node: + if node.op == "PyFunc" or node.op == "PyFuncStateless": + token = node.attr["token"] + new_token = new_tokens_by_old_token.get(token.s, None) + if new_token is not None: + token.s = new_token + else: + if token.s not in expected_tokens_in_graph_def: + raise ValueError(f"Function: {node.name} was not registered") + return graph_def diff --git a/tensorflow_transform/saved/constants.py b/tensorflow_transform/saved/constants.py index e215850..d549e79 100644 --- a/tensorflow_transform/saved/constants.py +++ b/tensorflow_transform/saved/constants.py @@ -14,6 +14,6 @@ """Constants for tf.Transform SavedModels.""" # TODO(b/123243166) eventually migrate this constant to tag_constants.TRANSFORM. -TRANSFORM_TAG = 'transform' +TRANSFORM_TAG = "transform" -TRANSFORM_SIGNATURE = 'transform_signature' +TRANSFORM_SIGNATURE = "transform_signature" diff --git a/tensorflow_transform/saved/saved_model_loader.py b/tensorflow_transform/saved/saved_model_loader.py index e008265..725baae 100644 --- a/tensorflow_transform/saved/saved_model_loader.py +++ b/tensorflow_transform/saved/saved_model_loader.py @@ -13,72 +13,83 @@ # limitations under the License. """Utility functions to build input_fns for use with tf.Learn.""" +from tensorflow.python.saved_model import ( + loader_impl, # pylint: disable=g-direct-tensorflow-import +) + from tensorflow_transform.saved import constants -from tensorflow.python.saved_model import loader_impl # pylint: disable=g-direct-tensorflow-import def parse_saved_model(saved_model_dir): - return loader_impl.parse_saved_model(saved_model_dir) + return loader_impl.parse_saved_model(saved_model_dir) def _choose_meta_graph_def_internal(saved_model, tags): - """Find a MetaGraphDef within the SavedModel with exactly matching tags. - - Args: - saved_model: A `SavedModel` protocol buffer. - tags: Set of string tags to identify the required MetaGraphDef. These should - correspond to the tags used when saving the variables using the - SavedModel `save()` API. - Returns: - The chosen `MetaGraphDef` protocol buffer. This can be used to further - extract signature-defs, collection-defs, etc. If tags cannot be found, - returns None. - """ - result = None - for meta_graph_def in saved_model.meta_graphs: - if set(meta_graph_def.meta_info_def.tags) == set(tags): - result = meta_graph_def - break - - return result + """Find a MetaGraphDef within the SavedModel with exactly matching tags. + + Args: + ---- + saved_model: A `SavedModel` protocol buffer. + tags: Set of string tags to identify the required MetaGraphDef. These should + correspond to the tags used when saving the variables using the + SavedModel `save()` API. + + Returns: + ------- + The chosen `MetaGraphDef` protocol buffer. This can be used to further + extract signature-defs, collection-defs, etc. If tags cannot be found, + returns None. + """ + result = None + for meta_graph_def in saved_model.meta_graphs: + if set(meta_graph_def.meta_info_def.tags) == set(tags): + result = meta_graph_def + break + + return result def choose_meta_graph_def(saved_model): - """Find a MetaGraphDef in the SavedModel with tag `constants.TRANSFORM_TAG`. + """Find a MetaGraphDef in the SavedModel with tag `constants.TRANSFORM_TAG`. - Args: - saved_model: A `SavedModel` protocol buffer. + Args: + ---- + saved_model: A `SavedModel` protocol buffer. - Returns: - The chosen `MetaGraphDef` protocol buffer. This can be used to further - extract signature-defs, collection-defs, etc. If tags cannot be found, - returns None. - """ - return _choose_meta_graph_def_internal(saved_model, [constants.TRANSFORM_TAG]) + Returns: + ------- + The chosen `MetaGraphDef` protocol buffer. This can be used to further + extract signature-defs, collection-defs, etc. If tags cannot be found, + returns None. + """ + return _choose_meta_graph_def_internal(saved_model, [constants.TRANSFORM_TAG]) def choose_meta_graph_def_and_raise(saved_model): - """Find a MetaGraphDef in the SavedModel with tag `constants.TRANSFORM_TAG`. + """Find a MetaGraphDef in the SavedModel with tag `constants.TRANSFORM_TAG`. - Args: - saved_model: A `SavedModel` protocol buffer. + Args: + ---- + saved_model: A `SavedModel` protocol buffer. - Returns: - The chosen `MetaGraphDef` protocol buffer. This can be used to further - extract signature-defs, collection-defs, etc. + Returns: + ------- + The chosen `MetaGraphDef` protocol buffer. This can be used to further + extract signature-defs, collection-defs, etc. - Raises: - RuntimeError: MetaGraphDef associated with the tags cannot be found. - """ - result = choose_meta_graph_def(saved_model) + Raises: + ------ + RuntimeError: MetaGraphDef associated with the tags cannot be found. + """ + result = choose_meta_graph_def(saved_model) - if result is None: - raise RuntimeError( - 'MetaGraphDef associated with tags {} could not be found in SavedModel' - .format(constants.TRANSFORM_TAG)) + if result is None: + raise RuntimeError( + f"MetaGraphDef associated with tags {constants.TRANSFORM_TAG} could not be found in SavedModel" + ) - return result + return result def get_asset_tensors(saved_model_dir, meta_graph_def_to_load): - return loader_impl.get_asset_tensors(saved_model_dir, meta_graph_def_to_load) + return loader_impl.get_asset_tensors(saved_model_dir, meta_graph_def_to_load) diff --git a/tensorflow_transform/saved/saved_model_loader_test.py b/tensorflow_transform/saved/saved_model_loader_test.py index 70b8f65..e973fe1 100644 --- a/tensorflow_transform/saved/saved_model_loader_test.py +++ b/tensorflow_transform/saved/saved_model_loader_test.py @@ -15,36 +15,36 @@ import os import tempfile +import unittest import tensorflow as tf from tensorflow_transform.saved import saved_transform_io -import unittest - def _create_test_saved_model_dir(): - export_path = os.path.join(tempfile.mkdtemp(), 'export') + export_path = os.path.join(tempfile.mkdtemp(), "export") - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - input_float = tf.compat.v1.placeholder(tf.float32, shape=[1]) - output = (input_float - 2.0) / 5.0 - inputs = {'x': input_float} - outputs = {'x_scaled': output} - saved_transform_io.write_saved_transform_from_session( - session, inputs, outputs, export_path) + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + input_float = tf.compat.v1.placeholder(tf.float32, shape=[1]) + output = (input_float - 2.0) / 5.0 + inputs = {"x": input_float} + outputs = {"x_scaled": output} + saved_transform_io.write_saved_transform_from_session( + session, inputs, outputs, export_path + ) - return export_path + return export_path class SavedModelLoaderTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._test_saved_model_dir = _create_test_saved_model_dir() - @classmethod - def setUpClass(cls): - cls._test_saved_model_dir = _create_test_saved_model_dir() + # This class has no tests at the moment. - # This class has no tests at the moment. -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow_transform/saved/saved_transform_io.py b/tensorflow_transform/saved/saved_transform_io.py index 51b98fc..7d7f690 100644 --- a/tensorflow_transform/saved/saved_transform_io.py +++ b/tensorflow_transform/saved/saved_transform_io.py @@ -18,442 +18,479 @@ import re import tensorflow as tf -from tensorflow_transform.py_func import pyfunc_helper -from tensorflow_transform.saved import constants -from tensorflow_transform.saved import saved_model_loader + # pylint: disable=g-direct-tensorflow-import from tensorflow.core.protobuf import struct_pb2 from tensorflow.python.framework import ops from tensorflow.python.saved_model import nested_structure_coder from tensorflow.python.training import saver as tf_saver + +from tensorflow_transform.py_func import pyfunc_helper +from tensorflow_transform.saved import constants, saved_model_loader + # pylint: enable=g-direct-tensorflow-import _MANGLED_TENSOR_NAME_RE = re.compile( - r'(.*)\$(indices|values|dense_shape|dense_tensor)$') + r"(.*)\$(indices|values|dense_shape|dense_tensor)$" +) def _update_legacy_signature(signature): - """Update a legacy name-mangled signature in-place. - - Note this code will not work if there are clashes between the old and new - names, e.g. if x$dense_tensor$dense_tensor and x$dense_tensor are both - features, but this is an edge case that we do not expect to ever happen. - - Args: - signature: A SignatureDef. - """ - for tensor_info_map in [signature.inputs, signature.outputs]: - # It is necessary to make a copy of tensor_info_map.items() since we need to - # modify tensor_info_map while iterating it. - for original_name, original_tensor_info in list(tensor_info_map.items()): - match = _MANGLED_TENSOR_NAME_RE.match(original_name) - if not match: - continue - tf.compat.v1.logging.warn( - 'Converting feature %s from legacy signature. New models will ' - 'be written without name-mangling in the signature', original_name) - name = match.group(1) - if name == 'dense_shape': - assert name not in tensor_info_map - else: - assert (name not in tensor_info_map or - tensor_info_map[name].WhichOneof('encoding') == 'coo_sparse') - new_tensor_info = tensor_info_map[name] - original_tensor_type = match.group(2) - if original_tensor_type == 'indices': - new_tensor_info.coo_sparse.indices_tensor_name = ( - original_tensor_info.name) - elif original_tensor_type == 'values': - new_tensor_info.dtype = original_tensor_info.dtype - new_tensor_info.coo_sparse.values_tensor_name = ( - original_tensor_info.name) - elif original_tensor_type == 'dense_shape': - new_tensor_info.coo_sparse.dense_shape_tensor_name = ( - original_tensor_info.name) - else: - new_tensor_info.CopyFrom(tensor_info_map[original_name]) - del tensor_info_map[original_name] + """Update a legacy name-mangled signature in-place. + + Note this code will not work if there are clashes between the old and new + names, e.g. if x$dense_tensor$dense_tensor and x$dense_tensor are both + features, but this is an edge case that we do not expect to ever happen. + + Args: + ---- + signature: A SignatureDef. + """ + for tensor_info_map in [signature.inputs, signature.outputs]: + # It is necessary to make a copy of tensor_info_map.items() since we need to + # modify tensor_info_map while iterating it. + for original_name, original_tensor_info in list(tensor_info_map.items()): + match = _MANGLED_TENSOR_NAME_RE.match(original_name) + if not match: + continue + tf.compat.v1.logging.warn( + "Converting feature %s from legacy signature. New models will " + "be written without name-mangling in the signature", + original_name, + ) + name = match.group(1) + if name == "dense_shape": + assert name not in tensor_info_map + else: + assert ( + name not in tensor_info_map + or tensor_info_map[name].WhichOneof("encoding") == "coo_sparse" + ) + new_tensor_info = tensor_info_map[name] + original_tensor_type = match.group(2) + if original_tensor_type == "indices": + new_tensor_info.coo_sparse.indices_tensor_name = ( + original_tensor_info.name + ) + elif original_tensor_type == "values": + new_tensor_info.dtype = original_tensor_info.dtype + new_tensor_info.coo_sparse.values_tensor_name = ( + original_tensor_info.name + ) + elif original_tensor_type == "dense_shape": + new_tensor_info.coo_sparse.dense_shape_tensor_name = ( + original_tensor_info.name + ) + else: + new_tensor_info.CopyFrom(tensor_info_map[original_name]) + del tensor_info_map[original_name] def _load_transform_saved_model(transform_savedmodel_dir): - """Load a SavedModel representing a transform function from disk. + """Load a SavedModel representing a transform function from disk. - Args: - transform_savedmodel_dir: a SavedModel directory. + Args: + ---- + transform_savedmodel_dir: a SavedModel directory. - Returns: - A tuple with a `MetaGraphDef` proto, the input and outputs of a - `SignatureDef` proto, and a dict from tensor names to absolute paths for - asset filepaths. - """ - saved_model = saved_model_loader.parse_saved_model( - transform_savedmodel_dir) - meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise( - saved_model) + Returns: + ------- + A tuple with a `MetaGraphDef` proto, the input and outputs of a + `SignatureDef` proto, and a dict from tensor names to absolute paths for + asset filepaths. + """ + saved_model = saved_model_loader.parse_saved_model(transform_savedmodel_dir) + meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise(saved_model) - signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE] - # The following code handles models produced prior to CL/200123875. These - # models used a non-standard naming convention for features in order to - # support SparseTensor. - # TODO(b/34253951): Remove the following code once we no longer want to - # support the legacy formats. - _update_legacy_signature(signature) + signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE] + # The following code handles models produced prior to CL/200123875. These + # models used a non-standard naming convention for features in order to + # support SparseTensor. + # TODO(b/34253951): Remove the following code once we no longer want to + # support the legacy formats. + _update_legacy_signature(signature) - # maps name to TensorInfo - input_signature = signature.inputs - output_signature = signature.outputs + # maps name to TensorInfo + input_signature = signature.inputs + output_signature = signature.outputs - # asset_path_dict is {string: string}, mapping tensor names to absolute paths. - asset_path_dict = saved_model_loader.get_asset_tensors( - transform_savedmodel_dir, meta_graph_def) + # asset_path_dict is {string: string}, mapping tensor names to absolute paths. + asset_path_dict = saved_model_loader.get_asset_tensors( + transform_savedmodel_dir, meta_graph_def + ) - return meta_graph_def, input_signature, output_signature, asset_path_dict + return meta_graph_def, input_signature, output_signature, asset_path_dict def _expand_input_map(logical_input_map, input_signature): - """Expands user provided inputs to component tensors in the graph. - - The user specified `logical_input_map` contains mappings from logical feature - names to `Tensor`s or `CompositeTensor`s. These are expanded into mappings - from component tensor names in the graph to their corresponding component - tensor value. - - Args: - logical_input_map: a dict of logical name to Tensor. The logical names must - be a subset of those in the input signature of the transform graph, and - the corresponding Tensors must have the expected types and shapes. - input_signature: The inputs of a `SignatureDef` proto for the graph to be - imported. - - Returns: - A map from tensor names in `input_signature` to the tensors - specified in `logical_input_map`. - """ - result = {} - for logical_name, replacement in logical_input_map.items(): - tensor_info = input_signature[logical_name] - encoding = tensor_info.WhichOneof('encoding') - if encoding == 'coo_sparse': - assert isinstance(replacement, tf.SparseTensor), logical_name - result[tensor_info.coo_sparse.indices_tensor_name] = replacement.indices - result[tensor_info.coo_sparse.values_tensor_name] = replacement.values - result[tensor_info.coo_sparse.dense_shape_tensor_name] = ( - replacement.dense_shape) - elif encoding == 'composite_tensor': - component_infos = tensor_info.composite_tensor.components - component_tensors = tf.nest.flatten(replacement, expand_composites=True) - for (info, tensor) in zip(component_infos, component_tensors): - result[info.name] = tensor - elif encoding == 'name': - result[tensor_info.name] = replacement - else: - raise ValueError('Unsupported TensorInfo encoding %s' % encoding) - return result - - -_PARTITIONED_VARIABLE_NAME_RE = re.compile(r'^(.*)/part_(\d*)$') + """Expands user provided inputs to component tensors in the graph. + + The user specified `logical_input_map` contains mappings from logical feature + names to `Tensor`s or `CompositeTensor`s. These are expanded into mappings + from component tensor names in the graph to their corresponding component + tensor value. + + Args: + ---- + logical_input_map: a dict of logical name to Tensor. The logical names must + be a subset of those in the input signature of the transform graph, and + the corresponding Tensors must have the expected types and shapes. + input_signature: The inputs of a `SignatureDef` proto for the graph to be + imported. + + Returns: + ------- + A map from tensor names in `input_signature` to the tensors + specified in `logical_input_map`. + """ + result = {} + for logical_name, replacement in logical_input_map.items(): + tensor_info = input_signature[logical_name] + encoding = tensor_info.WhichOneof("encoding") + if encoding == "coo_sparse": + assert isinstance(replacement, tf.SparseTensor), logical_name + result[tensor_info.coo_sparse.indices_tensor_name] = replacement.indices + result[tensor_info.coo_sparse.values_tensor_name] = replacement.values + result[tensor_info.coo_sparse.dense_shape_tensor_name] = ( + replacement.dense_shape + ) + elif encoding == "composite_tensor": + component_infos = tensor_info.composite_tensor.components + component_tensors = tf.nest.flatten(replacement, expand_composites=True) + for info, tensor in zip(component_infos, component_tensors): + result[info.name] = tensor + elif encoding == "name": + result[tensor_info.name] = replacement + else: + raise ValueError("Unsupported TensorInfo encoding %s" % encoding) + return result + + +_PARTITIONED_VARIABLE_NAME_RE = re.compile(r"^(.*)/part_(\d*)$") # TODO(b/159982957): Replace this with a mechinism that registers any custom op. def _maybe_register_addon_ops(): - """Optionally import libraries to register additional TF ops.""" - - def _try_import(name): - try: - importlib.import_module(name) - except (ImportError, tf.errors.NotFoundError): - tf.compat.v1.logging.info('{} is not available.'.format(name)) - pass - - # LINT.IfChange - _try_import('struct2tensor') - _try_import('tensorflow_decision_forests') - _try_import('tensorflow_text') - # LINT.ThenChange(tensorflow_model_analysis/utils/model_util.py) - - -def _partially_apply_saved_transform_impl(saved_model_dir, - logical_input_map, - tensor_replacement_map=None): - """Shared code for partially_apply_saved_transform and fetch_tensor_values. - - This adds nodes to a graph that already contains Tensors representing the - inputs. These input Tensors may be placeholders that will be fed when the - graph is executed, or may be the outputs of some Ops. Most typically, the - input Tensors are reading and/or parsing Ops, but they could be anything-- - including the outputs of a prior application of this function using another - transform graph. - - This function operates on the default Graph in the default Session, and so - must be called within a context where these are provided. - - Args: - saved_model_dir: A SavedModel directory providing a transform - graph. The MetaGraphDef and signature are selected from the SavedModel - using keys defined in `../constants.py` ('transform' and - 'transform_signature', respectively). - logical_input_map: a dict of logical name to Tensor. The logical names must - be a subset of those in the input signature of the transform graph, and - the corresponding Tensors must have the expected types and shapes. - tensor_replacement_map: a dict of tensor names to `Tensors`. - - Returns: - A tuple of (unbound_inputs, outputs, assets_dict) where - * unbound_inputs is a dict of logical name to Tensors that are yet to be - mapped or fed - * outputs is a dict of logical name to Tensor, as provided by the output - signature of the transform graph - - Raises: - ValueError: if the provided input_tensors dict has keys that are not part - of the input signature, or any of the provided inputs have the wrong - type or shape. - RuntimeError: if there is no default graph available to which to apply the - transform. - """ - _maybe_register_addon_ops() - graph = tf.compat.v1.get_default_graph() - if graph is None: - raise RuntimeError('apply_saved_transform() requires a default graph.') - - meta_graph_def, input_signature, output_signature, asset_path_dict = ( - _load_transform_saved_model(saved_model_dir)) - asset_tensor_dict = { - k: tf.convert_to_tensor(v) for k, v in asset_path_dict.items() - } - - # Check for inputs that were not part of the input signature. - unexpected_inputs = ( - set(logical_input_map.keys()) - set(input_signature.keys())) - if unexpected_inputs: - raise ValueError('Unexpected inputs ' - 'to transform: {}'.format(unexpected_inputs)) - - # Create a map from tensor names in the graph to be imported, to the tensors - # specified in `input_tensors`. - input_map = _expand_input_map(logical_input_map, input_signature) - - input_map.update(asset_tensor_dict) - if tensor_replacement_map: - input_map.update(tensor_replacement_map) - - # unique_name may produce e.g. transform_5. The result has no trailing slash. - scope = graph.unique_name('transform', mark_as_used=False) - - # unique_name returns an "absolute" name while we want a name relative to the - # current scope. Therefore, we check if the current name stack is non-empty, - # and if so, strip out the existing name scope. - if graph.get_name_scope(): - current_name_scope = graph.get_name_scope() + '/' - assert scope.startswith(current_name_scope) - import_scope = scope[len(current_name_scope):] - else: - import_scope = scope - - # If the saved_model contained py_funcs, will reinsert them in the graph - # here and update their associated token in the model. - _ = pyfunc_helper.register_pyfuncs_from_saved_transform( - graph, meta_graph_def, loaded_in_tf2=False) - - # Save the ASSET_FILEPATHS before importing the MetaGraphDef - current_assets = graph.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS) - - # Warn user if meta_graph_def has saved variables - if tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: - trainable_vars = meta_graph_def.collection_def[ - tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES].bytes_list.value - if trainable_vars: - raise ValueError( - 'The SavedModel contained trainable variables {}. Because this ' - 'function is typically called in the input_fn, trainable variables ' - 'are disallowed'.format(trainable_vars)) - - # Load the transform graph, applying it to existing Tensors via input_map. - # Throws ValueError if the input_map gives mismatched types or shapes. - saver = tf_saver.import_meta_graph(meta_graph_def, - import_scope=import_scope, - input_map=input_map) - - # Wipe out AssetFileDef collection; it is obsolete after loading - graph.clear_collection(tf.saved_model.ASSETS_KEY) - - # The import may have added Tensors to the ASSET_FILEPATHS collection that - # were substituted via input_map. To account for this, wipe out the - # collection, restore the preexisting collection values, and then write in - # the new substituted Tensors. - graph.clear_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS) - for asset_path_tensor in current_assets: - graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS, - asset_path_tensor) - for asset_path_tensor in asset_tensor_dict.values(): - graph.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS, - asset_path_tensor) - - if saver: - checkpoint_path = os.path.join( - tf.compat.as_bytes(saved_model_dir), - tf.compat.as_bytes(tf.saved_model.VARIABLES_DIRECTORY), - tf.compat.as_bytes(tf.saved_model.VARIABLES_FILENAME)) - - # We can't use the scope rename from init_from_checkpoint because it relies - # on var scopes not rebuilt by import_meta_graph. So we need to construct it - # explicitly by iterating over the variables. - # TODO(b/78624684): remove this workaround. - var_map = {} - for var in tf.compat.v1.global_variables(): - var_name = var.op.name - if not var_name.startswith(scope + '/'): - continue - - # Generate original name before importing into scope. - original_var_name = var_name[len(scope)+1:] - - match = _PARTITIONED_VARIABLE_NAME_RE.match(original_var_name) - if match: - # If the variable is partitioned, extract the base variable name and - # the index in the partition, then update var_map[base_name] to have - # var_map[base_name][partition_index] = var. - base_name = match.group(1) - partition_index = int(match.group(2)) - if base_name not in var_map: - var_map[base_name] = [] - while not partition_index < len(var_map[base_name]): - var_map[base_name].append(None) - assert var_map[base_name][partition_index] is None - var_map[base_name][partition_index] = var - else: - var_map[original_var_name] = var - - if var_map: - tf.compat.v1.train.init_from_checkpoint(checkpoint_path, var_map) - - # Add computed output tensors to the output. There are two cases. When the - # output is not in the input_map, then we look up the tensor in the imported - # graph by prepending the import scope and looking up the tensor by name. - # This will fail if the expected output tensor is not now in the graph - # under the expected name scope. When the output is in the input map, then - # that tensor will have been re-mapped so we use the tensor given in the - # input_map. - def lookup_remapped_tensor(tensor_name): - if tensor_name in input_map: - return input_map[tensor_name] + """Optionally import libraries to register additional TF ops.""" + + def _try_import(name): + try: + importlib.import_module(name) + except (ImportError, tf.errors.NotFoundError): + tf.compat.v1.logging.info(f"{name} is not available.") + pass + + # LINT.IfChange + _try_import("struct2tensor") + _try_import("tensorflow_decision_forests") + _try_import("tensorflow_text") + # LINT.ThenChange(tensorflow_model_analysis/utils/model_util.py) + + +def _partially_apply_saved_transform_impl( + saved_model_dir, logical_input_map, tensor_replacement_map=None +): + """Shared code for partially_apply_saved_transform and fetch_tensor_values. + + This adds nodes to a graph that already contains Tensors representing the + inputs. These input Tensors may be placeholders that will be fed when the + graph is executed, or may be the outputs of some Ops. Most typically, the + input Tensors are reading and/or parsing Ops, but they could be anything-- + including the outputs of a prior application of this function using another + transform graph. + + This function operates on the default Graph in the default Session, and so + must be called within a context where these are provided. + + Args: + ---- + saved_model_dir: A SavedModel directory providing a transform + graph. The MetaGraphDef and signature are selected from the SavedModel + using keys defined in `../constants.py` ('transform' and + 'transform_signature', respectively). + logical_input_map: a dict of logical name to Tensor. The logical names must + be a subset of those in the input signature of the transform graph, and + the corresponding Tensors must have the expected types and shapes. + tensor_replacement_map: a dict of tensor names to `Tensors`. + + Returns: + ------- + A tuple of (unbound_inputs, outputs, assets_dict) where + * unbound_inputs is a dict of logical name to Tensors that are yet to be + mapped or fed + * outputs is a dict of logical name to Tensor, as provided by the output + signature of the transform graph + + Raises: + ------ + ValueError: if the provided input_tensors dict has keys that are not part + of the input signature, or any of the provided inputs have the wrong + type or shape. + RuntimeError: if there is no default graph available to which to apply the + transform. + """ + _maybe_register_addon_ops() + graph = tf.compat.v1.get_default_graph() + if graph is None: + raise RuntimeError("apply_saved_transform() requires a default graph.") + + meta_graph_def, input_signature, output_signature, asset_path_dict = ( + _load_transform_saved_model(saved_model_dir) + ) + asset_tensor_dict = {k: tf.convert_to_tensor(v) for k, v in asset_path_dict.items()} + + # Check for inputs that were not part of the input signature. + unexpected_inputs = set(logical_input_map.keys()) - set(input_signature.keys()) + if unexpected_inputs: + raise ValueError("Unexpected inputs " f"to transform: {unexpected_inputs}") + + # Create a map from tensor names in the graph to be imported, to the tensors + # specified in `input_tensors`. + input_map = _expand_input_map(logical_input_map, input_signature) + + input_map.update(asset_tensor_dict) + if tensor_replacement_map: + input_map.update(tensor_replacement_map) + + # unique_name may produce e.g. transform_5. The result has no trailing slash. + scope = graph.unique_name("transform", mark_as_used=False) + + # unique_name returns an "absolute" name while we want a name relative to the + # current scope. Therefore, we check if the current name stack is non-empty, + # and if so, strip out the existing name scope. + if graph.get_name_scope(): + current_name_scope = graph.get_name_scope() + "/" + assert scope.startswith(current_name_scope) + import_scope = scope[len(current_name_scope) :] else: - return graph.get_tensor_by_name( - ops.prepend_name_scope(tensor_name, scope)) - def lookup_tensor_or_sparse_or_composite_tensor(tensor_info): - """Returns the remapped tensor corresponding to TensorInfo.""" - encoding = tensor_info.WhichOneof('encoding') - if encoding == 'coo_sparse': - return tf.SparseTensor( - lookup_remapped_tensor(tensor_info.coo_sparse.indices_tensor_name), - lookup_remapped_tensor(tensor_info.coo_sparse.values_tensor_name), - lookup_remapped_tensor( - tensor_info.coo_sparse.dense_shape_tensor_name)) - elif encoding == 'composite_tensor': - components = [lookup_remapped_tensor(info.name) - for info in tensor_info.composite_tensor.components] - spec_proto = struct_pb2.StructuredValue( - type_spec_value=tensor_info.composite_tensor.type_spec) - # StrcutureCoder.decode_proto was migrated after TF 2.7 to - # nested_structure_coder.decode_proto. - try: - spec = nested_structure_coder.decode_proto(spec_proto) - except AttributeError: - struct_coder = nested_structure_coder.StructureCoder() - spec = struct_coder.decode_proto(spec_proto) - return spec._from_components(components) # pylint: disable=protected-access - elif encoding == 'name': - return lookup_remapped_tensor(tensor_info.name) - else: - raise ValueError('Unsupported TensorInfo encoding %s' % encoding) - outputs = { - logical_name: lookup_tensor_or_sparse_or_composite_tensor(tensor_info) - for logical_name, tensor_info in output_signature.items() - } - # Do the same for input tensors, although such tensors should never be in the - # input_map since identical tensors in an input_map would be an error. - unbound_inputs = { - logical_name: lookup_tensor_or_sparse_or_composite_tensor(tensor_info) - for logical_name, tensor_info in input_signature.items() - if logical_name not in logical_input_map - } - - return unbound_inputs, outputs - - -def partially_apply_saved_transform_internal(saved_model_dir, - logical_input_map, - tensor_replacement_map=None): - """Apply a transform graph, represented as a SavedModel, to existing Tensors. - - For internal use only. Users should use the `transform_raw_features` or - `transform_raw_features_layer` method of the TFTrandformOutput class. - - This adds nodes to a graph that already contains Tensors representing the - inputs. These input Tensors may be placeholders that will be fed when the - graph is executed, or may be the outputs of some Ops. Most typically, the - input Tensors are reading and/or parsing Ops, but they could be anything-- - including the outputs of a prior application of this function using another - transform graph. - - This function operates on the default Graph in the default Session, and so - must be called within a context where these are provided. - - Args: - saved_model_dir: A SavedModel directory providing a transform - graph. The MetaGraphDef and signature are selected from the SavedModel - using keys defined in `../constants.py` ('transform' and - 'transform_signature', respectively). - logical_input_map: a dict of logical name to Tensor. The logical names must - be a subset of those in the input signature of the transform graph, and - the corresponding Tensors must have the expected types and shapes. - tensor_replacement_map: a dict of tensor names to `Tensors`. - - Returns: - A pair of (unbound_inputs, outputs) where unbound_inputs is a dict of - logical name to Tensors that are yet to be mapped or fed, and outputs is - a dict of logical name to Tensor, as provided by the output signature - of the transform graph - - Raises: - ValueError: if the provided input_tensors dict has keys that are not part - of the input signature, or any of the provided inputs have the wrong - type or shape. - RuntimeError: if there is no default graph available to which to apply the - transform. - """ - unbound_inputs, outputs = _partially_apply_saved_transform_impl( - saved_model_dir, logical_input_map, tensor_replacement_map) - return unbound_inputs, outputs + import_scope = scope + + # If the saved_model contained py_funcs, will reinsert them in the graph + # here and update their associated token in the model. + _ = pyfunc_helper.register_pyfuncs_from_saved_transform( + graph, meta_graph_def, loaded_in_tf2=False + ) + + # Save the ASSET_FILEPATHS before importing the MetaGraphDef + current_assets = graph.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS) + + # Warn user if meta_graph_def has saved variables + if tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES in meta_graph_def.collection_def: + trainable_vars = meta_graph_def.collection_def[ + tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES + ].bytes_list.value + if trainable_vars: + raise ValueError( + f"The SavedModel contained trainable variables {trainable_vars}. Because this " + "function is typically called in the input_fn, trainable variables " + "are disallowed" + ) + + # Load the transform graph, applying it to existing Tensors via input_map. + # Throws ValueError if the input_map gives mismatched types or shapes. + saver = tf_saver.import_meta_graph( + meta_graph_def, import_scope=import_scope, input_map=input_map + ) + + # Wipe out AssetFileDef collection; it is obsolete after loading + graph.clear_collection(tf.saved_model.ASSETS_KEY) + + # The import may have added Tensors to the ASSET_FILEPATHS collection that + # were substituted via input_map. To account for this, wipe out the + # collection, restore the preexisting collection values, and then write in + # the new substituted Tensors. + graph.clear_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS) + for asset_path_tensor in current_assets: + graph.add_to_collection( + tf.compat.v1.GraphKeys.ASSET_FILEPATHS, asset_path_tensor + ) + for asset_path_tensor in asset_tensor_dict.values(): + graph.add_to_collection( + tf.compat.v1.GraphKeys.ASSET_FILEPATHS, asset_path_tensor + ) + + if saver: + checkpoint_path = os.path.join( + tf.compat.as_bytes(saved_model_dir), + tf.compat.as_bytes(tf.saved_model.VARIABLES_DIRECTORY), + tf.compat.as_bytes(tf.saved_model.VARIABLES_FILENAME), + ) + + # We can't use the scope rename from init_from_checkpoint because it relies + # on var scopes not rebuilt by import_meta_graph. So we need to construct it + # explicitly by iterating over the variables. + # TODO(b/78624684): remove this workaround. + var_map = {} + for var in tf.compat.v1.global_variables(): + var_name = var.op.name + if not var_name.startswith(scope + "/"): + continue + + # Generate original name before importing into scope. + original_var_name = var_name[len(scope) + 1 :] + + match = _PARTITIONED_VARIABLE_NAME_RE.match(original_var_name) + if match: + # If the variable is partitioned, extract the base variable name and + # the index in the partition, then update var_map[base_name] to have + # var_map[base_name][partition_index] = var. + base_name = match.group(1) + partition_index = int(match.group(2)) + if base_name not in var_map: + var_map[base_name] = [] + while not partition_index < len(var_map[base_name]): + var_map[base_name].append(None) + assert var_map[base_name][partition_index] is None + var_map[base_name][partition_index] = var + else: + var_map[original_var_name] = var + + if var_map: + tf.compat.v1.train.init_from_checkpoint(checkpoint_path, var_map) + + # Add computed output tensors to the output. There are two cases. When the + # output is not in the input_map, then we look up the tensor in the imported + # graph by prepending the import scope and looking up the tensor by name. + # This will fail if the expected output tensor is not now in the graph + # under the expected name scope. When the output is in the input map, then + # that tensor will have been re-mapped so we use the tensor given in the + # input_map. + def lookup_remapped_tensor(tensor_name): + if tensor_name in input_map: + return input_map[tensor_name] + else: + return graph.get_tensor_by_name(ops.prepend_name_scope(tensor_name, scope)) + + def lookup_tensor_or_sparse_or_composite_tensor(tensor_info): + """Returns the remapped tensor corresponding to TensorInfo.""" + encoding = tensor_info.WhichOneof("encoding") + if encoding == "coo_sparse": + return tf.SparseTensor( + lookup_remapped_tensor(tensor_info.coo_sparse.indices_tensor_name), + lookup_remapped_tensor(tensor_info.coo_sparse.values_tensor_name), + lookup_remapped_tensor(tensor_info.coo_sparse.dense_shape_tensor_name), + ) + elif encoding == "composite_tensor": + components = [ + lookup_remapped_tensor(info.name) + for info in tensor_info.composite_tensor.components + ] + spec_proto = struct_pb2.StructuredValue( + type_spec_value=tensor_info.composite_tensor.type_spec + ) + # StrcutureCoder.decode_proto was migrated after TF 2.7 to + # nested_structure_coder.decode_proto. + try: + spec = nested_structure_coder.decode_proto(spec_proto) + except AttributeError: + struct_coder = nested_structure_coder.StructureCoder() + spec = struct_coder.decode_proto(spec_proto) + return spec._from_components(components) # pylint: disable=protected-access + elif encoding == "name": + return lookup_remapped_tensor(tensor_info.name) + else: + raise ValueError("Unsupported TensorInfo encoding %s" % encoding) + + outputs = { + logical_name: lookup_tensor_or_sparse_or_composite_tensor(tensor_info) + for logical_name, tensor_info in output_signature.items() + } + # Do the same for input tensors, although such tensors should never be in the + # input_map since identical tensors in an input_map would be an error. + unbound_inputs = { + logical_name: lookup_tensor_or_sparse_or_composite_tensor(tensor_info) + for logical_name, tensor_info in input_signature.items() + if logical_name not in logical_input_map + } + + return unbound_inputs, outputs + + +def partially_apply_saved_transform_internal( + saved_model_dir, logical_input_map, tensor_replacement_map=None +): + """Apply a transform graph, represented as a SavedModel, to existing Tensors. + + For internal use only. Users should use the `transform_raw_features` or + `transform_raw_features_layer` method of the TFTrandformOutput class. + + This adds nodes to a graph that already contains Tensors representing the + inputs. These input Tensors may be placeholders that will be fed when the + graph is executed, or may be the outputs of some Ops. Most typically, the + input Tensors are reading and/or parsing Ops, but they could be anything-- + including the outputs of a prior application of this function using another + transform graph. + + This function operates on the default Graph in the default Session, and so + must be called within a context where these are provided. + + Args: + ---- + saved_model_dir: A SavedModel directory providing a transform + graph. The MetaGraphDef and signature are selected from the SavedModel + using keys defined in `../constants.py` ('transform' and + 'transform_signature', respectively). + logical_input_map: a dict of logical name to Tensor. The logical names must + be a subset of those in the input signature of the transform graph, and + the corresponding Tensors must have the expected types and shapes. + tensor_replacement_map: a dict of tensor names to `Tensors`. + + Returns: + ------- + A pair of (unbound_inputs, outputs) where unbound_inputs is a dict of + logical name to Tensors that are yet to be mapped or fed, and outputs is + a dict of logical name to Tensor, as provided by the output signature + of the transform graph + + Raises: + ------ + ValueError: if the provided input_tensors dict has keys that are not part + of the input signature, or any of the provided inputs have the wrong + type or shape. + RuntimeError: if there is no default graph available to which to apply the + transform. + """ + unbound_inputs, outputs = _partially_apply_saved_transform_impl( + saved_model_dir, logical_input_map, tensor_replacement_map + ) + return unbound_inputs, outputs def write_saved_transform_from_session( - session, inputs, outputs, export_path, as_text=False): - """Write the current session as a SavedModel.""" - predict_signature_def = ( - tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( - inputs, outputs)) - - builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path) - builder.add_meta_graph_and_variables( - session, [constants.TRANSFORM_TAG], - signature_def_map={constants.TRANSFORM_SIGNATURE: predict_signature_def}, - assets_collection=tf.compat.v1.get_collection( - tf.compat.v1.GraphKeys.ASSET_FILEPATHS)) - builder.save(as_text) + session, inputs, outputs, export_path, as_text=False +): + """Write the current session as a SavedModel.""" + predict_signature_def = ( + tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( + inputs, outputs + ) + ) + + builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(export_path) + builder.add_meta_graph_and_variables( + session, + [constants.TRANSFORM_TAG], + signature_def_map={constants.TRANSFORM_SIGNATURE: predict_signature_def}, + assets_collection=tf.compat.v1.get_collection( + tf.compat.v1.GraphKeys.ASSET_FILEPATHS + ), + ) + builder.save(as_text) def exported_as_v1(transform_savedmodel_dir): - """Check if a SavedModel was exported as a TF 1 model or not. - - Args: - transform_savedmodel_dir: a SavedModel directory. - - Returns: - `True` if `transform_savedmodel_dir` contains a TF1 SavedModel else - returns `False`. - """ - saved_model = saved_model_loader.parse_saved_model(transform_savedmodel_dir) - meta_graph_def = saved_model_loader.choose_meta_graph_def(saved_model) - return meta_graph_def is not None + """Check if a SavedModel was exported as a TF 1 model or not. + + Args: + ---- + transform_savedmodel_dir: a SavedModel directory. + + Returns: + ------- + `True` if `transform_savedmodel_dir` contains a TF1 SavedModel else + returns `False`. + """ + saved_model = saved_model_loader.parse_saved_model(transform_savedmodel_dir) + meta_graph_def = saved_model_loader.choose_meta_graph_def(saved_model) + return meta_graph_def is not None diff --git a/tensorflow_transform/saved/saved_transform_io_test.py b/tensorflow_transform/saved/saved_transform_io_test.py index 30ee65a..892da02 100644 --- a/tensorflow_transform/saved/saved_transform_io_test.py +++ b/tensorflow_transform/saved/saved_transform_io_test.py @@ -18,292 +18,333 @@ import numpy as np import tensorflow as tf -from tensorflow_transform.saved import saved_transform_io # pylint: disable=g-direct-tensorflow-import from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.lib.io import file_io from tensorflow.python.ops import lookup_ops + +from tensorflow_transform.saved import saved_transform_io + # pylint: enable=g-direct-tensorflow-import # TODO(b/123241798): Find an open-source compatible way to access # FLAGS.test_tmpdir. def _create_test_saved_model(): - export_path = os.path.join(tempfile.mkdtemp(), 'export') - - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - input_float = tf.compat.v1.placeholder(tf.float32, shape=[1]) - output = (input_float - 2.0) / 5.0 - inputs = {'x': input_float} - outputs = {'x_scaled': output} - saved_transform_io.write_saved_transform_from_session( - session, inputs, outputs, export_path) - - return export_path - + export_path = os.path.join(tempfile.mkdtemp(), "export") -class SavedTransformIOTest(tf.test.TestCase): - - @classmethod - def setUpClass(cls): - cls._test_saved_model = _create_test_saved_model() - - def test_apply_saved_transform(self): - with tf.compat.v1.Graph().as_default() as graph: - with tf.compat.v1.Session().as_default() as session: - input_floats = tf.constant([1237.0]) # tf.float32 - input_features = {'x': input_floats} - _, transformed_features = ( - saved_transform_io.partially_apply_saved_transform_internal( - self._test_saved_model, input_features)) - self.assertEqual(['x_scaled'], list(transformed_features)) - result_tensor = transformed_features['x_scaled'] - self.assertIsInstance(result_tensor, tf.Tensor) - - self.assertAllEqual(session.run(result_tensor), [247.0]) - self.assertEqual(graph.get_tensor_by_name('Const:0'), input_floats) - self.assertEqual( - graph.get_tensor_by_name('transform/truediv:0'), - result_tensor) - - def test_apply_transform_extra_features_no_passthrough(self): - with self.assertRaises(ValueError): - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default(): - input_floats = tf.constant([1234.0]) # tf.float32 - input_features = {'x': input_floats, - 'extra_1': tf.constant('1'), - 'extra_2': tf.constant('2')} - saved_transform_io.partially_apply_saved_transform_internal( - self._test_saved_model, input_features) - - def test_apply_transform_type_mismatch(self): - with self.assertRaises(ValueError): - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default(): - input_strings = tf.constant(['bogus']) # tf.string - input_features = {'x': input_strings} - saved_transform_io.partially_apply_saved_transform_internal( - self._test_saved_model, input_features) - - def test_apply_transform_shape_mismatch(self): - with self.assertRaises(ValueError): - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default(): - input_floats = tf.constant(1234.0) # tf.float32 - input_features = {'x': input_floats} - saved_transform_io.partially_apply_saved_transform_internal( - self._test_saved_model, input_features) - - def test_apply_saved_transform_to_tensor_inside_scope(self): - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.name_scope('my_scope'): - with tf.compat.v1.Session().as_default() as session: - input_floats = tf.constant([1237.0]) # tf.float32 - input_features = {'x': input_floats} - _, transformed_features = ( - saved_transform_io.partially_apply_saved_transform_internal( - self._test_saved_model, input_features)) - self.assertEqual(['x_scaled'], list(transformed_features)) - result_tensor = transformed_features['x_scaled'] - self.assertAllEqual(session.run(result_tensor), [247.0]) - - def test_apply_saved_transform_to_tensor_outside_scope(self): with tf.compat.v1.Graph().as_default(): - input_floats = tf.constant([1237.0]) # tf.float32 - with tf.compat.v1.name_scope('my_scope'): with tf.compat.v1.Session().as_default() as session: - input_features = {'x': input_floats} - _, transformed_features = ( - saved_transform_io.partially_apply_saved_transform_internal( - self._test_saved_model, input_features)) - self.assertEqual(['x_scaled'], list(transformed_features)) - result_tensor = transformed_features['x_scaled'] - self.assertAllEqual(session.run(result_tensor), [247.0]) - - def test_dense_roundtrip(self): - export_path = os.path.join(tempfile.mkdtemp(), 'export') - - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - input_float = tf.compat.v1.placeholder(tf.float32) - # show that unrelated & unmapped placeholders do not interfere - tf.compat.v1.placeholder(tf.int64) - output = input_float / 5.0 - inputs = {'input': input_float} - outputs = {'output': output} - saved_transform_io.write_saved_transform_from_session( - session, inputs, outputs, export_path) + input_float = tf.compat.v1.placeholder(tf.float32, shape=[1]) + output = (input_float - 2.0) / 5.0 + inputs = {"x": input_float} + outputs = {"x_scaled": output} + saved_transform_io.write_saved_transform_from_session( + session, inputs, outputs, export_path + ) - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - # Using a computed input gives confidence that the graphs are fused. - input_float = tf.constant(25.0) * 2 - inputs = {'input': input_float} - _, outputs = ( - saved_transform_io.partially_apply_saved_transform_internal( - export_path, inputs)) - result = session.run(outputs['output']) - # (25 * 2) / 5 = 10 - self.assertEqual(10.0, result) - - def test_table_roundtrip(self): - export_path = os.path.join(tempfile.mkdtemp(), 'export') - - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - input_string = tf.compat.v1.placeholder(tf.string) - # Map string through a table, in this case based on a constant tensor. - table_keys = ['cat', 'dog', 'giraffe'] - initializer = tf.lookup.KeyValueTensorInitializer( - keys=table_keys, - values=tf.cast(tf.range(len(table_keys)), tf.int64), - key_dtype=tf.string, - value_dtype=tf.int64) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - - output = table.lookup(input_string) - inputs = {'input': input_string} - outputs = {'output': output} - saved_transform_io.write_saved_transform_from_session( - session, inputs, outputs, export_path) - - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - # Using a computed input gives confidence that the graphs are fused. - input_string = tf.constant('dog') - inputs = {'input': input_string} - _, outputs = ( - saved_transform_io.partially_apply_saved_transform_internal( - export_path, inputs)) - session.run(tf.compat.v1.tables_initializer()) - result = session.run(outputs['output']) - self.assertEqual(1, result) - - def test_sparse_roundtrip(self): - export_path = os.path.join(tempfile.mkdtemp(), 'export') - - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - input_float = tf.compat.v1.sparse_placeholder(tf.float32) - output = input_float / 5.0 - inputs = {'input': input_float} - outputs = {'output': output} - saved_transform_io.write_saved_transform_from_session( - session, inputs, outputs, export_path) + return export_path - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64) - values = np.array([1.0, 2.0], dtype=np.float32) - shape = np.array([7, 9, 2], dtype=np.int64) - input_sparse = tf.SparseTensor( - indices=indices, values=values, dense_shape=shape) - - # Using a computed input gives confidence that the graphs are fused - inputs = {'input': input_sparse * 10} - _, outputs = ( - saved_transform_io.partially_apply_saved_transform_internal( - export_path, inputs)) - output_sparse = outputs['output'] - self.assertIsInstance(output_sparse, tf.SparseTensor) - result = session.run(output_sparse) - - # indices and shape unchanged; values multiplied by 10 and divided by 5 - self.assertEqual(indices.tolist(), result.indices.tolist()) - self.assertEqual([2.0, 4.0], result.values.tolist()) - self.assertEqual(shape.tolist(), result.dense_shape.tolist()) - - def test_ragged_roundtrip(self): - if not hasattr(meta_graph_pb2.TensorInfo, 'CompositeTensor'): - self.skipTest('This version of TensorFlow does not support ' - 'CompositeTenors in TensorInfo.') - export_path = os.path.join(tempfile.mkdtemp(), 'export') - - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - input_float = tf.compat.v1.ragged.placeholder(tf.float32, ragged_rank=1, - value_shape=[]) - output = input_float / 2.0 - inputs = {'input': input_float} - outputs = {'output': output} - saved_transform_io.write_saved_transform_from_session( - session, inputs, outputs, export_path) - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - splits = np.array([0, 2, 3], dtype=np.int64) - values = np.array([1.0, 2.0, 4.0], dtype=np.float32) - input_ragged = tf.RaggedTensor.from_row_splits(values, splits) - - # Using a computed input gives confidence that the graphs are fused - inputs = {'input': input_ragged * 10} - _, outputs = ( - saved_transform_io.partially_apply_saved_transform_internal( - export_path, inputs)) - output_ragged = outputs['output'] - self.assertIsInstance(output_ragged, tf.RaggedTensor) - result = session.run(output_ragged) - - # indices and shape unchanged; values multipled by 10 and divided by 2 - self.assertAllEqual(splits, result.row_splits) - self.assertEqual([5.0, 10.0, 20.0], result.values.tolist()) - - def test_stale_asset_collections_are_cleaned(self): - vocabulary_file = os.path.join( - tf.compat.as_bytes(self.get_temp_dir()), tf.compat.as_bytes('asset')) - file_io.write_string_to_file(vocabulary_file, 'foo bar baz') - - export_path = os.path.join(tempfile.mkdtemp(), 'export') - - # create a SavedModel including assets - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - input_string = tf.compat.v1.placeholder(tf.string) - # Map string through a table loaded from an asset file - initializer = tf.lookup.TextFileInitializer( - vocabulary_file, - key_dtype=tf.string, - key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - value_dtype=tf.int64, - value_index=tf.lookup.TextFileIndex.LINE_NUMBER) - table = tf.lookup.StaticHashTable(initializer, default_value=12) - table = lookup_ops.IdTableWithHashBuckets(table, - num_oov_buckets=12, - key_dtype=tf.string) - output = table.lookup(input_string) - inputs = {'input': input_string} - outputs = {'output': output} - saved_transform_io.write_saved_transform_from_session( - session, inputs, outputs, export_path) - - # Load it and save it again repeatedly, verifying that the asset collections - # remain valid. - for _ in [1, 2, 3]: - with tf.compat.v1.Graph().as_default() as g: - with tf.compat.v1.Session().as_default() as session: - input_string = tf.constant('dog') - inputs = {'input': input_string} - _, outputs = ( - saved_transform_io.partially_apply_saved_transform_internal( - export_path, inputs)) - - self.assertEqual( - 1, len(g.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS))) - self.assertEqual(0, len(g.get_collection(tf.saved_model.ASSETS_KEY))) - - # Check that every ASSET_FILEPATHS refers to a Tensor in the graph. - # If not, get_tensor_by_name() raises KeyError. - for asset_path in g.get_collection( - tf.compat.v1.GraphKeys.ASSET_FILEPATHS): - tensor_name = asset_path.name - g.get_tensor_by_name(tensor_name) - - export_path = os.path.join(tempfile.mkdtemp(), 'export') - saved_transform_io.write_saved_transform_from_session( - session, inputs, outputs, export_path) - -if __name__ == '__main__': - tf.test.main() +class SavedTransformIOTest(tf.test.TestCase): + @classmethod + def setUpClass(cls): + cls._test_saved_model = _create_test_saved_model() + + def test_apply_saved_transform(self): + with tf.compat.v1.Graph().as_default() as graph: + with tf.compat.v1.Session().as_default() as session: + input_floats = tf.constant([1237.0]) # tf.float32 + input_features = {"x": input_floats} + _, transformed_features = ( + saved_transform_io.partially_apply_saved_transform_internal( + self._test_saved_model, input_features + ) + ) + self.assertEqual(["x_scaled"], list(transformed_features)) + result_tensor = transformed_features["x_scaled"] + self.assertIsInstance(result_tensor, tf.Tensor) + + self.assertAllEqual(session.run(result_tensor), [247.0]) + self.assertEqual(graph.get_tensor_by_name("Const:0"), input_floats) + self.assertEqual( + graph.get_tensor_by_name("transform/truediv:0"), result_tensor + ) + + def test_apply_transform_extra_features_no_passthrough(self): + with self.assertRaises(ValueError): + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default(): + input_floats = tf.constant([1234.0]) # tf.float32 + input_features = { + "x": input_floats, + "extra_1": tf.constant("1"), + "extra_2": tf.constant("2"), + } + saved_transform_io.partially_apply_saved_transform_internal( + self._test_saved_model, input_features + ) + + def test_apply_transform_type_mismatch(self): + with self.assertRaises(ValueError): + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default(): + input_strings = tf.constant(["bogus"]) # tf.string + input_features = {"x": input_strings} + saved_transform_io.partially_apply_saved_transform_internal( + self._test_saved_model, input_features + ) + + def test_apply_transform_shape_mismatch(self): + with self.assertRaises(ValueError): + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default(): + input_floats = tf.constant(1234.0) # tf.float32 + input_features = {"x": input_floats} + saved_transform_io.partially_apply_saved_transform_internal( + self._test_saved_model, input_features + ) + + def test_apply_saved_transform_to_tensor_inside_scope(self): + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.name_scope("my_scope"): + with tf.compat.v1.Session().as_default() as session: + input_floats = tf.constant([1237.0]) # tf.float32 + input_features = {"x": input_floats} + _, transformed_features = ( + saved_transform_io.partially_apply_saved_transform_internal( + self._test_saved_model, input_features + ) + ) + self.assertEqual(["x_scaled"], list(transformed_features)) + result_tensor = transformed_features["x_scaled"] + self.assertAllEqual(session.run(result_tensor), [247.0]) + + def test_apply_saved_transform_to_tensor_outside_scope(self): + with tf.compat.v1.Graph().as_default(): + input_floats = tf.constant([1237.0]) # tf.float32 + with tf.compat.v1.name_scope("my_scope"): + with tf.compat.v1.Session().as_default() as session: + input_features = {"x": input_floats} + _, transformed_features = ( + saved_transform_io.partially_apply_saved_transform_internal( + self._test_saved_model, input_features + ) + ) + self.assertEqual(["x_scaled"], list(transformed_features)) + result_tensor = transformed_features["x_scaled"] + self.assertAllEqual(session.run(result_tensor), [247.0]) + + def test_dense_roundtrip(self): + export_path = os.path.join(tempfile.mkdtemp(), "export") + + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + input_float = tf.compat.v1.placeholder(tf.float32) + # show that unrelated & unmapped placeholders do not interfere + tf.compat.v1.placeholder(tf.int64) + output = input_float / 5.0 + inputs = {"input": input_float} + outputs = {"output": output} + saved_transform_io.write_saved_transform_from_session( + session, inputs, outputs, export_path + ) + + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + # Using a computed input gives confidence that the graphs are fused. + input_float = tf.constant(25.0) * 2 + inputs = {"input": input_float} + _, outputs = ( + saved_transform_io.partially_apply_saved_transform_internal( + export_path, inputs + ) + ) + result = session.run(outputs["output"]) + # (25 * 2) / 5 = 10 + self.assertEqual(10.0, result) + + def test_table_roundtrip(self): + export_path = os.path.join(tempfile.mkdtemp(), "export") + + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + input_string = tf.compat.v1.placeholder(tf.string) + # Map string through a table, in this case based on a constant tensor. + table_keys = ["cat", "dog", "giraffe"] + initializer = tf.lookup.KeyValueTensorInitializer( + keys=table_keys, + values=tf.cast(tf.range(len(table_keys)), tf.int64), + key_dtype=tf.string, + value_dtype=tf.int64, + ) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + + output = table.lookup(input_string) + inputs = {"input": input_string} + outputs = {"output": output} + saved_transform_io.write_saved_transform_from_session( + session, inputs, outputs, export_path + ) + + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + # Using a computed input gives confidence that the graphs are fused. + input_string = tf.constant("dog") + inputs = {"input": input_string} + _, outputs = ( + saved_transform_io.partially_apply_saved_transform_internal( + export_path, inputs + ) + ) + session.run(tf.compat.v1.tables_initializer()) + result = session.run(outputs["output"]) + self.assertEqual(1, result) + + def test_sparse_roundtrip(self): + export_path = os.path.join(tempfile.mkdtemp(), "export") + + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + input_float = tf.compat.v1.sparse_placeholder(tf.float32) + output = input_float / 5.0 + inputs = {"input": input_float} + outputs = {"output": output} + saved_transform_io.write_saved_transform_from_session( + session, inputs, outputs, export_path + ) + + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64) + values = np.array([1.0, 2.0], dtype=np.float32) + shape = np.array([7, 9, 2], dtype=np.int64) + input_sparse = tf.SparseTensor( + indices=indices, values=values, dense_shape=shape + ) + + # Using a computed input gives confidence that the graphs are fused + inputs = {"input": input_sparse * 10} + _, outputs = ( + saved_transform_io.partially_apply_saved_transform_internal( + export_path, inputs + ) + ) + output_sparse = outputs["output"] + self.assertIsInstance(output_sparse, tf.SparseTensor) + result = session.run(output_sparse) + + # indices and shape unchanged; values multiplied by 10 and divided by 5 + self.assertEqual(indices.tolist(), result.indices.tolist()) + self.assertEqual([2.0, 4.0], result.values.tolist()) + self.assertEqual(shape.tolist(), result.dense_shape.tolist()) + + def test_ragged_roundtrip(self): + if not hasattr(meta_graph_pb2.TensorInfo, "CompositeTensor"): + self.skipTest( + "This version of TensorFlow does not support " + "CompositeTenors in TensorInfo." + ) + export_path = os.path.join(tempfile.mkdtemp(), "export") + + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + input_float = tf.compat.v1.ragged.placeholder( + tf.float32, ragged_rank=1, value_shape=[] + ) + output = input_float / 2.0 + inputs = {"input": input_float} + outputs = {"output": output} + saved_transform_io.write_saved_transform_from_session( + session, inputs, outputs, export_path + ) + + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + splits = np.array([0, 2, 3], dtype=np.int64) + values = np.array([1.0, 2.0, 4.0], dtype=np.float32) + input_ragged = tf.RaggedTensor.from_row_splits(values, splits) + + # Using a computed input gives confidence that the graphs are fused + inputs = {"input": input_ragged * 10} + _, outputs = ( + saved_transform_io.partially_apply_saved_transform_internal( + export_path, inputs + ) + ) + output_ragged = outputs["output"] + self.assertIsInstance(output_ragged, tf.RaggedTensor) + result = session.run(output_ragged) + + # indices and shape unchanged; values multipled by 10 and divided by 2 + self.assertAllEqual(splits, result.row_splits) + self.assertEqual([5.0, 10.0, 20.0], result.values.tolist()) + + def test_stale_asset_collections_are_cleaned(self): + vocabulary_file = os.path.join( + tf.compat.as_bytes(self.get_temp_dir()), tf.compat.as_bytes("asset") + ) + file_io.write_string_to_file(vocabulary_file, "foo bar baz") + + export_path = os.path.join(tempfile.mkdtemp(), "export") + + # create a SavedModel including assets + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + input_string = tf.compat.v1.placeholder(tf.string) + # Map string through a table loaded from an asset file + initializer = tf.lookup.TextFileInitializer( + vocabulary_file, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + ) + table = tf.lookup.StaticHashTable(initializer, default_value=12) + table = lookup_ops.IdTableWithHashBuckets( + table, num_oov_buckets=12, key_dtype=tf.string + ) + output = table.lookup(input_string) + inputs = {"input": input_string} + outputs = {"output": output} + saved_transform_io.write_saved_transform_from_session( + session, inputs, outputs, export_path + ) + + # Load it and save it again repeatedly, verifying that the asset collections + # remain valid. + for _ in [1, 2, 3]: + with tf.compat.v1.Graph().as_default() as g: + with tf.compat.v1.Session().as_default() as session: + input_string = tf.constant("dog") + inputs = {"input": input_string} + _, outputs = ( + saved_transform_io.partially_apply_saved_transform_internal( + export_path, inputs + ) + ) + + self.assertEqual( + 1, len(g.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)) + ) + self.assertEqual( + 0, len(g.get_collection(tf.saved_model.ASSETS_KEY)) + ) + + # Check that every ASSET_FILEPATHS refers to a Tensor in the graph. + # If not, get_tensor_by_name() raises KeyError. + for asset_path in g.get_collection( + tf.compat.v1.GraphKeys.ASSET_FILEPATHS + ): + tensor_name = asset_path.name + g.get_tensor_by_name(tensor_name) + + export_path = os.path.join(tempfile.mkdtemp(), "export") + saved_transform_io.write_saved_transform_from_session( + session, inputs, outputs, export_path + ) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow_transform/saved/saved_transform_io_v2.py b/tensorflow_transform/saved/saved_transform_io_v2.py index 7c60ef7..e8254f1 100644 --- a/tensorflow_transform/saved/saved_transform_io_v2.py +++ b/tensorflow_transform/saved/saved_transform_io_v2.py @@ -16,448 +16,487 @@ from typing import Any, Dict, Iterable, Mapping, Optional, Tuple, Union import tensorflow as tf -from tensorflow_transform import annotators -from tensorflow_transform import common_types -from tensorflow_transform import graph_tools -from tensorflow_transform import tf2_utils -from tensorflow_transform import tf_utils -from tensorflow_transform.py_func import pyfunc_helper -from tensorflow_transform.saved import constants -from tensorflow_transform.saved import saved_model_loader -from tensorflow_transform.saved import saved_transform_io + # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.eager import function -from tensorflow.python.eager import wrap_function +from tensorflow.python.eager import function, wrap_function from tensorflow.python.framework import composite_tensor from tensorflow.python.ops import lookup_ops from tensorflow.python.util import object_identity +from tensorflow_transform import ( + annotators, + common_types, + graph_tools, + tf2_utils, + tf_utils, +) +from tensorflow_transform.py_func import pyfunc_helper +from tensorflow_transform.saved import constants, saved_model_loader, saved_transform_io + # pylint: disable=g-import-not-at-top try: - # Moved in TensorFlow 2.10. - from tensorflow.python.trackable import resource as tracking + # Moved in TensorFlow 2.10. + from tensorflow.python.trackable import resource as tracking except ImportError: - from tensorflow.python.training.tracking import tracking + from tensorflow.python.training.tracking import tracking # pylint: enable=g-direct-tensorflow-import, g-import-not-at-top def _restore_from_v1_saved_model( restored_function: function.ConcreteFunction, saved_model_dir: str -) -> Tuple[function.ConcreteFunction, Mapping[str, Any], Mapping[ - str, common_types.TensorType]]: - """Restores an exported TF1 compat SavedModel.""" - saved_model = saved_model_loader.parse_saved_model(saved_model_dir) - meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise( - saved_model) - signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE] - # Re-register pyfuncs, if any. - graph_def = pyfunc_helper.register_pyfuncs_from_saved_transform( - restored_function.graph, meta_graph_def, loaded_in_tf2=True) - if graph_def is None: - return (restored_function, signature.inputs, - restored_function.structured_outputs) - - inputs = [t.name for t in restored_function.graph.inputs] - outputs = [t.name for t in restored_function.graph.outputs] - wrapped = wrap_function.function_from_graph_def(graph_def, inputs, outputs) - structured_outputs = ( - tf.nest.pack_sequence_as( - restored_function.structured_outputs, - wrapped.outputs, - expand_composites=True)) - wrapped = wrapped.prune(wrapped.inputs, structured_outputs) - return (wrapped, signature.inputs, wrapped.structured_outputs) +) -> Tuple[ + function.ConcreteFunction, Mapping[str, Any], Mapping[str, common_types.TensorType] +]: + """Restores an exported TF1 compat SavedModel.""" + saved_model = saved_model_loader.parse_saved_model(saved_model_dir) + meta_graph_def = saved_model_loader.choose_meta_graph_def_and_raise(saved_model) + signature = meta_graph_def.signature_def[constants.TRANSFORM_SIGNATURE] + # Re-register pyfuncs, if any. + graph_def = pyfunc_helper.register_pyfuncs_from_saved_transform( + restored_function.graph, meta_graph_def, loaded_in_tf2=True + ) + if graph_def is None: + return ( + restored_function, + signature.inputs, + restored_function.structured_outputs, + ) + + inputs = [t.name for t in restored_function.graph.inputs] + outputs = [t.name for t in restored_function.graph.outputs] + wrapped = wrap_function.function_from_graph_def(graph_def, inputs, outputs) + structured_outputs = tf.nest.pack_sequence_as( + restored_function.structured_outputs, wrapped.outputs, expand_composites=True + ) + wrapped = wrapped.prune(wrapped.inputs, structured_outputs) + return (wrapped, signature.inputs, wrapped.structured_outputs) def _as_operation(op_or_tensor: Union[tf.Operation, tf.Tensor]) -> tf.Operation: - if isinstance(op_or_tensor, tf.Tensor): - return op_or_tensor.op - return op_or_tensor + if isinstance(op_or_tensor, tf.Tensor): + return op_or_tensor.op + return op_or_tensor def _get_component_tensors( - tensor: Union[tf.Tensor, composite_tensor.CompositeTensor] + tensor: Union[tf.Tensor, composite_tensor.CompositeTensor], ) -> Iterable[tf.Tensor]: - """Get all component tensors. - - Args: - tensor: A `Tensor` or `CompositeTensor`. + """Get all component tensors. - Returns: - All `Tensor` components of `tensor`. - - Raises: - ValueError if supplied `tensor` parameter is neither a `Tensor` nor a - `CompositeTensor`. - """ - if isinstance(tensor, tf.Tensor): - return [tensor] - elif isinstance(tensor, composite_tensor.CompositeTensor): - return tf.nest.flatten(tensor, expand_composites=True) - else: - raise ValueError( - 'Unsupported tensor. Arg `tensor` is neither a `Tensor` nor a ' - f'`CompositeTensor`: {tensor}.') - - -def _get_output_to_inputs_map( - output_signature: Mapping[str, common_types.TensorType] -) -> Dict[str, Iterable[tf.Tensor]]: - """Gets all graph inputs that the tensors in output_signature depend on.""" - result = {} - for name, output in output_signature.items(): - components = _get_component_tensors(output) - sinks = [_as_operation(component) for component in components] - # Ignore control dependencies when walking the graph as we only care about - # which user defined inputs this output depends on. - result[name] = graph_tools.retrieve_sources( - sinks, ignore_control_dependencies=True) - return result - - -class SavedModelLoader: - """Handles a SavedModel exported using TF 1.x APIs in TF 2.x.""" + Args: + ---- + tensor: A `Tensor` or `CompositeTensor`. - def __init__(self, saved_model_dir: str): - """Init method for SavedModelLoader. + Returns: + ------- + All `Tensor` components of `tensor`. - Args: - saved_model_dir: A SavedModel directory providing a transform graph. The - MetaGraphDef and signature are selected from the SavedModel using keys - defined in `../constants.py` ('transform' and 'transform_signature', - respectively). + Raises: + ------ + ValueError if supplied `tensor` parameter is neither a `Tensor` nor a + `CompositeTensor`. """ - imported = tf.saved_model.load(saved_model_dir) - load_v2_in_compat = constants.TRANSFORM_SIGNATURE in imported.signatures - if load_v2_in_compat: - restored_function = imported.signatures[constants.TRANSFORM_SIGNATURE] - wrapped, structured_inputs, structured_outputs = ( - _restore_from_v1_saved_model(restored_function, saved_model_dir)) + if isinstance(tensor, tf.Tensor): + return [tensor] + elif isinstance(tensor, composite_tensor.CompositeTensor): + return tf.nest.flatten(tensor, expand_composites=True) else: - # transform_fn is now a ConcreteFunction, but was a tf.function. We need - # to handle both to maintain backward compatiblity. If it's a tf.function, - # since `input_signature` was specified when exporting the tf function to - # `SavedModel`, there should be exactly one concrete function present on - # loading the `SavedModel`. - if hasattr(imported.transform_fn, 'concrete_functions'): - concrete_functions = imported.transform_fn.concrete_functions - assert len(concrete_functions) == 1, concrete_functions - wrapped = concrete_functions[0] - else: - wrapped = imported.transform_fn - func_graph = wrapped.graph - structured_inputs = ( - tf2_utils.get_structured_inputs_from_func_graph(func_graph)) - structured_outputs = tf.nest.pack_sequence_as( - func_graph.structured_outputs, - func_graph.outputs, - expand_composites=True) - outputs_to_inputs_map = _get_output_to_inputs_map(structured_outputs) - self._initialize(load_v2_in_compat, imported, wrapped, structured_inputs, - structured_outputs, outputs_to_inputs_map) - saved_transform_io._maybe_register_addon_ops() # pylint: disable=protected-access - - def _initialize(self, load_v2_in_compat, imported, wrapped, structured_inputs, - structured_outputs, outputs_to_inputs_map): - """Initializes all class arguments.""" - self._load_v2_in_compat = load_v2_in_compat - self._imported = imported - self._wrapped_function = wrapped - self._func_graph = self._wrapped_function.graph - self._structured_inputs = structured_inputs - self._structured_outputs = structured_outputs - self._output_to_inputs_map = outputs_to_inputs_map - self._sorted_unfed_input_keys = None - self._wrapped_function_finalized = None - self._is_finalized = False - - @property - def load_v2_in_compat(self): - return self._load_v2_in_compat - - @property - def structured_outputs(self): - return self._structured_outputs - - def _get_feeds(self, unfed_input_keys: Iterable[str]) -> Iterable[tf.Tensor]: - """Returns set of tensors that will be fed.""" - result = object_identity.ObjectIdentitySet(self._func_graph.inputs) - for input_key in unfed_input_keys: - unfed_input_components = _get_component_tensors( - self._structured_inputs[input_key]) - result = result.difference(unfed_input_components) - return result + raise ValueError( + "Unsupported tensor. Arg `tensor` is neither a `Tensor` nor a " + f"`CompositeTensor`: {tensor}." + ) - def _get_unfed_input_keys(self, - input_tensor_keys: Iterable[str]) -> Iterable[str]: - return set(self._structured_inputs.keys()).difference(input_tensor_keys) - def _get_fetches( - self, feeds: Iterable[tf.Tensor]) -> Dict[str, common_types.TensorType]: - """Returns set of tensors that can be fetched when `feeds` is supplied.""" - result = {} - for name, output in self._structured_outputs.items(): - extra_sources = self._output_to_inputs_map[name].difference(feeds) - # If output does not depend on an input placeholder that is not being fed, - # add it to fetches. - if not extra_sources.difference(self._func_graph.internal_captures): - result[name] = output - return result - - def _get_fetches_keys(self, feeds: Iterable[tf.Tensor]) -> Iterable[str]: - return self._get_fetches(feeds).keys() - - def _get_missing_inputs( - self, unfed_input_keys: Iterable[str], - batch_size: int) -> Dict[str, common_types.TensorType]: - """Supplies inputs for `unfed_input_keys`.""" - result = {} - if unfed_input_keys: - result = ( - tf2_utils.supply_missing_inputs(self._structured_inputs, batch_size, - unfed_input_keys)) - return result - - def _apply_v1_transform_model_in_v2( - self, logical_input_map: Mapping[str, common_types.TensorType] - ) -> Dict[str, common_types.TensorType]: - """Applies a V1 transform graph to dictionary of (Composite)Tensors. - - This method applies the transformation graph as a pruned function to the - `logical_input_map`. - It prunes the function loaded from the SavedModel to return only outputs - that can be computed from the keys provided in `logical_input_map`. - - Args: - logical_input_map: a dict of logical name to Tensor. The logical names - must be a subset of those in the input signature of the transform graph, - and the corresponding Tensors must have the expected types and shapes. - - Returns: - A dict of logical name to Tensor, as provided by the output signature of - the transform graph. - """ - input_map = ( - saved_transform_io._expand_input_map( # pylint: disable=protected-access - logical_input_map, self._structured_inputs)) - - feeds = [] - pruned_input_args = [] - for name in input_map: - tensor = self._func_graph.get_tensor_by_name(name) - try: - tensor.shape.assert_is_compatible_with(input_map[name].shape) - except ValueError as e: - raise ValueError('{}: {}'.format(name, e)) - feeds.append(tensor) - pruned_input_args.append(input_map[name]) - - fetches = self._get_fetches(feeds) - pruned = self._wrapped_function.prune(feeds, fetches) - result = pruned(*pruned_input_args) - # TODO(b/163329414): Remove set_shape when calling pruned no longer produces - # tensors with unknown shapes. - for name, output in fetches.items(): - if hasattr(result[name], 'set_shape'): - result[name].set_shape(output.shape) - return result - - def _format_input_map_as_tensors(self, input_map): - """Returns a map from string to `tf.Tensor` or `CompositeTensor`.""" +def _get_output_to_inputs_map( + output_signature: Mapping[str, common_types.TensorType], +) -> Dict[str, Iterable[tf.Tensor]]: + """Gets all graph inputs that the tensors in output_signature depend on.""" result = {} - for key, value in input_map.items(): - if isinstance(value, (tf.Tensor, composite_tensor.CompositeTensor)): - result[key] = value - else: - result[key] = tf.convert_to_tensor(value) + for name, output in output_signature.items(): + components = _get_component_tensors(output) + sinks = [_as_operation(component) for component in components] + # Ignore control dependencies when walking the graph as we only care about + # which user defined inputs this output depends on. + result[name] = graph_tools.retrieve_sources( + sinks, ignore_control_dependencies=True + ) return result - def _apply_v2_transform_model_finalized( - self, logical_input_map: Mapping[str, common_types.TensorType] - ) -> Dict[str, common_types.TensorType]: - """Applies a V2 transform graph to dictionary of (Composite)Tensors. - - This method applies the transformation graph to the `logical_input_map` to - return only outputs that can be computed from the keys provided in - `logical_input_map`. It assumes that self.finalize has been called before - this method is invoked. - - Args: - logical_input_map: a dict of logical name to Tensor. The logical names - must be a subset of those in the input signature of the transform graph, - and the corresponding Tensors must have the expected types and shapes. - - Returns: - A dict of logical name to Tensor, as provided by the output signature of - the transform graph. - """ - - # Assert that the same keys are fed as this model was finalized with. - unfed_input_keys = self._get_unfed_input_keys(logical_input_map.keys()) - assert sorted(unfed_input_keys) == self._sorted_unfed_input_keys - - modified_inputs = self._format_input_map_as_tensors(logical_input_map) - return self._wrapped_function_finalized(modified_inputs) - - def _apply_v2_transform_model( - self, logical_input_map: Mapping[str, common_types.TensorType] - ) -> Dict[str, common_types.TensorType]: - """Applies a V2 transform graph to dictionary of (Composite)Tensors. - - This method applies the transformation graph to the `logical_input_map` to - return only outputs that can be computed from the keys provided in - `logical_input_map`. - - Args: - logical_input_map: a dict of logical name to Tensor. The logical names - must be a subset of those in the input signature of the transform graph, - and the corresponding Tensors must have the expected types and shapes. - - Returns: - A dict of logical name to Tensor, as provided by the output signature of - the transform graph. - """ - - unfed_input_keys = self._get_unfed_input_keys(logical_input_map.keys()) - feeds = self._get_feeds(unfed_input_keys) - modified_inputs = self._format_input_map_as_tensors(logical_input_map) - - if unfed_input_keys: - batch_size = 1 - if logical_input_map: - an_input = next(iter(logical_input_map.values())) - if isinstance(an_input, tf.RaggedTensor): - batch_size = an_input.bounding_shape(axis=0) - elif tf.shape(an_input)[0] is not None: - batch_size = tf.shape(an_input)[0] - - missing_inputs = self._get_missing_inputs(unfed_input_keys, batch_size) - modified_inputs.update(missing_inputs) - - flattened_inputs = tf.nest.flatten(modified_inputs, expand_composites=True) - - # self._wrapped_function.inputs may be longer than flattened_inputs as it - # also contains captured inputs. However, we only want the user inputs here - # so we don't assert equal length. - for input_t, wrapped_input in zip(flattened_inputs, - self._wrapped_function.inputs): - try: - wrapped_input.shape.assert_is_compatible_with(input_t.shape) - except ValueError as e: - raise ValueError('{}: {}'.format(input_t, e)) - - transformed_features = self._wrapped_function(*flattened_inputs) - fetches_keys = self._get_fetches_keys(feeds) - return {key: transformed_features[key] for key in fetches_keys} - - def apply_transform_model( - self, logical_input_map: Mapping[str, common_types.TensorType] - ) -> Dict[str, common_types.TensorType]: - """Applies a transform graph to dictionary of (Composite)Tensors. - Args: - logical_input_map: a dict of logical name to Tensor. The logical names - must be a subset of those in the input signature of the transform graph, - and the corresponding Tensors must have the expected types and shapes. - - Returns: - A dict of logical name to Tensor, as provided by the output signature of - the transform graph. - """ - unexpected_inputs = ( - set(logical_input_map.keys()) - set(self._structured_inputs.keys())) - if unexpected_inputs: - raise ValueError( - 'Unexpected inputs to transform: {}'.format(unexpected_inputs)) - - if self.load_v2_in_compat: - return self._apply_v1_transform_model_in_v2(logical_input_map) - elif self._is_finalized: - return self._apply_v2_transform_model_finalized(logical_input_map) - else: - return self._apply_v2_transform_model(logical_input_map) - - def _finalize_wrapped_function( - self, unfed_input_keys: Iterable[str], - fetches_keys: Iterable[str]) -> function.ConcreteFunction: - """Constructs a function that can be invoked without `unfed_input_keys`.""" - original_input_signature = ( - self._wrapped_function.structured_input_signature[0][0]) - input_signature = { - k: v - for k, v in original_input_signature.items() - if k not in unfed_input_keys - } - - @tf.function(input_signature=[input_signature], autograph=False) - def wrapped_finalized(inputs): - missing_inputs = self._get_missing_inputs(unfed_input_keys, batch_size=1) - # Directly modifying inputs is not allowed in a tf.function. Hence, we - # make a deep copy here. - inputs_copy = tf_utils.copy_tensors(inputs) - inputs_copy.update(missing_inputs) - flattened_inputs = tf.nest.flatten(inputs_copy, expand_composites=True) - transformed_features = self._wrapped_function(*flattened_inputs) - return {key: transformed_features[key] for key in fetches_keys} - - return wrapped_finalized.get_concrete_function() - - # TODO(b/177672051): Consider calling finalize in the TransformFeaturesLayer. - def finalize(self, input_tensor_keys: Iterable[str], - output_tensor_keys: Iterable[str]): - """Finalizes the set of inputs with which this SavedModel will be called. - - Note: This is not Thread-safe. Should be called prior to any calls to - `apply_transform_model`. - - Args: - input_tensor_keys: Set of input keys with which the SavedModel will be - called. - output_tensor_keys: Set of output keys that should be returned by the - SavedModel. - """ - self._sorted_unfed_input_keys = sorted( - self._get_unfed_input_keys(input_tensor_keys)) - feeds = self._get_feeds(self._sorted_unfed_input_keys) - unexpected_outputs = ( - set(output_tensor_keys) - set(self._get_fetches_keys(feeds))) - if unexpected_outputs: - raise ValueError( - 'Unexpected output keys requested: {}'.format(unexpected_outputs)) - self._wrapped_function_finalized = self._finalize_wrapped_function( - self._sorted_unfed_input_keys, sorted(output_tensor_keys)) - self._is_finalized = True +class SavedModelLoader: + """Handles a SavedModel exported using TF 1.x APIs in TF 2.x.""" + + def __init__(self, saved_model_dir: str): + """Init method for SavedModelLoader. + + Args: + ---- + saved_model_dir: A SavedModel directory providing a transform graph. The + MetaGraphDef and signature are selected from the SavedModel using keys + defined in `../constants.py` ('transform' and 'transform_signature', + respectively). + """ + imported = tf.saved_model.load(saved_model_dir) + load_v2_in_compat = constants.TRANSFORM_SIGNATURE in imported.signatures + if load_v2_in_compat: + restored_function = imported.signatures[constants.TRANSFORM_SIGNATURE] + wrapped, structured_inputs, structured_outputs = ( + _restore_from_v1_saved_model(restored_function, saved_model_dir) + ) + else: + # transform_fn is now a ConcreteFunction, but was a tf.function. We need + # to handle both to maintain backward compatiblity. If it's a tf.function, + # since `input_signature` was specified when exporting the tf function to + # `SavedModel`, there should be exactly one concrete function present on + # loading the `SavedModel`. + if hasattr(imported.transform_fn, "concrete_functions"): + concrete_functions = imported.transform_fn.concrete_functions + assert len(concrete_functions) == 1, concrete_functions + wrapped = concrete_functions[0] + else: + wrapped = imported.transform_fn + func_graph = wrapped.graph + structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( + func_graph + ) + structured_outputs = tf.nest.pack_sequence_as( + func_graph.structured_outputs, + func_graph.outputs, + expand_composites=True, + ) + outputs_to_inputs_map = _get_output_to_inputs_map(structured_outputs) + self._initialize( + load_v2_in_compat, + imported, + wrapped, + structured_inputs, + structured_outputs, + outputs_to_inputs_map, + ) + saved_transform_io._maybe_register_addon_ops() # pylint: disable=protected-access + + def _initialize( + self, + load_v2_in_compat, + imported, + wrapped, + structured_inputs, + structured_outputs, + outputs_to_inputs_map, + ): + """Initializes all class arguments.""" + self._load_v2_in_compat = load_v2_in_compat + self._imported = imported + self._wrapped_function = wrapped + self._func_graph = self._wrapped_function.graph + self._structured_inputs = structured_inputs + self._structured_outputs = structured_outputs + self._output_to_inputs_map = outputs_to_inputs_map + self._sorted_unfed_input_keys = None + self._wrapped_function_finalized = None + self._is_finalized = False + + @property + def load_v2_in_compat(self): + return self._load_v2_in_compat + + @property + def structured_outputs(self): + return self._structured_outputs + + def _get_feeds(self, unfed_input_keys: Iterable[str]) -> Iterable[tf.Tensor]: + """Returns set of tensors that will be fed.""" + result = object_identity.ObjectIdentitySet(self._func_graph.inputs) + for input_key in unfed_input_keys: + unfed_input_components = _get_component_tensors( + self._structured_inputs[input_key] + ) + result = result.difference(unfed_input_components) + return result + + def _get_unfed_input_keys(self, input_tensor_keys: Iterable[str]) -> Iterable[str]: + return set(self._structured_inputs.keys()).difference(input_tensor_keys) + + def _get_fetches( + self, feeds: Iterable[tf.Tensor] + ) -> Dict[str, common_types.TensorType]: + """Returns set of tensors that can be fetched when `feeds` is supplied.""" + result = {} + for name, output in self._structured_outputs.items(): + extra_sources = self._output_to_inputs_map[name].difference(feeds) + # If output does not depend on an input placeholder that is not being fed, + # add it to fetches. + if not extra_sources.difference(self._func_graph.internal_captures): + result[name] = output + return result + + def _get_fetches_keys(self, feeds: Iterable[tf.Tensor]) -> Iterable[str]: + return self._get_fetches(feeds).keys() + + def _get_missing_inputs( + self, unfed_input_keys: Iterable[str], batch_size: int + ) -> Dict[str, common_types.TensorType]: + """Supplies inputs for `unfed_input_keys`.""" + result = {} + if unfed_input_keys: + result = tf2_utils.supply_missing_inputs( + self._structured_inputs, batch_size, unfed_input_keys + ) + return result + + def _apply_v1_transform_model_in_v2( + self, logical_input_map: Mapping[str, common_types.TensorType] + ) -> Dict[str, common_types.TensorType]: + """Applies a V1 transform graph to dictionary of (Composite)Tensors. + + This method applies the transformation graph as a pruned function to the + `logical_input_map`. + It prunes the function loaded from the SavedModel to return only outputs + that can be computed from the keys provided in `logical_input_map`. + + Args: + ---- + logical_input_map: a dict of logical name to Tensor. The logical names + must be a subset of those in the input signature of the transform graph, + and the corresponding Tensors must have the expected types and shapes. + + Returns: + ------- + A dict of logical name to Tensor, as provided by the output signature of + the transform graph. + """ + input_map = saved_transform_io._expand_input_map( # pylint: disable=protected-access + logical_input_map, self._structured_inputs + ) + + feeds = [] + pruned_input_args = [] + for name in input_map: + tensor = self._func_graph.get_tensor_by_name(name) + try: + tensor.shape.assert_is_compatible_with(input_map[name].shape) + except ValueError as e: + raise ValueError(f"{name}: {e}") + feeds.append(tensor) + pruned_input_args.append(input_map[name]) + + fetches = self._get_fetches(feeds) + pruned = self._wrapped_function.prune(feeds, fetches) + result = pruned(*pruned_input_args) + # TODO(b/163329414): Remove set_shape when calling pruned no longer produces + # tensors with unknown shapes. + for name, output in fetches.items(): + if hasattr(result[name], "set_shape"): + result[name].set_shape(output.shape) + return result + + def _format_input_map_as_tensors(self, input_map): + """Returns a map from string to `tf.Tensor` or `CompositeTensor`.""" + result = {} + for key, value in input_map.items(): + if isinstance(value, (tf.Tensor, composite_tensor.CompositeTensor)): + result[key] = value + else: + result[key] = tf.convert_to_tensor(value) + return result + + def _apply_v2_transform_model_finalized( + self, logical_input_map: Mapping[str, common_types.TensorType] + ) -> Dict[str, common_types.TensorType]: + """Applies a V2 transform graph to dictionary of (Composite)Tensors. + + This method applies the transformation graph to the `logical_input_map` to + return only outputs that can be computed from the keys provided in + `logical_input_map`. It assumes that self.finalize has been called before + this method is invoked. + + Args: + ---- + logical_input_map: a dict of logical name to Tensor. The logical names + must be a subset of those in the input signature of the transform graph, + and the corresponding Tensors must have the expected types and shapes. + + Returns: + ------- + A dict of logical name to Tensor, as provided by the output signature of + the transform graph. + """ + # Assert that the same keys are fed as this model was finalized with. + unfed_input_keys = self._get_unfed_input_keys(logical_input_map.keys()) + assert sorted(unfed_input_keys) == self._sorted_unfed_input_keys + + modified_inputs = self._format_input_map_as_tensors(logical_input_map) + return self._wrapped_function_finalized(modified_inputs) + + def _apply_v2_transform_model( + self, logical_input_map: Mapping[str, common_types.TensorType] + ) -> Dict[str, common_types.TensorType]: + """Applies a V2 transform graph to dictionary of (Composite)Tensors. + + This method applies the transformation graph to the `logical_input_map` to + return only outputs that can be computed from the keys provided in + `logical_input_map`. + + Args: + ---- + logical_input_map: a dict of logical name to Tensor. The logical names + must be a subset of those in the input signature of the transform graph, + and the corresponding Tensors must have the expected types and shapes. + + Returns: + ------- + A dict of logical name to Tensor, as provided by the output signature of + the transform graph. + """ + unfed_input_keys = self._get_unfed_input_keys(logical_input_map.keys()) + feeds = self._get_feeds(unfed_input_keys) + modified_inputs = self._format_input_map_as_tensors(logical_input_map) + + if unfed_input_keys: + batch_size = 1 + if logical_input_map: + an_input = next(iter(logical_input_map.values())) + if isinstance(an_input, tf.RaggedTensor): + batch_size = an_input.bounding_shape(axis=0) + elif tf.shape(an_input)[0] is not None: + batch_size = tf.shape(an_input)[0] + + missing_inputs = self._get_missing_inputs(unfed_input_keys, batch_size) + modified_inputs.update(missing_inputs) + + flattened_inputs = tf.nest.flatten(modified_inputs, expand_composites=True) + + # self._wrapped_function.inputs may be longer than flattened_inputs as it + # also contains captured inputs. However, we only want the user inputs here + # so we don't assert equal length. + for input_t, wrapped_input in zip( + flattened_inputs, self._wrapped_function.inputs + ): + try: + wrapped_input.shape.assert_is_compatible_with(input_t.shape) + except ValueError as e: + raise ValueError(f"{input_t}: {e}") + + transformed_features = self._wrapped_function(*flattened_inputs) + fetches_keys = self._get_fetches_keys(feeds) + return {key: transformed_features[key] for key in fetches_keys} + + def apply_transform_model( + self, logical_input_map: Mapping[str, common_types.TensorType] + ) -> Dict[str, common_types.TensorType]: + """Applies a transform graph to dictionary of (Composite)Tensors. + + Args: + ---- + logical_input_map: a dict of logical name to Tensor. The logical names + must be a subset of those in the input signature of the transform graph, + and the corresponding Tensors must have the expected types and shapes. + + Returns: + ------- + A dict of logical name to Tensor, as provided by the output signature of + the transform graph. + """ + unexpected_inputs = set(logical_input_map.keys()) - set( + self._structured_inputs.keys() + ) + if unexpected_inputs: + raise ValueError(f"Unexpected inputs to transform: {unexpected_inputs}") + + if self.load_v2_in_compat: + return self._apply_v1_transform_model_in_v2(logical_input_map) + elif self._is_finalized: + return self._apply_v2_transform_model_finalized(logical_input_map) + else: + return self._apply_v2_transform_model(logical_input_map) + + def _finalize_wrapped_function( + self, unfed_input_keys: Iterable[str], fetches_keys: Iterable[str] + ) -> function.ConcreteFunction: + """Constructs a function that can be invoked without `unfed_input_keys`.""" + original_input_signature = self._wrapped_function.structured_input_signature[0][ + 0 + ] + input_signature = { + k: v + for k, v in original_input_signature.items() + if k not in unfed_input_keys + } + + @tf.function(input_signature=[input_signature], autograph=False) + def wrapped_finalized(inputs): + missing_inputs = self._get_missing_inputs(unfed_input_keys, batch_size=1) + # Directly modifying inputs is not allowed in a tf.function. Hence, we + # make a deep copy here. + inputs_copy = tf_utils.copy_tensors(inputs) + inputs_copy.update(missing_inputs) + flattened_inputs = tf.nest.flatten(inputs_copy, expand_composites=True) + transformed_features = self._wrapped_function(*flattened_inputs) + return {key: transformed_features[key] for key in fetches_keys} + + return wrapped_finalized.get_concrete_function() + + # TODO(b/177672051): Consider calling finalize in the TransformFeaturesLayer. + def finalize( + self, input_tensor_keys: Iterable[str], output_tensor_keys: Iterable[str] + ): + """Finalizes the set of inputs with which this SavedModel will be called. + + Note: This is not Thread-safe. Should be called prior to any calls to + `apply_transform_model`. + + Args: + ---- + input_tensor_keys: Set of input keys with which the SavedModel will be + called. + output_tensor_keys: Set of output keys that should be returned by the + SavedModel. + """ + self._sorted_unfed_input_keys = sorted( + self._get_unfed_input_keys(input_tensor_keys) + ) + feeds = self._get_feeds(self._sorted_unfed_input_keys) + unexpected_outputs = set(output_tensor_keys) - set( + self._get_fetches_keys(feeds) + ) + if unexpected_outputs: + raise ValueError(f"Unexpected output keys requested: {unexpected_outputs}") + self._wrapped_function_finalized = self._finalize_wrapped_function( + self._sorted_unfed_input_keys, sorted(output_tensor_keys) + ) + self._is_finalized = True # TODO(b/177606209): Remove once TF supports saving optimized functions. # TODO(b/129646028): WrappedFunction.prune does not support composite tensors. # Hence, add additional handling when supporting composite tensors in TFT. def optimize_concrete_function( - concrete_function: function.ConcreteFunction, - strip_control_dependencies: bool) -> wrap_function.WrappedFunction: - """Returns optimized function with same signature as `concrete_function`.""" - wrapped_fn = wrap_function.WrappedFunction( - concrete_function.graph, - variable_holder=wrap_function.VariableHolder(share_variables=True)) - fetches = concrete_function.structured_outputs - if strip_control_dependencies: - flat_outputs, _ = tf2_utils.strip_and_get_tensors_and_control_dependencies( - tf.nest.flatten(fetches, expand_composites=True)) - fetches = tf.nest.pack_sequence_as( - concrete_function.structured_outputs, - flat_outputs, - expand_composites=True) - result = wrapped_fn.prune( - feeds=concrete_function.inputs, - fetches=fetches, - input_signature=concrete_function.structured_input_signature) - # TODO(b/163329414): Remove once `prune` retains shape information for all - # components. - for original_out, pruned_out in zip(concrete_function.outputs, - result.outputs): - pruned_out.set_shape(original_out.get_shape()) - return result + concrete_function: function.ConcreteFunction, strip_control_dependencies: bool +) -> wrap_function.WrappedFunction: + """Returns optimized function with same signature as `concrete_function`.""" + wrapped_fn = wrap_function.WrappedFunction( + concrete_function.graph, + variable_holder=wrap_function.VariableHolder(share_variables=True), + ) + fetches = concrete_function.structured_outputs + if strip_control_dependencies: + flat_outputs, _ = tf2_utils.strip_and_get_tensors_and_control_dependencies( + tf.nest.flatten(fetches, expand_composites=True) + ) + fetches = tf.nest.pack_sequence_as( + concrete_function.structured_outputs, flat_outputs, expand_composites=True + ) + result = wrapped_fn.prune( + feeds=concrete_function.inputs, + fetches=fetches, + input_signature=concrete_function.structured_input_signature, + ) + # TODO(b/163329414): Remove once `prune` retains shape information for all + # components. + for original_out, pruned_out in zip(concrete_function.outputs, result.outputs): + pruned_out.set_shape(original_out.get_shape()) + return result def trace_and_update_module( @@ -466,67 +505,74 @@ def trace_and_update_module( name: str, strip_control_dependencies: bool, ) -> function.ConcreteFunction: - """Traces `tf_function` and saves under attr `name` of `module`. - - Args: - module: A saveable module which will contain the traced `tf_function` under - attr `name`. - tf_function: A tf.function to trace. - name: A name to same the traced `tf_function` to. - strip_control_dependencies: Boolean. If True, automatic control dependencies - will be stripped from the outputs of `tf_function`. This should almost - always be False. It is useful only if you want to use the structure of the - TF graph to perform any graph manipulations. - - Returns: - The concrete function obtained from tracing `tf_function`. - """ - resource_tracker = tracking.ResourceTracker() - object_tracker = annotators.ObjectTracker() - created_variables = [] - - def _variable_creator(next_creator, **kwargs): - var = next_creator(**kwargs) - created_variables.append(var) - return var - - # Trace `tf_function` to gather any resources in it using the - # resource_tracker. These are then assigned to `module.resources` and tracked - # before exporting to SavedModel. - with tracking.resource_tracker_scope(resource_tracker), \ - annotators.object_tracker_scope(object_tracker), \ - tf.variable_creator_scope(_variable_creator): - concrete_fn = tf_function.get_concrete_function() - - # Prior to 2020/10/08, saving a tf.function with a concrete function signature - # would ensure that the function was not re-traced in a round-trip to a - # SavedModel. Since this is no longer the case, we save the concrete function - # directly. - if tf.compat.forward_compatible(2020, 10, 8): - pruned_function = optimize_concrete_function(concrete_fn, - strip_control_dependencies) - module.pruned_variables = pruned_function.variables - setattr(module, name, pruned_function) - else: - setattr(module, name, tf_function) - - # Any variables created need to be explicitly tracked. - module.created_variables = created_variables - # Resources need to be explicitly tracked. - module.resources = resource_tracker.resources - module.trackable_objects = object_tracker.trackable_objects - # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the - # table should be sufficient. - initializers = [] - for resource in module.resources: - if isinstance(resource, lookup_ops.InitializableLookupTableBase): - initializers.append(resource._initializer) # pylint: disable=protected-access - module.initializers = initializers - module.assets = [ - tf.saved_model.Asset(asset_filepath) for asset_filepath in - concrete_fn.graph.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS) - ] - return concrete_fn + """Traces `tf_function` and saves under attr `name` of `module`. + + Args: + ---- + module: A saveable module which will contain the traced `tf_function` under + attr `name`. + tf_function: A tf.function to trace. + name: A name to same the traced `tf_function` to. + strip_control_dependencies: Boolean. If True, automatic control dependencies + will be stripped from the outputs of `tf_function`. This should almost + always be False. It is useful only if you want to use the structure of the + TF graph to perform any graph manipulations. + + Returns: + ------- + The concrete function obtained from tracing `tf_function`. + """ + resource_tracker = tracking.ResourceTracker() + object_tracker = annotators.ObjectTracker() + created_variables = [] + + def _variable_creator(next_creator, **kwargs): + var = next_creator(**kwargs) + created_variables.append(var) + return var + + # Trace `tf_function` to gather any resources in it using the + # resource_tracker. These are then assigned to `module.resources` and tracked + # before exporting to SavedModel. + with tracking.resource_tracker_scope( + resource_tracker + ), annotators.object_tracker_scope(object_tracker), tf.variable_creator_scope( + _variable_creator + ): + concrete_fn = tf_function.get_concrete_function() + + # Prior to 2020/10/08, saving a tf.function with a concrete function signature + # would ensure that the function was not re-traced in a round-trip to a + # SavedModel. Since this is no longer the case, we save the concrete function + # directly. + if tf.compat.forward_compatible(2020, 10, 8): + pruned_function = optimize_concrete_function( + concrete_fn, strip_control_dependencies + ) + module.pruned_variables = pruned_function.variables + setattr(module, name, pruned_function) + else: + setattr(module, name, tf_function) + + # Any variables created need to be explicitly tracked. + module.created_variables = created_variables + # Resources need to be explicitly tracked. + module.resources = resource_tracker.resources + module.trackable_objects = object_tracker.trackable_objects + # TODO(b/158011374) - Stop explicitly tracking initializers. Tracking the + # table should be sufficient. + initializers = [] + for resource in module.resources: + if isinstance(resource, lookup_ops.InitializableLookupTableBase): + initializers.append(resource._initializer) # pylint: disable=protected-access + module.initializers = initializers + module.assets = [ + tf.saved_model.Asset(asset_filepath) + for asset_filepath in concrete_fn.graph.get_collection( + tf.compat.v1.GraphKeys.ASSET_FILEPATHS + ) + ] + return concrete_fn def write_v2_saved_model( @@ -536,13 +582,13 @@ def write_v2_saved_model( saved_model_dir: str, save_options: Optional[tf.saved_model.SaveOptions] = None, ) -> function.ConcreteFunction: - """Writes `tf_function` under attr `name` of `module` to `saved_model_dir`.""" - concrete_fn = trace_and_update_module( - module, tf_function, name, strip_control_dependencies=False - ) - tf.saved_model.save( - module, - saved_model_dir, - options=save_options, - ) - return concrete_fn + """Writes `tf_function` under attr `name` of `module` to `saved_model_dir`.""" + concrete_fn = trace_and_update_module( + module, tf_function, name, strip_control_dependencies=False + ) + tf.saved_model.save( + module, + saved_model_dir, + options=save_options, + ) + return concrete_fn diff --git a/tensorflow_transform/saved/saved_transform_io_v2_test.py b/tensorflow_transform/saved/saved_transform_io_v2_test.py index c6d2a06..9e73d25 100644 --- a/tensorflow_transform/saved/saved_transform_io_v2_test.py +++ b/tensorflow_transform/saved/saved_transform_io_v2_test.py @@ -19,652 +19,699 @@ import numpy as np import tensorflow as tf -from tensorflow_transform import graph_context -from tensorflow_transform import impl_helper -from tensorflow_transform import tf_utils -from tensorflow_transform import test_case -from tensorflow_transform.py_func.api import apply_pyfunc -from tensorflow_transform.saved import constants -from tensorflow_transform.saved import saved_transform_io -from tensorflow_transform.saved import saved_transform_io_v2 # pylint: disable=g-direct-tensorflow-import from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.lib.io import file_io from tensorflow.python.ops import script_ops + +from tensorflow_transform import graph_context, impl_helper, test_case, tf_utils +from tensorflow_transform.py_func.api import apply_pyfunc +from tensorflow_transform.saved import ( + constants, + saved_transform_io, + saved_transform_io_v2, +) + # pylint: enable=g-direct-tensorflow-import _TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES = [ - dict(testcase_name='_exported_in_tf1', exported_in_tf1=True), - dict(testcase_name='_exported_in_tf2', exported_in_tf1=False) + dict(testcase_name="_exported_in_tf1", exported_in_tf1=True), + dict(testcase_name="_exported_in_tf2", exported_in_tf1=False), ] def _get_preprocessing_fn_asset_table(asset_file): + def construct_table(asset_path): + initializer = tf.lookup.TextFileInitializer( + asset_path, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + ) + return tf.lookup.StaticHashTable(initializer, default_value=-1) - def construct_table(asset_path): - initializer = tf.lookup.TextFileInitializer( - asset_path, - key_dtype=tf.string, - key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - value_dtype=tf.int64, - value_index=tf.lookup.TextFileIndex.LINE_NUMBER) - return tf.lookup.StaticHashTable(initializer, default_value=-1) - - def preprocessing_fn(inputs): - output, unused_table_size = tf_utils.construct_and_lookup_table( - construct_table, asset_file, inputs['input']) - return {'output': output} + def preprocessing_fn(inputs): + output, unused_table_size = tf_utils.construct_and_lookup_table( + construct_table, asset_file, inputs["input"] + ) + return {"output": output} - return preprocessing_fn + return preprocessing_fn def _get_preprocessing_fn_non_asset_table(asset_file): - del asset_file + del asset_file - def preprocessing_fn(inputs): - initializer = tf.lookup.KeyValueTensorInitializer( - keys=['foo', 'bar', 'baz'], - values=tf.cast(tf.range(3), tf.int64), - key_dtype=tf.string, - value_dtype=tf.int64) - table = tf.lookup.StaticHashTable(initializer, default_value=12) - return { - 'output': table.lookup(inputs['input']), - } + def preprocessing_fn(inputs): + initializer = tf.lookup.KeyValueTensorInitializer( + keys=["foo", "bar", "baz"], + values=tf.cast(tf.range(3), tf.int64), + key_dtype=tf.string, + value_dtype=tf.int64, + ) + table = tf.lookup.StaticHashTable(initializer, default_value=12) + return { + "output": table.lookup(inputs["input"]), + } - return preprocessing_fn + return preprocessing_fn _RE_EXPORT_TF2_TO_TF1_TEST_CASES = [ dict( - testcase_name='_asset_table', + testcase_name="_asset_table", preprocessing_fn_getter=_get_preprocessing_fn_asset_table, expected_output=2, - test_input='baz', - asset_file_contents='foo\nbar\nbaz\n'), + test_input="baz", + asset_file_contents="foo\nbar\nbaz\n", + ), dict( - testcase_name='_non_asset_table', + testcase_name="_non_asset_table", preprocessing_fn_getter=_get_preprocessing_fn_non_asset_table, expected_output=2, - test_input='baz'), + test_input="baz", + ), ] # TODO(b/123241798): Find an open-source compatible way to access # FLAGS.test_tmpdir. -def _create_test_saved_model(export_in_tf1, - input_specs, - preprocessing_fn, - export_path_suffix=None, - base_dir=None): - if not export_path_suffix: - export_path = os.path.join(tempfile.mkdtemp(dir=base_dir), 'export') - else: - export_path = os.path.join( - tempfile.mkdtemp(dir=base_dir), export_path_suffix) - if export_in_tf1: - with tf.compat.v1.Graph().as_default(): - with tf.compat.v1.Session().as_default() as session: - inputs = {} - for key in input_specs: - tensor_spec = input_specs[key] - if isinstance(tensor_spec, tf.TensorSpec): - inputs[key] = tf.compat.v1.placeholder( - tensor_spec.dtype, shape=tensor_spec.shape +def _create_test_saved_model( + export_in_tf1, input_specs, preprocessing_fn, export_path_suffix=None, base_dir=None +): + if not export_path_suffix: + export_path = os.path.join(tempfile.mkdtemp(dir=base_dir), "export") + else: + export_path = os.path.join(tempfile.mkdtemp(dir=base_dir), export_path_suffix) + if export_in_tf1: + with tf.compat.v1.Graph().as_default(): + with tf.compat.v1.Session().as_default() as session: + inputs = {} + for key in input_specs: + tensor_spec = input_specs[key] + if isinstance(tensor_spec, tf.TensorSpec): + inputs[key] = tf.compat.v1.placeholder( + tensor_spec.dtype, shape=tensor_spec.shape + ) + elif isinstance(tensor_spec, tf.SparseTensorSpec): + inputs[key] = tf.compat.v1.sparse_placeholder( + tensor_spec.dtype, shape=tensor_spec.shape + ) + elif isinstance(tensor_spec, tf.RaggedTensorSpec): + inputs[key] = tf.compat.v1.ragged.placeholder( + tensor_spec._dtype, tensor_spec._ragged_rank, [] + ) + else: + raise ValueError( + "TypeSpecs specified should be one of `tf.TensorSpec`, " + "`tf.SparseTensorSpec`, `tf.RaggedTensorSpec`" + ) + outputs = preprocessing_fn(inputs) + # show that unrelated & unmapped placeholders do not interfere + tf.compat.v1.placeholder(tf.int64) + saved_transform_io.write_saved_transform_from_session( + session, inputs, outputs, export_path + ) + else: + module = tf.Module() + tf_graph_context = graph_context.TFGraphContext( + module_to_export=module, temp_dir=None, evaluated_replacements=None + ) + transform_fn = impl_helper.get_traced_transform_fn( + preprocessing_fn=preprocessing_fn, + input_signature=input_specs, + tf_graph_context=tf_graph_context, + output_keys_to_name_map=None, + ) + + saved_transform_io_v2.write_v2_saved_model( + module, transform_fn, "transform_fn", export_path, None + ) + return export_path + + +class SavedTransformIOV2Test(test_case.TransformTestCase): + @classmethod + def setUpClass(cls): + test_case.skip_if_not_tf2("Tensorflow 2.x required.") + input_specs = { + "x": tf.TensorSpec( + [ + None, + ], + dtype=tf.float32, ) - elif isinstance(tensor_spec, tf.SparseTensorSpec): - inputs[key] = tf.compat.v1.sparse_placeholder( - tensor_spec.dtype, shape=tensor_spec.shape + } + + def preprocessing_fn(inputs): + output = (inputs["x"] - 2.0) / 5.0 + return {"x_scaled": output} + + cls._saved_model_path_v1 = _create_test_saved_model( + True, input_specs, preprocessing_fn, "export_v1" + ) + cls._saved_model_path_v2 = _create_test_saved_model( + False, input_specs, preprocessing_fn, "export_v2" + ) + + def _get_saved_model_loader(self, exported_in_tf1): + if exported_in_tf1: + return saved_transform_io_v2.SavedModelLoader(self._saved_model_path_v1) + return saved_transform_io_v2.SavedModelLoader(self._saved_model_path_v2) + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_apply_saved_transform(self, exported_in_tf1): + input_floats = tf.constant([1237.0]) # tf.float32 + input_features = {"x": input_floats} + transformed_features = self._get_saved_model_loader( + exported_in_tf1 + ).apply_transform_model(input_features) + self.assertEqual(["x_scaled"], list(transformed_features)) + result_tensor = transformed_features["x_scaled"] + self.assertIsInstance(result_tensor, tf.Tensor) + self.assertAllEqual(result_tensor.numpy(), [247.0]) + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_apply_saved_transform_dataset_map(self, exported_in_tf1): + ds = tf.data.Dataset.from_tensor_slices({"x": [[1237.0]]}) + model_loader = self._get_saved_model_loader(exported_in_tf1) + + def map_fn(inputs): + result = model_loader.apply_transform_model(inputs) + self.assertEqual(["x_scaled"], list(result)) + result_tensor = result["x_scaled"] + self.assertIsInstance(result_tensor, tf.Tensor) + self.assertEqual(result_tensor.shape.as_list(), [1]) + return result + + result_ds = ds.map(map_fn) + self.assertAllEqual( + list(result_ds.as_numpy_iterator()), [{"x_scaled": [247.0]}] + ) + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_apply_transform_extra_features_no_passthrough(self, exported_in_tf1): + with self.assertRaises(ValueError): + input_floats = tf.constant([1237.0]) # tf.float32 + input_features = { + "x": input_floats, + "extra_1": tf.constant("1"), + "extra_2": tf.constant("2"), + } + self._get_saved_model_loader(exported_in_tf1).apply_transform_model( + input_features ) - elif isinstance(tensor_spec, tf.RaggedTensorSpec): - inputs[key] = tf.compat.v1.ragged.placeholder( - tensor_spec._dtype, tensor_spec._ragged_rank, [] + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_apply_transform_type_mismatch(self, exported_in_tf1): + with self.assertRaises(tf.errors.InvalidArgumentError): + input_strings = tf.constant(["bogus"]) # tf.string + input_features = {"x": input_strings} + self._get_saved_model_loader(exported_in_tf1).apply_transform_model( + input_features ) - else: - raise ValueError( - 'TypeSpecs specified should be one of `tf.TensorSpec`, ' - '`tf.SparseTensorSpec`, `tf.RaggedTensorSpec`' + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_apply_transform_shape_mismatch(self, exported_in_tf1): + with self.assertRaises(ValueError): + input_floats = tf.constant(1237.0) # tf.float32 + input_features = {"x": input_floats} + self._get_saved_model_loader(exported_in_tf1).apply_transform_model( + input_features ) - outputs = preprocessing_fn(inputs) - # show that unrelated & unmapped placeholders do not interfere - tf.compat.v1.placeholder(tf.int64) - saved_transform_io.write_saved_transform_from_session( - session, inputs, outputs, export_path + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_apply_saved_transform_to_tensor_inside_scope(self, exported_in_tf1): + with tf.compat.v1.name_scope("my_scope"): + input_floats = tf.constant([1237.0]) # tf.float32 + input_features = {"x": input_floats} + transformed_features = self._get_saved_model_loader( + exported_in_tf1 + ).apply_transform_model(input_features) + self.assertEqual(["x_scaled"], list(transformed_features)) + result_tensor = transformed_features["x_scaled"] + self.assertIsInstance(result_tensor, tf.Tensor) + self.assertAllEqual(result_tensor.numpy(), [247.0]) + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_apply_saved_transform_to_tensor_outside_scope(self, exported_in_tf1): + input_floats = tf.constant([1237.0]) # tf.float32 + with tf.compat.v1.name_scope("my_scope"): + input_features = {"x": input_floats} + transformed_features = self._get_saved_model_loader( + exported_in_tf1 + ).apply_transform_model(input_features) + self.assertEqual(["x_scaled"], list(transformed_features)) + result_tensor = transformed_features["x_scaled"] + self.assertIsInstance(result_tensor, tf.Tensor) + self.assertAllEqual(result_tensor.numpy(), [247.0]) + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_dense_roundtrip(self, exported_in_tf1): + input_specs = {"input": tf.TensorSpec([], dtype=tf.float32)} + + def preprocessing_fn(inputs): + return {"output": inputs["input"] / 5.0} + + export_path = _create_test_saved_model( + exported_in_tf1, input_specs, preprocessing_fn, base_dir=self.get_temp_dir() ) - else: - module = tf.Module() - tf_graph_context = graph_context.TFGraphContext( - module_to_export=module, temp_dir=None, evaluated_replacements=None - ) - transform_fn = impl_helper.get_traced_transform_fn( - preprocessing_fn=preprocessing_fn, - input_signature=input_specs, - tf_graph_context=tf_graph_context, - output_keys_to_name_map=None, - ) - saved_transform_io_v2.write_v2_saved_model( - module, transform_fn, 'transform_fn', export_path, None - ) - return export_path + # Using a computed input gives confidence that the graphs are fused. + input_float = tf.constant(25.0) * 2 + inputs = {"input": input_float} + saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) + outputs = saved_model_loader.apply_transform_model(inputs) + # (25 * 2) / 5 = 10 + self.assertEqual(10.0, outputs["output"].numpy()) + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_table_roundtrip(self, exported_in_tf1): + input_specs = {"input": tf.TensorSpec([], dtype=tf.string)} + + def preprocessing_fn(inputs): + table_keys = ["cat", "dog", "giraffe"] + initializer = tf.lookup.KeyValueTensorInitializer( + keys=table_keys, + values=tf.cast(tf.range(len(table_keys)), tf.int64), + key_dtype=tf.string, + value_dtype=tf.int64, + ) + table = tf.lookup.StaticHashTable(initializer, default_value=-1) + return {"output": table.lookup(inputs["input"])} + export_path = _create_test_saved_model( + exported_in_tf1, input_specs, preprocessing_fn, base_dir=self.get_temp_dir() + ) -class SavedTransformIOV2Test(test_case.TransformTestCase): + # Using a computed input gives confidence that the graphs are fused. + input_string = tf.constant("dog") + inputs = {"input": input_string} + saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) + outputs = saved_model_loader.apply_transform_model(inputs) + self.assertEqual(1, outputs["output"].numpy()) - @classmethod - def setUpClass(cls): - test_case.skip_if_not_tf2('Tensorflow 2.x required.') - input_specs = { - 'x': tf.TensorSpec([ - None, - ], dtype=tf.float32) - } + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_sparse_roundtrip(self, exported_in_tf1): + input_specs = { + "input": tf.SparseTensorSpec([None, None, None], dtype=tf.float32) + } - def preprocessing_fn(inputs): - output = (inputs['x'] - 2.0) / 5.0 - return {'x_scaled': output} - - cls._saved_model_path_v1 = _create_test_saved_model(True, input_specs, - preprocessing_fn, - 'export_v1') - cls._saved_model_path_v2 = _create_test_saved_model(False, input_specs, - preprocessing_fn, - 'export_v2') - - def _get_saved_model_loader(self, exported_in_tf1): - if exported_in_tf1: - return saved_transform_io_v2.SavedModelLoader(self._saved_model_path_v1) - return saved_transform_io_v2.SavedModelLoader(self._saved_model_path_v2) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_apply_saved_transform(self, exported_in_tf1): - input_floats = tf.constant([1237.0]) # tf.float32 - input_features = {'x': input_floats} - transformed_features = ( - self._get_saved_model_loader(exported_in_tf1).apply_transform_model( - input_features)) - self.assertEqual(['x_scaled'], list(transformed_features)) - result_tensor = transformed_features['x_scaled'] - self.assertIsInstance(result_tensor, tf.Tensor) - self.assertAllEqual(result_tensor.numpy(), [247.0]) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_apply_saved_transform_dataset_map(self, exported_in_tf1): - ds = tf.data.Dataset.from_tensor_slices({'x': [[1237.0]]}) - model_loader = self._get_saved_model_loader(exported_in_tf1) - - def map_fn(inputs): - result = model_loader.apply_transform_model(inputs) - self.assertEqual(['x_scaled'], list(result)) - result_tensor = result['x_scaled'] - self.assertIsInstance(result_tensor, tf.Tensor) - self.assertEqual(result_tensor.shape.as_list(), [1]) - return result - - result_ds = ds.map(map_fn) - self.assertAllEqual( - list(result_ds.as_numpy_iterator()), [{ - 'x_scaled': [247.0] - }]) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_apply_transform_extra_features_no_passthrough(self, exported_in_tf1): - with self.assertRaises(ValueError): - input_floats = tf.constant([1237.0]) # tf.float32 - input_features = { - 'x': input_floats, - 'extra_1': tf.constant('1'), - 'extra_2': tf.constant('2') - } - self._get_saved_model_loader(exported_in_tf1).apply_transform_model( - input_features) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_apply_transform_type_mismatch(self, exported_in_tf1): - with self.assertRaises(tf.errors.InvalidArgumentError): - input_strings = tf.constant(['bogus']) # tf.string - input_features = {'x': input_strings} - self._get_saved_model_loader(exported_in_tf1).apply_transform_model( - input_features) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_apply_transform_shape_mismatch(self, exported_in_tf1): - with self.assertRaises(ValueError): - input_floats = tf.constant(1237.0) # tf.float32 - input_features = {'x': input_floats} - self._get_saved_model_loader(exported_in_tf1).apply_transform_model( - input_features) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_apply_saved_transform_to_tensor_inside_scope(self, exported_in_tf1): - with tf.compat.v1.name_scope('my_scope'): - input_floats = tf.constant([1237.0]) # tf.float32 - input_features = {'x': input_floats} - transformed_features = ( - self._get_saved_model_loader(exported_in_tf1).apply_transform_model( - input_features)) - self.assertEqual(['x_scaled'], list(transformed_features)) - result_tensor = transformed_features['x_scaled'] - self.assertIsInstance(result_tensor, tf.Tensor) - self.assertAllEqual(result_tensor.numpy(), [247.0]) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_apply_saved_transform_to_tensor_outside_scope(self, exported_in_tf1): - input_floats = tf.constant([1237.0]) # tf.float32 - with tf.compat.v1.name_scope('my_scope'): - input_features = {'x': input_floats} - transformed_features = ( - self._get_saved_model_loader(exported_in_tf1).apply_transform_model( - input_features)) - self.assertEqual(['x_scaled'], list(transformed_features)) - result_tensor = transformed_features['x_scaled'] - self.assertIsInstance(result_tensor, tf.Tensor) - self.assertAllEqual(result_tensor.numpy(), [247.0]) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_dense_roundtrip(self, exported_in_tf1): - input_specs = {'input': tf.TensorSpec([], dtype=tf.float32)} + def preprocessing_fn(inputs): + return {"output": inputs["input"] / 5.0} - def preprocessing_fn(inputs): - return {'output': inputs['input'] / 5.0} - - export_path = _create_test_saved_model( - exported_in_tf1, - input_specs, - preprocessing_fn, - base_dir=self.get_temp_dir()) - - # Using a computed input gives confidence that the graphs are fused. - input_float = tf.constant(25.0) * 2 - inputs = {'input': input_float} - saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) - outputs = saved_model_loader.apply_transform_model(inputs) - # (25 * 2) / 5 = 10 - self.assertEqual(10.0, outputs['output'].numpy()) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_table_roundtrip(self, exported_in_tf1): - input_specs = {'input': tf.TensorSpec([], dtype=tf.string)} + export_path = _create_test_saved_model( + exported_in_tf1, input_specs, preprocessing_fn, base_dir=self.get_temp_dir() + ) - def preprocessing_fn(inputs): - table_keys = ['cat', 'dog', 'giraffe'] - initializer = tf.lookup.KeyValueTensorInitializer( - keys=table_keys, - values=tf.cast(tf.range(len(table_keys)), tf.int64), - key_dtype=tf.string, - value_dtype=tf.int64) - table = tf.lookup.StaticHashTable(initializer, default_value=-1) - return {'output': table.lookup(inputs['input'])} - - export_path = _create_test_saved_model( - exported_in_tf1, - input_specs, - preprocessing_fn, - base_dir=self.get_temp_dir()) - - # Using a computed input gives confidence that the graphs are fused. - input_string = tf.constant('dog') - inputs = {'input': input_string} - saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) - outputs = saved_model_loader.apply_transform_model(inputs) - self.assertEqual(1, outputs['output'].numpy()) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_sparse_roundtrip(self, exported_in_tf1): - input_specs = { - 'input': tf.SparseTensorSpec([None, None, None], dtype=tf.float32) - } + indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64) + values = np.array([1.0, 2.0], dtype=np.float32) + shape = np.array([7, 9, 2], dtype=np.int64) + input_sparse = tf.SparseTensor( + indices=indices, values=values, dense_shape=shape + ) - def preprocessing_fn(inputs): - return {'output': inputs['input'] / 5.0} - - export_path = _create_test_saved_model( - exported_in_tf1, - input_specs, - preprocessing_fn, - base_dir=self.get_temp_dir()) - - indices = np.array([[3, 2, 0], [4, 5, 1]], dtype=np.int64) - values = np.array([1.0, 2.0], dtype=np.float32) - shape = np.array([7, 9, 2], dtype=np.int64) - input_sparse = tf.SparseTensor( - indices=indices, values=values, dense_shape=shape) - - # Using a computed input gives confidence that the graphs are fused - inputs = {'input': input_sparse * 10} - saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) - outputs = saved_model_loader.apply_transform_model(inputs) - result = outputs['output'] - self.assertIsInstance(result, tf.SparseTensor) - - # indices and shape unchanged; values multiplied by 10 and divided by 5 - self.assertEqual(indices.tolist(), result.indices.numpy().tolist()) - self.assertEqual([2.0, 4.0], result.values.numpy().tolist()) - self.assertEqual(shape.tolist(), result.dense_shape.numpy().tolist()) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_ragged_roundtrip(self, exported_in_tf1): - if not hasattr(meta_graph_pb2.TensorInfo, 'CompositeTensor'): - self.skipTest('This version of TensorFlow does not support ' - 'CompositeTenors in TensorInfo.') - input_specs = { - 'input': - tf.RaggedTensorSpec( + # Using a computed input gives confidence that the graphs are fused + inputs = {"input": input_sparse * 10} + saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) + outputs = saved_model_loader.apply_transform_model(inputs) + result = outputs["output"] + self.assertIsInstance(result, tf.SparseTensor) + + # indices and shape unchanged; values multiplied by 10 and divided by 5 + self.assertEqual(indices.tolist(), result.indices.numpy().tolist()) + self.assertEqual([2.0, 4.0], result.values.numpy().tolist()) + self.assertEqual(shape.tolist(), result.dense_shape.numpy().tolist()) + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_ragged_roundtrip(self, exported_in_tf1): + if not hasattr(meta_graph_pb2.TensorInfo, "CompositeTensor"): + self.skipTest( + "This version of TensorFlow does not support " + "CompositeTenors in TensorInfo." + ) + input_specs = { + "input": tf.RaggedTensorSpec( shape=[None, None], dtype=tf.float32, ragged_rank=1, - row_splits_dtype=tf.int64) - } + row_splits_dtype=tf.int64, + ) + } - def preprocessing_fn(inputs): - return {'output': inputs['input'] / 2.0} - - export_path = _create_test_saved_model( - exported_in_tf1, - input_specs, - preprocessing_fn, - base_dir=self.get_temp_dir()) - - splits = np.array([0, 2, 3], dtype=np.int64) - values = np.array([1.0, 2.0, 4.0], dtype=np.float32) - input_ragged = tf.RaggedTensor.from_row_splits(values, splits) - - # Using a computed input gives confidence that the graphs are fused - inputs = {'input': input_ragged * 10} - saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) - outputs = saved_model_loader.apply_transform_model(inputs) - result = outputs['output'] - self.assertIsInstance(result, tf.RaggedTensor) - - # indices and shape unchanged; values multipled by 10 and divided by 2 - self.assertAllEqual(splits, result.row_splits) - self.assertEqual([5.0, 10.0, 20.0], result.values.numpy().tolist()) - - @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) - def test_ragged_with_unfed(self, exported_in_tf1): - input_specs = { - 'x': tf.RaggedTensorSpec([ - None, - None, - ], dtype=tf.float32), - 'y': tf.RaggedTensorSpec([ - None, - ], dtype=tf.float32) - } + def preprocessing_fn(inputs): + return {"output": inputs["input"] / 2.0} - def preprocessing_fn(inputs): - output = (inputs['x'] - 2.0) / 5.0 - return {'x_scaled': output, 'x_in': inputs['x'], 'y': inputs['y'] + 1} - - export_path = _create_test_saved_model( - exported_in_tf1, - input_specs, - preprocessing_fn, - base_dir=self.get_temp_dir()) - saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) - - # Missing 'y'. - input_features = {'x': tf.ragged.constant([[1237.0]], ragged_rank=1)} - transformed_features = ( - saved_model_loader.apply_transform_model(input_features)) - self.assertCountEqual(['x_in', 'x_scaled'], list(transformed_features)) - self.assertAllEqual(transformed_features['x_scaled'].numpy(), [[247.0]]) - self.assertAllEqual(transformed_features['x_in'].numpy(), [[1237.0]]) - - @test_case.named_parameters(*_RE_EXPORT_TF2_TO_TF1_TEST_CASES) - def test_re_export_tf2_saved_model_to_tf1(self, - preprocessing_fn_getter, - expected_output, - test_input, - asset_file_contents=None): - - asset_file = None - if asset_file_contents is not None: - asset_file_path = os.path.join( - tempfile.mkdtemp(dir=self.get_temp_dir()), 'asset') - file_io.write_string_to_file(asset_file_path, asset_file_contents) - asset_file = tf.constant(asset_file_path) - - input_specs = {'input': tf.TensorSpec([], dtype=tf.string)} - export_path = _create_test_saved_model( - False, - input_specs, - preprocessing_fn_getter(asset_file), - base_dir=self.get_temp_dir()) - - if asset_file is not None: - os.remove(asset_file.numpy()) - new_export_path = os.path.join( - tempfile.mkdtemp(dir=self.get_temp_dir()), 'export_v1') - - builder = tf.compat.v1.saved_model.builder.SavedModelBuilder( - new_export_path) - # TODO(b/175844561): Investigate why the variable names need to be different - # for the two graph and session contexts below. - with tf.compat.v1.Graph().as_default() as g1: - saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) - if asset_file_contents is not None: - self.assertEqual( - 1, len(g1.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS))) - with tf.compat.v1.Session().as_default() as s1: - inputs = {'input': tf.compat.v1.placeholder(tf.string)} + export_path = _create_test_saved_model( + exported_in_tf1, input_specs, preprocessing_fn, base_dir=self.get_temp_dir() + ) + + splits = np.array([0, 2, 3], dtype=np.int64) + values = np.array([1.0, 2.0, 4.0], dtype=np.float32) + input_ragged = tf.RaggedTensor.from_row_splits(values, splits) + + # Using a computed input gives confidence that the graphs are fused + inputs = {"input": input_ragged * 10} + saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) outputs = saved_model_loader.apply_transform_model(inputs) - predict_signature_def = ( - tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( - inputs, outputs)) - builder.add_meta_graph_and_variables( - s1, ['graph_tag'], - signature_def_map={'graph_signature': predict_signature_def}, - assets_collection=tf.compat.v1.get_collection( - tf.compat.v1.GraphKeys.ASSET_FILEPATHS), - main_op=tf.compat.v1.tables_initializer()) - builder.save() - - shutil.rmtree(export_path) - - with tf.compat.v1.Graph().as_default() as g2: - with tf.compat.v1.Session().as_default() as s2: - meta_graph_def = tf.compat.v1.saved_model.loader.load( - s2, ['graph_tag'], new_export_path) - signature = meta_graph_def.signature_def['graph_signature'] - output = s2.run( - g2.get_tensor_by_name(signature.outputs['output'].name), - feed_dict={ - g2.get_tensor_by_name(signature.inputs['input'].name): - test_input - }) - self.assertEqual(expected_output, output) + result = outputs["output"] + self.assertIsInstance(result, tf.RaggedTensor) + + # indices and shape unchanged; values multipled by 10 and divided by 2 + self.assertAllEqual(splits, result.row_splits) + self.assertEqual([5.0, 10.0, 20.0], result.values.numpy().tolist()) + + @test_case.named_parameters(*_TRANFORM_FN_EXPORT_TF_VERSION_TEST_CASES) + def test_ragged_with_unfed(self, exported_in_tf1): + input_specs = { + "x": tf.RaggedTensorSpec( + [ + None, + None, + ], + dtype=tf.float32, + ), + "y": tf.RaggedTensorSpec( + [ + None, + ], + dtype=tf.float32, + ), + } + + def preprocessing_fn(inputs): + output = (inputs["x"] - 2.0) / 5.0 + return {"x_scaled": output, "x_in": inputs["x"], "y": inputs["y"] + 1} + + export_path = _create_test_saved_model( + exported_in_tf1, input_specs, preprocessing_fn, base_dir=self.get_temp_dir() + ) + saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) + + # Missing 'y'. + input_features = {"x": tf.ragged.constant([[1237.0]], ragged_rank=1)} + transformed_features = saved_model_loader.apply_transform_model(input_features) + self.assertCountEqual(["x_in", "x_scaled"], list(transformed_features)) + self.assertAllEqual(transformed_features["x_scaled"].numpy(), [[247.0]]) + self.assertAllEqual(transformed_features["x_in"].numpy(), [[1237.0]]) + + @test_case.named_parameters(*_RE_EXPORT_TF2_TO_TF1_TEST_CASES) + def test_re_export_tf2_saved_model_to_tf1( + self, + preprocessing_fn_getter, + expected_output, + test_input, + asset_file_contents=None, + ): + asset_file = None if asset_file_contents is not None: - self.assertEqual( - 1, len(g2.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS))) + asset_file_path = os.path.join( + tempfile.mkdtemp(dir=self.get_temp_dir()), "asset" + ) + file_io.write_string_to_file(asset_file_path, asset_file_contents) + asset_file = tf.constant(asset_file_path) + + input_specs = {"input": tf.TensorSpec([], dtype=tf.string)} + export_path = _create_test_saved_model( + False, + input_specs, + preprocessing_fn_getter(asset_file), + base_dir=self.get_temp_dir(), + ) - def test_stale_asset_collections_are_cleaned(self): - vocabulary_file = os.path.join( - tempfile.mkdtemp(dir=self.get_temp_dir()), 'asset') - file_io.write_string_to_file(vocabulary_file, 'foo bar baz') + if asset_file is not None: + os.remove(asset_file.numpy()) + new_export_path = os.path.join( + tempfile.mkdtemp(dir=self.get_temp_dir()), "export_v1" + ) - input_specs = {'input': tf.TensorSpec([], dtype=tf.string)} + builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(new_export_path) + # TODO(b/175844561): Investigate why the variable names need to be different + # for the two graph and session contexts below. + with tf.compat.v1.Graph().as_default() as g1: + saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) + if asset_file_contents is not None: + self.assertEqual( + 1, len(g1.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)) + ) + with tf.compat.v1.Session().as_default() as s1: + inputs = {"input": tf.compat.v1.placeholder(tf.string)} + outputs = saved_model_loader.apply_transform_model(inputs) + predict_signature_def = ( + tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( + inputs, outputs + ) + ) + builder.add_meta_graph_and_variables( + s1, + ["graph_tag"], + signature_def_map={"graph_signature": predict_signature_def}, + assets_collection=tf.compat.v1.get_collection( + tf.compat.v1.GraphKeys.ASSET_FILEPATHS + ), + main_op=tf.compat.v1.tables_initializer(), + ) + builder.save() + + shutil.rmtree(export_path) + + with tf.compat.v1.Graph().as_default() as g2: + with tf.compat.v1.Session().as_default() as s2: + meta_graph_def = tf.compat.v1.saved_model.loader.load( + s2, ["graph_tag"], new_export_path + ) + signature = meta_graph_def.signature_def["graph_signature"] + output = s2.run( + g2.get_tensor_by_name(signature.outputs["output"].name), + feed_dict={ + g2.get_tensor_by_name( + signature.inputs["input"].name + ): test_input + }, + ) + self.assertEqual(expected_output, output) + if asset_file_contents is not None: + self.assertEqual( + 1, + len(g2.get_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS)), + ) + + def test_stale_asset_collections_are_cleaned(self): + vocabulary_file = os.path.join( + tempfile.mkdtemp(dir=self.get_temp_dir()), "asset" + ) + file_io.write_string_to_file(vocabulary_file, "foo bar baz") - def preprocessing_fn(inputs): - initializer = tf.lookup.TextFileInitializer( - vocabulary_file, - key_dtype=tf.string, - key_index=tf.lookup.TextFileIndex.WHOLE_LINE, - value_dtype=tf.int64, - value_index=tf.lookup.TextFileIndex.LINE_NUMBER) - table = tf.lookup.StaticHashTable(initializer, default_value=12) - return {'output': table.lookup(inputs['input'])} - - export_path = _create_test_saved_model( - False, input_specs, preprocessing_fn, base_dir=self.get_temp_dir()) - - # Load it and save it again repeatedly, verifying that the asset collections - # remain valid. - for it in [1, 2, 3]: - input_string = tf.constant('dog') - inputs = {'input': input_string} - saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) - outputs = saved_model_loader.apply_transform_model(inputs) - self.assertEqual(12, outputs['output']) - - new_export_path = os.path.join( - tempfile.mkdtemp(dir=self.get_temp_dir()), 'export_' + str(it)) - tf.saved_model.save(saved_model_loader._imported, new_export_path) - shutil.rmtree(export_path) - export_path = new_export_path - - def test_finalize(self): - input_keys = ['x'] - output_keys = ['x_scaled'] - - input_specs = { - 'x': tf.TensorSpec([ - None, - ], dtype=tf.float32), - 'y': tf.TensorSpec([ - None, - ], dtype=tf.float32) - } + input_specs = {"input": tf.TensorSpec([], dtype=tf.string)} - def preprocessing_fn(inputs): - output = (inputs['x'] - 2.0) / 5.0 - return {'x_scaled': output, 'x_in': inputs['x'], 'y': inputs['y'] + 1} - - export_path = _create_test_saved_model( - False, input_specs, preprocessing_fn, base_dir=self.get_temp_dir()) - saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) - - input_features = {'x': tf.constant([1237.0])} # tf.float32 - transformed_features = ( - saved_model_loader.apply_transform_model(input_features)) - self.assertCountEqual(['x_in', 'x_scaled'], list(transformed_features)) - self.assertAllEqual(transformed_features['x_scaled'].numpy(), [247.0]) - self.assertAllEqual(transformed_features['x_in'].numpy(), [1237.0]) - - # Since `finalize` is not thread-safe it is not recommended to call it after - # `apply_transform_model` has already been invoked. This is only for unit - # testing behavior differences. - saved_model_loader.finalize(input_keys, output_keys) - transformed_features = ( - saved_model_loader.apply_transform_model(input_features)) - self.assertEqual(['x_scaled'], list(transformed_features)) - self.assertAllEqual(transformed_features['x_scaled'].numpy(), [247.0]) - - @test_case.named_parameters( - dict( - testcase_name='_strip_control_dependencies', - strip_control_dependencies=True), - dict( - testcase_name='_keep_control_dependencies', - strip_control_dependencies=False)) - def test_optimize_concrete_function(self, strip_control_dependencies): - - @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int64)]) - def func(x): - z = x + 2 - with tf.init_scope(): - initializer = tf.lookup.KeyValueTensorInitializer([0, 1, 2], - ['a', 'b', 'c'], - key_dtype=tf.int64, - value_dtype=tf.string) - table = tf.lookup.StaticHashTable(initializer, default_value='NAN') - _ = table.lookup(x) - return z - - concrete_function = func.get_concrete_function() - optimized_function = saved_transform_io_v2.optimize_concrete_function( - concrete_function, - strip_control_dependencies=strip_control_dependencies) - output = optimized_function(tf.constant(0, tf.int64)) - self.assertEqual(output, 2) - - if strip_control_dependencies: - self.assertLess( - len(optimized_function.graph.as_graph_def().node), - len(concrete_function.graph.as_graph_def().node)) - else: - self.assertEqual( - len(optimized_function.graph.as_graph_def().node), - len(concrete_function.graph.as_graph_def().node)) - - def test_restore_from_v1_saved_model_with_pyfuncs(self): - input_specs = { - 'a': tf.TensorSpec([ - None, - ], dtype=tf.float32), - 'b': tf.TensorSpec([ - None, - ], dtype=tf.float32), - } - - def my_add(x, y): - return x + y - - def func(inputs): - result = { - 'a+b': - apply_pyfunc(my_add, tf.float32, True, 'add', inputs['a'], - inputs['b']) - } - for value in result.values(): - value.set_shape([1]) - return result - - saved_model_path_v1 = _create_test_saved_model(True, input_specs, func, - 'export_v1') - # Clear PyFunc registry to mimic loading a SavedModel in a new runtime. - script_ops._py_funcs._funcs.clear() # pylint: disable=protected-access - - imported = tf.compat.v2.saved_model.load(saved_model_path_v1) - imported_function = imported.signatures[constants.TRANSFORM_SIGNATURE] - input_keys = ['a', 'b'] - inputs = [ - tf.constant([2.0], dtype=tf.float32), - tf.constant([3.0], dtype=tf.float32) - ] - input_kwargs = {k: v for k, v in zip(input_keys, inputs)} - expected_output = 5.0 - restored_function, _, _ = ( - saved_transform_io_v2._restore_from_v1_saved_model( - imported_function, saved_model_path_v1)) - with self.assertRaisesRegex(tf.errors.InvalidArgumentError, - 'callback.*pyfunc_'): - imported_function(**input_kwargs) - self.assertEqual(restored_function(*inputs)['a+b'], expected_output) - - def test_restore_from_v1_saved_model_without_pyfuncs(self): - input_specs = { - 'a': tf.TensorSpec([ - None, - ], dtype=tf.float32), - 'b': tf.TensorSpec([ - None, - ], dtype=tf.float32), - } - - def func(inputs): - result = {'a+b': inputs['a'] + inputs['b']} - for value in result.values(): - value.set_shape([1]) - return result - - saved_model_path_v1 = _create_test_saved_model(True, input_specs, func, - 'export_v1') - - imported = tf.compat.v2.saved_model.load(saved_model_path_v1) - imported_function = imported.signatures[constants.TRANSFORM_SIGNATURE] - input_kwargs = { - 'a': tf.constant([2.0], dtype=tf.float32), - 'b': tf.constant([3.0], dtype=tf.float32) - } - expected_output = 5.0 - restored_function, _, _ = ( - saved_transform_io_v2._restore_from_v1_saved_model( - imported_function, saved_model_path_v1)) - self.assertEqual(imported_function(**input_kwargs)['a+b'], expected_output) - self.assertEqual(restored_function(**input_kwargs)['a+b'], expected_output) - - -if __name__ == '__main__': - test_case.main() + def preprocessing_fn(inputs): + initializer = tf.lookup.TextFileInitializer( + vocabulary_file, + key_dtype=tf.string, + key_index=tf.lookup.TextFileIndex.WHOLE_LINE, + value_dtype=tf.int64, + value_index=tf.lookup.TextFileIndex.LINE_NUMBER, + ) + table = tf.lookup.StaticHashTable(initializer, default_value=12) + return {"output": table.lookup(inputs["input"])} + + export_path = _create_test_saved_model( + False, input_specs, preprocessing_fn, base_dir=self.get_temp_dir() + ) + + # Load it and save it again repeatedly, verifying that the asset collections + # remain valid. + for it in [1, 2, 3]: + input_string = tf.constant("dog") + inputs = {"input": input_string} + saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) + outputs = saved_model_loader.apply_transform_model(inputs) + self.assertEqual(12, outputs["output"]) + + new_export_path = os.path.join( + tempfile.mkdtemp(dir=self.get_temp_dir()), "export_" + str(it) + ) + tf.saved_model.save(saved_model_loader._imported, new_export_path) + shutil.rmtree(export_path) + export_path = new_export_path + + def test_finalize(self): + input_keys = ["x"] + output_keys = ["x_scaled"] + + input_specs = { + "x": tf.TensorSpec( + [ + None, + ], + dtype=tf.float32, + ), + "y": tf.TensorSpec( + [ + None, + ], + dtype=tf.float32, + ), + } + + def preprocessing_fn(inputs): + output = (inputs["x"] - 2.0) / 5.0 + return {"x_scaled": output, "x_in": inputs["x"], "y": inputs["y"] + 1} + + export_path = _create_test_saved_model( + False, input_specs, preprocessing_fn, base_dir=self.get_temp_dir() + ) + saved_model_loader = saved_transform_io_v2.SavedModelLoader(export_path) + + input_features = {"x": tf.constant([1237.0])} # tf.float32 + transformed_features = saved_model_loader.apply_transform_model(input_features) + self.assertCountEqual(["x_in", "x_scaled"], list(transformed_features)) + self.assertAllEqual(transformed_features["x_scaled"].numpy(), [247.0]) + self.assertAllEqual(transformed_features["x_in"].numpy(), [1237.0]) + + # Since `finalize` is not thread-safe it is not recommended to call it after + # `apply_transform_model` has already been invoked. This is only for unit + # testing behavior differences. + saved_model_loader.finalize(input_keys, output_keys) + transformed_features = saved_model_loader.apply_transform_model(input_features) + self.assertEqual(["x_scaled"], list(transformed_features)) + self.assertAllEqual(transformed_features["x_scaled"].numpy(), [247.0]) + + @test_case.named_parameters( + dict( + testcase_name="_strip_control_dependencies", strip_control_dependencies=True + ), + dict( + testcase_name="_keep_control_dependencies", strip_control_dependencies=False + ), + ) + def test_optimize_concrete_function(self, strip_control_dependencies): + @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int64)]) + def func(x): + z = x + 2 + with tf.init_scope(): + initializer = tf.lookup.KeyValueTensorInitializer( + [0, 1, 2], + ["a", "b", "c"], + key_dtype=tf.int64, + value_dtype=tf.string, + ) + table = tf.lookup.StaticHashTable(initializer, default_value="NAN") + _ = table.lookup(x) + return z + + concrete_function = func.get_concrete_function() + optimized_function = saved_transform_io_v2.optimize_concrete_function( + concrete_function, strip_control_dependencies=strip_control_dependencies + ) + output = optimized_function(tf.constant(0, tf.int64)) + self.assertEqual(output, 2) + + if strip_control_dependencies: + self.assertLess( + len(optimized_function.graph.as_graph_def().node), + len(concrete_function.graph.as_graph_def().node), + ) + else: + self.assertEqual( + len(optimized_function.graph.as_graph_def().node), + len(concrete_function.graph.as_graph_def().node), + ) + + def test_restore_from_v1_saved_model_with_pyfuncs(self): + input_specs = { + "a": tf.TensorSpec( + [ + None, + ], + dtype=tf.float32, + ), + "b": tf.TensorSpec( + [ + None, + ], + dtype=tf.float32, + ), + } + + def my_add(x, y): + return x + y + + def func(inputs): + result = { + "a+b": apply_pyfunc( + my_add, tf.float32, True, "add", inputs["a"], inputs["b"] + ) + } + for value in result.values(): + value.set_shape([1]) + return result + + saved_model_path_v1 = _create_test_saved_model( + True, input_specs, func, "export_v1" + ) + # Clear PyFunc registry to mimic loading a SavedModel in a new runtime. + script_ops._py_funcs._funcs.clear() # pylint: disable=protected-access + + imported = tf.compat.v2.saved_model.load(saved_model_path_v1) + imported_function = imported.signatures[constants.TRANSFORM_SIGNATURE] + input_keys = ["a", "b"] + inputs = [ + tf.constant([2.0], dtype=tf.float32), + tf.constant([3.0], dtype=tf.float32), + ] + input_kwargs = {k: v for k, v in zip(input_keys, inputs)} + expected_output = 5.0 + restored_function, _, _ = saved_transform_io_v2._restore_from_v1_saved_model( + imported_function, saved_model_path_v1 + ) + with self.assertRaisesRegex( + tf.errors.InvalidArgumentError, "callback.*pyfunc_" + ): + imported_function(**input_kwargs) + self.assertEqual(restored_function(*inputs)["a+b"], expected_output) + + def test_restore_from_v1_saved_model_without_pyfuncs(self): + input_specs = { + "a": tf.TensorSpec( + [ + None, + ], + dtype=tf.float32, + ), + "b": tf.TensorSpec( + [ + None, + ], + dtype=tf.float32, + ), + } + + def func(inputs): + result = {"a+b": inputs["a"] + inputs["b"]} + for value in result.values(): + value.set_shape([1]) + return result + + saved_model_path_v1 = _create_test_saved_model( + True, input_specs, func, "export_v1" + ) + + imported = tf.compat.v2.saved_model.load(saved_model_path_v1) + imported_function = imported.signatures[constants.TRANSFORM_SIGNATURE] + input_kwargs = { + "a": tf.constant([2.0], dtype=tf.float32), + "b": tf.constant([3.0], dtype=tf.float32), + } + expected_output = 5.0 + restored_function, _, _ = saved_transform_io_v2._restore_from_v1_saved_model( + imported_function, saved_model_path_v1 + ) + self.assertEqual(imported_function(**input_kwargs)["a+b"], expected_output) + self.assertEqual(restored_function(**input_kwargs)["a+b"], expected_output) + + +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/schema_inference.py b/tensorflow_transform/schema_inference.py index d6fe81c..c507d95 100644 --- a/tensorflow_transform/schema_inference.py +++ b/tensorflow_transform/schema_inference.py @@ -22,52 +22,61 @@ import collections import itertools from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union -from absl import logging import tensorflow as tf -from tensorflow_transform import common -from tensorflow_transform import common_types -from tensorflow_transform import graph_context -from tensorflow_transform import tf2_utils -from tensorflow_transform import tf_utils -from tensorflow_transform.saved import saved_transform_io_v2 -from tensorflow_transform.tf_metadata import schema_utils -from tfx_bsl.tfxio import tensor_representation_util - +from absl import logging from google.protobuf import any_pb2 + # pylint: disable=g-direct-tensorflow-import from tensorflow.python.eager import function from tensorflow.python.framework import ops + # pylint: enable=g-direct-tensorflow-import from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx_bsl.tfxio import tensor_representation_util + +from tensorflow_transform import ( + common, + common_types, + graph_context, + tf2_utils, + tf_utils, +) +from tensorflow_transform.saved import saved_transform_io_v2 +from tensorflow_transform.tf_metadata import schema_utils -SPARSE_VALUES_NAME_TEMPLATE = '{tensor_name}$sparse_values' -SPARSE_INDICES_NAME_TEMPLATE = '{tensor_name}$sparse_indices_{index}' +SPARSE_VALUES_NAME_TEMPLATE = "{tensor_name}$sparse_values" +SPARSE_INDICES_NAME_TEMPLATE = "{tensor_name}$sparse_indices_{index}" def _ragged_feature_spec_from_batched_tensor( - name: str, - tensor: tf.RaggedTensor) -> Union[tf.io.VarLenFeature, tf.io.RaggedFeature]: - """Infer `tf.io.RaggedFeature` from a batched `tf.RaggedTensor`.""" - partitions = [] - row_lengths_partition_idx = 1 - # Ignore batch dimension. - for dim in tensor.values.shape[1:]: - if dim or dim == 0: - partitions.append( - tf.io.RaggedFeature.UniformRowLength( # pytype: disable=attribute-error - length=dim)) - else: - partitions.append( - tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error - key='{}$row_lengths_{}'.format(name, row_lengths_partition_idx))) - row_lengths_partition_idx += 1 - - return tf.io.RaggedFeature( - dtype=tensor.dtype, - value_key='{}$ragged_values'.format(name), - partitions=partitions, - row_splits_dtype=tensor.row_splits.dtype) + name: str, tensor: tf.RaggedTensor +) -> Union[tf.io.VarLenFeature, tf.io.RaggedFeature]: + """Infer `tf.io.RaggedFeature` from a batched `tf.RaggedTensor`.""" + partitions = [] + row_lengths_partition_idx = 1 + # Ignore batch dimension. + for dim in tensor.values.shape[1:]: + if dim or dim == 0: + partitions.append( + tf.io.RaggedFeature.UniformRowLength( # pytype: disable=attribute-error + length=dim + ) + ) + else: + partitions.append( + tf.io.RaggedFeature.RowLengths( # pytype: disable=attribute-error + key=f"{name}$row_lengths_{row_lengths_partition_idx}" + ) + ) + row_lengths_partition_idx += 1 + + return tf.io.RaggedFeature( + dtype=tensor.dtype, + value_key=f"{name}$ragged_values", + partitions=partitions, + row_splits_dtype=tensor.row_splits.dtype, + ) def _feature_spec_from_batched_tensors( @@ -75,248 +84,257 @@ def _feature_spec_from_batched_tensors( is_evaluation_complete: bool, forced_sparse_keys: Iterable[str], ) -> Dict[str, common_types.FeatureSpecType]: - """Infer a feature spec from a dict of tensors. - - Args: - tensors: A dict whose keys are strings and values are `Tensor`, - `SparseTensor`, or `RaggedTensor`s. - is_evaluation_complete: A boolean indicating whether all analyzers have been - evaluated or not. - forced_sparse_keys: Keys of sparse tensors which should not be treated as - varlen. - - Returns: - A feature spec inferred from the types and shapes of the tensors. - - Raises: - ValueError: If the feature spec cannot be inferred. - TypeError: If any of the values of `tensors` are not a `Tensor`, - `SparseTensor`, or `RaggedTensor`. - """ - feature_spec = {} - for name, tensor in tensors.items(): - if tensor.dtype not in (tf.string, tf.int64, tf.float32): - raise ValueError( - 'Feature {} ({}) had invalid dtype {} for feature spec'.format( - name, tensor, tensor.dtype)) - if isinstance(tensor, tf.SparseTensor): - shape = tensor.get_shape() - if shape.ndims > 2 or name in forced_sparse_keys: - feature_spec[name] = tf.io.SparseFeature( - index_key=[ - SPARSE_INDICES_NAME_TEMPLATE.format( - tensor_name=name, index=idx) - for idx in range(shape.ndims - 1) - ], - value_key=SPARSE_VALUES_NAME_TEMPLATE.format(tensor_name=name), - dtype=tensor.dtype, - size=shape[1:], - already_sorted=True) - else: - feature_spec[name] = tf.io.VarLenFeature(tensor.dtype) - elif isinstance(tensor, tf.Tensor): - shape = tensor.get_shape() - if shape.ndims in [None, 0]: - raise ValueError( - 'Feature {} ({}) had invalid shape {} for FixedLenFeature: must ' - 'have rank at least 1'.format(name, tensor, shape)) - if is_evaluation_complete and any( - dim is None for dim in shape.as_list()[1:]): - raise ValueError( - 'Feature {} ({}) had invalid shape {} for FixedLenFeature: apart ' - 'from the batch dimension, all dimensions must have known size' - .format(name, tensor, shape)) - feature_spec[name] = tf.io.FixedLenFeature(shape.as_list()[1:], - tensor.dtype) - elif isinstance(tensor, tf.RaggedTensor): - feature_spec[name] = ( - _ragged_feature_spec_from_batched_tensor(name, tensor)) - else: - raise TypeError( - 'Expected a Tensor, SparseTensor, or RaggedTensor got {} of type {} ' - 'for feature {}'.format(tensor, type(tensor), name)) - - return feature_spec + """Infer a feature spec from a dict of tensors. + + Args: + ---- + tensors: A dict whose keys are strings and values are `Tensor`, + `SparseTensor`, or `RaggedTensor`s. + is_evaluation_complete: A boolean indicating whether all analyzers have been + evaluated or not. + forced_sparse_keys: Keys of sparse tensors which should not be treated as + varlen. + + Returns: + ------- + A feature spec inferred from the types and shapes of the tensors. + + Raises: + ------ + ValueError: If the feature spec cannot be inferred. + TypeError: If any of the values of `tensors` are not a `Tensor`, + `SparseTensor`, or `RaggedTensor`. + """ + feature_spec = {} + for name, tensor in tensors.items(): + if tensor.dtype not in (tf.string, tf.int64, tf.float32): + raise ValueError( + f"Feature {name} ({tensor}) had invalid dtype {tensor.dtype} for feature spec" + ) + if isinstance(tensor, tf.SparseTensor): + shape = tensor.get_shape() + if shape.ndims > 2 or name in forced_sparse_keys: + feature_spec[name] = tf.io.SparseFeature( + index_key=[ + SPARSE_INDICES_NAME_TEMPLATE.format(tensor_name=name, index=idx) + for idx in range(shape.ndims - 1) + ], + value_key=SPARSE_VALUES_NAME_TEMPLATE.format(tensor_name=name), + dtype=tensor.dtype, + size=shape[1:], + already_sorted=True, + ) + else: + feature_spec[name] = tf.io.VarLenFeature(tensor.dtype) + elif isinstance(tensor, tf.Tensor): + shape = tensor.get_shape() + if shape.ndims in [None, 0]: + raise ValueError( + f"Feature {name} ({tensor}) had invalid shape {shape} for FixedLenFeature: must " + "have rank at least 1" + ) + if is_evaluation_complete and any( + dim is None for dim in shape.as_list()[1:] + ): + raise ValueError( + f"Feature {name} ({tensor}) had invalid shape {shape} for FixedLenFeature: apart " + "from the batch dimension, all dimensions must have known size" + ) + feature_spec[name] = tf.io.FixedLenFeature( + shape.as_list()[1:], tensor.dtype + ) + elif isinstance(tensor, tf.RaggedTensor): + feature_spec[name] = _ragged_feature_spec_from_batched_tensor(name, tensor) + else: + raise TypeError( + f"Expected a Tensor, SparseTensor, or RaggedTensor got {tensor} of type {type(tensor)} " + f"for feature {name}" + ) + + return feature_spec def _get_tensor_values(tensor: common_types.TensorType) -> tf.Tensor: - if isinstance(tensor, tf.SparseTensor): - return tensor.values - elif isinstance(tensor, tf.RaggedTensor): - return tensor.flat_values - else: - return tensor + if isinstance(tensor, tf.SparseTensor): + return tensor.values + elif isinstance(tensor, tf.RaggedTensor): + return tensor.flat_values + else: + return tensor def infer_feature_schema( features: Mapping[str, common_types.TensorType], graph: tf.compat.v1.Graph, - session: Optional[tf.compat.v1.Session] = None) -> schema_pb2.Schema: - """Given a dict of tensors, creates a `Schema`. - - Infers a schema, in the format of a tf.Transform `Schema`, for the given - dictionary of tensors. - - If there is an override specified, we override the inferred schema for the - given feature's tensor. An override has the meaning that we should set - is_categorical=True. If session is not provided then we just set - is_categorical=True, and if the session is provided then was also compute - values of the tensors representing the min and max values and set them in the - schema. - - If annotations have been specified, they are added to the output schema. - - Args: - features: A dict mapping column names to `Tensor`, `SparseTensor` or - `RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have - a 0'th dimension which is interpreted as the batch dimension. - graph: A `tf.Graph` used to determine schema overrides. - session: (optional) A `tf.Session` used to compute schema overrides. If - None, schema overrides will not be computed. It is assumed that if all - analyzers have been evaluated, `session` is passed to this API. - - Returns: - A `Schema` proto. - """ - tensor_ranges = _get_tensor_ranges(graph) - if session is None: - tensor_ranges = {hashable: (None, None) for hashable in tensor_ranges} - tensor_annotations = {} - global_annotations = [] - else: - tensor_ranges = session.run(tensor_ranges) - tensor_annotations, global_annotations = _get_schema_annotations( - graph, session) - sparse_output_annotations = _get_sparse_output_annotations_v1(graph, session) - modified_sparse_output_annotations = {} - modified_tensor_ranges = {} - feature_annotations = {} - for name, tensor in features.items(): - hashable_values = tf_utils.hashable_tensor_or_op( - values := _get_tensor_values(tensor) + session: Optional[tf.compat.v1.Session] = None, +) -> schema_pb2.Schema: + """Given a dict of tensors, creates a `Schema`. + + Infers a schema, in the format of a tf.Transform `Schema`, for the given + dictionary of tensors. + + If there is an override specified, we override the inferred schema for the + given feature's tensor. An override has the meaning that we should set + is_categorical=True. If session is not provided then we just set + is_categorical=True, and if the session is provided then was also compute + values of the tensors representing the min and max values and set them in the + schema. + + If annotations have been specified, they are added to the output schema. + + Args: + ---- + features: A dict mapping column names to `Tensor`, `SparseTensor` or + `RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have + a 0'th dimension which is interpreted as the batch dimension. + graph: A `tf.Graph` used to determine schema overrides. + session: (optional) A `tf.Session` used to compute schema overrides. If + None, schema overrides will not be computed. It is assumed that if all + analyzers have been evaluated, `session` is passed to this API. + + Returns: + ------- + A `Schema` proto. + """ + tensor_ranges = _get_tensor_ranges(graph) + if session is None: + tensor_ranges = {hashable: (None, None) for hashable in tensor_ranges} + tensor_annotations = {} + global_annotations = [] + else: + tensor_ranges = session.run(tensor_ranges) + tensor_annotations, global_annotations = _get_schema_annotations(graph, session) + sparse_output_annotations = _get_sparse_output_annotations_v1(graph, session) + modified_sparse_output_annotations = {} + modified_tensor_ranges = {} + feature_annotations = {} + for name, tensor in features.items(): + hashable_values = tf_utils.hashable_tensor_or_op( + values := _get_tensor_values(tensor) + ) + if hashable_values in tensor_ranges: + assert values.dtype == tf.int64 + modified_tensor_ranges[name] = tensor_ranges[hashable_values] + if hashable_values in sparse_output_annotations: + modified_sparse_output_annotations[name] = sparse_output_annotations[ + hashable_values + ] + feature_annotations[name] = tensor_annotations.get(hashable_values, []) + + return _infer_feature_schema_common( + features, + modified_tensor_ranges, + feature_annotations, + global_annotations, + modified_sparse_output_annotations, + # A session is passed to compute schema overrides which exist only if all + # analyzers have been evaluated. Hence, if `session` was provided to this + # API assume TFT's evaluation is complete. + is_evaluation_complete=session is not None, ) - if hashable_values in tensor_ranges: - assert values.dtype == tf.int64 - modified_tensor_ranges[name] = tensor_ranges[hashable_values] - if hashable_values in sparse_output_annotations: - modified_sparse_output_annotations[name] = sparse_output_annotations[ - hashable_values - ] - feature_annotations[name] = tensor_annotations.get(hashable_values, []) - - return _infer_feature_schema_common( - features, - modified_tensor_ranges, - feature_annotations, - global_annotations, - modified_sparse_output_annotations, - # A session is passed to compute schema overrides which exist only if all - # analyzers have been evaluated. Hence, if `session` was provided to this - # API assume TFT's evaluation is complete. - is_evaluation_complete=session is not None, - ) def infer_feature_schema_v2( features: Mapping[str, common_types.TensorType], concrete_metadata_fn: function.ConcreteFunction, - evaluate_schema_overrides: bool) -> schema_pb2.Schema: - """Given a dict of tensors, creates a `Schema`. - - Infers a schema, in the format of a tf.Transform `Schema`, for the given - dictionary of tensors. - - If there is an override specified, we override the inferred schema for the - given feature's tensor. An override has the meaning that we should set - is_categorical=True. If evaluate_schema_overrides is False then we just set - is_categorical=True, and if evaluate_schema_overrides is True then we also - compute values of the tensors representing the min and max values and set them - in the schema. - - If annotations have been specified, they are added to the output schema. - - Args: - features: A dict mapping column names to `Tensor`, `SparseTensor` or - `RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have - a 0'th dimension which is interpreted as the batch dimension. - concrete_metadata_fn: A `tf.ConcreteFunction` that returns a dictionary - containing the deferred annotations added to the graph when invoked with - any valid input. - evaluate_schema_overrides: A Boolean used to compute schema overrides. If - `False`, schema overrides will not be computed. - - Returns: - A `Schema` proto. - """ - metadata = collections.defaultdict(list, concrete_metadata_fn()) - - if not evaluate_schema_overrides: - tensor_ranges = { - tensor.numpy().decode(): (None, None) - for tensor in metadata[_TF_METADATA_TENSOR_COLLECTION] - } - tensor_annotations = {} - global_annotations = [] - else: - tensor_ranges = _get_tensor_ranges_v2(metadata) - tensor_annotations, global_annotations = _get_schema_annotations_v2( - metadata) - - def _get_metadata_sparse_output_annotations(metadata): - sparse_output_annotations = metadata[ - _METADATA_SPARSE_OUTPUT_OVERRIDES_FIELD - ] - if not sparse_output_annotations: - return None - return { - k: [a.numpy() for a in v] - for k, v in sparse_output_annotations[0].items() - } + evaluate_schema_overrides: bool, +) -> schema_pb2.Schema: + """Given a dict of tensors, creates a `Schema`. + + Infers a schema, in the format of a tf.Transform `Schema`, for the given + dictionary of tensors. + + If there is an override specified, we override the inferred schema for the + given feature's tensor. An override has the meaning that we should set + is_categorical=True. If evaluate_schema_overrides is False then we just set + is_categorical=True, and if evaluate_schema_overrides is True then we also + compute values of the tensors representing the min and max values and set them + in the schema. + + If annotations have been specified, they are added to the output schema. + + Args: + ---- + features: A dict mapping column names to `Tensor`, `SparseTensor` or + `RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have + a 0'th dimension which is interpreted as the batch dimension. + concrete_metadata_fn: A `tf.ConcreteFunction` that returns a dictionary + containing the deferred annotations added to the graph when invoked with + any valid input. + evaluate_schema_overrides: A Boolean used to compute schema overrides. If + `False`, schema overrides will not be computed. + + Returns: + ------- + A `Schema` proto. + """ + metadata = collections.defaultdict(list, concrete_metadata_fn()) - return _infer_feature_schema_common( - features, - tensor_ranges, - tensor_annotations, - global_annotations, - sparse_output_annotations=_get_metadata_sparse_output_annotations( - metadata - ), - is_evaluation_complete=evaluate_schema_overrides, - ) + if not evaluate_schema_overrides: + tensor_ranges = { + tensor.numpy().decode(): (None, None) + for tensor in metadata[_TF_METADATA_TENSOR_COLLECTION] + } + tensor_annotations = {} + global_annotations = [] + else: + tensor_ranges = _get_tensor_ranges_v2(metadata) + tensor_annotations, global_annotations = _get_schema_annotations_v2(metadata) + + def _get_metadata_sparse_output_annotations(metadata): + sparse_output_annotations = metadata[_METADATA_SPARSE_OUTPUT_OVERRIDES_FIELD] + if not sparse_output_annotations: + return None + return { + k: [a.numpy() for a in v] for k, v in sparse_output_annotations[0].items() + } + + return _infer_feature_schema_common( + features, + tensor_ranges, + tensor_annotations, + global_annotations, + sparse_output_annotations=_get_metadata_sparse_output_annotations(metadata), + is_evaluation_complete=evaluate_schema_overrides, + ) def _override_sparse_feature_annotated_shapes( feature_spec: Dict[str, common_types.FeatureSpecType], sparse_output_annotations: Dict[str, List[Union[int, str]]], ): - """Overrides feature_spec SparseFeatures with annotated shapes.""" - for k in sparse_output_annotations: - if k not in feature_spec: - raise ValueError( - f'Shape annotation for feature "{k}" which is not an output feature.' - ) - - if not isinstance(feature := feature_spec[k], tf.io.SparseFeature): - logging.warning( - 'Feature "%s" was annotated with a sparse shape but it is not' - ' sparse: %s', k, feature) - continue - - annotations = sparse_output_annotations[k] - - # Otherwise, the feature was just annotated to be truely sparse. - if any(isinstance(t, (str, bytes)) for t in annotations): - logging.warning( - 'Feature "%s" was annotated to be sparse without setting a shape.', k) - continue - - feature_spec[k] = tf.io.SparseFeature( - index_key=feature.index_key, - value_key=feature.value_key, - dtype=feature.dtype, - size=tf.TensorShape(annotations), - already_sorted=feature.already_sorted, - ) + """Overrides feature_spec SparseFeatures with annotated shapes.""" + for k in sparse_output_annotations: + if k not in feature_spec: + raise ValueError( + f'Shape annotation for feature "{k}" which is not an output feature.' + ) + + if not isinstance(feature := feature_spec[k], tf.io.SparseFeature): + logging.warning( + 'Feature "%s" was annotated with a sparse shape but it is not' + " sparse: %s", + k, + feature, + ) + continue + + annotations = sparse_output_annotations[k] + + # Otherwise, the feature was just annotated to be truely sparse. + if any(isinstance(t, (str, bytes)) for t in annotations): + logging.warning( + 'Feature "%s" was annotated to be sparse without setting a shape.', k + ) + continue + + feature_spec[k] = tf.io.SparseFeature( + index_key=feature.index_key, + value_key=feature.value_key, + dtype=feature.dtype, + size=tf.TensorShape(annotations), + already_sorted=feature.already_sorted, + ) def _infer_feature_schema_common( @@ -327,537 +345,576 @@ def _infer_feature_schema_common( sparse_output_annotations: Optional[Dict[str, List[Union[int, str]]]], is_evaluation_complete: bool, ) -> schema_pb2.Schema: - """Given a dict of tensors, creates a `Schema`. - - Args: - features: A dict mapping column names to `Tensor`, `SparseTensor` or - `RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have - a 0'th dimension which is interpreted as the batch dimension. - tensor_ranges: A dict mapping a tensor to a tuple containing its min and max - value. - feature_annotations: dictionary from feature name to list of any_pb2.Any - protos to be added as an annotation for that feature in the schema. - global_annotations: list of any_pb2.Any protos to be added at the global - schema level. - sparse_output_annotations: A dict mapping sparse feature names to their - annotations. - is_evaluation_complete: A boolean indicating whether all analyzers have been - evaluated or not. - - Returns: - A `Schema` proto. - """ - domains = {} - feature_tags = collections.defaultdict(list) - for name in features: - if name in tensor_ranges: - min_value, max_value = tensor_ranges[name] - domains[name] = schema_pb2.IntDomain( - min=min_value, max=max_value, is_categorical=True) - sparse_output_annotations = sparse_output_annotations or dict() - feature_spec = _feature_spec_from_batched_tensors( - features, is_evaluation_complete, sparse_output_annotations.keys() - ) - _override_sparse_feature_annotated_shapes( - feature_spec, sparse_output_annotations - ) - - schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains) - - # Add the annotations to the schema. - for annotation in global_annotations: - schema_proto.annotation.extra_metadata.add().CopyFrom(annotation) - # Build a map from logical feature names to Feature protos - feature_protos_by_name = {} - for feature in schema_proto.feature: - feature_protos_by_name[feature.name] = feature - for sparse_feature in schema_proto.sparse_feature: - for index_feature in sparse_feature.index_feature: - feature_protos_by_name.pop(index_feature.name) - value_feature = feature_protos_by_name.pop( - sparse_feature.value_feature.name) - feature_protos_by_name[sparse_feature.name] = value_feature - - # Handle ragged tensor representations. - tensor_representations = ( - tensor_representation_util.GetTensorRepresentationsFromSchema( - schema_proto, schema_utils.TENSOR_REPRESENTATION_GROUP)) - if tensor_representations is not None: - for name, tensor_representation in tensor_representations.items(): - feature_protos_by_name[name] = schema_utils.pop_ragged_source_columns( - name, tensor_representation, feature_protos_by_name) - - # Update annotations - for feature_name, annotations in feature_annotations.items(): - feature_proto = feature_protos_by_name[feature_name] - for annotation in annotations: - feature_proto.annotation.extra_metadata.add().CopyFrom(annotation) - for feature_name, tags in feature_tags.items(): - feature_proto = feature_protos_by_name[feature_name] - for tag in tags: - feature_proto.annotation.tag.append(tag) - return schema_proto + """Given a dict of tensors, creates a `Schema`. + + Args: + ---- + features: A dict mapping column names to `Tensor`, `SparseTensor` or + `RaggedTensor`. The `Tensor`, `SparseTensor` or `RaggedTensor` should have + a 0'th dimension which is interpreted as the batch dimension. + tensor_ranges: A dict mapping a tensor to a tuple containing its min and max + value. + feature_annotations: dictionary from feature name to list of any_pb2.Any + protos to be added as an annotation for that feature in the schema. + global_annotations: list of any_pb2.Any protos to be added at the global + schema level. + sparse_output_annotations: A dict mapping sparse feature names to their + annotations. + is_evaluation_complete: A boolean indicating whether all analyzers have been + evaluated or not. + + Returns: + ------- + A `Schema` proto. + """ + domains = {} + feature_tags = collections.defaultdict(list) + for name in features: + if name in tensor_ranges: + min_value, max_value = tensor_ranges[name] + domains[name] = schema_pb2.IntDomain( + min=min_value, max=max_value, is_categorical=True + ) + sparse_output_annotations = sparse_output_annotations or dict() + feature_spec = _feature_spec_from_batched_tensors( + features, is_evaluation_complete, sparse_output_annotations.keys() + ) + _override_sparse_feature_annotated_shapes(feature_spec, sparse_output_annotations) + + schema_proto = schema_utils.schema_from_feature_spec(feature_spec, domains) + + # Add the annotations to the schema. + for annotation in global_annotations: + schema_proto.annotation.extra_metadata.add().CopyFrom(annotation) + # Build a map from logical feature names to Feature protos + feature_protos_by_name = {} + for feature in schema_proto.feature: + feature_protos_by_name[feature.name] = feature + for sparse_feature in schema_proto.sparse_feature: + for index_feature in sparse_feature.index_feature: + feature_protos_by_name.pop(index_feature.name) + value_feature = feature_protos_by_name.pop(sparse_feature.value_feature.name) + feature_protos_by_name[sparse_feature.name] = value_feature + + # Handle ragged tensor representations. + tensor_representations = ( + tensor_representation_util.GetTensorRepresentationsFromSchema( + schema_proto, schema_utils.TENSOR_REPRESENTATION_GROUP + ) + ) + if tensor_representations is not None: + for name, tensor_representation in tensor_representations.items(): + feature_protos_by_name[name] = schema_utils.pop_ragged_source_columns( + name, tensor_representation, feature_protos_by_name + ) + + # Update annotations + for feature_name, annotations in feature_annotations.items(): + feature_proto = feature_protos_by_name[feature_name] + for annotation in annotations: + feature_proto.annotation.extra_metadata.add().CopyFrom(annotation) + for feature_name, tags in feature_tags.items(): + feature_proto = feature_protos_by_name[feature_name] + for tag in tags: + feature_proto.annotation.tag.append(tag) + return schema_proto # Names of collections, which should all be the same length and contain tensors. # Each tensor in the first collection should have its min/max described by the # tensors in the other two collections. -_TF_METADATA_TENSOR_COLLECTION = 'tft_schema_override_tensor' -_TF_METADATA_TENSOR_MIN_COLLECTION = 'tft_schema_override_min' -_TF_METADATA_TENSOR_MAX_COLLECTION = 'tft_schema_override_max' +_TF_METADATA_TENSOR_COLLECTION = "tft_schema_override_tensor" +_TF_METADATA_TENSOR_MIN_COLLECTION = "tft_schema_override_min" +_TF_METADATA_TENSOR_MAX_COLLECTION = "tft_schema_override_max" # Collections for adding to annotation.extra_metadata on the schema. Each # tensor in the first collection should have a proto type and proto message in # the other two collections -_TF_METADATA_EXTRA_ANNOTATION = 'tft_schema_override_annotation_tensor' -_TF_METADATA_EXTRA_ANNOTATION_TYPE_URL = 'tft_schema_override_annotation_type' -_TF_METADATA_EXTRA_ANNOTATION_PROTO = 'tft_schema_override_annotation_proto' +_TF_METADATA_EXTRA_ANNOTATION = "tft_schema_override_annotation_tensor" +_TF_METADATA_EXTRA_ANNOTATION_TYPE_URL = "tft_schema_override_annotation_type" +_TF_METADATA_EXTRA_ANNOTATION_PROTO = "tft_schema_override_annotation_proto" # Used to indicate that an annotation should be applied at the schema level. -_TF_METADATA_EXTRA_ANNOTATION_GLOBAL = 'tft_schema_override_global_sentinel' +_TF_METADATA_EXTRA_ANNOTATION_GLOBAL = "tft_schema_override_global_sentinel" # V2 metadata entry for shape overrides of SparseTensor outputs. -_METADATA_SPARSE_OUTPUT_OVERRIDES_FIELD = 'tft_sparse_output_overrides' +_METADATA_SPARSE_OUTPUT_OVERRIDES_FIELD = "tft_sparse_output_overrides" def set_tensor_schema_override(tensor, min_value, max_value): - """Override parts of the schema of a `Tensor`. - - Args: - tensor: The `Tensor` whose range is being set. Must have dtype int64. - min_value: A `Tensor` representing the min value of `tensor`. - max_value: A `Tensor` representing the max value of `tensor`. - - Raises: - ValueError: If any arguments are invalid. - """ - if not isinstance(tensor, tf.Tensor): - raise ValueError('tensor {} was not a Tensor'.format(tensor)) - if tensor.dtype != tf.int64: - raise ValueError( - 'Range can only be set for feature of type tf.int64, got {}'.format( - tensor.dtype)) - if not isinstance(min_value, tf.Tensor): - raise ValueError('min_value {} was not a Tensor'.format(min_value)) - if not isinstance(max_value, tf.Tensor): - raise ValueError('max_value {} was not a Tensor'.format(max_value)) - tf.compat.v1.add_to_collection(_TF_METADATA_TENSOR_COLLECTION, tensor) - tf.compat.v1.add_to_collection(_TF_METADATA_TENSOR_MIN_COLLECTION, min_value) - tf.compat.v1.add_to_collection(_TF_METADATA_TENSOR_MAX_COLLECTION, max_value) + """Override parts of the schema of a `Tensor`. + + Args: + ---- + tensor: The `Tensor` whose range is being set. Must have dtype int64. + min_value: A `Tensor` representing the min value of `tensor`. + max_value: A `Tensor` representing the max value of `tensor`. + + Raises: + ------ + ValueError: If any arguments are invalid. + """ + if not isinstance(tensor, tf.Tensor): + raise ValueError(f"tensor {tensor} was not a Tensor") + if tensor.dtype != tf.int64: + raise ValueError( + f"Range can only be set for feature of type tf.int64, got {tensor.dtype}" + ) + if not isinstance(min_value, tf.Tensor): + raise ValueError(f"min_value {min_value} was not a Tensor") + if not isinstance(max_value, tf.Tensor): + raise ValueError(f"max_value {max_value} was not a Tensor") + tf.compat.v1.add_to_collection(_TF_METADATA_TENSOR_COLLECTION, tensor) + tf.compat.v1.add_to_collection(_TF_METADATA_TENSOR_MIN_COLLECTION, min_value) + tf.compat.v1.add_to_collection(_TF_METADATA_TENSOR_MAX_COLLECTION, max_value) def _get_tensor_ranges(graph): - """Lookup overrides for `Tensor`s or `SparseTensor`s.""" - tensors = graph.get_collection(_TF_METADATA_TENSOR_COLLECTION) - min_values = graph.get_collection(_TF_METADATA_TENSOR_MIN_COLLECTION) - max_values = graph.get_collection(_TF_METADATA_TENSOR_MAX_COLLECTION) - assert len(tensors) == len(min_values), '{} != {}'.format(tensors, min_values) - assert len(tensors) == len(max_values), '{} != {}'.format(tensors, max_values) - return dict( - zip( - map(tf_utils.hashable_tensor_or_op, tensors), - zip(min_values, max_values))) + """Lookup overrides for `Tensor`s or `SparseTensor`s.""" + tensors = graph.get_collection(_TF_METADATA_TENSOR_COLLECTION) + min_values = graph.get_collection(_TF_METADATA_TENSOR_MIN_COLLECTION) + max_values = graph.get_collection(_TF_METADATA_TENSOR_MAX_COLLECTION) + assert len(tensors) == len(min_values), f"{tensors} != {min_values}" + assert len(tensors) == len(max_values), f"{tensors} != {max_values}" + return dict( + zip(map(tf_utils.hashable_tensor_or_op, tensors), zip(min_values, max_values)) + ) def _get_tensor_ranges_v2(metadata): - """Lookup overrides for `Tensor`s or `SparseTensor`s.""" - tensors = metadata[_TF_METADATA_TENSOR_COLLECTION] - min_values = metadata[_TF_METADATA_TENSOR_MIN_COLLECTION] - max_values = metadata[_TF_METADATA_TENSOR_MAX_COLLECTION] - assert len(tensors) == len(min_values), '{} != {}'.format(tensors, min_values) - assert len(tensors) == len(max_values), '{} != {}'.format(tensors, max_values) - return { - tensor.numpy().decode(): (min_value.numpy(), max_value.numpy()) - for (tensor, min_value, max_value) in zip(tensors, min_values, max_values) - } + """Lookup overrides for `Tensor`s or `SparseTensor`s.""" + tensors = metadata[_TF_METADATA_TENSOR_COLLECTION] + min_values = metadata[_TF_METADATA_TENSOR_MIN_COLLECTION] + max_values = metadata[_TF_METADATA_TENSOR_MAX_COLLECTION] + assert len(tensors) == len(min_values), f"{tensors} != {min_values}" + assert len(tensors) == len(max_values), f"{tensors} != {max_values}" + return { + tensor.numpy().decode(): (min_value.numpy(), max_value.numpy()) + for (tensor, min_value, max_value) in zip(tensors, min_values, max_values) + } def get_tensor_schema_override( - tensor: common_types.TensorType) -> Tuple[tf.Tensor, tf.Tensor]: - """Lookup schema overrides for a `Tensor` or `CompositeTensor`.""" - tensor = _get_tensor_values(tensor) - overrides = _get_tensor_ranges(tensor.graph) - min_max = overrides.get(tf_utils.hashable_tensor_or_op(tensor), None) - if min_max is None: - raise ValueError('Requested tensor does not have recorded min/max values') - return min_max + tensor: common_types.TensorType, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Lookup schema overrides for a `Tensor` or `CompositeTensor`.""" + tensor = _get_tensor_values(tensor) + overrides = _get_tensor_ranges(tensor.graph) + min_max = overrides.get(tf_utils.hashable_tensor_or_op(tensor), None) + if min_max is None: + raise ValueError("Requested tensor does not have recorded min/max values") + return min_max def annotate(type_url, proto_message, tensor=None): - """Adds a deferred annotation to the schema. - - Experimental: This API is subject to change. - - This function allows analyzers or end users to annotate the post-transform - schema with additional information based on analyzer output. These annotations - are stored in the annotation.extra_metadata field of the tf.metadata schema: - https://github.com/tensorflow/metadata/blob/master/tensorflow_metadata/proto/v0/schema.proto#L193 - - Args: - type_url: A string or string `Tensor` containing the type url which uniquely - identifies the type of the serialized proto message. See - https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/any.proto#L151 - proto_message: A deferred string tensor containing the serialized proto to - write to the feature schema. - tensor: (optional) If provided, the annotation will be written to the - Feature proto that is created for this tensor in the schema. If None, - the annotation is assumed to be global. Note: if the tensor is not present - in the output signature of `preprocessing_fn`, this will be a no-op. - """ - if tensor is None: - tensor = tf.constant('unused', name=_TF_METADATA_EXTRA_ANNOTATION_GLOBAL) - if not isinstance(tensor, (tf.Tensor, tf.SparseTensor)): - raise ValueError('tensor {} was not a Tensor'.format(tensor)) - if not isinstance(proto_message, tf.Tensor): - raise ValueError('proto_message {} was not a Tensor'.format(proto_message)) - - # If the type_url is passed as a plain string, create a string tensor. - if not isinstance(type_url, tf.Tensor): - type_url = tf.constant(type_url, dtype=tf.string) - # Note: The tensors, types, and messages are stored in separate collections - # because SavedModel only supports primitive types in collections. - tf.compat.v1.add_to_collection(_TF_METADATA_EXTRA_ANNOTATION, tensor) - tf.compat.v1.add_to_collection(_TF_METADATA_EXTRA_ANNOTATION_TYPE_URL, - type_url) - tf.compat.v1.add_to_collection(_TF_METADATA_EXTRA_ANNOTATION_PROTO, - proto_message) + """Adds a deferred annotation to the schema. + + Experimental: This API is subject to change. + + This function allows analyzers or end users to annotate the post-transform + schema with additional information based on analyzer output. These annotations + are stored in the annotation.extra_metadata field of the tf.metadata schema: + https://github.com/tensorflow/metadata/blob/master/tensorflow_metadata/proto/v0/schema.proto#L193 + + Args: + ---- + type_url: A string or string `Tensor` containing the type url which uniquely + identifies the type of the serialized proto message. See + https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/any.proto#L151 + proto_message: A deferred string tensor containing the serialized proto to + write to the feature schema. + tensor: (optional) If provided, the annotation will be written to the + Feature proto that is created for this tensor in the schema. If None, + the annotation is assumed to be global. Note: if the tensor is not present + in the output signature of `preprocessing_fn`, this will be a no-op. + """ + if tensor is None: + tensor = tf.constant("unused", name=_TF_METADATA_EXTRA_ANNOTATION_GLOBAL) + if not isinstance(tensor, (tf.Tensor, tf.SparseTensor)): + raise ValueError(f"tensor {tensor} was not a Tensor") + if not isinstance(proto_message, tf.Tensor): + raise ValueError(f"proto_message {proto_message} was not a Tensor") + + # If the type_url is passed as a plain string, create a string tensor. + if not isinstance(type_url, tf.Tensor): + type_url = tf.constant(type_url, dtype=tf.string) + # Note: The tensors, types, and messages are stored in separate collections + # because SavedModel only supports primitive types in collections. + tf.compat.v1.add_to_collection(_TF_METADATA_EXTRA_ANNOTATION, tensor) + tf.compat.v1.add_to_collection(_TF_METADATA_EXTRA_ANNOTATION_TYPE_URL, type_url) + tf.compat.v1.add_to_collection(_TF_METADATA_EXTRA_ANNOTATION_PROTO, proto_message) def _get_schema_annotations(graph, session): - """Fetch extra_metadata annotations to be applied to the schema. - - Extracts any deferred annotations that have been added to the graph and - evaluates them to obtain any_pb2.Any proto messages. - - Args: - graph: A `tf.Graph` used to determine schema overrides. - session: (optional) A `tf.Session` used to compute schema annotations. If - None, schema annotations will not be computed. - - Returns: - tensor_annotations: dictionary from tensor to list of any_pb2.Any protos to - be added as an annotation for that tensor's feature in the schema. - global_annotations: list of any_pb2.Any protos to be added at the global - schema level. - """ - tensors = graph.get_collection(_TF_METADATA_EXTRA_ANNOTATION) - type_urls = session.run( - graph.get_collection(_TF_METADATA_EXTRA_ANNOTATION_TYPE_URL)) - proto_values = session.run( - graph.get_collection(_TF_METADATA_EXTRA_ANNOTATION_PROTO)) - tensor_annotation_keys = [] - for tensor in tensors: - # Entries meant for the global schema annotation will have names like - # tft_schema_override_global_sentinel:0 or - # transform/tft_schema_override_global_sentinel_1:0 - tensor_name = tensor.name.split('/')[-1] - if tensor_name.startswith(_TF_METADATA_EXTRA_ANNOTATION_GLOBAL): - tensor_annotation_keys.append(_TF_METADATA_EXTRA_ANNOTATION_GLOBAL) - else: - tensor_annotation_keys.append(tf_utils.hashable_tensor_or_op(tensor)) - return _get_schema_annotations_common(tensor_annotation_keys, type_urls, - proto_values) + """Fetch extra_metadata annotations to be applied to the schema. + + Extracts any deferred annotations that have been added to the graph and + evaluates them to obtain any_pb2.Any proto messages. + + Args: + ---- + graph: A `tf.Graph` used to determine schema overrides. + session: (optional) A `tf.Session` used to compute schema annotations. If + None, schema annotations will not be computed. + + Returns: + ------- + tensor_annotations: dictionary from tensor to list of any_pb2.Any protos to + be added as an annotation for that tensor's feature in the schema. + global_annotations: list of any_pb2.Any protos to be added at the global + schema level. + """ + tensors = graph.get_collection(_TF_METADATA_EXTRA_ANNOTATION) + type_urls = session.run( + graph.get_collection(_TF_METADATA_EXTRA_ANNOTATION_TYPE_URL) + ) + proto_values = session.run( + graph.get_collection(_TF_METADATA_EXTRA_ANNOTATION_PROTO) + ) + tensor_annotation_keys = [] + for tensor in tensors: + # Entries meant for the global schema annotation will have names like + # tft_schema_override_global_sentinel:0 or + # transform/tft_schema_override_global_sentinel_1:0 + tensor_name = tensor.name.split("/")[-1] + if tensor_name.startswith(_TF_METADATA_EXTRA_ANNOTATION_GLOBAL): + tensor_annotation_keys.append(_TF_METADATA_EXTRA_ANNOTATION_GLOBAL) + else: + tensor_annotation_keys.append(tf_utils.hashable_tensor_or_op(tensor)) + return _get_schema_annotations_common( + tensor_annotation_keys, type_urls, proto_values + ) def _get_schema_annotations_v2(metadata): - """Fetch extra_metadata annotations to be applied to the schema. - - Extracts any deferred annotations that have been added to the graph and - evaluates them to obtain any_pb2.Any proto messages. - - Args: - metadata: A dictionary containing the deferred annotations added to the - graph. - - Returns: - tensor_annotations: dictionary from tensor to list of any_pb2.Any protos to - be added as an annotation for that tensor's feature in the schema. - global_annotations: list of any_pb2.Any protos to be added at the global - schema level. - """ - type_urls = [ - type_url.numpy() - for type_url in metadata[_TF_METADATA_EXTRA_ANNOTATION_TYPE_URL] - ] - proto_values = [ - proto_value.numpy() - for proto_value in metadata[_TF_METADATA_EXTRA_ANNOTATION_PROTO] - ] - tensor_annotation_keys = [ - tensor.numpy().decode() - for tensor in metadata[_TF_METADATA_EXTRA_ANNOTATION] - ] - return _get_schema_annotations_common(tensor_annotation_keys, type_urls, - proto_values) - - -def _get_schema_annotations_common(tensor_annotation_keys, type_urls, - proto_values): - """Fetch extra_metadata annotations to be applied to the schema. - - Args: - tensor_annotation_keys: A list containing either - `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` or a hashed tensor representation - corresponding to each entry in `proto_values`. If an entry - is`_TF_METADATA_EXTRA_ANNOTATION_GLOBAL`, the corresponding any_pb2.Any - proto in `proto_values` is returned in `global_annotations`. Otherwise, it - is returned in `feature_annotations`. - type_urls: A list of type urls corresponding to the serialized protos in - `proto_values`. - proto_values: A list of serialized any_pb2.Any protos. - - Returns: - A tuple of: - tensor_annotations: dictionary from tensor to list of any_pb2.Any protos to - be added as an annotation for that tensor's feature in the schema. - global_annotations: list of any_pb2.Any protos to be added at the global - schema level. - """ - tensor_annotations = collections.defaultdict(list) - global_annotations = [] - if not common.IS_ANNOTATIONS_PB_AVAILABLE: + """Fetch extra_metadata annotations to be applied to the schema. + + Extracts any deferred annotations that have been added to the graph and + evaluates them to obtain any_pb2.Any proto messages. + + Args: + ---- + metadata: A dictionary containing the deferred annotations added to the + graph. + + Returns: + ------- + tensor_annotations: dictionary from tensor to list of any_pb2.Any protos to + be added as an annotation for that tensor's feature in the schema. + global_annotations: list of any_pb2.Any protos to be added at the global + schema level. + """ + type_urls = [ + type_url.numpy() + for type_url in metadata[_TF_METADATA_EXTRA_ANNOTATION_TYPE_URL] + ] + proto_values = [ + proto_value.numpy() + for proto_value in metadata[_TF_METADATA_EXTRA_ANNOTATION_PROTO] + ] + tensor_annotation_keys = [ + tensor.numpy().decode() for tensor in metadata[_TF_METADATA_EXTRA_ANNOTATION] + ] + return _get_schema_annotations_common( + tensor_annotation_keys, type_urls, proto_values + ) + + +def _get_schema_annotations_common(tensor_annotation_keys, type_urls, proto_values): + """Fetch extra_metadata annotations to be applied to the schema. + + Args: + ---- + tensor_annotation_keys: A list containing either + `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` or a hashed tensor representation + corresponding to each entry in `proto_values`. If an entry + is`_TF_METADATA_EXTRA_ANNOTATION_GLOBAL`, the corresponding any_pb2.Any + proto in `proto_values` is returned in `global_annotations`. Otherwise, it + is returned in `feature_annotations`. + type_urls: A list of type urls corresponding to the serialized protos in + `proto_values`. + proto_values: A list of serialized any_pb2.Any protos. + + Returns: + ------- + A tuple of: + tensor_annotations: dictionary from tensor to list of any_pb2.Any protos to + be added as an annotation for that tensor's feature in the schema. + global_annotations: list of any_pb2.Any protos to be added at the global + schema level. + """ + tensor_annotations = collections.defaultdict(list) + global_annotations = [] + if not common.IS_ANNOTATIONS_PB_AVAILABLE: + return tensor_annotations, global_annotations + assert len(tensor_annotation_keys) == len(type_urls) == len(proto_values) + for tensor_annotation_key, type_url, proto_value in zip( + tensor_annotation_keys, type_urls, proto_values + ): + annotation = any_pb2.Any(type_url=type_url, value=proto_value) + if ( + isinstance( + _TF_METADATA_EXTRA_ANNOTATION_GLOBAL, type(tensor_annotation_key) + ) + and tensor_annotation_key == _TF_METADATA_EXTRA_ANNOTATION_GLOBAL + ): + global_annotations.append(annotation) + else: + tensor_annotations[tensor_annotation_key].append(annotation) return tensor_annotations, global_annotations - assert len(tensor_annotation_keys) == len(type_urls) == len(proto_values) - for (tensor_annotation_key, type_url, - proto_value) in zip(tensor_annotation_keys, type_urls, proto_values): - annotation = any_pb2.Any(type_url=type_url, value=proto_value) - if (isinstance(_TF_METADATA_EXTRA_ANNOTATION_GLOBAL, - type(tensor_annotation_key)) and - tensor_annotation_key == _TF_METADATA_EXTRA_ANNOTATION_GLOBAL): - global_annotations.append(annotation) - else: - tensor_annotations[tensor_annotation_key].append(annotation) - return tensor_annotations, global_annotations def _get_tensor_value_to_key_map(features_dict): - """Get reverse map from name of tensor values to key in `features_dict`.""" - result = {} - for key, tensor in features_dict.items(): - values = _get_tensor_values(tensor) - result[values.name] = key - return result - - -def _get_schema_overrides(graph, - tensor_name_to_key_map, - tensor_collection_key, - overrides_keys, - default_tensor_name=None): - """Obtain schema overrides from graph collections. - - For every tensor in the `tensor_collection_key` collection, the corresponding - feature name is in `tensor_name_to_key_map` and its schema overrides are in - the graph collections defined by keys in `overrides_keys`. - If a tensor does not exist in `tensor_name_to_key_map` but its name starts - with `default_tensor_name` (if provided), the overrides are returned with this - key. - - Args: - graph: A `FuncGraph`. - tensor_name_to_key_map: A dictionary from tensor name to output feature key. - tensor_collection_key: Key for the graph collection that contains list of - tensors to annotate. - overrides_keys: A list of graph collection keys that contain schema - overrides/annotations. - default_tensor_name: (Optional) A String. If provided, use as feature key if - a tensor in the graph collections is not in `tensor_name_to_key_map`. - - Returns: - A dictionary from graph collection keys to lists of features and their - schema overrides/annotations. - - """ - tensors = graph.get_collection(tensor_collection_key) - overrides_list = [graph.get_collection(k) for k in overrides_keys] - - result = collections.defaultdict(list) - if len(tensors) != len(overrides_list[0]) or any( - len(lst) != len(overrides_list[0]) for lst in overrides_list - ): - raise ValueError( - f'Unexpected collections lengths. tensors: {tensors}, overrides_list:' - f' {overrides_list}' - ) - for tensor, overrides_tuple in zip(tensors, zip(*overrides_list)): - if tensor.name in tensor_name_to_key_map: - result[tensor_collection_key].append(tensor_name_to_key_map[tensor.name]) - else: - if default_tensor_name is None: - continue - tensor_name = tensor.name.split('/')[-1] - if tensor.dtype == tf.string and tensor_name.startswith( - default_tensor_name): - result[tensor_collection_key].append(default_tensor_name) - else: - continue - - # If a feature name was added to the result list for tensor_collection_key, - # add its annotations as well. - assert len(overrides_keys) == len(overrides_tuple) - for overrides_key, override in zip(overrides_keys, overrides_tuple): - result[overrides_key].append(override) - return result + """Get reverse map from name of tensor values to key in `features_dict`.""" + result = {} + for key, tensor in features_dict.items(): + values = _get_tensor_values(tensor) + result[values.name] = key + return result -def get_traced_metadata_fn( - preprocessing_fn: Callable[[Mapping[str, common_types.TensorType]], - Mapping[str, common_types.TensorType]], - structured_inputs: Mapping[str, common_types.TensorType], - tf_graph_context: graph_context.TFGraphContext, - evaluate_schema_overrides: bool) -> tf.types.experimental.GenericFunction: - """Get a tf.function that returns a dictionary of annotations. - - Annotations are added to graph collections keyed by graph tensor names when - `preprocessing_fn` is being traced. The metadata fn defined by this method - converts the graph tensor names to output feature keys. - - If `evaluate_schema_overrides` is True, tracing the `preprocessing_fn` will - add overrides for feature ranges (min/max) and/or feature protos to the graph - collection, if applicable. These overrides are returned when the function - returned by this method is invoked. - - Args: - preprocessing_fn: A user defined python function to be traced. - structured_inputs: A dictionary of placeholder inputs to `preprocessing_fn`. - tf_graph_context: A `TFGraphContext` context manager to invoke the - `preprocessing_fn` in. - evaluate_schema_overrides: If `False`, the returned dictionary contains a - single key `_TF_METADATA_TENSOR_COLLECTION` as all other annotations are - deferred. Else, the returned dictionary contains several deferred - annotations. - - Returns: - A tf.function which when invoked returns a dictionary whose keys represent - the types of annotations and the values are collections of feature - keys/annotations. - """ - # Since this is a TFT-internal function with constant outputs, autograph will - # not affect its behavior. It will only increase tracing time, if enabled. - # Hence, trace with `autograph=False` here. - @tf.function(input_signature=[], autograph=False) - def metadata_fn(): - graph = ops.get_default_graph() - inputs = tf2_utils.supply_missing_inputs(structured_inputs, batch_size=1) - with tf_graph_context: - transformed_features = preprocessing_fn(inputs) - - # Get a map from tensor value names to feature keys. - reversed_features = _get_tensor_value_to_key_map(transformed_features) +def _get_schema_overrides( + graph, + tensor_name_to_key_map, + tensor_collection_key, + overrides_keys, + default_tensor_name=None, +): + """Obtain schema overrides from graph collections. + + For every tensor in the `tensor_collection_key` collection, the corresponding + feature name is in `tensor_name_to_key_map` and its schema overrides are in + the graph collections defined by keys in `overrides_keys`. + If a tensor does not exist in `tensor_name_to_key_map` but its name starts + with `default_tensor_name` (if provided), the overrides are returned with this + key. + + Args: + ---- + graph: A `FuncGraph`. + tensor_name_to_key_map: A dictionary from tensor name to output feature key. + tensor_collection_key: Key for the graph collection that contains list of + tensors to annotate. + overrides_keys: A list of graph collection keys that contain schema + overrides/annotations. + default_tensor_name: (Optional) A String. If provided, use as feature key if + a tensor in the graph collections is not in `tensor_name_to_key_map`. + + Returns: + ------- + A dictionary from graph collection keys to lists of features and their + schema overrides/annotations. + + """ + tensors = graph.get_collection(tensor_collection_key) + overrides_list = [graph.get_collection(k) for k in overrides_keys] result = collections.defaultdict(list) - if not evaluate_schema_overrides: - schema_override_tensors = graph.get_collection( - _TF_METADATA_TENSOR_COLLECTION) - for tensor in schema_override_tensors: - if tensor.name in reversed_features: - result[_TF_METADATA_TENSOR_COLLECTION].append( - reversed_features[tensor.name]) - else: - # Obtain schema overrides for feature tensor ranges. - result.update( - _get_schema_overrides(graph, reversed_features, - _TF_METADATA_TENSOR_COLLECTION, [ - _TF_METADATA_TENSOR_MIN_COLLECTION, - _TF_METADATA_TENSOR_MAX_COLLECTION - ])) - # Obtain schema overrides for feature protos. If no feature tensor is in - # the `_TF_METADATA_EXTRA_ANNOTATION` collection for a specified - # annotation, `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` is used as the - # feature name to indicate that this annotation should be added to the - # global schema. - result.update( - _get_schema_overrides(graph, reversed_features, - _TF_METADATA_EXTRA_ANNOTATION, [ - _TF_METADATA_EXTRA_ANNOTATION_TYPE_URL, - _TF_METADATA_EXTRA_ANNOTATION_PROTO - ], _TF_METADATA_EXTRA_ANNOTATION_GLOBAL)) - result[_METADATA_SPARSE_OUTPUT_OVERRIDES_FIELD] = [ - _get_sparse_output_annotations_v2(graph, reversed_features) - ] + if len(tensors) != len(overrides_list[0]) or any( + len(lst) != len(overrides_list[0]) for lst in overrides_list + ): + raise ValueError( + f"Unexpected collections lengths. tensors: {tensors}, overrides_list:" + f" {overrides_list}" + ) + for tensor, overrides_tuple in zip(tensors, zip(*overrides_list)): + if tensor.name in tensor_name_to_key_map: + result[tensor_collection_key].append(tensor_name_to_key_map[tensor.name]) + else: + if default_tensor_name is None: + continue + tensor_name = tensor.name.split("/")[-1] + if tensor.dtype == tf.string and tensor_name.startswith( + default_tensor_name + ): + result[tensor_collection_key].append(default_tensor_name) + else: + continue + + # If a feature name was added to the result list for tensor_collection_key, + # add its annotations as well. + assert len(overrides_keys) == len(overrides_tuple) + for overrides_key, override in zip(overrides_keys, overrides_tuple): + result[overrides_key].append(override) return result - # Though this tf.function is not serialized to SavedModel, if it is capturing - # any resources, they need to be tracked to ensure they are not garbage - # collected. - # We strip control dependencies from this function as we only want to evaluate - # annotations which are constants available at graph construction time and - # have no dependency on inputs. - # TODO(b/149352022): Re-factor metadata computation when it is possible to - # evaluate non output tensors from a func graph. - module = tf_graph_context.module_to_export - saved_transform_io_v2.trace_and_update_module( - module, metadata_fn, 'metadata_fn', strip_control_dependencies=True) - return module.metadata_fn -_ANNOTATED_SPARSE_SHAPE_TENSORS = 'annotated_sparse_shape_tensors' -_ANNOTATED_SPARSE_SHAPES = 'annotated_sparse_shape_dimensions' -_ANNOTATED_TRUELY_SPARSE_TENSORS = 'annotated_truely_sparse_tensors' +def get_traced_metadata_fn( + preprocessing_fn: Callable[ + [Mapping[str, common_types.TensorType]], Mapping[str, common_types.TensorType] + ], + structured_inputs: Mapping[str, common_types.TensorType], + tf_graph_context: graph_context.TFGraphContext, + evaluate_schema_overrides: bool, +) -> tf.types.experimental.GenericFunction: + """Get a tf.function that returns a dictionary of annotations. + + Annotations are added to graph collections keyed by graph tensor names when + `preprocessing_fn` is being traced. The metadata fn defined by this method + converts the graph tensor names to output feature keys. + + If `evaluate_schema_overrides` is True, tracing the `preprocessing_fn` will + add overrides for feature ranges (min/max) and/or feature protos to the graph + collection, if applicable. These overrides are returned when the function + returned by this method is invoked. + + Args: + ---- + preprocessing_fn: A user defined python function to be traced. + structured_inputs: A dictionary of placeholder inputs to `preprocessing_fn`. + tf_graph_context: A `TFGraphContext` context manager to invoke the + `preprocessing_fn` in. + evaluate_schema_overrides: If `False`, the returned dictionary contains a + single key `_TF_METADATA_TENSOR_COLLECTION` as all other annotations are + deferred. Else, the returned dictionary contains several deferred + annotations. + + Returns: + ------- + A tf.function which when invoked returns a dictionary whose keys represent + the types of annotations and the values are collections of feature + keys/annotations. + """ + + # Since this is a TFT-internal function with constant outputs, autograph will + # not affect its behavior. It will only increase tracing time, if enabled. + # Hence, trace with `autograph=False` here. + @tf.function(input_signature=[], autograph=False) + def metadata_fn(): + graph = ops.get_default_graph() + inputs = tf2_utils.supply_missing_inputs(structured_inputs, batch_size=1) + with tf_graph_context: + transformed_features = preprocessing_fn(inputs) + + # Get a map from tensor value names to feature keys. + reversed_features = _get_tensor_value_to_key_map(transformed_features) + + result = collections.defaultdict(list) + if not evaluate_schema_overrides: + schema_override_tensors = graph.get_collection( + _TF_METADATA_TENSOR_COLLECTION + ) + for tensor in schema_override_tensors: + if tensor.name in reversed_features: + result[_TF_METADATA_TENSOR_COLLECTION].append( + reversed_features[tensor.name] + ) + else: + # Obtain schema overrides for feature tensor ranges. + result.update( + _get_schema_overrides( + graph, + reversed_features, + _TF_METADATA_TENSOR_COLLECTION, + [ + _TF_METADATA_TENSOR_MIN_COLLECTION, + _TF_METADATA_TENSOR_MAX_COLLECTION, + ], + ) + ) + # Obtain schema overrides for feature protos. If no feature tensor is in + # the `_TF_METADATA_EXTRA_ANNOTATION` collection for a specified + # annotation, `_TF_METADATA_EXTRA_ANNOTATION_GLOBAL` is used as the + # feature name to indicate that this annotation should be added to the + # global schema. + result.update( + _get_schema_overrides( + graph, + reversed_features, + _TF_METADATA_EXTRA_ANNOTATION, + [ + _TF_METADATA_EXTRA_ANNOTATION_TYPE_URL, + _TF_METADATA_EXTRA_ANNOTATION_PROTO, + ], + _TF_METADATA_EXTRA_ANNOTATION_GLOBAL, + ) + ) + result[_METADATA_SPARSE_OUTPUT_OVERRIDES_FIELD] = [ + _get_sparse_output_annotations_v2(graph, reversed_features) + ] + return result + + # Though this tf.function is not serialized to SavedModel, if it is capturing + # any resources, they need to be tracked to ensure they are not garbage + # collected. + # We strip control dependencies from this function as we only want to evaluate + # annotations which are constants available at graph construction time and + # have no dependency on inputs. + # TODO(b/149352022): Re-factor metadata computation when it is possible to + # evaluate non output tensors from a func graph. + module = tf_graph_context.module_to_export + saved_transform_io_v2.trace_and_update_module( + module, metadata_fn, "metadata_fn", strip_control_dependencies=True + ) + return module.metadata_fn + + +_ANNOTATED_SPARSE_SHAPE_TENSORS = "annotated_sparse_shape_tensors" +_ANNOTATED_SPARSE_SHAPES = "annotated_sparse_shape_dimensions" +_ANNOTATED_TRUELY_SPARSE_TENSORS = "annotated_truely_sparse_tensors" def annotate_sparse_output_shape(tensor: tf.SparseTensor, shape: tf.Tensor): - """Annotates a sparse output with a given shape.""" - tf.compat.v1.add_to_collection(_ANNOTATED_SPARSE_SHAPE_TENSORS, tensor.values) - tf.compat.v1.add_to_collection(_ANNOTATED_SPARSE_SHAPES, shape) + """Annotates a sparse output with a given shape.""" + tf.compat.v1.add_to_collection(_ANNOTATED_SPARSE_SHAPE_TENSORS, tensor.values) + tf.compat.v1.add_to_collection(_ANNOTATED_SPARSE_SHAPES, shape) def annotate_true_sparse_output(tensor: tf.SparseTensor): - """Annotates a true sparse output to avoid representing it as varlen.""" - tf.compat.v1.add_to_collection( - _ANNOTATED_TRUELY_SPARSE_TENSORS, tensor.values - ) + """Annotates a true sparse output to avoid representing it as varlen.""" + tf.compat.v1.add_to_collection(_ANNOTATED_TRUELY_SPARSE_TENSORS, tensor.values) -def _extract_true_sparse_annotations( - graph: tf.compat.v1.Graph) -> List[tf.Tensor]: - """Extracts true sparse annotations from the graph.""" - return graph.get_collection(_ANNOTATED_TRUELY_SPARSE_TENSORS) +def _extract_true_sparse_annotations(graph: tf.compat.v1.Graph) -> List[tf.Tensor]: + """Extracts true sparse annotations from the graph.""" + return graph.get_collection(_ANNOTATED_TRUELY_SPARSE_TENSORS) def _extract_sparse_output_annotations( - graph: tf.compat.v1.Graph) -> List[Tuple[tf.Tensor, List[tf.Tensor]]]: - """Extracts sparse output annotations from the graph.""" - tensors = graph.get_collection(_ANNOTATED_SPARSE_SHAPE_TENSORS) - shapes = graph.get_collection(_ANNOTATED_SPARSE_SHAPES) - assert len(tensors) == len(shapes), f'{tensors} != {shapes}' - return list(zip(tensors, shapes)) + graph: tf.compat.v1.Graph, +) -> List[Tuple[tf.Tensor, List[tf.Tensor]]]: + """Extracts sparse output annotations from the graph.""" + tensors = graph.get_collection(_ANNOTATED_SPARSE_SHAPE_TENSORS) + shapes = graph.get_collection(_ANNOTATED_SPARSE_SHAPES) + assert len(tensors) == len(shapes), f"{tensors} != {shapes}" + return list(zip(tensors, shapes)) def _get_sparse_output_annotations( graph: tf.compat.v1.Graph, ) -> List[Tuple[tf.Tensor, List[Union[str, tf.Tensor]]]]: - """Provides sparse output user annotations.""" - sparse_output_annotations = _extract_sparse_output_annotations(graph) - annotated_refs = set(t[0].ref() for t in sparse_output_annotations) - # Can't put None in collection, so putting an empty string. - return list( - itertools.chain( - ( - (a, tf.constant([''])) - for a in _extract_true_sparse_annotations(graph) - if a.ref() not in annotated_refs - ), - sparse_output_annotations, - ) - ) + """Provides sparse output user annotations.""" + sparse_output_annotations = _extract_sparse_output_annotations(graph) + annotated_refs = set(t[0].ref() for t in sparse_output_annotations) + # Can't put None in collection, so putting an empty string. + return list( + itertools.chain( + ( + (a, tf.constant([""])) + for a in _extract_true_sparse_annotations(graph) + if a.ref() not in annotated_refs + ), + sparse_output_annotations, + ) + ) def _get_sparse_output_annotations_v1( graph: tf.compat.v1.Graph, session: Optional[tf.compat.v1.Session] ) -> Dict[Any, List[Union[str, tf.Tensor]]]: - if not session: - return {} - else: - return { - tf_utils.hashable_tensor_or_op(kv[0]): session.run(kv[1]) - for kv in _get_sparse_output_annotations(graph) - } + if not session: + return {} + else: + return { + tf_utils.hashable_tensor_or_op(kv[0]): session.run(kv[1]) + for kv in _get_sparse_output_annotations(graph) + } def _get_sparse_output_annotations_v2( graph: tf.compat.v1.Graph, tensor_to_feature_names: Dict[str, str] ) -> Dict[str, List[Union[str, tf.Tensor]]]: - annotations = _get_sparse_output_annotations(graph) - result = {} - for tensor, v in annotations: - if tensor.name in tensor_to_feature_names: - result[tensor_to_feature_names[tensor.name]] = v - return result + annotations = _get_sparse_output_annotations(graph) + result = {} + for tensor, v in annotations: + if tensor.name in tensor_to_feature_names: + result[tensor_to_feature_names[tensor.name]] = v + return result diff --git a/tensorflow_transform/schema_inference_test.py b/tensorflow_transform/schema_inference_test.py index a015983..7f27f1a 100644 --- a/tensorflow_transform/schema_inference_test.py +++ b/tensorflow_transform/schema_inference_test.py @@ -15,321 +15,354 @@ import functools import os +import unittest import tensorflow as tf -from tensorflow_transform import analyzers -from tensorflow_transform import common -from tensorflow_transform import graph_context -from tensorflow_transform import mappers -from tensorflow_transform import schema_inference -from tensorflow_transform import tf2_utils -from tensorflow_transform import test_case -from tensorflow_transform.tf_metadata import schema_utils_legacy -from tensorflow_transform.tf_metadata import schema_utils - from google.protobuf import text_format -import unittest from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_transform import ( + analyzers, + common, + graph_context, + mappers, + schema_inference, + test_case, + tf2_utils, +) +from tensorflow_transform.tf_metadata import schema_utils, schema_utils_legacy if common.IS_ANNOTATIONS_PB_AVAILABLE: - from tensorflow_transform import annotations_pb2 # pylint: disable=g-import-not-at-top + from tensorflow_transform import ( + annotations_pb2, # pylint: disable=g-import-not-at-top + ) def _make_tensors(inputs): - return {'x': tf.identity(inputs['x'])} + return {"x": tf.identity(inputs["x"])} def _make_tensors_with_override(inputs): - x = tf.identity(inputs['x']) - schema_inference.set_tensor_schema_override(x, tf.constant(5), tf.constant(6)) - return {'x': x} + x = tf.identity(inputs["x"]) + schema_inference.set_tensor_schema_override(x, tf.constant(5), tf.constant(6)) + return {"x": x} def _make_tensors_with_depth(inputs, depth=None): - if depth is None: - depth = tf.raw_ops.Placeholder(dtype=tf.int32, shape=[]) - else: - depth = tf.constant(depth, dtype=tf.int32) - return {'x': tf.one_hot(inputs['x'], depth=depth, dtype=inputs['x'].dtype)} + if depth is None: + depth = tf.raw_ops.Placeholder(dtype=tf.int32, shape=[]) + else: + depth = tf.constant(depth, dtype=tf.int32) + return {"x": tf.one_hot(inputs["x"], depth=depth, dtype=inputs["x"].dtype)} class SchemaInferenceTest(test_case.TransformTestCase): - - def _get_schema(self, - preprocessing_fn, - use_compat_v1, - inputs=None, - input_signature=None, - create_session=False): - if inputs is None: - inputs = {} - if input_signature is None: - input_signature = {} - if use_compat_v1: - with tf.compat.v1.Graph().as_default() as graph: - # Convert eager tensors to graph tensors. - inputs_copy = { - k: tf.constant(v, input_signature[k].dtype) - for k, v in inputs.items() - } - tensors = preprocessing_fn(inputs_copy) - if create_session: - # Create a session to actually evaluate the annotations and extract - # the output schema with annotations applied. - with tf.compat.v1.Session(graph=graph) as session: - schema = schema_inference.infer_feature_schema( - tensors, graph, session) + def _get_schema( + self, + preprocessing_fn, + use_compat_v1, + inputs=None, + input_signature=None, + create_session=False, + ): + if inputs is None: + inputs = {} + if input_signature is None: + input_signature = {} + if use_compat_v1: + with tf.compat.v1.Graph().as_default() as graph: + # Convert eager tensors to graph tensors. + inputs_copy = { + k: tf.constant(v, input_signature[k].dtype) + for k, v in inputs.items() + } + tensors = preprocessing_fn(inputs_copy) + if create_session: + # Create a session to actually evaluate the annotations and extract + # the output schema with annotations applied. + with tf.compat.v1.Session(graph=graph) as session: + schema = schema_inference.infer_feature_schema( + tensors, graph, session + ) + else: + schema = schema_inference.infer_feature_schema(tensors, graph) else: - schema = schema_inference.infer_feature_schema(tensors, graph) - else: - tf_func = tf.function( - preprocessing_fn, - input_signature=[input_signature]).get_concrete_function() - tensors = tf.nest.pack_sequence_as( - structure=tf_func.structured_outputs, - flat_sequence=tf_func.outputs, - expand_composites=True) - structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( - tf_func.graph) - tf_graph_context = graph_context.TFGraphContext( - module_to_export=tf.Module(), - temp_dir=os.path.join(self.get_temp_dir(), self._testMethodName), - evaluated_replacements={}) - concrete_metadata_fn = schema_inference.get_traced_metadata_fn( - preprocessing_fn=preprocessing_fn, - structured_inputs=structured_inputs, - tf_graph_context=tf_graph_context, - evaluate_schema_overrides=create_session) - schema = schema_inference.infer_feature_schema_v2( - tensors, - concrete_metadata_fn, - evaluate_schema_overrides=create_session) - return schema + tf_func = tf.function( + preprocessing_fn, input_signature=[input_signature] + ).get_concrete_function() + tensors = tf.nest.pack_sequence_as( + structure=tf_func.structured_outputs, + flat_sequence=tf_func.outputs, + expand_composites=True, + ) + structured_inputs = tf2_utils.get_structured_inputs_from_func_graph( + tf_func.graph + ) + tf_graph_context = graph_context.TFGraphContext( + module_to_export=tf.Module(), + temp_dir=os.path.join(self.get_temp_dir(), self._testMethodName), + evaluated_replacements={}, + ) + concrete_metadata_fn = schema_inference.get_traced_metadata_fn( + preprocessing_fn=preprocessing_fn, + structured_inputs=structured_inputs, + tf_graph_context=tf_graph_context, + evaluate_schema_overrides=create_session, + ) + schema = schema_inference.infer_feature_schema_v2( + tensors, concrete_metadata_fn, evaluate_schema_overrides=create_session + ) + return schema - # pylint: disable=g-long-lambda - @test_case.named_parameters(*test_case.cross_named_parameters([ - dict( - testcase_name='fixed_len_int', - make_tensors_fn=_make_tensors, - feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}), - dict( - testcase_name='fixed_len_string', - make_tensors_fn=_make_tensors, - feature_spec={'x': tf.io.FixedLenFeature([], tf.string)}), - dict( - testcase_name='fixed_len_float', - make_tensors_fn=_make_tensors, - feature_spec={'x': tf.io.FixedLenFeature([], tf.float32)}), - dict( - testcase_name='override', - make_tensors_fn=_make_tensors_with_override, - feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}, - domains={'x': schema_pb2.IntDomain(is_categorical=True)}), - dict( - testcase_name='override_with_session', - make_tensors_fn=_make_tensors_with_override, - feature_spec={'x': tf.io.FixedLenFeature([], tf.int64)}, - domains={ - 'x': schema_pb2.IntDomain(min=5, max=6, is_categorical=True) - }, - create_session=True), - dict( - testcase_name='unknown_output_non_batch_dim', - make_tensors_fn=_make_tensors_with_depth, - feature_spec={'x': tf.io.FixedLenFeature([None], tf.int64)}), - dict( - testcase_name='known_output_non_batch_dim', - make_tensors_fn=functools.partial(_make_tensors_with_depth, depth=10), - feature_spec={'x': tf.io.FixedLenFeature([10], tf.int64)}, - create_session=True) - ], [ - dict(testcase_name='compat_v1', use_compat_v1=True), - dict(testcase_name='v2', use_compat_v1=False) - ])) - # pylint: enable=g-long-lambda - def test_infer_feature_schema(self, - make_tensors_fn, - feature_spec, - use_compat_v1, - domains=None, - create_session=False): - if not use_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') - x_val = '0' if feature_spec['x'].dtype == tf.string else 0 - inputs = {'x': [x_val]} - input_signature = { - 'x': tf.TensorSpec([None], dtype=feature_spec['x'].dtype) - } - schema = self._get_schema( + # pylint: disable=g-long-lambda + @test_case.named_parameters( + *test_case.cross_named_parameters( + [ + dict( + testcase_name="fixed_len_int", + make_tensors_fn=_make_tensors, + feature_spec={"x": tf.io.FixedLenFeature([], tf.int64)}, + ), + dict( + testcase_name="fixed_len_string", + make_tensors_fn=_make_tensors, + feature_spec={"x": tf.io.FixedLenFeature([], tf.string)}, + ), + dict( + testcase_name="fixed_len_float", + make_tensors_fn=_make_tensors, + feature_spec={"x": tf.io.FixedLenFeature([], tf.float32)}, + ), + dict( + testcase_name="override", + make_tensors_fn=_make_tensors_with_override, + feature_spec={"x": tf.io.FixedLenFeature([], tf.int64)}, + domains={"x": schema_pb2.IntDomain(is_categorical=True)}, + ), + dict( + testcase_name="override_with_session", + make_tensors_fn=_make_tensors_with_override, + feature_spec={"x": tf.io.FixedLenFeature([], tf.int64)}, + domains={ + "x": schema_pb2.IntDomain(min=5, max=6, is_categorical=True) + }, + create_session=True, + ), + dict( + testcase_name="unknown_output_non_batch_dim", + make_tensors_fn=_make_tensors_with_depth, + feature_spec={"x": tf.io.FixedLenFeature([None], tf.int64)}, + ), + dict( + testcase_name="known_output_non_batch_dim", + make_tensors_fn=functools.partial( + _make_tensors_with_depth, depth=10 + ), + feature_spec={"x": tf.io.FixedLenFeature([10], tf.int64)}, + create_session=True, + ), + ], + [ + dict(testcase_name="compat_v1", use_compat_v1=True), + dict(testcase_name="v2", use_compat_v1=False), + ], + ) + ) + # pylint: enable=g-long-lambda + def test_infer_feature_schema( + self, make_tensors_fn, + feature_spec, use_compat_v1, - inputs=inputs, - input_signature=input_signature, - create_session=create_session) - expected_schema = schema_utils.schema_from_feature_spec( - feature_spec, domains) - self.assertEqual(schema, expected_schema) - - @test_case.named_parameters( - dict(testcase_name='compat_v1', use_compat_v1=True), - dict(testcase_name='v2', use_compat_v1=False)) - def test_infer_feature_schema_bad_rank(self, use_compat_v1): - if not use_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') - inputs = {'x': 0} - input_signature = {'x': tf.TensorSpec([], dtype=tf.float32)} - with self.assertRaises(ValueError): - self._get_schema( - _make_tensors, - use_compat_v1, - inputs=inputs, - input_signature=input_signature) + domains=None, + create_session=False, + ): + if not use_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") + x_val = "0" if feature_spec["x"].dtype == tf.string else 0 + inputs = {"x": [x_val]} + input_signature = {"x": tf.TensorSpec([None], dtype=feature_spec["x"].dtype)} + schema = self._get_schema( + make_tensors_fn, + use_compat_v1, + inputs=inputs, + input_signature=input_signature, + create_session=create_session, + ) + expected_schema = schema_utils.schema_from_feature_spec(feature_spec, domains) + self.assertEqual(schema, expected_schema) - @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE, - 'Schema annotations are not available') - @test_case.named_parameters( - dict(testcase_name='compat_v1', use_compat_v1=True), - dict(testcase_name='v2', use_compat_v1=False)) - def test_vocab_annotation(self, use_compat_v1): - if not use_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') + @test_case.named_parameters( + dict(testcase_name="compat_v1", use_compat_v1=True), + dict(testcase_name="v2", use_compat_v1=False), + ) + def test_infer_feature_schema_bad_rank(self, use_compat_v1): + if not use_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") + inputs = {"x": 0} + input_signature = {"x": tf.TensorSpec([], dtype=tf.float32)} + with self.assertRaises(ValueError): + self._get_schema( + _make_tensors, + use_compat_v1, + inputs=inputs, + input_signature=input_signature, + ) - def preprocessing_fn(_): - analyzers._maybe_annotate_vocab_metadata('file1', - tf.constant(100, dtype=tf.int64), - tf.constant(75, dtype=tf.int64)) - analyzers._maybe_annotate_vocab_metadata('file2', - tf.constant(200, dtype=tf.int64), - tf.constant(175, dtype=tf.int64)) - return { - 'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), - } + @unittest.skipIf( + not common.IS_ANNOTATIONS_PB_AVAILABLE, "Schema annotations are not available" + ) + @test_case.named_parameters( + dict(testcase_name="compat_v1", use_compat_v1=True), + dict(testcase_name="v2", use_compat_v1=False), + ) + def test_vocab_annotation(self, use_compat_v1): + if not use_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") - schema = self._get_schema( - preprocessing_fn, use_compat_v1, create_session=True) - self.assertLen(schema.annotation.extra_metadata, 2) - unfiltered_sizes = {} - filtered_sizes = {} - for annotation in schema.annotation.extra_metadata: - message = annotations_pb2.VocabularyMetadata() - annotation.Unpack(message) - unfiltered_sizes[message.file_name] = message.unfiltered_vocabulary_size - filtered_sizes[message.file_name] = message.filtered_vocabulary_size - self.assertDictEqual(unfiltered_sizes, {'file1': 100, 'file2': 200}) - self.assertDictEqual(filtered_sizes, {'file1': 75, 'file2': 175}) + def preprocessing_fn(_): + analyzers._maybe_annotate_vocab_metadata( + "file1", + tf.constant(100, dtype=tf.int64), + tf.constant(75, dtype=tf.int64), + ) + analyzers._maybe_annotate_vocab_metadata( + "file2", + tf.constant(200, dtype=tf.int64), + tf.constant(175, dtype=tf.int64), + ) + return { + "foo": tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), + } - @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE, - 'Schema annotations are not available') - @test_case.named_parameters( - dict(testcase_name='compat_v1', use_compat_v1=True), - dict(testcase_name='v2', use_compat_v1=False)) - def test_bucketization_annotation(self, use_compat_v1): - if not use_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') + schema = self._get_schema(preprocessing_fn, use_compat_v1, create_session=True) + self.assertLen(schema.annotation.extra_metadata, 2) + unfiltered_sizes = {} + filtered_sizes = {} + for annotation in schema.annotation.extra_metadata: + message = annotations_pb2.VocabularyMetadata() + annotation.Unpack(message) + unfiltered_sizes[message.file_name] = message.unfiltered_vocabulary_size + filtered_sizes[message.file_name] = message.filtered_vocabulary_size + self.assertDictEqual(unfiltered_sizes, {"file1": 100, "file2": 200}) + self.assertDictEqual(filtered_sizes, {"file1": 75, "file2": 175}) - def preprocessing_fn(_): - inputs = { - 'foo': tf.convert_to_tensor([0, 1, 2, 3]), - 'bar': tf.convert_to_tensor([0, 2, 0, 2]), - } - boundaries_foo = tf.expand_dims(tf.convert_to_tensor([.5, 1.5]), axis=0) - boundaries_bar = tf.expand_dims(tf.convert_to_tensor([.1, .2]), axis=0) - outputs = {} - # tft.apply_buckets will annotate the feature in the output schema to - # indicate the bucket boundaries that were applied. - outputs['Bucketized_foo'] = mappers.apply_buckets(inputs['foo'], - boundaries_foo) - outputs['Bucketized_bar'] = mappers.apply_buckets(inputs['bar'], - boundaries_bar) - return outputs + @unittest.skipIf( + not common.IS_ANNOTATIONS_PB_AVAILABLE, "Schema annotations are not available" + ) + @test_case.named_parameters( + dict(testcase_name="compat_v1", use_compat_v1=True), + dict(testcase_name="v2", use_compat_v1=False), + ) + def test_bucketization_annotation(self, use_compat_v1): + if not use_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") - schema = self._get_schema( - preprocessing_fn, use_compat_v1, create_session=True) - self.assertLen(schema.feature, 2) - for feature in schema.feature: - self.assertLen(feature.annotation.extra_metadata, 1) - for annotation in feature.annotation.extra_metadata: + def preprocessing_fn(_): + inputs = { + "foo": tf.convert_to_tensor([0, 1, 2, 3]), + "bar": tf.convert_to_tensor([0, 2, 0, 2]), + } + boundaries_foo = tf.expand_dims(tf.convert_to_tensor([0.5, 1.5]), axis=0) + boundaries_bar = tf.expand_dims(tf.convert_to_tensor([0.1, 0.2]), axis=0) + outputs = {} + # tft.apply_buckets will annotate the feature in the output schema to + # indicate the bucket boundaries that were applied. + outputs["Bucketized_foo"] = mappers.apply_buckets( + inputs["foo"], boundaries_foo + ) + outputs["Bucketized_bar"] = mappers.apply_buckets( + inputs["bar"], boundaries_bar + ) + return outputs - # Extract the annotated message and validate its contents - message = annotations_pb2.BucketBoundaries() - annotation.Unpack(message) - if feature.name == 'Bucketized_foo': - self.assertAllClose(list(message.boundaries), [.5, 1.5]) - elif feature.name == 'Bucketized_bar': - self.assertAllClose(list(message.boundaries), [.1, .2]) - else: - raise RuntimeError('Unexpected features in schema') + schema = self._get_schema(preprocessing_fn, use_compat_v1, create_session=True) + self.assertLen(schema.feature, 2) + for feature in schema.feature: + self.assertLen(feature.annotation.extra_metadata, 1) + for annotation in feature.annotation.extra_metadata: + # Extract the annotated message and validate its contents + message = annotations_pb2.BucketBoundaries() + annotation.Unpack(message) + if feature.name == "Bucketized_foo": + self.assertAllClose(list(message.boundaries), [0.5, 1.5]) + elif feature.name == "Bucketized_bar": + self.assertAllClose(list(message.boundaries), [0.1, 0.2]) + else: + raise RuntimeError("Unexpected features in schema") - @unittest.skipIf(not common.IS_ANNOTATIONS_PB_AVAILABLE, - 'Schema annotations are not available') - @test_case.named_parameters( - dict(testcase_name='compat_v1', use_compat_v1=True), - dict(testcase_name='v2', use_compat_v1=False)) - def test_global_annotation(self, use_compat_v1): - # pylint: enable=g-import-not-at-top - if not use_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') + @unittest.skipIf( + not common.IS_ANNOTATIONS_PB_AVAILABLE, "Schema annotations are not available" + ) + @test_case.named_parameters( + dict(testcase_name="compat_v1", use_compat_v1=True), + dict(testcase_name="v2", use_compat_v1=False), + ) + def test_global_annotation(self, use_compat_v1): + # pylint: enable=g-import-not-at-top + if not use_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") - def preprocessing_fn(_): - # Annotate an arbitrary proto at the schema level (not sure what global - # schema boundaries would mean, but hey I'm just a test). - boundaries = tf.constant([[1.0]]) - message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name - sizes = tf.expand_dims([tf.size(boundaries)], axis=0) - message_proto = tf.raw_ops.EncodeProto( - sizes=sizes, - values=[tf.cast(boundaries, tf.float32)], - field_names=['boundaries'], - message_type=message_type)[0] - type_url = os.path.join('type.googleapis.com', message_type) - schema_inference.annotate(type_url, message_proto) - return { - 'foo': tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), - 'bar': tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64), - } + def preprocessing_fn(_): + # Annotate an arbitrary proto at the schema level (not sure what global + # schema boundaries would mean, but hey I'm just a test). + boundaries = tf.constant([[1.0]]) + message_type = annotations_pb2.BucketBoundaries.DESCRIPTOR.full_name + sizes = tf.expand_dims([tf.size(boundaries)], axis=0) + message_proto = tf.raw_ops.EncodeProto( + sizes=sizes, + values=[tf.cast(boundaries, tf.float32)], + field_names=["boundaries"], + message_type=message_type, + )[0] + type_url = os.path.join("type.googleapis.com", message_type) + schema_inference.annotate(type_url, message_proto) + return { + "foo": tf.convert_to_tensor([0, 1, 2, 3], dtype=tf.int64), + "bar": tf.convert_to_tensor([0, 2, 0, 2], dtype=tf.int64), + } - schema = self._get_schema( - preprocessing_fn, use_compat_v1, create_session=True) - self.assertLen(schema.annotation.extra_metadata, 1) - for annotation in schema.annotation.extra_metadata: - # Extract the annotated message and validate its contents - message = annotations_pb2.BucketBoundaries() - annotation.Unpack(message) - self.assertAllClose(list(message.boundaries), [1]) + schema = self._get_schema(preprocessing_fn, use_compat_v1, create_session=True) + self.assertLen(schema.annotation.extra_metadata, 1) + for annotation in schema.annotation.extra_metadata: + # Extract the annotated message and validate its contents + message = annotations_pb2.BucketBoundaries() + annotation.Unpack(message) + self.assertAllClose(list(message.boundaries), [1]) - @test_case.named_parameters( - dict(testcase_name='compat_v1', use_compat_v1=True), - dict(testcase_name='v2', use_compat_v1=False)) - def test_infer_feature_schema_with_ragged_tensor(self, use_compat_v1): - if not use_compat_v1: - test_case.skip_if_not_tf2('Tensorflow 2.x required') + @test_case.named_parameters( + dict(testcase_name="compat_v1", use_compat_v1=True), + dict(testcase_name="v2", use_compat_v1=False), + ) + def test_infer_feature_schema_with_ragged_tensor(self, use_compat_v1): + if not use_compat_v1: + test_case.skip_if_not_tf2("Tensorflow 2.x required") - def preprocessing_fn(_): - return { - 'foo': - tf.RaggedTensor.from_row_splits( - values=tf.constant([3, 1, 4, 1, 5, 9, 2, 6], tf.int64), - row_splits=[0, 4, 4, 7, 8, 8]), - 'bar': - tf.RaggedTensor.from_row_splits( - values=tf.RaggedTensor.from_row_splits( - values=tf.ones([5], tf.float32), row_splits=[0, 2, 3, 5]), - row_splits=[0, 0, 0, 2, 2, 4]), - 'baz': - tf.RaggedTensor.from_row_splits( - values=tf.ones([5, 3], tf.float32), row_splits=[0, 2, 3, 5]), - 'qux': - tf.RaggedTensor.from_row_splits( - values=tf.RaggedTensor.from_row_splits( - values=tf.ones([5, 7], tf.float32), - row_splits=[0, 2, 3, 5]), - row_splits=[0, 0, 0, 2, 2, 4]), - } + def preprocessing_fn(_): + return { + "foo": tf.RaggedTensor.from_row_splits( + values=tf.constant([3, 1, 4, 1, 5, 9, 2, 6], tf.int64), + row_splits=[0, 4, 4, 7, 8, 8], + ), + "bar": tf.RaggedTensor.from_row_splits( + values=tf.RaggedTensor.from_row_splits( + values=tf.ones([5], tf.float32), row_splits=[0, 2, 3, 5] + ), + row_splits=[0, 0, 0, 2, 2, 4], + ), + "baz": tf.RaggedTensor.from_row_splits( + values=tf.ones([5, 3], tf.float32), row_splits=[0, 2, 3, 5] + ), + "qux": tf.RaggedTensor.from_row_splits( + values=tf.RaggedTensor.from_row_splits( + values=tf.ones([5, 7], tf.float32), row_splits=[0, 2, 3, 5] + ), + row_splits=[0, 0, 0, 2, 2, 4], + ), + } - schema = self._get_schema( - preprocessing_fn, use_compat_v1, create_session=True) - expected_schema_ascii = """ + schema = self._get_schema(preprocessing_fn, use_compat_v1, create_session=True) + expected_schema_ascii = """ feature { name: "bar$ragged_values" type: FLOAT @@ -396,11 +429,10 @@ def preprocessing_fn(_): } } """ - expected_schema = text_format.Parse(expected_schema_ascii, - schema_pb2.Schema()) - schema_utils_legacy.set_generate_legacy_feature_spec(expected_schema, False) - self.assertProtoEquals(expected_schema, schema) + expected_schema = text_format.Parse(expected_schema_ascii, schema_pb2.Schema()) + schema_utils_legacy.set_generate_legacy_feature_spec(expected_schema, False) + self.assertProtoEquals(expected_schema, schema) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow_transform/test_case.py b/tensorflow_transform/test_case.py index 172a3e5..566777f 100644 --- a/tensorflow_transform/test_case.py +++ b/tensorflow_transform/test_case.py @@ -13,20 +13,21 @@ # limitations under the License. """Library for Tensorflow Transform test cases.""" -from builtins import zip # pylint: disable=redefined-builtin,g-importing-member import functools import inspect import itertools import os +import unittest +from builtins import zip # pylint: disable=redefined-builtin,g-importing-member -from absl.testing import parameterized import numpy as np import tensorflow as tf +from absl.testing import parameterized -import unittest # pylint: disable=g-direct-tensorflow-import from tensorflow.python import tf2 from tensorflow.python.eager import context + # pylint: enable=g-direct-tensorflow-import main = tf.test.main @@ -36,321 +37,355 @@ def cross_named_parameters(*args): - """Cross a list of lists of dicts suitable for @named_parameters. - - Takes a list of lists, where each list is suitable as an input to - @named_parameters, and crosses them, forming a new name for each crossed test - case. - - Args: - *args: A list of lists of dicts. - - Returns: - A list of dicts. - """ - def _cross_test_cases(parameters_list): - """Cross a list of test case parameters.""" - crossed_parameters = parameters_list[0].copy() - for current_parameters in parameters_list[1:]: - for name, value in current_parameters.items(): - if name == 'testcase_name': - crossed_parameters[name] = '{}_{}'.format( - crossed_parameters[name], value) - else: - assert name not in crossed_parameters, name - crossed_parameters[name] = value - return crossed_parameters - return list(map(_cross_test_cases, itertools.product(*args))) + """Cross a list of lists of dicts suitable for @named_parameters. + + Takes a list of lists, where each list is suitable as an input to + @named_parameters, and crosses them, forming a new name for each crossed test + case. + + Args: + ---- + *args: A list of lists of dicts. + + Returns: + ------- + A list of dicts. + """ + + def _cross_test_cases(parameters_list): + """Cross a list of test case parameters.""" + crossed_parameters = parameters_list[0].copy() + for current_parameters in parameters_list[1:]: + for name, value in current_parameters.items(): + if name == "testcase_name": + crossed_parameters[name] = f"{crossed_parameters[name]}_{value}" + else: + assert name not in crossed_parameters, name + crossed_parameters[name] = value + return crossed_parameters + + return list(map(_cross_test_cases, itertools.product(*args))) def parameters(*testcases): - """like parameterized.parameters but tests show arg names. + """Like parameterized.parameters but tests show arg names. - Only works for class methods without *args or **kwargs. + Only works for class methods without *args or **kwargs. - Args: - *testcases: The input to parameterized.parameters(). + Args: + ---- + *testcases: The input to parameterized.parameters(). - Returns: - A wrapper function which passes the arguments through as a dictionary. - """ + Returns: + ------- + A wrapper function which passes the arguments through as a dictionary. + """ - def wrapper(fn): - """Constructs and returns the arguments as a dictionary.""" - arg_names = inspect.getfullargspec(fn).args - if arg_names[0] != 'self': - raise ValueError( - 'First argument to test is expected to be "self", but is {}'.format( - arg_names[0])) - arg_names = arg_names[1:] + def wrapper(fn): + """Constructs and returns the arguments as a dictionary.""" + arg_names = inspect.getfullargspec(fn).args + if arg_names[0] != "self": + raise ValueError( + f'First argument to test is expected to be "self", but is {arg_names[0]}' + ) + arg_names = arg_names[1:] - def to_arg_dict(testcase): - if isinstance(testcase, dict): - return testcase - testcase = tuple(testcase) - if len(testcase) != len(arg_names): - raise ValueError( - 'The number of arguments to parameterized test do not match the ' - 'number of expected arguments: {} != {}, arguments: {}, names: {}'. - format(len(testcase), len(arg_names), testcase, arg_names)) - return dict(zip(arg_names, testcase)) + def to_arg_dict(testcase): + if isinstance(testcase, dict): + return testcase + testcase = tuple(testcase) + if len(testcase) != len(arg_names): + raise ValueError( + "The number of arguments to parameterized test do not match the " + f"number of expected arguments: {len(testcase)} != {len(arg_names)}, arguments: {testcase}, names: {arg_names}" + ) + return dict(zip(arg_names, testcase)) - testcases_with_names = [to_arg_dict(testcase) for testcase in testcases] - return parameterized.parameters(*testcases_with_names)(fn) + testcases_with_names = [to_arg_dict(testcase) for testcase in testcases] + return parameterized.parameters(*testcases_with_names)(fn) - return wrapper + return wrapper def cross_parameters(*args): - """Cross a sequence of sequences of parameters suitable for @parameters.""" - for p in itertools.product(*args): - yield functools.reduce(lambda x, y: x + y, p) + """Cross a sequence of sequences of parameters suitable for @parameters.""" + for p in itertools.product(*args): + yield functools.reduce(lambda x, y: x + y, p) def _make_placeholder(tensor_spec): - """Create a placeholder for the given tensor_spec.""" - - if isinstance(tensor_spec, tf.SparseTensorSpec): - return tf.compat.v1.sparse_placeholder( - shape=tensor_spec.shape, dtype=tensor_spec.dtype) - if isinstance(tensor_spec, tf.RaggedTensorSpec): - return tf.compat.v1.ragged.placeholder( - tensor_spec.dtype, tensor_spec.ragged_rank, value_shape=()) - else: - return tf.compat.v1.placeholder( - shape=tensor_spec.shape, dtype=tensor_spec.dtype) + """Create a placeholder for the given tensor_spec.""" + if isinstance(tensor_spec, tf.SparseTensorSpec): + return tf.compat.v1.sparse_placeholder( + shape=tensor_spec.shape, dtype=tensor_spec.dtype + ) + if isinstance(tensor_spec, tf.RaggedTensorSpec): + return tf.compat.v1.ragged.placeholder( + tensor_spec.dtype, tensor_spec.ragged_rank, value_shape=() + ) + else: + return tf.compat.v1.placeholder( + shape=tensor_spec.shape, dtype=tensor_spec.dtype + ) def _graph_function_handler(input_signature): - """Run the given function in graph mode, utilizing placeholders. - - Args: - input_signature: A possibly nested sequence of `tf.TensorSpec` objects - specifying the shapes and dtypes of the Tensors that will be supplied to - this function. - - Returns: - A wrapper function that accepts arguments specified by `input_signature`. - """ - def wrapper(fn): - """Decorator that runs decorated function in graph mode.""" - def _run_graph(*inputs): - with context.graph_mode(): # pylint: disable=missing-docstring - assert len(input_signature) == len(inputs) - placeholders = list(map(_make_placeholder, input_signature)) - output_tensor = fn(*placeholders) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - return sess.run(output_tensor, - feed_dict=dict(zip(placeholders, inputs))) - - return _run_graph - return wrapper + """Run the given function in graph mode, utilizing placeholders. + + Args: + ---- + input_signature: A possibly nested sequence of `tf.TensorSpec` objects + specifying the shapes and dtypes of the Tensors that will be supplied to + this function. + + Returns: + ------- + A wrapper function that accepts arguments specified by `input_signature`. + """ + + def wrapper(fn): + """Decorator that runs decorated function in graph mode.""" + + def _run_graph(*inputs): + with context.graph_mode(): # pylint: disable=missing-docstring + assert len(input_signature) == len(inputs) + placeholders = list(map(_make_placeholder, input_signature)) + output_tensor = fn(*placeholders) + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.tables_initializer()) + return sess.run( + output_tensor, feed_dict=dict(zip(placeholders, inputs)) + ) + + return _run_graph + + return wrapper def _ragged_value_as_constant(value, dtype): - if isinstance(value, tf.compat.v1.ragged.RaggedTensorValue): - return tf.RaggedTensor.from_row_splits( - values=_ragged_value_as_constant(value.values, dtype), - row_splits=tf.constant(value.row_splits, dtype=tf.int64)) - else: - return tf.constant(value, dtype=dtype) + if isinstance(value, tf.compat.v1.ragged.RaggedTensorValue): + return tf.RaggedTensor.from_row_splits( + values=_ragged_value_as_constant(value.values, dtype), + row_splits=tf.constant(value.row_splits, dtype=tf.int64), + ) + else: + return tf.constant(value, dtype=dtype) def _wrap_as_constant(value, tensor_spec): - """Wrap a value as a constant, using tensor_spec for shape and type info.""" - if isinstance(tensor_spec, tf.SparseTensorSpec): - result = tf.SparseTensor( - indices=tf.constant(value.indices, dtype=tf.int64), - values=tf.constant(value.values, dtype=tensor_spec.dtype), - dense_shape=tf.constant(value.dense_shape, dtype=tf.int64)) - elif isinstance(tensor_spec, tf.RaggedTensorSpec): - result = _ragged_value_as_constant(value, tensor_spec.dtype) - else: - result = tf.constant(value, dtype=tensor_spec.dtype) - result.shape.assert_is_compatible_with(tensor_spec.shape) - return result + """Wrap a value as a constant, using tensor_spec for shape and type info.""" + if isinstance(tensor_spec, tf.SparseTensorSpec): + result = tf.SparseTensor( + indices=tf.constant(value.indices, dtype=tf.int64), + values=tf.constant(value.values, dtype=tensor_spec.dtype), + dense_shape=tf.constant(value.dense_shape, dtype=tf.int64), + ) + elif isinstance(tensor_spec, tf.RaggedTensorSpec): + result = _ragged_value_as_constant(value, tensor_spec.dtype) + else: + result = tf.constant(value, dtype=tensor_spec.dtype) + result.shape.assert_is_compatible_with(tensor_spec.shape) + return result def _eager_function_handler(input_signature): - """Run the given function in eager mode. - - Args: - input_signature: A possibly nested sequence of `tf.TensorSpec` objects - specifying the shapes and dtypes of the Tensors that will be supplied to - this function. - - Returns: - A wrapper function that accepts arguments specified by `input_signature`. - """ - def wrapper(fn): - """Decorator that runs decorated function in eager mode.""" - def _run_eagerly(*inputs): # pylint: disable=missing-docstring - with context.eager_mode(): - constants = [_wrap_as_constant(value, tensor_spec) - for value, tensor_spec in zip(inputs, input_signature)] - output = fn(*constants) - if hasattr(output, '_make'): - return output._make([np.asarray(tensor) for tensor in output]) - if isinstance(output, (tuple, list)): - return [ - tensor.to_list() - if isinstance(tensor, tf.RaggedTensor) else np.asarray(tensor) - for tensor in output - ] - elif isinstance(output, tf.RaggedTensor): - return output.to_list() - else: - return np.asarray(output) + """Run the given function in eager mode. + + Args: + ---- + input_signature: A possibly nested sequence of `tf.TensorSpec` objects + specifying the shapes and dtypes of the Tensors that will be supplied to + this function. + + Returns: + ------- + A wrapper function that accepts arguments specified by `input_signature`. + """ - return _run_eagerly - return wrapper + def wrapper(fn): + """Decorator that runs decorated function in eager mode.""" + + def _run_eagerly(*inputs): # pylint: disable=missing-docstring + with context.eager_mode(): + constants = [ + _wrap_as_constant(value, tensor_spec) + for value, tensor_spec in zip(inputs, input_signature) + ] + output = fn(*constants) + if hasattr(output, "_make"): + return output._make([np.asarray(tensor) for tensor in output]) + if isinstance(output, (tuple, list)): + return [ + tensor.to_list() + if isinstance(tensor, tf.RaggedTensor) + else np.asarray(tensor) + for tensor in output + ] + elif isinstance(output, tf.RaggedTensor): + return output.to_list() + else: + return np.asarray(output) + + return _run_eagerly + + return wrapper def _tf_function_function_handler(input_signature): - """Call function in eager mode, but also wrapped in `tf.function`.""" - def wrapper(fn): - wrapped_fn = tf.function(fn, input_signature) - return _eager_function_handler(input_signature)(wrapped_fn) - return wrapper + """Call function in eager mode, but also wrapped in `tf.function`.""" + + def wrapper(fn): + wrapped_fn = tf.function(fn, input_signature) + return _eager_function_handler(input_signature)(wrapped_fn) + + return wrapper FUNCTION_HANDLERS = [ - dict(testcase_name='graph', - function_handler=_graph_function_handler), - dict(testcase_name='eager', - function_handler=_eager_function_handler), - dict(testcase_name='tf_function', - function_handler=_tf_function_function_handler) + dict(testcase_name="graph", function_handler=_graph_function_handler), + dict(testcase_name="eager", function_handler=_eager_function_handler), + dict(testcase_name="tf_function", function_handler=_tf_function_function_handler), ] def is_external_environment(): - return not os.environ.get('TEST_WORKSPACE', '').startswith('google') + return not os.environ.get("TEST_WORKSPACE", "").startswith("google") def skip_if_external_environment(reason): - if is_external_environment(): - raise unittest.SkipTest(reason) + if is_external_environment(): + raise unittest.SkipTest(reason) def skip_if_not_tf2(reason): - if not tf2.enabled(): - raise unittest.SkipTest(reason) + if not tf2.enabled(): + raise unittest.SkipTest(reason) def cross_with_function_handlers(parameters_list): - """Cross named parameters with all function handlers. + """Cross named parameters with all function handlers. - Takes a list of parameters suitable as an input to @named_parameters, - and crosses it with the set of function handlers. - A parameterized test function that uses this should have a parameter named - `function_handler`. + Takes a list of parameters suitable as an input to @named_parameters, + and crosses it with the set of function handlers. + A parameterized test function that uses this should have a parameter named + `function_handler`. - Args: - parameters_list: A list of dicts. + Args: + ---- + parameters_list: A list of dicts. - Returns: - A list of dicts. - """ - return cross_named_parameters(parameters_list, FUNCTION_HANDLERS) + Returns: + ------- + A list of dicts. + """ + return cross_named_parameters(parameters_list, FUNCTION_HANDLERS) class TransformTestCase(parameterized.TestCase, tf.test.TestCase): - """Base test class for testing tf-transform code.""" - - # Display context for failing rows in data assertions. - longMessage = True # pylint: disable=invalid-name - - def assertDataCloseOrEqual(self, a_data, b_data): - """Assert two datasets contain nearly equal values. - - Args: - a_data: a sequence of dicts whose values are - either strings, lists of strings, numeric types or a pair of - those. - b_data: same types as a_data - - Raises: - AssertionError: if the two datasets are not the same. - """ - msg = '' - try: - sorted_a, sorted_b = self._SortedData(a_data), self._SortedData(b_data) - self.assertEqual( - len(sorted_a), len(sorted_b), 'len(%r) != len(%r)' % (a_data, b_data)) - for i, (a_row, b_row) in enumerate(zip(sorted_a, sorted_b)): - self.assertCountEqual(a_row.keys(), b_row.keys(), msg='Row %d' % i) - for key in a_row.keys(): - a_value = a_row[key] - b_value = b_row[key] - msg = 'Row %d, key %s' % (i, key) - if isinstance(a_value, tuple): - self._assertValuesCloseOrEqual(a_value[0], b_value[0], msg=msg) - self._assertValuesCloseOrEqual(a_value[1], b_value[1], msg=msg) - else: - self._assertValuesCloseOrEqual(a_value, b_value, msg=msg) - except (AssertionError, TypeError) as e: - message = '{}\nCompared:\n{}\nvs.\n{}'.format(msg, a_data, b_data) - e.args = ((e.args[0] + ' : ' + message,) + e.args[1:]) - raise e - - def _assertValuesCloseOrEqual(self, a_value, b_value, msg=None): - if (isinstance(a_value, (bytes, str)) or isinstance(a_value, list) and - a_value and isinstance(a_value[0], (bytes, str)) or - isinstance(a_value, np.ndarray) and a_value.dtype == object): - self.assertAllEqual(a_value, b_value, msg=msg) - else: - # TODO(varshaan): Change atol only for tests for which 1e-6 is too strict. - self.assertAllClose(a_value, b_value, atol=1e-4, msg=msg) - - def AssertVocabularyContents(self, vocab_file_path, file_contents): - if vocab_file_path.endswith('.tfrecord.gz'): - file_lines = list( - tf.data.TFRecordDataset(vocab_file_path, - compression_type='GZIP').as_numpy_iterator()) - else: - with tf.io.gfile.GFile(vocab_file_path, 'rb') as f: - file_lines = f.read().splitlines() - - # Store frequency case. - if isinstance(file_contents[0], tuple): - word_and_frequency_list = [] - for content in file_lines: - frequency, word = content.split(b' ', 1) - # Split by comma for when the vocabulary file stores the result of - # per-key analyzers. - values = list(map(float, frequency.split(b','))) - word_and_frequency_list.append( - (word, values[0] if len(values) == 1 else values)) - - expected_words, expected_frequency = zip(*word_and_frequency_list) - actual_words, actual_frequency = zip(*file_contents) - self.assertAllEqual(actual_words, expected_words) - np.testing.assert_almost_equal( - expected_frequency, actual_frequency, decimal=6) - else: - self.assertAllEqual(file_lines, file_contents) - - def WriteRenderedDotFile(self, dot_string, output_file=None): - tf.compat.v1.logging.info( - 'Writing a rendered dot file is not yet supported.') - - def _NumpyArraysToLists(self, maybe_arrays): - return [ - x.tolist() if isinstance(x, np.ndarray) else x for x in maybe_arrays] - - def _SortedDicts(self, list_of_dicts): - # Sorts dicts by their unordered (key, value) pairs. We use string ordering - # to ensure consistent comparison of NaNs with numbers. - return sorted(list_of_dicts, key=lambda d: str(sorted(d.items()))) - - def _SortedData(self, list_of_dicts_of_arrays): - list_of_values = [ - self._NumpyArraysToLists(d.values()) for d in list_of_dicts_of_arrays - ] - list_of_keys = [d.keys() for d in list_of_dicts_of_arrays] - unsorted_dict_list = [ - dict(zip(a, b)) for a, b in zip(list_of_keys, list_of_values) - ] - return self._SortedDicts(unsorted_dict_list) + """Base test class for testing tf-transform code.""" + + # Display context for failing rows in data assertions. + longMessage = True # pylint: disable=invalid-name + + def assertDataCloseOrEqual(self, a_data, b_data): + """Assert two datasets contain nearly equal values. + + Args: + ---- + a_data: a sequence of dicts whose values are + either strings, lists of strings, numeric types or a pair of + those. + b_data: same types as a_data + + Raises: + ------ + AssertionError: if the two datasets are not the same. + """ + msg = "" + try: + sorted_a, sorted_b = self._SortedData(a_data), self._SortedData(b_data) + self.assertEqual( + len(sorted_a), len(sorted_b), "len(%r) != len(%r)" % (a_data, b_data) + ) + for i, (a_row, b_row) in enumerate(zip(sorted_a, sorted_b)): + self.assertCountEqual(a_row.keys(), b_row.keys(), msg="Row %d" % i) + for key in a_row.keys(): + a_value = a_row[key] + b_value = b_row[key] + msg = "Row %d, key %s" % (i, key) + if isinstance(a_value, tuple): + self._assertValuesCloseOrEqual(a_value[0], b_value[0], msg=msg) + self._assertValuesCloseOrEqual(a_value[1], b_value[1], msg=msg) + else: + self._assertValuesCloseOrEqual(a_value, b_value, msg=msg) + except (AssertionError, TypeError) as e: + message = f"{msg}\nCompared:\n{a_data}\nvs.\n{b_data}" + e.args = (e.args[0] + " : " + message,) + e.args[1:] + raise e + + def _assertValuesCloseOrEqual(self, a_value, b_value, msg=None): + if ( + isinstance(a_value, (bytes, str)) + or isinstance(a_value, list) + and a_value + and isinstance(a_value[0], (bytes, str)) + or isinstance(a_value, np.ndarray) + and a_value.dtype == object + ): + self.assertAllEqual(a_value, b_value, msg=msg) + else: + # TODO(varshaan): Change atol only for tests for which 1e-6 is too strict. + self.assertAllClose(a_value, b_value, atol=1e-4, msg=msg) + + def AssertVocabularyContents(self, vocab_file_path, file_contents): + if vocab_file_path.endswith(".tfrecord.gz"): + file_lines = list( + tf.data.TFRecordDataset( + vocab_file_path, compression_type="GZIP" + ).as_numpy_iterator() + ) + else: + with tf.io.gfile.GFile(vocab_file_path, "rb") as f: + file_lines = f.read().splitlines() + + # Store frequency case. + if isinstance(file_contents[0], tuple): + word_and_frequency_list = [] + for content in file_lines: + frequency, word = content.split(b" ", 1) + # Split by comma for when the vocabulary file stores the result of + # per-key analyzers. + values = list(map(float, frequency.split(b","))) + word_and_frequency_list.append( + (word, values[0] if len(values) == 1 else values) + ) + + expected_words, expected_frequency = zip(*word_and_frequency_list) + actual_words, actual_frequency = zip(*file_contents) + self.assertAllEqual(actual_words, expected_words) + np.testing.assert_almost_equal( + expected_frequency, actual_frequency, decimal=6 + ) + else: + self.assertAllEqual(file_lines, file_contents) + + def WriteRenderedDotFile(self, dot_string, output_file=None): + tf.compat.v1.logging.info("Writing a rendered dot file is not yet supported.") + + def _NumpyArraysToLists(self, maybe_arrays): + return [x.tolist() if isinstance(x, np.ndarray) else x for x in maybe_arrays] + + def _SortedDicts(self, list_of_dicts): + # Sorts dicts by their unordered (key, value) pairs. We use string ordering + # to ensure consistent comparison of NaNs with numbers. + return sorted(list_of_dicts, key=lambda d: str(sorted(d.items()))) + + def _SortedData(self, list_of_dicts_of_arrays): + list_of_values = [ + self._NumpyArraysToLists(d.values()) for d in list_of_dicts_of_arrays + ] + list_of_keys = [d.keys() for d in list_of_dicts_of_arrays] + unsorted_dict_list = [ + dict(zip(a, b)) for a, b in zip(list_of_keys, list_of_values) + ] + return self._SortedDicts(unsorted_dict_list) diff --git a/tensorflow_transform/test_case_test.py b/tensorflow_transform/test_case_test.py index 397ee66..6728660 100644 --- a/tensorflow_transform/test_case_test.py +++ b/tensorflow_transform/test_case_test.py @@ -14,73 +14,69 @@ """Tests for tensorflow_transform.test_case.""" import re +import unittest from tensorflow_transform import test_case -import unittest - class TftUnitTest(test_case.TransformTestCase): + def testCrossNamedParameters(self): + test_cases_1 = [ + {"testcase_name": "a_1_b_1", "a": 1, "b": 1}, + {"testcase_name": "a_3_b_3", "a": 3, "b": 3}, + ] + test_cases_2 = [ + {"testcase_name": "c_2", "c": 2}, + {"testcase_name": "c_4", "c": 4}, + ] + expected_cross = [ + {"testcase_name": "a_1_b_1_c_2", "a": 1, "b": 1, "c": 2}, + {"testcase_name": "a_1_b_1_c_4", "a": 1, "b": 1, "c": 4}, + {"testcase_name": "a_3_b_3_c_2", "a": 3, "b": 3, "c": 2}, + {"testcase_name": "a_3_b_3_c_4", "a": 3, "b": 3, "c": 4}, + ] + self.assertEqual( + test_case.cross_named_parameters(test_cases_1, test_cases_2), expected_cross + ) - def testCrossNamedParameters(self): - test_cases_1 = [ - {'testcase_name': 'a_1_b_1', 'a': 1, 'b': 1}, - {'testcase_name': 'a_3_b_3', 'a': 3, 'b': 3}, - ] - test_cases_2 = [ - {'testcase_name': 'c_2', 'c': 2}, - {'testcase_name': 'c_4', 'c': 4}, - ] - expected_cross = [ - {'testcase_name': 'a_1_b_1_c_2', 'a': 1, 'b': 1, 'c': 2}, - {'testcase_name': 'a_1_b_1_c_4', 'a': 1, 'b': 1, 'c': 4}, - {'testcase_name': 'a_3_b_3_c_2', 'a': 3, 'b': 3, 'c': 2}, - {'testcase_name': 'a_3_b_3_c_4', 'a': 3, 'b': 3, 'c': 4}, - ] - self.assertEqual( - test_case.cross_named_parameters(test_cases_1, test_cases_2), - expected_cross) - - def testCrossParameters(self): - test_cases_1 = [('a', 1), ('b', 2)] - test_cases_2 = [(True,), (False,)] - expected_cross = [ - ('a', 1, True), ('b', 2, True), - ('a', 1, False), ('b', 2, False), - ] - self.assertCountEqual( - test_case.cross_parameters(test_cases_1, test_cases_2), expected_cross) + def testCrossParameters(self): + test_cases_1 = [("a", 1), ("b", 2)] + test_cases_2 = [(True,), (False,)] + expected_cross = [ + ("a", 1, True), + ("b", 2, True), + ("a", 1, False), + ("b", 2, False), + ] + self.assertCountEqual( + test_case.cross_parameters(test_cases_1, test_cases_2), expected_cross + ) - def testAssertDataCloseOrEqual(self): - self.assertDataCloseOrEqual([{'a': 'first', - 'b': 1.0, - 'c': 5, - 'd': ('second', 2.0)}, - {'e': 2, - 'f': 3}], - [{'a': 'first', - 'b': 1.0000001, - 'c': 5, - 'd': ('second', 2.0000001)}, - {'e': 2, - 'f': 3}]) - with self.assertRaisesRegex(AssertionError, r'len\(.*\) != len\(\[\]\)'): - self.assertDataCloseOrEqual([{'a': 1}], []) - with self.assertRaisesRegex( - AssertionError, - re.compile('Element counts were not equal.*: Row 0', re.DOTALL), - ): - self.assertDataCloseOrEqual([{'a': 1}], [{'b': 1}]) - with self.assertRaisesRegex( - AssertionError, - re.compile('Not equal to tolerance.*: Row 0, key a', re.DOTALL), - ): - self.assertDataCloseOrEqual([{'a': 1}], [{'a': 2}]) + def testAssertDataCloseOrEqual(self): + self.assertDataCloseOrEqual( + [{"a": "first", "b": 1.0, "c": 5, "d": ("second", 2.0)}, {"e": 2, "f": 3}], + [ + {"a": "first", "b": 1.0000001, "c": 5, "d": ("second", 2.0000001)}, + {"e": 2, "f": 3}, + ], + ) + with self.assertRaisesRegex(AssertionError, r"len\(.*\) != len\(\[\]\)"): + self.assertDataCloseOrEqual([{"a": 1}], []) + with self.assertRaisesRegex( + AssertionError, + re.compile("Element counts were not equal.*: Row 0", re.DOTALL), + ): + self.assertDataCloseOrEqual([{"a": 1}], [{"b": 1}]) + with self.assertRaisesRegex( + AssertionError, + re.compile("Not equal to tolerance.*: Row 0, key a", re.DOTALL), + ): + self.assertDataCloseOrEqual([{"a": 1}], [{"a": 2}]) - @test_case.parameters((1, 'a'), (2, 'b')) - def testSampleParametrizedTestMethod(self, my_arg, my_other_arg): - self.assertIn((my_arg, my_other_arg), {(1, 'a'), (2, 'b')}) + @test_case.parameters((1, "a"), (2, "b")) + def testSampleParametrizedTestMethod(self, my_arg, my_other_arg): + self.assertIn((my_arg, my_other_arg), {(1, "a"), (2, "b")}) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow_transform/tf2_utils.py b/tensorflow_transform/tf2_utils.py index a7e1891..41072aa 100644 --- a/tensorflow_transform/tf2_utils.py +++ b/tensorflow_transform/tf2_utils.py @@ -18,181 +18,221 @@ from typing import Collection, Iterable, Mapping, Optional, Tuple import tensorflow as tf -from tensorflow_transform import common_types + # pylint: disable=g-direct-tensorflow-import from tensorflow.python import tf2 from tensorflow.python.framework.func_graph import FuncGraph + +from tensorflow_transform import common_types + # pylint: enable=g-direct-tensorflow-import def use_tf_compat_v1(force_tf_compat_v1: bool) -> bool: - """Evaluate from environment variables if TF should be used in compat.v1 mode.""" - # If tf.enable_v2_behavior has been called, but eager execution has been - # disabled, force compat v1 behavior. Hence, check - # `executing_eagerly_outside_functions` as well. - return (force_tf_compat_v1 or not tf2.enabled() or - not tf.compat.v1.executing_eagerly_outside_functions()) + """Evaluate from environment variables if TF should be used in compat.v1 mode.""" + # If tf.enable_v2_behavior has been called, but eager execution has been + # disabled, force compat v1 behavior. Hence, check + # `executing_eagerly_outside_functions` as well. + return ( + force_tf_compat_v1 + or not tf2.enabled() + or not tf.compat.v1.executing_eagerly_outside_functions() + ) def strip_and_get_tensors_and_control_dependencies( - flat_tensor_list: Iterable[tf.Tensor] + flat_tensor_list: Iterable[tf.Tensor], ) -> Tuple[Iterable[tf.Tensor], Iterable[tf.Operation]]: - """Strips automatic control dependencies from `flat_tensor_list`. - - Args: - flat_tensor_list: A flattened list of output tensors from a tf.function. - - Returns: - A tuple of: - Tensors from `flat_tensor_list` with control dependencies removed. - The set of control dependency ops that `flat_tensor_list` depended on. - """ - # If an automatic control dependency node was added, all tensors in - # `flat_tensor_list` will be the result of Identity ops with the original - # tensor as an input and the automatic control dependencies as control inputs. - if all(tensor.op.type == 'Identity' and len(tensor.op.inputs) == 1 - for tensor in flat_tensor_list): - control_dependency_ops = [t.op.control_inputs for t in flat_tensor_list] - return ([t.op.inputs[0] for t in flat_tensor_list], - set(itertools.chain(*control_dependency_ops))) - else: - return flat_tensor_list, set() - - -def supply_missing_tensor(batch_size: int, tensor_shape: tf.TensorShape, - tensor_dtype: tf.DType) -> tf.Tensor: - """Supplies a `tf.Tensor` compatible with `tensor`. - - Supports only string and numeric dtypes. - Args: - batch_size: an integer representing the size of the batch returned. - tensor_shape: a `tf.TensorShape`. The returned tensor will have shape - compatible with this. - tensor_dtype: The dtype of the returned tensors. - - Returns: - A batch of `tf.Tensor` tensors. - """ - # If tensor rank is 0 or unknown, return a scalar. - if tensor_shape.ndims is None or tensor_shape.ndims == 0: - return tf.zeros([], dtype=tensor_dtype) - - input_shape = tensor_shape.as_list() - result_shape = [input_shape[0] or batch_size] - - for shape in input_shape[1:]: - if shape is None: - result_shape = result_shape + [1] + """Strips automatic control dependencies from `flat_tensor_list`. + + Args: + ---- + flat_tensor_list: A flattened list of output tensors from a tf.function. + + Returns: + ------- + A tuple of: + Tensors from `flat_tensor_list` with control dependencies removed. + The set of control dependency ops that `flat_tensor_list` depended on. + """ + # If an automatic control dependency node was added, all tensors in + # `flat_tensor_list` will be the result of Identity ops with the original + # tensor as an input and the automatic control dependencies as control inputs. + if all( + tensor.op.type == "Identity" and len(tensor.op.inputs) == 1 + for tensor in flat_tensor_list + ): + control_dependency_ops = [t.op.control_inputs for t in flat_tensor_list] + return ( + [t.op.inputs[0] for t in flat_tensor_list], + set(itertools.chain(*control_dependency_ops)), + ) else: - result_shape = result_shape + [shape] - return tf.zeros(result_shape, dtype=tensor_dtype) + return flat_tensor_list, set() + + +def supply_missing_tensor( + batch_size: int, tensor_shape: tf.TensorShape, tensor_dtype: tf.DType +) -> tf.Tensor: + """Supplies a `tf.Tensor` compatible with `tensor`. + + Supports only string and numeric dtypes. + + Args: + ---- + batch_size: an integer representing the size of the batch returned. + tensor_shape: a `tf.TensorShape`. The returned tensor will have shape + compatible with this. + tensor_dtype: The dtype of the returned tensors. + + Returns: + ------- + A batch of `tf.Tensor` tensors. + """ + # If tensor rank is 0 or unknown, return a scalar. + if tensor_shape.ndims is None or tensor_shape.ndims == 0: + return tf.zeros([], dtype=tensor_dtype) + + input_shape = tensor_shape.as_list() + result_shape = [input_shape[0] or batch_size] + + for shape in input_shape[1:]: + if shape is None: + result_shape = result_shape + [1] + else: + result_shape = result_shape + [shape] + return tf.zeros(result_shape, dtype=tensor_dtype) def supply_missing_inputs( structured_inputs: Mapping[str, common_types.TensorType], batch_size: int, - missing_keys: Optional[Collection[str]] = None + missing_keys: Optional[Collection[str]] = None, ) -> Mapping[str, common_types.TensorType]: - """Supply inputs for unfed features. - - Supports only tf.Tensor, tf.SparseTensor and tf.RaggedTensor. - - Note: Since this returns placeholders, it should be called from within a graph - context. - - Args: - structured_inputs: a dict from keys to batches of placeholder graph tensors. - batch_size: an integer representing the size of the batch returned. - missing_keys: (Optional) a subset of the keys of `structured_inputs` for - which concrete tensors need to be supplied. If `None`, tensors are - supplied for all keys. - - Returns: - A batch of placeholders with default values having the same structure as in - `structured_inputs` for the keys in `missing_keys`. - """ - missing_keys = missing_keys or list(structured_inputs) - # Return placeholders to ensure that tensor shape is not constrained to the - # dummy shape of the missing tensor created here during tracing. - result = {} - for key in missing_keys: - tensor = structured_inputs[key] - if isinstance(tensor, tf.Tensor) or (isinstance(tensor, tf.RaggedTensor) and - tensor.ragged_rank == 0): - missing_tensor = supply_missing_tensor(batch_size, tensor.shape, - tensor.dtype) - result[key] = tf.raw_ops.PlaceholderWithDefault( - input=missing_tensor, shape=tensor.shape) - elif isinstance(tensor, tf.SparseTensor): - values = supply_missing_tensor(batch_size, tensor.values.shape, - tensor.values.dtype) - dense_rank = tensor.shape.ndims - # Since values is always a 1-D tensor, set index for every ith value in - # values to be [i 0 0 ...]. Each index should be compatible with the - # rank of the SparseTensor. Hence, the number of 0s is dense_rank-1. - actual_batch_size = tf.shape(values)[0] - indices = tf.stack( - [tf.range(actual_batch_size, dtype=tf.int64)] + - [tf.zeros(actual_batch_size, dtype=tf.int64)] * (dense_rank - 1), - axis=1) - dense_shape = tf.cast( - [actual_batch_size] + [1] * (dense_rank - 1), dtype=tf.int64) - - indices = tf.raw_ops.PlaceholderWithDefault( - input=indices, shape=tensor.indices.shape) - values = tf.raw_ops.PlaceholderWithDefault( - input=values, shape=tensor.values.shape) - dense_shape = tf.raw_ops.PlaceholderWithDefault( - input=dense_shape, shape=tensor.dense_shape.shape) - result[key] = tf.SparseTensor( - indices=indices, values=values, dense_shape=dense_shape) - elif isinstance(tensor, tf.RaggedTensor): - # Builds a ragged tensor similar to tf.ragged.placeholder, except with - # default values for all components. - ragged_rank = tensor.ragged_rank - values = supply_missing_tensor(batch_size, tensor.flat_values.shape, - tensor.flat_values.dtype) - result[key] = tf.raw_ops.PlaceholderWithDefault( - input=values, shape=tensor.flat_values.shape) - for _ in range(ragged_rank): - if isinstance(values, tf.RaggedTensor): - values_batch_size = values.bounding_shape(axis=0) + """Supply inputs for unfed features. + + Supports only tf.Tensor, tf.SparseTensor and tf.RaggedTensor. + + Note: Since this returns placeholders, it should be called from within a graph + context. + + Args: + ---- + structured_inputs: a dict from keys to batches of placeholder graph tensors. + batch_size: an integer representing the size of the batch returned. + missing_keys: (Optional) a subset of the keys of `structured_inputs` for + which concrete tensors need to be supplied. If `None`, tensors are + supplied for all keys. + + Returns: + ------- + A batch of placeholders with default values having the same structure as in + `structured_inputs` for the keys in `missing_keys`. + """ + missing_keys = missing_keys or list(structured_inputs) + # Return placeholders to ensure that tensor shape is not constrained to the + # dummy shape of the missing tensor created here during tracing. + result = {} + for key in missing_keys: + tensor = structured_inputs[key] + if isinstance(tensor, tf.Tensor) or ( + isinstance(tensor, tf.RaggedTensor) and tensor.ragged_rank == 0 + ): + missing_tensor = supply_missing_tensor( + batch_size, tensor.shape, tensor.dtype + ) + result[key] = tf.raw_ops.PlaceholderWithDefault( + input=missing_tensor, shape=tensor.shape + ) + elif isinstance(tensor, tf.SparseTensor): + values = supply_missing_tensor( + batch_size, tensor.values.shape, tensor.values.dtype + ) + dense_rank = tensor.shape.ndims + # Since values is always a 1-D tensor, set index for every ith value in + # values to be [i 0 0 ...]. Each index should be compatible with the + # rank of the SparseTensor. Hence, the number of 0s is dense_rank-1. + actual_batch_size = tf.shape(values)[0] + indices = tf.stack( + [tf.range(actual_batch_size, dtype=tf.int64)] + + [tf.zeros(actual_batch_size, dtype=tf.int64)] * (dense_rank - 1), + axis=1, + ) + dense_shape = tf.cast( + [actual_batch_size] + [1] * (dense_rank - 1), dtype=tf.int64 + ) + + indices = tf.raw_ops.PlaceholderWithDefault( + input=indices, shape=tensor.indices.shape + ) + values = tf.raw_ops.PlaceholderWithDefault( + input=values, shape=tensor.values.shape + ) + dense_shape = tf.raw_ops.PlaceholderWithDefault( + input=dense_shape, shape=tensor.dense_shape.shape + ) + result[key] = tf.SparseTensor( + indices=indices, values=values, dense_shape=dense_shape + ) + elif isinstance(tensor, tf.RaggedTensor): + # Builds a ragged tensor similar to tf.ragged.placeholder, except with + # default values for all components. + ragged_rank = tensor.ragged_rank + values = supply_missing_tensor( + batch_size, tensor.flat_values.shape, tensor.flat_values.dtype + ) + result[key] = tf.raw_ops.PlaceholderWithDefault( + input=values, shape=tensor.flat_values.shape + ) + for _ in range(ragged_rank): + if isinstance(values, tf.RaggedTensor): + values_batch_size = values.bounding_shape(axis=0) + else: + values_batch_size = tf.shape(values)[0] + row_splits = tf.range(values_batch_size + 1, dtype=tf.int64) + values = tf.RaggedTensor.from_row_splits( + values, row_splits, validate=False + ) + row_splits = tf.raw_ops.PlaceholderWithDefault( + input=row_splits, shape=[None] + ) + result[key] = tf.RaggedTensor.from_row_splits( + result[key], row_splits, validate=False + ) else: - values_batch_size = tf.shape(values)[0] - row_splits = tf.range(values_batch_size + 1, dtype=tf.int64) - values = tf.RaggedTensor.from_row_splits( - values, row_splits, validate=False) - row_splits = tf.raw_ops.PlaceholderWithDefault( - input=row_splits, shape=[None]) - result[key] = tf.RaggedTensor.from_row_splits( - result[key], row_splits, validate=False) - else: - raise ValueError('Received unsupported input tensor type. Only ' - 'dense/sparse/ragged tensors are currently supported.') - return result + raise ValueError( + "Received unsupported input tensor type. Only " + "dense/sparse/ragged tensors are currently supported." + ) + return result def get_structured_inputs_from_func_graph( - func_graph: FuncGraph) -> Mapping[str, common_types.TensorType]: - """Get structured inputs to a FuncGraph. - - Args: - func_graph: A `FuncGraph` object. - - Returns: - Input graph tensors of `func_graph` formatted as possibly-nested python - objects received by it. - """ - # structured_input_signature is a tuple of (args, kwargs). [0][0] retrieves - # the structure of the first arg, which for the preprocessing function is - # the dictionary of features. - input_signature = func_graph.structured_input_signature[0][0] - num_captures = len(func_graph.internal_captures + - func_graph.deferred_internal_captures) - # `func_graph.inputs` contains placeholders that represent regular inputs - # followed by captured inputs. We are only interested in the regular inputs. - graph_inputs = copy.copy(func_graph.inputs) - if num_captures > 0: - graph_inputs = graph_inputs[:-num_captures] - return tf.nest.pack_sequence_as( - input_signature, graph_inputs, expand_composites=True) + func_graph: FuncGraph, +) -> Mapping[str, common_types.TensorType]: + """Get structured inputs to a FuncGraph. + + Args: + ---- + func_graph: A `FuncGraph` object. + + Returns: + ------- + Input graph tensors of `func_graph` formatted as possibly-nested python + objects received by it. + """ + # structured_input_signature is a tuple of (args, kwargs). [0][0] retrieves + # the structure of the first arg, which for the preprocessing function is + # the dictionary of features. + input_signature = func_graph.structured_input_signature[0][0] + num_captures = len( + func_graph.internal_captures + func_graph.deferred_internal_captures + ) + # `func_graph.inputs` contains placeholders that represent regular inputs + # followed by captured inputs. We are only interested in the regular inputs. + graph_inputs = copy.copy(func_graph.inputs) + if num_captures > 0: + graph_inputs = graph_inputs[:-num_captures] + return tf.nest.pack_sequence_as( + input_signature, graph_inputs, expand_composites=True + ) diff --git a/tensorflow_transform/tf2_utils_test.py b/tensorflow_transform/tf2_utils_test.py index fafc5b8..0f16027 100644 --- a/tensorflow_transform/tf2_utils_test.py +++ b/tensorflow_transform/tf2_utils_test.py @@ -14,9 +14,10 @@ """Tests for tensorflow_transform.tf2_utils.""" import itertools + import tensorflow as tf -from tensorflow_transform import tf2_utils -from tensorflow_transform import test_case + +from tensorflow_transform import test_case, tf2_utils _TEST_BATCH_SIZES = [1, 10] _TEST_DTYPES = [ @@ -31,85 +32,101 @@ _TEST_TENSORS_TYPES = [ (lambda dtype: tf.TensorSpec([None], dtype=dtype), tf.Tensor, []), (lambda dtype: tf.TensorSpec([None, 2], dtype=dtype), tf.Tensor, [2]), - (lambda dtype: tf.RaggedTensorSpec([None, None], dtype=dtype), - tf.RaggedTensor, [None]), + ( + lambda dtype: tf.RaggedTensorSpec([None, None], dtype=dtype), + tf.RaggedTensor, + [None], + ), ( lambda dtype: tf.RaggedTensorSpec( # pylint: disable=g-long-lambda - [None, None, 2], - dtype=dtype, - ragged_rank=1), + [None, None, 2], dtype=dtype, ragged_rank=1 + ), tf.RaggedTensor, - [None, 2]), + [None, 2], + ), ] class TF2UtilsTest(test_case.TransformTestCase): + def test_strip_and_get_tensors_and_control_dependencies(self): + @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int64)]) + def func(x): + with tf.init_scope(): + initializer_1 = tf.lookup.KeyValueTensorInitializer( + [0, 1, 2], + ["a", "b", "c"], + key_dtype=tf.int64, + value_dtype=tf.string, + ) + table_1 = tf.lookup.StaticHashTable(initializer_1, default_value="NAN") + size = table_1.size() + initializer_2 = tf.lookup.KeyValueTensorInitializer( + ["a", "b", "c"], + [-1, 0, 1], + key_dtype=tf.string, + value_dtype=tf.int64, + ) + table_2 = tf.lookup.StaticHashTable(initializer_2, default_value=-777) + y = table_1.lookup(x) + _ = table_2.lookup(y) + z = x + size + return {"x": x, "z": z} - def test_strip_and_get_tensors_and_control_dependencies(self): - - @tf.function(input_signature=[tf.TensorSpec([], dtype=tf.int64)]) - def func(x): - with tf.init_scope(): - initializer_1 = tf.lookup.KeyValueTensorInitializer( - [0, 1, 2], ['a', 'b', 'c'], - key_dtype=tf.int64, - value_dtype=tf.string) - table_1 = tf.lookup.StaticHashTable(initializer_1, default_value='NAN') - size = table_1.size() - initializer_2 = tf.lookup.KeyValueTensorInitializer( - ['a', 'b', 'c'], [-1, 0, 1], - key_dtype=tf.string, - value_dtype=tf.int64) - table_2 = tf.lookup.StaticHashTable(initializer_2, default_value=-777) - y = table_1.lookup(x) - _ = table_2.lookup(y) - z = x + size - return {'x': x, 'z': z} - - concrete_function = func.get_concrete_function() - flat_outputs = tf.nest.flatten( - concrete_function.structured_outputs, expand_composites=True) - expected_flat_outputs = [t.op.inputs[0] for t in flat_outputs] - expected_control_dependencies = itertools.chain( - *[t.op.control_inputs for t in flat_outputs]) - new_flat_outputs, control_dependencies = ( - tf2_utils.strip_and_get_tensors_and_control_dependencies(flat_outputs)) - self.assertEqual(new_flat_outputs, expected_flat_outputs) - self.assertEqual(control_dependencies, set(expected_control_dependencies)) + concrete_function = func.get_concrete_function() + flat_outputs = tf.nest.flatten( + concrete_function.structured_outputs, expand_composites=True + ) + expected_flat_outputs = [t.op.inputs[0] for t in flat_outputs] + expected_control_dependencies = itertools.chain( + *[t.op.control_inputs for t in flat_outputs] + ) + new_flat_outputs, control_dependencies = ( + tf2_utils.strip_and_get_tensors_and_control_dependencies(flat_outputs) + ) + self.assertEqual(new_flat_outputs, expected_flat_outputs) + self.assertEqual(control_dependencies, set(expected_control_dependencies)) - @test_case.parameters(*test_case.cross_parameters( - [(x,) for x in _TEST_BATCH_SIZES], - [(x,) for x in _TEST_DTYPES], - _TEST_TENSORS_TYPES, - )) - def test_supply_missing_tensor_inputs(self, batch_size, dtype, - type_spec_getter, tensor_type, - inner_shape): - test_case.skip_if_not_tf2('Tensorflow 2.x required.') + @test_case.parameters( + *test_case.cross_parameters( + [(x,) for x in _TEST_BATCH_SIZES], + [(x,) for x in _TEST_DTYPES], + _TEST_TENSORS_TYPES, + ) + ) + def test_supply_missing_tensor_inputs( + self, batch_size, dtype, type_spec_getter, tensor_type, inner_shape + ): + test_case.skip_if_not_tf2("Tensorflow 2.x required.") - @tf.function(input_signature=[{ - 'x_1': tf.TensorSpec([None], dtype=tf.int32), - 'x_2': type_spec_getter(dtype), - }]) - def foo(inputs): - return inputs + @tf.function( + input_signature=[ + { + "x_1": tf.TensorSpec([None], dtype=tf.int32), + "x_2": type_spec_getter(dtype), + } + ] + ) + def foo(inputs): + return inputs - conc_fn = foo.get_concrete_function() - # structured_input_signature is a tuple of (args, kwargs). [0][0] retrieves - # the structure of the first arg, which for `foo` is `inputs`. - structured_inputs = tf.nest.pack_sequence_as( - conc_fn.structured_input_signature[0][0], - conc_fn.inputs, - expand_composites=True) - missing_keys = ['x_2'] - result = tf2_utils.supply_missing_inputs(structured_inputs, batch_size, - missing_keys) + conc_fn = foo.get_concrete_function() + # structured_input_signature is a tuple of (args, kwargs). [0][0] retrieves + # the structure of the first arg, which for `foo` is `inputs`. + structured_inputs = tf.nest.pack_sequence_as( + conc_fn.structured_input_signature[0][0], + conc_fn.inputs, + expand_composites=True, + ) + missing_keys = ["x_2"] + result = tf2_utils.supply_missing_inputs( + structured_inputs, batch_size, missing_keys + ) - self.assertCountEqual(missing_keys, result.keys()) - self.assertIsInstance(result['x_2'], tensor_type) - self.assertEqual(result['x_2'].shape.as_list(), [batch_size] + inner_shape) - self.assertEqual(result['x_2'].dtype, dtype) + self.assertCountEqual(missing_keys, result.keys()) + self.assertIsInstance(result["x_2"], tensor_type) + self.assertEqual(result["x_2"].shape.as_list(), [batch_size] + inner_shape) + self.assertEqual(result["x_2"].dtype, dtype) -if __name__ == '__main__': - test_case.main() +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/tf_metadata/dataset_metadata.py b/tensorflow_transform/tf_metadata/dataset_metadata.py index 5af3892..8fec233 100644 --- a/tensorflow_transform/tf_metadata/dataset_metadata.py +++ b/tensorflow_transform/tf_metadata/dataset_metadata.py @@ -15,50 +15,51 @@ from typing import Mapping, Optional, Type, TypeVar +from tensorflow_metadata.proto.v0 import schema_pb2 + from tensorflow_transform import common_types from tensorflow_transform.tf_metadata import schema_utils -from tensorflow_metadata.proto.v0 import schema_pb2 -_DatasetMetadataType = TypeVar('_DatasetMetadataType', bound='DatasetMetadata') +_DatasetMetadataType = TypeVar("_DatasetMetadataType", bound="DatasetMetadata") class DatasetMetadata: - """Metadata about a dataset used for the "instance dict" format. + """Metadata about a dataset used for the "instance dict" format. - Caution: The "instance dict" format used with `DatasetMetadata` is much less - efficient than TFXIO. For any serious workloads you should use TFXIO with a - `tfxio.TensorAdapterConfig` instance as the metadata. Refer to - [Get started with TF-Transform](https://www.tensorflow.org/tfx/transform/get_started#data_formats_and_schema) - for more details. + Caution: The "instance dict" format used with `DatasetMetadata` is much less + efficient than TFXIO. For any serious workloads you should use TFXIO with a + `tfxio.TensorAdapterConfig` instance as the metadata. Refer to + [Get started with TF-Transform](https://www.tensorflow.org/tfx/transform/get_started#data_formats_and_schema) + for more details. - This is an in-memory representation that may be serialized and deserialized to - and from a variety of disk representations. - """ + This is an in-memory representation that may be serialized and deserialized to + and from a variety of disk representations. + """ - def __init__(self, schema: schema_pb2.Schema): - self._schema = schema - self._output_record_batches = True + def __init__(self, schema: schema_pb2.Schema): + self._schema = schema + self._output_record_batches = True - @classmethod - def from_feature_spec( - cls: Type[_DatasetMetadataType], - feature_spec: Mapping[str, common_types.FeatureSpecType], - domains: Optional[Mapping[str, common_types.DomainType]] = None - ) -> _DatasetMetadataType: - """Creates a DatasetMetadata from a TF feature spec dict.""" - return cls(schema_utils.schema_from_feature_spec(feature_spec, domains)) + @classmethod + def from_feature_spec( + cls: Type[_DatasetMetadataType], + feature_spec: Mapping[str, common_types.FeatureSpecType], + domains: Optional[Mapping[str, common_types.DomainType]] = None, + ) -> _DatasetMetadataType: + """Creates a DatasetMetadata from a TF feature spec dict.""" + return cls(schema_utils.schema_from_feature_spec(feature_spec, domains)) - @property - def schema(self) -> schema_pb2.Schema: - return self._schema + @property + def schema(self) -> schema_pb2.Schema: + return self._schema - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.schema == other.schema - return NotImplemented + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.schema == other.schema + return NotImplemented - def __ne__(self, other): - return not self == other + def __ne__(self, other): + return not self == other - def __repr__(self): - return self.__dict__.__repr__() + def __repr__(self): + return self.__dict__.__repr__() diff --git a/tensorflow_transform/tf_metadata/dataset_metadata_test.py b/tensorflow_transform/tf_metadata/dataset_metadata_test.py index e4576b1..c1f81ea 100644 --- a/tensorflow_transform/tf_metadata/dataset_metadata_test.py +++ b/tensorflow_transform/tf_metadata/dataset_metadata_test.py @@ -13,18 +13,18 @@ # limitations under the License. """Tests for dataset_metadata.""" -from tensorflow_transform.tf_metadata import test_common -from tensorflow_transform.tf_metadata import dataset_metadata import unittest +from tensorflow_transform.tf_metadata import dataset_metadata, test_common -class DatasetSchemaTest(unittest.TestCase): - def test_sanity(self): - metadata = dataset_metadata.DatasetMetadata.from_feature_spec( - test_common.test_feature_spec) - self.assertEqual(metadata.schema, test_common.get_test_schema()) +class DatasetSchemaTest(unittest.TestCase): + def test_sanity(self): + metadata = dataset_metadata.DatasetMetadata.from_feature_spec( + test_common.test_feature_spec + ) + self.assertEqual(metadata.schema, test_common.get_test_schema()) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow_transform/tf_metadata/metadata_io.py b/tensorflow_transform/tf_metadata/metadata_io.py index 7e87177..cb98977 100644 --- a/tensorflow_transform/tf_metadata/metadata_io.py +++ b/tensorflow_transform/tf_metadata/metadata_io.py @@ -16,114 +16,121 @@ import json import os - import tensorflow as tf -from tensorflow_transform.tf_metadata import dataset_metadata -from tensorflow_transform.tf_metadata import schema_utils - from google.protobuf import text_format -from tensorflow.python.lib.io import file_io # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.lib.io import ( + file_io, # pylint: disable=g-direct-tensorflow-import +) from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils + def read_metadata(path): - """Load metadata in JSON format from a path into a new DatasetMetadata.""" - schema_file = os.path.join(path, 'schema.pbtxt') - legacy_schema_file = os.path.join(path, 'v1-json', 'schema.json') - if file_io.file_exists(schema_file): - text_proto = file_io.FileIO(schema_file, 'r').read() - schema_proto = text_format.Parse(text_proto, schema_pb2.Schema(), - allow_unknown_extension=True) - elif file_io.file_exists(legacy_schema_file): - schema_json = file_io.FileIO(legacy_schema_file, 'r').read() - schema_proto = _parse_schema_json(schema_json) - else: - raise IOError( - 'Schema file {} does not exist and neither did legacy format file ' - '{}'.format(schema_file, legacy_schema_file)) - return dataset_metadata.DatasetMetadata(schema_proto) + """Load metadata in JSON format from a path into a new DatasetMetadata.""" + schema_file = os.path.join(path, "schema.pbtxt") + legacy_schema_file = os.path.join(path, "v1-json", "schema.json") + if file_io.file_exists(schema_file): + text_proto = file_io.FileIO(schema_file, "r").read() + schema_proto = text_format.Parse( + text_proto, schema_pb2.Schema(), allow_unknown_extension=True + ) + elif file_io.file_exists(legacy_schema_file): + schema_json = file_io.FileIO(legacy_schema_file, "r").read() + schema_proto = _parse_schema_json(schema_json) + else: + raise OSError( + f"Schema file {schema_file} does not exist and neither did legacy format file " + f"{legacy_schema_file}" + ) + return dataset_metadata.DatasetMetadata(schema_proto) def _parse_schema_json(schema_json): - """Translate a JSON schema into a Schema proto.""" - schema_dict = json.loads(schema_json) - feature_spec = { - feature_dict['name']: _column_schema_from_json(feature_dict) - for feature_dict in schema_dict.get('feature', []) - } - domains = { - feature_dict['name']: _domain_from_json(feature_dict['domain']) - for feature_dict in schema_dict.get('feature', []) - } - return schema_utils.schema_from_feature_spec(feature_spec, domains) + """Translate a JSON schema into a Schema proto.""" + schema_dict = json.loads(schema_json) + feature_spec = { + feature_dict["name"]: _column_schema_from_json(feature_dict) + for feature_dict in schema_dict.get("feature", []) + } + domains = { + feature_dict["name"]: _domain_from_json(feature_dict["domain"]) + for feature_dict in schema_dict.get("feature", []) + } + return schema_utils.schema_from_feature_spec(feature_spec, domains) def _column_schema_from_json(feature_dict): - """Translate a JSON feature dict into a feature spec.""" - dtype = _dtype_from_json(feature_dict['domain']) - tf_options = feature_dict['parsingOptions']['tfOptions'] - if tf_options.get('fixedLenFeature') is not None: - default_value = None - try: - # int() is needed because protobuf JSON encodes int64 as string - default_value = _convert_scalar_or_list( - int, tf_options['fixedLenFeature']['intDefaultValue']) - except KeyError: - try: - default_value = tf_options['fixedLenFeature']['stringDefaultValue'] - except KeyError: + """Translate a JSON feature dict into a feature spec.""" + dtype = _dtype_from_json(feature_dict["domain"]) + tf_options = feature_dict["parsingOptions"]["tfOptions"] + if tf_options.get("fixedLenFeature") is not None: + default_value = None try: - default_value = tf_options['fixedLenFeature']['floatDefaultValue'] + # int() is needed because protobuf JSON encodes int64 as string + default_value = _convert_scalar_or_list( + int, tf_options["fixedLenFeature"]["intDefaultValue"] + ) except KeyError: - pass - axes = feature_dict['fixedShape'].get('axis', []) - shape = [int(axis['size']) for axis in axes] - return tf.io.FixedLenFeature(shape, dtype, default_value) - elif tf_options.get('varLenFeature') is not None: - return tf.io.VarLenFeature(dtype) - else: - raise ValueError('Could not interpret tfOptions: {}'.format(tf_options)) + try: + default_value = tf_options["fixedLenFeature"]["stringDefaultValue"] + except KeyError: + try: + default_value = tf_options["fixedLenFeature"]["floatDefaultValue"] + except KeyError: + pass + axes = feature_dict["fixedShape"].get("axis", []) + shape = [int(axis["size"]) for axis in axes] + return tf.io.FixedLenFeature(shape, dtype, default_value) + elif tf_options.get("varLenFeature") is not None: + return tf.io.VarLenFeature(dtype) + else: + raise ValueError(f"Could not interpret tfOptions: {tf_options}") def _domain_from_json(domain): - """Translate a JSON domain dict into an IntDomain or None.""" - if domain.get('ints') is not None: - def maybe_to_int(s): - return int(s) if s is not None else None - return schema_pb2.IntDomain( - min=maybe_to_int(domain['ints'].get('min')), - max=maybe_to_int(domain['ints'].get('max')), - is_categorical=domain['ints'].get('isCategorical')) - return None + """Translate a JSON domain dict into an IntDomain or None.""" + if domain.get("ints") is not None: + + def maybe_to_int(s): + return int(s) if s is not None else None + + return schema_pb2.IntDomain( + min=maybe_to_int(domain["ints"].get("min")), + max=maybe_to_int(domain["ints"].get("max")), + is_categorical=domain["ints"].get("isCategorical"), + ) + return None def _dtype_from_json(domain): - """Translate a JSON domain dict into a tf.DType.""" - if domain.get('ints') is not None: - return tf.int64 - if domain.get('floats') is not None: - return tf.float32 - if domain.get('strings') is not None: - return tf.string - raise ValueError('Unknown domain: {}'.format(domain)) + """Translate a JSON domain dict into a tf.DType.""" + if domain.get("ints") is not None: + return tf.int64 + if domain.get("floats") is not None: + return tf.float32 + if domain.get("strings") is not None: + return tf.string + raise ValueError(f"Unknown domain: {domain}") def write_metadata(metadata, path): - """Write metadata to given path, in JSON format. + """Write metadata to given path, in JSON format. - Args: - metadata: A `DatasetMetadata` to write. - path: a path to a directory where metadata should be written. - """ - if not file_io.file_exists(path): - file_io.recursive_create_dir(path) - schema_file = os.path.join(path, 'schema.pbtxt') - ascii_proto = text_format.MessageToString(metadata.schema) - file_io.atomic_write_string_to_file(schema_file, ascii_proto, overwrite=True) + Args: + ---- + metadata: A `DatasetMetadata` to write. + path: a path to a directory where metadata should be written. + """ + if not file_io.file_exists(path): + file_io.recursive_create_dir(path) + schema_file = os.path.join(path, "schema.pbtxt") + ascii_proto = text_format.MessageToString(metadata.schema) + file_io.atomic_write_string_to_file(schema_file, ascii_proto, overwrite=True) def _convert_scalar_or_list(fn, scalar_or_list): - if isinstance(scalar_or_list, list): - return list(map(fn, scalar_or_list)) - else: - return fn(scalar_or_list) + if isinstance(scalar_or_list, list): + return list(map(fn, scalar_or_list)) + else: + return fn(scalar_or_list) diff --git a/tensorflow_transform/tf_metadata/metadata_io_test.py b/tensorflow_transform/tf_metadata/metadata_io_test.py index 097139e..76a3190 100644 --- a/tensorflow_transform/tf_metadata/metadata_io_test.py +++ b/tensorflow_transform/tf_metadata/metadata_io_test.py @@ -11,19 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for dataset_metadata. -""" +"""Tests for dataset_metadata.""" import os import tempfile - -from tensorflow_transform.tf_metadata import test_common -from tensorflow_transform.tf_metadata import dataset_metadata -from tensorflow_transform.tf_metadata import metadata_io import unittest -from tensorflow.python.lib.io import file_io # pylint: disable=g-direct-tensorflow-import +from tensorflow.python.lib.io import ( + file_io, # pylint: disable=g-direct-tensorflow-import +) +from tensorflow_transform.tf_metadata import dataset_metadata, metadata_io, test_common _SCHEMA_WITH_INVALID_KEYS = """ { @@ -58,24 +56,24 @@ class SchemaIOv1JsonTest(unittest.TestCase): - - def _write_schema_to_disk(self, basedir, schema_string): - version_basedir = os.path.join(basedir, 'v1-json') - - # Write a proto by hand to disk - file_io.recursive_create_dir(version_basedir) - file_io.write_string_to_file(os.path.join(version_basedir, 'schema.json'), - schema_string) - - def test_read_with_invalid_keys(self): - # TODO(b/123241798): use TEST_TMPDIR - basedir = tempfile.mkdtemp() - self._write_schema_to_disk(basedir, _SCHEMA_WITH_INVALID_KEYS) - - def test_read_features_default_axis(self): - # TODO(b/123241798): use TEST_TMPDIR - basedir = tempfile.mkdtemp() - schema_no_sparse_features = """ + def _write_schema_to_disk(self, basedir, schema_string): + version_basedir = os.path.join(basedir, "v1-json") + + # Write a proto by hand to disk + file_io.recursive_create_dir(version_basedir) + file_io.write_string_to_file( + os.path.join(version_basedir, "schema.json"), schema_string + ) + + def test_read_with_invalid_keys(self): + # TODO(b/123241798): use TEST_TMPDIR + basedir = tempfile.mkdtemp() + self._write_schema_to_disk(basedir, _SCHEMA_WITH_INVALID_KEYS) + + def test_read_features_default_axis(self): + # TODO(b/123241798): use TEST_TMPDIR + basedir = tempfile.mkdtemp() + schema_no_sparse_features = """ { "feature": [{ "name": "my_key", @@ -92,13 +90,13 @@ def test_read_features_default_axis(self): }] } """ - self._write_schema_to_disk(basedir, schema_no_sparse_features) - _ = metadata_io.read_metadata(basedir) + self._write_schema_to_disk(basedir, schema_no_sparse_features) + _ = metadata_io.read_metadata(basedir) - def test_read_features(self): - # TODO(b/123241798): use TEST_TMPDIR - basedir = tempfile.mkdtemp() - schema_no_sparse_features = """ + def test_read_features(self): + # TODO(b/123241798): use TEST_TMPDIR + basedir = tempfile.mkdtemp() + schema_no_sparse_features = """ { "feature": [{ "name": "my_key", @@ -119,20 +117,21 @@ def test_read_features(self): }] } """ - self._write_schema_to_disk(basedir, schema_no_sparse_features) - _ = metadata_io.read_metadata(basedir) + self._write_schema_to_disk(basedir, schema_no_sparse_features) + _ = metadata_io.read_metadata(basedir) - def test_write_and_read(self): - # TODO(b/123241798): use TEST_TMPDIR - basedir = tempfile.mkdtemp() - original = dataset_metadata.DatasetMetadata( - schema=test_common.get_test_schema()) + def test_write_and_read(self): + # TODO(b/123241798): use TEST_TMPDIR + basedir = tempfile.mkdtemp() + original = dataset_metadata.DatasetMetadata( + schema=test_common.get_test_schema() + ) - metadata_io.write_metadata(original, basedir) - reloaded = metadata_io.read_metadata(basedir) + metadata_io.write_metadata(original, basedir) + reloaded = metadata_io.read_metadata(basedir) - self.assertEqual(original, reloaded) + self.assertEqual(original, reloaded) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow_transform/tf_metadata/schema_utils.py b/tensorflow_transform/tf_metadata/schema_utils.py index 3d49b51..41225ea 100644 --- a/tensorflow_transform/tf_metadata/schema_utils.py +++ b/tensorflow_transform/tf_metadata/schema_utils.py @@ -18,528 +18,584 @@ from typing import Dict, List, Mapping, Optional, Tuple, Union import tensorflow as tf -from tensorflow_transform import common_types -from tensorflow_transform.tf_metadata import schema_utils_legacy +from tensorflow_metadata.proto.v0 import path_pb2, schema_pb2 from tfx_bsl.tfxio import tensor_representation_util -from tensorflow_metadata.proto.v0 import path_pb2 -from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_transform import common_types +from tensorflow_transform.tf_metadata import schema_utils_legacy # We use an empty name for the default tensor representation group in the output # schema. It contains all ragged output tensor representations. -TENSOR_REPRESENTATION_GROUP = '' +TENSOR_REPRESENTATION_GROUP = "" def schema_from_feature_spec( feature_spec: Mapping[str, common_types.FeatureSpecType], - domains: Optional[Mapping[str, common_types.DomainType]] = None + domains: Optional[Mapping[str, common_types.DomainType]] = None, ) -> schema_pb2.Schema: - """Convert a feature spec to a Schema proto. - - Args: - feature_spec: A TensorFlow feature spec - domains: (optional) a dict whose keys are feature names and values are one - of schema_pb2.IntDomain, schema_pb2.StringDomain or - schema_pb2.FloatDomain. - - Returns: - A Schema proto - - Raises: - ValueError: If the feature spec cannot be converted to a Schema proto. - """ - if domains is None: - domains = {} - - result = schema_pb2.Schema() - - # Some feature specs can only be represented with the legacy schema, in - # particular feature specs where any FixedLenFeature has default_value set. - # We represent these (and only these) using a schema with - # generate_legacy_feature_spec=True. Note the generate_legacy_feature_spec - # field is not part of the open source codebase. - if schema_utils_legacy.should_set_generate_legacy_feature_spec(feature_spec): - return _legacy_schema_from_feature_spec(feature_spec, domains) - - schema_utils_legacy.set_generate_legacy_feature_spec(result, False) - - # Add the features to the schema. - for name, spec in sorted(feature_spec.items()): - if isinstance(spec, tf.io.SparseFeature): - (index_feature, value_feature, sparse_feature) = ( - _sparse_feature_from_feature_spec(spec, name, domains)) - for f in index_feature: - result.feature.add().CopyFrom(f) - result.feature.add().CopyFrom(value_feature) - result.sparse_feature.add().CopyFrom(sparse_feature) - - elif isinstance(spec, tf.io.RaggedFeature): - (value_feature, partitions_features, ragged_tensor_representation) = ( - _ragged_tensor_representation_from_feature_spec(spec, name, domains)) - result.feature.add().CopyFrom(value_feature) - for f in partitions_features: - result.feature.add().CopyFrom(f) - tensor_representation_map = result.tensor_representation_group[ - TENSOR_REPRESENTATION_GROUP].tensor_representation - tensor_representation_map[name].CopyFrom(ragged_tensor_representation) - - else: - result.feature.add().CopyFrom( - _feature_from_feature_spec(spec, name, domains)) - return result + """Convert a feature spec to a Schema proto. + + Args: + ---- + feature_spec: A TensorFlow feature spec + domains: (optional) a dict whose keys are feature names and values are one + of schema_pb2.IntDomain, schema_pb2.StringDomain or + schema_pb2.FloatDomain. + + Returns: + ------- + A Schema proto + + Raises: + ------ + ValueError: If the feature spec cannot be converted to a Schema proto. + """ + if domains is None: + domains = {} + + result = schema_pb2.Schema() + + # Some feature specs can only be represented with the legacy schema, in + # particular feature specs where any FixedLenFeature has default_value set. + # We represent these (and only these) using a schema with + # generate_legacy_feature_spec=True. Note the generate_legacy_feature_spec + # field is not part of the open source codebase. + if schema_utils_legacy.should_set_generate_legacy_feature_spec(feature_spec): + return _legacy_schema_from_feature_spec(feature_spec, domains) + + schema_utils_legacy.set_generate_legacy_feature_spec(result, False) + + # Add the features to the schema. + for name, spec in sorted(feature_spec.items()): + if isinstance(spec, tf.io.SparseFeature): + (index_feature, value_feature, sparse_feature) = ( + _sparse_feature_from_feature_spec(spec, name, domains) + ) + for f in index_feature: + result.feature.add().CopyFrom(f) + result.feature.add().CopyFrom(value_feature) + result.sparse_feature.add().CopyFrom(sparse_feature) + + elif isinstance(spec, tf.io.RaggedFeature): + (value_feature, partitions_features, ragged_tensor_representation) = ( + _ragged_tensor_representation_from_feature_spec(spec, name, domains) + ) + result.feature.add().CopyFrom(value_feature) + for f in partitions_features: + result.feature.add().CopyFrom(f) + tensor_representation_map = result.tensor_representation_group[ + TENSOR_REPRESENTATION_GROUP + ].tensor_representation + tensor_representation_map[name].CopyFrom(ragged_tensor_representation) + + else: + result.feature.add().CopyFrom( + _feature_from_feature_spec(spec, name, domains) + ) + return result def _ragged_tensor_representation_from_feature_spec( - spec: tf.io.RaggedFeature, name: str, domains: Dict[str, - common_types.DomainType] -) -> Tuple[schema_pb2.Feature, List[schema_pb2.Feature], - schema_pb2.TensorRepresentation]: - """Returns representation of a RaggedTensor from a feature spec. - - Args: - spec: A tf.io.RaggedFeature feature spec. - name: Feature name. - domains: A dict whose keys are feature names and values are one of - schema_pb2.IntDomain, schema_pb2.StringDomain or schema_pb2.FloatDomain. - - Returns: - A tuple (value_feature, partitions_features, ragged_tensor_rep), - where value_feature represents RaggedTensor values, partitions_features - represent row lengths partitions and ragged_tensor_rep - ragged - TensorRepresentation. - - Raises: - ValueError: If the feature spec contains partition types different from - UniformRowLength and RowLengths. - """ - value_feature = schema_pb2.Feature(name=spec.value_key or name) - _set_type(name, value_feature, spec.dtype) - _set_domain(name, value_feature, domains.get(name)) - - ragged_tensor = schema_pb2.TensorRepresentation.RaggedTensor( - feature_path=path_pb2.Path(step=[spec.value_key or name])) - - partitions_features = [] - for partition in spec.partitions: - if isinstance(partition, tf.io.RaggedFeature.UniformRowLength): # pytype: disable=attribute-error - ragged_tensor.partition.append( - schema_pb2.TensorRepresentation.RaggedTensor.Partition( - uniform_row_length=partition.length)) - elif isinstance(partition, tf.io.RaggedFeature.RowLengths): # pytype: disable=attribute-error - ragged_tensor.partition.append( - schema_pb2.TensorRepresentation.RaggedTensor.Partition( - row_length=partition.key)) - partitions_features.append( - schema_pb2.Feature(name=partition.key, type=schema_pb2.INT)) - else: - raise ValueError( - 'RaggedFeature can only be created with UniformRowLength and ' - 'RowLengths partitions.') - - return value_feature, partitions_features, schema_pb2.TensorRepresentation( - ragged_tensor=ragged_tensor) + spec: tf.io.RaggedFeature, name: str, domains: Dict[str, common_types.DomainType] +) -> Tuple[ + schema_pb2.Feature, List[schema_pb2.Feature], schema_pb2.TensorRepresentation +]: + """Returns representation of a RaggedTensor from a feature spec. + + Args: + ---- + spec: A tf.io.RaggedFeature feature spec. + name: Feature name. + domains: A dict whose keys are feature names and values are one of + schema_pb2.IntDomain, schema_pb2.StringDomain or schema_pb2.FloatDomain. + + Returns: + ------- + A tuple (value_feature, partitions_features, ragged_tensor_rep), + where value_feature represents RaggedTensor values, partitions_features + represent row lengths partitions and ragged_tensor_rep - ragged + TensorRepresentation. + + Raises: + ------ + ValueError: If the feature spec contains partition types different from + UniformRowLength and RowLengths. + """ + value_feature = schema_pb2.Feature(name=spec.value_key or name) + _set_type(name, value_feature, spec.dtype) + _set_domain(name, value_feature, domains.get(name)) + + ragged_tensor = schema_pb2.TensorRepresentation.RaggedTensor( + feature_path=path_pb2.Path(step=[spec.value_key or name]) + ) + + partitions_features = [] + for partition in spec.partitions: + if isinstance( + partition, tf.io.RaggedFeature.UniformRowLength + ): # pytype: disable=attribute-error + ragged_tensor.partition.append( + schema_pb2.TensorRepresentation.RaggedTensor.Partition( + uniform_row_length=partition.length + ) + ) + elif isinstance( + partition, tf.io.RaggedFeature.RowLengths + ): # pytype: disable=attribute-error + ragged_tensor.partition.append( + schema_pb2.TensorRepresentation.RaggedTensor.Partition( + row_length=partition.key + ) + ) + partitions_features.append( + schema_pb2.Feature(name=partition.key, type=schema_pb2.INT) + ) + else: + raise ValueError( + "RaggedFeature can only be created with UniformRowLength and " + "RowLengths partitions." + ) + + return ( + value_feature, + partitions_features, + schema_pb2.TensorRepresentation(ragged_tensor=ragged_tensor), + ) def _sparse_feature_from_feature_spec(spec, name, domains): - """Returns a representation of a SparseFeature from a feature spec.""" - if isinstance(spec.index_key, list): - assert isinstance(spec.size, (list, tuple, tf.TensorShape)), type(spec.size) - assert len(spec.index_key) == len(spec.size), (spec.index_key, spec.size) - spec_size = [ - s.value if isinstance(s, tf.compat.v1.Dimension) else s - for s in spec.size - ] - spec_size = [s if s != -1 else None for s in spec_size] - int_domains = [ - schema_pb2.IntDomain(min=0, max=size - 1) if size is not None else None - for size in spec_size - ] - index_feature = [ - schema_pb2.Feature( - name=key, type=schema_pb2.INT, int_domain=int_domain) - for (key, int_domain) in zip(spec.index_key, int_domains) - ] - index_feature_ref = [ - schema_pb2.SparseFeature.IndexFeature(name=key) - for key in spec.index_key - ] - else: - # Create a index feature. - index_feature = [ - schema_pb2.Feature( - name=spec.index_key, - type=schema_pb2.INT, - int_domain=schema_pb2.IntDomain(min=0, max=spec.size - 1)) - ] - index_feature_ref = [ - schema_pb2.SparseFeature.IndexFeature(name=spec.index_key) - ] - - # Create a value feature. - value_feature = schema_pb2.Feature(name=spec.value_key) - _set_type(name, value_feature, spec.dtype) - _set_domain(name, value_feature, domains.get(name)) - - # Create a sparse feature which refers to the index and value features. - value_feature_ref = schema_pb2.SparseFeature.ValueFeature(name=spec.value_key) - sparse_feature = schema_pb2.SparseFeature( - name=name, - is_sorted=True if spec.already_sorted else None, - index_feature=index_feature_ref, - value_feature=value_feature_ref) - - return (index_feature, value_feature, sparse_feature) + """Returns a representation of a SparseFeature from a feature spec.""" + if isinstance(spec.index_key, list): + assert isinstance(spec.size, (list, tuple, tf.TensorShape)), type(spec.size) + assert len(spec.index_key) == len(spec.size), (spec.index_key, spec.size) + spec_size = [ + s.value if isinstance(s, tf.compat.v1.Dimension) else s for s in spec.size + ] + spec_size = [s if s != -1 else None for s in spec_size] + int_domains = [ + schema_pb2.IntDomain(min=0, max=size - 1) if size is not None else None + for size in spec_size + ] + index_feature = [ + schema_pb2.Feature(name=key, type=schema_pb2.INT, int_domain=int_domain) + for (key, int_domain) in zip(spec.index_key, int_domains) + ] + index_feature_ref = [ + schema_pb2.SparseFeature.IndexFeature(name=key) for key in spec.index_key + ] + else: + # Create a index feature. + index_feature = [ + schema_pb2.Feature( + name=spec.index_key, + type=schema_pb2.INT, + int_domain=schema_pb2.IntDomain(min=0, max=spec.size - 1), + ) + ] + index_feature_ref = [schema_pb2.SparseFeature.IndexFeature(name=spec.index_key)] + + # Create a value feature. + value_feature = schema_pb2.Feature(name=spec.value_key) + _set_type(name, value_feature, spec.dtype) + _set_domain(name, value_feature, domains.get(name)) + + # Create a sparse feature which refers to the index and value features. + value_feature_ref = schema_pb2.SparseFeature.ValueFeature(name=spec.value_key) + sparse_feature = schema_pb2.SparseFeature( + name=name, + is_sorted=True if spec.already_sorted else None, + index_feature=index_feature_ref, + value_feature=value_feature_ref, + ) + + return (index_feature, value_feature, sparse_feature) def _feature_from_feature_spec(spec, name, domains): - """Returns a representation of a Feature from a feature spec.""" - if isinstance(spec, tf.io.FixedLenFeature): - if spec.default_value is not None: - raise ValueError( - 'feature "{}" had default_value {}, but FixedLenFeature must have ' - 'default_value=None'.format(name, spec.default_value)) - dims = [schema_pb2.FixedShape.Dim(size=size) for size in spec.shape] - feature = schema_pb2.Feature( - name=name, - presence=schema_pb2.FeaturePresence(min_fraction=1.0), - shape=schema_pb2.FixedShape(dim=dims)) - elif isinstance(spec, tf.io.VarLenFeature): - feature = schema_pb2.Feature(name=name) - else: - raise TypeError('Spec for feature "{}" was {} of type {}, expected a ' - 'FixedLenFeature, VarLenFeature or SparseFeature'.format( - name, spec, type(spec))) + """Returns a representation of a Feature from a feature spec.""" + if isinstance(spec, tf.io.FixedLenFeature): + if spec.default_value is not None: + raise ValueError( + f'feature "{name}" had default_value {spec.default_value}, but FixedLenFeature must have ' + "default_value=None" + ) + dims = [schema_pb2.FixedShape.Dim(size=size) for size in spec.shape] + feature = schema_pb2.Feature( + name=name, + presence=schema_pb2.FeaturePresence(min_fraction=1.0), + shape=schema_pb2.FixedShape(dim=dims), + ) + elif isinstance(spec, tf.io.VarLenFeature): + feature = schema_pb2.Feature(name=name) + else: + raise TypeError( + f'Spec for feature "{name}" was {spec} of type {type(spec)}, expected a ' + "FixedLenFeature, VarLenFeature or SparseFeature" + ) - _set_type(name, feature, spec.dtype) - _set_domain(name, feature, domains.get(name)) - return feature + _set_type(name, feature, spec.dtype) + _set_domain(name, feature, domains.get(name)) + return feature def _set_type(name, feature, dtype): - """Set the type of a Feature proto.""" - if dtype == tf.int64: - feature.type = schema_pb2.INT - elif dtype == tf.float32: - feature.type = schema_pb2.FLOAT - elif dtype == tf.string: - feature.type = schema_pb2.BYTES - else: - raise ValueError('Feature "{}" has invalid dtype {}'.format(name, dtype)) + """Set the type of a Feature proto.""" + if dtype == tf.int64: + feature.type = schema_pb2.INT + elif dtype == tf.float32: + feature.type = schema_pb2.FLOAT + elif dtype == tf.string: + feature.type = schema_pb2.BYTES + else: + raise ValueError(f'Feature "{name}" has invalid dtype {dtype}') def _set_domain(name, feature, domain): - """Set the domain of a Feature proto.""" - if domain is None: - return - - if isinstance(domain, schema_pb2.IntDomain): - feature.int_domain.CopyFrom(domain) - elif isinstance(domain, schema_pb2.StringDomain): - feature.string_domain.CopyFrom(domain) - elif isinstance(domain, schema_pb2.FloatDomain): - feature.float_domain.CopyFrom(domain) - else: - raise ValueError('Feature "{}" has invalid domain {}'.format(name, domain)) + """Set the domain of a Feature proto.""" + if domain is None: + return + + if isinstance(domain, schema_pb2.IntDomain): + feature.int_domain.CopyFrom(domain) + elif isinstance(domain, schema_pb2.StringDomain): + feature.string_domain.CopyFrom(domain) + elif isinstance(domain, schema_pb2.FloatDomain): + feature.float_domain.CopyFrom(domain) + else: + raise ValueError(f'Feature "{name}" has invalid domain {domain}') @dataclasses.dataclass(frozen=True) class SchemaAsFeatureSpecResult: - feature_spec: Dict[str, common_types.FeatureSpecType] - domains: Dict[str, common_types.DomainType] - - # This is needed because many users unpack this with: - # `feature_spec, domains = schema_utils.schema_as_feature_spec()`. - def __iter__(self): - return (getattr(self, field.name) for field in dataclasses.fields(self)) - - -def _standardize_default_value( - spec: tf.io.FixedLenFeature) -> tf.io.FixedLenFeature: - """Converts bytes to strings and unwraps lists with a single element.""" - if spec.default_value is None: - return spec - default_value = spec.default_value - assert isinstance(default_value, list), spec.default_value - # Convert bytes to string - if spec.dtype == tf.string: - - # Handle bytes string by trying to decode them (for legacy backwards - # compatibility) and if failed, keep the default value as bytes. - def try_decode(value: bytes) -> Union[str, bytes]: - try: - return value.decode('utf-8') - except UnicodeError: - return value - - default_value = [try_decode(value) for value in default_value] - # Unwrap a list with a single element. - if len(default_value) == 1: - default_value = default_value[0] - return tf.io.FixedLenFeature( - shape=spec.shape, dtype=spec.dtype, default_value=default_value) + feature_spec: Dict[str, common_types.FeatureSpecType] + domains: Dict[str, common_types.DomainType] + + # This is needed because many users unpack this with: + # `feature_spec, domains = schema_utils.schema_as_feature_spec()`. + def __iter__(self): + return (getattr(self, field.name) for field in dataclasses.fields(self)) + + +def _standardize_default_value(spec: tf.io.FixedLenFeature) -> tf.io.FixedLenFeature: + """Converts bytes to strings and unwraps lists with a single element.""" + if spec.default_value is None: + return spec + default_value = spec.default_value + assert isinstance(default_value, list), spec.default_value + # Convert bytes to string + if spec.dtype == tf.string: + # Handle bytes string by trying to decode them (for legacy backwards + # compatibility) and if failed, keep the default value as bytes. + def try_decode(value: bytes) -> Union[str, bytes]: + try: + return value.decode("utf-8") + except UnicodeError: + return value + + default_value = [try_decode(value) for value in default_value] + # Unwrap a list with a single element. + if len(default_value) == 1: + default_value = default_value[0] + return tf.io.FixedLenFeature( + shape=spec.shape, dtype=spec.dtype, default_value=default_value + ) def schema_as_feature_spec( - schema_proto: schema_pb2.Schema) -> SchemaAsFeatureSpecResult: - """Generates a feature spec from a Schema proto. - - For a Feature with a FixedShape we generate a FixedLenFeature with no default. - For a Feature without a FixedShape we generate a VarLenFeature. For a - SparseFeature we generate a SparseFeature. - If schema contains struct feature, then it must also contain - TensorRepresentations and is assumed to describe SequenceExample data. The - result in such case is union of context and sequence feature specs. - - Args: - schema_proto: A Schema proto. - - Returns: - A pair (feature spec, domains) where feature spec is a dict whose keys are - feature names and values are instances of FixedLenFeature, - VarLenFeature, SparseFeature or RaggedFeature, and `domains` is a dict - whose keys are feature names and values are one of the `domain_info` - oneof, e.g. IntDomain. - - Raises: - ValueError: If the schema proto is invalid. - """ - - # Presence of a struct means that data's physical format is tf.SequenceExample - # and the struct contains sequence features. - if any(feature.type == schema_pb2.STRUCT for feature in schema_proto.feature): - return _sequence_schema_as_feature_spec(schema_proto) - - tensor_representations = ( - tensor_representation_util.InferTensorRepresentationsFromMixedSchema( - schema_proto)) - - feature_spec = {} - # Will hold the domain_info (IntDomain, FloatDomain etc.) of the feature. For - # sparse features, will hold the domain_info of the values feature. Features - # that do not have a domain set will not be present in `domains`. - domains = {} - string_domains = _get_string_domains(schema_proto) - feature_by_name = {feature.name: feature for feature in schema_proto.feature} - for name, tensor_representation in tensor_representations.items(): - value_feature = str( - tensor_representation_util.GetSourceValueColumnFromTensorRepresentation( - tensor_representation)) - spec = ( - tensor_representation_util.CreateTfExampleParserConfig( - tensor_representation, feature_by_name[value_feature].type)) - if isinstance(spec, tf.io.FixedLenFeature): - spec = _standardize_default_value(spec) - feature_spec[name] = spec - domain = _get_domain(feature_by_name[value_feature], string_domains) - if domain is not None: - domains[name] = domain - return SchemaAsFeatureSpecResult(feature_spec, domains) + schema_proto: schema_pb2.Schema, +) -> SchemaAsFeatureSpecResult: + """Generates a feature spec from a Schema proto. + + For a Feature with a FixedShape we generate a FixedLenFeature with no default. + For a Feature without a FixedShape we generate a VarLenFeature. For a + SparseFeature we generate a SparseFeature. + If schema contains struct feature, then it must also contain + TensorRepresentations and is assumed to describe SequenceExample data. The + result in such case is union of context and sequence feature specs. + + Args: + ---- + schema_proto: A Schema proto. + + Returns: + ------- + A pair (feature spec, domains) where feature spec is a dict whose keys are + feature names and values are instances of FixedLenFeature, + VarLenFeature, SparseFeature or RaggedFeature, and `domains` is a dict + whose keys are feature names and values are one of the `domain_info` + oneof, e.g. IntDomain. + + Raises: + ------ + ValueError: If the schema proto is invalid. + """ + # Presence of a struct means that data's physical format is tf.SequenceExample + # and the struct contains sequence features. + if any(feature.type == schema_pb2.STRUCT for feature in schema_proto.feature): + return _sequence_schema_as_feature_spec(schema_proto) + + tensor_representations = ( + tensor_representation_util.InferTensorRepresentationsFromMixedSchema( + schema_proto + ) + ) + + feature_spec = {} + # Will hold the domain_info (IntDomain, FloatDomain etc.) of the feature. For + # sparse features, will hold the domain_info of the values feature. Features + # that do not have a domain set will not be present in `domains`. + domains = {} + string_domains = _get_string_domains(schema_proto) + feature_by_name = {feature.name: feature for feature in schema_proto.feature} + for name, tensor_representation in tensor_representations.items(): + value_feature = str( + tensor_representation_util.GetSourceValueColumnFromTensorRepresentation( + tensor_representation + ) + ) + spec = tensor_representation_util.CreateTfExampleParserConfig( + tensor_representation, feature_by_name[value_feature].type + ) + if isinstance(spec, tf.io.FixedLenFeature): + spec = _standardize_default_value(spec) + feature_spec[name] = spec + domain = _get_domain(feature_by_name[value_feature], string_domains) + if domain is not None: + domains[name] = domain + return SchemaAsFeatureSpecResult(feature_spec, domains) def _sequence_schema_as_feature_spec( - schema: schema_pb2.Schema) -> SchemaAsFeatureSpecResult: - """Generates a feature spec from a Schema describing tf.SequenceExample data. - - See `tensor_representation_util.CreateTfSequenceExampleParserConfig`s - docstring for feature spec generation rules. - We mix context and sequence feature specs to replicate how preprocessing_fn - sees input features -- as top-level values of a single `inputs` dict. Note - that this makes the feature spec generation irreversible without additional - input since it's no longer possible to distinguish context and sequence - features to produce the original schema. - - Args: - schema: A TFMD Schema proto. - - Returns: - A pair (feature spec, domains) where feature spec is a dict whose keys are - feature names and values are instances of FixedLenFeature, - VarLenFeature, SparseFeature or RaggedFeature, and `domains` is a dict - whose keys are feature names and values are one of the `domain_info` - oneof, e.g. IntDomain. - - Raises: - ValueError: If `TensorRepresentation`s in the schema result in feature specs - that are not supported. - """ - (context_feature_spec, sequence_feature_spec - ) = tensor_representation_util.CreateTfSequenceExampleParserConfig(schema) - feature_spec = {**context_feature_spec, **sequence_feature_spec} - string_domains = _get_string_domains(schema) - domain_by_feature_name = _get_source_feature_domains(schema, string_domains) - domains = {} - for name, spec in feature_spec.items(): - if isinstance(spec, (tf.io.FixedLenFeature, tf.io.VarLenFeature)): - source_feature_name = name - elif isinstance(spec, (tf.io.SparseFeature, tf.io.RaggedFeature)): - source_feature_name = spec.value_key - else: - raise ValueError('spec is not recognized') - if source_feature_name in domain_by_feature_name: - domains[name] = domain_by_feature_name[source_feature_name] - return SchemaAsFeatureSpecResult(feature_spec, domains) + schema: schema_pb2.Schema, +) -> SchemaAsFeatureSpecResult: + """Generates a feature spec from a Schema describing tf.SequenceExample data. + + See `tensor_representation_util.CreateTfSequenceExampleParserConfig`s + docstring for feature spec generation rules. + We mix context and sequence feature specs to replicate how preprocessing_fn + sees input features -- as top-level values of a single `inputs` dict. Note + that this makes the feature spec generation irreversible without additional + input since it's no longer possible to distinguish context and sequence + features to produce the original schema. + + Args: + ---- + schema: A TFMD Schema proto. + + Returns: + ------- + A pair (feature spec, domains) where feature spec is a dict whose keys are + feature names and values are instances of FixedLenFeature, + VarLenFeature, SparseFeature or RaggedFeature, and `domains` is a dict + whose keys are feature names and values are one of the `domain_info` + oneof, e.g. IntDomain. + + Raises: + ------ + ValueError: If `TensorRepresentation`s in the schema result in feature specs + that are not supported. + """ + (context_feature_spec, sequence_feature_spec) = ( + tensor_representation_util.CreateTfSequenceExampleParserConfig(schema) + ) + feature_spec = {**context_feature_spec, **sequence_feature_spec} + string_domains = _get_string_domains(schema) + domain_by_feature_name = _get_source_feature_domains(schema, string_domains) + domains = {} + for name, spec in feature_spec.items(): + if isinstance(spec, (tf.io.FixedLenFeature, tf.io.VarLenFeature)): + source_feature_name = name + elif isinstance(spec, (tf.io.SparseFeature, tf.io.RaggedFeature)): + source_feature_name = spec.value_key + else: + raise ValueError("spec is not recognized") + if source_feature_name in domain_by_feature_name: + domains[name] = domain_by_feature_name[source_feature_name] + return SchemaAsFeatureSpecResult(feature_spec, domains) def _get_source_feature_domains( schema_or_domain: Union[schema_pb2.Schema, schema_pb2.StructDomain], - string_domains: Dict[str, schema_pb2.StringDomain] + string_domains: Dict[str, schema_pb2.StringDomain], ) -> Dict[str, common_types.DomainType]: - """Recursively extracts domains of all source features in the schema.""" - result = {} - for feature in schema_or_domain.feature: - domain_info = feature.WhichOneof('domain_info') - if domain_info == 'struct_domain': - result.update( - _get_source_feature_domains(feature.struct_domain, string_domains)) - else: - domain = _get_domain(feature, string_domains) - if domain is not None: - result[feature.name] = domain - return result + """Recursively extracts domains of all source features in the schema.""" + result = {} + for feature in schema_or_domain.feature: + domain_info = feature.WhichOneof("domain_info") + if domain_info == "struct_domain": + result.update( + _get_source_feature_domains(feature.struct_domain, string_domains) + ) + else: + domain = _get_domain(feature, string_domains) + if domain is not None: + result[feature.name] = domain + return result def _get_string_domains( - schema: schema_pb2.Schema) -> Dict[str, schema_pb2.StringDomain]: - return {domain.name: domain for domain in schema.string_domain} + schema: schema_pb2.Schema, +) -> Dict[str, schema_pb2.StringDomain]: + return {domain.name: domain for domain in schema.string_domain} def _get_domain(feature, string_domains): - """Get the domain of a feature, possibly looking up a schema-level domain.""" - domain_info = feature.WhichOneof('domain_info') - if domain_info is None: - return None - if domain_info == 'domain': - try: - return string_domains[feature.domain] - except KeyError: - tf.compat.v1.logging.warn( - 'Feature "%s" referred to string domain "%s" which did not exist', - feature.name, feature.domain) - return None - return getattr(feature, domain_info) + """Get the domain of a feature, possibly looking up a schema-level domain.""" + domain_info = feature.WhichOneof("domain_info") + if domain_info is None: + return None + if domain_info == "domain": + try: + return string_domains[feature.domain] + except KeyError: + tf.compat.v1.logging.warn( + 'Feature "%s" referred to string domain "%s" which did not exist', + feature.name, + feature.domain, + ) + return None + return getattr(feature, domain_info) def pop_ragged_source_columns( - name: str, tensor_representation: schema_pb2.TensorRepresentation, - feature_by_name: Dict[str, schema_pb2.Feature]) -> schema_pb2.Feature: - """Removes source columns of a ragged tensor from the given features dict. - - Args: - name: Name of the ragged tensor. - tensor_representation: Ragged TensorRepresentation. - feature_by_name: Dict of features that contains source columns of the ragged - TensorRepresentation. - - Returns: - Value feature of the ragged tensor. - - Raises: - ValueError: If any of the source columns are missing in the features dict. - """ - source_columns = ( - tensor_representation_util.GetSourceColumnsFromTensorRepresentation( - tensor_representation)) - missing_column_error_format = ( - 'Ragged feature "{}" referred to value feature "{}" which did not exist ' - 'in the schema or was referred to as an index or value multiple times.') - - assert source_columns - assert len(source_columns[0].steps()) == 1, (name, source_columns[0].steps()) - try: - value_feature = feature_by_name.pop(source_columns[0].steps()[0]) - except KeyError: - raise ValueError( - missing_column_error_format.format(name, source_columns[0].steps()[0])) - for column_path in source_columns[1:]: - assert len(column_path.steps()) == 1, (name, column_path.steps()) + name: str, + tensor_representation: schema_pb2.TensorRepresentation, + feature_by_name: Dict[str, schema_pb2.Feature], +) -> schema_pb2.Feature: + """Removes source columns of a ragged tensor from the given features dict. + + Args: + ---- + name: Name of the ragged tensor. + tensor_representation: Ragged TensorRepresentation. + feature_by_name: Dict of features that contains source columns of the ragged + TensorRepresentation. + + Returns: + ------- + Value feature of the ragged tensor. + + Raises: + ------ + ValueError: If any of the source columns are missing in the features dict. + """ + source_columns = ( + tensor_representation_util.GetSourceColumnsFromTensorRepresentation( + tensor_representation + ) + ) + missing_column_error_format = ( + 'Ragged feature "{}" referred to value feature "{}" which did not exist ' + "in the schema or was referred to as an index or value multiple times." + ) + + assert source_columns + assert len(source_columns[0].steps()) == 1, (name, source_columns[0].steps()) try: - row_length_feature = feature_by_name.pop(column_path.steps()[0]) + value_feature = feature_by_name.pop(source_columns[0].steps()[0]) except KeyError: - raise ValueError( - missing_column_error_format.format(name, - column_path.steps()[0])) - if row_length_feature.type != schema_pb2.FeatureType.INT: - raise ValueError( - 'Row length feature "{}" is not an integer feature.'.format( - row_length_feature.name)) - return value_feature + raise ValueError( + missing_column_error_format.format(name, source_columns[0].steps()[0]) + ) + for column_path in source_columns[1:]: + assert len(column_path.steps()) == 1, (name, column_path.steps()) + try: + row_length_feature = feature_by_name.pop(column_path.steps()[0]) + except KeyError: + raise ValueError( + missing_column_error_format.format(name, column_path.steps()[0]) + ) + if row_length_feature.type != schema_pb2.FeatureType.INT: + raise ValueError( + f'Row length feature "{row_length_feature.name}" is not an integer feature.' + ) + return value_feature def _ragged_tensor_representation_as_feature_spec( - name: str, tensor_representation: schema_pb2.TensorRepresentation, + name: str, + tensor_representation: schema_pb2.TensorRepresentation, feature_by_name: Dict[str, schema_pb2.Feature], - string_domains: Dict[str, common_types.DomainType] + string_domains: Dict[str, common_types.DomainType], ) -> Tuple[tf.io.RaggedFeature, Optional[common_types.DomainType]]: - """Returns a representation of a RaggedTensor as a feature spec.""" - value_feature = pop_ragged_source_columns(name, tensor_representation, - feature_by_name) - spec = tensor_representation_util.CreateTfExampleParserConfig( - tensor_representation, value_feature.type) - domain = _get_domain(value_feature, string_domains) - return typing.cast(tf.io.RaggedFeature, spec), domain + """Returns a representation of a RaggedTensor as a feature spec.""" + value_feature = pop_ragged_source_columns( + name, tensor_representation, feature_by_name + ) + spec = tensor_representation_util.CreateTfExampleParserConfig( + tensor_representation, value_feature.type + ) + domain = _get_domain(value_feature, string_domains) + return typing.cast(tf.io.RaggedFeature, spec), domain def _legacy_schema_from_feature_spec(feature_spec, domains=None): - """Infer a Schema from a feature spec, using the legacy feature spec logic. - - Infers a Schema proto that with generate_legacy_feature_spec set to true, - which will result in the given feature spec and domains when - schema_as_feature_spec is called. This is used to represent feature specs - that can only be represented when generate_legacy_feature_spec is true. In - particular, feature specs with a default value set. - - Args: - feature_spec: A TensorFlow feature spec - domains: A dict from key names to `IntDomain`s - - Returns: - A Schema proto. - - Raises: - ValueError: If a default value is invalid. - TypeError: If an unknown type of feature spec is encountered. - """ - result = schema_pb2.Schema() - result.generate_legacy_feature_spec = True - for name, spec in sorted(feature_spec.items()): - if isinstance(spec, tf.io.FixedLenFeature): - # Validate shape first as shape governs which default values are valid. - if len(spec.shape) == 0: # pylint: disable=g-explicit-length-test - size = 1 - expected_default_value = '' if spec.dtype == tf.string else -1 - elif len(spec.shape) == 1 and spec.shape[0] > 1: - size = spec.shape[0] - expected_default_value = ['' if spec.dtype == tf.string else -1] * size - else: - raise ValueError( - 'When inferring legacy schema from feature spec, feature "{}" had ' - 'shape {}, but FixedLenFeature must have shape [] or [k] where ' - 'k > 1.'.format(name, spec.shape)) - - if spec.default_value is None: - min_fraction = 1 - elif spec.default_value == expected_default_value: - min_fraction = 0 - else: - raise ValueError( - 'When inferring legacy schema from feature spec, feature "{}" had ' - 'default_value {}, but FixedLenFeature must have ' - 'default_value=None or {}'.format(name, spec.default_value, - expected_default_value)) - - feature = result.feature.add( - name=name, - presence=schema_pb2.FeaturePresence(min_fraction=min_fraction), - value_count=schema_pb2.ValueCount(min=size, max=size)) - elif isinstance(spec, tf.io.VarLenFeature): - feature = result.feature.add(name=name) - else: - raise TypeError( - 'When inferring legacy schema from feature spec, spec for feature ' - '"{}" was {} of type {}, expected a FixedLenFeature or ' - 'VarLenFeature '.format(name, spec, type(spec))) - - _set_type(name, feature, spec.dtype) - _set_domain(name, feature, domains.get(name)) - - return result + """Infer a Schema from a feature spec, using the legacy feature spec logic. + + Infers a Schema proto that with generate_legacy_feature_spec set to true, + which will result in the given feature spec and domains when + schema_as_feature_spec is called. This is used to represent feature specs + that can only be represented when generate_legacy_feature_spec is true. In + particular, feature specs with a default value set. + + Args: + ---- + feature_spec: A TensorFlow feature spec + domains: A dict from key names to `IntDomain`s + + Returns: + ------- + A Schema proto. + + Raises: + ------ + ValueError: If a default value is invalid. + TypeError: If an unknown type of feature spec is encountered. + """ + result = schema_pb2.Schema() + result.generate_legacy_feature_spec = True + for name, spec in sorted(feature_spec.items()): + if isinstance(spec, tf.io.FixedLenFeature): + # Validate shape first as shape governs which default values are valid. + if len(spec.shape) == 0: # pylint: disable=g-explicit-length-test + size = 1 + expected_default_value = "" if spec.dtype == tf.string else -1 + elif len(spec.shape) == 1 and spec.shape[0] > 1: + size = spec.shape[0] + expected_default_value = ["" if spec.dtype == tf.string else -1] * size + else: + raise ValueError( + f'When inferring legacy schema from feature spec, feature "{name}" had ' + f"shape {spec.shape}, but FixedLenFeature must have shape [] or [k] where " + "k > 1." + ) + + if spec.default_value is None: + min_fraction = 1 + elif spec.default_value == expected_default_value: + min_fraction = 0 + else: + raise ValueError( + f'When inferring legacy schema from feature spec, feature "{name}" had ' + f"default_value {spec.default_value}, but FixedLenFeature must have " + f"default_value=None or {expected_default_value}" + ) + + feature = result.feature.add( + name=name, + presence=schema_pb2.FeaturePresence(min_fraction=min_fraction), + value_count=schema_pb2.ValueCount(min=size, max=size), + ) + elif isinstance(spec, tf.io.VarLenFeature): + feature = result.feature.add(name=name) + else: + raise TypeError( + "When inferring legacy schema from feature spec, spec for feature " + f'"{name}" was {spec} of type {type(spec)}, expected a FixedLenFeature or ' + "VarLenFeature " + ) + + _set_type(name, feature, spec.dtype) + _set_domain(name, feature, domains.get(name)) + + return result diff --git a/tensorflow_transform/tf_metadata/schema_utils_legacy.py b/tensorflow_transform/tf_metadata/schema_utils_legacy.py index 66dce16..ba3e6b0 100644 --- a/tensorflow_transform/tf_metadata/schema_utils_legacy.py +++ b/tensorflow_transform/tf_metadata/schema_utils_legacy.py @@ -15,13 +15,14 @@ def should_set_generate_legacy_feature_spec(feature_spec): - del feature_spec # unused - return False + del feature_spec # unused + return False def set_generate_legacy_feature_spec(schema_proto, value): - del schema_proto # unused - if value: - raise NotImplementedError( - 'The generate_legacy_feature_spec is a legacy field that is not part ' - 'of the OSS tf.Transform codebase') + del schema_proto # unused + if value: + raise NotImplementedError( + "The generate_legacy_feature_spec is a legacy field that is not part " + "of the OSS tf.Transform codebase" + ) diff --git a/tensorflow_transform/tf_metadata/schema_utils_test.py b/tensorflow_transform/tf_metadata/schema_utils_test.py index 90d4814..de2bc4f 100644 --- a/tensorflow_transform/tf_metadata/schema_utils_test.py +++ b/tensorflow_transform/tf_metadata/schema_utils_test.py @@ -13,72 +13,99 @@ # limitations under the License. """Tests for tensorflow_transform.tf_metadata.schema_utils.""" -from absl.testing import parameterized -from tensorflow_transform.tf_metadata import schema_utils_legacy -from tensorflow_transform.tf_metadata import schema_utils_test_cases -from tensorflow_transform.tf_metadata import schema_utils +import unittest +from absl.testing import parameterized from google.protobuf import text_format -import unittest from tensorflow_metadata.proto.v0 import schema_pb2 +from tensorflow_transform.tf_metadata import ( + schema_utils, + schema_utils_legacy, + schema_utils_test_cases, +) -class SchemaUtilsTest(parameterized.TestCase): - @parameterized.named_parameters( - *schema_utils_test_cases.EQUIVALENT_FEATURE_SPEC_AND_SCHEMAS) - def test_schema_from_feature_spec( - self, ascii_proto, feature_spec, domains=None, - generate_legacy_feature_spec=False): - expected_schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) - schema_utils_legacy.set_generate_legacy_feature_spec( - expected_schema_proto, generate_legacy_feature_spec) - result = schema_utils.schema_from_feature_spec(feature_spec, domains) - self.assertEqual(result, expected_schema_proto) +class SchemaUtilsTest(parameterized.TestCase): + @parameterized.named_parameters( + *schema_utils_test_cases.EQUIVALENT_FEATURE_SPEC_AND_SCHEMAS + ) + def test_schema_from_feature_spec( + self, + ascii_proto, + feature_spec, + domains=None, + generate_legacy_feature_spec=False, + ): + expected_schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) + schema_utils_legacy.set_generate_legacy_feature_spec( + expected_schema_proto, generate_legacy_feature_spec + ) + result = schema_utils.schema_from_feature_spec(feature_spec, domains) + self.assertEqual(result, expected_schema_proto) - @parameterized.named_parameters( - *(schema_utils_test_cases.EQUIVALENT_FEATURE_SPEC_AND_SCHEMAS + - schema_utils_test_cases.NON_ROUNDTRIP_SCHEMAS)) - def test_schema_as_feature_spec( - self, ascii_proto, feature_spec, domains=None, - generate_legacy_feature_spec=False): - schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) - schema_utils_legacy.set_generate_legacy_feature_spec( - schema_proto, generate_legacy_feature_spec) - result = schema_utils.schema_as_feature_spec(schema_proto) - self.assertEqual( - result, - schema_utils.SchemaAsFeatureSpecResult(feature_spec, domains or {}), + @parameterized.named_parameters( + *( + schema_utils_test_cases.EQUIVALENT_FEATURE_SPEC_AND_SCHEMAS + + schema_utils_test_cases.NON_ROUNDTRIP_SCHEMAS + ) ) + def test_schema_as_feature_spec( + self, + ascii_proto, + feature_spec, + domains=None, + generate_legacy_feature_spec=False, + ): + schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) + schema_utils_legacy.set_generate_legacy_feature_spec( + schema_proto, generate_legacy_feature_spec + ) + result = schema_utils.schema_as_feature_spec(schema_proto) + self.assertEqual( + result, + schema_utils.SchemaAsFeatureSpecResult(feature_spec, domains or {}), + ) - @parameterized.named_parameters( - *schema_utils_test_cases.INVALID_SCHEMA_PROTOS) - def test_schema_as_feature_spec_fails( - self, ascii_proto, error_msg, error_class=ValueError, - generate_legacy_feature_spec=False): - schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) - schema_utils_legacy.set_generate_legacy_feature_spec( - schema_proto, generate_legacy_feature_spec) - with self.assertRaisesRegex(error_class, error_msg): - schema_utils.schema_as_feature_spec(schema_proto) + @parameterized.named_parameters(*schema_utils_test_cases.INVALID_SCHEMA_PROTOS) + def test_schema_as_feature_spec_fails( + self, + ascii_proto, + error_msg, + error_class=ValueError, + generate_legacy_feature_spec=False, + ): + schema_proto = text_format.Parse(ascii_proto, schema_pb2.Schema()) + schema_utils_legacy.set_generate_legacy_feature_spec( + schema_proto, generate_legacy_feature_spec + ) + with self.assertRaisesRegex(error_class, error_msg): + schema_utils.schema_as_feature_spec(schema_proto) - @parameterized.named_parameters( - *schema_utils_test_cases.INVALID_FEATURE_SPECS) - def test_schema_from_feature_spec_fails( - self, feature_spec, error_msg, domain=None, error_class=ValueError): - with self.assertRaisesRegex(error_class, error_msg): - schema_utils.schema_from_feature_spec(feature_spec, domain) + @parameterized.named_parameters(*schema_utils_test_cases.INVALID_FEATURE_SPECS) + def test_schema_from_feature_spec_fails( + self, feature_spec, error_msg, domain=None, error_class=ValueError + ): + with self.assertRaisesRegex(error_class, error_msg): + schema_utils.schema_from_feature_spec(feature_spec, domain) - @parameterized.named_parameters( - *schema_utils_test_cases.RAGGED_VALUE_FEATURES_AND_TENSOR_REPRESENTATIONS) - def test_pop_ragged_source_columns(self, name, tensor_representation, - feature_by_name, expected_value_feature, - truncated_feature_by_name): - value_feature = schema_utils.pop_ragged_source_columns( - name, tensor_representation, feature_by_name) - self.assertEqual(value_feature, expected_value_feature) - self.assertEqual(feature_by_name, truncated_feature_by_name) + @parameterized.named_parameters( + *schema_utils_test_cases.RAGGED_VALUE_FEATURES_AND_TENSOR_REPRESENTATIONS + ) + def test_pop_ragged_source_columns( + self, + name, + tensor_representation, + feature_by_name, + expected_value_feature, + truncated_feature_by_name, + ): + value_feature = schema_utils.pop_ragged_source_columns( + name, tensor_representation, feature_by_name + ) + self.assertEqual(value_feature, expected_value_feature) + self.assertEqual(feature_by_name, truncated_feature_by_name) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow_transform/tf_metadata/schema_utils_test_cases.py b/tensorflow_transform/tf_metadata/schema_utils_test_cases.py index 902873c..ccac9c6 100644 --- a/tensorflow_transform/tf_metadata/schema_utils_test_cases.py +++ b/tensorflow_transform/tf_metadata/schema_utils_test_cases.py @@ -15,73 +15,54 @@ import tensorflow as tf from google.protobuf import text_format - from tensorflow_metadata.proto.v0 import schema_pb2 EQUIVALENT_FEATURE_SPEC_AND_SCHEMAS = [ # Test different dtypes { - 'testcase_name': 'int', - 'ascii_proto': """feature: {name: "x" type: INT}""", - 'feature_spec': { - 'x': tf.io.VarLenFeature(tf.int64) - } + "testcase_name": "int", + "ascii_proto": """feature: {name: "x" type: INT}""", + "feature_spec": {"x": tf.io.VarLenFeature(tf.int64)}, }, { - 'testcase_name': 'string', - 'ascii_proto': """feature: {name: "x" type: BYTES}""", - 'feature_spec': { - 'x': tf.io.VarLenFeature(tf.string) - } + "testcase_name": "string", + "ascii_proto": """feature: {name: "x" type: BYTES}""", + "feature_spec": {"x": tf.io.VarLenFeature(tf.string)}, }, { - 'testcase_name': 'float', - 'ascii_proto': """feature: {name: "x" type: FLOAT}""", - 'feature_spec': { - 'x': tf.io.VarLenFeature(tf.float32) - } + "testcase_name": "float", + "ascii_proto": """feature: {name: "x" type: FLOAT}""", + "feature_spec": {"x": tf.io.VarLenFeature(tf.float32)}, }, # Test different shapes { - 'testcase_name': - 'fixed_len_vector', - 'ascii_proto': - """ + "testcase_name": "fixed_len_vector", + "ascii_proto": """ feature: { name: "x" type: INT shape: {dim {size: 1}} presence: {min_fraction: 1} } """, - 'feature_spec': { - 'x': tf.io.FixedLenFeature([1], tf.int64, None) - } + "feature_spec": {"x": tf.io.FixedLenFeature([1], tf.int64, None)}, }, { - 'testcase_name': - 'fixed_len_matrix', - 'ascii_proto': - """ + "testcase_name": "fixed_len_matrix", + "ascii_proto": """ feature: { name: "x" type: INT shape: {dim {size: 2} dim {size: 2}} presence: {min_fraction: 1} } """, - 'feature_spec': { - 'x': tf.io.FixedLenFeature([2, 2], tf.int64, None) - } + "feature_spec": {"x": tf.io.FixedLenFeature([2, 2], tf.int64, None)}, }, { - 'testcase_name': 'var_len', - 'ascii_proto': """feature: {name: "x" type: INT}""", - 'feature_spec': { - 'x': tf.io.VarLenFeature(tf.int64) - } + "testcase_name": "var_len", + "ascii_proto": """feature: {name: "x" type: INT}""", + "feature_spec": {"x": tf.io.VarLenFeature(tf.int64)}, }, { - 'testcase_name': - 'sparse', - 'ascii_proto': - """ + "testcase_name": "sparse", + "ascii_proto": """ feature { name: "index_key" type: INT @@ -97,19 +78,15 @@ value_feature {name: "value_key"} } """, - 'feature_spec': { - 'x': - tf.io.SparseFeature(['index_key'], - 'value_key', - tf.int64, [10], - already_sorted=False) - } + "feature_spec": { + "x": tf.io.SparseFeature( + ["index_key"], "value_key", tf.int64, [10], already_sorted=False + ) + }, }, { - 'testcase_name': - 'sparse_sorted', - 'ascii_proto': - """ + "testcase_name": "sparse_sorted", + "ascii_proto": """ feature { name: "index_key" type: INT @@ -126,71 +103,49 @@ value_feature {name: "value_key"} } """, - 'feature_spec': { - 'x': - tf.io.SparseFeature(['index_key'], - 'value_key', - tf.int64, [10], - already_sorted=True) - } + "feature_spec": { + "x": tf.io.SparseFeature( + ["index_key"], "value_key", tf.int64, [10], already_sorted=True + ) + }, }, # Test domains { - 'testcase_name': - 'int_domain', - 'ascii_proto': - """ + "testcase_name": "int_domain", + "ascii_proto": """ feature: { name: "x" type: INT int_domain {min: 0 max: 5 is_categorical: true} } """, - 'feature_spec': { - 'x': tf.io.VarLenFeature(tf.int64) - }, - 'domains': { - 'x': schema_pb2.IntDomain(min=0, max=5, is_categorical=True) - } + "feature_spec": {"x": tf.io.VarLenFeature(tf.int64)}, + "domains": {"x": schema_pb2.IntDomain(min=0, max=5, is_categorical=True)}, }, { - 'testcase_name': - 'string_domain', - 'ascii_proto': - """ + "testcase_name": "string_domain", + "ascii_proto": """ feature: { name: "x" type: BYTES string_domain {value: "a" value: "b"} } """, - 'feature_spec': { - 'x': tf.io.VarLenFeature(tf.string) - }, - 'domains': { - 'x': schema_pb2.StringDomain(value=['a', 'b']) - } + "feature_spec": {"x": tf.io.VarLenFeature(tf.string)}, + "domains": {"x": schema_pb2.StringDomain(value=["a", "b"])}, }, { - 'testcase_name': - 'float_domain', - 'ascii_proto': - """ + "testcase_name": "float_domain", + "ascii_proto": """ feature: { name: "x" type: FLOAT float_domain {min: 0.0 max: 0.5} } """, - 'feature_spec': { - 'x': tf.io.VarLenFeature(tf.float32) - }, - 'domains': { - 'x': schema_pb2.FloatDomain(min=0.0, max=0.5) - } + "feature_spec": {"x": tf.io.VarLenFeature(tf.float32)}, + "domains": {"x": schema_pb2.FloatDomain(min=0.0, max=0.5)}, }, { - 'testcase_name': - 'sparse_feature_rank_0', - 'ascii_proto': - """ + "testcase_name": "sparse_feature_rank_0", + "ascii_proto": """ feature { name: "value_key" type: INT @@ -200,15 +155,11 @@ value_feature {name: "value_key"} } """, - 'feature_spec': { - 'x': tf.io.SparseFeature([], 'value_key', tf.int64, []) - } + "feature_spec": {"x": tf.io.SparseFeature([], "value_key", tf.int64, [])}, }, { - 'testcase_name': - 'sparse_feature_rank_2', - 'ascii_proto': - """ + "testcase_name": "sparse_feature_rank_2", + "ascii_proto": """ feature { name: "index_key_1" type: INT @@ -230,17 +181,15 @@ value_feature {name: "value_key"} } """, - 'feature_spec': { - 'x': - tf.io.SparseFeature(['index_key_1', 'index_key_2'], 'value_key', - tf.int64, [1, 1]) - } + "feature_spec": { + "x": tf.io.SparseFeature( + ["index_key_1", "index_key_2"], "value_key", tf.int64, [1, 1] + ) + }, }, { - 'testcase_name': - 'sparse_feature_no_index_int_domain', - 'ascii_proto': - """ + "testcase_name": "sparse_feature_no_index_int_domain", + "ascii_proto": """ feature { name: "index_key" type: INT @@ -255,17 +204,13 @@ value_feature {name: "value_key"} } """, - 'feature_spec': { - 'x': - tf.io.SparseFeature(['index_key'], 'value_key', tf.int64, - [-1]) - } + "feature_spec": { + "x": tf.io.SparseFeature(["index_key"], "value_key", tf.int64, [-1]) + }, }, { - 'testcase_name': - 'ragged_float', - 'ascii_proto': - """ + "testcase_name": "ragged_float", + "ascii_proto": """ feature { name: "value" type: FLOAT @@ -284,20 +229,15 @@ } } """, - 'feature_spec': { - 'x': - tf.io.RaggedFeature( - tf.float32, - value_key='value', - partitions=[], - row_splits_dtype=tf.int64), + "feature_spec": { + "x": tf.io.RaggedFeature( + tf.float32, value_key="value", partitions=[], row_splits_dtype=tf.int64 + ), }, }, { - 'testcase_name': - 'ragged_int', - 'ascii_proto': - """ + "testcase_name": "ragged_int", + "ascii_proto": """ feature { name: "value" type: INT @@ -316,20 +256,15 @@ } } """, - 'feature_spec': { - 'x': - tf.io.RaggedFeature( - tf.int64, - value_key='value', - partitions=[], - row_splits_dtype=tf.int64), + "feature_spec": { + "x": tf.io.RaggedFeature( + tf.int64, value_key="value", partitions=[], row_splits_dtype=tf.int64 + ), }, }, { - 'testcase_name': - 'ragged_uniform_row_length', - 'ascii_proto': - """ + "testcase_name": "ragged_uniform_row_length", + "ascii_proto": """ feature { name: "value" type: FLOAT @@ -349,22 +284,22 @@ } } """, - 'feature_spec': { - 'x': - tf.io.RaggedFeature( - tf.float32, - value_key='value', - partitions=[ - tf.io.RaggedFeature.UniformRowLength(length=4), # pytype: disable=attribute-error - ], - row_splits_dtype=tf.int64), + "feature_spec": { + "x": tf.io.RaggedFeature( + tf.float32, + value_key="value", + partitions=[ + tf.io.RaggedFeature.UniformRowLength( + length=4 + ), # pytype: disable=attribute-error + ], + row_splits_dtype=tf.int64, + ), }, }, { - 'testcase_name': - 'ragged_uniform_row_length_3d', - 'ascii_proto': - """ + "testcase_name": "ragged_uniform_row_length_3d", + "ascii_proto": """ feature { name: "value" type: FLOAT @@ -389,23 +324,25 @@ } } """, - 'feature_spec': { - 'x': - tf.io.RaggedFeature( - tf.float32, - value_key='value', - partitions=[ - tf.io.RaggedFeature.RowLengths(key='row_length_1'), # pytype: disable=attribute-error - tf.io.RaggedFeature.UniformRowLength(length=4), # pytype: disable=attribute-error - ], - row_splits_dtype=tf.int64), + "feature_spec": { + "x": tf.io.RaggedFeature( + tf.float32, + value_key="value", + partitions=[ + tf.io.RaggedFeature.RowLengths( + key="row_length_1" + ), # pytype: disable=attribute-error + tf.io.RaggedFeature.UniformRowLength( + length=4 + ), # pytype: disable=attribute-error + ], + row_splits_dtype=tf.int64, + ), }, }, { - 'testcase_name': - 'ragged_row_lengths', - 'ascii_proto': - """ + "testcase_name": "ragged_row_lengths", + "ascii_proto": """ feature { name: "value" type: FLOAT @@ -434,23 +371,25 @@ } } """, - 'feature_spec': { - 'x': - tf.io.RaggedFeature( - tf.float32, - value_key='value', - partitions=[ - tf.io.RaggedFeature.RowLengths(key='row_length_1'), # pytype: disable=attribute-error - tf.io.RaggedFeature.RowLengths(key='row_length_2'), # pytype: disable=attribute-error - ], - row_splits_dtype=tf.int64), + "feature_spec": { + "x": tf.io.RaggedFeature( + tf.float32, + value_key="value", + partitions=[ + tf.io.RaggedFeature.RowLengths( + key="row_length_1" + ), # pytype: disable=attribute-error + tf.io.RaggedFeature.RowLengths( + key="row_length_2" + ), # pytype: disable=attribute-error + ], + row_splits_dtype=tf.int64, + ), }, }, { - 'testcase_name': - 'ragged_tensor_and_feature_same_name', - 'ascii_proto': - """ + "testcase_name": "ragged_tensor_and_feature_same_name", + "ascii_proto": """ feature { name: "ragged" type: FLOAT @@ -469,75 +408,53 @@ } } """, - 'feature_spec': { - 'ragged': - tf.io.RaggedFeature( - tf.float32, - value_key='ragged', - partitions=[], - row_splits_dtype=tf.int64), + "feature_spec": { + "ragged": tf.io.RaggedFeature( + tf.float32, value_key="ragged", partitions=[], row_splits_dtype=tf.int64 + ), }, }, ] NON_ROUNDTRIP_SCHEMAS = [ { - 'testcase_name': - 'deprecated_feature', - 'ascii_proto': - """ + "testcase_name": "deprecated_feature", + "ascii_proto": """ feature: {name: "x" type: INT lifecycle_stage: DEPRECATED} """, - 'feature_spec': {} + "feature_spec": {}, }, { - 'testcase_name': - 'schema_level_string_domain', - 'ascii_proto': - """ + "testcase_name": "schema_level_string_domain", + "ascii_proto": """ feature: {name: "x" type: BYTES domain: "my_domain"} string_domain {name: "my_domain" value: "a" value: "b"} """, - 'feature_spec': { - 'x': tf.io.VarLenFeature(tf.string) - }, - 'domains': { - 'x': schema_pb2.StringDomain(name='my_domain', value=['a', 'b']) - } + "feature_spec": {"x": tf.io.VarLenFeature(tf.string)}, + "domains": {"x": schema_pb2.StringDomain(name="my_domain", value=["a", "b"])}, }, { - 'testcase_name': - 'missing_schema_level_string_domain', - 'ascii_proto': - """ + "testcase_name": "missing_schema_level_string_domain", + "ascii_proto": """ feature: {name: "x" type: BYTES domain: "my_domain"} """, - 'feature_spec': { - 'x': tf.io.VarLenFeature(tf.string) - } + "feature_spec": {"x": tf.io.VarLenFeature(tf.string)}, }, { - 'testcase_name': - 'varlen_ragged', - 'ascii_proto': - """ + "testcase_name": "varlen_ragged", + "ascii_proto": """ feature: {name: "x" type: INT} represent_variable_length_as_ragged: true """, - 'feature_spec': { - 'x': - tf.io.RaggedFeature( - tf.int64, - value_key='x', - partitions=[], - row_splits_dtype=tf.int64) - } + "feature_spec": { + "x": tf.io.RaggedFeature( + tf.int64, value_key="x", partitions=[], row_splits_dtype=tf.int64 + ) + }, }, { - 'testcase_name': - 'sequence', - 'ascii_proto': - """ + "testcase_name": "sequence", + "ascii_proto": """ feature { name: "int_feature" type: INT @@ -579,23 +496,20 @@ } } """, - 'feature_spec': { - 'int_feature': - tf.io.VarLenFeature(dtype=tf.int64), - 'seq_int_feature': - tf.io.RaggedFeature( - dtype=tf.int64, - value_key='int_feature', - partitions=[], - row_splits_dtype=tf.int64, - validate=False), + "feature_spec": { + "int_feature": tf.io.VarLenFeature(dtype=tf.int64), + "seq_int_feature": tf.io.RaggedFeature( + dtype=tf.int64, + value_key="int_feature", + partitions=[], + row_splits_dtype=tf.int64, + validate=False, + ), }, }, { - 'testcase_name': - 'sequence_no_context', - 'ascii_proto': - """ + "testcase_name": "sequence_no_context", + "ascii_proto": """ feature { name: "##SEQUENCE##" type: STRUCT @@ -621,21 +535,19 @@ } } """, - 'feature_spec': { - 'x': - tf.io.RaggedFeature( - dtype=tf.int64, - value_key='x', - partitions=[], - row_splits_dtype=tf.int64, - validate=False), + "feature_spec": { + "x": tf.io.RaggedFeature( + dtype=tf.int64, + value_key="x", + partitions=[], + row_splits_dtype=tf.int64, + validate=False, + ), }, }, { - 'testcase_name': - 'sequence_with_domains', - 'ascii_proto': - """ + "testcase_name": "sequence_with_domains", + "ascii_proto": """ feature { name: "int_feature" type: INT @@ -679,27 +591,24 @@ } } """, - 'feature_spec': { - 'int_feature': - tf.io.VarLenFeature(dtype=tf.int64), - 'seq_float_feature': - tf.io.RaggedFeature( - dtype=tf.float32, - value_key='float_feature', - partitions=[], - row_splits_dtype=tf.int64, - validate=False), + "feature_spec": { + "int_feature": tf.io.VarLenFeature(dtype=tf.int64), + "seq_float_feature": tf.io.RaggedFeature( + dtype=tf.float32, + value_key="float_feature", + partitions=[], + row_splits_dtype=tf.int64, + validate=False, + ), + }, + "domains": { + "int_feature": schema_pb2.IntDomain(min=0, max=9), + "seq_float_feature": schema_pb2.FloatDomain(min=1.0), }, - 'domains': { - 'int_feature': schema_pb2.IntDomain(min=0, max=9), - 'seq_float_feature': schema_pb2.FloatDomain(min=1.0) - } }, { - 'testcase_name': - 'sequence_with_string_domain', - 'ascii_proto': - """ + "testcase_name": "sequence_with_string_domain", + "ascii_proto": """ feature { name: "int_feature" type: INT @@ -738,24 +647,21 @@ } } """, - 'feature_spec': { - 'int_feature': - tf.io.VarLenFeature(dtype=tf.int64), - 'seq_string_feature': - tf.io.RaggedFeature( - dtype=tf.string, - value_key='string_feature', - partitions=[], - row_splits_dtype=tf.int64, - validate=False), + "feature_spec": { + "int_feature": tf.io.VarLenFeature(dtype=tf.int64), + "seq_string_feature": tf.io.RaggedFeature( + dtype=tf.string, + value_key="string_feature", + partitions=[], + row_splits_dtype=tf.int64, + validate=False, + ), }, - 'domains': { - 'seq_string_feature': schema_pb2.StringDomain(value=['a', 'b']) - } + "domains": {"seq_string_feature": schema_pb2.StringDomain(value=["a", "b"])}, }, { - 'testcase_name': 'fixed_len_bytes_encoding', - 'ascii_proto': """ + "testcase_name": "fixed_len_bytes_encoding", + "ascii_proto": """ feature { name: "x" type: BYTES @@ -780,34 +686,29 @@ } } """, - 'feature_spec': {'x': tf.io.FixedLenFeature([1], tf.string, b'\xd0')}, + "feature_spec": {"x": tf.io.FixedLenFeature([1], tf.string, b"\xd0")}, }, ] INVALID_SCHEMA_PROTOS = [ { - 'testcase_name': 'no_type', - 'ascii_proto': """ + "testcase_name": "no_type", + "ascii_proto": """ feature: {name: "x"} """, - 'error_msg': 'The feature_type: 0 is not supported.' + "error_msg": "The feature_type: 0 is not supported.", }, { - 'testcase_name': - 'feature_has_shape_but_not_always_present', - 'ascii_proto': - """ + "testcase_name": "feature_has_shape_but_not_always_present", + "ascii_proto": """ feature: {name: "x" type: INT shape: {}} """, - 'error_msg': - r'Feature x had shape set but min_fraction 0.0 != 1. ' - r'Use value_count not shape field when min_fraction != 1.' + "error_msg": r"Feature x had shape set but min_fraction 0.0 != 1. " + r"Use value_count not shape field when min_fraction != 1.", }, { - 'testcase_name': - 'sparse_feature_no_index_int_domain_min', - 'ascii_proto': - """ + "testcase_name": "sparse_feature_no_index_int_domain_min", + "ascii_proto": """ feature { name: "index_key" type: INT @@ -823,16 +724,13 @@ value_feature {name: "value_key"} } """, - 'error_msg': - r'Cannot determine dense shape of sparse feature x. ' - r'The minimum domain value of index feature index_key' - r' is not set.' + "error_msg": r"Cannot determine dense shape of sparse feature x. " + r"The minimum domain value of index feature index_key" + r" is not set.", }, { - 'testcase_name': - 'sparse_feature_non_zero_index_int_domain_min', - 'ascii_proto': - """ + "testcase_name": "sparse_feature_non_zero_index_int_domain_min", + "ascii_proto": """ feature { name: "index_key" type: INT @@ -848,16 +746,13 @@ value_feature {name: "value_key"} } """, - 'error_msg': - r'Only 0-based index features are supported. Sparse ' - r'feature x has index feature index_key whose ' - r'minimum domain value is 1' + "error_msg": r"Only 0-based index features are supported. Sparse " + r"feature x has index feature index_key whose " + r"minimum domain value is 1", }, { - 'testcase_name': - 'sparse_feature_no_index_int_domain_max', - 'ascii_proto': - """ + "testcase_name": "sparse_feature_no_index_int_domain_max", + "ascii_proto": """ feature { name: "index_key" type: INT @@ -873,16 +768,13 @@ value_feature {name: "value_key"} } """, - 'error_msg': - r'Cannot determine dense shape of sparse feature x. ' - r'The maximum domain value of index feature index_key ' - r'is not set.' + "error_msg": r"Cannot determine dense shape of sparse feature x. " + r"The maximum domain value of index feature index_key " + r"is not set.", }, { - 'testcase_name': - 'sparse_feature_missing_index_key', - 'ascii_proto': - """ + "testcase_name": "sparse_feature_missing_index_key", + "ascii_proto": """ feature { name: "value_key" type: INT @@ -894,15 +786,12 @@ value_feature {name: "value_key"} } """, - 'error_msg': - r'sparse_feature x referred to index feature ' - r'index_key which did not exist in the schema' + "error_msg": r"sparse_feature x referred to index feature " + r"index_key which did not exist in the schema", }, { - 'testcase_name': - 'sparse_feature_missing_value_key', - 'ascii_proto': - """ + "testcase_name": "sparse_feature_missing_value_key", + "ascii_proto": """ feature { name: "index_key" type: INT @@ -915,170 +804,154 @@ value_feature {name: "value_key"} } """, - 'error_msg': - r'sparse_feature x referred to value feature ' - r'value_key which did not exist in the schema' + "error_msg": r"sparse_feature x referred to value feature " + r"value_key which did not exist in the schema", }, ] INVALID_FEATURE_SPECS = [ { - 'testcase_name': 'bad_type', - 'feature_spec': { - 'x': tf.io.FixedLenFeature([], tf.bool) - }, - 'error_msg': 'Feature "x" has invalid dtype' + "testcase_name": "bad_type", + "feature_spec": {"x": tf.io.FixedLenFeature([], tf.bool)}, + "error_msg": 'Feature "x" has invalid dtype', }, { - 'testcase_name': 'unsupported_type', - 'feature_spec': { - 'x': tf.io.FixedLenSequenceFeature([], tf.int64) - }, - 'error_msg': r'Spec for feature "x" was .* of type .*, expected a ' - r'FixedLenFeature, VarLenFeature or SparseFeature', - 'error_class': TypeError + "testcase_name": "unsupported_type", + "feature_spec": {"x": tf.io.FixedLenSequenceFeature([], tf.int64)}, + "error_msg": r'Spec for feature "x" was .* of type .*, expected a ' + r"FixedLenFeature, VarLenFeature or SparseFeature", + "error_class": TypeError, }, ] _FEATURE_BY_NAME = { - 'x': - text_format.Parse( - """ + "x": text_format.Parse( + """ name: "x" type: INT int_domain { min: 0 max: 9 } - """, schema_pb2.Feature()), - 'ragged$value': - text_format.Parse( - """ + """, + schema_pb2.Feature(), + ), + "ragged$value": text_format.Parse( + """ name: "ragged$value" type: FLOAT - """, schema_pb2.Feature()), - 'ragged$row_lengths_1': - text_format.Parse( - """ + """, + schema_pb2.Feature(), + ), + "ragged$row_lengths_1": text_format.Parse( + """ name: "ragged$row_lengths_1" type: INT - """, schema_pb2.Feature()), - 'ragged$row_lengths_2': - text_format.Parse( - """ + """, + schema_pb2.Feature(), + ), + "ragged$row_lengths_2": text_format.Parse( + """ name: "ragged$row_lengths_2" type: INT - """, schema_pb2.Feature()), + """, + schema_pb2.Feature(), + ), } RAGGED_VALUE_FEATURES_AND_TENSOR_REPRESENTATIONS = [ { - 'testcase_name': - '1d', - 'name': - 'ragged_1d', - 'tensor_representation': - text_format.Parse( - """ + "testcase_name": "1d", + "name": "ragged_1d", + "tensor_representation": text_format.Parse( + """ ragged_tensor { feature_path { step: "ragged$value" } } - """, schema_pb2.TensorRepresentation()), - 'feature_by_name': - _FEATURE_BY_NAME.copy(), - 'expected_value_feature': - _FEATURE_BY_NAME['ragged$value'], - 'truncated_feature_by_name': { - 'x': _FEATURE_BY_NAME['x'], - 'ragged$row_lengths_1': _FEATURE_BY_NAME['ragged$row_lengths_1'], - 'ragged$row_lengths_2': _FEATURE_BY_NAME['ragged$row_lengths_2'], + """, + schema_pb2.TensorRepresentation(), + ), + "feature_by_name": _FEATURE_BY_NAME.copy(), + "expected_value_feature": _FEATURE_BY_NAME["ragged$value"], + "truncated_feature_by_name": { + "x": _FEATURE_BY_NAME["x"], + "ragged$row_lengths_1": _FEATURE_BY_NAME["ragged$row_lengths_1"], + "ragged$row_lengths_2": _FEATURE_BY_NAME["ragged$row_lengths_2"], }, }, { - 'testcase_name': - '2d', - 'name': - 'ragged_2d', - 'tensor_representation': - text_format.Parse( - """ + "testcase_name": "2d", + "name": "ragged_2d", + "tensor_representation": text_format.Parse( + """ ragged_tensor { feature_path { step: "ragged$value" } partition { row_length: "ragged$row_lengths_1" } } - """, schema_pb2.TensorRepresentation()), - 'feature_by_name': - _FEATURE_BY_NAME.copy(), - 'expected_value_feature': - _FEATURE_BY_NAME['ragged$value'], - 'truncated_feature_by_name': { - 'x': _FEATURE_BY_NAME['x'], - 'ragged$row_lengths_2': _FEATURE_BY_NAME['ragged$row_lengths_2'], + """, + schema_pb2.TensorRepresentation(), + ), + "feature_by_name": _FEATURE_BY_NAME.copy(), + "expected_value_feature": _FEATURE_BY_NAME["ragged$value"], + "truncated_feature_by_name": { + "x": _FEATURE_BY_NAME["x"], + "ragged$row_lengths_2": _FEATURE_BY_NAME["ragged$row_lengths_2"], }, }, { - 'testcase_name': - '3d', - 'name': - 'ragged_3d', - 'tensor_representation': - text_format.Parse( - """ + "testcase_name": "3d", + "name": "ragged_3d", + "tensor_representation": text_format.Parse( + """ ragged_tensor { feature_path { step: "ragged$value" } partition { row_length: "ragged$row_lengths_1" } partition { row_length: "ragged$row_lengths_2" } } - """, schema_pb2.TensorRepresentation()), - 'feature_by_name': - _FEATURE_BY_NAME.copy(), - 'expected_value_feature': - _FEATURE_BY_NAME['ragged$value'], - 'truncated_feature_by_name': { - 'x': _FEATURE_BY_NAME['x'], + """, + schema_pb2.TensorRepresentation(), + ), + "feature_by_name": _FEATURE_BY_NAME.copy(), + "expected_value_feature": _FEATURE_BY_NAME["ragged$value"], + "truncated_feature_by_name": { + "x": _FEATURE_BY_NAME["x"], }, }, { - 'testcase_name': - 'uniform', - 'name': - 'ragged_uniform', - 'tensor_representation': - text_format.Parse( - """ + "testcase_name": "uniform", + "name": "ragged_uniform", + "tensor_representation": text_format.Parse( + """ ragged_tensor { feature_path { step: "ragged$value" } partition { uniform_row_length: 3 } } - """, schema_pb2.TensorRepresentation()), - 'feature_by_name': - _FEATURE_BY_NAME.copy(), - 'expected_value_feature': - _FEATURE_BY_NAME['ragged$value'], - 'truncated_feature_by_name': { - 'x': _FEATURE_BY_NAME['x'], - 'ragged$row_lengths_1': _FEATURE_BY_NAME['ragged$row_lengths_1'], - 'ragged$row_lengths_2': _FEATURE_BY_NAME['ragged$row_lengths_2'], + """, + schema_pb2.TensorRepresentation(), + ), + "feature_by_name": _FEATURE_BY_NAME.copy(), + "expected_value_feature": _FEATURE_BY_NAME["ragged$value"], + "truncated_feature_by_name": { + "x": _FEATURE_BY_NAME["x"], + "ragged$row_lengths_1": _FEATURE_BY_NAME["ragged$row_lengths_1"], + "ragged$row_lengths_2": _FEATURE_BY_NAME["ragged$row_lengths_2"], }, }, { - 'testcase_name': - 'uniform_3d', - 'name': - 'ragged_uniform_3d', - 'tensor_representation': - text_format.Parse( - """ + "testcase_name": "uniform_3d", + "name": "ragged_uniform_3d", + "tensor_representation": text_format.Parse( + """ ragged_tensor { feature_path { step: "ragged$value" } partition { row_length: "ragged$row_lengths_1" } partition { uniform_row_length: 3 } } - """, schema_pb2.TensorRepresentation()), - 'feature_by_name': - _FEATURE_BY_NAME.copy(), - 'expected_value_feature': - _FEATURE_BY_NAME['ragged$value'], - 'truncated_feature_by_name': { - 'x': _FEATURE_BY_NAME['x'], - 'ragged$row_lengths_2': _FEATURE_BY_NAME['ragged$row_lengths_2'], + """, + schema_pb2.TensorRepresentation(), + ), + "feature_by_name": _FEATURE_BY_NAME.copy(), + "expected_value_feature": _FEATURE_BY_NAME["ragged$value"], + "truncated_feature_by_name": { + "x": _FEATURE_BY_NAME["x"], + "ragged$row_lengths_2": _FEATURE_BY_NAME["ragged$row_lengths_2"], }, }, ] diff --git a/tensorflow_transform/tf_metadata/test_common.py b/tensorflow_transform/tf_metadata/test_common.py index f00fb0d..a03cd27 100644 --- a/tensorflow_transform/tf_metadata/test_common.py +++ b/tensorflow_transform/tf_metadata/test_common.py @@ -17,27 +17,18 @@ from tensorflow_transform.tf_metadata import schema_utils - test_feature_spec = { # FixedLenFeatures - 'fixed_categorical_int_with_range': - tf.io.FixedLenFeature(shape=[], dtype=tf.int64), - 'fixed_int': - tf.io.FixedLenFeature(shape=[5], dtype=tf.int64), - 'fixed_float': - tf.io.FixedLenFeature(shape=[5], dtype=tf.float32), - 'fixed_string': - tf.io.FixedLenFeature(shape=[5], dtype=tf.string), - + "fixed_categorical_int_with_range": tf.io.FixedLenFeature(shape=[], dtype=tf.int64), + "fixed_int": tf.io.FixedLenFeature(shape=[5], dtype=tf.int64), + "fixed_float": tf.io.FixedLenFeature(shape=[5], dtype=tf.float32), + "fixed_string": tf.io.FixedLenFeature(shape=[5], dtype=tf.string), # VarLenFeatures - 'var_int': - tf.io.VarLenFeature(dtype=tf.int64), - 'var_float': - tf.io.VarLenFeature(dtype=tf.float32), - 'var_string': - tf.io.VarLenFeature(dtype=tf.string), + "var_int": tf.io.VarLenFeature(dtype=tf.int64), + "var_float": tf.io.VarLenFeature(dtype=tf.float32), + "var_string": tf.io.VarLenFeature(dtype=tf.string), } def get_test_schema(): - return schema_utils.schema_from_feature_spec(test_feature_spec) + return schema_utils.schema_from_feature_spec(test_feature_spec) diff --git a/tensorflow_transform/tf_utils.py b/tensorflow_transform/tf_utils.py index f717b86..229577f 100644 --- a/tensorflow_transform/tf_utils.py +++ b/tensorflow_transform/tf_utils.py @@ -19,181 +19,187 @@ from typing import Callable, Optional, Sequence, Tuple, Union import tensorflow as tf -from tensorflow_transform import annotators -from tensorflow_transform import common_types -# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` -# once the Spark issue is resolved. -from tfx_bsl.types import tfx_namedtuple # pylint: disable=g-direct-tensorflow-import -from tensorflow.python.framework import composite_tensor -from tensorflow.python.framework import ops +from tensorflow.python.framework import composite_tensor, ops from tensorflow.python.ops import lookup_ops from tensorflow.python.util import object_identity + +# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple` +# once the Spark issue is resolved. +from tfx_bsl.types import tfx_namedtuple + +from tensorflow_transform import annotators, common_types + # pylint: enable=g-direct-tensorflow-import _AssetFileType = Union[tf.Tensor, str] -_FLOATING_NAN = float('nan') +_FLOATING_NAN = float("nan") # Global sentinels used to keep track of the total counts of y -GLOBAL_Y_COUNT_SENTINEL_STRING = b'global_y_count_sentinel' +GLOBAL_Y_COUNT_SENTINEL_STRING = b"global_y_count_sentinel" GLOBAL_Y_COUNT_SENTINEL_INT = tf.int64.limits[1] # Key for graph collection containing tuple of a key to the eager tensor # representing asset path and the graph tensor tracking the analyzer in # `analyzer_nodes.TENSOR_REPLACEMENTS`. -_ASSET_REPLACEMENTS = 'tft_asset_replacements' +_ASSET_REPLACEMENTS = "tft_asset_replacements" # Key for graph collection containing string IDs for vocabulary extra tokens. -_VOCABULARY_RESERVED_TOKENS_IDS = 'tft_vocab_extra_tokens_ids' +_VOCABULARY_RESERVED_TOKENS_IDS = "tft_vocab_extra_tokens_ids" # Key for graph collection containing extra tokens to include in a vocabulary. -_VOCABULARY_RESERVED_TOKENS = 'tft_vocab_extra_tokens' +_VOCABULARY_RESERVED_TOKENS = "tft_vocab_extra_tokens" # The default value used when densifying sparse labels. _MISSING_MASK_VALUE = -1 ReducedBatchWeightedCounts = tfx_namedtuple.namedtuple( - 'ReducedBatchCounts', + "ReducedBatchCounts", [ - 'unique_x', - 'summed_weights_per_x', - 'summed_positive_per_x_and_y', - 'counts_per_x', + "unique_x", + "summed_weights_per_x", + "summed_positive_per_x_and_y", + "counts_per_x", ], ) _CompositeTensorRef = tfx_namedtuple.namedtuple( - '_CompositeTensorRef', ['type_spec', 'list_of_refs'] + "_CompositeTensorRef", ["type_spec", "list_of_refs"] ) @dataclasses.dataclass(frozen=True) class _BatchWeightedVocabulary: - """Contains values of a vocabulary for categorical features. + """Contains values of a vocabulary for categorical features. - * unique_x_values: All distinct values of X observed within the batch. - * summed_weights_per_x: The sum of weights for each value of X. - * unique_count: An array of the same size as unique_x_values reflecting the - frequency of each unique x value within the batch. - * unique_idx: An array the size of the original bag, where each value is an - integer from 0 to N corresponding (positionally) to a value in - unique_x_values. - * max_x_idx: The total number of distinct elements seen in the batch. - """ + * unique_x_values: All distinct values of X observed within the batch. + * summed_weights_per_x: The sum of weights for each value of X. + * unique_count: An array of the same size as unique_x_values reflecting the + frequency of each unique x value within the batch. + * unique_idx: An array the size of the original bag, where each value is an + integer from 0 to N corresponding (positionally) to a value in + unique_x_values. + * max_x_idx: The total number of distinct elements seen in the batch. + """ - unique_x_values: tf.Tensor - summed_weights_per_x: tf.Tensor - unique_count: int - unique_idx: int - max_x_idx: int + unique_x_values: tf.Tensor + summed_weights_per_x: tf.Tensor + unique_count: int + unique_idx: int + max_x_idx: int def get_values(x: common_types.TensorType) -> tf.Tensor: - """Extracts values if the given tensor is composite.""" - if isinstance(x, tf.SparseTensor): - return x.values - elif isinstance(x, tf.RaggedTensor): - return x.flat_values - else: - return x + """Extracts values if the given tensor is composite.""" + if isinstance(x, tf.SparseTensor): + return x.values + elif isinstance(x, tf.RaggedTensor): + return x.flat_values + else: + return x def copy_tensors(tensors): - """Makes deep copies of a dict of tensors. - - Makes deep copies (using tf.identity or its equivalent for `CompositeTensor`s) - of the values of `tensors`. - - Args: - tensors: A a dict whose keys are strings and values are `Tensors`s or - `CompositeTensor`s. - - Returns: - A copy of `tensors` with values replaced by tf.identity applied to the - value, or the equivalent for `CompositeTensor`s. - """ - return { - name: _copy_tensor_or_composite_tensor(tensor) - for name, tensor in tensors.items() - } + """Makes deep copies of a dict of tensors. + + Makes deep copies (using tf.identity or its equivalent for `CompositeTensor`s) + of the values of `tensors`. + + Args: + ---- + tensors: A a dict whose keys are strings and values are `Tensors`s or + `CompositeTensor`s. + + Returns: + ------- + A copy of `tensors` with values replaced by tf.identity applied to the + value, or the equivalent for `CompositeTensor`s. + """ + return { + name: _copy_tensor_or_composite_tensor(tensor) + for name, tensor in tensors.items() + } def _copy_tensor(tensor): - return tf.identity(tensor, name='{}_copy'.format(tensor.op.name)) + return tf.identity(tensor, name=f"{tensor.op.name}_copy") def _copy_tensor_or_composite_tensor(tensor): - if isinstance(tensor, composite_tensor.CompositeTensor): - return tf.nest.map_structure(_copy_tensor, tensor, expand_composites=True) - return _copy_tensor(tensor) + if isinstance(tensor, composite_tensor.CompositeTensor): + return tf.nest.map_structure(_copy_tensor, tensor, expand_composites=True) + return _copy_tensor(tensor) def _get_ragged_batch_value_rowids(tensor: tf.RaggedTensor) -> tf.Tensor: - nested_value_rowids = tensor.nested_value_rowids() - result = nested_value_rowids[-1] - for value_rowids in reversed(nested_value_rowids[:-1]): - result = tf.gather(value_rowids, result) - return result + nested_value_rowids = tensor.nested_value_rowids() + result = nested_value_rowids[-1] + for value_rowids in reversed(nested_value_rowids[:-1]): + result = tf.gather(value_rowids, result) + return result def _make_regex_filter_fn( - x: tf.Tensor, - filter_regex: Optional[str]) -> Callable[[tf.Tensor], tf.Tensor]: - """Returns a filter function that applies `x`'s mask.""" - if filter_regex is None: - return lambda values: values - else: - if x.dtype != tf.string: - raise ValueError('Regex filtering is only possible with string input, ' - f'got {x.dtype}') - filter_mask = tf.logical_not(tf.strings.regex_full_match(x, filter_regex)) - return lambda values: tf.boolean_mask(values, filter_mask) + x: tf.Tensor, filter_regex: Optional[str] +) -> Callable[[tf.Tensor], tf.Tensor]: + """Returns a filter function that applies `x`'s mask.""" + if filter_regex is None: + return lambda values: values + else: + if x.dtype != tf.string: + raise ValueError( + "Regex filtering is only possible with string input, " f"got {x.dtype}" + ) + filter_mask = tf.logical_not(tf.strings.regex_full_match(x, filter_regex)) + return lambda values: tf.boolean_mask(values, filter_mask) def reduce_batch_weighted_counts( x: common_types.TensorType, weights: Optional[tf.Tensor] = None, force: bool = False, - filter_regex: Optional[str] = None) -> ReducedBatchWeightedCounts: - """Performs batch-wise reduction to produce (possibly weighted) counts. - - Args: - x: Input `Tensor` or `CompositeTensor`. - weights: (Optional) Input weights. - force: If True, reduces input tensor without weights to unique elements and - counts. - filter_regex: (Optional) Regex that matches tokens that have to be filtered - out. May only be specified if `x` has string dtype. - - Returns: - a named tuple of... - The unique values in x - The sum of the weights for each unique value in x if weights are provided, - else None - """ - if isinstance(x, tf.SparseTensor): - x = x.values - elif isinstance(x, tf.RaggedTensor): - x = x.flat_values - flat_x = tf.reshape(x, [-1]) - filter_fn = _make_regex_filter_fn(flat_x, filter_regex) - flat_x = filter_fn(flat_x) - if weights is None: - if force: - unique, _, counts = tf.unique_with_counts(flat_x) - return ReducedBatchWeightedCounts(unique, None, None, counts) - else: - # TODO(b/112916494): Always do batch wise reduction once possible. - return ReducedBatchWeightedCounts(flat_x, None, None, None) - weights = filter_fn(tf.reshape(weights, [-1])) - unique_x_values, unique_idx, _ = tf.unique_with_counts( - flat_x, out_idx=tf.int64) - summed_weights_per_x = tf.math.unsorted_segment_sum( - weights, unique_idx, tf.size(input=unique_x_values)) - return ReducedBatchWeightedCounts(unique_x_values, summed_weights_per_x, None, - None) + filter_regex: Optional[str] = None, +) -> ReducedBatchWeightedCounts: + """Performs batch-wise reduction to produce (possibly weighted) counts. + + Args: + ---- + x: Input `Tensor` or `CompositeTensor`. + weights: (Optional) Input weights. + force: If True, reduces input tensor without weights to unique elements and + counts. + filter_regex: (Optional) Regex that matches tokens that have to be filtered + out. May only be specified if `x` has string dtype. + + Returns: + ------- + a named tuple of... + The unique values in x + The sum of the weights for each unique value in x if weights are provided, + else None + """ + if isinstance(x, tf.SparseTensor): + x = x.values + elif isinstance(x, tf.RaggedTensor): + x = x.flat_values + flat_x = tf.reshape(x, [-1]) + filter_fn = _make_regex_filter_fn(flat_x, filter_regex) + flat_x = filter_fn(flat_x) + if weights is None: + if force: + unique, _, counts = tf.unique_with_counts(flat_x) + return ReducedBatchWeightedCounts(unique, None, None, counts) + else: + # TODO(b/112916494): Always do batch wise reduction once possible. + return ReducedBatchWeightedCounts(flat_x, None, None, None) + weights = filter_fn(tf.reshape(weights, [-1])) + unique_x_values, unique_idx, _ = tf.unique_with_counts(flat_x, out_idx=tf.int64) + summed_weights_per_x = tf.math.unsorted_segment_sum( + weights, unique_idx, tf.size(input=unique_x_values) + ) + return ReducedBatchWeightedCounts(unique_x_values, summed_weights_per_x, None, None) def reduce_batch_weighted_cooccurrences( @@ -201,417 +207,433 @@ def reduce_batch_weighted_cooccurrences( y_input: Union[tf.Tensor, tf.SparseTensor], weights_input: Optional[tf.Tensor] = None, extend_with_sentinel_counts: bool = True, - filter_regex: Optional[str] = None) -> ReducedBatchWeightedCounts: - """Performs batch-wise reduction to produce weighted co-occurrences. - - Computes the weighted co-occurrence of each feature value in x, for each value - in the range [0, max(y)). If extend_with_sentinel_counts is true, the return - value will include an additional sentinel token (not in the true vocabulary) - that is used to accumulate the global distribution of y values. - - Args: - x_input: Input `Tensor` or `CompositeTensor`. - y_input: Integer `Tensor` or `SparseTensor` (in case of - multi-label/multi-task) with which to compute the co-occurrence with - x_input. - weights_input: (Optional) Weights input `Tensor`. - extend_with_sentinel_counts: If True, the reduced batch will be extended - a sentinel value that accumlate the total distribution of y values. Should - be True except when called recursively with the sentinel value as input. - filter_regex: (Optional) Regex that matches tokens that have to be filtered - out. Can only be specified if `x_input` has string dtype. - - Returns: - a namedtuple of... - unique_x_values: the unique values in x - summed_weights_per_x: sum of the weights for each unique value in x - summed_positive_per_x_and_y: If tensor y is provided, the sum of - positive weights for each unique y value, for each unique value in x. - If y tensor is not provided, value is None. - counts_per_x: if y is provided, counts of each of the unique values in x, - otherwise, None. - """ - - if isinstance(y_input, tf.SparseTensor): - # This is a multi-label/multi-task problem. - - # All cooccurrences tensors will have the same size to allow merging. - # This should not influence the value of MI. - max_y_value = tf.cast( - tf.reduce_max(input_tensor=tf.sparse.to_dense(y_input)), tf.int64 - ) - # To handle this case, we densify the label with a padding value to compute - # the cooccurences treating each "column" in the multivariant label as a - # distinct Y label. Then we reduce by summing. - result = tf.map_fn( - lambda y_column: _compute_weighted_counts( # pylint: disable=g-long-lambda - x_input, y_column, weights_input, filter_regex, max_y_value - ), - # Set the default dense value to -1 to avoid confusing missing labels - # with class 0. - # Per segment_sum docs: If the given segment ID i is negative, the value - # is dropped and will not be added to the sum of the segment. - tf.transpose( - tf.sparse.to_dense( - tf.cast(y_input, tf.int64), default_value=_MISSING_MASK_VALUE - ) - ), - fn_output_signature={ - 'unique_x_values': tf.TensorSpec( - shape=[None], - dtype=x_input.dtype, - ), - 'summed_weights_per_x': tf.TensorSpec( - shape=[None], - dtype=tf.float32, - ), - 'summed_positive_per_x_and_y': tf.TensorSpec( - shape=[None, None], - dtype=tf.float32, + filter_regex: Optional[str] = None, +) -> ReducedBatchWeightedCounts: + """Performs batch-wise reduction to produce weighted co-occurrences. + + Computes the weighted co-occurrence of each feature value in x, for each value + in the range [0, max(y)). If extend_with_sentinel_counts is true, the return + value will include an additional sentinel token (not in the true vocabulary) + that is used to accumulate the global distribution of y values. + + Args: + ---- + x_input: Input `Tensor` or `CompositeTensor`. + y_input: Integer `Tensor` or `SparseTensor` (in case of + multi-label/multi-task) with which to compute the co-occurrence with + x_input. + weights_input: (Optional) Weights input `Tensor`. + extend_with_sentinel_counts: If True, the reduced batch will be extended + a sentinel value that accumlate the total distribution of y values. Should + be True except when called recursively with the sentinel value as input. + filter_regex: (Optional) Regex that matches tokens that have to be filtered + out. Can only be specified if `x_input` has string dtype. + + Returns: + ------- + a namedtuple of... + unique_x_values: the unique values in x + summed_weights_per_x: sum of the weights for each unique value in x + summed_positive_per_x_and_y: If tensor y is provided, the sum of + positive weights for each unique y value, for each unique value in x. + If y tensor is not provided, value is None. + counts_per_x: if y is provided, counts of each of the unique values in x, + otherwise, None. + """ + if isinstance(y_input, tf.SparseTensor): + # This is a multi-label/multi-task problem. + + # All cooccurrences tensors will have the same size to allow merging. + # This should not influence the value of MI. + max_y_value = tf.cast( + tf.reduce_max(input_tensor=tf.sparse.to_dense(y_input)), tf.int64 + ) + # To handle this case, we densify the label with a padding value to compute + # the cooccurences treating each "column" in the multivariant label as a + # distinct Y label. Then we reduce by summing. + result = tf.map_fn( + lambda y_column: _compute_weighted_counts( # pylint: disable=g-long-lambda + x_input, y_column, weights_input, filter_regex, max_y_value ), - 'unique_count': tf.TensorSpec( - shape=[None], - dtype=tf.int64, + # Set the default dense value to -1 to avoid confusing missing labels + # with class 0. + # Per segment_sum docs: If the given segment ID i is negative, the value + # is dropped and will not be added to the sum of the segment. + tf.transpose( + tf.sparse.to_dense( + tf.cast(y_input, tf.int64), default_value=_MISSING_MASK_VALUE + ) ), - }, - ) - # summed_positive_per_x_and_y requires summing across the N "sub-labels". - # For the other values, they should be a constant for each invocation and - # thus we simply take the first: these values are computed on the X tensor - # only, and won't depend on y_column (on which we're iterating). - result['summed_positive_per_x_and_y'] = tf.reduce_sum( - result['summed_positive_per_x_and_y'], axis=0 - ) - result['unique_x_values'] = result['unique_x_values'][0] - result['summed_weights_per_x'] = result['summed_weights_per_x'][0] - result['unique_count'] = result['unique_count'][0] - else: - max_y_value = tf.cast(tf.reduce_max(input_tensor=y_input), tf.int64) - result = _compute_weighted_counts( - x_input, y_input, weights_input, filter_regex, max_y_value - ) - reduced_batch = ReducedBatchWeightedCounts( - unique_x=result['unique_x_values'], - summed_weights_per_x=result['summed_weights_per_x'], - summed_positive_per_x_and_y=result['summed_positive_per_x_and_y'], - counts_per_x=result['unique_count'], - ) - # Add a sentinel token tracking the full distribution of y values. - if extend_with_sentinel_counts: - reduced_batch = extend_reduced_batch_with_y_counts( - reduced_batch, y_input, weights_input + fn_output_signature={ + "unique_x_values": tf.TensorSpec( + shape=[None], + dtype=x_input.dtype, + ), + "summed_weights_per_x": tf.TensorSpec( + shape=[None], + dtype=tf.float32, + ), + "summed_positive_per_x_and_y": tf.TensorSpec( + shape=[None, None], + dtype=tf.float32, + ), + "unique_count": tf.TensorSpec( + shape=[None], + dtype=tf.int64, + ), + }, + ) + # summed_positive_per_x_and_y requires summing across the N "sub-labels". + # For the other values, they should be a constant for each invocation and + # thus we simply take the first: these values are computed on the X tensor + # only, and won't depend on y_column (on which we're iterating). + result["summed_positive_per_x_and_y"] = tf.reduce_sum( + result["summed_positive_per_x_and_y"], axis=0 + ) + result["unique_x_values"] = result["unique_x_values"][0] + result["summed_weights_per_x"] = result["summed_weights_per_x"][0] + result["unique_count"] = result["unique_count"][0] + else: + max_y_value = tf.cast(tf.reduce_max(input_tensor=y_input), tf.int64) + result = _compute_weighted_counts( + x_input, y_input, weights_input, filter_regex, max_y_value + ) + reduced_batch = ReducedBatchWeightedCounts( + unique_x=result["unique_x_values"], + summed_weights_per_x=result["summed_weights_per_x"], + summed_positive_per_x_and_y=result["summed_positive_per_x_and_y"], + counts_per_x=result["unique_count"], ) - return reduced_batch + # Add a sentinel token tracking the full distribution of y values. + if extend_with_sentinel_counts: + reduced_batch = extend_reduced_batch_with_y_counts( + reduced_batch, y_input, weights_input + ) + return reduced_batch def _compute_weighted_counts( x_input, y_input, weights_input, filter_regex, max_y_value ): - """Computes weighted counts of feature/label values to support AMI. - - Args: - x_input: Input `Tensor` or `CompositeTensor`. - y_input: Integer `Tensor` with which to compute the co-occurrence with - x_input. - weights_input: (Optional) Weights input `Tensor`. - filter_regex: (Optional) Regex that matches tokens that have to be filtered - out. Can only be specified if `x_input` has string dtype. - max_y_value: The maximum index for labels (basically, the number of labels). - - Returns: - A dictionary containing: - - summed_positive_per_x_and_y: The occurrences of each X value, for each - label value. - - unique_x_values: All the X values in the vocabulary. - - summed_weights_per_x: The sum of weights for each value in - unique_x_values. - - unique_count: The count of the values seen for X - (basically, len(unique_x_values)). - """ - # TODO(b/297854080): Evaluate the redundant calls. - x, y, weights = _preprocess_tensors_for_cooccurences( - x_input, y_input, weights_input, filter_regex - ) - # TODO(b/297854080): Evaluate the redundant calls. - vocabulary = _compute_vocabulary_values(x, weights) - - # Get a mask for the missing labels - missing_labels_mask = tf.equal(y, _MISSING_MASK_VALUE) - # Ultimately we want to compute the summed weights per x and y, as a - # 2D [x_dim, y_dim] tensor. To accomplish this with a single call to - # tf.math.unsorted_segment_sum, we flatten the [x_dim, y_dim] into a single - # x*y dim vector and then re-shape back into 2d matrix form after computing - # the segmented sum. - flattened_index = (max_y_value + 1) * vocabulary.unique_idx + y - # and apply it to the flattened index to keep missing values at -1. - flattened_index = tf.where( - missing_labels_mask, - tf.constant(_MISSING_MASK_VALUE, tf.int64), - flattened_index, - ) - - summed_positive_per_x_and_y = _compute_summed_positive_per_x_and_y( - weights, flattened_index, vocabulary.max_x_idx, max_y_value - ) - return { - 'summed_positive_per_x_and_y': tf.cast( - summed_positive_per_x_and_y, tf.float32 - ), - 'unique_x_values': vocabulary.unique_x_values, - 'summed_weights_per_x': tf.cast( - vocabulary.summed_weights_per_x, tf.float32 - ), - 'unique_count': tf.cast(vocabulary.unique_count, tf.int64), - } - - -def _preprocess_tensors_for_cooccurences( - x_input, y_input, weights_input, filter_regex -): - """Wrangles the tensors to make them compatible with AMI computation. - - Args: - x_input: Input `Tensor` or `CompositeTensor`. - y_input: Integer `Tensor` with which to compute the co-occurrence with - x_input. - weights_input: (Optional) Weights input `Tensor`. - filter_regex: (Optional) Regex that matches tokens that have to be filtered - out. Can only be specified if `x_input` has string dtype. - - Returns: - A tuple containing the re-shaped, verified (x, y, weights). - """ - tf.compat.v1.assert_type(y_input, tf.int64) - # TODO(b/134075780): Revisit expected weights shape when input is sparse. - if isinstance(x_input, tf.SparseTensor): - batch_indices = x_input.indices[:, 0] - # y and densified x should have the same batch dimension. - assert_eq = tf.compat.v1.assert_equal( - tf.shape(y_input)[0], tf.cast(x_input.dense_shape[0], tf.int32)) - with tf.control_dependencies([assert_eq]): - y = tf.gather(y_input, batch_indices) - x = x_input.values - elif isinstance(x_input, tf.RaggedTensor): - # Each batch instance in x corresponds to a single value in y. - x_row_indices = _get_ragged_batch_value_rowids(x_input) - assert_compatible = tf.debugging.assert_greater_equal( - tf.shape(y_input, out_type=tf.int64)[0], x_input.bounding_shape(axis=0)) - with tf.control_dependencies([assert_compatible]): - x = tf.ensure_shape(x_input.flat_values, [None]) - y = tf.gather(y_input, x_row_indices) - else: - y = y_input - x = x_input - if weights_input is None: - weights = tf.ones_like(x, dtype=tf.float32) - else: - x, weights_input = assert_same_shape(x, weights_input) - weights = weights_input - y = _broadcast_to_x_shape(x, y) - x = tf.reshape(x, [-1]) - filter_fn = _make_regex_filter_fn(x, filter_regex) - x = filter_fn(x) - y = filter_fn(tf.reshape(y, [-1])) - weights = filter_fn(tf.reshape(weights, [-1])) - return x, y, weights - - -def _compute_summed_positive_per_x_and_y( - weights, dummy_index, max_x_idx, max_y_value -): - """Computes a segment sum to retrieve weighted co-occurrences of X and Y. - - Args: - weights: The example weights, or a ones_like Tensor (=unweighted). - dummy_index: An index the size of X containing shifted Y values. See - cl/251659477. - max_x_idx: The total number of distinct elements seen in the batch. - max_y_value: The maximum id for Y values. Basically, the number of distinct - values in Y. - - Returns: - A 2D Tensor of shape [x_dim, y_dim] containing the weighted cooccurences of - each Y value for each X value. - """ - summed_positive_per_x_and_y = tf.cast( - tf.math.unsorted_segment_sum( - weights, dummy_index, max_x_idx * (max_y_value + 1) - ), - dtype=tf.float32, - ) - summed_positive_per_x_and_y = tf.reshape( - summed_positive_per_x_and_y, [max_x_idx, max_y_value + 1] - ) - return summed_positive_per_x_and_y + """Computes weighted counts of feature/label values to support AMI. + + Args: + ---- + x_input: Input `Tensor` or `CompositeTensor`. + y_input: Integer `Tensor` with which to compute the co-occurrence with + x_input. + weights_input: (Optional) Weights input `Tensor`. + filter_regex: (Optional) Regex that matches tokens that have to be filtered + out. Can only be specified if `x_input` has string dtype. + max_y_value: The maximum index for labels (basically, the number of labels). + + Returns: + ------- + A dictionary containing: + - summed_positive_per_x_and_y: The occurrences of each X value, for each + label value. + - unique_x_values: All the X values in the vocabulary. + - summed_weights_per_x: The sum of weights for each value in + unique_x_values. + - unique_count: The count of the values seen for X + (basically, len(unique_x_values)). + """ + # TODO(b/297854080): Evaluate the redundant calls. + x, y, weights = _preprocess_tensors_for_cooccurences( + x_input, y_input, weights_input, filter_regex + ) + # TODO(b/297854080): Evaluate the redundant calls. + vocabulary = _compute_vocabulary_values(x, weights) + + # Get a mask for the missing labels + missing_labels_mask = tf.equal(y, _MISSING_MASK_VALUE) + # Ultimately we want to compute the summed weights per x and y, as a + # 2D [x_dim, y_dim] tensor. To accomplish this with a single call to + # tf.math.unsorted_segment_sum, we flatten the [x_dim, y_dim] into a single + # x*y dim vector and then re-shape back into 2d matrix form after computing + # the segmented sum. + flattened_index = (max_y_value + 1) * vocabulary.unique_idx + y + # and apply it to the flattened index to keep missing values at -1. + flattened_index = tf.where( + missing_labels_mask, + tf.constant(_MISSING_MASK_VALUE, tf.int64), + flattened_index, + ) + summed_positive_per_x_and_y = _compute_summed_positive_per_x_and_y( + weights, flattened_index, vocabulary.max_x_idx, max_y_value + ) + return { + "summed_positive_per_x_and_y": tf.cast(summed_positive_per_x_and_y, tf.float32), + "unique_x_values": vocabulary.unique_x_values, + "summed_weights_per_x": tf.cast(vocabulary.summed_weights_per_x, tf.float32), + "unique_count": tf.cast(vocabulary.unique_count, tf.int64), + } -def _compute_vocabulary_values(x, weights) -> _BatchWeightedVocabulary: - """Computes a vocabulary for X values. - - Args: - x: Input `Tensor` or `CompositeTensor`. - weights: Weights input `Tensor`. - - Returns: - An XVocabulary object containing the results of the computation. See more in - the XVocabulary docstring. - """ - unique_x_values, unique_idx, unique_count = tf.unique_with_counts( - x, out_idx=tf.int64 - ) - # Counts the occurrences of a given X value - summed_weights_per_x = tf.math.unsorted_segment_sum( - weights, unique_idx, tf.size(input=unique_x_values) - ) - max_x_idx = tf.size(unique_x_values, out_type=tf.int64) - return _BatchWeightedVocabulary( - unique_x_values=unique_x_values, - summed_weights_per_x=summed_weights_per_x, - unique_count=unique_count, - unique_idx=unique_idx, - max_x_idx=max_x_idx, - ) +def _preprocess_tensors_for_cooccurences(x_input, y_input, weights_input, filter_regex): + """Wrangles the tensors to make them compatible with AMI computation. + + Args: + ---- + x_input: Input `Tensor` or `CompositeTensor`. + y_input: Integer `Tensor` with which to compute the co-occurrence with + x_input. + weights_input: (Optional) Weights input `Tensor`. + filter_regex: (Optional) Regex that matches tokens that have to be filtered + out. Can only be specified if `x_input` has string dtype. + + Returns: + ------- + A tuple containing the re-shaped, verified (x, y, weights). + """ + tf.compat.v1.assert_type(y_input, tf.int64) + # TODO(b/134075780): Revisit expected weights shape when input is sparse. + if isinstance(x_input, tf.SparseTensor): + batch_indices = x_input.indices[:, 0] + # y and densified x should have the same batch dimension. + assert_eq = tf.compat.v1.assert_equal( + tf.shape(y_input)[0], tf.cast(x_input.dense_shape[0], tf.int32) + ) + with tf.control_dependencies([assert_eq]): + y = tf.gather(y_input, batch_indices) + x = x_input.values + elif isinstance(x_input, tf.RaggedTensor): + # Each batch instance in x corresponds to a single value in y. + x_row_indices = _get_ragged_batch_value_rowids(x_input) + assert_compatible = tf.debugging.assert_greater_equal( + tf.shape(y_input, out_type=tf.int64)[0], x_input.bounding_shape(axis=0) + ) + with tf.control_dependencies([assert_compatible]): + x = tf.ensure_shape(x_input.flat_values, [None]) + y = tf.gather(y_input, x_row_indices) + else: + y = y_input + x = x_input + if weights_input is None: + weights = tf.ones_like(x, dtype=tf.float32) + else: + x, weights_input = assert_same_shape(x, weights_input) + weights = weights_input + y = _broadcast_to_x_shape(x, y) + x = tf.reshape(x, [-1]) + filter_fn = _make_regex_filter_fn(x, filter_regex) + x = filter_fn(x) + y = filter_fn(tf.reshape(y, [-1])) + weights = filter_fn(tf.reshape(weights, [-1])) + return x, y, weights + + +def _compute_summed_positive_per_x_and_y(weights, dummy_index, max_x_idx, max_y_value): + """Computes a segment sum to retrieve weighted co-occurrences of X and Y. + + Args: + ---- + weights: The example weights, or a ones_like Tensor (=unweighted). + dummy_index: An index the size of X containing shifted Y values. See + cl/251659477. + max_x_idx: The total number of distinct elements seen in the batch. + max_y_value: The maximum id for Y values. Basically, the number of distinct + values in Y. + + Returns: + ------- + A 2D Tensor of shape [x_dim, y_dim] containing the weighted cooccurences of + each Y value for each X value. + """ + summed_positive_per_x_and_y = tf.cast( + tf.math.unsorted_segment_sum( + weights, dummy_index, max_x_idx * (max_y_value + 1) + ), + dtype=tf.float32, + ) + summed_positive_per_x_and_y = tf.reshape( + summed_positive_per_x_and_y, [max_x_idx, max_y_value + 1] + ) + return summed_positive_per_x_and_y -def extend_reduced_batch_with_y_counts(reduced_batch, y, weights=None): - """Extend the ReducedBatchWeightedCounts with global counts for y. - - This is used to maintain an accurate count of global frequencies of each value - in y. When x is multivalent, the sum over the summed_positive_per_x_and_y - will over-count the occurrence of y. To keep track of the true distribution - of y values, we add a sentinel value that tracks the global counts of each - distinct value in y. This is useful for computing the mutual information - between values in x and y. - - Args: - reduced_batch: A ReducedBatchWeightedCounts instance. - y: A `Tensor` representing a batch of y values. - weights: Optional `Tensor` representing a batch of weight values. - - Returns: - A new ReducedBatchWeightedCounts instance with sentinel values appended. - """ - if isinstance(y, tf.SparseTensor): - y_shape = y.dense_shape - # Need to slice instead of directly using [0] as this runs in graph mode. - shape = tf.slice(y_shape, [0], [1]) - else: - shape = tf.shape(y) - # Create a dummy sentinel token that is present in every record. - if reduced_batch.unique_x.dtype.is_integer: - sentinel_values = tf.cast( - tf.fill(shape, GLOBAL_Y_COUNT_SENTINEL_INT), tf.int64 + +def _compute_vocabulary_values(x, weights) -> _BatchWeightedVocabulary: + """Computes a vocabulary for X values. + + Args: + ---- + x: Input `Tensor` or `CompositeTensor`. + weights: Weights input `Tensor`. + + Returns: + ------- + An XVocabulary object containing the results of the computation. See more in + the XVocabulary docstring. + """ + unique_x_values, unique_idx, unique_count = tf.unique_with_counts( + x, out_idx=tf.int64 + ) + # Counts the occurrences of a given X value + summed_weights_per_x = tf.math.unsorted_segment_sum( + weights, unique_idx, tf.size(input=unique_x_values) + ) + max_x_idx = tf.size(unique_x_values, out_type=tf.int64) + return _BatchWeightedVocabulary( + unique_x_values=unique_x_values, + summed_weights_per_x=summed_weights_per_x, + unique_count=unique_count, + unique_idx=unique_idx, + max_x_idx=max_x_idx, ) - else: - sentinel_values = tf.fill(shape, GLOBAL_Y_COUNT_SENTINEL_STRING) - # Computing the batch reduction over this sentinel token will reduce to a - # single sentinel value in sentinel_batch.unique_x, with the - # summed_positive_per_x_and_y thus capturing the total summed positive per - # value in y. - sentinel_batch = reduce_batch_weighted_cooccurrences( - sentinel_values, y, weights, extend_with_sentinel_counts=False) - - # Concatenate the sentinel counts with the existing reduced batch. - return ReducedBatchWeightedCounts( - unique_x=tf.concat([reduced_batch.unique_x, sentinel_batch.unique_x], - axis=0), - summed_weights_per_x=tf.concat([ - reduced_batch.summed_weights_per_x, - sentinel_batch.summed_weights_per_x - ], - axis=0), - summed_positive_per_x_and_y=tf.concat([ - reduced_batch.summed_positive_per_x_and_y, - sentinel_batch.summed_positive_per_x_and_y - ], - axis=0), - counts_per_x=tf.concat( - [reduced_batch.counts_per_x, sentinel_batch.counts_per_x], axis=0)) -def hashable_tensor_or_op(tensor_or_op): - """Returns a hashable reference to a Tensor if given a Tensor/CompositeTensor. +def extend_reduced_batch_with_y_counts(reduced_batch, y, weights=None): + """Extend the ReducedBatchWeightedCounts with global counts for y. + + This is used to maintain an accurate count of global frequencies of each value + in y. When x is multivalent, the sum over the summed_positive_per_x_and_y + will over-count the occurrence of y. To keep track of the true distribution + of y values, we add a sentinel value that tracks the global counts of each + distinct value in y. This is useful for computing the mutual information + between values in x and y. + + Args: + ---- + reduced_batch: A ReducedBatchWeightedCounts instance. + y: A `Tensor` representing a batch of y values. + weights: Optional `Tensor` representing a batch of weight values. + + Returns: + ------- + A new ReducedBatchWeightedCounts instance with sentinel values appended. + """ + if isinstance(y, tf.SparseTensor): + y_shape = y.dense_shape + # Need to slice instead of directly using [0] as this runs in graph mode. + shape = tf.slice(y_shape, [0], [1]) + else: + shape = tf.shape(y) + # Create a dummy sentinel token that is present in every record. + if reduced_batch.unique_x.dtype.is_integer: + sentinel_values = tf.cast(tf.fill(shape, GLOBAL_Y_COUNT_SENTINEL_INT), tf.int64) + else: + sentinel_values = tf.fill(shape, GLOBAL_Y_COUNT_SENTINEL_STRING) + # Computing the batch reduction over this sentinel token will reduce to a + # single sentinel value in sentinel_batch.unique_x, with the + # summed_positive_per_x_and_y thus capturing the total summed positive per + # value in y. + sentinel_batch = reduce_batch_weighted_cooccurrences( + sentinel_values, y, weights, extend_with_sentinel_counts=False + ) - Use deref_tensor_or_op on the result to get the Tensor (or SparseTensor). + # Concatenate the sentinel counts with the existing reduced batch. + return ReducedBatchWeightedCounts( + unique_x=tf.concat([reduced_batch.unique_x, sentinel_batch.unique_x], axis=0), + summed_weights_per_x=tf.concat( + [reduced_batch.summed_weights_per_x, sentinel_batch.summed_weights_per_x], + axis=0, + ), + summed_positive_per_x_and_y=tf.concat( + [ + reduced_batch.summed_positive_per_x_and_y, + sentinel_batch.summed_positive_per_x_and_y, + ], + axis=0, + ), + counts_per_x=tf.concat( + [reduced_batch.counts_per_x, sentinel_batch.counts_per_x], axis=0 + ), + ) - Args: - tensor_or_op: A `tf.Tensor`, `tf.CompositeTensor`, or other type. - Returns: - A hashable representation for the Tensor or CompositeTensor, or the original - value for other types. - """ - if isinstance(tensor_or_op, tf.Tensor): - return tensor_or_op.ref() - if isinstance(tensor_or_op, composite_tensor.CompositeTensor): - return _CompositeTensorRef( - type_spec=tf.type_spec_from_value(tensor_or_op), - list_of_refs=tuple( - hashable_tensor_or_op(component) for component in tf.nest.flatten( - tensor_or_op, expand_composites=True))) - return tensor_or_op +def hashable_tensor_or_op(tensor_or_op): + """Returns a hashable reference to a Tensor if given a Tensor/CompositeTensor. + + Use deref_tensor_or_op on the result to get the Tensor (or SparseTensor). + + Args: + ---- + tensor_or_op: A `tf.Tensor`, `tf.CompositeTensor`, or other type. + + Returns: + ------- + A hashable representation for the Tensor or CompositeTensor, or the original + value for other types. + """ + if isinstance(tensor_or_op, tf.Tensor): + return tensor_or_op.ref() + if isinstance(tensor_or_op, composite_tensor.CompositeTensor): + return _CompositeTensorRef( + type_spec=tf.type_spec_from_value(tensor_or_op), + list_of_refs=tuple( + hashable_tensor_or_op(component) + for component in tf.nest.flatten(tensor_or_op, expand_composites=True) + ), + ) + return tensor_or_op def deref_tensor_or_op(tensor_or_op): - """Returns a Tensor or CompositeTensor if given a reference, otherwise input. - - Args: - tensor_or_op: An output of `hashable_tensor_or_op`. - - Returns: - A Tensor, CompositeTensor, or the given tensor_or_op. - """ - if isinstance(tensor_or_op, object_identity.Reference): - return tensor_or_op.deref() - if isinstance(tensor_or_op, _CompositeTensorRef): - return tf.nest.pack_sequence_as( - structure=tensor_or_op.type_spec, - flat_sequence=[ - deref_tensor_or_op(component) - for component in tensor_or_op.list_of_refs - ], - expand_composites=True) - return tensor_or_op + """Returns a Tensor or CompositeTensor if given a reference, otherwise input. + + Args: + ---- + tensor_or_op: An output of `hashable_tensor_or_op`. + + Returns: + ------- + A Tensor, CompositeTensor, or the given tensor_or_op. + """ + if isinstance(tensor_or_op, object_identity.Reference): + return tensor_or_op.deref() + if isinstance(tensor_or_op, _CompositeTensorRef): + return tf.nest.pack_sequence_as( + structure=tensor_or_op.type_spec, + flat_sequence=[ + deref_tensor_or_op(component) for component in tensor_or_op.list_of_refs + ], + expand_composites=True, + ) + return tensor_or_op def _broadcast_to_x_shape(x, y): - """Broadcasts y to same shape as x as needed. - - Args: - x: An input feature. - y: A feature that is either the same shape as x or has the same outer - dimensions as x. If the latter, y is broadcast to the same shape as x. - - Returns: - A Tensor that contains the broadcasted feature, y. - """ - # The batch dimension of x and y must be the same, and y must be 1D. - x_shape = tf.shape(input=x) - y_shape = tf.shape(input=y) - assert_eq = tf.compat.v1.assert_equal(x_shape[0], y_shape[0]) - with tf.control_dependencies([assert_eq]): - rank_delta = tf.rank(x) - tf.rank(y) - target_shape = tf.concat( - [tf.shape(y), tf.ones(rank_delta, dtype=tf.int32)], axis=0) - matched_rank = tf.reshape(y, target_shape) - return tf.broadcast_to(matched_rank, x_shape) + """Broadcasts y to same shape as x as needed. + + Args: + ---- + x: An input feature. + y: A feature that is either the same shape as x or has the same outer + dimensions as x. If the latter, y is broadcast to the same shape as x. + + Returns: + ------- + A Tensor that contains the broadcasted feature, y. + """ + # The batch dimension of x and y must be the same, and y must be 1D. + x_shape = tf.shape(input=x) + y_shape = tf.shape(input=y) + assert_eq = tf.compat.v1.assert_equal(x_shape[0], y_shape[0]) + with tf.control_dependencies([assert_eq]): + rank_delta = tf.rank(x) - tf.rank(y) + target_shape = tf.concat([tf.shape(y), tf.ones(rank_delta, dtype=tf.int32)], axis=0) + matched_rank = tf.reshape(y, target_shape) + return tf.broadcast_to(matched_rank, x_shape) def assert_same_shape(x, y): - """Asserts two tensors have the same dynamic and static shape. - - Args: - x: A `Tensor`. - y: A `Tensor` - - Returns: - The elements `x` and `y`, the results must be used in order to ensure that - the dynamic check is executed. - """ - x.shape.assert_is_compatible_with(y.shape) - assert_eq = tf.compat.v1.assert_equal(tf.shape(input=x), tf.shape(input=y)) - with tf.control_dependencies([assert_eq]): - return tf.identity(x), tf.identity(y) + """Asserts two tensors have the same dynamic and static shape. + + Args: + ---- + x: A `Tensor`. + y: A `Tensor` + + Returns: + ------- + The elements `x` and `y`, the results must be used in order to ensure that + the dynamic check is executed. + """ + x.shape.assert_is_compatible_with(y.shape) + assert_eq = tf.compat.v1.assert_equal(tf.shape(input=x), tf.shape(input=y)) + with tf.control_dependencies([assert_eq]): + return tf.identity(x), tf.identity(y) # TODO(b/178189903): This is needed because tf.sparse.reduce_* produces a dense @@ -619,1405 +641,1666 @@ def assert_same_shape(x, y): def _sparse_reduce_batch_keep_shape( sparse_reduce_fn: Callable[..., tf.Tensor], sparse_tensor: tf.SparseTensor ) -> tf.Tensor: # pylint: disable=g-bare-generic - """Applies a tf.sparse.reduce_* method on the given sparse_tensor.""" - result = sparse_reduce_fn(sparse_tensor, axis=0) - result.set_shape(sparse_tensor.get_shape()[1:]) - return result - - -def reduce_batch_count(x: common_types.TensorType, - reduce_instance_dims: bool) -> tf.Tensor: - """Counts elements in the given tensor. - - Args: - x: A `Tensor` or `CompositeTensor`. - reduce_instance_dims: A bool, if True - collapses the batch and instance - dimensions to arrive at a single scalar output. Otherwise, only collapses - the batch dimension and outputs a `Tensor` of the same shape as the input. - - Returns: - The element count of `x`. The result is either a scalar if - reduce_instance_dims is True, otherwise a `Tensor` having shape of `x` - without the first (batch) dimension. NaNs and infinite input values are - ignored. - """ - if isinstance(x, tf.SparseTensor): - if reduce_instance_dims: - x = x.values - else: - ones_like = tf.SparseTensor( - indices=x.indices, - values=tf.cast(_is_finite(x.values), tf.int64), - dense_shape=x.dense_shape) - # TODO(b/178189903): Remove this once we no longer lose static shape - # information. - ones_like._dense_shape_default = x._dense_shape_default # pylint: disable=protected-access - return _sparse_reduce_batch_keep_shape(tf.sparse.reduce_sum, ones_like) - elif isinstance(x, tf.RaggedTensor): - if reduce_instance_dims: - x = x.flat_values - else: - finite_mask = tf.cast(_is_finite(x), tf.int64) - return tf.math.reduce_sum(finite_mask, axis=0).to_tensor() + """Applies a tf.sparse.reduce_* method on the given sparse_tensor.""" + result = sparse_reduce_fn(sparse_tensor, axis=0) + result.set_shape(sparse_tensor.get_shape()[1:]) + return result - # Exlude NaNs and infinite elements from size calculation. They can only occur - # in tensors with floating data types. - if x.dtype.is_floating: - finite_mask = tf.cast(tf.math.is_finite(x), tf.int64) - return tf.reduce_sum(finite_mask, axis=None if reduce_instance_dims else 0) - if reduce_instance_dims: - return tf.size(input=x) +def reduce_batch_count( + x: common_types.TensorType, reduce_instance_dims: bool +) -> tf.Tensor: + """Counts elements in the given tensor. + + Args: + ---- + x: A `Tensor` or `CompositeTensor`. + reduce_instance_dims: A bool, if True - collapses the batch and instance + dimensions to arrive at a single scalar output. Otherwise, only collapses + the batch dimension and outputs a `Tensor` of the same shape as the input. + + Returns: + ------- + The element count of `x`. The result is either a scalar if + reduce_instance_dims is True, otherwise a `Tensor` having shape of `x` + without the first (batch) dimension. NaNs and infinite input values are + ignored. + """ + if isinstance(x, tf.SparseTensor): + if reduce_instance_dims: + x = x.values + else: + ones_like = tf.SparseTensor( + indices=x.indices, + values=tf.cast(_is_finite(x.values), tf.int64), + dense_shape=x.dense_shape, + ) + # TODO(b/178189903): Remove this once we no longer lose static shape + # information. + ones_like._dense_shape_default = x._dense_shape_default # pylint: disable=protected-access + return _sparse_reduce_batch_keep_shape(tf.sparse.reduce_sum, ones_like) + elif isinstance(x, tf.RaggedTensor): + if reduce_instance_dims: + x = x.flat_values + else: + finite_mask = tf.cast(_is_finite(x), tf.int64) + return tf.math.reduce_sum(finite_mask, axis=0).to_tensor() + + # Exlude NaNs and infinite elements from size calculation. They can only occur + # in tensors with floating data types. + if x.dtype.is_floating: + finite_mask = tf.cast(tf.math.is_finite(x), tf.int64) + return tf.reduce_sum(finite_mask, axis=None if reduce_instance_dims else 0) + + if reduce_instance_dims: + return tf.size(input=x) - # Fill a tensor shaped like x except batch_size=1 with batch_size. - x_shape = tf.shape(input=x) - return tf.fill(x_shape[1:], x_shape[0]) + # Fill a tensor shaped like x except batch_size=1 with batch_size. + x_shape = tf.shape(input=x) + return tf.fill(x_shape[1:], x_shape[0]) def _map_values( - map_function: Callable[[Union[tf.Tensor, tf.RaggedTensor]], - Union[tf.Tensor, tf.RaggedTensor]], + map_function: Callable[ + [Union[tf.Tensor, tf.RaggedTensor]], Union[tf.Tensor, tf.RaggedTensor] + ], tensor: common_types.ConsistentTensorType, ) -> common_types.ConsistentTensorType: - values = tensor if isinstance(tensor, tf.Tensor) else tensor.values - result = map_function(values) - if not isinstance(tensor, tf.Tensor): - return tensor.with_values(result) - else: - return result + values = tensor if isinstance(tensor, tf.Tensor) else tensor.values + result = map_function(values) + if not isinstance(tensor, tf.Tensor): + return tensor.with_values(result) + else: + return result def maybe_format_vocabulary_input( x: common_types.ConsistentTensorType, ) -> common_types.ConsistentTensorType: - """Formats string vocabulary input. + """Formats string vocabulary input. + + Args: + ---- + x: a tensor containing the vocabulary. + + Returns: + ------- + A similarly typed tensor. + """ + if x.dtype == tf.string: + # b/62379925: This is a workaround to allow tokens to contain spaces when + # store_frequency=True, which should eventaully be removed. + def map_spaces( + t: Union[tf.Tensor, tf.RaggedTensor], + ) -> Union[tf.Tensor, tf.RaggedTensor]: + return tf.strings.regex_replace(t, " ", "__SPACE__") + + return _map_values(map_spaces, x) + return x + + +def _to_string(x: common_types.TensorType) -> common_types.TensorType: + """Converts values in the given `Tensor` or `CompositeTensor` to strings.""" + if x.dtype is tf.string: + return x + return _map_values(tf.strings.as_string, x) - Args: - x: a tensor containing the vocabulary. - Returns: - A similarly typed tensor. - """ +def reduce_batch_count_per_key( + key: common_types.TensorType, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes per-key counts in the given tensor. - if x.dtype == tf.string: - # b/62379925: This is a workaround to allow tokens to contain spaces when - # store_frequency=True, which should eventaully be removed. - def map_spaces(t: Union[tf.Tensor, tf.RaggedTensor] - ) -> Union[tf.Tensor, tf.RaggedTensor]: - return tf.strings.regex_replace(t, ' ', '__SPACE__') + Args: + ---- + key: A `Tensor` or `CompositeTensor`. - return _map_values(map_spaces, x) - return x + Returns: + ------- + A 2-tuple containing the tensor's (key_vocab, count_per_key). + """ + key = _to_string(key) + if isinstance(key, tf.SparseTensor): + key = key.values + elif isinstance(key, tf.RaggedTensor): + key = key.flat_values + key.set_shape([None]) + unique = tf.unique_with_counts(key, out_idx=tf.int64) -def _to_string(x: common_types.TensorType) -> common_types.TensorType: - """Converts values in the given `Tensor` or `CompositeTensor` to strings.""" - if x.dtype is tf.string: - return x - return _map_values(tf.strings.as_string, x) + return unique.y, unique.count -def reduce_batch_count_per_key( - key: common_types.TensorType) -> Tuple[tf.Tensor, tf.Tensor]: - """Computes per-key counts in the given tensor. - - Args: - key: A `Tensor` or `CompositeTensor`. - - Returns: - A 2-tuple containing the tensor's (key_vocab, count_per_key). - """ - key = _to_string(key) - - if isinstance(key, tf.SparseTensor): - key = key.values - elif isinstance(key, tf.RaggedTensor): - key = key.flat_values - key.set_shape([None]) - unique = tf.unique_with_counts(key, out_idx=tf.int64) - - return unique.y, unique.count - - -def reorder_histogram(bucket_vocab: tf.Tensor, counts: tf.Tensor, - boundary_size: int) -> tf.Tensor: - """Return the histogram counts in indexed order, and zero out missing values. - - The count_elements analyzer returns counts in alphanumeric order, only for the - values that are present. To construct a well-formed histogram, we need to - rearrange them in numerical order, and fill in the missing values. - - Ex: The data contains values in the following form: [0, 1, 0, 1, 0, 3, 0, 1] - bucket_indices happen to be the same as these values, and - count_elements(tf.strings.as_string(bucket_indices)) returns: - bucket_vocab=['1', '3', '0'], - counts=[3, 1, 4] - - If boundaries=[0, 1, 2, 3, 4], we expect counts=[4, 3, 0, 1, 0], - which this function will return. - - Args: - bucket_vocab: A `Tensor` that names the buckets corresponding to the count - information returned. - counts: A `Tensor` that matches the bucket_vocab. - boundary_size: A scalar that provides information about how big the returned - counts should be. - - Returns: - counts: A `Tensor` of size boundary_size corresponding to counts of all - available buckets. - """ - if bucket_vocab.dtype == tf.string: - bucket_vocab = tf.strings.to_number(bucket_vocab, tf.int32) - # counts/bucket_vocab may be out of order and missing values (empty buckets). - ordering = tf.argsort( - tf.concat([bucket_vocab, - tf.sets.difference([tf.range(boundary_size)], - [bucket_vocab]).values], axis=-1)) - counts = tf.pad(counts, [[0, boundary_size - tf.size(counts)]]) - return tf.gather(counts, ordering) +def reorder_histogram( + bucket_vocab: tf.Tensor, counts: tf.Tensor, boundary_size: int +) -> tf.Tensor: + """Return the histogram counts in indexed order, and zero out missing values. + + The count_elements analyzer returns counts in alphanumeric order, only for the + values that are present. To construct a well-formed histogram, we need to + rearrange them in numerical order, and fill in the missing values. + + Ex: The data contains values in the following form: [0, 1, 0, 1, 0, 3, 0, 1] + bucket_indices happen to be the same as these values, and + count_elements(tf.strings.as_string(bucket_indices)) returns: + bucket_vocab=['1', '3', '0'], + counts=[3, 1, 4] + + If boundaries=[0, 1, 2, 3, 4], we expect counts=[4, 3, 0, 1, 0], + which this function will return. + + Args: + ---- + bucket_vocab: A `Tensor` that names the buckets corresponding to the count + information returned. + counts: A `Tensor` that matches the bucket_vocab. + boundary_size: A scalar that provides information about how big the returned + counts should be. + + Returns: + ------- + counts: A `Tensor` of size boundary_size corresponding to counts of all + available buckets. + """ + if bucket_vocab.dtype == tf.string: + bucket_vocab = tf.strings.to_number(bucket_vocab, tf.int32) + # counts/bucket_vocab may be out of order and missing values (empty buckets). + ordering = tf.argsort( + tf.concat( + [ + bucket_vocab, + tf.sets.difference([tf.range(boundary_size)], [bucket_vocab]).values, + ], + axis=-1, + ) + ) + counts = tf.pad(counts, [[0, boundary_size - tf.size(counts)]]) + return tf.gather(counts, ordering) # Used to decide which bucket boundary index to assign to a value. class Side(enum.Enum): - RIGHT = 'right' - LEFT = 'left' - - -def assign_buckets(x: tf.Tensor, - bucket_boundaries: tf.Tensor, - side: Side = Side.LEFT) -> tf.Tensor: - """Assigns every value in x to a bucket index defined by bucket_boundaries. - - Note that `x` and `bucket_boundaries` will be cast to a common type that can - hold the largest of values. - - Args: - x: a `Tensor` of values to be bucketized. - bucket_boundaries: The bucket boundaries `Tensor`. Note that the boundaries - are going to be flattened. - side: Controlls index of a bucket that is being assigned: LEFT means that - a value is going to be assigned index of the rightmost boundary such that - boundary <= value; RIGHT means that a value is assigned index of the - leftmost boundary such that value < boundary. - - Returns: - A `Tensor` of dtype int64 with the same shape as `x`, and each element in - the returned tensor representing the bucketized value. Bucketized value is - in the range [0, len(bucket_boundaries)]. - """ - with tf.compat.v1.name_scope(None, 'assign_buckets'): - flat_x = tf.reshape(x, [-1]) - flat_boundaries = tf.reshape(bucket_boundaries, [-1]) + RIGHT = "right" + LEFT = "left" - # Cast values or boundaries to the "largest" dtype to avoid truncating - # larger values and avoid casting if dtypes are the same. - if flat_x.dtype.max > flat_boundaries.dtype.max: - flat_boundaries = tf.cast(flat_boundaries, flat_x.dtype) - else: - flat_x = tf.cast(flat_x, flat_boundaries.dtype) - - if side == Side.LEFT: - # Ignore the last boundary to replicate behavior of the previously used - # `BoostedTreesBucketize` for backwards compatibility. - flat_boundaries = flat_boundaries[:-1] - buckets = tf.searchsorted( - flat_boundaries, flat_x, side=side.value, out_type=tf.int64) - return tf.reshape(buckets, tf.shape(x)) +def assign_buckets( + x: tf.Tensor, bucket_boundaries: tf.Tensor, side: Side = Side.LEFT +) -> tf.Tensor: + """Assigns every value in x to a bucket index defined by bucket_boundaries. + + Note that `x` and `bucket_boundaries` will be cast to a common type that can + hold the largest of values. + + Args: + ---- + x: a `Tensor` of values to be bucketized. + bucket_boundaries: The bucket boundaries `Tensor`. Note that the boundaries + are going to be flattened. + side: Controlls index of a bucket that is being assigned: LEFT means that + a value is going to be assigned index of the rightmost boundary such that + boundary <= value; RIGHT means that a value is assigned index of the + leftmost boundary such that value < boundary. + + Returns: + ------- + A `Tensor` of dtype int64 with the same shape as `x`, and each element in + the returned tensor representing the bucketized value. Bucketized value is + in the range [0, len(bucket_boundaries)]. + """ + with tf.compat.v1.name_scope(None, "assign_buckets"): + flat_x = tf.reshape(x, [-1]) + flat_boundaries = tf.reshape(bucket_boundaries, [-1]) + + # Cast values or boundaries to the "largest" dtype to avoid truncating + # larger values and avoid casting if dtypes are the same. + if flat_x.dtype.max > flat_boundaries.dtype.max: + flat_boundaries = tf.cast(flat_boundaries, flat_x.dtype) + else: + flat_x = tf.cast(flat_x, flat_boundaries.dtype) + + if side == Side.LEFT: + # Ignore the last boundary to replicate behavior of the previously used + # `BoostedTreesBucketize` for backwards compatibility. + flat_boundaries = flat_boundaries[:-1] + + buckets = tf.searchsorted( + flat_boundaries, flat_x, side=side.value, out_type=tf.int64 + ) + return tf.reshape(buckets, tf.shape(x)) # TODO(b/62379925): Remove this once all supported TF versions have # tf.data.experimental.DatasetInitializer. class _DatasetInitializerCompat( - getattr(tf.data.experimental, 'DatasetInitializer', - getattr(tf.lookup.experimental, 'DatasetInitializer', object))): - """Extends DatasetInitializer when possible and registers the init_op.""" - - def __init__(self, *args, **kwargs): - if self.__class__.mro()[1] == object: - raise NotImplementedError( - 'Cannot create a DatasetInitializer with this version of TF: {}' - .format(tf.__version__)) - super().__init__(*args, **kwargs) - - def initialize(self, table): - init_op = super().initialize(table) - collection_ref = tf.compat.v1.get_collection_ref( - tf.compat.v1.GraphKeys.TABLE_INITIALIZERS) - if init_op not in collection_ref: - collection_ref.append(init_op) - return init_op + getattr( + tf.data.experimental, + "DatasetInitializer", + getattr(tf.lookup.experimental, "DatasetInitializer", object), + ) +): + """Extends DatasetInitializer when possible and registers the init_op.""" + + def __init__(self, *args, **kwargs): + if self.__class__.mro()[1] == object: + raise NotImplementedError( + f"Cannot create a DatasetInitializer with this version of TF: {tf.__version__}" + ) + super().__init__(*args, **kwargs) + + def initialize(self, table): + init_op = super().initialize(table) + collection_ref = tf.compat.v1.get_collection_ref( + tf.compat.v1.GraphKeys.TABLE_INITIALIZERS + ) + if init_op not in collection_ref: + collection_ref.append(init_op) + return init_op def _make_vocab_entry_to_dtype_fn(dtype): + def vocab_entry_to_dtype(key): + return key if dtype is tf.string else tf.strings.to_number(key, out_type=dtype) + + return vocab_entry_to_dtype + + +def _make_tfrecord_vocabulary_dataset( + vocab_path, + key_dtype=tf.string, + value_dtype=tf.int64, + return_indicator_as_value=False, + has_indicator=False, +): + """Makes a (key, value) dataset from a compressed tfrecord file.""" + if not (value_dtype.is_floating or value_dtype.is_integer): + raise ValueError("value_dtype must be numeric. Got: %s" % value_dtype) + dataset = tf.data.TFRecordDataset(vocab_path, compression_type="GZIP") + key_dtype_fn = _make_vocab_entry_to_dtype_fn(key_dtype) + value_dtype_fn = _make_vocab_entry_to_dtype_fn(value_dtype) + + if return_indicator_as_value: + assert has_indicator + + def convert_dtype(k, v): + return key_dtype_fn(k), value_dtype_fn(v) + + return dataset.map( + _split_vocabulary_entries, num_parallel_calls=tf.data.experimental.AUTOTUNE + ).map(convert_dtype) + + else: + if has_indicator: + drop_indicator = lambda k, v: k + dataset = dataset.map( + _split_vocabulary_entries, + num_parallel_calls=tf.data.experimental.AUTOTUNE, + ).map(drop_indicator) - def vocab_entry_to_dtype(key): - return key if dtype is tf.string else tf.strings.to_number( - key, out_type=dtype) - - return vocab_entry_to_dtype - - -def _make_tfrecord_vocabulary_dataset(vocab_path, - key_dtype=tf.string, - value_dtype=tf.int64, - return_indicator_as_value=False, - has_indicator=False): - """Makes a (key, value) dataset from a compressed tfrecord file.""" - if not (value_dtype.is_floating or value_dtype.is_integer): - raise ValueError('value_dtype must be numeric. Got: %s' % value_dtype) - dataset = tf.data.TFRecordDataset(vocab_path, compression_type='GZIP') - key_dtype_fn = _make_vocab_entry_to_dtype_fn(key_dtype) - value_dtype_fn = _make_vocab_entry_to_dtype_fn(value_dtype) - - if return_indicator_as_value: - assert has_indicator - - def convert_dtype(k, v): - return key_dtype_fn(k), value_dtype_fn(v) - - return dataset.map( - _split_vocabulary_entries, - num_parallel_calls=tf.data.experimental.AUTOTUNE).map(convert_dtype) - - else: - if has_indicator: - drop_indicator = lambda k, v: k - dataset = dataset.map( - _split_vocabulary_entries, - num_parallel_calls=tf.data.experimental.AUTOTUNE).map(drop_indicator) - - def convert_dtype_and_swap(v, k): - return key_dtype_fn(k), tf.cast(v, value_dtype) - - return dataset.enumerate().map(convert_dtype_and_swap) - - -def make_tfrecord_vocabulary_lookup_initializer(filename_tensor, - key_dtype=tf.string, - value_dtype=tf.int64, - return_indicator_as_value=False, - has_indicator=False): - """Makes a lookup table initializer from a compressed tfrecord file.""" - with contextlib.ExitStack() as stack: - # If filename_tensor is a graph tensor (e.g. temporary analyzer output), the - # following operation cannot be lifted to init scope. Hence, check it is an - # eager tensor or a string constant. - if (tf.inside_function() and - isinstance(filename_tensor, (ops.EagerTensor, str))): - # Lift the dataset creation out of graph construction to avoid - # repeated initialization in TF2. - stack.enter_context(tf.init_scope()) - - dataset = _make_tfrecord_vocabulary_dataset(filename_tensor, key_dtype, - value_dtype, - return_indicator_as_value, - has_indicator) - if tf.inside_function(): - annotators.track_object(dataset, name=None) - return _DatasetInitializerCompat(dataset) + def convert_dtype_and_swap(v, k): + return key_dtype_fn(k), tf.cast(v, value_dtype) + + return dataset.enumerate().map(convert_dtype_and_swap) + + +def make_tfrecord_vocabulary_lookup_initializer( + filename_tensor, + key_dtype=tf.string, + value_dtype=tf.int64, + return_indicator_as_value=False, + has_indicator=False, +): + """Makes a lookup table initializer from a compressed tfrecord file.""" + with contextlib.ExitStack() as stack: + # If filename_tensor is a graph tensor (e.g. temporary analyzer output), the + # following operation cannot be lifted to init scope. Hence, check it is an + # eager tensor or a string constant. + if tf.inside_function() and isinstance(filename_tensor, (ops.EagerTensor, str)): + # Lift the dataset creation out of graph construction to avoid + # repeated initialization in TF2. + stack.enter_context(tf.init_scope()) + + dataset = _make_tfrecord_vocabulary_dataset( + filename_tensor, + key_dtype, + value_dtype, + return_indicator_as_value, + has_indicator, + ) + if tf.inside_function(): + annotators.track_object(dataset, name=None) + return _DatasetInitializerCompat(dataset) def _split_vocabulary_entries(batched_vocab_lines): - """Splits vocabulary entries separated by a single space. - - Vocabulary entries that include indicators are formatted as: - "" - - Args: - batched_vocab_lines: A possible batched string tensor. - - Returns: - A pair of (indicator, key) tensors. - """ - # Setting maxsplit=1 allows the vocabulary entries to include space - # characters. - split = tf.strings.split(batched_vocab_lines, sep=' ', maxsplit=1) - if isinstance(split, tf.RaggedTensor): - split_tensor = split.to_tensor() - return split_tensor[:, 1], split_tensor[:, 0] - else: - return split[1], split[0] - - -def apply_per_key_vocabulary(per_key_filename: tf.Tensor, - key: tf.Tensor, - default_value: Optional[str] = None, - target_ndims: Optional[int] = None) -> tf.Tensor: - """Apply a stored key-value mapping to a set of keys. - - We expect the values stored in per_key_filename to be two comma-delimited - numbers, such that it has the following form: - a 1,3 - b 2,4 - if a and b are the keys corresponding to each row. - - Args: - per_key_filename: The file name for the per-key vocabulary file. - key: A `Tensor` of dtype tf.string, which will determine which values are - returned. - default_value: (Optional) A string that determines the default output for - keys that are not found. - target_ndims: (Optional) The requested rank of each returned value (wrapped - in a single Tensor). - - Returns: - A `Tensor` representing the mapped values of shape [None, 2, ...], where - extra dimensions are added according to `target_dims`. - If no default value is given, maps oov keys to [0, 0]. - """ - if default_value is None: - default_value = '0,0' - - def _construct_table(asset_filepath): - initializer = tf.lookup.TextFileInitializer( - asset_filepath, - key_dtype=tf.string, - key_index=1, - value_dtype=tf.string, - value_index=0, - delimiter=' ') - return tf.lookup.StaticHashTable(initializer, default_value=default_value) - - table_lookup, unused_table_size = construct_and_lookup_table( - _construct_table, per_key_filename, key) - - sparse_result = tf.compat.v1.strings.split(table_lookup, sep=',') - dense_result = tf.sparse.to_dense(sparse_result, '0') - # Add 0s where dense_result has empty strings. - number_strings = tf.where( - tf.strings.length(dense_result) > 0, dense_result, - tf.fill(tf.shape(dense_result), '0')) - numbers = tf.strings.to_number(number_strings) - # We add 1 to represent the dimension of the multiple associated values found - # in the vocabulary file (the d values present for every key). - return numbers if not target_ndims else _align_dims(numbers, target_ndims + 1) + """Splits vocabulary entries separated by a single space. + + Vocabulary entries that include indicators are formatted as: + "" + + Args: + ---- + batched_vocab_lines: A possible batched string tensor. + + Returns: + ------- + A pair of (indicator, key) tensors. + """ + # Setting maxsplit=1 allows the vocabulary entries to include space + # characters. + split = tf.strings.split(batched_vocab_lines, sep=" ", maxsplit=1) + if isinstance(split, tf.RaggedTensor): + split_tensor = split.to_tensor() + return split_tensor[:, 1], split_tensor[:, 0] + else: + return split[1], split[0] + + +def apply_per_key_vocabulary( + per_key_filename: tf.Tensor, + key: tf.Tensor, + default_value: Optional[str] = None, + target_ndims: Optional[int] = None, +) -> tf.Tensor: + """Apply a stored key-value mapping to a set of keys. + + We expect the values stored in per_key_filename to be two comma-delimited + numbers, such that it has the following form: + a 1,3 + b 2,4 + if a and b are the keys corresponding to each row. + + Args: + ---- + per_key_filename: The file name for the per-key vocabulary file. + key: A `Tensor` of dtype tf.string, which will determine which values are + returned. + default_value: (Optional) A string that determines the default output for + keys that are not found. + target_ndims: (Optional) The requested rank of each returned value (wrapped + in a single Tensor). + + Returns: + ------- + A `Tensor` representing the mapped values of shape [None, 2, ...], where + extra dimensions are added according to `target_dims`. + If no default value is given, maps oov keys to [0, 0]. + """ + if default_value is None: + default_value = "0,0" + + def _construct_table(asset_filepath): + initializer = tf.lookup.TextFileInitializer( + asset_filepath, + key_dtype=tf.string, + key_index=1, + value_dtype=tf.string, + value_index=0, + delimiter=" ", + ) + return tf.lookup.StaticHashTable(initializer, default_value=default_value) + + table_lookup, unused_table_size = construct_and_lookup_table( + _construct_table, per_key_filename, key + ) + + sparse_result = tf.compat.v1.strings.split(table_lookup, sep=",") + dense_result = tf.sparse.to_dense(sparse_result, "0") + # Add 0s where dense_result has empty strings. + number_strings = tf.where( + tf.strings.length(dense_result) > 0, + dense_result, + tf.fill(tf.shape(dense_result), "0"), + ) + numbers = tf.strings.to_number(number_strings) + # We add 1 to represent the dimension of the multiple associated values found + # in the vocabulary file (the d values present for every key). + return numbers if not target_ndims else _align_dims(numbers, target_ndims + 1) def _is_finite(x: common_types.TensorType) -> common_types.TensorType: - """Extension of `tf.math.is_finite` that works with all dtypes.""" - if x.dtype.is_floating: - return tf.math.is_finite(x) - return tf.ones_like(x, dtype=tf.bool) + """Extension of `tf.math.is_finite` that works with all dtypes.""" + if x.dtype.is_floating: + return tf.math.is_finite(x) + return tf.ones_like(x, dtype=tf.bool) def _reduce_batch_count_mean_and_var_sparse( - x: tf.SparseTensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - """Computes elementwise count, mean and var for the given sparse tensor.""" - x_count = tf.cast(reduce_batch_count(x, reduce_instance_dims=False), x.dtype) - finite_x = tf.SparseTensor( - indices=x.indices, - values=tf.where(_is_finite(x.values), x.values, tf.zeros_like(x.values)), - dense_shape=x.dense_shape) - x_sum = _sparse_reduce_batch_keep_shape(tf.sparse.reduce_sum, finite_x) - x_mean = tf.math.divide_no_nan(x_sum, x_count) - x_minus_mean = tf.sparse.add(finite_x, -tf.broadcast_to(x_mean, tf.shape(x))) - x_minus_mean_sparse = tf.SparseTensor(x.indices, - tf.gather_nd(x_minus_mean, x.indices), - x.dense_shape) - sum_of_squares = tf.math.reduce_sum( - tf.square(tf.sparse.to_dense(x_minus_mean_sparse)), axis=0) - x_variance = tf.math.divide_no_nan(sum_of_squares, x_count) - return (x_count, x_mean, x_variance) + x: tf.SparseTensor, +) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + """Computes elementwise count, mean and var for the given sparse tensor.""" + x_count = tf.cast(reduce_batch_count(x, reduce_instance_dims=False), x.dtype) + finite_x = tf.SparseTensor( + indices=x.indices, + values=tf.where(_is_finite(x.values), x.values, tf.zeros_like(x.values)), + dense_shape=x.dense_shape, + ) + x_sum = _sparse_reduce_batch_keep_shape(tf.sparse.reduce_sum, finite_x) + x_mean = tf.math.divide_no_nan(x_sum, x_count) + x_minus_mean = tf.sparse.add(finite_x, -tf.broadcast_to(x_mean, tf.shape(x))) + x_minus_mean_sparse = tf.SparseTensor( + x.indices, tf.gather_nd(x_minus_mean, x.indices), x.dense_shape + ) + sum_of_squares = tf.math.reduce_sum( + tf.square(tf.sparse.to_dense(x_minus_mean_sparse)), axis=0 + ) + x_variance = tf.math.divide_no_nan(sum_of_squares, x_count) + return (x_count, x_mean, x_variance) def _reduce_batch_count_mean_and_var_ragged( - x: tf.RaggedTensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - """Computes elementwise count, mean and var for the given ragged tensor.""" - zeros_like_x = tf.zeros_like(x) - x_is_finite = _is_finite(x) - x_sum = tf.reduce_sum(tf.where(x_is_finite, x, zeros_like_x), axis=0) - dense_x_count = tf.cast( - reduce_batch_count(x, reduce_instance_dims=False), x.dtype) - x_count = tf.RaggedTensor.from_tensor( - dense_x_count, lengths=x_sum.nested_row_lengths()) - x_mean = tf.math.divide_no_nan(x_sum, x_count).to_tensor() - dense_x = x.to_tensor() - dense_x_is_finite = _is_finite(dense_x) - x_minus_mean = tf.where(dense_x_is_finite, dense_x - x_mean, - tf.zeros_like(dense_x)) - x_minus_mean = tf.RaggedTensor.from_tensor( - x_minus_mean, lengths=x.nested_row_lengths()) - sum_of_squares = tf.reduce_sum(input_tensor=tf.square(x_minus_mean), axis=0) - x_variance = tf.math.divide_no_nan(sum_of_squares, x_count) - return (dense_x_count, x_mean, x_variance.to_tensor()) + x: tf.RaggedTensor, +) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + """Computes elementwise count, mean and var for the given ragged tensor.""" + zeros_like_x = tf.zeros_like(x) + x_is_finite = _is_finite(x) + x_sum = tf.reduce_sum(tf.where(x_is_finite, x, zeros_like_x), axis=0) + dense_x_count = tf.cast(reduce_batch_count(x, reduce_instance_dims=False), x.dtype) + x_count = tf.RaggedTensor.from_tensor( + dense_x_count, lengths=x_sum.nested_row_lengths() + ) + x_mean = tf.math.divide_no_nan(x_sum, x_count).to_tensor() + dense_x = x.to_tensor() + dense_x_is_finite = _is_finite(dense_x) + x_minus_mean = tf.where(dense_x_is_finite, dense_x - x_mean, tf.zeros_like(dense_x)) + x_minus_mean = tf.RaggedTensor.from_tensor( + x_minus_mean, lengths=x.nested_row_lengths() + ) + sum_of_squares = tf.reduce_sum(input_tensor=tf.square(x_minus_mean), axis=0) + x_variance = tf.math.divide_no_nan(sum_of_squares, x_count) + return (dense_x_count, x_mean, x_variance.to_tensor()) def _reduce_batch_count_mean_and_var_dense( - x: tf.Tensor, - reduce_instance_dims: bool) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - """Computes count, mean and var for the given dense tensor.""" - axis = None if reduce_instance_dims else 0 - x_count = tf.cast(reduce_batch_count(x, reduce_instance_dims), x.dtype) - zeros_like_x = tf.zeros_like(x) - x_is_finite = _is_finite(x) - x_sum = tf.reduce_sum(tf.where(x_is_finite, x, zeros_like_x), axis=axis) - x_mean = tf.math.divide_no_nan(x_sum, x_count) - x_minus_mean = tf.where(x_is_finite, x - x_mean, zeros_like_x) - sum_of_squares = tf.reduce_sum( - input_tensor=tf.square(x_minus_mean), axis=axis) - x_variance = tf.math.divide_no_nan(sum_of_squares, x_count) - return (x_count, x_mean, x_variance) + x: tf.Tensor, reduce_instance_dims: bool +) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + """Computes count, mean and var for the given dense tensor.""" + axis = None if reduce_instance_dims else 0 + x_count = tf.cast(reduce_batch_count(x, reduce_instance_dims), x.dtype) + zeros_like_x = tf.zeros_like(x) + x_is_finite = _is_finite(x) + x_sum = tf.reduce_sum(tf.where(x_is_finite, x, zeros_like_x), axis=axis) + x_mean = tf.math.divide_no_nan(x_sum, x_count) + x_minus_mean = tf.where(x_is_finite, x - x_mean, zeros_like_x) + sum_of_squares = tf.reduce_sum(input_tensor=tf.square(x_minus_mean), axis=axis) + x_variance = tf.math.divide_no_nan(sum_of_squares, x_count) + return (x_count, x_mean, x_variance) def reduce_batch_count_mean_and_var( - x: common_types.TensorType, - reduce_instance_dims: bool) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - """Computes element count, mean and var for the given tensor. - - Args: - x: A `Tensor` or `CompositeTensor`. - reduce_instance_dims: A bool, if True - collapses the batch and instance - dimensions to arrive at a single scalar output. Otherwise, only - collapses the batch dimension and outputs a `Tensor` of the same shape - as the input. - - Returns: - A 3-tuple containing the tensor's (count, mean, var). NaNs and infinite - input values are ignored. - """ - if isinstance(x, tf.SparseTensor): - if reduce_instance_dims: - return _reduce_batch_count_mean_and_var_dense( - x.values, reduce_instance_dims=True) - else: - return _reduce_batch_count_mean_and_var_sparse(x) - elif isinstance(x, tf.RaggedTensor): - if reduce_instance_dims: - return _reduce_batch_count_mean_and_var_dense( - x.flat_values, reduce_instance_dims=True) + x: common_types.TensorType, reduce_instance_dims: bool +) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + """Computes element count, mean and var for the given tensor. + + Args: + ---- + x: A `Tensor` or `CompositeTensor`. + reduce_instance_dims: A bool, if True - collapses the batch and instance + dimensions to arrive at a single scalar output. Otherwise, only + collapses the batch dimension and outputs a `Tensor` of the same shape + as the input. + + Returns: + ------- + A 3-tuple containing the tensor's (count, mean, var). NaNs and infinite + input values are ignored. + """ + if isinstance(x, tf.SparseTensor): + if reduce_instance_dims: + return _reduce_batch_count_mean_and_var_dense( + x.values, reduce_instance_dims=True + ) + else: + return _reduce_batch_count_mean_and_var_sparse(x) + elif isinstance(x, tf.RaggedTensor): + if reduce_instance_dims: + return _reduce_batch_count_mean_and_var_dense( + x.flat_values, reduce_instance_dims=True + ) + else: + return _reduce_batch_count_mean_and_var_ragged(x) else: - return _reduce_batch_count_mean_and_var_ragged(x) - else: - return _reduce_batch_count_mean_and_var_dense(x, reduce_instance_dims) + return _reduce_batch_count_mean_and_var_dense(x, reduce_instance_dims) def _num_terms_and_factors(num_samples, dtype): - """Computes counts and sample multipliers for the given number of samples. - - Args: - num_samples: An integral type scalar `Tensor` containing the number of - samples used to compute the L-moments. This must be non-negative. - dtype: The dtype of the samples to process. This determines the output - `Tensor`s dtype. - - Returns: - The tuple (current_samples, current_pairs, current_triplets, - current_quadruplets, l1_factors, l2_factors, l3_factors, l4_factors). - Entries are `Tensor`s with the given dtype containing counters for each - moment and the factors to use to compute the moments. - """ - has_pairs = tf.math.greater(num_samples, 1) - has_triplets = tf.math.greater(num_samples, 2) - has_quadruplets = tf.math.greater(num_samples, 3) - - current_samples = tf.cast(num_samples, dtype=dtype) - current_pairs = tf.cast( - current_samples * (current_samples - 1.0) / 2.0, dtype=dtype) - current_triplets = tf.cast( - current_pairs * (current_samples - 2.0) / 3.0, dtype=dtype) - current_quadruplets = tf.cast( - current_triplets * (current_samples - 3.0) / 4.0, dtype=dtype) - - term_up = tf.range(0, current_samples, 1, dtype=dtype) - term_up_delay_1 = tf.range(-1, current_samples - 1, 1, dtype=dtype) - term_up_delay_2 = tf.range(-2, current_samples - 2, 1, dtype=dtype) - term_down = tf.range(current_samples - 1, -1, -1, dtype=dtype) - term_down_delay_1 = tf.range(current_samples - 2, -2, -1, dtype=dtype) - term_down_delay_2 = tf.range(current_samples - 3, -3, -1, dtype=dtype) - - l1_denominator = tf.cond(tf.math.greater(num_samples, 0), - lambda: current_samples, - lambda: tf.constant(1, dtype)) - l1_factors = tf.ones([num_samples], dtype=dtype) / l1_denominator - l2_denominator = tf.cond(has_pairs, - lambda: tf.cast(current_pairs * 2.0, dtype=dtype), - lambda: tf.constant(1, dtype)) - l2_factors = (term_up - term_down) / l2_denominator - l3_denominator = tf.cond(has_triplets, - lambda: tf.cast(current_triplets * 6, dtype=dtype), - lambda: tf.constant(1, dtype)) - l3_factors = ((term_up * term_up_delay_1 - 4.0 * term_up * term_down + - term_down * term_down_delay_1) / l3_denominator) - l4_denominator = tf.cond( - has_quadruplets, - lambda: tf.cast(current_quadruplets * 24, dtype=dtype), - lambda: tf.constant(1, dtype)) - l4_factors = ((term_up * term_up_delay_1 * term_up_delay_2 - - 9.0 * term_up * term_up_delay_1 * term_down + - 9.0 * term_up * term_down * term_down_delay_1 - - term_down * term_down_delay_1 * term_down_delay_2) / - l4_denominator) - return (current_samples, current_pairs, current_triplets, current_quadruplets, - l1_factors, l2_factors, l3_factors, l4_factors) + """Computes counts and sample multipliers for the given number of samples. + + Args: + ---- + num_samples: An integral type scalar `Tensor` containing the number of + samples used to compute the L-moments. This must be non-negative. + dtype: The dtype of the samples to process. This determines the output + `Tensor`s dtype. + + Returns: + ------- + The tuple (current_samples, current_pairs, current_triplets, + current_quadruplets, l1_factors, l2_factors, l3_factors, l4_factors). + Entries are `Tensor`s with the given dtype containing counters for each + moment and the factors to use to compute the moments. + """ + has_pairs = tf.math.greater(num_samples, 1) + has_triplets = tf.math.greater(num_samples, 2) + has_quadruplets = tf.math.greater(num_samples, 3) + + current_samples = tf.cast(num_samples, dtype=dtype) + current_pairs = tf.cast( + current_samples * (current_samples - 1.0) / 2.0, dtype=dtype + ) + current_triplets = tf.cast( + current_pairs * (current_samples - 2.0) / 3.0, dtype=dtype + ) + current_quadruplets = tf.cast( + current_triplets * (current_samples - 3.0) / 4.0, dtype=dtype + ) + + term_up = tf.range(0, current_samples, 1, dtype=dtype) + term_up_delay_1 = tf.range(-1, current_samples - 1, 1, dtype=dtype) + term_up_delay_2 = tf.range(-2, current_samples - 2, 1, dtype=dtype) + term_down = tf.range(current_samples - 1, -1, -1, dtype=dtype) + term_down_delay_1 = tf.range(current_samples - 2, -2, -1, dtype=dtype) + term_down_delay_2 = tf.range(current_samples - 3, -3, -1, dtype=dtype) + + l1_denominator = tf.cond( + tf.math.greater(num_samples, 0), + lambda: current_samples, + lambda: tf.constant(1, dtype), + ) + l1_factors = tf.ones([num_samples], dtype=dtype) / l1_denominator + l2_denominator = tf.cond( + has_pairs, + lambda: tf.cast(current_pairs * 2.0, dtype=dtype), + lambda: tf.constant(1, dtype), + ) + l2_factors = (term_up - term_down) / l2_denominator + l3_denominator = tf.cond( + has_triplets, + lambda: tf.cast(current_triplets * 6, dtype=dtype), + lambda: tf.constant(1, dtype), + ) + l3_factors = ( + term_up * term_up_delay_1 + - 4.0 * term_up * term_down + + term_down * term_down_delay_1 + ) / l3_denominator + l4_denominator = tf.cond( + has_quadruplets, + lambda: tf.cast(current_quadruplets * 24, dtype=dtype), + lambda: tf.constant(1, dtype), + ) + l4_factors = ( + term_up * term_up_delay_1 * term_up_delay_2 + - 9.0 * term_up * term_up_delay_1 * term_down + + 9.0 * term_up * term_down * term_down_delay_1 + - term_down * term_down_delay_1 * term_down_delay_2 + ) / l4_denominator + return ( + current_samples, + current_pairs, + current_triplets, + current_quadruplets, + l1_factors, + l2_factors, + l3_factors, + l4_factors, + ) @tf.function def _condition_l_moments_sparse( - current_index, unused_l1_sum, unused_l2_sum, unused_l3_sum, unused_l4_sum, - unused_count_samples, unused_count_pairs, unused_count_triplets, - unused_count_quadruplets, x_rank_2): - """Condition for the loop that computes L-moments for a `SparseTensor`.""" - return tf.less(current_index, x_rank_2.dense_shape[1]) + current_index, + unused_l1_sum, + unused_l2_sum, + unused_l3_sum, + unused_l4_sum, + unused_count_samples, + unused_count_pairs, + unused_count_triplets, + unused_count_quadruplets, + x_rank_2, +): + """Condition for the loop that computes L-moments for a `SparseTensor`.""" + return tf.less(current_index, x_rank_2.dense_shape[1]) @tf.function def _iteration_l_moments_sparse( - current_index, l1_sum, l2_sum, l3_sum, l4_sum, count_samples, - count_pairs, count_triplets, count_quadruplets, x_rank_2): - """Process one column of a `SparseTensor` and updates L-moments variables.""" - current_x = tf.boolean_mask( - x_rank_2.values, - tf.math.equal(x_rank_2.indices[:, 1], [current_index])) - sorted_x = tf.sort(current_x, axis=0) - num_samples = tf.shape(current_x)[0] - (current_samples, current_pairs, current_triplets, current_quadruplets, - l1_factors, l2_factors, l3_factors, - l4_factors) = _num_terms_and_factors(num_samples, x_rank_2.values.dtype) - - dim_1 = x_rank_2.dense_shape[1] - new_l1_sum = l1_sum + tf.scatter_nd( - [[current_index]], - [tf.reduce_sum(tf.multiply(sorted_x, l1_factors), axis=0)], [dim_1]) - new_l2_sum = l2_sum + tf.scatter_nd( - [[current_index]], - [tf.reduce_sum(tf.multiply(sorted_x, l2_factors), axis=0)], [dim_1]) - new_l3_sum = l3_sum + tf.scatter_nd( - [[current_index]], - [tf.reduce_sum(tf.multiply(sorted_x, l3_factors), axis=0)], [dim_1]) - new_l4_sum = l4_sum + tf.scatter_nd( - [[current_index]], - [tf.reduce_sum(tf.multiply(sorted_x, l4_factors), axis=0)], [dim_1]) - - new_count_samples = count_samples + tf.scatter_nd( - [[current_index]], [current_samples], [dim_1]) - new_count_pairs = count_pairs + tf.scatter_nd( - [[current_index]], [current_pairs], [dim_1]) - new_count_triplets = count_triplets + tf.scatter_nd( - [[current_index]], [current_triplets], [dim_1]) - new_count_quadruplets = count_quadruplets + tf.scatter_nd( - [[current_index]], [current_quadruplets], [dim_1]) - - return (tf.add(current_index, 1), - new_l1_sum, new_l2_sum, new_l3_sum, new_l4_sum, - new_count_samples, new_count_pairs, new_count_triplets, - new_count_quadruplets, x_rank_2) + current_index, + l1_sum, + l2_sum, + l3_sum, + l4_sum, + count_samples, + count_pairs, + count_triplets, + count_quadruplets, + x_rank_2, +): + """Process one column of a `SparseTensor` and updates L-moments variables.""" + current_x = tf.boolean_mask( + x_rank_2.values, tf.math.equal(x_rank_2.indices[:, 1], [current_index]) + ) + sorted_x = tf.sort(current_x, axis=0) + num_samples = tf.shape(current_x)[0] + ( + current_samples, + current_pairs, + current_triplets, + current_quadruplets, + l1_factors, + l2_factors, + l3_factors, + l4_factors, + ) = _num_terms_and_factors(num_samples, x_rank_2.values.dtype) + + dim_1 = x_rank_2.dense_shape[1] + new_l1_sum = l1_sum + tf.scatter_nd( + [[current_index]], + [tf.reduce_sum(tf.multiply(sorted_x, l1_factors), axis=0)], + [dim_1], + ) + new_l2_sum = l2_sum + tf.scatter_nd( + [[current_index]], + [tf.reduce_sum(tf.multiply(sorted_x, l2_factors), axis=0)], + [dim_1], + ) + new_l3_sum = l3_sum + tf.scatter_nd( + [[current_index]], + [tf.reduce_sum(tf.multiply(sorted_x, l3_factors), axis=0)], + [dim_1], + ) + new_l4_sum = l4_sum + tf.scatter_nd( + [[current_index]], + [tf.reduce_sum(tf.multiply(sorted_x, l4_factors), axis=0)], + [dim_1], + ) + + new_count_samples = count_samples + tf.scatter_nd( + [[current_index]], [current_samples], [dim_1] + ) + new_count_pairs = count_pairs + tf.scatter_nd( + [[current_index]], [current_pairs], [dim_1] + ) + new_count_triplets = count_triplets + tf.scatter_nd( + [[current_index]], [current_triplets], [dim_1] + ) + new_count_quadruplets = count_quadruplets + tf.scatter_nd( + [[current_index]], [current_quadruplets], [dim_1] + ) + + return ( + tf.add(current_index, 1), + new_l1_sum, + new_l2_sum, + new_l3_sum, + new_l4_sum, + new_count_samples, + new_count_pairs, + new_count_triplets, + new_count_quadruplets, + x_rank_2, + ) @tf.function def _condition_l_moments_dense( - current_index, unused_l1_sum, unused_l2_sum, unused_l3_sum, unused_l4_sum, - unused_l1_factors, unused_l2_factors, unused_l3_factors, unused_l4_factors, - x_rank_2): - """Condition for the loop that computes L-moments for a `Tensor`.""" - return tf.less(current_index, tf.shape(x_rank_2)[1]) + current_index, + unused_l1_sum, + unused_l2_sum, + unused_l3_sum, + unused_l4_sum, + unused_l1_factors, + unused_l2_factors, + unused_l3_factors, + unused_l4_factors, + x_rank_2, +): + """Condition for the loop that computes L-moments for a `Tensor`.""" + return tf.less(current_index, tf.shape(x_rank_2)[1]) @tf.function def _iteration_l_moments_dense( - current_index, l1_sum, l2_sum, l3_sum, l4_sum, l1_factors, l2_factors, - l3_factors, l4_factors, x_rank_2): - """Process one column of a `Tensor` and updates L-moments variables.""" - current_x = x_rank_2[:, current_index] - sorted_x = tf.sort(current_x) - - dim_1 = tf.shape(x_rank_2)[1] - new_l1_sum = l1_sum + tf.scatter_nd( - [[current_index]], - [tf.reduce_sum(tf.multiply(sorted_x, l1_factors), axis=0)], [dim_1]) - new_l2_sum = l2_sum + tf.scatter_nd( - [[current_index]], - [tf.reduce_sum(tf.multiply(sorted_x, l2_factors), axis=0)], [dim_1]) - new_l3_sum = l3_sum + tf.scatter_nd( - [[current_index]], - [tf.reduce_sum(tf.multiply(sorted_x, l3_factors), axis=0)], [dim_1]) - new_l4_sum = l4_sum + tf.scatter_nd( - [[current_index]], - [tf.reduce_sum(tf.multiply(sorted_x, l4_factors), axis=0)], [dim_1]) - return (tf.add(current_index, 1), - new_l1_sum, new_l2_sum, new_l3_sum, new_l4_sum, l1_factors, - l2_factors, l3_factors, l4_factors, x_rank_2) + current_index, + l1_sum, + l2_sum, + l3_sum, + l4_sum, + l1_factors, + l2_factors, + l3_factors, + l4_factors, + x_rank_2, +): + """Process one column of a `Tensor` and updates L-moments variables.""" + current_x = x_rank_2[:, current_index] + sorted_x = tf.sort(current_x) + + dim_1 = tf.shape(x_rank_2)[1] + new_l1_sum = l1_sum + tf.scatter_nd( + [[current_index]], + [tf.reduce_sum(tf.multiply(sorted_x, l1_factors), axis=0)], + [dim_1], + ) + new_l2_sum = l2_sum + tf.scatter_nd( + [[current_index]], + [tf.reduce_sum(tf.multiply(sorted_x, l2_factors), axis=0)], + [dim_1], + ) + new_l3_sum = l3_sum + tf.scatter_nd( + [[current_index]], + [tf.reduce_sum(tf.multiply(sorted_x, l3_factors), axis=0)], + [dim_1], + ) + new_l4_sum = l4_sum + tf.scatter_nd( + [[current_index]], + [tf.reduce_sum(tf.multiply(sorted_x, l4_factors), axis=0)], + [dim_1], + ) + return ( + tf.add(current_index, 1), + new_l1_sum, + new_l2_sum, + new_l3_sum, + new_l4_sum, + l1_factors, + l2_factors, + l3_factors, + l4_factors, + x_rank_2, + ) def reduce_batch_count_l_moments( x: common_types.TensorType, reduce_instance_dims: bool -) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor, - tf.Tensor, tf.Tensor]: - """Computes element first 4 L-moments and the corresponding counts. - - Computes the first 4 L-moments (https://en.wikipedia.org/wiki/L-moment) and - the number of samples, pairs, etc. used to compute them. - - Args: - x: A `Tensor` or `CompositeTensor`. - reduce_instance_dims: A bool, if True - collapses the batch and instance - dimensions to arrive at a single scalar output. Otherwise, only - collapses the batch dimension and outputs a `Tensor` of the same shape - as the input. - - Returns: - The tuple (count_samples, l1, count_pairs, l2, count_triplets, l3, - count_quadruplets, l4). Each entry is a `Tensor` with the same dtype as x. - If reduce_instance_dims is True, the tensors are scalars; otherwise the - shape is x.shape[1:], i.e. the batch dimension is removed. - """ - if isinstance(x, tf.SparseTensor) and reduce_instance_dims: - x = x.values - elif isinstance(x, tf.RaggedTensor): - if reduce_instance_dims: - x = x.flat_values - else: - raise NotImplementedError( - 'L-moments only support reduced dims for RaggedTensors') +) -> Tuple[ + tf.Tensor, + tf.Tensor, + tf.Tensor, + tf.Tensor, + tf.Tensor, + tf.Tensor, + tf.Tensor, + tf.Tensor, +]: + """Computes element first 4 L-moments and the corresponding counts. + + Computes the first 4 L-moments (https://en.wikipedia.org/wiki/L-moment) and + the number of samples, pairs, etc. used to compute them. + + Args: + ---- + x: A `Tensor` or `CompositeTensor`. + reduce_instance_dims: A bool, if True - collapses the batch and instance + dimensions to arrive at a single scalar output. Otherwise, only + collapses the batch dimension and outputs a `Tensor` of the same shape + as the input. + + Returns: + ------- + The tuple (count_samples, l1, count_pairs, l2, count_triplets, l3, + count_quadruplets, l4). Each entry is a `Tensor` with the same dtype as x. + If reduce_instance_dims is True, the tensors are scalars; otherwise the + shape is x.shape[1:], i.e. the batch dimension is removed. + """ + if isinstance(x, tf.SparseTensor) and reduce_instance_dims: + x = x.values + elif isinstance(x, tf.RaggedTensor): + if reduce_instance_dims: + x = x.flat_values + else: + raise NotImplementedError( + "L-moments only support reduced dims for RaggedTensors" + ) + + if isinstance(x, tf.SparseTensor): + batch_size = x.dense_shape[0] + x_rank_2 = tf.sparse.reshape(x, [batch_size, -1]) + dim_1 = x_rank_2.dense_shape[1] + initial_values = tf.zeros([dim_1], dtype=x.dtype) + ( + unused_current_index, + l1_sum, + l2_sum, + l3_sum, + l4_sum, + count_samples, + count_pairs, + count_triplets, + count_quadruplets, + unused_x_rank_2, + ) = tf.while_loop( + _condition_l_moments_sparse, + _iteration_l_moments_sparse, + [tf.constant(0, dim_1.dtype)] + [initial_values] * 8 + [x_rank_2], + ) + if reduce_instance_dims: + final_shape = () + elif x.get_shape().ndims and x.get_shape()[1:].is_fully_defined(): + final_shape = x.get_shape()[1:] + else: + final_shape = tf.shape(x)[1:] + l1 = tf.reshape(l1_sum, final_shape) + l2 = tf.reshape(l2_sum, final_shape) + l3 = tf.reshape(l3_sum, final_shape) + l4 = tf.reshape(l4_sum, final_shape) + count_l1 = tf.reshape(count_samples, final_shape) + count_l2 = tf.reshape(count_pairs, final_shape) + count_l3 = tf.reshape(count_triplets, final_shape) + count_l4 = tf.reshape(count_quadruplets, final_shape) - if isinstance(x, tf.SparseTensor): - batch_size = x.dense_shape[0] - x_rank_2 = tf.sparse.reshape(x, [batch_size, -1]) - dim_1 = x_rank_2.dense_shape[1] - initial_values = tf.zeros([dim_1], dtype=x.dtype) - (unused_current_index, l1_sum, l2_sum, l3_sum, l4_sum, - count_samples, count_pairs, count_triplets, - count_quadruplets, unused_x_rank_2) = tf.while_loop( - _condition_l_moments_sparse, - _iteration_l_moments_sparse, - [tf.constant(0, dim_1.dtype)] + [initial_values] * 8 + [x_rank_2]) - if reduce_instance_dims: - final_shape = () - elif x.get_shape().ndims and x.get_shape()[1:].is_fully_defined(): - final_shape = x.get_shape()[1:] else: - final_shape = tf.shape(x)[1:] - l1 = tf.reshape(l1_sum, final_shape) - l2 = tf.reshape(l2_sum, final_shape) - l3 = tf.reshape(l3_sum, final_shape) - l4 = tf.reshape(l4_sum, final_shape) - count_l1 = tf.reshape(count_samples, final_shape) - count_l2 = tf.reshape(count_pairs, final_shape) - count_l3 = tf.reshape(count_triplets, final_shape) - count_l4 = tf.reshape(count_quadruplets, final_shape) - - else: - num_samples = tf.size(x) if reduce_instance_dims else tf.shape(x)[0] - (count_samples, count_pairs, count_triplets, count_quadruplets, - l1_factors, l2_factors, l3_factors, l4_factors) = _num_terms_and_factors( - num_samples, x.dtype) - x_rank_2 = tf.reshape(x, [num_samples, -1]) - dim_1 = tf.shape(x_rank_2)[1] - initial_moment_values = tf.zeros([dim_1], dtype=x.dtype) - (unused_current_index, l1_sum, l2_sum, l3_sum, l4_sum, unused_l1_factors, - unused_l2_factors, unused_l3_factors, unused_l4_factors, - unused_x_rank_2) = tf.while_loop( - _condition_l_moments_dense, - _iteration_l_moments_dense, - [tf.constant(0, dim_1.dtype)] + [initial_moment_values] * 4 + - [l1_factors, l2_factors, l3_factors, l4_factors, x_rank_2]) - final_shape = (() if reduce_instance_dims else tf.shape(x)[1:]) - l1 = tf.reshape(l1_sum, final_shape) - l2 = tf.reshape(l2_sum, final_shape) - l3 = tf.reshape(l3_sum, final_shape) - l4 = tf.reshape(l4_sum, final_shape) - count_l1 = tf.fill(final_shape, count_samples) - count_l2 = tf.fill(final_shape, count_pairs) - count_l3 = tf.fill(final_shape, count_triplets) - count_l4 = tf.fill(final_shape, count_quadruplets) - - return (count_l1, l1, count_l2, l2, count_l3, l3, count_l4, l4) + num_samples = tf.size(x) if reduce_instance_dims else tf.shape(x)[0] + ( + count_samples, + count_pairs, + count_triplets, + count_quadruplets, + l1_factors, + l2_factors, + l3_factors, + l4_factors, + ) = _num_terms_and_factors(num_samples, x.dtype) + x_rank_2 = tf.reshape(x, [num_samples, -1]) + dim_1 = tf.shape(x_rank_2)[1] + initial_moment_values = tf.zeros([dim_1], dtype=x.dtype) + ( + unused_current_index, + l1_sum, + l2_sum, + l3_sum, + l4_sum, + unused_l1_factors, + unused_l2_factors, + unused_l3_factors, + unused_l4_factors, + unused_x_rank_2, + ) = tf.while_loop( + _condition_l_moments_dense, + _iteration_l_moments_dense, + [tf.constant(0, dim_1.dtype)] + + [initial_moment_values] * 4 + + [l1_factors, l2_factors, l3_factors, l4_factors, x_rank_2], + ) + final_shape = () if reduce_instance_dims else tf.shape(x)[1:] + l1 = tf.reshape(l1_sum, final_shape) + l2 = tf.reshape(l2_sum, final_shape) + l3 = tf.reshape(l3_sum, final_shape) + l4 = tf.reshape(l4_sum, final_shape) + count_l1 = tf.fill(final_shape, count_samples) + count_l2 = tf.fill(final_shape, count_pairs) + count_l3 = tf.fill(final_shape, count_triplets) + count_l4 = tf.fill(final_shape, count_quadruplets) + + return (count_l1, l1, count_l2, l2, count_l3, l3, count_l4, l4) def _validate_and_get_dense_value_key_inputs( - x: common_types.TensorType, - key: common_types.TensorType) -> Tuple[tf.Tensor, tf.Tensor]: - """Validate x and key and returns dense representations if feasible. - - Check if sparse x and sparse key have identical indices, map key if dense. - - Args: - x: A `Tensor` or `CompositeTensor`. - key: A `Tensor` or `CompositeTensor`. Must be `Tensor` if x is `Tensor`. - - Returns: - The values of x and key if both are composite, the values of x and a mapped - key if only x is composite, or the original x and key if both are dense. - """ - - if isinstance(x, tf.Tensor) and isinstance(key, tf.Tensor): - return x, key - elif isinstance(x, tf.Tensor): - raise ValueError('A dense key is required if x is dense') - - elif isinstance(x, tf.SparseTensor) and isinstance(key, tf.SparseTensor): - assert_shape = tf.debugging.assert_equal(x.dense_shape, key.dense_shape) - assert_eq = tf.debugging.assert_equal(x.indices, key.indices) - with tf.control_dependencies([assert_eq, assert_shape]): - return tf.identity(x.values), tf.identity(key.values) - elif isinstance(x, tf.SparseTensor) and isinstance(key, tf.Tensor): - # In this case, the row of x corresponds to the key at that row. - x_row_indices = x.indices[:, 0] - assert_compatible = tf.debugging.assert_greater_equal( - tf.shape(key, out_type=tf.int64)[0], x.dense_shape[0]) - with tf.control_dependencies([assert_compatible]): - return x.values, tf.gather(key, x_row_indices) - elif isinstance(x, tf.SparseTensor): - raise ValueError('A sparse or dense key is required if x is sparse') - - elif isinstance(x, tf.RaggedTensor) and isinstance(key, tf.RaggedTensor): - x.shape.assert_is_compatible_with(key.shape) - assert_ops = [ - tf.debugging.assert_equal(x_split, key_split) for x_split, key_split in - zip(x.nested_row_splits, key.nested_row_splits) - ] - with tf.control_dependencies(assert_ops): - return (tf.ensure_shape(tf.identity(x.flat_values), [None]), - tf.ensure_shape(tf.identity(key.flat_values), [None])) - elif isinstance(x, tf.RaggedTensor) and isinstance(key, tf.Tensor): - # Each batch instance in x corresponds to a single element in key. - x_row_indices = _get_ragged_batch_value_rowids(x) - assert_compatible = tf.debugging.assert_greater_equal( - tf.shape(key, out_type=tf.int64)[0], x.bounding_shape(axis=0)) - with tf.control_dependencies([assert_compatible]): - return (tf.ensure_shape(x.flat_values, - [None]), tf.gather(key, x_row_indices)) - else: - raise ValueError('A ragged or dense key is required if x is ragged') + x: common_types.TensorType, key: common_types.TensorType +) -> Tuple[tf.Tensor, tf.Tensor]: + """Validate x and key and returns dense representations if feasible. + + Check if sparse x and sparse key have identical indices, map key if dense. + + Args: + ---- + x: A `Tensor` or `CompositeTensor`. + key: A `Tensor` or `CompositeTensor`. Must be `Tensor` if x is `Tensor`. + + Returns: + ------- + The values of x and key if both are composite, the values of x and a mapped + key if only x is composite, or the original x and key if both are dense. + """ + if isinstance(x, tf.Tensor) and isinstance(key, tf.Tensor): + return x, key + elif isinstance(x, tf.Tensor): + raise ValueError("A dense key is required if x is dense") + + elif isinstance(x, tf.SparseTensor) and isinstance(key, tf.SparseTensor): + assert_shape = tf.debugging.assert_equal(x.dense_shape, key.dense_shape) + assert_eq = tf.debugging.assert_equal(x.indices, key.indices) + with tf.control_dependencies([assert_eq, assert_shape]): + return tf.identity(x.values), tf.identity(key.values) + elif isinstance(x, tf.SparseTensor) and isinstance(key, tf.Tensor): + # In this case, the row of x corresponds to the key at that row. + x_row_indices = x.indices[:, 0] + assert_compatible = tf.debugging.assert_greater_equal( + tf.shape(key, out_type=tf.int64)[0], x.dense_shape[0] + ) + with tf.control_dependencies([assert_compatible]): + return x.values, tf.gather(key, x_row_indices) + elif isinstance(x, tf.SparseTensor): + raise ValueError("A sparse or dense key is required if x is sparse") + + elif isinstance(x, tf.RaggedTensor) and isinstance(key, tf.RaggedTensor): + x.shape.assert_is_compatible_with(key.shape) + assert_ops = [ + tf.debugging.assert_equal(x_split, key_split) + for x_split, key_split in zip(x.nested_row_splits, key.nested_row_splits) + ] + with tf.control_dependencies(assert_ops): + return ( + tf.ensure_shape(tf.identity(x.flat_values), [None]), + tf.ensure_shape(tf.identity(key.flat_values), [None]), + ) + elif isinstance(x, tf.RaggedTensor) and isinstance(key, tf.Tensor): + # Each batch instance in x corresponds to a single element in key. + x_row_indices = _get_ragged_batch_value_rowids(x) + assert_compatible = tf.debugging.assert_greater_equal( + tf.shape(key, out_type=tf.int64)[0], x.bounding_shape(axis=0) + ) + with tf.control_dependencies([assert_compatible]): + return ( + tf.ensure_shape(x.flat_values, [None]), + tf.gather(key, x_row_indices), + ) + else: + raise ValueError("A ragged or dense key is required if x is ragged") def lookup_key(query: tf.Tensor, key_vocab: tf.Tensor) -> tf.Tensor: - """Look up the index of each element in query in key_vocab. - - Args: - query: A `Tensor`. - key_vocab: A 1-D `Tensor` of unique keys. - - Returns: - The indices of the keys in query, determined by position in key_vocab. - """ - - def _lookup_key(): - # Obtain 0-indexed int64 positions for the keys in key_vocab. - indices = tf.cast(tf.range(tf.size(key_vocab)), tf.int64) - - expanded_vocab_size = tf.expand_dims(tf.size(key_vocab), axis=0) - matrix_shape = tf.concat([expanded_vocab_size, tf.shape(query)], axis=0) - # Expand dims of key_vocab to rank of query. - vocab_shape = tf.concat( - [expanded_vocab_size, - tf.ones(tf.rank(query), dtype=tf.int32)], axis=0) - # Make copies of key_vocab to fill matrix_shape. - expand_vocab = tf.broadcast_to( - tf.reshape(key_vocab, vocab_shape), matrix_shape) - # Make copies of indices to fill matrix_shape. - expand_indices = tf.broadcast_to( - tf.reshape(indices, vocab_shape), matrix_shape) - # Make copies of query to fill matrix_shape. - expand_query = tf.broadcast_to(query, matrix_shape) - - # Indices where expand_query equals expand_vocab is set to the key's - # index. All the other indices are -1. - expand_result = tf.where( - tf.math.equal(expand_query, expand_vocab), expand_indices, - tf.cast(tf.fill(matrix_shape, -1), tf.int64)) - # Reduce matrix above to desired 1-D shape. - result = tf.math.reduce_max(expand_result, axis=0) - result.set_shape(query.shape) - return result - - def _check_vocab_size_and_lookup_key(): - return tf.cond( - tf.math.equal(tf.size(key_vocab), 0), - lambda: tf.cast(tf.fill(tf.shape(query), -1), tf.int64), _lookup_key) + """Look up the index of each element in query in key_vocab. + + Args: + ---- + query: A `Tensor`. + key_vocab: A 1-D `Tensor` of unique keys. + + Returns: + ------- + The indices of the keys in query, determined by position in key_vocab. + """ + + def _lookup_key(): + # Obtain 0-indexed int64 positions for the keys in key_vocab. + indices = tf.cast(tf.range(tf.size(key_vocab)), tf.int64) + + expanded_vocab_size = tf.expand_dims(tf.size(key_vocab), axis=0) + matrix_shape = tf.concat([expanded_vocab_size, tf.shape(query)], axis=0) + # Expand dims of key_vocab to rank of query. + vocab_shape = tf.concat( + [expanded_vocab_size, tf.ones(tf.rank(query), dtype=tf.int32)], axis=0 + ) + # Make copies of key_vocab to fill matrix_shape. + expand_vocab = tf.broadcast_to(tf.reshape(key_vocab, vocab_shape), matrix_shape) + # Make copies of indices to fill matrix_shape. + expand_indices = tf.broadcast_to(tf.reshape(indices, vocab_shape), matrix_shape) + # Make copies of query to fill matrix_shape. + expand_query = tf.broadcast_to(query, matrix_shape) + + # Indices where expand_query equals expand_vocab is set to the key's + # index. All the other indices are -1. + expand_result = tf.where( + tf.math.equal(expand_query, expand_vocab), + expand_indices, + tf.cast(tf.fill(matrix_shape, -1), tf.int64), + ) + # Reduce matrix above to desired 1-D shape. + result = tf.math.reduce_max(expand_result, axis=0) + result.set_shape(query.shape) + return result + + def _check_vocab_size_and_lookup_key(): + return tf.cond( + tf.math.equal(tf.size(key_vocab), 0), + lambda: tf.cast(tf.fill(tf.shape(query), -1), tf.int64), + _lookup_key, + ) - def _check_input_size_and_lookup_key(): - return tf.cond( - tf.math.equal(tf.size(query), - 0), lambda: tf.constant([], dtype=tf.int64), - _check_vocab_size_and_lookup_key) + def _check_input_size_and_lookup_key(): + return tf.cond( + tf.math.equal(tf.size(query), 0), + lambda: tf.constant([], dtype=tf.int64), + _check_vocab_size_and_lookup_key, + ) - return _check_input_size_and_lookup_key() + return _check_input_size_and_lookup_key() def _align_dims(tensor: tf.Tensor, target_ndims: int) -> tf.Tensor: - """Expand the rank of input tensor until it matches the target rank. + """Expand the rank of input tensor until it matches the target rank. + + Non-elementwise per-key reduce returns a tensor with rank 1 (batch). + The dimension count needs to match with x to finish the final mapping, because + we want to broadcast each reduction with x. To do so we need to add singleton + dimensions, otherwise TF will try to broadcast along the wrong dimensions. + + Args: + ---- + tensor: A `Tensor`. + target_ndims: The count of dims we want the output to meet or exceed. + + Returns: + ------- + The original input, with dimension count >= target_ndims. + """ + if target_ndims is None or target_ndims <= tensor.get_shape().ndims: + return tensor + for _ in range(target_ndims - tensor.get_shape().ndims): + tensor = tf.expand_dims(tensor, -1) + return tensor - Non-elementwise per-key reduce returns a tensor with rank 1 (batch). - The dimension count needs to match with x to finish the final mapping, because - we want to broadcast each reduction with x. To do so we need to add singleton - dimensions, otherwise TF will try to broadcast along the wrong dimensions. - Args: - tensor: A `Tensor`. - target_ndims: The count of dims we want the output to meet or exceed. +def map_per_key_reductions( + tensors_to_map: Tuple[tf.Tensor, ...], + key: common_types.TensorType, + key_vocab: tf.Tensor, + original_input: common_types.TensorType, + reduce_instance_dims: bool, +) -> Tuple[tf.Tensor, ...]: + """Rearrange the reduced per-key result to correspond to the original keys. + + Args: + ---- + tensors_to_map: A tuple of 1-D `Tensor`s that are same shape as key_vocab, + to be mapped to respective key. + key: A `Tensor` or `CompositeTensor`. + key_vocab: A 1-D `Tensor`. + original_input: A `Tensor` or `CompositeTensor`. + reduce_instance_dims: A `bool`. True if tensors_to_map are reduced in + dimension, else False. + + Returns: + ------- + A tuple same length as tensors_to_map, of `Tensor`s the same dimension as + original_input. We are mapping using the key for each original_input, + but output rank needs to match original_input in the dense case. + For the sparse case, it is enough for output to match original_input.values. + Any missing key would result in a mapping to 0. + """ + _, key = _validate_and_get_dense_value_key_inputs(original_input, key) + key_indices = lookup_key(key, key_vocab) + + ndims = ( + None + if isinstance(original_input, (tf.SparseTensor, tf.RaggedTensor)) + else original_input.get_shape().ndims + ) - Returns: - The original input, with dimension count >= target_ndims. - """ - if target_ndims is None or target_ndims <= tensor.get_shape().ndims: - return tensor - for _ in range(target_ndims - tensor.get_shape().ndims): - tensor = tf.expand_dims(tensor, -1) - return tensor - - -def map_per_key_reductions(tensors_to_map: Tuple[tf.Tensor, ...], - key: common_types.TensorType, key_vocab: tf.Tensor, - original_input: common_types.TensorType, - reduce_instance_dims: bool) -> Tuple[tf.Tensor, ...]: - """Rearrange the reduced per-key result to correspond to the original keys. - - Args: - tensors_to_map: A tuple of 1-D `Tensor`s that are same shape as key_vocab, - to be mapped to respective key. - key: A `Tensor` or `CompositeTensor`. - key_vocab: A 1-D `Tensor`. - original_input: A `Tensor` or `CompositeTensor`. - reduce_instance_dims: A `bool`. True if tensors_to_map are reduced in - dimension, else False. - - Returns: - A tuple same length as tensors_to_map, of `Tensor`s the same dimension as - original_input. We are mapping using the key for each original_input, - but output rank needs to match original_input in the dense case. - For the sparse case, it is enough for output to match original_input.values. - Any missing key would result in a mapping to 0. - """ - - _, key = _validate_and_get_dense_value_key_inputs(original_input, key) - key_indices = lookup_key(key, key_vocab) - - ndims = (None if isinstance(original_input, - (tf.SparseTensor, tf.RaggedTensor)) else - original_input.get_shape().ndims) - - # Append 0s to allow mapping OOVs to it. - tensors_to_map = [ - tf.concat([t, tf.expand_dims(tf.zeros_like(t[0]), 0)], axis=0) - for t in tensors_to_map - ] - - # Replace `-1`s due to OOV with size of key_vocab. - adjusted_indices = tf.where( - key_indices >= 0, key_indices, - tf.cast( - tf.fill(tf.shape(key_indices), tf.size(key_vocab)), dtype=tf.int64)) - axis = -1 if reduce_instance_dims else 0 - mapped_result = [ - _align_dims(tf.gather(t, adjusted_indices, axis=axis), ndims) - for t in tensors_to_map - ] - - return tuple(mapped_result) + # Append 0s to allow mapping OOVs to it. + tensors_to_map = [ + tf.concat([t, tf.expand_dims(tf.zeros_like(t[0]), 0)], axis=0) + for t in tensors_to_map + ] + + # Replace `-1`s due to OOV with size of key_vocab. + adjusted_indices = tf.where( + key_indices >= 0, + key_indices, + tf.cast(tf.fill(tf.shape(key_indices), tf.size(key_vocab)), dtype=tf.int64), + ) + axis = -1 if reduce_instance_dims else 0 + mapped_result = [ + _align_dims(tf.gather(t, adjusted_indices, axis=axis), ndims) + for t in tensors_to_map + ] + + return tuple(mapped_result) def reduce_batch_count_mean_and_var_per_key( - x: common_types.TensorType, key: common_types.TensorType, - reduce_instance_dims: bool + x: common_types.TensorType, key: common_types.TensorType, reduce_instance_dims: bool ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor, tf.Tensor]: - """Computes per-key element count, mean and var for the given tensor. - - Args: - x: A `Tensor` or `CompositeTensor`. - key: A `Tensor` or `CompositeTensor` (cannot be None). - Must meet one of the following conditions: - 1. Both x and key are dense, - 2. Both x and key are composite and `key` must exactly match `x` in - everything except values, - 3. The axis=1 index of each element of sparse x matches its index of - dense key. - reduce_instance_dims: A bool, if True - collapses the batch and instance - dimensions to arrive at a single scalar output. Otherwise, only - collapses the batch dimension and outputs a `Tensor` of the same shape - as the input. Not supported for `CompositeTensor`s. - - Returns: - A 4-tuple containing the `Tensor`s (key_vocab, count, mean, var). NaNs and - infinite input values are ignored. - """ - - if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): - if not reduce_instance_dims: - raise NotImplementedError( - 'Mean and var per key only support reduced dims for CompositeTensors') - - x, key = _validate_and_get_dense_value_key_inputs(x, key) - - unique = tf.unique(key, out_idx=tf.int64) - x_is_finite = _is_finite(x) - - finite_x = tf.where(x_is_finite, x, tf.zeros_like(x)) - if reduce_instance_dims: - x_count = tf.cast(x_is_finite, x.dtype) - if x.get_shape().ndims != 1: - x_count = tf.reduce_sum(x_count, axis=1) - x_count = tf.math.unsorted_segment_sum(x_count, unique.idx, - tf.size(unique.y)) - sums = ( - tf.reduce_sum(finite_x, axis=1) - if x.get_shape().ndims != 1 else finite_x) - sums = tf.math.unsorted_segment_sum(sums, unique.idx, tf.size(unique.y)) - else: - sums = tf.math.unsorted_segment_sum(finite_x, unique.idx, tf.size(unique.y)) - x_count = tf.math.unsorted_segment_sum( - tf.cast(x_is_finite, tf.float32), unique.idx, tf.size(unique.y)) - - means = tf.math.divide_no_nan(tf.cast(sums, x.dtype), x_count) - sum_sqs = tf.math.unsorted_segment_sum( - tf.square(finite_x), unique.idx, tf.size(input=unique.y)) - if sum_sqs.get_shape().ndims != 1 and reduce_instance_dims: - sum_sqs = tf.reduce_sum(sum_sqs, axis=1) - - variances = tf.math.divide_no_nan(sum_sqs, x_count) - tf.square(means) - - return unique.y, tf.cast(x_count, tf.int64), means, variances + """Computes per-key element count, mean and var for the given tensor. + + Args: + ---- + x: A `Tensor` or `CompositeTensor`. + key: A `Tensor` or `CompositeTensor` (cannot be None). + Must meet one of the following conditions: + 1. Both x and key are dense, + 2. Both x and key are composite and `key` must exactly match `x` in + everything except values, + 3. The axis=1 index of each element of sparse x matches its index of + dense key. + reduce_instance_dims: A bool, if True - collapses the batch and instance + dimensions to arrive at a single scalar output. Otherwise, only + collapses the batch dimension and outputs a `Tensor` of the same shape + as the input. Not supported for `CompositeTensor`s. + + Returns: + ------- + A 4-tuple containing the `Tensor`s (key_vocab, count, mean, var). NaNs and + infinite input values are ignored. + """ + if isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): + if not reduce_instance_dims: + raise NotImplementedError( + "Mean and var per key only support reduced dims for CompositeTensors" + ) + + x, key = _validate_and_get_dense_value_key_inputs(x, key) + + unique = tf.unique(key, out_idx=tf.int64) + x_is_finite = _is_finite(x) + + finite_x = tf.where(x_is_finite, x, tf.zeros_like(x)) + if reduce_instance_dims: + x_count = tf.cast(x_is_finite, x.dtype) + if x.get_shape().ndims != 1: + x_count = tf.reduce_sum(x_count, axis=1) + x_count = tf.math.unsorted_segment_sum(x_count, unique.idx, tf.size(unique.y)) + sums = tf.reduce_sum(finite_x, axis=1) if x.get_shape().ndims != 1 else finite_x + sums = tf.math.unsorted_segment_sum(sums, unique.idx, tf.size(unique.y)) + else: + sums = tf.math.unsorted_segment_sum(finite_x, unique.idx, tf.size(unique.y)) + x_count = tf.math.unsorted_segment_sum( + tf.cast(x_is_finite, tf.float32), unique.idx, tf.size(unique.y) + ) + + means = tf.math.divide_no_nan(tf.cast(sums, x.dtype), x_count) + sum_sqs = tf.math.unsorted_segment_sum( + tf.square(finite_x), unique.idx, tf.size(input=unique.y) + ) + if sum_sqs.get_shape().ndims != 1 and reduce_instance_dims: + sum_sqs = tf.reduce_sum(sum_sqs, axis=1) + + variances = tf.math.divide_no_nan(sum_sqs, x_count) - tf.square(means) + + return unique.y, tf.cast(x_count, tf.int64), means, variances # Code for serializing and example proto -_DEFAULT_VALUE_BY_DTYPE = { - tf.string: '', - tf.float32: 0, - tf.int64: 0 -} - - -def _encode_proto(values_dict, message_type, descriptor_source=''): - """A wrapper around tf.raw_ops.EncodeProto.""" - field_names = [] - sizes = [] - values = [] - for field_name, value in sorted(values_dict.items(), key=lambda x: x[0]): - if isinstance(value, tf.SparseTensor): - size = tf.sparse.reduce_sum( - tf.SparseTensor(value.indices, - tf.ones_like(value.values, dtype=tf.int32), - value.dense_shape), - axis=1) - value = tf.sparse.to_dense(value, _DEFAULT_VALUE_BY_DTYPE[value.dtype]) - else: - value = tf.reshape(value, [tf.shape(input=value)[0], -1]) - size = tf.fill((tf.shape(input=value)[0],), tf.shape(input=value)[1]) - field_names.append(field_name) - values.append(value) - sizes.append(size) +_DEFAULT_VALUE_BY_DTYPE = {tf.string: "", tf.float32: 0, tf.int64: 0} - sizes = tf.stack(sizes, axis=1) - return tf.raw_ops.EncodeProto( - sizes=sizes, - values=values, - field_names=field_names, - message_type=message_type, - descriptor_source=descriptor_source) + +def _encode_proto(values_dict, message_type, descriptor_source=""): + """A wrapper around tf.raw_ops.EncodeProto.""" + field_names = [] + sizes = [] + values = [] + for field_name, value in sorted(values_dict.items(), key=lambda x: x[0]): + if isinstance(value, tf.SparseTensor): + size = tf.sparse.reduce_sum( + tf.SparseTensor( + value.indices, + tf.ones_like(value.values, dtype=tf.int32), + value.dense_shape, + ), + axis=1, + ) + value = tf.sparse.to_dense(value, _DEFAULT_VALUE_BY_DTYPE[value.dtype]) + else: + value = tf.reshape(value, [tf.shape(input=value)[0], -1]) + size = tf.fill((tf.shape(input=value)[0],), tf.shape(input=value)[1]) + field_names.append(field_name) + values.append(value) + sizes.append(size) + + sizes = tf.stack(sizes, axis=1) + return tf.raw_ops.EncodeProto( + sizes=sizes, + values=values, + field_names=field_names, + message_type=message_type, + descriptor_source=descriptor_source, + ) def _serialize_feature(values): - """Serialize a Tensor or SparseTensor as `Feature` protos. - - `values` should be a Tensor of rank >=1 or SparseTensor of rank 2. We will - refer to the size of the first dimension as batch_size. - - This function encodes each row of the `Tensor` as a list of values (flattening - the other dimensions) and each row of the `SparseTensor` as a list of values, - where the indices within each row are ignored and assumed to be 0, 1, .... - - Args: - values: A `Tensor` or `SparseTensor`. - - Returns: - A tensor of shape (batch_size,) and type `tf.string` where each element is - a serialized `Feature` proto. - - Raises: - ValueError: If the dtype is of `values` is not `tf.string`, `tf.float32` - or `tf.int64`. - """ - values = tf.compat.v1.convert_to_tensor_or_sparse_tensor(values) - if values.dtype == tf.string: - values_dict = { - 'bytes_list': _encode_proto({'value': values}, 'tensorflow.BytesList') - } - elif values.dtype == tf.float32: - values_dict = { - 'float_list': _encode_proto({'value': values}, 'tensorflow.FloatList') - } - elif values.dtype == tf.int64: - values_dict = { - 'int64_list': _encode_proto({'value': values}, 'tensorflow.Int64List') - } - else: - raise ValueError('Cannot encode values of dtype {}'.format(values.dtype)) - return _encode_proto(values_dict, 'tensorflow.Feature') + """Serialize a Tensor or SparseTensor as `Feature` protos. + + `values` should be a Tensor of rank >=1 or SparseTensor of rank 2. We will + refer to the size of the first dimension as batch_size. + + This function encodes each row of the `Tensor` as a list of values (flattening + the other dimensions) and each row of the `SparseTensor` as a list of values, + where the indices within each row are ignored and assumed to be 0, 1, .... + + Args: + ---- + values: A `Tensor` or `SparseTensor`. + + Returns: + ------- + A tensor of shape (batch_size,) and type `tf.string` where each element is + a serialized `Feature` proto. + + Raises: + ------ + ValueError: If the dtype is of `values` is not `tf.string`, `tf.float32` + or `tf.int64`. + """ + values = tf.compat.v1.convert_to_tensor_or_sparse_tensor(values) + if values.dtype == tf.string: + values_dict = { + "bytes_list": _encode_proto({"value": values}, "tensorflow.BytesList") + } + elif values.dtype == tf.float32: + values_dict = { + "float_list": _encode_proto({"value": values}, "tensorflow.FloatList") + } + elif values.dtype == tf.int64: + values_dict = { + "int64_list": _encode_proto({"value": values}, "tensorflow.Int64List") + } + else: + raise ValueError(f"Cannot encode values of dtype {values.dtype}") + return _encode_proto(values_dict, "tensorflow.Feature") def serialize_example(features): - """Serialized a dict of `Tensor` or `SparseTensor`s as example protos. - - `features` should be a dict where each value is a Tensor of rank >=1 or - SparseTensor of rank 2. The sizes of the first dimension of each value should - be the same, and we refer to this size as batch_size. - - Args: - features: A dictionary whose values are `Tensor`s or `SparseTensor`s. - - Returns: - A tensor of shape (batch_size,) and type `tf.string` where each element is - a serialized `Example` proto. - """ - features_dict = [] - for key, value in sorted(features.items(), key=lambda x: x[0]): - serialized_value = _serialize_feature(value) - features_dict.append( - _encode_proto({ - 'key': tf.fill((tf.shape(input=serialized_value)[0],), key), - 'value': serialized_value, - }, 'tensorflow.Features.FeatureEntry')) - features_dict = tf.stack(features_dict, axis=1) - features = _encode_proto({'feature': features_dict}, 'tensorflow.Features') - return _encode_proto({'features': features}, 'tensorflow.Example') + """Serialized a dict of `Tensor` or `SparseTensor`s as example protos. + + `features` should be a dict where each value is a Tensor of rank >=1 or + SparseTensor of rank 2. The sizes of the first dimension of each value should + be the same, and we refer to this size as batch_size. + + Args: + ---- + features: A dictionary whose values are `Tensor`s or `SparseTensor`s. + + Returns: + ------- + A tensor of shape (batch_size,) and type `tf.string` where each element is + a serialized `Example` proto. + """ + features_dict = [] + for key, value in sorted(features.items(), key=lambda x: x[0]): + serialized_value = _serialize_feature(value) + features_dict.append( + _encode_proto( + { + "key": tf.fill((tf.shape(input=serialized_value)[0],), key), + "value": serialized_value, + }, + "tensorflow.Features.FeatureEntry", + ) + ) + features_dict = tf.stack(features_dict, axis=1) + features = _encode_proto({"feature": features_dict}, "tensorflow.Features") + return _encode_proto({"features": features}, "tensorflow.Example") def _get_missing_value(dtype: tf.DType) -> tf.Tensor: - if dtype.is_floating: - return tf.constant(_FLOATING_NAN, dtype) - else: - return tf.constant(dtype.min + 1, dtype) + if dtype.is_floating: + return tf.constant(_FLOATING_NAN, dtype) + else: + return tf.constant(dtype.min + 1, dtype) def _sparse_minus_reduce_min_and_reduce_max( - x: tf.SparseTensor) -> Tuple[tf.Tensor, tf.Tensor]: - """Computes the -min and max of a SparseTensor x. - - It differs from sparse_reduce_max in that sparse_reduce_max returns 0 when all - elements are missing along axis 0. - We replace the 0 with NaN when x's dtype is float and dtype.min+1 when it's - int. - - Args: - x: A `SparseTensor`. - - Returns: - Two `Tensors' which are the -min and max. - - Raises: - TypeError: If the type of `x` is not supported. - """ - minus_x = tf.SparseTensor( - indices=x.indices, values=0 - x.values, dense_shape=x.dense_shape) - x_count = reduce_batch_count(x, reduce_instance_dims=False) - batch_has_no_values = tf.equal(x_count, tf.constant(0, dtype=tf.int64)) - x_batch_max = _sparse_reduce_batch_keep_shape(tf.sparse.reduce_max, x) - x_batch_minus_min = _sparse_reduce_batch_keep_shape(tf.sparse.reduce_max, - minus_x) - missing_value = _get_missing_value(x.dtype) - x_batch_max = tf.where(batch_has_no_values, - tf.fill(tf.shape(input=x_batch_max), missing_value), - x_batch_max) - x_batch_minus_min = tf.where( - batch_has_no_values, - tf.fill(tf.shape(input=x_batch_minus_min), missing_value), - x_batch_minus_min) - return x_batch_minus_min, x_batch_max + x: tf.SparseTensor, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes the -min and max of a SparseTensor x. + + It differs from sparse_reduce_max in that sparse_reduce_max returns 0 when all + elements are missing along axis 0. + We replace the 0 with NaN when x's dtype is float and dtype.min+1 when it's + int. + + Args: + ---- + x: A `SparseTensor`. + + Returns: + ------- + Two `Tensors' which are the -min and max. + + Raises: + ------ + TypeError: If the type of `x` is not supported. + """ + minus_x = tf.SparseTensor( + indices=x.indices, values=0 - x.values, dense_shape=x.dense_shape + ) + x_count = reduce_batch_count(x, reduce_instance_dims=False) + batch_has_no_values = tf.equal(x_count, tf.constant(0, dtype=tf.int64)) + x_batch_max = _sparse_reduce_batch_keep_shape(tf.sparse.reduce_max, x) + x_batch_minus_min = _sparse_reduce_batch_keep_shape(tf.sparse.reduce_max, minus_x) + missing_value = _get_missing_value(x.dtype) + x_batch_max = tf.where( + batch_has_no_values, + tf.fill(tf.shape(input=x_batch_max), missing_value), + x_batch_max, + ) + x_batch_minus_min = tf.where( + batch_has_no_values, + tf.fill(tf.shape(input=x_batch_minus_min), missing_value), + x_batch_minus_min, + ) + return x_batch_minus_min, x_batch_max def reduce_batch_minus_min_and_max( - x: common_types.TensorType, - reduce_instance_dims: bool) -> Tuple[tf.Tensor, tf.Tensor]: - """Computes the -min and max of a tensor x. - - NOTE: For TF versions < 2.4, if all feature values are NaNs, the -min and max - will both be -inf (consistent with`tf.reduce_max`). - - Args: - x: A `Tensor` or `CompositeTensor`. - reduce_instance_dims: A bool indicating whether this should collapse the - batch and instance dimensions to arrive at a single scalar output, or only - collapse the batch dimension and outputs a vector of the same shape as the - input. - - Returns: - The computed tensor's (batch -min, batch max) pair. - """ - # In TF < 2.3, neg(x) would throw an exception, if x was tf.int16. Hence, cast - # to tf.int32. - if x.dtype in (tf.uint8, tf.uint16, tf.int16): - x = tf.cast(x, tf.int32) - - elif x.dtype == tf.uint32 or x.dtype == tf.uint64: - raise TypeError('Tensor type %r is not supported' % x.dtype) - - if reduce_instance_dims: - if isinstance(x, tf.SparseTensor): - x = x.values - elif isinstance(x, tf.RaggedTensor): - x = x.flat_values - - x_batch_max = tf.reduce_max(input_tensor=x) - x_batch_minus_min = tf.reduce_max(input_tensor=tf.zeros_like(x) - x) - return x_batch_minus_min, x_batch_max - - elif isinstance(x, tf.SparseTensor): - return _sparse_minus_reduce_min_and_reduce_max(x) + x: common_types.TensorType, reduce_instance_dims: bool +) -> Tuple[tf.Tensor, tf.Tensor]: + """Computes the -min and max of a tensor x. + + NOTE: For TF versions < 2.4, if all feature values are NaNs, the -min and max + will both be -inf (consistent with`tf.reduce_max`). + + Args: + ---- + x: A `Tensor` or `CompositeTensor`. + reduce_instance_dims: A bool indicating whether this should collapse the + batch and instance dimensions to arrive at a single scalar output, or only + collapse the batch dimension and outputs a vector of the same shape as the + input. + + Returns: + ------- + The computed tensor's (batch -min, batch max) pair. + """ + # In TF < 2.3, neg(x) would throw an exception, if x was tf.int16. Hence, cast + # to tf.int32. + if x.dtype in (tf.uint8, tf.uint16, tf.int16): + x = tf.cast(x, tf.int32) + + elif x.dtype == tf.uint32 or x.dtype == tf.uint64: + raise TypeError("Tensor type %r is not supported" % x.dtype) - x_batch_max = tf.reduce_max(input_tensor=x, axis=0) - if isinstance(x, tf.RaggedTensor): - x_batch_minus_min = tf.reduce_max(input_tensor=tf.math.negative(x), axis=0) - missing_value = _get_missing_value(x.dtype) - return (x_batch_minus_min.to_tensor(default_value=missing_value), - x_batch_max.to_tensor(default_value=missing_value)) - else: - # TODO(iindyk): switch to `tf.math.negative` when analyzer cache will get - # invalidated next time. - return (tf.reduce_max(input_tensor=0 - x, axis=0), x_batch_max) + if reduce_instance_dims: + if isinstance(x, tf.SparseTensor): + x = x.values + elif isinstance(x, tf.RaggedTensor): + x = x.flat_values + + x_batch_max = tf.reduce_max(input_tensor=x) + x_batch_minus_min = tf.reduce_max(input_tensor=tf.zeros_like(x) - x) + return x_batch_minus_min, x_batch_max + + elif isinstance(x, tf.SparseTensor): + return _sparse_minus_reduce_min_and_reduce_max(x) + + x_batch_max = tf.reduce_max(input_tensor=x, axis=0) + if isinstance(x, tf.RaggedTensor): + x_batch_minus_min = tf.reduce_max(input_tensor=tf.math.negative(x), axis=0) + missing_value = _get_missing_value(x.dtype) + return ( + x_batch_minus_min.to_tensor(default_value=missing_value), + x_batch_max.to_tensor(default_value=missing_value), + ) + else: + # TODO(iindyk): switch to `tf.math.negative` when analyzer cache will get + # invalidated next time. + return (tf.reduce_max(input_tensor=0 - x, axis=0), x_batch_max) def reduce_batch_minus_min_and_max_per_key( x: common_types.TensorType, key: common_types.TensorType, - reduce_instance_dims: bool = True + reduce_instance_dims: bool = True, ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: - """Computes the -min and max of a tensor x. - - Args: - x: A `Tensor` or `CompositeTensor`. - key: A `Tensor` or `CompositeTensor`. Must meet one of the following - conditions - 1. Both x and key are dense, 2. Both x and key are composite - and `key` must exactly match `x` in everything except values, 3. The - axis=1 index of each element of sparse x matches its index of dense key. - reduce_instance_dims: A bool indicating whether this should collapse the - batch and instance dimensions to arrive at a single scalar output, or only - collapse the batch dimension and outputs a vector of the same shape as the - input. - Returns: - A 3-tuple containing the `Tensor`s (key_vocab, min_per_key, max_per_key). - """ - if x.dtype == tf.uint8 or x.dtype == tf.uint16: - x = tf.cast(x, tf.int32) - - elif x.dtype == tf.uint32 or x.dtype == tf.uint64: - raise TypeError('Tensor type %r is not supported' % x.dtype) - - if not reduce_instance_dims and isinstance( - x, (tf.SparseTensor, tf.RaggedTensor)): - raise NotImplementedError( - 'Elementwise reduction of composite tensors is not supported' - ) + """Computes the -min and max of a tensor x. + + Args: + ---- + x: A `Tensor` or `CompositeTensor`. + key: A `Tensor` or `CompositeTensor`. Must meet one of the following + conditions - 1. Both x and key are dense, 2. Both x and key are composite + and `key` must exactly match `x` in everything except values, 3. The + axis=1 index of each element of sparse x matches its index of dense key. + reduce_instance_dims: A bool indicating whether this should collapse the + batch and instance dimensions to arrive at a single scalar output, or only + collapse the batch dimension and outputs a vector of the same shape as the + input. + + Returns: + ------- + A 3-tuple containing the `Tensor`s (key_vocab, min_per_key, max_per_key). + """ + if x.dtype == tf.uint8 or x.dtype == tf.uint16: + x = tf.cast(x, tf.int32) + + elif x.dtype == tf.uint32 or x.dtype == tf.uint64: + raise TypeError("Tensor type %r is not supported" % x.dtype) + + if not reduce_instance_dims and isinstance(x, (tf.SparseTensor, tf.RaggedTensor)): + raise NotImplementedError( + "Elementwise reduction of composite tensors is not supported" + ) - x, key = _validate_and_get_dense_value_key_inputs(x, key) + x, key = _validate_and_get_dense_value_key_inputs(x, key) - def get_batch_max_per_key(tensor, key_uniques): # pylint: disable=missing-docstring - if not reduce_instance_dims or tensor.get_shape().ndims < 2: - row_maxes = tensor - else: - row_maxes = tf.reduce_max( - tensor, axis=tf.range(1, tensor.get_shape().ndims)) - return tf.math.unsorted_segment_max(row_maxes, key_uniques.idx, - tf.size(input=key_uniques.y)) + def get_batch_max_per_key(tensor, key_uniques): # pylint: disable=missing-docstring + if not reduce_instance_dims or tensor.get_shape().ndims < 2: + row_maxes = tensor + else: + row_maxes = tf.reduce_max( + tensor, axis=tf.range(1, tensor.get_shape().ndims) + ) + return tf.math.unsorted_segment_max( + row_maxes, key_uniques.idx, tf.size(input=key_uniques.y) + ) - unique = tf.unique_with_counts(key, out_idx=tf.int64) - x_batch_maxes = get_batch_max_per_key(x, unique) - x_batch_minus_mins = get_batch_max_per_key(-x, unique) + unique = tf.unique_with_counts(key, out_idx=tf.int64) + x_batch_maxes = get_batch_max_per_key(x, unique) + x_batch_minus_mins = get_batch_max_per_key(-x, unique) - return (unique.y, x_batch_minus_mins, x_batch_maxes) + return (unique.y, x_batch_minus_mins, x_batch_maxes) -def track_asset_analyzer_output(eager_asset_path: ops.EagerTensor, - graph_tensor: tf.Tensor): - """Track `graph_tensor` representing analyzer output written to `eager_asset_path`.""" - graph = ops.get_default_graph() - graph.add_to_collection( - _ASSET_REPLACEMENTS, - (hashable_tensor_or_op(graph_tensor), eager_asset_path)) +def track_asset_analyzer_output( + eager_asset_path: ops.EagerTensor, graph_tensor: tf.Tensor +): + """Track `graph_tensor` representing analyzer output written to `eager_asset_path`.""" + graph = ops.get_default_graph() + graph.add_to_collection( + _ASSET_REPLACEMENTS, (hashable_tensor_or_op(graph_tensor), eager_asset_path) + ) def _get_asset_analyzer_output_and_control_dependency( - asset_filepath: _AssetFileType + asset_filepath: _AssetFileType, ) -> Tuple[_AssetFileType, Optional[tf.Tensor]]: - """Returns a tuple of (asset filepath, control dependency).""" - control_dependency = None - asset_replacements_coll = ops.get_default_graph().get_collection( - _ASSET_REPLACEMENTS) - if not asset_replacements_coll: - return asset_filepath, control_dependency + """Returns a tuple of (asset filepath, control dependency).""" + control_dependency = None + asset_replacements_coll = ops.get_default_graph().get_collection( + _ASSET_REPLACEMENTS + ) + if not asset_replacements_coll: + return asset_filepath, control_dependency - if not isinstance(asset_filepath, tf.Tensor): - raise ValueError('Expected asset_filepath ({}) to be a tf.Tensor.'.format( - asset_filepath)) - eager_asset_filepath = dict(asset_replacements_coll).get( - hashable_tensor_or_op(asset_filepath), None) - if eager_asset_filepath: - control_dependency = asset_filepath - asset_filepath = eager_asset_filepath - return asset_filepath, control_dependency - - -def _lookup_table(table: lookup_ops.LookupInterface, x: tf.Tensor, - control_dependency: Optional[tf.Tensor]) -> tf.Tensor: - """Look up x in table with an optional depndency on control_dependency.""" - with contextlib.ExitStack() as stack: - # tf.control_dependencies([tensor]) adds a dependency to tensor.op. Wrap the - # tensor in an identity op to ensure that walking the graph from `result` - # encounters the control_dependency tensor. - if control_dependency is not None: - stack.enter_context( - tf.control_dependencies([tf.identity(control_dependency)])) - result = table.lookup(x) - return result - - -def construct_and_lookup_table(construct_table_callable: Callable[ - [_AssetFileType], lookup_ops.LookupInterface], - asset_filepath: _AssetFileType, - x: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: - """Construct a table and look x up in it. - - Args: - construct_table_callable: A Callable that takes a path to an asset file and - constructs a lookup table. - asset_filepath: Path to an asset used to construct the table. Can be a - python string, a `tf.Tensor`, a `tf.Placeholder`. - x: A categorical `Tensor` of type tf.string or tf.int[8|16|32|64] to which - the table lookup should be applied. - - Returns: - A tuple of the result from looking x up in a table and the table's size. - - """ - # If table is lifted into an initialization scope, add a control dependency - # on the graph tensor used to track this analyzer in - # `analyzer_nodes.TENSOR_REPLACEMENTS`. - asset_filepath, control_dependency = ( - _get_asset_analyzer_output_and_control_dependency(asset_filepath)) - with contextlib.ExitStack() as stack: - if (tf.inside_function() and - isinstance(asset_filepath, (ops.EagerTensor, str))): - # Lift the table initialization out of graph construction to avoid - # repeated initialization in TF2. - stack.enter_context(tf.init_scope()) - - table = construct_table_callable(asset_filepath) - table_size = table.size() - return _lookup_table(table, x, control_dependency), table_size - - -def lookup_table(lookup_fn: Callable[[common_types.TensorType, tf.Tensor], - Tuple[tf.Tensor, tf.Tensor]], - asset_filepath: _AssetFileType, x: common_types.TensorType): - """Takes a `lookup_fn` and invokes it on `x` and `asset_filepath`. - - If an eager tensor is being tracked by `asset_filepath`, `lookup_fn` is - invoked on it instead. - - Args: - lookup_fn: A Callable that should take a tensor and a deferred vocab - filename as an input and return a lookup `op` along with the table size. - asset_filepath: Path to an asset used to construct the table. Can be a - python string, a `tf.Tensor`, a `tf.Placeholder`. - x: A categorical `Tensor` or `SparseTensor` of type tf.string or - tf.int[8|16|32|64] to which the table lookup should be applied. - - Returns: - A tuple of the result from looking x up and the table size. - """ - # If table is lifted into an initialization scope, add a control dependency - # on the graph tensor used to track this analyzer in - # `analyzer_nodes.TENSOR_REPLACEMENTS`. - asset_filepath, control_dependency = ( - _get_asset_analyzer_output_and_control_dependency(asset_filepath)) - lookup_result, table_size = lookup_fn(x, asset_filepath) - with contextlib.ExitStack() as stack: - # tf.control_dependencies([tensor]) adds a dependency to tensor.op. Wrap the - # `lookup_result` in an identity op to ensure that walking the graph from - # it encounters the `control_dependency` tensor. The table size should not - # have the `control_dependency` tensor as its parent, hence it is returned - # as is. - if control_dependency is not None: - stack.enter_context( - tf.control_dependencies([tf.identity(control_dependency)])) - return tf.identity(lookup_result), table_size - - -def to_vocab_range(x: tf.SparseTensor, - vocab_size: Union[int, tf.Tensor]) -> tf.SparseTensor: - """Mods x's int values to enforce that the vocab_ids in x are in range. - - Args: - x: A int-valued SparseTensor typically representing the vocab indices of - terms. This is usually the output of tft.compute_and_apply_vocabulary. - vocab_size: An int or scalar tensor representing the size of vocab. Values - in x will be mod by this size to avoid negative or out-of-vocab indices. - - Returns: - A sparse tensor of the same size as x with negative or out-of-vocab values - normalized. - """ - return tf.SparseTensor( - indices=x.indices, - values=tf.math.mod(x.values, vocab_size), - dense_shape=x.dense_shape) - - -def document_frequency_to_idf(document_frequency: tf.Tensor, - corpus_size: Union[int, tf.Tensor], - smooth: bool = True, - add_baseline: bool = True) -> tf.Tensor: - """Computes inverse document frequency given document frequency. - - The inverse document frequency of a term, by default, is calculated as - 1 + log ((corpus size + 1) / (document frequency + 1)), where document - frequency is the number of documents that contain this term. - - Args: - document_frequency: A tensor storing the document frequency of each term. - corpus_size: An int or int scalar tensor representing the size of the entire - dataset, i.e., number of examples. - smooth: A bool indicating if the inverse document frequency should be - smoothed. If True, which is the default, then the idf is calculated as 1 + - log((corpus size + 1) / (document frequency of term + 1)). Otherwise, the - idf is 1 + log((corpus size) / (document frequency of term)), which could - result in a division by zero error. - add_baseline: A bool indicating if the inverse document frequency should be - added with a constant baseline 1.0. If True, which is the default, then - the idf is calculated as 1 + log(*). Otherwise, the idf is log(*) without - the constant 1 baseline. Keeping the baseline reduces the discrepancy in - idf between commonly seen terms and rare terms. - - Returns: - A tensor of the inverse document frequency of input document frequency. - """ - baseline = 1.0 if add_baseline else 0.0 - if smooth: - return ( - tf.math.log( - (tf.cast(corpus_size, dtype=tf.float32) + 1.0) - / (1.0 + tf.cast(document_frequency, dtype=tf.float32)) + if not isinstance(asset_filepath, tf.Tensor): + raise ValueError( + f"Expected asset_filepath ({asset_filepath}) to be a tf.Tensor." ) - + baseline + eager_asset_filepath = dict(asset_replacements_coll).get( + hashable_tensor_or_op(asset_filepath), None ) - else: - return ( - tf.math.log( - tf.cast(corpus_size, dtype=tf.float32) - / (tf.cast(document_frequency, dtype=tf.float32)) - ) - + baseline + if eager_asset_filepath: + control_dependency = asset_filepath + asset_filepath = eager_asset_filepath + return asset_filepath, control_dependency + + +def _lookup_table( + table: lookup_ops.LookupInterface, + x: tf.Tensor, + control_dependency: Optional[tf.Tensor], +) -> tf.Tensor: + """Look up x in table with an optional depndency on control_dependency.""" + with contextlib.ExitStack() as stack: + # tf.control_dependencies([tensor]) adds a dependency to tensor.op. Wrap the + # tensor in an identity op to ensure that walking the graph from `result` + # encounters the control_dependency tensor. + if control_dependency is not None: + stack.enter_context( + tf.control_dependencies([tf.identity(control_dependency)]) + ) + result = table.lookup(x) + return result + + +def construct_and_lookup_table( + construct_table_callable: Callable[[_AssetFileType], lookup_ops.LookupInterface], + asset_filepath: _AssetFileType, + x: tf.Tensor, +) -> Tuple[tf.Tensor, tf.Tensor]: + """Construct a table and look x up in it. + + Args: + ---- + construct_table_callable: A Callable that takes a path to an asset file and + constructs a lookup table. + asset_filepath: Path to an asset used to construct the table. Can be a + python string, a `tf.Tensor`, a `tf.Placeholder`. + x: A categorical `Tensor` of type tf.string or tf.int[8|16|32|64] to which + the table lookup should be applied. + + Returns: + ------- + A tuple of the result from looking x up in a table and the table's size. + + """ + # If table is lifted into an initialization scope, add a control dependency + # on the graph tensor used to track this analyzer in + # `analyzer_nodes.TENSOR_REPLACEMENTS`. + asset_filepath, control_dependency = ( + _get_asset_analyzer_output_and_control_dependency(asset_filepath) + ) + with contextlib.ExitStack() as stack: + if tf.inside_function() and isinstance(asset_filepath, (ops.EagerTensor, str)): + # Lift the table initialization out of graph construction to avoid + # repeated initialization in TF2. + stack.enter_context(tf.init_scope()) + + table = construct_table_callable(asset_filepath) + table_size = table.size() + return _lookup_table(table, x, control_dependency), table_size + + +def lookup_table( + lookup_fn: Callable[ + [common_types.TensorType, tf.Tensor], Tuple[tf.Tensor, tf.Tensor] + ], + asset_filepath: _AssetFileType, + x: common_types.TensorType, +): + """Takes a `lookup_fn` and invokes it on `x` and `asset_filepath`. + + If an eager tensor is being tracked by `asset_filepath`, `lookup_fn` is + invoked on it instead. + + Args: + ---- + lookup_fn: A Callable that should take a tensor and a deferred vocab + filename as an input and return a lookup `op` along with the table size. + asset_filepath: Path to an asset used to construct the table. Can be a + python string, a `tf.Tensor`, a `tf.Placeholder`. + x: A categorical `Tensor` or `SparseTensor` of type tf.string or + tf.int[8|16|32|64] to which the table lookup should be applied. + + Returns: + ------- + A tuple of the result from looking x up and the table size. + """ + # If table is lifted into an initialization scope, add a control dependency + # on the graph tensor used to track this analyzer in + # `analyzer_nodes.TENSOR_REPLACEMENTS`. + asset_filepath, control_dependency = ( + _get_asset_analyzer_output_and_control_dependency(asset_filepath) ) + lookup_result, table_size = lookup_fn(x, asset_filepath) + with contextlib.ExitStack() as stack: + # tf.control_dependencies([tensor]) adds a dependency to tensor.op. Wrap the + # `lookup_result` in an identity op to ensure that walking the graph from + # it encounters the `control_dependency` tensor. The table size should not + # have the `control_dependency` tensor as its parent, hence it is returned + # as is. + if control_dependency is not None: + stack.enter_context( + tf.control_dependencies([tf.identity(control_dependency)]) + ) + return tf.identity(lookup_result), table_size + + +def to_vocab_range( + x: tf.SparseTensor, vocab_size: Union[int, tf.Tensor] +) -> tf.SparseTensor: + """Mods x's int values to enforce that the vocab_ids in x are in range. + + Args: + ---- + x: A int-valued SparseTensor typically representing the vocab indices of + terms. This is usually the output of tft.compute_and_apply_vocabulary. + vocab_size: An int or scalar tensor representing the size of vocab. Values + in x will be mod by this size to avoid negative or out-of-vocab indices. + + Returns: + ------- + A sparse tensor of the same size as x with negative or out-of-vocab values + normalized. + """ + return tf.SparseTensor( + indices=x.indices, + values=tf.math.mod(x.values, vocab_size), + dense_shape=x.dense_shape, + ) + + +def document_frequency_to_idf( + document_frequency: tf.Tensor, + corpus_size: Union[int, tf.Tensor], + smooth: bool = True, + add_baseline: bool = True, +) -> tf.Tensor: + """Computes inverse document frequency given document frequency. + + The inverse document frequency of a term, by default, is calculated as + 1 + log ((corpus size + 1) / (document frequency + 1)), where document + frequency is the number of documents that contain this term. + + Args: + ---- + document_frequency: A tensor storing the document frequency of each term. + corpus_size: An int or int scalar tensor representing the size of the entire + dataset, i.e., number of examples. + smooth: A bool indicating if the inverse document frequency should be + smoothed. If True, which is the default, then the idf is calculated as 1 + + log((corpus size + 1) / (document frequency of term + 1)). Otherwise, the + idf is 1 + log((corpus size) / (document frequency of term)), which could + result in a division by zero error. + add_baseline: A bool indicating if the inverse document frequency should be + added with a constant baseline 1.0. If True, which is the default, then + the idf is calculated as 1 + log(*). Otherwise, the idf is log(*) without + the constant 1 baseline. Keeping the baseline reduces the discrepancy in + idf between commonly seen terms and rare terms. + + Returns: + ------- + A tensor of the inverse document frequency of input document frequency. + """ + baseline = 1.0 if add_baseline else 0.0 + if smooth: + return ( + tf.math.log( + (tf.cast(corpus_size, dtype=tf.float32) + 1.0) + / (1.0 + tf.cast(document_frequency, dtype=tf.float32)) + ) + + baseline + ) + else: + return ( + tf.math.log( + tf.cast(corpus_size, dtype=tf.float32) + / (tf.cast(document_frequency, dtype=tf.float32)) + ) + + baseline + ) def register_vocabulary_reserved_tokens( name: str, reserved_tokens: Union[Sequence[str], tf.Tensor] ) -> tf.Tensor: - """Registers a reserved_tokens tensor to a vocabulary.""" - if not isinstance(reserved_tokens, tf.Tensor): - reserved_tokens = tf.constant(reserved_tokens, dtype=tf.string) - tf.compat.v1.add_to_collection(_VOCABULARY_RESERVED_TOKENS_IDS, name) - tf.compat.v1.add_to_collection(_VOCABULARY_RESERVED_TOKENS, reserved_tokens) - return tf.size(reserved_tokens, out_type=tf.int64) + """Registers a reserved_tokens tensor to a vocabulary.""" + if not isinstance(reserved_tokens, tf.Tensor): + reserved_tokens = tf.constant(reserved_tokens, dtype=tf.string) + tf.compat.v1.add_to_collection(_VOCABULARY_RESERVED_TOKENS_IDS, name) + tf.compat.v1.add_to_collection(_VOCABULARY_RESERVED_TOKENS, reserved_tokens) + return tf.size(reserved_tokens, out_type=tf.int64) def fetch_vocabulary_reserved_tokens(graph, name: str) -> Sequence[str]: - """Fetches an evaluated reserved_tokens tensor for a vocabulary.""" - name_collection = graph.get_collection(_VOCABULARY_RESERVED_TOKENS_IDS) - tokens_collection = graph.get_collection(_VOCABULARY_RESERVED_TOKENS) - assert len(name_collection) == len(tokens_collection) - tensor = tokens_collection[name_collection.index(name)] - with tf.compat.v1.Session(graph=graph) as session: - return session.run(tensor) + """Fetches an evaluated reserved_tokens tensor for a vocabulary.""" + name_collection = graph.get_collection(_VOCABULARY_RESERVED_TOKENS_IDS) + tokens_collection = graph.get_collection(_VOCABULARY_RESERVED_TOKENS) + assert len(name_collection) == len(tokens_collection) + tensor = tokens_collection[name_collection.index(name)] + with tf.compat.v1.Session(graph=graph) as session: + return session.run(tensor) diff --git a/tensorflow_transform/tf_utils_test.py b/tensorflow_transform/tf_utils_test.py index 2676f6f..329b13d 100644 --- a/tensorflow_transform/tf_utils_test.py +++ b/tensorflow_transform/tf_utils_test.py @@ -16,2227 +16,2734 @@ import os import numpy as np -from packaging import version import tensorflow as tf -from tensorflow_transform import analyzers -from tensorflow_transform import annotators -from tensorflow_transform import tf_utils -from tensorflow_transform import test_case +from packaging import version +from tensorflow.python.framework import ( + composite_tensor, # pylint: disable=g-direct-tensorflow-import +) -from tensorflow.python.framework import composite_tensor # pylint: disable=g-direct-tensorflow-import +from tensorflow_transform import analyzers, annotators, test_case, tf_utils _CONSTRUCT_TABLE_PARAMETERS = [ - dict(testcase_name='_string', asset_path_input_fn=lambda x: x), - dict(testcase_name='_string_tensor', asset_path_input_fn=tf.constant), + dict(testcase_name="_string", asset_path_input_fn=lambda x: x), + dict(testcase_name="_string_tensor", asset_path_input_fn=tf.constant), ] -def _construct_table(asset_file_path, - key_dtype=tf.string, - key_index=0, - value_dtype=tf.int64, - value_index=1, - default_value=-1): - initializer = tf.lookup.TextFileInitializer( - asset_file_path, - key_dtype=key_dtype, - key_index=key_index, - value_dtype=value_dtype, - value_index=value_index) - return tf.lookup.StaticHashTable(initializer, default_value=default_value) +def _construct_table( + asset_file_path, + key_dtype=tf.string, + key_index=0, + value_dtype=tf.int64, + value_index=1, + default_value=-1, +): + initializer = tf.lookup.TextFileInitializer( + asset_file_path, + key_dtype=key_dtype, + key_index=key_index, + value_dtype=value_dtype, + value_index=value_index, + ) + return tf.lookup.StaticHashTable(initializer, default_value=default_value) def _value_to_tensor(value): - if isinstance(value, tf.compat.v1.SparseTensorValue): - return tf.compat.v1.convert_to_tensor_or_sparse_tensor(value) - elif isinstance(value, tf.compat.v1.ragged.RaggedTensorValue): - return tf.ragged.constant(value.to_list()) - else: - return tf.constant(value) + if isinstance(value, tf.compat.v1.SparseTensorValue): + return tf.compat.v1.convert_to_tensor_or_sparse_tensor(value) + elif isinstance(value, tf.compat.v1.ragged.RaggedTensorValue): + return tf.ragged.constant(value.to_list()) + else: + return tf.constant(value) class _SparseTensorSpec: + def __init__(self, shape, dtype): + self._shape = shape + self._dtype = dtype - def __init__(self, shape, dtype): - self._shape = shape - self._dtype = dtype -if not hasattr(tf, 'SparseTensorSpec'): - tf.SparseTensorSpec = _SparseTensorSpec +if not hasattr(tf, "SparseTensorSpec"): + tf.SparseTensorSpec = _SparseTensorSpec class TFUtilsTest(test_case.TransformTestCase): - - def _assertCompositeRefEqual(self, left, right): - """Asserts that a two `tf_util._CompositeTensorRef`s are equal.""" - self.assertEqual(left.type_spec, right.type_spec) - self.assertAllEqual(left.list_of_refs, right.list_of_refs) - - def test_copy_tensors_produces_different_tensors(self): - with tf.compat.v1.Graph().as_default(): - tensors = { - 'dense': - tf.compat.v1.placeholder( - tf.int64, (None,), name='my_dense_input'), - 'sparse': - tf.compat.v1.sparse_placeholder(tf.int64, name='my_sparse_input'), - 'ragged': - tf.compat.v1.ragged.placeholder( - tf.int64, ragged_rank=2, name='my_ragged_input') - } - copied_tensors = tf_utils.copy_tensors(tensors) - - self.assertNotEqual(tensors['dense'], copied_tensors['dense']) - self.assertNotEqual(tensors['sparse'].indices, - copied_tensors['sparse'].indices) - self.assertNotEqual(tensors['sparse'].values, - copied_tensors['sparse'].values) - self.assertNotEqual(tensors['sparse'].dense_shape, - copied_tensors['sparse'].dense_shape) - self.assertNotEqual(tensors['ragged'].values, - copied_tensors['ragged'].values) - self.assertNotEqual(tensors['ragged'].row_splits, - copied_tensors['ragged'].row_splits) - - def test_copy_tensors_produces_equivalent_tensors(self): - with tf.compat.v1.Graph().as_default(): - tensors = { - 'dense': - tf.compat.v1.placeholder( - tf.int64, (None,), name='my_dense_input'), - 'sparse': - tf.compat.v1.sparse_placeholder(tf.int64, name='my_sparse_input'), - 'ragged': - tf.compat.v1.ragged.placeholder( - tf.int64, ragged_rank=1, name='my_ragged_input') - } - copied_tensors = tf_utils.copy_tensors(tensors) - - with tf.compat.v1.Session() as session: - dense_value = [1, 2] - sparse_value = tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 2], [1, 1]], - values=[3, 4, 5], - dense_shape=[2, 3]) - ragged_value = tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([3, 4, 5], dtype=np.int64), - row_splits=np.array([0, 2, 3], dtype=np.int64)) - sample_tensors = session.run( - copied_tensors, - feed_dict={ - tensors['dense']: dense_value, - tensors['sparse']: sparse_value, - tensors['ragged']: ragged_value - }) - self.assertAllEqual(sample_tensors['dense'], dense_value) - self.assertAllEqual(sample_tensors['sparse'].indices, - sparse_value.indices) - self.assertAllEqual(sample_tensors['sparse'].values, - sparse_value.values) - self.assertAllEqual(sample_tensors['sparse'].dense_shape, - sparse_value.dense_shape) - self.assertAllEqual(sample_tensors['ragged'].values, - ragged_value.values) - self.assertAllEqual(sample_tensors['ragged'].row_splits, - ragged_value.row_splits) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='2d', - tensor=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1.2, 1., 1.2, 1.]), - row_splits=np.array([0, 2, 4])), - rowids=[0, 0, 1, 1], - tensor_spec=tf.RaggedTensorSpec([None, None], tf.float32)), - dict( - testcase_name='3d', - tensor=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1.2, 1., 1.2, 1.]), - row_splits=np.array([0, 3, 4])), - row_splits=np.array([0, 1, 1, 2])), - rowids=[0, 0, 0, 2], - tensor_spec=tf.RaggedTensorSpec([None, None, None], tf.float32)), - ])) - def test_get_ragged_batch_value_rowids(self, tensor, rowids, tensor_spec, - function_handler): - - @function_handler(input_signature=[tensor_spec]) - def get_ragged_batch_value_rowids(tensor): - return tf_utils._get_ragged_batch_value_rowids(tensor) - - self.assertAllEqual(get_ragged_batch_value_rowids(tensor), rowids) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='rank1', - x=['a', 'b', 'a'], - x_spec=tf.TensorSpec(None, tf.string), - weights=[1, 1, 2], - filter_regex=None, - expected_unique_x=[b'a', b'b'], - expected_summed_weights_per_x=[3, 1]), - dict( - testcase_name='rank2', - x=[['a', 'b\n', 'a'], ['b\n', 'a', 'b\n']], - x_spec=tf.TensorSpec(None, tf.string), - weights=[[1, 2, 1], [1, 2, 2]], - filter_regex=None, - expected_unique_x=[b'a', b'b\n'], - expected_summed_weights_per_x=[4, 5]), - dict( - testcase_name='rank3', - x=[[['a', 'b', 'a'], ['b', 'a', 'b']], - [['a', 'b', 'a'], ['b', 'a', 'b']]], - x_spec=tf.TensorSpec(None, tf.string), - weights=[[[1, 1, 2], [1, 2, 1]], [[1, 2, 1], [1, 2, 1]]], - filter_regex=None, - expected_unique_x=[b'a', b'b'], - expected_summed_weights_per_x=[9, 7]), - dict( - testcase_name='sparse', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 1], [2, 1]], - values=['a', 'a', 'b'], - dense_shape=[4, 2]), - x_spec=tf.SparseTensorSpec([4, 2], tf.string), - weights=[2, 3, 4], - filter_regex=None, - expected_unique_x=[b'a', b'b'], - expected_summed_weights_per_x=[5, 4]), - dict( - testcase_name='ragged', - x=tf.compat.v1.ragged.RaggedTensorValue( # pylint: disable=g-long-lambda - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array(['a', 'b', 'b', 'a']), - row_splits=np.array([0, 2, 4])), - row_splits=np.array([0, 2])), - x_spec=tf.RaggedTensorSpec([None, None, None], tf.string), - weights=[2, 3, 4, 6], - filter_regex=None, - expected_unique_x=[b'a', b'b'], - expected_summed_weights_per_x=[8, 7]), - dict( - testcase_name='regex_filtering', - x=[['a\n', '', '\n\r'], ['\r', 'a', 'b']], - x_spec=tf.TensorSpec(None, tf.string), - weights=[[1, 2, 1], [1, 2, 2]], - filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, - expected_unique_x=[b'a', b'b'], - expected_summed_weights_per_x=[2, 2]), - dict( - testcase_name='regex_filtering_invalid_utf8', - x=[[b'\xe1\n', b'\xa9', b'\n\xb8\r'], - [b'\xe8\r', b'\xc6', b'\n\xb3']], - x_spec=tf.TensorSpec(None, tf.string), - weights=[[1, 3, 1], [1, 4, 2]], - filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, - expected_unique_x=[b'\xa9', b'\xc6'], - expected_summed_weights_per_x=[3, 4]), - ])) - def test_reduce_batch_weighted_counts(self, x, x_spec, weights, filter_regex, - expected_unique_x, - expected_summed_weights_per_x, - function_handler): - input_signature = [x_spec, tf.TensorSpec(None, tf.float32)] - @function_handler(input_signature=input_signature) - def _reduce_batch_weighted_counts(x, weights): - (unique_x, summed_weights_per_x, summed_positive_per_x_and_y, - counts_per_x) = tf_utils.reduce_batch_weighted_counts( - x, weights, filter_regex=filter_regex) - self.assertIsNone(summed_positive_per_x_and_y) - self.assertIsNone(counts_per_x) - return unique_x, summed_weights_per_x - - unique_x, summed_weights_per_x = _reduce_batch_weighted_counts(x, weights) - - self.assertAllEqual(unique_x, - expected_unique_x) - self.assertAllEqual(summed_weights_per_x, - expected_summed_weights_per_x) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='rank1', - x=['a', 'b', 'a'], - filter_regex=None, - expected_result=[b'a', b'b', b'a'], - ), - dict( - testcase_name='rank2', - x=[['a', 'b\r', 'a'], ['b\r', 'a', 'b\r']], - filter_regex=None, - expected_result=[b'a', b'b\r', b'a', b'b\r', b'a', b'b\r'], - ), - dict( - testcase_name='rank3', - x=[[['a', 'b', 'a'], ['b', 'a', 'b']], - [['a', 'b', 'a'], ['b', 'a', 'b']]], - filter_regex=None, - expected_result=[ - b'a', b'b', b'a', b'b', b'a', b'b', b'a', b'b', b'a', b'b', - b'a', b'b' - ], - ), - dict( - testcase_name='regex_filtering_empty_result', - x=['a\n\r', 'b\n', 'a\r', '', 'a\rsd', ' \r', '\nas'], - filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, - expected_result=[], - ), - ])) - def test_reduce_batch_weighted_counts_weights_none(self, x, filter_regex, - expected_result, - function_handler): - input_signature = [tf.TensorSpec(None, tf.string)] - - @function_handler(input_signature=input_signature) - def _reduce_batch_weighted_counts(x): - (unique_x, summed_weights_per_x, summed_positive_per_x_and_y, - counts_per_x) = tf_utils.reduce_batch_weighted_counts( - x, force=False, filter_regex=filter_regex) - self.assertIsNone(summed_weights_per_x) - self.assertIsNone(summed_positive_per_x_and_y) - self.assertIsNone(counts_per_x) - return unique_x - - unique_x = _reduce_batch_weighted_counts(x) - self.assertAllEqual(unique_x, expected_result) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='rank1', - x=['a', 'b', 'a'], - filter_regex=None, - expected_result=([b'a', b'b'], [2, 1]), - ), - dict( - testcase_name='rank3', - x=[[['a', 'b', 'a'], ['b', 'a', 'b']], - [['a', 'b', 'a'], ['b', 'a', 'b']]], - filter_regex=None, - expected_result=([b'a', b'b'], [6, 6]), - ), - dict( - testcase_name='regex_filtering', - x=['a\n\r', 'b\n', 'a\r', '', 'asd', ' ', '\nas'], - filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, - expected_result=([b'asd', b' '], [1, 1]), - ), - dict( - testcase_name='regex_filtering_empty_result', - x=['a\n\r', 'b\n', 'a\r', '', 'a\rsd', ' \r', '\nas'], - filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, - expected_result=([], []), - ), - ])) - def test_reduce_batch_weighted_counts_weights_none_force( - self, x, filter_regex, expected_result, function_handler): - input_signature = [tf.TensorSpec(None, tf.string)] - - @function_handler(input_signature=input_signature) - def _reduce_batch_weighted_counts(x): - (unique_x, summed_weights_per_x, summed_positive_per_x_and_y, - counts_per_x) = tf_utils.reduce_batch_weighted_counts( - x, force=True, filter_regex=filter_regex) - self.assertIsNone(summed_weights_per_x) - self.assertIsNone(summed_positive_per_x_and_y) - return unique_x, counts_per_x - - expected_unique_x, expected_counts_per_x = expected_result - unique_x, counts_per_x = _reduce_batch_weighted_counts(x) - self.assertAllEqual(unique_x, expected_unique_x) - self.assertAllEqual(counts_per_x, expected_counts_per_x) - - @test_case.named_parameters([ - dict(testcase_name='constant', get_value_fn=lambda: tf.constant([1.618])), - dict(testcase_name='op', get_value_fn=lambda: tf.identity), - dict(testcase_name='int', get_value_fn=lambda: 4), - dict(testcase_name='object', get_value_fn=object), - dict( - testcase_name='sparse', - get_value_fn=lambda: tf.SparseTensor( # pylint: disable=g-long-lambda - indices=[[0, 0], [2, 1]], - values=['a', 'b'], - dense_shape=[4, 2])), - dict( - testcase_name='ragged', - get_value_fn=lambda: tf.RaggedTensor.from_row_splits( # pylint: disable=g-long-lambda - values=['a', 'b'], - row_splits=[0, 1, 2])), - dict( - testcase_name='ragged_multi_dimension', - get_value_fn=lambda: tf.RaggedTensor.from_row_splits( # pylint: disable=g-long-lambda - values=tf.RaggedTensor.from_row_splits( - values=[[0, 1], [2, 3]], row_splits=[0, 1, 2]), - row_splits=[0, 2])), - ]) - def test_hashable_tensor_or_op(self, get_value_fn): - with tf.compat.v1.Graph().as_default(): - input_value = get_value_fn() - input_ref = tf_utils.hashable_tensor_or_op(input_value) - input_dict = {input_ref: input_value} - input_deref = tf_utils.deref_tensor_or_op(input_ref) - if isinstance(input_value, composite_tensor.CompositeTensor): - self._assertCompositeRefEqual( - input_ref, tf_utils.hashable_tensor_or_op(input_deref)) - else: - self.assertAllEqual(input_ref, - tf_utils.hashable_tensor_or_op(input_deref)) - - if isinstance(input_value, tf.SparseTensor): - input_deref = input_deref.values - input_dict[input_ref] = input_dict[input_ref].values - input_value = input_value.values - - self.assertAllEqual(input_value, input_deref) - self.assertAllEqual(input_value, input_dict[input_ref]) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='rank1_with_weights_and_binary_y', - x=['a', 'b', 'a'], - weights=[1, 1, 2], - y=[0, 1, 1], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], [3, 1, 4], - [[1, 2], [0, 1], [1, 3]], [2, 1, 3]), - filter_regex=None, - ), - dict( - testcase_name='rank1_with_weights_and_multi_class_y', - x=['a', 'b\n', 'a', 'a'], - weights=[1, 1, 2, 2], - y=[0, 2, 1, 1], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b\n', b'global_y_count_sentinel'], [5, 1, 6], - [[1, 4, 0], [0, 0, 1], [1, 4, 1]], [3, 1, 4]), - filter_regex=None, - ), - dict( - testcase_name='rank1_with_weights_and_missing_y_values', - x=['a', 'b', 'a', 'a'], - weights=[1, 1, 2, 2], - y=[3, 5, 6, 6], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], [5, 1, 6], - [[0, 0, 0, 1, 0, 0, 4], [0, 0, 0, 0, 0, 1, 0], - [0, 0, 0, 1, 0, 1, 4]], [3, 1, 4]), - filter_regex=None, - ), - dict( - testcase_name='rank2_with_weights_and_binary_y', - x=[['a', 'b', 'a'], ['b', 'a', 'b']], - weights=[[1, 2, 1], [1, 2, 2]], - y=[[1, 0, 1], [1, 0, 0]], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], [4, 5, 9], - [[2, 2], [4, 1], [6, 3]], [3, 3, 6]), - filter_regex=None, - ), - dict( - testcase_name='rank3_with_weights_and_binary_y', - x=[[['a', 'b', 'a'], ['b', 'a', 'b']], - [['a', 'b', 'a'], ['b', 'a', 'b']]], - weights=[[[1, 1, 2], [1, 2, 1]], [[1, 2, 1], [1, 2, 1]]], - y=[[[1, 1, 0], [1, 0, 1]], [[1, 0, 1], [1, 0, 1]]], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], [9, 7, 16], - [[6, 3], [2, 5], [8, 8]], [6, 6, 12]), - filter_regex=None, - ), - dict( - testcase_name='rank1_with_weights_multi_class_y_and_filtering', - x=['\na\r', '', '\na\r', 'a', ''], - weights=[1, 1, 2, 2, 3], - y=[0, 2, 1, 1, 2], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'global_y_count_sentinel'], [2, 9], - [[0, 2, 0], [1, 4, 4]], [1, 5]), - filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, - ), - dict( - testcase_name='rank1_with_weights_filtering_empty_result', - x=['\na\r', '', '\na\r', '\ra', ''], - weights=[1, 1, 2, 2, 3], - y=[0, 2, 1, 1, 2], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'global_y_count_sentinel'], [9], [[1, 4, 4]], [5]), - filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, - ), - ])) - def test_reduce_batch_coocurrences(self, x, weights, y, expected_result, - filter_regex, function_handler): - input_signature = [tf.TensorSpec(None, tf.string), - tf.TensorSpec(None, tf.int64), - tf.TensorSpec(None, tf.int64)] - - @function_handler(input_signature=input_signature) - def _reduce_batch_weighted_cooccurrences(x, y, weights): - return tf_utils.reduce_batch_weighted_cooccurrences( - x, y, weights, filter_regex=filter_regex) - - result = _reduce_batch_weighted_cooccurrences(x, y, weights) - - self.assertAllEqual(result.unique_x, - expected_result.unique_x) - self.assertAllEqual(result.summed_weights_per_x, - expected_result.summed_weights_per_x) - self.assertAllEqual(result.summed_positive_per_x_and_y, - expected_result.summed_positive_per_x_and_y) - self.assertAllEqual(result.counts_per_x, - expected_result.counts_per_x) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='rank1_with_binary_y', - x=['a', 'b', 'a'], - y=[0, 1, 1], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], - [2, 1, 3], - [[1, 1], [0, 1], [1, 2]], - [2, 1, 3], - ), - input_signature=[ - tf.TensorSpec(None, tf.string), - tf.TensorSpec(None, tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='rank1_with_multi_class_y', - x=['yes', 'no', 'yes', 'may\rbe', 'yes'], - y=[1, 1, 0, 2, 3], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'yes', b'no', b'may\rbe', b'global_y_count_sentinel'], - [3, 1, 1, 5], - [[1, 1, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [1, 2, 1, 1]], - [3, 1, 1, 5], - ), - input_signature=[ - tf.TensorSpec(None, tf.string), - tf.TensorSpec(None, tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='rank1_with_sparse_y', - x=['yes', 'no', 'yes', 'may\rbe', 'yes'], - # 5 examples, 4 labels: - # 0: (3,2) - # 1: (1) - # 2: (2,0) - # 3: (1) - # 4: (2, 1) - y=tf.compat.v1.SparseTensorValue( - indices=( - (0, 0), - (0, 1), - (1, 0), - (2, 0), - (2, 1), - (3, 1), - (4, 0), - (4, 1), - ), - values=[3, 2, 1, 2, 0, 1, 2, 1], - dense_shape=[5, 2], - ), - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'yes', b'no', b'may\rbe', b'global_y_count_sentinel'], - [3, 1, 1, 5], - [[1, 1, 3, 1], [0, 1, 0, 0], [0, 1, 0, 0], [1, 3, 3, 1]], - [3, 1, 1, 5], - ), - input_signature=[ - tf.TensorSpec(None, tf.string), - tf.SparseTensorSpec([None, 2], tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='rank1_with_sparse_y_missing_labels', - x=['yes', 'no', 'yes', 'may\rbe', 'yes'], - # 5 examples, 4 labels: - # 0: (3,2) - # 1: () - # 2: (2,0) - # 3: (1) - # 4: (2, 1) - y=tf.compat.v1.SparseTensorValue( - indices=( - (0, 0), - (0, 1), - (2, 0), - (2, 1), - (3, 1), - (4, 0), - (4, 1), - ), - values=[3, 2, 2, 0, 1, 2, 1], - dense_shape=[5, 2], - ), - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'yes', b'no', b'may\rbe', b'global_y_count_sentinel'], - [3, 1, 1, 5], - [[1, 1, 3, 1], [0, 0, 0, 0], [0, 1, 0, 0], [1, 2, 3, 1]], - [3, 1, 1, 5], - ), - input_signature=[ - tf.TensorSpec(None, tf.string), - tf.SparseTensorSpec([None, 2], tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='rank2_with_binary_y', - x=[['a', 'b', 'a'], ['b', 'a', 'b']], - y=[[1, 0, 1], [1, 0, 0]], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], - [3, 3, 6], - [[1, 2], [2, 1], [3, 3]], - [3, 3, 6], - ), - input_signature=[ - tf.TensorSpec(None, tf.string), - tf.TensorSpec(None, tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='rank2_with_missing_y_values', - x=[['a', 'b', 'a'], ['b', 'a', 'b']], - y=[[2, 0, 2], [2, 0, 0]], - # The label 1 isn't in the batch but it will have a position (with - # weights of 0) in the resulting array. - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], - [3, 3, 6], - [[1, 0, 2], [2, 0, 1], [3, 0, 3]], - [3, 3, 6], - ), - input_signature=[ - tf.TensorSpec(None, tf.string), - tf.TensorSpec(None, tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='rank2_with_multi_class_y', - x=[['a', 'b', 'a'], ['b', 'a', 'b']], - y=[[1, 0, 1], [1, 0, 2]], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], - [3, 3, 6], - [[1, 2, 0], [1, 1, 1], [2, 3, 1]], - [3, 3, 6], - ), - input_signature=[ - tf.TensorSpec(None, tf.string), - tf.TensorSpec(None, tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='rank3_with_binary_y', - x=[ - [['a', 'b', 'a'], ['b', 'a', 'b']], - [['a', 'b', 'a'], ['b', 'a', 'b']], - ], - y=[[[1, 1, 0], [1, 0, 1]], [[1, 0, 1], [1, 0, 1]]], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], - [6, 6, 12], - [[3, 3], [1, 5], [4, 8]], - [6, 6, 12], - ), - input_signature=[ - tf.TensorSpec(None, tf.string), - tf.TensorSpec(None, tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='sparse', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [2, 1]], - values=['a', 'b'], - dense_shape=[4, 2], - ), - y=[0, 1, 0, 0], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], - [1, 1, 4], - [[1, 0], [1, 0], [3, 1]], - [1, 1, 4], - ), - input_signature=[ - tf.SparseTensorSpec([None, 2], tf.string), - tf.TensorSpec([None], tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='empty_sparse', - x=tf.compat.v1.SparseTensorValue( - indices=np.empty([0, 2]), values=[], dense_shape=[4, 2] - ), - y=[1, 0, 1, 1], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'global_y_count_sentinel'], [4], [[1, 3]], [4] - ), - input_signature=[ - tf.SparseTensorSpec([None, 2], tf.string), - tf.TensorSpec([None], tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='ragged', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array(['a', 'b', 'a', 'b', 'b']), - row_splits=np.array([0, 2, 3, 4, 5]), - ), - row_splits=np.array([0, 2, 3, 4]), - ), - row_splits=np.array([0, 2, 3]), - ), - y=[1, 0], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'a', b'b', b'global_y_count_sentinel'], - [2, 3, 2], - [[0, 2], [1, 2], [1, 1]], - [2, 3, 2], - ), - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.string), - tf.TensorSpec([None], tf.int64), - ], - filter_regex=None, - ), - dict( - testcase_name='rank1_with_filtering', - x=['yes\n', 'no', 'yes\n', '', 'yes\n'], - y=[1, 1, 0, 2, 3], - expected_result=tf_utils.ReducedBatchWeightedCounts( - [b'no', b'global_y_count_sentinel'], - [1, 5], - [[0, 1, 0, 0], [1, 2, 1, 1]], - [1, 5], - ), - input_signature=[ - tf.TensorSpec(None, tf.string), - tf.TensorSpec(None, tf.int64), - ], - filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, - ), - ]) - ) - def test_reduce_batch_coocurrences_no_weights(self, x, y, expected_result, - input_signature, filter_regex, - function_handler): - @function_handler(input_signature=input_signature) - def _reduce_batch_weighted_cooccurrences_no_weights(x, y): - return tf_utils.reduce_batch_weighted_cooccurrences( - x, y, filter_regex=filter_regex) - - result = _reduce_batch_weighted_cooccurrences_no_weights(x, y) - - self.assertAllEqual(result.unique_x, - expected_result.unique_x) - self.assertAllEqual(result.summed_weights_per_x, - expected_result.summed_weights_per_x) - self.assertAllEqual(result.summed_positive_per_x_and_y, - expected_result.summed_positive_per_x_and_y) - self.assertAllEqual(result.counts_per_x, - expected_result.counts_per_x) - - @test_case.parameters( - ([[1], [2]], [[1], [2], [3]], None, None, tf.errors.InvalidArgumentError, - 'Condition x == y did not hold element-wise:'), - ([[1], [2], [3]], [[1], [2], [3]], [None, None], [None], ValueError, - r'Shapes \(None, None\) and \(None,\) are incompatible'), - ) - def test_same_shape_exceptions(self, x_input, y_input, x_shape, y_shape, - exception_cls, error_string): - - with tf.compat.v1.Graph().as_default(): - x = tf.compat.v1.placeholder(tf.int32, x_shape) - y = tf.compat.v1.placeholder(tf.int32, y_shape) - with tf.compat.v1.Session() as sess: - with self.assertRaisesRegex(exception_cls, error_string): - sess.run(tf_utils.assert_same_shape(x, y), {x: x_input, y: y_input}) - - @test_case.named_parameters(test_case.FUNCTION_HANDLERS) - def test_same_shape(self, function_handler): - input_signature = [tf.TensorSpec(None, tf.int64), - tf.TensorSpec(None, tf.int64)] - - @function_handler(input_signature=input_signature) - def _assert_shape(x, y): - x_return, _ = tf_utils.assert_same_shape(x, y) - return x_return - - input_list = [[1], [2], [3]] - x_return = _assert_shape(input_list, input_list) - self.assertAllEqual(x_return, input_list) - - @test_case.named_parameters([ - dict( - testcase_name='_all_keys_in_vocab', - query_list=['a', 'a', 'b', 'a', 'b'], - key_vocab_list=['a', 'b'], - query_shape=[None], - expected_output=[0, 0, 1, 0, 1]), - dict( - testcase_name='_missing_keys_in_vocab', - query_list=['a', 'c', 'b', 'a', 'b'], - key_vocab_list=['a', 'b'], - query_shape=[None], - expected_output=[0, -1, 1, 0, 1]), - dict( - testcase_name='_nd_keys', - query_list=[['a', 'c', 'b'], ['a', 'b', 'a']], - key_vocab_list=['a', 'b'], - query_shape=[None, None], - expected_output=[[0, -1, 1], [0, 1, 0]]), - dict( - testcase_name='_empty_vocab', - query_list=['a', 'c', 'b', 'a', 'b'], - key_vocab_list=[], - query_shape=[None], - expected_output=[-1, -1, -1, -1, -1]), - dict( - testcase_name='_empty_query', - query_list=[], - key_vocab_list=['a'], - query_shape=[None], - expected_output=[]), - ]) - def test_lookup_key(self, query_list, key_vocab_list, query_shape, - expected_output): - with tf.compat.v1.Graph().as_default(): - query_ph = tf.compat.v1.placeholder( - dtype=tf.string, shape=query_shape, name='query') - key_vocab_ph = tf.compat.v1.placeholder( - dtype=tf.string, shape=[None], name='key_vocab') - key_indices = tf_utils.lookup_key(query_ph, key_vocab_ph) - with tf.compat.v1.Session().as_default() as sess: - output = sess.run( - key_indices, - feed_dict={ - query_ph.name: query_list, - key_vocab_ph.name: key_vocab_list - }) - self.assertAllEqual(expected_output, output) - - @test_case.named_parameters([ - dict( - testcase_name='_with_default', - with_default_value=True, - input_keys=['a', 'b', 'c', 'd', 'e']), - dict( - testcase_name='_wihout_default', - with_default_value=False, - input_keys=['a', 'b', 'c', 'd', 'e']), - dict( - testcase_name='_single_oov_key', - with_default_value=False, - input_keys=['e']) - ]) - def test_apply_per_key_vocab(self, with_default_value, input_keys): - default_value = '-7,-5' if with_default_value else None - vocab_data = [('0,0', 'a'), ('1,-1', 'b'), ('-1,1', 'c'), ('-2,2', 'd')] - expected_missing_key_result = [-7, -5] if default_value else [0, 0] - expected_lookup_results = { - 'a': [0, 0], - 'b': [1, -1], - 'c': [-1, 1], - 'd': [-2, 2], - } - - with tf.compat.v1.Graph().as_default(): - input_tensor = _value_to_tensor(input_keys) - vocab_filename = os.path.join(self.get_temp_dir(), 'test.txt') - encoded_vocab = '\n'.join([' '.join(pair) for pair in vocab_data]) - with tf.io.gfile.GFile(vocab_filename, 'w') as f: - f.write(encoded_vocab) - - output_tensor = tf_utils.apply_per_key_vocabulary( - tf.constant(vocab_filename), - input_tensor, - default_value=default_value) - - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - output = output_tensor.eval() - - expected_data = [ - expected_lookup_results.get(key, expected_missing_key_result) - for key in input_keys - ] - self.assertAllEqual(output, expected_data) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='dense', - x=[[[1], [2]], [[1], [2]]], - expected_result=4, - reduce_instance_dims=True, - input_signature=[tf.TensorSpec(None, tf.int64)]), - dict( - testcase_name='dense_with_nans', - x=[[[1], [np.nan]], [[1], [2]]], - expected_result=3, - reduce_instance_dims=True, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='dense_elementwise', - x=[[[1], [2]], [[1], [2]]], - expected_result=[[2], [2]], - reduce_instance_dims=False, - input_signature=[tf.TensorSpec(None, tf.int64)]), - dict( - testcase_name='dense_elementwise_with_nans', - x=[[[1], [2]], [[1], [np.nan]]], - expected_result=[[2], [1]], - reduce_instance_dims=False, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='sparse', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 0], [0, 2, 0], [1, 1, 0], [1, 2, 0]], - values=[1., 2., 3., 4.], - dense_shape=[2, 4, 1]), - expected_result=4, - reduce_instance_dims=True, - input_signature=[tf.SparseTensorSpec([None, 4, 1], tf.float32)]), - dict( - testcase_name='sparse_with_nans', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 0], [0, 2, 0], [1, 1, 0], [1, 2, 0], - [1, 3, 0]], - values=[1., 2., 3., 4., np.nan], - dense_shape=[2, 4, 1]), - expected_result=4, - reduce_instance_dims=True, - input_signature=[tf.SparseTensorSpec([None, 4, 1], tf.float32)]), - dict( - testcase_name='sparse_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 0], [0, 2, 0], [1, 1, 0], [1, 2, 0]], - values=[1., 2., 3., 4.], - dense_shape=[2, 4, 1]), - expected_result=[[1], [1], [2], [0]], - reduce_instance_dims=False, - input_signature=[tf.SparseTensorSpec([None, 4, 1], tf.float32)]), - dict( - testcase_name='sparse_elementwise_with_nans', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 0], [0, 2, 0], [1, 1, 0], [1, 2, 0], - [1, 3, 0]], - values=[1., 2., 3., 4., np.nan], - dense_shape=[2, 4, 1]), - expected_result=[[1], [1], [2], [0]], - reduce_instance_dims=False, - input_signature=[tf.SparseTensorSpec([None, 4, 1], tf.float32)]), - dict( - testcase_name='ragged', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 3, 4, 5])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_result=5, - reduce_instance_dims=True, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32) - ]), - dict( - testcase_name='ragged_with_nans', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5., np.nan], - np.float32), - row_splits=np.array([0, 2, 3, 4, 6])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_result=5, - reduce_instance_dims=True, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32) - ]), - dict( - testcase_name='ragged_elementwise', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 2, 4, 5])), - row_splits=np.array([0, 3, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_result=[[[2, 1], [0., 0], [1, 1]], - [[0, 0], [0, 0], [0, 0]]], - reduce_instance_dims=False, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32) - ]), - dict( - testcase_name='ragged_elementwise_with_nans', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5., np.nan], - np.float32), - row_splits=np.array([0, 2, 2, 4, 6])), - row_splits=np.array([0, 3, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_result=[[[2, 1], [0., 0], [1, 1]], - [[0, 0], [0, 0], [0, 0]]], - reduce_instance_dims=False, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32) - ]), - ])) - def test_reduce_batch_count(self, x, input_signature, expected_result, - reduce_instance_dims, function_handler): - - @function_handler(input_signature=input_signature) - def _reduce_batch_count(x): - result = tf_utils.reduce_batch_count( - x, reduce_instance_dims=reduce_instance_dims) - # Verify that the output shape is maintained. - # TODO(b/178189903): This will fail if _dense_shape_default isn't set in - # reduce_batch_count. - if (not isinstance(x, tf.RaggedTensor) and not reduce_instance_dims and - x.get_shape().ndims): - self.assertEqual(x.get_shape()[1:].as_list(), - result.get_shape().as_list()) - return result - - result = _reduce_batch_count(x) - self.assertAllEqual(result, expected_result) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='dense', - x=[[[1], [2]], [[3], [4]]], - expected_count=4, - expected_mean=2.5, - expected_var=1.25, - reduce_instance_dims=True, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='dense_with_nans', - x=[[[1], [2]], [[3], [np.nan]], [[np.nan], [4]]], - expected_count=4, - expected_mean=2.5, - expected_var=1.25, - reduce_instance_dims=True, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='dense_elementwise', - x=[[[1], [2]], [[3], [4]]], - expected_count=[[2.], [2.]], - expected_mean=[[2.], [3.]], - expected_var=[[1.], [1.]], - reduce_instance_dims=False, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='dense_elementwise_with_nans', - x=[[[1], [2]], [[3], [np.nan]], [[np.nan], [4]]], - expected_count=[[2.], [2.]], - expected_mean=[[2.], [3.]], - expected_var=[[1.], [1.]], - reduce_instance_dims=False, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='sparse', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 2], [1, 1], [1, 2]], - values=[1., 2., 3., 4.], - dense_shape=[2, 4]), - expected_count=4, - expected_mean=2.5, - expected_var=1.25, - reduce_instance_dims=True, - input_signature=[tf.SparseTensorSpec([None, 4], tf.float32)]), - dict( - testcase_name='sparse_with_nans', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 2], [1, 1], [1, 2], [1, 3]], - values=[1., 2., 3., 4., np.nan], - dense_shape=[2, 4]), - expected_count=4, - expected_mean=2.5, - expected_var=1.25, - reduce_instance_dims=True, - input_signature=[tf.SparseTensorSpec([None, 4], tf.float32)]), - dict( - testcase_name='sparse_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 3], [1, 1], [1, 3]], - values=[1., 2., 3., 4.], - dense_shape=[2, 5]), - expected_count=[1.0, 1.0, 0.0, 2.0, 0.0], - expected_mean=[1.0, 3.0, 0.0, 3.0, 0.0], - expected_var=[0.0, 0.0, 0.0, 1.0, 0.0], - reduce_instance_dims=False, - input_signature=[tf.SparseTensorSpec([None, 5], tf.float32)]), - dict( - testcase_name='sparse_elementwise_with_nans', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 3], [1, 1], [1, 2], [1, 3]], - values=[1., 2., 3., np.nan, 4.], - dense_shape=[2, 5]), - expected_count=[1.0, 1.0, 0.0, 2.0, 0.0], - expected_mean=[1.0, 3.0, 0.0, 3.0, 0.0], - expected_var=[0.0, 0.0, 0.0, 1.0, 0.0], - reduce_instance_dims=False, - input_signature=[tf.SparseTensorSpec([None, 5], tf.float32)]), - dict( - testcase_name='sparse_3d_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 3], [0, 1, 0], [0, 1, 3], [1, 1, 1], - [1, 1, 3]], - values=[-10., 1., 2., 3., 4.], - dense_shape=[2, 3, 5]), - expected_count=[[0, 0, 0, 1, 0], [1, 1, 0, 2, 0], [0] * 5], - expected_mean=[[0, 0, 0, -10, 0], [1, 3, 0, 3, 0], [0] * 5], - expected_var=[[0] * 5, [0, 0, 0, 1, 0], [0] * 5], - reduce_instance_dims=False, - input_signature=[tf.SparseTensorSpec([None, 3, 5], tf.float32)]), - dict( - testcase_name='ragged', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 3, 4, 5])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_count=5, - expected_mean=3, - expected_var=2, - reduce_instance_dims=True, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32) - ]), - dict( - testcase_name='ragged_with_nans', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5., np.nan], - np.float32), - row_splits=np.array([0, 2, 3, 4, 6])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_count=5, - expected_mean=3, - expected_var=2, - reduce_instance_dims=True, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32) - ]), - dict( - testcase_name='ragged_elementwise', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 2, 4, 5])), - row_splits=np.array([0, 3, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_count=[[[2., 1.], [0., 0.], [1., 1.]], - [[0., 0.], [0., 0.], [0., 0.]]], - expected_mean=[[[3., 2.], [0., 0.], [3., 4.]], - [[0., 0.], [0., 0.], [0., 0.]]], - expected_var=[[[4., 0.], [0., 0.], [0., 0.]], - [[0., 0.], [0., 0.], [0., 0.]]], - reduce_instance_dims=False, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32) - ]), - dict( - testcase_name='ragged_elementwise_with_nans', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5., np.nan], - np.float32), - row_splits=np.array([0, 2, 2, 4, 6])), - row_splits=np.array([0, 3, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_count=[[[2., 1.], [0., 0.], [1., 1.]], - [[0., 0.], [0., 0.], [0., 0.]]], - expected_mean=[[[3., 2.], [0., 0.], [3., 4.]], - [[0., 0.], [0., 0.], [0., 0.]]], - expected_var=[[[4., 0.], [0., 0.], [0., 0.]], - [[0., 0.], [0., 0.], [0., 0.]]], - reduce_instance_dims=False, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32) - ]), - ])) - def test_reduce_batch_count_mean_and_var( - self, x, input_signature, expected_count, expected_mean, expected_var, - reduce_instance_dims, function_handler): - - @function_handler(input_signature=input_signature) - def _reduce_batch_count_mean_and_var(x): - result = tf_utils.reduce_batch_count_mean_and_var( - x, reduce_instance_dims=reduce_instance_dims) - # Verify that the output shapes are maintained. - # TODO(b/178189903): This will fail if _dense_shape_default isn't set in - # reduce_batch_count. - if (not isinstance(x, tf.RaggedTensor) and not reduce_instance_dims and - x.get_shape().ndims): - for tensor in result: - self.assertEqual(x.get_shape()[1:].as_list(), - tensor.get_shape().as_list()) - return result - - count, mean, var = _reduce_batch_count_mean_and_var(x) - self.assertAllEqual(expected_count, count) - self.assertAllEqual(expected_mean, mean) - self.assertAllEqual(expected_var, var) - - @test_case.named_parameters([ - dict( - testcase_name='num_samples_1', - num_samples=1, - dtype=tf.float32, - expected_counts=np.array([1, 0, 0, 0], np.float32), - expected_factors=np.array([[1.0], [0.0], [0.0], [0.0]], np.float32)), - dict( - testcase_name='num_samples_2', - num_samples=2, - dtype=tf.float32, - expected_counts=np.array([2, 1, 0, 0], np.float32), - expected_factors=np.array( - [[1. / 2., 1. / 2.], [-1. / 2., 1. / 2.], [0., 0.], [0., 0.]], - np.float32)), - dict( - testcase_name='num_samples_3', - num_samples=3, - dtype=tf.float32, - expected_counts=np.array([3, 3, 1, 0], np.float32), - expected_factors=np.array( - [[1. / 3., 1. / 3., 1. / 3.], [-1. / 3., 0., 1. / 3.], - [1. / 3., -2. / 3., 1. / 3.], [0., 0., 0.]], np.float32)), - dict( - testcase_name='num_samples_4', - num_samples=4, - dtype=tf.float32, - expected_counts=np.array([4, 6, 4, 1], np.float32), - expected_factors=np.array( - [[1. / 4., 1. / 4., 1. / 4., 1. / 4.], - [-3. / 12., -1. / 12., 1. / 12., 3. / 12.], - [1. / 4., -1. / 4., -1. / 4., 1. / 4.], - [-1. / 4., 3. / 4., -3. / 4., 1. / 4.]], np.float32)) - ]) - def test_num_terms_and_factors( - self, num_samples, dtype, expected_counts, expected_factors): - results = tf_utils._num_terms_and_factors(num_samples, dtype) - counts = results[0:4] - assert len(expected_counts) == len(counts), (expected_counts, counts) - for result, expected_count in zip(counts, expected_counts): - self.assertEqual(result.dtype, dtype) - self.assertAllClose(result, expected_count) - - factors = results[4:] - assert len(expected_factors) == len(factors), (expected_factors, factors) - for result, expected_factor in zip(factors, expected_factors): - self.assertEqual(result.dtype, dtype) - self.assertAllClose(result, expected_factor) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='dense', - x=[[[1], [2]], [[3], [4]]], - expected_counts=np.array([4., 6., 4., 1.], np.float32), - expected_moments=np.array([2.5, 10.0 / 12.0, 0.0, 0.0], - np.float32), - reduce_instance_dims=True, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='dense_large', - x=[2.0, 3.0, 4.0, 2.4, 5.5, 1.2, 5.4, 2.2, 7.1, 1.3, 1.5], - expected_counts=np.array( - [11, 11 * 10 // 2, 11 * 10 * 9 // 6, 11 * 10 * 9 * 8 // 24], - np.float32), - expected_moments=np.array([ - 3.2363636363636363, 1.141818181818182, 0.31272727272727263, - 0.026666666666666616 - ], np.float32), - reduce_instance_dims=True, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='dense_very_large', - x=-np.log(1.0 - np.arange(0, 1, 1e-6, dtype=np.float32)), - expected_counts=np.array([ - 1000000, 499999500000.0, 1.66666166667e+17, - 4.1666416667125e+22 - ], np.float32), - expected_moments=np.array([ - 0.99999217330, 0.4999936732947, 0.166660839941, - 0.0833278399134 - ], np.float32), - reduce_instance_dims=True, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='dense_elementwise', - x=[[[1], [2]], [[3], [4]]], - expected_counts=np.array( - [[[2], [2]], [[1], [1]], [[0], [0]], [[0], [0]]], np.float32), - expected_moments=np.array([[[2.0], [3.0]], [[1.0], [1.0]], - [[0.0], [0.0]], [[0.0], [0.0]]], - np.float32), - reduce_instance_dims=False, - input_signature=[tf.TensorSpec(None, tf.float32)]), - dict( - testcase_name='sparse', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 2], [2, 0], [2, 2]], - values=[1., 2., 3., 4.], - dense_shape=[3, 4]), - expected_counts=np.array([4, 6, 4, 1], np.float32), - expected_moments=np.array([2.5, 10.0 / 12.0, 0.0, 0.0], - np.float32), - reduce_instance_dims=True, - input_signature=[tf.SparseTensorSpec([None, 4], tf.float32)]), - dict( - testcase_name='sparse_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 0], [0, 2, 0], [2, 0, 0], [2, 2, 0], - [3, 3, 0]], - values=[1., 2., 3., 4., 5.], - dense_shape=[3, 5, 1]), - expected_counts=np.array( - [[[2], [0], [2], [1], [0]], [[1], [0], [1], [0], [0]], - [[0], [0], [0], [0], [0]], [[0], [0], [0], [0], [0]]], - np.float32), - expected_moments=np.array([[[2.0], [0.0], [3.0], [5.0], [0.0]], - [[1.0], [0.0], [1.0], [0.0], [0.0]], - [[0.0], [0.0], [0.0], [0.0], [0.0]], - [[0.0], [0.0], [0.0], [0.0], [0.0]]], - np.float32), - reduce_instance_dims=False, - input_signature=[tf.SparseTensorSpec([None, 5, 1], tf.float32)]), - dict( - testcase_name='ragged', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 3, 4, 5])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_counts=np.array([5., 10., 10., 5.], np.float32), - expected_moments=np.array([3., 1., 0., 0.], np.float32), - reduce_instance_dims=True, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32) - ]), - ])) - def test_reduce_batch_count_l_moments( - self, x, input_signature, expected_counts, expected_moments, - reduce_instance_dims, function_handler): - - @function_handler(input_signature=input_signature) - def _reduce_batch_count_l_moments(x): - result = tf_utils.reduce_batch_count_l_moments( - x, reduce_instance_dims=reduce_instance_dims) - for tensor in result: - if not reduce_instance_dims and x.get_shape().ndims: - self.assertEqual(x.get_shape()[1:].as_list(), - tensor.get_shape().as_list()) - return result - - count_and_moments = _reduce_batch_count_l_moments(x) - counts = count_and_moments[0::2] - moments = count_and_moments[1::2] - for i in range(0, 4): - self.assertEqual(counts[i].dtype, expected_counts[i].dtype) - self.assertAllClose(counts[i], expected_counts[i], rtol=1e-8) - self.assertEqual(moments[i].dtype, expected_moments[i].dtype) - self.assertAllClose(moments[i], expected_moments[i], rtol=1e-8) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='dense', - x=[[1], [2], [3], [4], [4]], - key=['a', 'a', 'a', 'b', 'a'], - expected_key_vocab=[b'a', b'b'], - expected_count=[4., 1.], - expected_mean=[2.5, 4.], - expected_var=[1.25, 0.], - reduce_instance_dims=True, - input_signature=[ - tf.TensorSpec([None, 1], tf.float32), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='dense_with_nans', - x=[[1], [2], [3], [4], [4], [np.nan], [np.nan]], - key=['a', 'a', 'a', 'b', 'a', 'a', 'b'], - expected_key_vocab=[b'a', b'b'], - expected_count=[4., 1.], - expected_mean=[2.5, 4.], - expected_var=[1.25, 0.], - reduce_instance_dims=True, - input_signature=[ - tf.TensorSpec([None, 1], tf.float32), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='dense_elementwise', - x=[[1, 2], [3, 4], [1, 2]], - key=['a', 'a', 'b'], - expected_key_vocab=[b'a', b'b'], - expected_count=[[2., 2.], [1., 1.]], - expected_mean=[[2., 3.], [1., 2.]], - expected_var=[[1., 1.], [0., 0.]], - reduce_instance_dims=False, - input_signature=[ - tf.TensorSpec([None, 2], tf.float32), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='dense_elementwise_with_nans', - x=[[1, 2], [3, 4], [1, 2], [np.nan, np.nan]], - key=['a', 'a', 'b', 'a'], - expected_key_vocab=[b'a', b'b'], - expected_count=[[2., 2.], [1., 1.]], - expected_mean=[[2., 3.], [1., 2.]], - expected_var=[[1., 1.], [0., 0.]], - reduce_instance_dims=False, - input_signature=[ - tf.TensorSpec([None, 2], tf.float32), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='sparse', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 3]], - values=[1., 2., 3., 4., 4.], - dense_shape=[3, 4]), - key=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 3]], - values=['a', 'a', 'a', 'a', 'b'], - dense_shape=[3, 4]), - expected_key_vocab=[b'a', b'b'], - expected_count=[4, 1], - expected_mean=[2.5, 4], - expected_var=[1.25, 0], - reduce_instance_dims=True, - input_signature=[ - tf.SparseTensorSpec([None, 4], tf.float32), - tf.SparseTensorSpec([None, 4], tf.string) - ]), - dict( - testcase_name='sparse_with_nans', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 2], [2, 3]], - values=[1., 2., 3., 4., np.nan, 4.], - dense_shape=[3, 4]), - key=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 2], [2, 3]], - values=['a', 'a', 'a', 'a', 'a', 'b'], - dense_shape=[3, 4]), - expected_key_vocab=[b'a', b'b'], - expected_count=[4, 1], - expected_mean=[2.5, 4], - expected_var=[1.25, 0], - reduce_instance_dims=True, - input_signature=[ - tf.SparseTensorSpec([None, 4], tf.float32), - tf.SparseTensorSpec([None, 4], tf.string) - ]), - dict( - testcase_name='sparse_x_dense_key', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 3]], - values=[1., 2., 3., 4., 4.], - dense_shape=[3, 4]), - key=['a', 'a', 'b'], - expected_key_vocab=[b'a', b'b'], - expected_count=[4, 1], - expected_mean=[2.5, 4], - expected_var=[1.25, 0], - reduce_instance_dims=True, - input_signature=[ - tf.SparseTensorSpec([None, 4], tf.float32), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='ragged', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([3., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 3, 4, 5])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - key=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array(['a', 'a', 'b', 'a', 'b']), - row_splits=np.array([0, 2, 3, 4, 5])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - expected_key_vocab=[b'a', b'b'], - expected_count=[3, 2], - expected_mean=[3, 4], - expected_var=[np.float32(0.666667), 1.], - reduce_instance_dims=True, - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32), - tf.RaggedTensorSpec([None, None, None, None], tf.string) - ]), - dict( - testcase_name='ragged_x_dense_key', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([3., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 3, 4, 5])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - key=['a', 'b'], - expected_key_vocab=[b'a', b'b'], - expected_count=[4, 1], - expected_mean=[3, 5], - expected_var=[.5, 0.], - reduce_instance_dims=True, - input_signature=[ - tf.RaggedTensorSpec([2, None, None, None], tf.float32), - tf.TensorSpec([2], tf.string) - ]), - dict( - testcase_name='ragged_with_nans', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([3., 2., 3., 4., 5., np.nan], - np.float32), - row_splits=np.array([0, 2, 3, 4, 6])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - key=['a', 'b'], - expected_key_vocab=[b'a', b'b'], - expected_count=[4, 1], - expected_mean=[3, 5], - expected_var=[.5, 0.], - reduce_instance_dims=True, - input_signature=[ - tf.RaggedTensorSpec([2, None, None, None], tf.float32), - tf.TensorSpec([2], tf.string) - ]), - ])) - def test_reduce_batch_count_mean_and_var_per_key( - self, x, key, input_signature, expected_key_vocab, expected_count, - expected_mean, expected_var, reduce_instance_dims, function_handler): - - @function_handler(input_signature=input_signature) - def _reduce_batch_count_mean_and_var_per_key(x, key): - return tf_utils.reduce_batch_count_mean_and_var_per_key( - x, key, reduce_instance_dims=reduce_instance_dims) - - key_vocab, count, mean, var = _reduce_batch_count_mean_and_var_per_key( - x, key) - - self.assertAllEqual(key_vocab, expected_key_vocab) - self.assertAllEqual(count, expected_count) - self.assertAllEqual(mean, expected_mean) - self.assertAllEqual(var, expected_var) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='sparse', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 1], [0, 2]], - values=[3, 2, -1], - dense_shape=[1, 5]), - expected_x_minus_min=1, - expected_x_max=3, - reduce_instance_dims=True, - input_signature=[tf.SparseTensorSpec([None, None], tf.int64)]), - dict( - testcase_name='float', - x=[[1, 5, 2]], - expected_x_minus_min=-1, - expected_x_max=5, - reduce_instance_dims=True, - input_signature=[tf.TensorSpec([None, None], tf.float32)]), - dict( - testcase_name='sparse_float_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 1], [1, 0]], - values=[3, 2, -1], - dense_shape=[2, 3]), - expected_x_minus_min=[1, -2, np.nan], - expected_x_max=[3, 2, np.nan], - reduce_instance_dims=False, - input_signature=[tf.SparseTensorSpec([None, None], tf.float32)]), - dict( - testcase_name='float_elementwise', - x=[[1, 5, 2], [2, 3, 4]], - reduce_instance_dims=False, - expected_x_minus_min=[-1, -3, -2], - expected_x_max=[2, 5, 4], - input_signature=[tf.TensorSpec([None, None], tf.float32)]), - dict( - testcase_name='sparse_int64_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 1], [1, 0]], - values=[3, 2, -1], - dense_shape=[2, 3]), - reduce_instance_dims=False, - expected_x_minus_min=[1, -2, tf.int64.min + 1], - expected_x_max=[3, 2, tf.int64.min + 1], - input_signature=[tf.SparseTensorSpec([None, None], tf.int64)]), - dict( - testcase_name='sparse_int32_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 1], [1, 0]], - values=[3, 2, -1], - dense_shape=[2, 3]), - reduce_instance_dims=False, - expected_x_minus_min=[1, -2, tf.int32.min + 1], - expected_x_max=[3, 2, tf.int32.min + 1], - input_signature=[tf.SparseTensorSpec([None, None], tf.int32)]), - dict( - testcase_name='sparse_float64_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 1], [1, 0]], - values=[3, 2, -1], - dense_shape=[2, 3]), - reduce_instance_dims=False, - expected_x_minus_min=[1, -2, np.nan], - expected_x_max=[3, 2, np.nan], - input_signature=[tf.SparseTensorSpec([None, None], tf.float64)]), - dict( - testcase_name='sparse_float32_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 1], [1, 0]], - values=[3, 2, -1], - dense_shape=[2, 3]), - reduce_instance_dims=False, - expected_x_minus_min=[1, -2, np.nan], - expected_x_max=[3, 2, np.nan], - input_signature=[tf.SparseTensorSpec([None, None], tf.float32)]), - dict( - testcase_name='sparse_3d_elementwise', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 0], [0, 0, 1], [1, 0, 1]], - values=[3, 2, -1], - dense_shape=[2, 3, 3]), - reduce_instance_dims=False, - expected_x_minus_min=[[-3, 1, np.nan], [np.nan] * 3, - [np.nan] * 3], - expected_x_max=[[3, 2, np.nan], [np.nan] * 3, [np.nan] * 3], - input_signature=[ - tf.SparseTensorSpec([None, None, None], tf.float32) - ]), - dict( - testcase_name='ragged', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 3, 5])), - row_splits=np.array([0, 2, 3])), - reduce_instance_dims=True, - expected_x_minus_min=-1., - expected_x_max=5., - input_signature=[ - tf.RaggedTensorSpec([2, None, None], tf.float32) - ]), - dict( - testcase_name='ragged_elementwise', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 2, 4, 5])), - row_splits=np.array([0, 3, 3, 4])), - row_splits=np.array([0, 2, 3])), - reduce_instance_dims=False, - expected_x_minus_min=[[[-1.0, -2.0], [np.nan, np.nan], - [-3.0, -4.0]], - [[np.nan, np.nan], [np.nan, np.nan], - [np.nan, np.nan]]], - expected_x_max=[[[5.0, 2.0], [np.nan, np.nan], [3.0, 4.0]], - [[np.nan, np.nan], [np.nan, np.nan], - [np.nan, np.nan]]], - input_signature=[ - tf.RaggedTensorSpec([2, None, None, None], tf.float32) - ]), - dict( - testcase_name='all_nans', - x=[[np.nan, np.nan, np.nan]], - # Output of `tf.reduce_max` if all inputs are NaNs for older - # versions of TF is -inf. - expected_x_minus_min=(-np.inf if version.parse(tf.__version__) < - version.parse('2.4') else np.nan), - expected_x_max=(-np.inf if version.parse(tf.__version__) < - version.parse('2.4') else np.nan), - reduce_instance_dims=True, - input_signature=[tf.TensorSpec([None, None], tf.float32)]), - dict( - testcase_name='empty_batch', - x=[[]], - expected_x_minus_min=-np.inf, - expected_x_max=-np.inf, - reduce_instance_dims=True, - input_signature=[tf.TensorSpec([None, None], tf.float32)]), - ])) - def test_reduce_batch_minus_min_and_max( - self, x, expected_x_minus_min, expected_x_max, reduce_instance_dims, - input_signature, function_handler): - - @function_handler(input_signature=input_signature) - def _reduce_batch_minus_min_and_max(x): - result = tf_utils.reduce_batch_minus_min_and_max( - x, reduce_instance_dims=reduce_instance_dims) - # Verify that the output shapes are maintained. - if (not reduce_instance_dims and not isinstance(x, tf.RaggedTensor)): - for tensor in result: - self.assertEqual(x.get_shape()[1:].as_list(), - tensor.get_shape().as_list()) - return result - - x_minus_min, x_max = _reduce_batch_minus_min_and_max(x) - - self.assertAllEqual(x_minus_min, expected_x_minus_min) - self.assertAllEqual(x_max, expected_x_max) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='sparse', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [1, 1], [2, 2], [3, 1]], - values=[3, 2, -1, 3], - dense_shape=[4, 5]), - key=['a', 'a', 'a', 'b'], - reduce_instance_dims=True, - expected_key_vocab=[b'a', b'b'], - expected_x_minus_min=[1, -3], - expected_x_max=[3, 3], - input_signature=[ - tf.SparseTensorSpec([None, None], tf.int64), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='float', - x=[[1], [5], [2], [3]], - key=['a', 'a', 'a', 'b'], - reduce_instance_dims=True, - expected_key_vocab=[b'a', b'b'], - expected_x_minus_min=[-1, -3], - expected_x_max=[5, 3], - input_signature=[ - tf.TensorSpec([None, None], tf.float32), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='float_elementwise', - x=[[1], [5], [2], [3]], - key=['a', 'a', 'a', 'b'], - reduce_instance_dims=False, - expected_key_vocab=[b'a', b'b'], - expected_x_minus_min=[[-1], [-3]], - expected_x_max=[[5], [3]], - input_signature=[ - tf.TensorSpec([None, None], tf.float32), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='float3dims', - x=[[[1, 5], [1, 1]], [[5, 1], [5, 5]], [[2, 2], [2, 5]], - [[3, -3], [3, 3]]], - key=['a', 'a', 'a', 'b'], - reduce_instance_dims=True, - expected_key_vocab=[b'a', b'b'], - expected_x_minus_min=[-1, 3], - expected_x_max=[5, 3], - input_signature=[ - tf.TensorSpec([None, None, None], tf.float32), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='float3dims_elementwise', - x=[[[1, 5], [1, 1]], [[5, 1], [5, 5]], [[2, 2], [2, 5]], - [[3, -3], [3, 3]]], - key=['a', 'a', 'a', 'b'], - reduce_instance_dims=False, - expected_key_vocab=[b'a', b'b'], - expected_x_minus_min=[[[-1, -1], [-1, -1]], [[-3, 3], [-3, -3]]], - expected_x_max=[[[5, 5], [5, 5]], [[3, -3], [3, 3]]], - input_signature=[ - tf.TensorSpec([None, None, None], tf.float32), - tf.TensorSpec([None], tf.string) - ]), - dict( - testcase_name='ragged', - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([3., 2., 3., 4., 5.], np.float32), - row_splits=np.array([0, 2, 3, 4, 5])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - key=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array(['a', 'a', 'b', 'a', 'b']), - row_splits=np.array([0, 2, 3, 4, 5])), - row_splits=np.array([0, 2, 3, 4])), - row_splits=np.array([0, 2, 3])), - reduce_instance_dims=True, - expected_key_vocab=[b'a', b'b'], - expected_x_minus_min=[-2., -3.], - expected_x_max=[4., 5.], - input_signature=[ - tf.RaggedTensorSpec([None, None, None, None], tf.float32), - tf.RaggedTensorSpec([None, None, None, None], tf.string) - ]), - ])) - def test_reduce_batch_minus_min_and_max_per_key( - self, x, key, reduce_instance_dims, expected_key_vocab, - expected_x_minus_min, expected_x_max, input_signature, function_handler): - - @function_handler(input_signature=input_signature) - def _reduce_batch_minus_min_and_max_per_key(x, key): - return tf_utils.reduce_batch_minus_min_and_max_per_key( - x, key, reduce_instance_dims=reduce_instance_dims) - - key_vocab, x_minus_min, x_max = _reduce_batch_minus_min_and_max_per_key( - x, key) - - self.assertAllEqual(key_vocab, expected_key_vocab) - self.assertAllEqual(x_minus_min, expected_x_minus_min) - self.assertAllEqual(x_max, expected_x_max) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='dense', - key=['a', 'a', 'a', 'b'], - spec=tf.TensorSpec([None], tf.string), - expected_key_vocab=[b'a', b'b'], - expected_count=[3, 1]), - dict( - testcase_name='sparse', - key=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [1, 1], [2, 2], [3, 1]], - values=[3, 2, -1, 3], - dense_shape=[4, 5]), - spec=tf.SparseTensorSpec([4, 5], tf.int64), - expected_key_vocab=[b'3', b'2', b'-1'], - expected_count=[2, 1, 1]), - dict( - testcase_name='ragged', - key=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1.2, 1., 1.2, 1.]), - row_splits=np.array([0, 2, 4])), - row_splits=np.array([0, 2])), - spec=tf.RaggedTensorSpec([1, None, None], tf.float32), - expected_key_vocab=[b'1.200000', b'1.000000'], - expected_count=[2, 2]), - ])) - def test_reduce_batch_count_per_key(self, key, spec, expected_key_vocab, - expected_count, function_handler): - - @function_handler(input_signature=[spec]) - def _reduce_batch_count_per_key(key): - return tf_utils.reduce_batch_count_per_key(key) - - key_vocab, key_counts = _reduce_batch_count_per_key(key) - - self.assertAllEqual(key_vocab, expected_key_vocab) - self.assertAllEqual(key_counts, expected_count) - - @test_case.named_parameters(test_case.cross_with_function_handlers([ - dict( - testcase_name='full', - bucket_vocab=['1', '2', '0'], - counts=[3, 1, 4], - boundary_size=3, - expected_counts=[4, 3, 1]), - dict( - testcase_name='missing', - bucket_vocab=['1', '3', '0'], - counts=[3, 1, 4], - boundary_size=5, - expected_counts=[4, 3, 0, 1, 0]), - ])) - def test_reorder_histogram( - self, bucket_vocab, counts, boundary_size, - expected_counts, function_handler): - input_signature = [tf.TensorSpec([None], tf.string), - tf.TensorSpec([None], tf.int64), - tf.TensorSpec([], tf.int32)] - @function_handler(input_signature=input_signature) - def _reorder_histogram(bucket_vocab, counts, boundary_size): - return tf_utils.reorder_histogram(bucket_vocab, counts, boundary_size) - - counts = _reorder_histogram(bucket_vocab, counts, boundary_size) - self.assertAllEqual(counts, expected_counts) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='simple', - x=[0.0, 2.0, 3.5, 4.0], - x_spec=tf.TensorSpec([None], tf.float32), - boundaries=[[1.0, 2.0, 3.0, 3.9]], - boundaries_spec=tf.TensorSpec([1, None], tf.float32), - side=tf_utils.Side.LEFT, - expected_buckets=[0, 1, 3, 3]), - dict( - testcase_name='simple_right', - x=[0.0, 2.0, 3.5, 4.0], - x_spec=tf.TensorSpec([None], tf.float32), - boundaries=[1.0, 2.0, 3.0, 3.9], - boundaries_spec=tf.TensorSpec([None], tf.float32), - side=tf_utils.Side.RIGHT, - expected_buckets=[0, 2, 3, 4]), - dict( - testcase_name='2dim', - x=[[0.0, 4.0, 3.5, 2.0, 1.7]], - x_spec=tf.TensorSpec([1, None], tf.float32), - boundaries=[[1.0, 2.0, 3.0, 5.0]], - boundaries_spec=tf.TensorSpec([1, None], tf.float32), - side=tf_utils.Side.LEFT, - expected_buckets=[[0, 3, 3, 1, 1]]), - dict( - testcase_name='large_buckets', - x=[[50_000_000]], - x_spec=tf.TensorSpec([1, None], tf.int64), - boundaries=[0, 50_000_001, 100_000_001], - boundaries_spec=tf.TensorSpec([None], tf.int64), - side=tf_utils.Side.RIGHT, - expected_buckets=[[1]]), - ])) - def test_assign_buckets(self, x, x_spec, boundaries, boundaries_spec, side, - expected_buckets, function_handler): - - @function_handler(input_signature=[x_spec, boundaries_spec]) - def _assign_buckets(x, boundaries): - return tf_utils.assign_buckets(x, boundaries, side) - - buckets = _assign_buckets(x, boundaries) - self.assertAllEqual(buckets, expected_buckets) - - def test_sparse_indices(self): - exception_cls = tf.errors.InvalidArgumentError - error_string = 'Condition x == y did not hold element-wise:' - value = tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [1, 1], [2, 2], [3, 1]], - values=[3, 2, -1, 3], - dense_shape=[4, 5]) - key_value = tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [1, 2], [2, 2], [3, 1]], - values=['a', 'a', 'a', 'b'], - dense_shape=[4, 5]) - with tf.compat.v1.Graph().as_default(): - x = tf.compat.v1.sparse_placeholder(tf.int64, shape=[None, None]) - key = tf.compat.v1.sparse_placeholder(tf.string, shape=[None, None]) - with tf.compat.v1.Session() as sess: - with self.assertRaisesRegex(exception_cls, error_string): - sess.run(tf_utils.reduce_batch_minus_min_and_max_per_key(x, key), - feed_dict={x: value, key: key_value}) - - def test_convert_sparse_indices(self): - exception_cls = tf.errors.InvalidArgumentError - error_string = 'Condition x == y did not hold element-wise:' - sparse = tf.SparseTensor( - indices=[[0, 0, 0], [1, 0, 1], [2, 0, 2], [3, 0, 1]], - values=[3, 2, -1, 3], - dense_shape=[4, 2, 5]) - dense = tf.constant(['a', 'b', 'c', 'd']) - x, key = tf_utils._validate_and_get_dense_value_key_inputs(sparse, sparse) - self.assertAllEqual(self.evaluate(x), sparse.values) - self.assertAllEqual(self.evaluate(key), sparse.values) - - x, key = tf_utils._validate_and_get_dense_value_key_inputs(sparse, dense) - self.assertAllEqual(self.evaluate(x), sparse.values) - self.assertAllEqual(self.evaluate(key), dense) - - with tf.compat.v1.Graph().as_default(): - sparse1 = tf.compat.v1.sparse_placeholder( - tf.int64, shape=[None, None, None]) - sparse2 = tf.compat.v1.sparse_placeholder( - tf.int64, shape=[None, None, None]) - sparse_value1 = tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 0], [1, 0, 1], [2, 0, 2], [3, 0, 1]], - values=[3, 2, -1, 3], - dense_shape=[4, 2, 5]) - sparse_value2 = tf.compat.v1.SparseTensorValue( - indices=[[0, 0, 0], [1, 0, 2], [2, 0, 2], [3, 0, 1]], - values=[3, 2, -1, 3], - dense_shape=[4, 2, 5]) - - with tf.compat.v1.Session() as sess: - with self.assertRaisesRegex(exception_cls, error_string): - sess.run(tf_utils._validate_and_get_dense_value_key_inputs(sparse1, - sparse2), - feed_dict={sparse1: sparse_value1, sparse2: sparse_value2}) - - def test_convert_ragged_indices(self): - exception_cls = tf.errors.InvalidArgumentError - error_string = 'Condition x == y did not hold element-wise:' - ragged = tf.RaggedTensor.from_row_splits( - values=tf.RaggedTensor.from_row_splits( - values=np.array([1.2, 1., 1.2, 1.]), row_splits=np.array([0, 2, - 4])), - row_splits=np.array([0, 1, 2])) - dense = tf.constant(['a', 'b']) - dense_result = tf.constant(['a', 'a', 'b', 'b']) - x, key = tf_utils._validate_and_get_dense_value_key_inputs(ragged, ragged) - self.assertAllEqual(self.evaluate(x), ragged.flat_values) - self.assertAllEqual(self.evaluate(key), ragged.flat_values) - - x, key = tf_utils._validate_and_get_dense_value_key_inputs(ragged, dense) - self.assertAllEqual(self.evaluate(x), ragged.flat_values) - self.assertAllEqual(self.evaluate(key), dense_result) - - with tf.compat.v1.Graph().as_default(): - ragged1 = tf.compat.v1.ragged.placeholder(tf.float32, 2) - ragged2 = tf.compat.v1.ragged.placeholder(tf.float32, 2) - ragged_value1 = tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1.2, 1., 1.2, 1.]), - row_splits=np.array([0, 2, 4])), - row_splits=np.array([0, 2])) - ragged_value2 = tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1.2, 1., 1.2, 1.]), - row_splits=np.array([0, 3, 4])), - row_splits=np.array([0, 2])) - - with tf.compat.v1.Session() as sess: - with self.assertRaisesRegex(exception_cls, error_string): - sess.run( - tf_utils._validate_and_get_dense_value_key_inputs( - ragged1, ragged2), - feed_dict={ - ragged1: ragged_value1, - ragged2: ragged_value2 - }) - - @test_case.named_parameters( - dict( - testcase_name='dense_tensor', - key=['b', 'a', 'b'], - key_vocab=['a', 'b'], - reductions=([1, 2], [3, 4]), - x=[5, 6, 7], - reduce_instance_dims=True, - expected_results=([2, 1, 2], [4, 3, 4])), - dict( - testcase_name='sparse_tensor_dense_key', - key=['b', 'a', 'b'], - key_vocab=['a', 'b'], - reductions=([1, 2], [3, 4]), - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [1, 2], [2, 2], [2, 3]], - values=[3, 2, -1, 3], - dense_shape=[3, 5]), - reduce_instance_dims=True, - expected_results=([2, 1, 2, 2], [4, 3, 4, 4])), - dict( - testcase_name='sparse_tensor_sparse_key', - key=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [1, 2], [2, 2], [2, 3]], - values=['b', 'a', 'b', 'b'], - dense_shape=[3, 5]), - key_vocab=['a', 'b'], - reductions=([1, 2], [3, 4]), - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [1, 2], [2, 2], [2, 3]], - values=[3, 2, -1, 3], - dense_shape=[3, 5]), - reduce_instance_dims=True, - expected_results=([2, 1, 2, 2], [4, 3, 4, 4])), - dict( - testcase_name='ragged_tensor_dense_key', - key=['a', 'b', 'a'], - key_vocab=['a', 'b'], - reductions=([1, 2], [3, 4]), - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1.2, 1., 1.2, 1.]), - row_splits=np.array([0, 2, 4])), - row_splits=np.array([0, 1, 2, 2])), - reduce_instance_dims=True, - expected_results=([1, 1, 2, 2], [3, 3, 4, 4])), - dict( - testcase_name='ragged_tensor_ragged_key', - key=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array(['a', 'b', 'b', 'a']), - row_splits=np.array([0, 2, 4])), - row_splits=np.array([0, 2])), - key_vocab=['a', 'b'], - reductions=([1, 2], [3, 4]), - x=tf.compat.v1.ragged.RaggedTensorValue( - values=tf.compat.v1.ragged.RaggedTensorValue( - values=np.array([1.2, 1., 1.2, 1.]), - row_splits=np.array([0, 2, 4])), - row_splits=np.array([0, 2])), - reduce_instance_dims=True, - expected_results=([1, 2, 2, 1], [3, 4, 4, 3])), - dict( - testcase_name='missing_key', - key=['b', 'a', 'c'], - key_vocab=['z', 'a', 'b'], - reductions=([-77, 1, 2], [-99, 3, 4]), - x=[5, 6, 7], - reduce_instance_dims=True, - expected_results=([2, 1, 0], [4, 3, 0])), - dict( - testcase_name='_dense_tensor_2d_elementwise', - key=['a'], - key_vocab=['a', 'b'], - reductions=([[1, 5], [-2, 0]], [[5, 9], [2, 4]]), - x=[[4, 8]], - reduce_instance_dims=False, - expected_results=([[1, 5]], [[5, 9]])), - dict( - testcase_name='_dense_tensor_3d_elementwise', - key=['a'], - key_vocab=['a', 'b'], - reductions=([[[1, 1], [1, 1]], [[3, -3], [3, 3]]], [[[5, 5], [5, 5]], - [[3, -3], [3, - 3]]]), - x=[[[1, 5], [1, 1]]], - reduce_instance_dims=False, - expected_results=([[[1, 1], [1, 1]]], [[[5, 5], [5, 5]]])), - ) - def test_map_per_key_reductions(self, key, key_vocab, reductions, x, - reduce_instance_dims, expected_results): - with tf.compat.v1.Graph().as_default(): - key = _value_to_tensor(key) - key_vocab = tf.constant(key_vocab) - reductions = tuple([tf.constant(t) for t in reductions]) - x = _value_to_tensor(x) - expected_results = tuple(tf.constant(t) for t in expected_results) - results = tf_utils.map_per_key_reductions(reductions, key, key_vocab, x, - reduce_instance_dims) - with tf.compat.v1.Session() as sess: - sess.run(tf.compat.v1.tables_initializer()) - output = sess.run(results) - for result, expected_result in zip(output, expected_results): - self.assertAllEqual(result, expected_result) - - @test_case.named_parameters(test_case.cross_with_function_handlers([ - dict( - testcase_name='sparse_tensor', - feature=tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 1], [0, 2], [1, 0]], - values=[1., 2., 3., 4.], - dense_shape=[2, 5]), - input_signature=[tf.SparseTensorSpec([None, 5], tf.float32)], - ascii_protos=[ - 'float_list { value: [1.0, 2.0, 3.0] }', - 'float_list { value: [4.0] }', - ]), - dict( - testcase_name='dense_scalar_int', - feature=[0, 1, 2], - input_signature=[tf.TensorSpec([None], tf.int64)], - ascii_protos=[ - 'int64_list { value: [0] }', - 'int64_list { value: [1] }', - 'int64_list { value: [2] }', - ]), - dict( - testcase_name='dense_scalar_float', - feature=[0.5, 1.5, 2.5], - input_signature=[tf.TensorSpec([None], tf.float32)], - ascii_protos=[ - 'float_list { value: [0.5] }', - 'float_list { value: [1.5] }', - 'float_list { value: [2.5] }', - ]), - dict( - testcase_name='dense_scalar_string', - feature=['hello', 'world'], - input_signature=[tf.TensorSpec([None], tf.string)], - ascii_protos=[ - 'bytes_list { value: "hello" }', - 'bytes_list { value: "world" }', - ]), - dict( - testcase_name='dense_vector_int', - feature=[[0, 1], [2, 3]], - input_signature=[tf.TensorSpec([None, 2], tf.int64)], - ascii_protos=[ - 'int64_list { value: [0, 1] }', - 'int64_list { value: [2, 3] }', - ]), - dict( - testcase_name='dense_matrix_int', - feature=[[[0, 1], [2, 3]], [[4, 5], [6, 7]]], - input_signature=[tf.TensorSpec([None, 2, 2], tf.int64)], - ascii_protos=[ - 'int64_list { value: [0, 1, 2, 3] }', - 'int64_list { value: [4, 5, 6, 7] }', - ]), - ])) - def test_serialize_feature( - self, feature, input_signature, ascii_protos, function_handler): - - @function_handler(input_signature=input_signature) - def _serialize_feature(feature): - return tf_utils._serialize_feature(feature) - - serialized_features = _serialize_feature(feature) - - self.assertEqual(len(ascii_protos), len(serialized_features)) - for ascii_proto, serialized_feature in zip(ascii_protos, - serialized_features): - feature_proto = tf.train.Feature() - feature_proto.ParseFromString(serialized_feature) - self.assertProtoEquals(ascii_proto, feature_proto) - - @test_case.named_parameters( - dict( - testcase_name='multiple_features', - examples={ - 'my_value': - tf.compat.v1.SparseTensorValue( - indices=[[0, 0], [0, 1], [0, 2], [1, 0]], - values=[1., 2., 3., 4.], - dense_shape=[2, 5]), - 'my_other_value': - np.array([1, 2], np.int64), - }, - ascii_protos=[ - """ + def _assertCompositeRefEqual(self, left, right): + """Asserts that a two `tf_util._CompositeTensorRef`s are equal.""" + self.assertEqual(left.type_spec, right.type_spec) + self.assertAllEqual(left.list_of_refs, right.list_of_refs) + + def test_copy_tensors_produces_different_tensors(self): + with tf.compat.v1.Graph().as_default(): + tensors = { + "dense": tf.compat.v1.placeholder( + tf.int64, (None,), name="my_dense_input" + ), + "sparse": tf.compat.v1.sparse_placeholder( + tf.int64, name="my_sparse_input" + ), + "ragged": tf.compat.v1.ragged.placeholder( + tf.int64, ragged_rank=2, name="my_ragged_input" + ), + } + copied_tensors = tf_utils.copy_tensors(tensors) + + self.assertNotEqual(tensors["dense"], copied_tensors["dense"]) + self.assertNotEqual( + tensors["sparse"].indices, copied_tensors["sparse"].indices + ) + self.assertNotEqual( + tensors["sparse"].values, copied_tensors["sparse"].values + ) + self.assertNotEqual( + tensors["sparse"].dense_shape, copied_tensors["sparse"].dense_shape + ) + self.assertNotEqual( + tensors["ragged"].values, copied_tensors["ragged"].values + ) + self.assertNotEqual( + tensors["ragged"].row_splits, copied_tensors["ragged"].row_splits + ) + + def test_copy_tensors_produces_equivalent_tensors(self): + with tf.compat.v1.Graph().as_default(): + tensors = { + "dense": tf.compat.v1.placeholder( + tf.int64, (None,), name="my_dense_input" + ), + "sparse": tf.compat.v1.sparse_placeholder( + tf.int64, name="my_sparse_input" + ), + "ragged": tf.compat.v1.ragged.placeholder( + tf.int64, ragged_rank=1, name="my_ragged_input" + ), + } + copied_tensors = tf_utils.copy_tensors(tensors) + + with tf.compat.v1.Session() as session: + dense_value = [1, 2] + sparse_value = tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 2], [1, 1]], + values=[3, 4, 5], + dense_shape=[2, 3], + ) + ragged_value = tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([3, 4, 5], dtype=np.int64), + row_splits=np.array([0, 2, 3], dtype=np.int64), + ) + sample_tensors = session.run( + copied_tensors, + feed_dict={ + tensors["dense"]: dense_value, + tensors["sparse"]: sparse_value, + tensors["ragged"]: ragged_value, + }, + ) + self.assertAllEqual(sample_tensors["dense"], dense_value) + self.assertAllEqual( + sample_tensors["sparse"].indices, sparse_value.indices + ) + self.assertAllEqual( + sample_tensors["sparse"].values, sparse_value.values + ) + self.assertAllEqual( + sample_tensors["sparse"].dense_shape, sparse_value.dense_shape + ) + self.assertAllEqual( + sample_tensors["ragged"].values, ragged_value.values + ) + self.assertAllEqual( + sample_tensors["ragged"].row_splits, ragged_value.row_splits + ) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="2d", + tensor=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.2, 1.0, 1.2, 1.0]), + row_splits=np.array([0, 2, 4]), + ), + rowids=[0, 0, 1, 1], + tensor_spec=tf.RaggedTensorSpec([None, None], tf.float32), + ), + dict( + testcase_name="3d", + tensor=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.2, 1.0, 1.2, 1.0]), + row_splits=np.array([0, 3, 4]), + ), + row_splits=np.array([0, 1, 1, 2]), + ), + rowids=[0, 0, 0, 2], + tensor_spec=tf.RaggedTensorSpec([None, None, None], tf.float32), + ), + ] + ) + ) + def test_get_ragged_batch_value_rowids( + self, tensor, rowids, tensor_spec, function_handler + ): + @function_handler(input_signature=[tensor_spec]) + def get_ragged_batch_value_rowids(tensor): + return tf_utils._get_ragged_batch_value_rowids(tensor) + + self.assertAllEqual(get_ragged_batch_value_rowids(tensor), rowids) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="rank1", + x=["a", "b", "a"], + x_spec=tf.TensorSpec(None, tf.string), + weights=[1, 1, 2], + filter_regex=None, + expected_unique_x=[b"a", b"b"], + expected_summed_weights_per_x=[3, 1], + ), + dict( + testcase_name="rank2", + x=[["a", "b\n", "a"], ["b\n", "a", "b\n"]], + x_spec=tf.TensorSpec(None, tf.string), + weights=[[1, 2, 1], [1, 2, 2]], + filter_regex=None, + expected_unique_x=[b"a", b"b\n"], + expected_summed_weights_per_x=[4, 5], + ), + dict( + testcase_name="rank3", + x=[ + [["a", "b", "a"], ["b", "a", "b"]], + [["a", "b", "a"], ["b", "a", "b"]], + ], + x_spec=tf.TensorSpec(None, tf.string), + weights=[[[1, 1, 2], [1, 2, 1]], [[1, 2, 1], [1, 2, 1]]], + filter_regex=None, + expected_unique_x=[b"a", b"b"], + expected_summed_weights_per_x=[9, 7], + ), + dict( + testcase_name="sparse", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 1], [2, 1]], + values=["a", "a", "b"], + dense_shape=[4, 2], + ), + x_spec=tf.SparseTensorSpec([4, 2], tf.string), + weights=[2, 3, 4], + filter_regex=None, + expected_unique_x=[b"a", b"b"], + expected_summed_weights_per_x=[5, 4], + ), + dict( + testcase_name="ragged", + x=tf.compat.v1.ragged.RaggedTensorValue( # pylint: disable=g-long-lambda + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array(["a", "b", "b", "a"]), + row_splits=np.array([0, 2, 4]), + ), + row_splits=np.array([0, 2]), + ), + x_spec=tf.RaggedTensorSpec([None, None, None], tf.string), + weights=[2, 3, 4, 6], + filter_regex=None, + expected_unique_x=[b"a", b"b"], + expected_summed_weights_per_x=[8, 7], + ), + dict( + testcase_name="regex_filtering", + x=[["a\n", "", "\n\r"], ["\r", "a", "b"]], + x_spec=tf.TensorSpec(None, tf.string), + weights=[[1, 2, 1], [1, 2, 2]], + filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, + expected_unique_x=[b"a", b"b"], + expected_summed_weights_per_x=[2, 2], + ), + dict( + testcase_name="regex_filtering_invalid_utf8", + x=[ + [b"\xe1\n", b"\xa9", b"\n\xb8\r"], + [b"\xe8\r", b"\xc6", b"\n\xb3"], + ], + x_spec=tf.TensorSpec(None, tf.string), + weights=[[1, 3, 1], [1, 4, 2]], + filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, + expected_unique_x=[b"\xa9", b"\xc6"], + expected_summed_weights_per_x=[3, 4], + ), + ] + ) + ) + def test_reduce_batch_weighted_counts( + self, + x, + x_spec, + weights, + filter_regex, + expected_unique_x, + expected_summed_weights_per_x, + function_handler, + ): + input_signature = [x_spec, tf.TensorSpec(None, tf.float32)] + + @function_handler(input_signature=input_signature) + def _reduce_batch_weighted_counts(x, weights): + ( + unique_x, + summed_weights_per_x, + summed_positive_per_x_and_y, + counts_per_x, + ) = tf_utils.reduce_batch_weighted_counts( + x, weights, filter_regex=filter_regex + ) + self.assertIsNone(summed_positive_per_x_and_y) + self.assertIsNone(counts_per_x) + return unique_x, summed_weights_per_x + + unique_x, summed_weights_per_x = _reduce_batch_weighted_counts(x, weights) + + self.assertAllEqual(unique_x, expected_unique_x) + self.assertAllEqual(summed_weights_per_x, expected_summed_weights_per_x) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="rank1", + x=["a", "b", "a"], + filter_regex=None, + expected_result=[b"a", b"b", b"a"], + ), + dict( + testcase_name="rank2", + x=[["a", "b\r", "a"], ["b\r", "a", "b\r"]], + filter_regex=None, + expected_result=[b"a", b"b\r", b"a", b"b\r", b"a", b"b\r"], + ), + dict( + testcase_name="rank3", + x=[ + [["a", "b", "a"], ["b", "a", "b"]], + [["a", "b", "a"], ["b", "a", "b"]], + ], + filter_regex=None, + expected_result=[ + b"a", + b"b", + b"a", + b"b", + b"a", + b"b", + b"a", + b"b", + b"a", + b"b", + b"a", + b"b", + ], + ), + dict( + testcase_name="regex_filtering_empty_result", + x=["a\n\r", "b\n", "a\r", "", "a\rsd", " \r", "\nas"], + filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, + expected_result=[], + ), + ] + ) + ) + def test_reduce_batch_weighted_counts_weights_none( + self, x, filter_regex, expected_result, function_handler + ): + input_signature = [tf.TensorSpec(None, tf.string)] + + @function_handler(input_signature=input_signature) + def _reduce_batch_weighted_counts(x): + ( + unique_x, + summed_weights_per_x, + summed_positive_per_x_and_y, + counts_per_x, + ) = tf_utils.reduce_batch_weighted_counts( + x, force=False, filter_regex=filter_regex + ) + self.assertIsNone(summed_weights_per_x) + self.assertIsNone(summed_positive_per_x_and_y) + self.assertIsNone(counts_per_x) + return unique_x + + unique_x = _reduce_batch_weighted_counts(x) + self.assertAllEqual(unique_x, expected_result) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="rank1", + x=["a", "b", "a"], + filter_regex=None, + expected_result=([b"a", b"b"], [2, 1]), + ), + dict( + testcase_name="rank3", + x=[ + [["a", "b", "a"], ["b", "a", "b"]], + [["a", "b", "a"], ["b", "a", "b"]], + ], + filter_regex=None, + expected_result=([b"a", b"b"], [6, 6]), + ), + dict( + testcase_name="regex_filtering", + x=["a\n\r", "b\n", "a\r", "", "asd", " ", "\nas"], + filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, + expected_result=([b"asd", b" "], [1, 1]), + ), + dict( + testcase_name="regex_filtering_empty_result", + x=["a\n\r", "b\n", "a\r", "", "a\rsd", " \r", "\nas"], + filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, + expected_result=([], []), + ), + ] + ) + ) + def test_reduce_batch_weighted_counts_weights_none_force( + self, x, filter_regex, expected_result, function_handler + ): + input_signature = [tf.TensorSpec(None, tf.string)] + + @function_handler(input_signature=input_signature) + def _reduce_batch_weighted_counts(x): + ( + unique_x, + summed_weights_per_x, + summed_positive_per_x_and_y, + counts_per_x, + ) = tf_utils.reduce_batch_weighted_counts( + x, force=True, filter_regex=filter_regex + ) + self.assertIsNone(summed_weights_per_x) + self.assertIsNone(summed_positive_per_x_and_y) + return unique_x, counts_per_x + + expected_unique_x, expected_counts_per_x = expected_result + unique_x, counts_per_x = _reduce_batch_weighted_counts(x) + self.assertAllEqual(unique_x, expected_unique_x) + self.assertAllEqual(counts_per_x, expected_counts_per_x) + + @test_case.named_parameters( + [ + dict(testcase_name="constant", get_value_fn=lambda: tf.constant([1.618])), + dict(testcase_name="op", get_value_fn=lambda: tf.identity), + dict(testcase_name="int", get_value_fn=lambda: 4), + dict(testcase_name="object", get_value_fn=object), + dict( + testcase_name="sparse", + get_value_fn=lambda: tf.SparseTensor( # pylint: disable=g-long-lambda + indices=[[0, 0], [2, 1]], values=["a", "b"], dense_shape=[4, 2] + ), + ), + dict( + testcase_name="ragged", + get_value_fn=lambda: tf.RaggedTensor.from_row_splits( # pylint: disable=g-long-lambda + values=["a", "b"], row_splits=[0, 1, 2] + ), + ), + dict( + testcase_name="ragged_multi_dimension", + get_value_fn=lambda: tf.RaggedTensor.from_row_splits( # pylint: disable=g-long-lambda + values=tf.RaggedTensor.from_row_splits( + values=[[0, 1], [2, 3]], row_splits=[0, 1, 2] + ), + row_splits=[0, 2], + ), + ), + ] + ) + def test_hashable_tensor_or_op(self, get_value_fn): + with tf.compat.v1.Graph().as_default(): + input_value = get_value_fn() + input_ref = tf_utils.hashable_tensor_or_op(input_value) + input_dict = {input_ref: input_value} + input_deref = tf_utils.deref_tensor_or_op(input_ref) + if isinstance(input_value, composite_tensor.CompositeTensor): + self._assertCompositeRefEqual( + input_ref, tf_utils.hashable_tensor_or_op(input_deref) + ) + else: + self.assertAllEqual( + input_ref, tf_utils.hashable_tensor_or_op(input_deref) + ) + + if isinstance(input_value, tf.SparseTensor): + input_deref = input_deref.values + input_dict[input_ref] = input_dict[input_ref].values + input_value = input_value.values + + self.assertAllEqual(input_value, input_deref) + self.assertAllEqual(input_value, input_dict[input_ref]) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="rank1_with_weights_and_binary_y", + x=["a", "b", "a"], + weights=[1, 1, 2], + y=[0, 1, 1], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [3, 1, 4], + [[1, 2], [0, 1], [1, 3]], + [2, 1, 3], + ), + filter_regex=None, + ), + dict( + testcase_name="rank1_with_weights_and_multi_class_y", + x=["a", "b\n", "a", "a"], + weights=[1, 1, 2, 2], + y=[0, 2, 1, 1], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b\n", b"global_y_count_sentinel"], + [5, 1, 6], + [[1, 4, 0], [0, 0, 1], [1, 4, 1]], + [3, 1, 4], + ), + filter_regex=None, + ), + dict( + testcase_name="rank1_with_weights_and_missing_y_values", + x=["a", "b", "a", "a"], + weights=[1, 1, 2, 2], + y=[3, 5, 6, 6], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [5, 1, 6], + [ + [0, 0, 0, 1, 0, 0, 4], + [0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 1, 4], + ], + [3, 1, 4], + ), + filter_regex=None, + ), + dict( + testcase_name="rank2_with_weights_and_binary_y", + x=[["a", "b", "a"], ["b", "a", "b"]], + weights=[[1, 2, 1], [1, 2, 2]], + y=[[1, 0, 1], [1, 0, 0]], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [4, 5, 9], + [[2, 2], [4, 1], [6, 3]], + [3, 3, 6], + ), + filter_regex=None, + ), + dict( + testcase_name="rank3_with_weights_and_binary_y", + x=[ + [["a", "b", "a"], ["b", "a", "b"]], + [["a", "b", "a"], ["b", "a", "b"]], + ], + weights=[[[1, 1, 2], [1, 2, 1]], [[1, 2, 1], [1, 2, 1]]], + y=[[[1, 1, 0], [1, 0, 1]], [[1, 0, 1], [1, 0, 1]]], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [9, 7, 16], + [[6, 3], [2, 5], [8, 8]], + [6, 6, 12], + ), + filter_regex=None, + ), + dict( + testcase_name="rank1_with_weights_multi_class_y_and_filtering", + x=["\na\r", "", "\na\r", "a", ""], + weights=[1, 1, 2, 2, 3], + y=[0, 2, 1, 1, 2], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"global_y_count_sentinel"], + [2, 9], + [[0, 2, 0], [1, 4, 4]], + [1, 5], + ), + filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, + ), + dict( + testcase_name="rank1_with_weights_filtering_empty_result", + x=["\na\r", "", "\na\r", "\ra", ""], + weights=[1, 1, 2, 2, 3], + y=[0, 2, 1, 1, 2], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"global_y_count_sentinel"], [9], [[1, 4, 4]], [5] + ), + filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, + ), + ] + ) + ) + def test_reduce_batch_coocurrences( + self, x, weights, y, expected_result, filter_regex, function_handler + ): + input_signature = [ + tf.TensorSpec(None, tf.string), + tf.TensorSpec(None, tf.int64), + tf.TensorSpec(None, tf.int64), + ] + + @function_handler(input_signature=input_signature) + def _reduce_batch_weighted_cooccurrences(x, y, weights): + return tf_utils.reduce_batch_weighted_cooccurrences( + x, y, weights, filter_regex=filter_regex + ) + + result = _reduce_batch_weighted_cooccurrences(x, y, weights) + + self.assertAllEqual(result.unique_x, expected_result.unique_x) + self.assertAllEqual( + result.summed_weights_per_x, expected_result.summed_weights_per_x + ) + self.assertAllEqual( + result.summed_positive_per_x_and_y, + expected_result.summed_positive_per_x_and_y, + ) + self.assertAllEqual(result.counts_per_x, expected_result.counts_per_x) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="rank1_with_binary_y", + x=["a", "b", "a"], + y=[0, 1, 1], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [2, 1, 3], + [[1, 1], [0, 1], [1, 2]], + [2, 1, 3], + ), + input_signature=[ + tf.TensorSpec(None, tf.string), + tf.TensorSpec(None, tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="rank1_with_multi_class_y", + x=["yes", "no", "yes", "may\rbe", "yes"], + y=[1, 1, 0, 2, 3], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"yes", b"no", b"may\rbe", b"global_y_count_sentinel"], + [3, 1, 1, 5], + [[1, 1, 0, 1], [0, 1, 0, 0], [0, 0, 1, 0], [1, 2, 1, 1]], + [3, 1, 1, 5], + ), + input_signature=[ + tf.TensorSpec(None, tf.string), + tf.TensorSpec(None, tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="rank1_with_sparse_y", + x=["yes", "no", "yes", "may\rbe", "yes"], + # 5 examples, 4 labels: + # 0: (3,2) + # 1: (1) + # 2: (2,0) + # 3: (1) + # 4: (2, 1) + y=tf.compat.v1.SparseTensorValue( + indices=( + (0, 0), + (0, 1), + (1, 0), + (2, 0), + (2, 1), + (3, 1), + (4, 0), + (4, 1), + ), + values=[3, 2, 1, 2, 0, 1, 2, 1], + dense_shape=[5, 2], + ), + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"yes", b"no", b"may\rbe", b"global_y_count_sentinel"], + [3, 1, 1, 5], + [[1, 1, 3, 1], [0, 1, 0, 0], [0, 1, 0, 0], [1, 3, 3, 1]], + [3, 1, 1, 5], + ), + input_signature=[ + tf.TensorSpec(None, tf.string), + tf.SparseTensorSpec([None, 2], tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="rank1_with_sparse_y_missing_labels", + x=["yes", "no", "yes", "may\rbe", "yes"], + # 5 examples, 4 labels: + # 0: (3,2) + # 1: () + # 2: (2,0) + # 3: (1) + # 4: (2, 1) + y=tf.compat.v1.SparseTensorValue( + indices=( + (0, 0), + (0, 1), + (2, 0), + (2, 1), + (3, 1), + (4, 0), + (4, 1), + ), + values=[3, 2, 2, 0, 1, 2, 1], + dense_shape=[5, 2], + ), + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"yes", b"no", b"may\rbe", b"global_y_count_sentinel"], + [3, 1, 1, 5], + [[1, 1, 3, 1], [0, 0, 0, 0], [0, 1, 0, 0], [1, 2, 3, 1]], + [3, 1, 1, 5], + ), + input_signature=[ + tf.TensorSpec(None, tf.string), + tf.SparseTensorSpec([None, 2], tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="rank2_with_binary_y", + x=[["a", "b", "a"], ["b", "a", "b"]], + y=[[1, 0, 1], [1, 0, 0]], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [3, 3, 6], + [[1, 2], [2, 1], [3, 3]], + [3, 3, 6], + ), + input_signature=[ + tf.TensorSpec(None, tf.string), + tf.TensorSpec(None, tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="rank2_with_missing_y_values", + x=[["a", "b", "a"], ["b", "a", "b"]], + y=[[2, 0, 2], [2, 0, 0]], + # The label 1 isn't in the batch but it will have a position (with + # weights of 0) in the resulting array. + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [3, 3, 6], + [[1, 0, 2], [2, 0, 1], [3, 0, 3]], + [3, 3, 6], + ), + input_signature=[ + tf.TensorSpec(None, tf.string), + tf.TensorSpec(None, tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="rank2_with_multi_class_y", + x=[["a", "b", "a"], ["b", "a", "b"]], + y=[[1, 0, 1], [1, 0, 2]], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [3, 3, 6], + [[1, 2, 0], [1, 1, 1], [2, 3, 1]], + [3, 3, 6], + ), + input_signature=[ + tf.TensorSpec(None, tf.string), + tf.TensorSpec(None, tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="rank3_with_binary_y", + x=[ + [["a", "b", "a"], ["b", "a", "b"]], + [["a", "b", "a"], ["b", "a", "b"]], + ], + y=[[[1, 1, 0], [1, 0, 1]], [[1, 0, 1], [1, 0, 1]]], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [6, 6, 12], + [[3, 3], [1, 5], [4, 8]], + [6, 6, 12], + ), + input_signature=[ + tf.TensorSpec(None, tf.string), + tf.TensorSpec(None, tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="sparse", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [2, 1]], + values=["a", "b"], + dense_shape=[4, 2], + ), + y=[0, 1, 0, 0], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [1, 1, 4], + [[1, 0], [1, 0], [3, 1]], + [1, 1, 4], + ), + input_signature=[ + tf.SparseTensorSpec([None, 2], tf.string), + tf.TensorSpec([None], tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="empty_sparse", + x=tf.compat.v1.SparseTensorValue( + indices=np.empty([0, 2]), values=[], dense_shape=[4, 2] + ), + y=[1, 0, 1, 1], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"global_y_count_sentinel"], [4], [[1, 3]], [4] + ), + input_signature=[ + tf.SparseTensorSpec([None, 2], tf.string), + tf.TensorSpec([None], tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="ragged", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array(["a", "b", "a", "b", "b"]), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + y=[1, 0], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"a", b"b", b"global_y_count_sentinel"], + [2, 3, 2], + [[0, 2], [1, 2], [1, 1]], + [2, 3, 2], + ), + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.string), + tf.TensorSpec([None], tf.int64), + ], + filter_regex=None, + ), + dict( + testcase_name="rank1_with_filtering", + x=["yes\n", "no", "yes\n", "", "yes\n"], + y=[1, 1, 0, 2, 3], + expected_result=tf_utils.ReducedBatchWeightedCounts( + [b"no", b"global_y_count_sentinel"], + [1, 5], + [[0, 1, 0, 0], [1, 2, 1, 1]], + [1, 5], + ), + input_signature=[ + tf.TensorSpec(None, tf.string), + tf.TensorSpec(None, tf.int64), + ], + filter_regex=analyzers._EMPTY_STRING_OR_NEWLINE_CHARS_REGEX, + ), + ] + ) + ) + def test_reduce_batch_coocurrences_no_weights( + self, x, y, expected_result, input_signature, filter_regex, function_handler + ): + @function_handler(input_signature=input_signature) + def _reduce_batch_weighted_cooccurrences_no_weights(x, y): + return tf_utils.reduce_batch_weighted_cooccurrences( + x, y, filter_regex=filter_regex + ) + + result = _reduce_batch_weighted_cooccurrences_no_weights(x, y) + + self.assertAllEqual(result.unique_x, expected_result.unique_x) + self.assertAllEqual( + result.summed_weights_per_x, expected_result.summed_weights_per_x + ) + self.assertAllEqual( + result.summed_positive_per_x_and_y, + expected_result.summed_positive_per_x_and_y, + ) + self.assertAllEqual(result.counts_per_x, expected_result.counts_per_x) + + @test_case.parameters( + ( + [[1], [2]], + [[1], [2], [3]], + None, + None, + tf.errors.InvalidArgumentError, + "Condition x == y did not hold element-wise:", + ), + ( + [[1], [2], [3]], + [[1], [2], [3]], + [None, None], + [None], + ValueError, + r"Shapes \(None, None\) and \(None,\) are incompatible", + ), + ) + def test_same_shape_exceptions( + self, x_input, y_input, x_shape, y_shape, exception_cls, error_string + ): + with tf.compat.v1.Graph().as_default(): + x = tf.compat.v1.placeholder(tf.int32, x_shape) + y = tf.compat.v1.placeholder(tf.int32, y_shape) + with tf.compat.v1.Session() as sess: + with self.assertRaisesRegex(exception_cls, error_string): + sess.run(tf_utils.assert_same_shape(x, y), {x: x_input, y: y_input}) + + @test_case.named_parameters(test_case.FUNCTION_HANDLERS) + def test_same_shape(self, function_handler): + input_signature = [tf.TensorSpec(None, tf.int64), tf.TensorSpec(None, tf.int64)] + + @function_handler(input_signature=input_signature) + def _assert_shape(x, y): + x_return, _ = tf_utils.assert_same_shape(x, y) + return x_return + + input_list = [[1], [2], [3]] + x_return = _assert_shape(input_list, input_list) + self.assertAllEqual(x_return, input_list) + + @test_case.named_parameters( + [ + dict( + testcase_name="_all_keys_in_vocab", + query_list=["a", "a", "b", "a", "b"], + key_vocab_list=["a", "b"], + query_shape=[None], + expected_output=[0, 0, 1, 0, 1], + ), + dict( + testcase_name="_missing_keys_in_vocab", + query_list=["a", "c", "b", "a", "b"], + key_vocab_list=["a", "b"], + query_shape=[None], + expected_output=[0, -1, 1, 0, 1], + ), + dict( + testcase_name="_nd_keys", + query_list=[["a", "c", "b"], ["a", "b", "a"]], + key_vocab_list=["a", "b"], + query_shape=[None, None], + expected_output=[[0, -1, 1], [0, 1, 0]], + ), + dict( + testcase_name="_empty_vocab", + query_list=["a", "c", "b", "a", "b"], + key_vocab_list=[], + query_shape=[None], + expected_output=[-1, -1, -1, -1, -1], + ), + dict( + testcase_name="_empty_query", + query_list=[], + key_vocab_list=["a"], + query_shape=[None], + expected_output=[], + ), + ] + ) + def test_lookup_key(self, query_list, key_vocab_list, query_shape, expected_output): + with tf.compat.v1.Graph().as_default(): + query_ph = tf.compat.v1.placeholder( + dtype=tf.string, shape=query_shape, name="query" + ) + key_vocab_ph = tf.compat.v1.placeholder( + dtype=tf.string, shape=[None], name="key_vocab" + ) + key_indices = tf_utils.lookup_key(query_ph, key_vocab_ph) + with tf.compat.v1.Session().as_default() as sess: + output = sess.run( + key_indices, + feed_dict={ + query_ph.name: query_list, + key_vocab_ph.name: key_vocab_list, + }, + ) + self.assertAllEqual(expected_output, output) + + @test_case.named_parameters( + [ + dict( + testcase_name="_with_default", + with_default_value=True, + input_keys=["a", "b", "c", "d", "e"], + ), + dict( + testcase_name="_wihout_default", + with_default_value=False, + input_keys=["a", "b", "c", "d", "e"], + ), + dict( + testcase_name="_single_oov_key", + with_default_value=False, + input_keys=["e"], + ), + ] + ) + def test_apply_per_key_vocab(self, with_default_value, input_keys): + default_value = "-7,-5" if with_default_value else None + vocab_data = [("0,0", "a"), ("1,-1", "b"), ("-1,1", "c"), ("-2,2", "d")] + expected_missing_key_result = [-7, -5] if default_value else [0, 0] + expected_lookup_results = { + "a": [0, 0], + "b": [1, -1], + "c": [-1, 1], + "d": [-2, 2], + } + + with tf.compat.v1.Graph().as_default(): + input_tensor = _value_to_tensor(input_keys) + vocab_filename = os.path.join(self.get_temp_dir(), "test.txt") + encoded_vocab = "\n".join([" ".join(pair) for pair in vocab_data]) + with tf.io.gfile.GFile(vocab_filename, "w") as f: + f.write(encoded_vocab) + + output_tensor = tf_utils.apply_per_key_vocabulary( + tf.constant(vocab_filename), input_tensor, default_value=default_value + ) + + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.tables_initializer()) + output = output_tensor.eval() + + expected_data = [ + expected_lookup_results.get(key, expected_missing_key_result) + for key in input_keys + ] + self.assertAllEqual(output, expected_data) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="dense", + x=[[[1], [2]], [[1], [2]]], + expected_result=4, + reduce_instance_dims=True, + input_signature=[tf.TensorSpec(None, tf.int64)], + ), + dict( + testcase_name="dense_with_nans", + x=[[[1], [np.nan]], [[1], [2]]], + expected_result=3, + reduce_instance_dims=True, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="dense_elementwise", + x=[[[1], [2]], [[1], [2]]], + expected_result=[[2], [2]], + reduce_instance_dims=False, + input_signature=[tf.TensorSpec(None, tf.int64)], + ), + dict( + testcase_name="dense_elementwise_with_nans", + x=[[[1], [2]], [[1], [np.nan]]], + expected_result=[[2], [1]], + reduce_instance_dims=False, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="sparse", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 0], [0, 2, 0], [1, 1, 0], [1, 2, 0]], + values=[1.0, 2.0, 3.0, 4.0], + dense_shape=[2, 4, 1], + ), + expected_result=4, + reduce_instance_dims=True, + input_signature=[tf.SparseTensorSpec([None, 4, 1], tf.float32)], + ), + dict( + testcase_name="sparse_with_nans", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 0], [0, 2, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0]], + values=[1.0, 2.0, 3.0, 4.0, np.nan], + dense_shape=[2, 4, 1], + ), + expected_result=4, + reduce_instance_dims=True, + input_signature=[tf.SparseTensorSpec([None, 4, 1], tf.float32)], + ), + dict( + testcase_name="sparse_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 0], [0, 2, 0], [1, 1, 0], [1, 2, 0]], + values=[1.0, 2.0, 3.0, 4.0], + dense_shape=[2, 4, 1], + ), + expected_result=[[1], [1], [2], [0]], + reduce_instance_dims=False, + input_signature=[tf.SparseTensorSpec([None, 4, 1], tf.float32)], + ), + dict( + testcase_name="sparse_elementwise_with_nans", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 0], [0, 2, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0]], + values=[1.0, 2.0, 3.0, 4.0, np.nan], + dense_shape=[2, 4, 1], + ), + expected_result=[[1], [1], [2], [0]], + reduce_instance_dims=False, + input_signature=[tf.SparseTensorSpec([None, 4, 1], tf.float32)], + ), + dict( + testcase_name="ragged", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_result=5, + reduce_instance_dims=True, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32) + ], + ), + dict( + testcase_name="ragged_with_nans", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, np.nan], np.float32 + ), + row_splits=np.array([0, 2, 3, 4, 6]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_result=5, + reduce_instance_dims=True, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32) + ], + ), + dict( + testcase_name="ragged_elementwise", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 2, 4, 5]), + ), + row_splits=np.array([0, 3, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_result=[ + [[2, 1], [0.0, 0], [1, 1]], + [[0, 0], [0, 0], [0, 0]], + ], + reduce_instance_dims=False, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32) + ], + ), + dict( + testcase_name="ragged_elementwise_with_nans", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, np.nan], np.float32 + ), + row_splits=np.array([0, 2, 2, 4, 6]), + ), + row_splits=np.array([0, 3, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_result=[ + [[2, 1], [0.0, 0], [1, 1]], + [[0, 0], [0, 0], [0, 0]], + ], + reduce_instance_dims=False, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32) + ], + ), + ] + ) + ) + def test_reduce_batch_count( + self, + x, + input_signature, + expected_result, + reduce_instance_dims, + function_handler, + ): + @function_handler(input_signature=input_signature) + def _reduce_batch_count(x): + result = tf_utils.reduce_batch_count( + x, reduce_instance_dims=reduce_instance_dims + ) + # Verify that the output shape is maintained. + # TODO(b/178189903): This will fail if _dense_shape_default isn't set in + # reduce_batch_count. + if ( + not isinstance(x, tf.RaggedTensor) + and not reduce_instance_dims + and x.get_shape().ndims + ): + self.assertEqual( + x.get_shape()[1:].as_list(), result.get_shape().as_list() + ) + return result + + result = _reduce_batch_count(x) + self.assertAllEqual(result, expected_result) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="dense", + x=[[[1], [2]], [[3], [4]]], + expected_count=4, + expected_mean=2.5, + expected_var=1.25, + reduce_instance_dims=True, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="dense_with_nans", + x=[[[1], [2]], [[3], [np.nan]], [[np.nan], [4]]], + expected_count=4, + expected_mean=2.5, + expected_var=1.25, + reduce_instance_dims=True, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="dense_elementwise", + x=[[[1], [2]], [[3], [4]]], + expected_count=[[2.0], [2.0]], + expected_mean=[[2.0], [3.0]], + expected_var=[[1.0], [1.0]], + reduce_instance_dims=False, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="dense_elementwise_with_nans", + x=[[[1], [2]], [[3], [np.nan]], [[np.nan], [4]]], + expected_count=[[2.0], [2.0]], + expected_mean=[[2.0], [3.0]], + expected_var=[[1.0], [1.0]], + reduce_instance_dims=False, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="sparse", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 2], [1, 1], [1, 2]], + values=[1.0, 2.0, 3.0, 4.0], + dense_shape=[2, 4], + ), + expected_count=4, + expected_mean=2.5, + expected_var=1.25, + reduce_instance_dims=True, + input_signature=[tf.SparseTensorSpec([None, 4], tf.float32)], + ), + dict( + testcase_name="sparse_with_nans", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 2], [1, 1], [1, 2], [1, 3]], + values=[1.0, 2.0, 3.0, 4.0, np.nan], + dense_shape=[2, 4], + ), + expected_count=4, + expected_mean=2.5, + expected_var=1.25, + reduce_instance_dims=True, + input_signature=[tf.SparseTensorSpec([None, 4], tf.float32)], + ), + dict( + testcase_name="sparse_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 3], [1, 1], [1, 3]], + values=[1.0, 2.0, 3.0, 4.0], + dense_shape=[2, 5], + ), + expected_count=[1.0, 1.0, 0.0, 2.0, 0.0], + expected_mean=[1.0, 3.0, 0.0, 3.0, 0.0], + expected_var=[0.0, 0.0, 0.0, 1.0, 0.0], + reduce_instance_dims=False, + input_signature=[tf.SparseTensorSpec([None, 5], tf.float32)], + ), + dict( + testcase_name="sparse_elementwise_with_nans", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 3], [1, 1], [1, 2], [1, 3]], + values=[1.0, 2.0, 3.0, np.nan, 4.0], + dense_shape=[2, 5], + ), + expected_count=[1.0, 1.0, 0.0, 2.0, 0.0], + expected_mean=[1.0, 3.0, 0.0, 3.0, 0.0], + expected_var=[0.0, 0.0, 0.0, 1.0, 0.0], + reduce_instance_dims=False, + input_signature=[tf.SparseTensorSpec([None, 5], tf.float32)], + ), + dict( + testcase_name="sparse_3d_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 3], [0, 1, 0], [0, 1, 3], [1, 1, 1], [1, 1, 3]], + values=[-10.0, 1.0, 2.0, 3.0, 4.0], + dense_shape=[2, 3, 5], + ), + expected_count=[[0, 0, 0, 1, 0], [1, 1, 0, 2, 0], [0] * 5], + expected_mean=[[0, 0, 0, -10, 0], [1, 3, 0, 3, 0], [0] * 5], + expected_var=[[0] * 5, [0, 0, 0, 1, 0], [0] * 5], + reduce_instance_dims=False, + input_signature=[tf.SparseTensorSpec([None, 3, 5], tf.float32)], + ), + dict( + testcase_name="ragged", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_count=5, + expected_mean=3, + expected_var=2, + reduce_instance_dims=True, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32) + ], + ), + dict( + testcase_name="ragged_with_nans", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, np.nan], np.float32 + ), + row_splits=np.array([0, 2, 3, 4, 6]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_count=5, + expected_mean=3, + expected_var=2, + reduce_instance_dims=True, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32) + ], + ), + dict( + testcase_name="ragged_elementwise", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 2, 4, 5]), + ), + row_splits=np.array([0, 3, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_count=[ + [[2.0, 1.0], [0.0, 0.0], [1.0, 1.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ], + expected_mean=[ + [[3.0, 2.0], [0.0, 0.0], [3.0, 4.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ], + expected_var=[ + [[4.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ], + reduce_instance_dims=False, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32) + ], + ), + dict( + testcase_name="ragged_elementwise_with_nans", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array( + [1.0, 2.0, 3.0, 4.0, 5.0, np.nan], np.float32 + ), + row_splits=np.array([0, 2, 2, 4, 6]), + ), + row_splits=np.array([0, 3, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_count=[ + [[2.0, 1.0], [0.0, 0.0], [1.0, 1.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ], + expected_mean=[ + [[3.0, 2.0], [0.0, 0.0], [3.0, 4.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ], + expected_var=[ + [[4.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]], + ], + reduce_instance_dims=False, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32) + ], + ), + ] + ) + ) + def test_reduce_batch_count_mean_and_var( + self, + x, + input_signature, + expected_count, + expected_mean, + expected_var, + reduce_instance_dims, + function_handler, + ): + @function_handler(input_signature=input_signature) + def _reduce_batch_count_mean_and_var(x): + result = tf_utils.reduce_batch_count_mean_and_var( + x, reduce_instance_dims=reduce_instance_dims + ) + # Verify that the output shapes are maintained. + # TODO(b/178189903): This will fail if _dense_shape_default isn't set in + # reduce_batch_count. + if ( + not isinstance(x, tf.RaggedTensor) + and not reduce_instance_dims + and x.get_shape().ndims + ): + for tensor in result: + self.assertEqual( + x.get_shape()[1:].as_list(), tensor.get_shape().as_list() + ) + return result + + count, mean, var = _reduce_batch_count_mean_and_var(x) + self.assertAllEqual(expected_count, count) + self.assertAllEqual(expected_mean, mean) + self.assertAllEqual(expected_var, var) + + @test_case.named_parameters( + [ + dict( + testcase_name="num_samples_1", + num_samples=1, + dtype=tf.float32, + expected_counts=np.array([1, 0, 0, 0], np.float32), + expected_factors=np.array([[1.0], [0.0], [0.0], [0.0]], np.float32), + ), + dict( + testcase_name="num_samples_2", + num_samples=2, + dtype=tf.float32, + expected_counts=np.array([2, 1, 0, 0], np.float32), + expected_factors=np.array( + [ + [1.0 / 2.0, 1.0 / 2.0], + [-1.0 / 2.0, 1.0 / 2.0], + [0.0, 0.0], + [0.0, 0.0], + ], + np.float32, + ), + ), + dict( + testcase_name="num_samples_3", + num_samples=3, + dtype=tf.float32, + expected_counts=np.array([3, 3, 1, 0], np.float32), + expected_factors=np.array( + [ + [1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0], + [-1.0 / 3.0, 0.0, 1.0 / 3.0], + [1.0 / 3.0, -2.0 / 3.0, 1.0 / 3.0], + [0.0, 0.0, 0.0], + ], + np.float32, + ), + ), + dict( + testcase_name="num_samples_4", + num_samples=4, + dtype=tf.float32, + expected_counts=np.array([4, 6, 4, 1], np.float32), + expected_factors=np.array( + [ + [1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0, 1.0 / 4.0], + [-3.0 / 12.0, -1.0 / 12.0, 1.0 / 12.0, 3.0 / 12.0], + [1.0 / 4.0, -1.0 / 4.0, -1.0 / 4.0, 1.0 / 4.0], + [-1.0 / 4.0, 3.0 / 4.0, -3.0 / 4.0, 1.0 / 4.0], + ], + np.float32, + ), + ), + ] + ) + def test_num_terms_and_factors( + self, num_samples, dtype, expected_counts, expected_factors + ): + results = tf_utils._num_terms_and_factors(num_samples, dtype) + counts = results[0:4] + assert len(expected_counts) == len(counts), (expected_counts, counts) + for result, expected_count in zip(counts, expected_counts): + self.assertEqual(result.dtype, dtype) + self.assertAllClose(result, expected_count) + + factors = results[4:] + assert len(expected_factors) == len(factors), (expected_factors, factors) + for result, expected_factor in zip(factors, expected_factors): + self.assertEqual(result.dtype, dtype) + self.assertAllClose(result, expected_factor) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="dense", + x=[[[1], [2]], [[3], [4]]], + expected_counts=np.array([4.0, 6.0, 4.0, 1.0], np.float32), + expected_moments=np.array([2.5, 10.0 / 12.0, 0.0, 0.0], np.float32), + reduce_instance_dims=True, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="dense_large", + x=[2.0, 3.0, 4.0, 2.4, 5.5, 1.2, 5.4, 2.2, 7.1, 1.3, 1.5], + expected_counts=np.array( + [11, 11 * 10 // 2, 11 * 10 * 9 // 6, 11 * 10 * 9 * 8 // 24], + np.float32, + ), + expected_moments=np.array( + [ + 3.2363636363636363, + 1.141818181818182, + 0.31272727272727263, + 0.026666666666666616, + ], + np.float32, + ), + reduce_instance_dims=True, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="dense_very_large", + x=-np.log(1.0 - np.arange(0, 1, 1e-6, dtype=np.float32)), + expected_counts=np.array( + [1000000, 499999500000.0, 1.66666166667e17, 4.1666416667125e22], + np.float32, + ), + expected_moments=np.array( + [ + 0.99999217330, + 0.4999936732947, + 0.166660839941, + 0.0833278399134, + ], + np.float32, + ), + reduce_instance_dims=True, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="dense_elementwise", + x=[[[1], [2]], [[3], [4]]], + expected_counts=np.array( + [[[2], [2]], [[1], [1]], [[0], [0]], [[0], [0]]], np.float32 + ), + expected_moments=np.array( + [ + [[2.0], [3.0]], + [[1.0], [1.0]], + [[0.0], [0.0]], + [[0.0], [0.0]], + ], + np.float32, + ), + reduce_instance_dims=False, + input_signature=[tf.TensorSpec(None, tf.float32)], + ), + dict( + testcase_name="sparse", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 2], [2, 0], [2, 2]], + values=[1.0, 2.0, 3.0, 4.0], + dense_shape=[3, 4], + ), + expected_counts=np.array([4, 6, 4, 1], np.float32), + expected_moments=np.array([2.5, 10.0 / 12.0, 0.0, 0.0], np.float32), + reduce_instance_dims=True, + input_signature=[tf.SparseTensorSpec([None, 4], tf.float32)], + ), + dict( + testcase_name="sparse_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 0], [0, 2, 0], [2, 0, 0], [2, 2, 0], [3, 3, 0]], + values=[1.0, 2.0, 3.0, 4.0, 5.0], + dense_shape=[3, 5, 1], + ), + expected_counts=np.array( + [ + [[2], [0], [2], [1], [0]], + [[1], [0], [1], [0], [0]], + [[0], [0], [0], [0], [0]], + [[0], [0], [0], [0], [0]], + ], + np.float32, + ), + expected_moments=np.array( + [ + [[2.0], [0.0], [3.0], [5.0], [0.0]], + [[1.0], [0.0], [1.0], [0.0], [0.0]], + [[0.0], [0.0], [0.0], [0.0], [0.0]], + [[0.0], [0.0], [0.0], [0.0], [0.0]], + ], + np.float32, + ), + reduce_instance_dims=False, + input_signature=[tf.SparseTensorSpec([None, 5, 1], tf.float32)], + ), + dict( + testcase_name="ragged", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_counts=np.array([5.0, 10.0, 10.0, 5.0], np.float32), + expected_moments=np.array([3.0, 1.0, 0.0, 0.0], np.float32), + reduce_instance_dims=True, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32) + ], + ), + ] + ) + ) + def test_reduce_batch_count_l_moments( + self, + x, + input_signature, + expected_counts, + expected_moments, + reduce_instance_dims, + function_handler, + ): + @function_handler(input_signature=input_signature) + def _reduce_batch_count_l_moments(x): + result = tf_utils.reduce_batch_count_l_moments( + x, reduce_instance_dims=reduce_instance_dims + ) + for tensor in result: + if not reduce_instance_dims and x.get_shape().ndims: + self.assertEqual( + x.get_shape()[1:].as_list(), tensor.get_shape().as_list() + ) + return result + + count_and_moments = _reduce_batch_count_l_moments(x) + counts = count_and_moments[0::2] + moments = count_and_moments[1::2] + for i in range(0, 4): + self.assertEqual(counts[i].dtype, expected_counts[i].dtype) + self.assertAllClose(counts[i], expected_counts[i], rtol=1e-8) + self.assertEqual(moments[i].dtype, expected_moments[i].dtype) + self.assertAllClose(moments[i], expected_moments[i], rtol=1e-8) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="dense", + x=[[1], [2], [3], [4], [4]], + key=["a", "a", "a", "b", "a"], + expected_key_vocab=[b"a", b"b"], + expected_count=[4.0, 1.0], + expected_mean=[2.5, 4.0], + expected_var=[1.25, 0.0], + reduce_instance_dims=True, + input_signature=[ + tf.TensorSpec([None, 1], tf.float32), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="dense_with_nans", + x=[[1], [2], [3], [4], [4], [np.nan], [np.nan]], + key=["a", "a", "a", "b", "a", "a", "b"], + expected_key_vocab=[b"a", b"b"], + expected_count=[4.0, 1.0], + expected_mean=[2.5, 4.0], + expected_var=[1.25, 0.0], + reduce_instance_dims=True, + input_signature=[ + tf.TensorSpec([None, 1], tf.float32), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="dense_elementwise", + x=[[1, 2], [3, 4], [1, 2]], + key=["a", "a", "b"], + expected_key_vocab=[b"a", b"b"], + expected_count=[[2.0, 2.0], [1.0, 1.0]], + expected_mean=[[2.0, 3.0], [1.0, 2.0]], + expected_var=[[1.0, 1.0], [0.0, 0.0]], + reduce_instance_dims=False, + input_signature=[ + tf.TensorSpec([None, 2], tf.float32), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="dense_elementwise_with_nans", + x=[[1, 2], [3, 4], [1, 2], [np.nan, np.nan]], + key=["a", "a", "b", "a"], + expected_key_vocab=[b"a", b"b"], + expected_count=[[2.0, 2.0], [1.0, 1.0]], + expected_mean=[[2.0, 3.0], [1.0, 2.0]], + expected_var=[[1.0, 1.0], [0.0, 0.0]], + reduce_instance_dims=False, + input_signature=[ + tf.TensorSpec([None, 2], tf.float32), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="sparse", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 3]], + values=[1.0, 2.0, 3.0, 4.0, 4.0], + dense_shape=[3, 4], + ), + key=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 3]], + values=["a", "a", "a", "a", "b"], + dense_shape=[3, 4], + ), + expected_key_vocab=[b"a", b"b"], + expected_count=[4, 1], + expected_mean=[2.5, 4], + expected_var=[1.25, 0], + reduce_instance_dims=True, + input_signature=[ + tf.SparseTensorSpec([None, 4], tf.float32), + tf.SparseTensorSpec([None, 4], tf.string), + ], + ), + dict( + testcase_name="sparse_with_nans", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 2], [2, 3]], + values=[1.0, 2.0, 3.0, 4.0, np.nan, 4.0], + dense_shape=[3, 4], + ), + key=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 2], [2, 3]], + values=["a", "a", "a", "a", "a", "b"], + dense_shape=[3, 4], + ), + expected_key_vocab=[b"a", b"b"], + expected_count=[4, 1], + expected_mean=[2.5, 4], + expected_var=[1.25, 0], + reduce_instance_dims=True, + input_signature=[ + tf.SparseTensorSpec([None, 4], tf.float32), + tf.SparseTensorSpec([None, 4], tf.string), + ], + ), + dict( + testcase_name="sparse_x_dense_key", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 2], [1, 1], [1, 2], [2, 3]], + values=[1.0, 2.0, 3.0, 4.0, 4.0], + dense_shape=[3, 4], + ), + key=["a", "a", "b"], + expected_key_vocab=[b"a", b"b"], + expected_count=[4, 1], + expected_mean=[2.5, 4], + expected_var=[1.25, 0], + reduce_instance_dims=True, + input_signature=[ + tf.SparseTensorSpec([None, 4], tf.float32), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="ragged", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([3.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + key=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array(["a", "a", "b", "a", "b"]), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + expected_key_vocab=[b"a", b"b"], + expected_count=[3, 2], + expected_mean=[3, 4], + expected_var=[np.float32(0.666667), 1.0], + reduce_instance_dims=True, + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32), + tf.RaggedTensorSpec([None, None, None, None], tf.string), + ], + ), + dict( + testcase_name="ragged_x_dense_key", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([3.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + key=["a", "b"], + expected_key_vocab=[b"a", b"b"], + expected_count=[4, 1], + expected_mean=[3, 5], + expected_var=[0.5, 0.0], + reduce_instance_dims=True, + input_signature=[ + tf.RaggedTensorSpec([2, None, None, None], tf.float32), + tf.TensorSpec([2], tf.string), + ], + ), + dict( + testcase_name="ragged_with_nans", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array( + [3.0, 2.0, 3.0, 4.0, 5.0, np.nan], np.float32 + ), + row_splits=np.array([0, 2, 3, 4, 6]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + key=["a", "b"], + expected_key_vocab=[b"a", b"b"], + expected_count=[4, 1], + expected_mean=[3, 5], + expected_var=[0.5, 0.0], + reduce_instance_dims=True, + input_signature=[ + tf.RaggedTensorSpec([2, None, None, None], tf.float32), + tf.TensorSpec([2], tf.string), + ], + ), + ] + ) + ) + def test_reduce_batch_count_mean_and_var_per_key( + self, + x, + key, + input_signature, + expected_key_vocab, + expected_count, + expected_mean, + expected_var, + reduce_instance_dims, + function_handler, + ): + @function_handler(input_signature=input_signature) + def _reduce_batch_count_mean_and_var_per_key(x, key): + return tf_utils.reduce_batch_count_mean_and_var_per_key( + x, key, reduce_instance_dims=reduce_instance_dims + ) + + key_vocab, count, mean, var = _reduce_batch_count_mean_and_var_per_key(x, key) + + self.assertAllEqual(key_vocab, expected_key_vocab) + self.assertAllEqual(count, expected_count) + self.assertAllEqual(mean, expected_mean) + self.assertAllEqual(var, expected_var) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="sparse", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 1], [0, 2]], + values=[3, 2, -1], + dense_shape=[1, 5], + ), + expected_x_minus_min=1, + expected_x_max=3, + reduce_instance_dims=True, + input_signature=[tf.SparseTensorSpec([None, None], tf.int64)], + ), + dict( + testcase_name="float", + x=[[1, 5, 2]], + expected_x_minus_min=-1, + expected_x_max=5, + reduce_instance_dims=True, + input_signature=[tf.TensorSpec([None, None], tf.float32)], + ), + dict( + testcase_name="sparse_float_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 1], [1, 0]], + values=[3, 2, -1], + dense_shape=[2, 3], + ), + expected_x_minus_min=[1, -2, np.nan], + expected_x_max=[3, 2, np.nan], + reduce_instance_dims=False, + input_signature=[tf.SparseTensorSpec([None, None], tf.float32)], + ), + dict( + testcase_name="float_elementwise", + x=[[1, 5, 2], [2, 3, 4]], + reduce_instance_dims=False, + expected_x_minus_min=[-1, -3, -2], + expected_x_max=[2, 5, 4], + input_signature=[tf.TensorSpec([None, None], tf.float32)], + ), + dict( + testcase_name="sparse_int64_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 1], [1, 0]], + values=[3, 2, -1], + dense_shape=[2, 3], + ), + reduce_instance_dims=False, + expected_x_minus_min=[1, -2, tf.int64.min + 1], + expected_x_max=[3, 2, tf.int64.min + 1], + input_signature=[tf.SparseTensorSpec([None, None], tf.int64)], + ), + dict( + testcase_name="sparse_int32_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 1], [1, 0]], + values=[3, 2, -1], + dense_shape=[2, 3], + ), + reduce_instance_dims=False, + expected_x_minus_min=[1, -2, tf.int32.min + 1], + expected_x_max=[3, 2, tf.int32.min + 1], + input_signature=[tf.SparseTensorSpec([None, None], tf.int32)], + ), + dict( + testcase_name="sparse_float64_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 1], [1, 0]], + values=[3, 2, -1], + dense_shape=[2, 3], + ), + reduce_instance_dims=False, + expected_x_minus_min=[1, -2, np.nan], + expected_x_max=[3, 2, np.nan], + input_signature=[tf.SparseTensorSpec([None, None], tf.float64)], + ), + dict( + testcase_name="sparse_float32_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 1], [1, 0]], + values=[3, 2, -1], + dense_shape=[2, 3], + ), + reduce_instance_dims=False, + expected_x_minus_min=[1, -2, np.nan], + expected_x_max=[3, 2, np.nan], + input_signature=[tf.SparseTensorSpec([None, None], tf.float32)], + ), + dict( + testcase_name="sparse_3d_elementwise", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 0], [0, 0, 1], [1, 0, 1]], + values=[3, 2, -1], + dense_shape=[2, 3, 3], + ), + reduce_instance_dims=False, + expected_x_minus_min=[[-3, 1, np.nan], [np.nan] * 3, [np.nan] * 3], + expected_x_max=[[3, 2, np.nan], [np.nan] * 3, [np.nan] * 3], + input_signature=[ + tf.SparseTensorSpec([None, None, None], tf.float32) + ], + ), + dict( + testcase_name="ragged", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 3, 5]), + ), + row_splits=np.array([0, 2, 3]), + ), + reduce_instance_dims=True, + expected_x_minus_min=-1.0, + expected_x_max=5.0, + input_signature=[tf.RaggedTensorSpec([2, None, None], tf.float32)], + ), + dict( + testcase_name="ragged_elementwise", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 2, 4, 5]), + ), + row_splits=np.array([0, 3, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + reduce_instance_dims=False, + expected_x_minus_min=[ + [[-1.0, -2.0], [np.nan, np.nan], [-3.0, -4.0]], + [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], + ], + expected_x_max=[ + [[5.0, 2.0], [np.nan, np.nan], [3.0, 4.0]], + [[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]], + ], + input_signature=[ + tf.RaggedTensorSpec([2, None, None, None], tf.float32) + ], + ), + dict( + testcase_name="all_nans", + x=[[np.nan, np.nan, np.nan]], + # Output of `tf.reduce_max` if all inputs are NaNs for older + # versions of TF is -inf. + expected_x_minus_min=( + -np.inf + if version.parse(tf.__version__) < version.parse("2.4") + else np.nan + ), + expected_x_max=( + -np.inf + if version.parse(tf.__version__) < version.parse("2.4") + else np.nan + ), + reduce_instance_dims=True, + input_signature=[tf.TensorSpec([None, None], tf.float32)], + ), + dict( + testcase_name="empty_batch", + x=[[]], + expected_x_minus_min=-np.inf, + expected_x_max=-np.inf, + reduce_instance_dims=True, + input_signature=[tf.TensorSpec([None, None], tf.float32)], + ), + ] + ) + ) + def test_reduce_batch_minus_min_and_max( + self, + x, + expected_x_minus_min, + expected_x_max, + reduce_instance_dims, + input_signature, + function_handler, + ): + @function_handler(input_signature=input_signature) + def _reduce_batch_minus_min_and_max(x): + result = tf_utils.reduce_batch_minus_min_and_max( + x, reduce_instance_dims=reduce_instance_dims + ) + # Verify that the output shapes are maintained. + if not reduce_instance_dims and not isinstance(x, tf.RaggedTensor): + for tensor in result: + self.assertEqual( + x.get_shape()[1:].as_list(), tensor.get_shape().as_list() + ) + return result + + x_minus_min, x_max = _reduce_batch_minus_min_and_max(x) + + self.assertAllEqual(x_minus_min, expected_x_minus_min) + self.assertAllEqual(x_max, expected_x_max) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="sparse", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [1, 1], [2, 2], [3, 1]], + values=[3, 2, -1, 3], + dense_shape=[4, 5], + ), + key=["a", "a", "a", "b"], + reduce_instance_dims=True, + expected_key_vocab=[b"a", b"b"], + expected_x_minus_min=[1, -3], + expected_x_max=[3, 3], + input_signature=[ + tf.SparseTensorSpec([None, None], tf.int64), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="float", + x=[[1], [5], [2], [3]], + key=["a", "a", "a", "b"], + reduce_instance_dims=True, + expected_key_vocab=[b"a", b"b"], + expected_x_minus_min=[-1, -3], + expected_x_max=[5, 3], + input_signature=[ + tf.TensorSpec([None, None], tf.float32), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="float_elementwise", + x=[[1], [5], [2], [3]], + key=["a", "a", "a", "b"], + reduce_instance_dims=False, + expected_key_vocab=[b"a", b"b"], + expected_x_minus_min=[[-1], [-3]], + expected_x_max=[[5], [3]], + input_signature=[ + tf.TensorSpec([None, None], tf.float32), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="float3dims", + x=[ + [[1, 5], [1, 1]], + [[5, 1], [5, 5]], + [[2, 2], [2, 5]], + [[3, -3], [3, 3]], + ], + key=["a", "a", "a", "b"], + reduce_instance_dims=True, + expected_key_vocab=[b"a", b"b"], + expected_x_minus_min=[-1, 3], + expected_x_max=[5, 3], + input_signature=[ + tf.TensorSpec([None, None, None], tf.float32), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="float3dims_elementwise", + x=[ + [[1, 5], [1, 1]], + [[5, 1], [5, 5]], + [[2, 2], [2, 5]], + [[3, -3], [3, 3]], + ], + key=["a", "a", "a", "b"], + reduce_instance_dims=False, + expected_key_vocab=[b"a", b"b"], + expected_x_minus_min=[[[-1, -1], [-1, -1]], [[-3, 3], [-3, -3]]], + expected_x_max=[[[5, 5], [5, 5]], [[3, -3], [3, 3]]], + input_signature=[ + tf.TensorSpec([None, None, None], tf.float32), + tf.TensorSpec([None], tf.string), + ], + ), + dict( + testcase_name="ragged", + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([3.0, 2.0, 3.0, 4.0, 5.0], np.float32), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + key=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array(["a", "a", "b", "a", "b"]), + row_splits=np.array([0, 2, 3, 4, 5]), + ), + row_splits=np.array([0, 2, 3, 4]), + ), + row_splits=np.array([0, 2, 3]), + ), + reduce_instance_dims=True, + expected_key_vocab=[b"a", b"b"], + expected_x_minus_min=[-2.0, -3.0], + expected_x_max=[4.0, 5.0], + input_signature=[ + tf.RaggedTensorSpec([None, None, None, None], tf.float32), + tf.RaggedTensorSpec([None, None, None, None], tf.string), + ], + ), + ] + ) + ) + def test_reduce_batch_minus_min_and_max_per_key( + self, + x, + key, + reduce_instance_dims, + expected_key_vocab, + expected_x_minus_min, + expected_x_max, + input_signature, + function_handler, + ): + @function_handler(input_signature=input_signature) + def _reduce_batch_minus_min_and_max_per_key(x, key): + return tf_utils.reduce_batch_minus_min_and_max_per_key( + x, key, reduce_instance_dims=reduce_instance_dims + ) + + key_vocab, x_minus_min, x_max = _reduce_batch_minus_min_and_max_per_key(x, key) + + self.assertAllEqual(key_vocab, expected_key_vocab) + self.assertAllEqual(x_minus_min, expected_x_minus_min) + self.assertAllEqual(x_max, expected_x_max) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="dense", + key=["a", "a", "a", "b"], + spec=tf.TensorSpec([None], tf.string), + expected_key_vocab=[b"a", b"b"], + expected_count=[3, 1], + ), + dict( + testcase_name="sparse", + key=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [1, 1], [2, 2], [3, 1]], + values=[3, 2, -1, 3], + dense_shape=[4, 5], + ), + spec=tf.SparseTensorSpec([4, 5], tf.int64), + expected_key_vocab=[b"3", b"2", b"-1"], + expected_count=[2, 1, 1], + ), + dict( + testcase_name="ragged", + key=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.2, 1.0, 1.2, 1.0]), + row_splits=np.array([0, 2, 4]), + ), + row_splits=np.array([0, 2]), + ), + spec=tf.RaggedTensorSpec([1, None, None], tf.float32), + expected_key_vocab=[b"1.200000", b"1.000000"], + expected_count=[2, 2], + ), + ] + ) + ) + def test_reduce_batch_count_per_key( + self, key, spec, expected_key_vocab, expected_count, function_handler + ): + @function_handler(input_signature=[spec]) + def _reduce_batch_count_per_key(key): + return tf_utils.reduce_batch_count_per_key(key) + + key_vocab, key_counts = _reduce_batch_count_per_key(key) + + self.assertAllEqual(key_vocab, expected_key_vocab) + self.assertAllEqual(key_counts, expected_count) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="full", + bucket_vocab=["1", "2", "0"], + counts=[3, 1, 4], + boundary_size=3, + expected_counts=[4, 3, 1], + ), + dict( + testcase_name="missing", + bucket_vocab=["1", "3", "0"], + counts=[3, 1, 4], + boundary_size=5, + expected_counts=[4, 3, 0, 1, 0], + ), + ] + ) + ) + def test_reorder_histogram( + self, bucket_vocab, counts, boundary_size, expected_counts, function_handler + ): + input_signature = [ + tf.TensorSpec([None], tf.string), + tf.TensorSpec([None], tf.int64), + tf.TensorSpec([], tf.int32), + ] + + @function_handler(input_signature=input_signature) + def _reorder_histogram(bucket_vocab, counts, boundary_size): + return tf_utils.reorder_histogram(bucket_vocab, counts, boundary_size) + + counts = _reorder_histogram(bucket_vocab, counts, boundary_size) + self.assertAllEqual(counts, expected_counts) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="simple", + x=[0.0, 2.0, 3.5, 4.0], + x_spec=tf.TensorSpec([None], tf.float32), + boundaries=[[1.0, 2.0, 3.0, 3.9]], + boundaries_spec=tf.TensorSpec([1, None], tf.float32), + side=tf_utils.Side.LEFT, + expected_buckets=[0, 1, 3, 3], + ), + dict( + testcase_name="simple_right", + x=[0.0, 2.0, 3.5, 4.0], + x_spec=tf.TensorSpec([None], tf.float32), + boundaries=[1.0, 2.0, 3.0, 3.9], + boundaries_spec=tf.TensorSpec([None], tf.float32), + side=tf_utils.Side.RIGHT, + expected_buckets=[0, 2, 3, 4], + ), + dict( + testcase_name="2dim", + x=[[0.0, 4.0, 3.5, 2.0, 1.7]], + x_spec=tf.TensorSpec([1, None], tf.float32), + boundaries=[[1.0, 2.0, 3.0, 5.0]], + boundaries_spec=tf.TensorSpec([1, None], tf.float32), + side=tf_utils.Side.LEFT, + expected_buckets=[[0, 3, 3, 1, 1]], + ), + dict( + testcase_name="large_buckets", + x=[[50_000_000]], + x_spec=tf.TensorSpec([1, None], tf.int64), + boundaries=[0, 50_000_001, 100_000_001], + boundaries_spec=tf.TensorSpec([None], tf.int64), + side=tf_utils.Side.RIGHT, + expected_buckets=[[1]], + ), + ] + ) + ) + def test_assign_buckets( + self, + x, + x_spec, + boundaries, + boundaries_spec, + side, + expected_buckets, + function_handler, + ): + @function_handler(input_signature=[x_spec, boundaries_spec]) + def _assign_buckets(x, boundaries): + return tf_utils.assign_buckets(x, boundaries, side) + + buckets = _assign_buckets(x, boundaries) + self.assertAllEqual(buckets, expected_buckets) + + def test_sparse_indices(self): + exception_cls = tf.errors.InvalidArgumentError + error_string = "Condition x == y did not hold element-wise:" + value = tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [1, 1], [2, 2], [3, 1]], + values=[3, 2, -1, 3], + dense_shape=[4, 5], + ) + key_value = tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [1, 2], [2, 2], [3, 1]], + values=["a", "a", "a", "b"], + dense_shape=[4, 5], + ) + with tf.compat.v1.Graph().as_default(): + x = tf.compat.v1.sparse_placeholder(tf.int64, shape=[None, None]) + key = tf.compat.v1.sparse_placeholder(tf.string, shape=[None, None]) + with tf.compat.v1.Session() as sess: + with self.assertRaisesRegex(exception_cls, error_string): + sess.run( + tf_utils.reduce_batch_minus_min_and_max_per_key(x, key), + feed_dict={x: value, key: key_value}, + ) + + def test_convert_sparse_indices(self): + exception_cls = tf.errors.InvalidArgumentError + error_string = "Condition x == y did not hold element-wise:" + sparse = tf.SparseTensor( + indices=[[0, 0, 0], [1, 0, 1], [2, 0, 2], [3, 0, 1]], + values=[3, 2, -1, 3], + dense_shape=[4, 2, 5], + ) + dense = tf.constant(["a", "b", "c", "d"]) + x, key = tf_utils._validate_and_get_dense_value_key_inputs(sparse, sparse) + self.assertAllEqual(self.evaluate(x), sparse.values) + self.assertAllEqual(self.evaluate(key), sparse.values) + + x, key = tf_utils._validate_and_get_dense_value_key_inputs(sparse, dense) + self.assertAllEqual(self.evaluate(x), sparse.values) + self.assertAllEqual(self.evaluate(key), dense) + + with tf.compat.v1.Graph().as_default(): + sparse1 = tf.compat.v1.sparse_placeholder( + tf.int64, shape=[None, None, None] + ) + sparse2 = tf.compat.v1.sparse_placeholder( + tf.int64, shape=[None, None, None] + ) + sparse_value1 = tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 0], [1, 0, 1], [2, 0, 2], [3, 0, 1]], + values=[3, 2, -1, 3], + dense_shape=[4, 2, 5], + ) + sparse_value2 = tf.compat.v1.SparseTensorValue( + indices=[[0, 0, 0], [1, 0, 2], [2, 0, 2], [3, 0, 1]], + values=[3, 2, -1, 3], + dense_shape=[4, 2, 5], + ) + + with tf.compat.v1.Session() as sess: + with self.assertRaisesRegex(exception_cls, error_string): + sess.run( + tf_utils._validate_and_get_dense_value_key_inputs( + sparse1, sparse2 + ), + feed_dict={sparse1: sparse_value1, sparse2: sparse_value2}, + ) + + def test_convert_ragged_indices(self): + exception_cls = tf.errors.InvalidArgumentError + error_string = "Condition x == y did not hold element-wise:" + ragged = tf.RaggedTensor.from_row_splits( + values=tf.RaggedTensor.from_row_splits( + values=np.array([1.2, 1.0, 1.2, 1.0]), row_splits=np.array([0, 2, 4]) + ), + row_splits=np.array([0, 1, 2]), + ) + dense = tf.constant(["a", "b"]) + dense_result = tf.constant(["a", "a", "b", "b"]) + x, key = tf_utils._validate_and_get_dense_value_key_inputs(ragged, ragged) + self.assertAllEqual(self.evaluate(x), ragged.flat_values) + self.assertAllEqual(self.evaluate(key), ragged.flat_values) + + x, key = tf_utils._validate_and_get_dense_value_key_inputs(ragged, dense) + self.assertAllEqual(self.evaluate(x), ragged.flat_values) + self.assertAllEqual(self.evaluate(key), dense_result) + + with tf.compat.v1.Graph().as_default(): + ragged1 = tf.compat.v1.ragged.placeholder(tf.float32, 2) + ragged2 = tf.compat.v1.ragged.placeholder(tf.float32, 2) + ragged_value1 = tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.2, 1.0, 1.2, 1.0]), + row_splits=np.array([0, 2, 4]), + ), + row_splits=np.array([0, 2]), + ) + ragged_value2 = tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.2, 1.0, 1.2, 1.0]), + row_splits=np.array([0, 3, 4]), + ), + row_splits=np.array([0, 2]), + ) + + with tf.compat.v1.Session() as sess: + with self.assertRaisesRegex(exception_cls, error_string): + sess.run( + tf_utils._validate_and_get_dense_value_key_inputs( + ragged1, ragged2 + ), + feed_dict={ragged1: ragged_value1, ragged2: ragged_value2}, + ) + + @test_case.named_parameters( + dict( + testcase_name="dense_tensor", + key=["b", "a", "b"], + key_vocab=["a", "b"], + reductions=([1, 2], [3, 4]), + x=[5, 6, 7], + reduce_instance_dims=True, + expected_results=([2, 1, 2], [4, 3, 4]), + ), + dict( + testcase_name="sparse_tensor_dense_key", + key=["b", "a", "b"], + key_vocab=["a", "b"], + reductions=([1, 2], [3, 4]), + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [1, 2], [2, 2], [2, 3]], + values=[3, 2, -1, 3], + dense_shape=[3, 5], + ), + reduce_instance_dims=True, + expected_results=([2, 1, 2, 2], [4, 3, 4, 4]), + ), + dict( + testcase_name="sparse_tensor_sparse_key", + key=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [1, 2], [2, 2], [2, 3]], + values=["b", "a", "b", "b"], + dense_shape=[3, 5], + ), + key_vocab=["a", "b"], + reductions=([1, 2], [3, 4]), + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [1, 2], [2, 2], [2, 3]], + values=[3, 2, -1, 3], + dense_shape=[3, 5], + ), + reduce_instance_dims=True, + expected_results=([2, 1, 2, 2], [4, 3, 4, 4]), + ), + dict( + testcase_name="ragged_tensor_dense_key", + key=["a", "b", "a"], + key_vocab=["a", "b"], + reductions=([1, 2], [3, 4]), + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.2, 1.0, 1.2, 1.0]), + row_splits=np.array([0, 2, 4]), + ), + row_splits=np.array([0, 1, 2, 2]), + ), + reduce_instance_dims=True, + expected_results=([1, 1, 2, 2], [3, 3, 4, 4]), + ), + dict( + testcase_name="ragged_tensor_ragged_key", + key=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array(["a", "b", "b", "a"]), + row_splits=np.array([0, 2, 4]), + ), + row_splits=np.array([0, 2]), + ), + key_vocab=["a", "b"], + reductions=([1, 2], [3, 4]), + x=tf.compat.v1.ragged.RaggedTensorValue( + values=tf.compat.v1.ragged.RaggedTensorValue( + values=np.array([1.2, 1.0, 1.2, 1.0]), + row_splits=np.array([0, 2, 4]), + ), + row_splits=np.array([0, 2]), + ), + reduce_instance_dims=True, + expected_results=([1, 2, 2, 1], [3, 4, 4, 3]), + ), + dict( + testcase_name="missing_key", + key=["b", "a", "c"], + key_vocab=["z", "a", "b"], + reductions=([-77, 1, 2], [-99, 3, 4]), + x=[5, 6, 7], + reduce_instance_dims=True, + expected_results=([2, 1, 0], [4, 3, 0]), + ), + dict( + testcase_name="_dense_tensor_2d_elementwise", + key=["a"], + key_vocab=["a", "b"], + reductions=([[1, 5], [-2, 0]], [[5, 9], [2, 4]]), + x=[[4, 8]], + reduce_instance_dims=False, + expected_results=([[1, 5]], [[5, 9]]), + ), + dict( + testcase_name="_dense_tensor_3d_elementwise", + key=["a"], + key_vocab=["a", "b"], + reductions=( + [[[1, 1], [1, 1]], [[3, -3], [3, 3]]], + [[[5, 5], [5, 5]], [[3, -3], [3, 3]]], + ), + x=[[[1, 5], [1, 1]]], + reduce_instance_dims=False, + expected_results=([[[1, 1], [1, 1]]], [[[5, 5], [5, 5]]]), + ), + ) + def test_map_per_key_reductions( + self, key, key_vocab, reductions, x, reduce_instance_dims, expected_results + ): + with tf.compat.v1.Graph().as_default(): + key = _value_to_tensor(key) + key_vocab = tf.constant(key_vocab) + reductions = tuple([tf.constant(t) for t in reductions]) + x = _value_to_tensor(x) + expected_results = tuple(tf.constant(t) for t in expected_results) + results = tf_utils.map_per_key_reductions( + reductions, key, key_vocab, x, reduce_instance_dims + ) + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.tables_initializer()) + output = sess.run(results) + for result, expected_result in zip(output, expected_results): + self.assertAllEqual(result, expected_result) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="sparse_tensor", + feature=tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 1], [0, 2], [1, 0]], + values=[1.0, 2.0, 3.0, 4.0], + dense_shape=[2, 5], + ), + input_signature=[tf.SparseTensorSpec([None, 5], tf.float32)], + ascii_protos=[ + "float_list { value: [1.0, 2.0, 3.0] }", + "float_list { value: [4.0] }", + ], + ), + dict( + testcase_name="dense_scalar_int", + feature=[0, 1, 2], + input_signature=[tf.TensorSpec([None], tf.int64)], + ascii_protos=[ + "int64_list { value: [0] }", + "int64_list { value: [1] }", + "int64_list { value: [2] }", + ], + ), + dict( + testcase_name="dense_scalar_float", + feature=[0.5, 1.5, 2.5], + input_signature=[tf.TensorSpec([None], tf.float32)], + ascii_protos=[ + "float_list { value: [0.5] }", + "float_list { value: [1.5] }", + "float_list { value: [2.5] }", + ], + ), + dict( + testcase_name="dense_scalar_string", + feature=["hello", "world"], + input_signature=[tf.TensorSpec([None], tf.string)], + ascii_protos=[ + 'bytes_list { value: "hello" }', + 'bytes_list { value: "world" }', + ], + ), + dict( + testcase_name="dense_vector_int", + feature=[[0, 1], [2, 3]], + input_signature=[tf.TensorSpec([None, 2], tf.int64)], + ascii_protos=[ + "int64_list { value: [0, 1] }", + "int64_list { value: [2, 3] }", + ], + ), + dict( + testcase_name="dense_matrix_int", + feature=[[[0, 1], [2, 3]], [[4, 5], [6, 7]]], + input_signature=[tf.TensorSpec([None, 2, 2], tf.int64)], + ascii_protos=[ + "int64_list { value: [0, 1, 2, 3] }", + "int64_list { value: [4, 5, 6, 7] }", + ], + ), + ] + ) + ) + def test_serialize_feature( + self, feature, input_signature, ascii_protos, function_handler + ): + @function_handler(input_signature=input_signature) + def _serialize_feature(feature): + return tf_utils._serialize_feature(feature) + + serialized_features = _serialize_feature(feature) + + self.assertEqual(len(ascii_protos), len(serialized_features)) + for ascii_proto, serialized_feature in zip(ascii_protos, serialized_features): + feature_proto = tf.train.Feature() + feature_proto.ParseFromString(serialized_feature) + self.assertProtoEquals(ascii_proto, feature_proto) + + @test_case.named_parameters( + dict( + testcase_name="multiple_features", + examples={ + "my_value": tf.compat.v1.SparseTensorValue( + indices=[[0, 0], [0, 1], [0, 2], [1, 0]], + values=[1.0, 2.0, 3.0, 4.0], + dense_shape=[2, 5], + ), + "my_other_value": np.array([1, 2], np.int64), + }, + ascii_protos=[ + """ features { feature { key: "my_value" @@ -2247,7 +2754,8 @@ def _serialize_feature(feature): value: { int64_list { value: [1] } } } } - """, """ + """, + """ features { feature { key: "my_value" @@ -2258,268 +2766,307 @@ def _serialize_feature(feature): value: { int64_list { value: [2] } } } } - """ - ])) - def test_serialize_example(self, examples, ascii_protos): - with tf.compat.v1.Graph().as_default(): - serialized_examples_tensor = tf_utils.serialize_example(examples) - with tf.compat.v1.Session(): - serialized_examples = serialized_examples_tensor.eval() - example_proto = tf.train.Example() - self.assertEqual(len(serialized_examples), len(ascii_protos)) - for ascii_proto, serialized_example in zip(ascii_protos, - serialized_examples): - example_proto.ParseFromString(serialized_example) - self.assertProtoEquals(ascii_proto, example_proto) - - def test_extend_reduced_batch_with_y_counts(self): - initial_reduction = tf_utils.ReducedBatchWeightedCounts( - unique_x=tf.constant(['foo', 'bar']), - summed_weights_per_x=tf.constant([3.0, 4.0]), - summed_positive_per_x_and_y=tf.constant([[1.0, 4.0], [1.0, 1.0]]), - counts_per_x=tf.constant([2, 5], tf.int64), - ) - y = tf.constant([0, 1, 1, 1, 0, 1, 1], tf.int64) - extended_batch = tf_utils.extend_reduced_batch_with_y_counts( - initial_reduction, y) - self.assertAllEqual(self.evaluate(extended_batch.unique_x), - np.array([b'foo', b'bar', b'global_y_count_sentinel'])) - self.assertAllClose( - self.evaluate(extended_batch.summed_weights_per_x), - np.array([3.0, 4.0, 7.0]), + """, + ], + ) ) - self.assertAllClose( - self.evaluate(extended_batch.summed_positive_per_x_and_y), - np.array([[1.0, 4.0], [1.0, 1.0], [2.0, 5.0]]), + def test_serialize_example(self, examples, ascii_protos): + with tf.compat.v1.Graph().as_default(): + serialized_examples_tensor = tf_utils.serialize_example(examples) + with tf.compat.v1.Session(): + serialized_examples = serialized_examples_tensor.eval() + example_proto = tf.train.Example() + self.assertEqual(len(serialized_examples), len(ascii_protos)) + for ascii_proto, serialized_example in zip(ascii_protos, serialized_examples): + example_proto.ParseFromString(serialized_example) + self.assertProtoEquals(ascii_proto, example_proto) + + def test_extend_reduced_batch_with_y_counts(self): + initial_reduction = tf_utils.ReducedBatchWeightedCounts( + unique_x=tf.constant(["foo", "bar"]), + summed_weights_per_x=tf.constant([3.0, 4.0]), + summed_positive_per_x_and_y=tf.constant([[1.0, 4.0], [1.0, 1.0]]), + counts_per_x=tf.constant([2, 5], tf.int64), + ) + y = tf.constant([0, 1, 1, 1, 0, 1, 1], tf.int64) + extended_batch = tf_utils.extend_reduced_batch_with_y_counts( + initial_reduction, y + ) + self.assertAllEqual( + self.evaluate(extended_batch.unique_x), + np.array([b"foo", b"bar", b"global_y_count_sentinel"]), + ) + self.assertAllClose( + self.evaluate(extended_batch.summed_weights_per_x), + np.array([3.0, 4.0, 7.0]), + ) + self.assertAllClose( + self.evaluate(extended_batch.summed_positive_per_x_and_y), + np.array([[1.0, 4.0], [1.0, 1.0], [2.0, 5.0]]), + ) + self.assertAllClose( + self.evaluate(extended_batch.counts_per_x), np.array([2.0, 5.0, 7.0]) + ) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="vocab_size_1", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 1], [0, 2], [0, 3], [1, 5], [100, 2]], + values=[0, 1, 9, -4, 100], + dense_shape=(100, 10), + ), + vocab_size=1, + input_signature=[ + tf.SparseTensorSpec([None, 10], tf.int64), + tf.TensorSpec([], tf.int64), + ], + expected_output_values=[0, 0, 0, 0, 0], + ), + dict( + testcase_name="vocab_size_9", + x=tf.compat.v1.SparseTensorValue( + indices=[[0, 1], [0, 2], [0, 3], [1, 5], [100, 2]], + values=[0, 1, 9, -4, 100], + dense_shape=(100, 6), + ), + vocab_size=9, + input_signature=[ + tf.SparseTensorSpec([None, 6], tf.int64), + tf.TensorSpec([], tf.int64), + ], + expected_output_values=[0, 1, 0, 5, 1], + ), + ] + ) ) - self.assertAllClose( - self.evaluate(extended_batch.counts_per_x), np.array([2.0, 5.0, 7.0]) + def test_to_vocab_range( + self, x, vocab_size, input_signature, expected_output_values, function_handler + ): + @function_handler(input_signature=input_signature) + def _to_vocab_range(x, vocab_size): + cleaned_x = tf_utils.to_vocab_range(x, vocab_size) + self.assertIsInstance(cleaned_x, tf.SparseTensor) + return cleaned_x.indices, cleaned_x.values, cleaned_x.dense_shape + + output_indices, output_values, output_dense_shape = _to_vocab_range( + x, vocab_size + ) + self.assertAllEqual(output_indices, x.indices) + self.assertAllEqual(output_values, expected_output_values) + self.assertAllEqual(output_dense_shape, x.dense_shape) + + @test_case.named_parameters( + test_case.cross_with_function_handlers( + [ + dict( + testcase_name="df_to_idf", + df_input=[0, 1, 10], + corpus_size=10, + smooth=True, + add_baseline=True, + expected_idf=[3.3978952728, 2.70474809224, 1.0], + ), + dict( + testcase_name="df_to_idf_zero_corpus", + df_input=[0, 1, 10], + corpus_size=0, + smooth=True, + add_baseline=True, + expected_idf=[1.0, 0.30685281944, -1.3978953], + ), + dict( + testcase_name="df_to_idf_non_smooth", + df_input=[1, 2, 10], + corpus_size=10, + smooth=False, + add_baseline=True, + expected_idf=[3.30258509299, 2.60943791243, 1.0], + ), + dict( + testcase_name="df_to_idf_no_baseline", + df_input=[0, 1, 10], + corpus_size=10, + smooth=True, + add_baseline=False, + expected_idf=[2.3978952728, 1.70474809224, 0.0], + ), + ] + ) ) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='vocab_size_1', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 1], [0, 2], [0, 3], [1, 5], [100, 2]], - values=[0, 1, 9, -4, 100], - dense_shape=(100, 10)), - vocab_size=1, - input_signature=[ - tf.SparseTensorSpec([None, 10], tf.int64), - tf.TensorSpec([], tf.int64), - ], - expected_output_values=[0, 0, 0, 0, 0], - ), - dict( - testcase_name='vocab_size_9', - x=tf.compat.v1.SparseTensorValue( - indices=[[0, 1], [0, 2], [0, 3], [1, 5], [100, 2]], - values=[0, 1, 9, -4, 100], - dense_shape=(100, 6)), - vocab_size=9, - input_signature=[ - tf.SparseTensorSpec([None, 6], tf.int64), - tf.TensorSpec([], tf.int64), - ], - expected_output_values=[0, 1, 0, 5, 1], - ), - ])) - def test_to_vocab_range(self, x, vocab_size, input_signature, - expected_output_values, function_handler): - - @function_handler(input_signature=input_signature) - def _to_vocab_range(x, vocab_size): - cleaned_x = tf_utils.to_vocab_range(x, vocab_size) - self.assertIsInstance(cleaned_x, tf.SparseTensor) - return cleaned_x.indices, cleaned_x.values, cleaned_x.dense_shape - - output_indices, output_values, output_dense_shape = _to_vocab_range( - x, vocab_size) - self.assertAllEqual(output_indices, x.indices) - self.assertAllEqual(output_values, expected_output_values) - self.assertAllEqual(output_dense_shape, x.dense_shape) - - @test_case.named_parameters( - test_case.cross_with_function_handlers([ - dict( - testcase_name='df_to_idf', - df_input=[0, 1, 10], - corpus_size=10, - smooth=True, - add_baseline=True, - expected_idf=[3.3978952728, 2.70474809224, 1.0]), - dict( - testcase_name='df_to_idf_zero_corpus', - df_input=[0, 1, 10], - corpus_size=0, - smooth=True, - add_baseline=True, - expected_idf=[1.0, 0.30685281944, -1.3978953]), - dict( - testcase_name='df_to_idf_non_smooth', - df_input=[1, 2, 10], - corpus_size=10, - smooth=False, - add_baseline=True, - expected_idf=[3.30258509299, 2.60943791243, 1.0]), - dict( - testcase_name='df_to_idf_no_baseline', - df_input=[0, 1, 10], - corpus_size=10, - smooth=True, - add_baseline=False, - expected_idf=[2.3978952728, 1.70474809224, 0.0]), - ])) - def test_document_frequency_to_idf(self, df_input, corpus_size, smooth, - add_baseline, expected_idf, - function_handler): - input_signature = [ - tf.TensorSpec([None], tf.int64), - tf.TensorSpec([], tf.int64), - ] - - @function_handler(input_signature=input_signature) - def _to_idf(df, corpus_size): - return tf_utils.document_frequency_to_idf( - df, corpus_size, smooth=smooth, add_baseline=add_baseline) - - idf_output = _to_idf(df_input, corpus_size) - self.assertAllClose(idf_output, expected_idf) + def test_document_frequency_to_idf( + self, + df_input, + corpus_size, + smooth, + add_baseline, + expected_idf, + function_handler, + ): + input_signature = [ + tf.TensorSpec([None], tf.int64), + tf.TensorSpec([], tf.int64), + ] + + @function_handler(input_signature=input_signature) + def _to_idf(df, corpus_size): + return tf_utils.document_frequency_to_idf( + df, corpus_size, smooth=smooth, add_baseline=add_baseline + ) + + idf_output = _to_idf(df_input, corpus_size) + self.assertAllClose(idf_output, expected_idf) class VocabTFUtilsTest(test_case.TransformTestCase): - - def _write_tfrecords(self, path, bytes_records): - with tf.io.TFRecordWriter(path, 'GZIP') as writer: - for record in bytes_records: - writer.write(record) - - def test_split_vocabulary_entries(self): - x = tf.constant([b'1 a b ', b'2 c', b'3 . . . ']) - keys, values = tf_utils._split_vocabulary_entries(x) - expected_keys = [b' a b ', b'c', b' . . . '] - expected_values = [b'1', b'2', b'3'] - self.assertAllEqual(self.evaluate(keys), np.array(expected_keys)) - self.assertAllEqual(self.evaluate(values), np.array(expected_values)) - - def test_read_tfrecord_vocabulary_dataset(self): - vocab_file = os.path.join(self.get_temp_dir(), 'vocab.tfrecord.gz') - contents = [b'a', b'b', b'c'] - self._write_tfrecords(vocab_file, contents) - self.AssertVocabularyContents(vocab_file, contents) - - ds = tf.data.TFRecordDataset(vocab_file, compression_type='GZIP') - self.assertAllEqual(np.array(contents), list(ds.as_numpy_iterator())) - - @test_case.named_parameters([ - dict( - testcase_name='_common', - contents=[b'a', b'b', b' c '], - expected=[(b'a', 0), (b'b', 1), (b' c ', 2)], - key_dtype=tf.string, - value_dtype=tf.int64, - return_indicator_as_value=False, - has_indicator=False), - dict( - testcase_name='_dtypes', - contents=[b'17', b'42'], - expected=[(17, 0.), (42, 1.)], - key_dtype=tf.int64, - value_dtype=tf.float32, - return_indicator_as_value=False, - has_indicator=False), - dict( - testcase_name='_drop_indicator', - contents=[b'17 a', b'42 b'], - expected=[(b'a', 0), (b'b', 1)], - key_dtype=tf.string, - value_dtype=tf.int64, - return_indicator_as_value=False, - has_indicator=True), - dict( - testcase_name='_indicator_value', - contents=[b'17 a', b'42 b '], - expected=[(b'a', 17), (b'b ', 42)], - key_dtype=tf.string, - value_dtype=tf.int64, - return_indicator_as_value=True, - has_indicator=True), - dict( - testcase_name='_indicator_value_dtype', - contents=[b'17 a', b'42 b'], - expected=[(b'a', 17.), (b'b', 42.)], - key_dtype=tf.string, - value_dtype=tf.float32, - return_indicator_as_value=True, - has_indicator=True), - ]) - def test_make_tfrecord_vocabulary_dataset(self, contents, expected, key_dtype, - value_dtype, - return_indicator_as_value, - has_indicator): - vocab_file = os.path.join(self.get_temp_dir(), 'vocab.tfrecord.gz') - self._write_tfrecords(vocab_file, contents) - - ds = tf_utils._make_tfrecord_vocabulary_dataset( - vocab_file, - key_dtype=key_dtype, - value_dtype=value_dtype, - return_indicator_as_value=return_indicator_as_value, - has_indicator=has_indicator) - - def validate_dtypes(key, value): - self.assertEqual(key.dtype, key_dtype) - self.assertEqual(value.dtype, value_dtype) - return key, value - - ds = ds.map(validate_dtypes) - - vocabulary = list(ds.as_numpy_iterator()) - self.assertAllEqual(expected, vocabulary) - - @test_case.named_parameters(test_case.FUNCTION_HANDLERS) - def test_make_tfrecord_vocabulary_lookup_initializer(self, function_handler): - vocab_file = os.path.join(self.get_temp_dir(), 'vocab.tfrecord.gz') - contents = [b'%i' % idx for idx in range(1000)] - self._write_tfrecords(vocab_file, contents) - - input_signature = [tf.TensorSpec(None, tf.string)] - - @function_handler(input_signature=input_signature) - def lookup(x): - initializer = tf_utils.make_tfrecord_vocabulary_lookup_initializer( - vocab_file) - table = tf.lookup.StaticHashTable(initializer, -1) - return table.lookup(x) - - # make_tfrecord_vocabulary_lookup_initializer calls annotators.track_object - # which expects to be invoked inside an object_tracker_scope. - with annotators.object_tracker_scope(annotators.ObjectTracker()): - self.assertEqual(lookup('5'), 5) - self.assertEqual(lookup('1000'), -1) - - @test_case.named_parameters( - test_case.cross_with_function_handlers(_CONSTRUCT_TABLE_PARAMETERS)) - def test_construct_and_lookup_table(self, asset_path_input_fn, - function_handler): - vocab_filename = os.path.join(self.get_temp_dir(), 'test.txt') - vocab_data = [('a', '0'), ('b', '1'), ('c', '1'), ('d', '2'), ()] - encoded_vocab = '\n'.join(['\t'.join(pair) for pair in vocab_data]) - with tf.io.gfile.GFile(vocab_filename, 'w') as writer: - writer.write(encoded_vocab) - - @function_handler( - input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) - def foo(input_tensor): - output_tensor, unused_table_size = tf_utils.construct_and_lookup_table( - _construct_table, asset_path_input_fn(vocab_filename), input_tensor) - return output_tensor - - expected_data = [0, 1, 1, 2, -1] - output_tensor = foo(['a', 'b', 'c', 'd', 'e']) - self.assertAllEqual(output_tensor, expected_data) - - -if __name__ == '__main__': - test_case.main() + def _write_tfrecords(self, path, bytes_records): + with tf.io.TFRecordWriter(path, "GZIP") as writer: + for record in bytes_records: + writer.write(record) + + def test_split_vocabulary_entries(self): + x = tf.constant([b"1 a b ", b"2 c", b"3 . . . "]) + keys, values = tf_utils._split_vocabulary_entries(x) + expected_keys = [b" a b ", b"c", b" . . . "] + expected_values = [b"1", b"2", b"3"] + self.assertAllEqual(self.evaluate(keys), np.array(expected_keys)) + self.assertAllEqual(self.evaluate(values), np.array(expected_values)) + + def test_read_tfrecord_vocabulary_dataset(self): + vocab_file = os.path.join(self.get_temp_dir(), "vocab.tfrecord.gz") + contents = [b"a", b"b", b"c"] + self._write_tfrecords(vocab_file, contents) + self.AssertVocabularyContents(vocab_file, contents) + + ds = tf.data.TFRecordDataset(vocab_file, compression_type="GZIP") + self.assertAllEqual(np.array(contents), list(ds.as_numpy_iterator())) + + @test_case.named_parameters( + [ + dict( + testcase_name="_common", + contents=[b"a", b"b", b" c "], + expected=[(b"a", 0), (b"b", 1), (b" c ", 2)], + key_dtype=tf.string, + value_dtype=tf.int64, + return_indicator_as_value=False, + has_indicator=False, + ), + dict( + testcase_name="_dtypes", + contents=[b"17", b"42"], + expected=[(17, 0.0), (42, 1.0)], + key_dtype=tf.int64, + value_dtype=tf.float32, + return_indicator_as_value=False, + has_indicator=False, + ), + dict( + testcase_name="_drop_indicator", + contents=[b"17 a", b"42 b"], + expected=[(b"a", 0), (b"b", 1)], + key_dtype=tf.string, + value_dtype=tf.int64, + return_indicator_as_value=False, + has_indicator=True, + ), + dict( + testcase_name="_indicator_value", + contents=[b"17 a", b"42 b "], + expected=[(b"a", 17), (b"b ", 42)], + key_dtype=tf.string, + value_dtype=tf.int64, + return_indicator_as_value=True, + has_indicator=True, + ), + dict( + testcase_name="_indicator_value_dtype", + contents=[b"17 a", b"42 b"], + expected=[(b"a", 17.0), (b"b", 42.0)], + key_dtype=tf.string, + value_dtype=tf.float32, + return_indicator_as_value=True, + has_indicator=True, + ), + ] + ) + def test_make_tfrecord_vocabulary_dataset( + self, + contents, + expected, + key_dtype, + value_dtype, + return_indicator_as_value, + has_indicator, + ): + vocab_file = os.path.join(self.get_temp_dir(), "vocab.tfrecord.gz") + self._write_tfrecords(vocab_file, contents) + + ds = tf_utils._make_tfrecord_vocabulary_dataset( + vocab_file, + key_dtype=key_dtype, + value_dtype=value_dtype, + return_indicator_as_value=return_indicator_as_value, + has_indicator=has_indicator, + ) + + def validate_dtypes(key, value): + self.assertEqual(key.dtype, key_dtype) + self.assertEqual(value.dtype, value_dtype) + return key, value + + ds = ds.map(validate_dtypes) + + vocabulary = list(ds.as_numpy_iterator()) + self.assertAllEqual(expected, vocabulary) + + @test_case.named_parameters(test_case.FUNCTION_HANDLERS) + def test_make_tfrecord_vocabulary_lookup_initializer(self, function_handler): + vocab_file = os.path.join(self.get_temp_dir(), "vocab.tfrecord.gz") + contents = [b"%i" % idx for idx in range(1000)] + self._write_tfrecords(vocab_file, contents) + + input_signature = [tf.TensorSpec(None, tf.string)] + + @function_handler(input_signature=input_signature) + def lookup(x): + initializer = tf_utils.make_tfrecord_vocabulary_lookup_initializer( + vocab_file + ) + table = tf.lookup.StaticHashTable(initializer, -1) + return table.lookup(x) + + # make_tfrecord_vocabulary_lookup_initializer calls annotators.track_object + # which expects to be invoked inside an object_tracker_scope. + with annotators.object_tracker_scope(annotators.ObjectTracker()): + self.assertEqual(lookup("5"), 5) + self.assertEqual(lookup("1000"), -1) + + @test_case.named_parameters( + test_case.cross_with_function_handlers(_CONSTRUCT_TABLE_PARAMETERS) + ) + def test_construct_and_lookup_table(self, asset_path_input_fn, function_handler): + vocab_filename = os.path.join(self.get_temp_dir(), "test.txt") + vocab_data = [("a", "0"), ("b", "1"), ("c", "1"), ("d", "2"), ()] + encoded_vocab = "\n".join(["\t".join(pair) for pair in vocab_data]) + with tf.io.gfile.GFile(vocab_filename, "w") as writer: + writer.write(encoded_vocab) + + @function_handler( + input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)] + ) + def foo(input_tensor): + output_tensor, unused_table_size = tf_utils.construct_and_lookup_table( + _construct_table, asset_path_input_fn(vocab_filename), input_tensor + ) + return output_tensor + + expected_data = [0, 1, 1, 2, -1] + output_tensor = foo(["a", "b", "c", "d", "e"]) + self.assertAllEqual(output_tensor, expected_data) + + +if __name__ == "__main__": + test_case.main() diff --git a/tensorflow_transform/version.py b/tensorflow_transform/version.py index ffec4a9..4ec0b9d 100644 --- a/tensorflow_transform/version.py +++ b/tensorflow_transform/version.py @@ -14,4 +14,4 @@ """Contains the version string of TF.Transform.""" # Note that setup.py uses this version. -__version__ = '1.17.0.dev' +__version__ = "1.17.0.dev" From da3548c29ec4bae53a15cd6a928880639fda7006 Mon Sep 17 00:00:00 2001 From: andrewfulton9 Date: Wed, 14 May 2025 11:08:06 -0600 Subject: [PATCH 3/4] adds failing codes to ignores list to be fixed later --- pyproject.toml | 49 ++++++++++++++++++- .../beam/combiner_packing_util.py | 28 +++++------ tensorflow_transform/beam/impl.py | 5 +- 3 files changed, 64 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 10cdddd..609d3a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,8 +78,55 @@ ignore = [ "D206", # Docstrings should be indented with spaces; unnecessary when running ruff-format "E501", # Line length too long; unnecessary when running ruff-format "W191", # Indentation contains tabs; unnecessary when running ruff-format -] + # FIX AND REMOVE BELOW CODES: + "ANN202", # Missing return type annotation for private function + "ANN001", # Missing type annotation for function argument + "D102", # Missing docstring in public method + "SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements + "PT027", # Use `pytest.raises` instead of unittest-style `assertRaisesRegex` + "PT009", # Use a regular `assert` instead of unittest-style `assertEqual` / `assertIsInstance` + "PD011", # Use `.to_numpy()` instead of `.values` + "D101", # Missing docstring in public class + "D401", # First line of docstring should be in imperative mood + "FIX002", # Line contains TODO, consider resolving the issue + "RET505", # Unnecessary `else` after `return` statement + "E721", # Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks + "UP031", # Use format specifiers instead of percent format + "E731", # Do not assign a `lambda` expression, use a `def` + "ARG005", # Unused lambda argument + "ARG001", # Unused function argument + "SIM102", # Use a single `if` statement instead of nested `if` statements + "RET504", # Unnecessary assignment to `result` before `return` statement + "N802", # Function name should be lowercase + "RET506", # Unnecessary `else` after `raise` statement + "B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` + "ARG002", # Unused method argument + "ERA001", # Found commented-out code + "RET503", # Missing explicit `return` at the end of function able to return non-`None` value + "F401", # `module` imported but unused + "D103", # Missing docstring in public function + "F403", # `from module import *` used; unable to detect undefined names + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + "ANN206", # Missing return type annotation for classmethod + "ANN102", # Missing type annotation for `cls` in classmethod + "D107", # Missing docstring in `__init__` + "UP028", # Replace `yield` over `for` loop with `yield from` + "B023", # Function definition does not bind loop variable + "UP032", # Use f-string instead of `format` call + "E741", # Ambiguous variable name + "N803", # Argument name should be lowercase + "ANN205", # Missing return type annotation for staticmethod + "UP029", # Unnecessary builtin import + "SIM105", # Use `contextlib.suppress(KeyError)` instead of `try`-`except`-`pass` + "SIM118", # Use `key in dict` instead of `key in dict.keys()` + "F811", # Redefinition of unused name + "UP008", # Use `super()` instead of `super(__class__, self)` + "D417", # Missing argument description in the docstring + "SIM103", # Return the condition directly + "D404", # First word of the docstring should not be "This" + "NPY002", # Replace legacy `np.random.uniform` call with `np.random.Generator` +] [tool.ruff.lint.per-file-ignores] "__init__.py" = ["F401"] diff --git a/tensorflow_transform/beam/combiner_packing_util.py b/tensorflow_transform/beam/combiner_packing_util.py index 465f58d..e40bb8a 100644 --- a/tensorflow_transform/beam/combiner_packing_util.py +++ b/tensorflow_transform/beam/combiner_packing_util.py @@ -120,8 +120,8 @@ def _maybe_add_packable_combine(self, operation_def, input_values): class _PackAccumulateCombineVisitor(_ValidationVisitor): r"""A visitor that packs combine nodes in the graph. - This visitor takes the grouped combines and performs the packing of those - combines. + This visitor takes the grouped combines and performs the packing of those + combines. Before packing GrandParentNode / \ @@ -136,9 +136,9 @@ class _PackAccumulateCombineVisitor(_ValidationVisitor): / \ ExtractFromDict1' ExtractFromDict2' - The ExtractFromDict nodes after packing extracts the accumulator corresponding - to the individual combines. - """ + The ExtractFromDict nodes after packing extracts the accumulator corresponding + to the individual combines. + """ def __init__(self, packable_combines): super().__init__() @@ -246,8 +246,8 @@ def _maybe_add_packable_combine(self, operation_def, input_values): class _PackMergeCombineVisitor(_ValidationVisitor): r"""A visitor that inspects the graph and looks for combine nodes. - This visitor takes the grouped combines and performs the packing of those - combines. + This visitor takes the grouped combines and performs the packing of those + combines. Before packing ... ... / \ @@ -270,13 +270,13 @@ class _PackMergeCombineVisitor(_ValidationVisitor): / \ ExtractPackedCombineMergeOutputs1 ExtractPackedCombineMergeOutputs2 - Since the inputs to the final flatten node before the packed merge come from - different paths, we add redundant flatten and packed merge nodes each time we - visit a new input of the final flatten node. At the end of this traversal, - we would have one final packed merge node with a corresponding flatten node - having all the needed inputs, and in addition to this we would have a set of - redundant packed merge and flatten nodes which needs to be removed. - """ + Since the inputs to the final flatten node before the packed merge come from + different paths, we add redundant flatten and packed merge nodes each time we + visit a new input of the final flatten node. At the end of this traversal, + we would have one final packed merge node with a corresponding flatten node + having all the needed inputs, and in addition to this we would have a set of + redundant packed merge and flatten nodes which needs to be removed. + """ def __init__(self, packable_combine_extract_outputs): super().__init__() diff --git a/tensorflow_transform/beam/impl.py b/tensorflow_transform/beam/impl.py index ef9eb20..43d60b5 100644 --- a/tensorflow_transform/beam/impl.py +++ b/tensorflow_transform/beam/impl.py @@ -1181,9 +1181,8 @@ def expand(self, dataset): if graph.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES): raise ValueError( - "The preprocessing function contained trainable variables " "{}".format( - graph.get_collection_ref(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES) - ) + "The preprocessing function contained trainable variables " + f"{graph.get_collection_ref(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)}" ) pipeline = ( From 9ae8239678562d6f4c1cca201d5037b9d7902d43 Mon Sep 17 00:00:00 2001 From: andrewfulton9 Date: Wed, 14 May 2025 11:10:11 -0600 Subject: [PATCH 4/4] order codes --- pyproject.toml | 68 +++++++++++++++++++++++++------------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 609d3a0..247a98d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,52 +80,52 @@ ignore = [ "W191", # Indentation contains tabs; unnecessary when running ruff-format # FIX AND REMOVE BELOW CODES: - "ANN202", # Missing return type annotation for private function "ANN001", # Missing type annotation for function argument - "D102", # Missing docstring in public method - "SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements - "PT027", # Use `pytest.raises` instead of unittest-style `assertRaisesRegex` - "PT009", # Use a regular `assert` instead of unittest-style `assertEqual` / `assertIsInstance` - "PD011", # Use `.to_numpy()` instead of `.values` + "ANN102", # Missing type annotation for `cls` in classmethod + "ANN202", # Missing return type annotation for private function + "ANN205", # Missing return type annotation for staticmethod + "ANN206", # Missing return type annotation for classmethod + "ANN401", # Dynamically typed expressions (typing.Any) are disallowed + "ARG001", # Unused function argument + "ARG002", # Unused method argument + "ARG005", # Unused lambda argument + "B023", # Function definition does not bind loop variable + "B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` "D101", # Missing docstring in public class + "D102", # Missing docstring in public method + "D103", # Missing docstring in public function + "D107", # Missing docstring in `__init__` "D401", # First line of docstring should be in imperative mood - "FIX002", # Line contains TODO, consider resolving the issue - "RET505", # Unnecessary `else` after `return` statement + "D404", # First word of the docstring should not be "This" + "D417", # Missing argument description in the docstring "E721", # Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks - "UP031", # Use format specifiers instead of percent format "E731", # Do not assign a `lambda` expression, use a `def` - "ARG005", # Unused lambda argument - "ARG001", # Unused function argument - "SIM102", # Use a single `if` statement instead of nested `if` statements - "RET504", # Unnecessary assignment to `result` before `return` statement - "N802", # Function name should be lowercase - "RET506", # Unnecessary `else` after `raise` statement - "B904", # Within an `except` clause, raise exceptions with `raise ... from err` or `raise ... from None` - "ARG002", # Unused method argument + "E741", # Ambiguous variable name "ERA001", # Found commented-out code - "RET503", # Missing explicit `return` at the end of function able to return non-`None` value "F401", # `module` imported but unused - "D103", # Missing docstring in public function "F403", # `from module import *` used; unable to detect undefined names - "ANN401", # Dynamically typed expressions (typing.Any) are disallowed - "ANN206", # Missing return type annotation for classmethod - "ANN102", # Missing type annotation for `cls` in classmethod - "D107", # Missing docstring in `__init__` - "UP028", # Replace `yield` over `for` loop with `yield from` - "B023", # Function definition does not bind loop variable - "UP032", # Use f-string instead of `format` call - "E741", # Ambiguous variable name + "F811", # Redefinition of unused name + "FIX002", # Line contains TODO, consider resolving the issue + "N802", # Function name should be lowercase "N803", # Argument name should be lowercase - "ANN205", # Missing return type annotation for staticmethod - "UP029", # Unnecessary builtin import + "NPY002", # Replace legacy `np.random.uniform` call with `np.random.Generator` + "PD011", # Use `.to_numpy()` instead of `.values` + "PT009", # Use a regular `assert` instead of unittest-style `assertEqual` / `assertIsInstance` + "PT027", # Use `pytest.raises` instead of unittest-style `assertRaisesRegex` + "RET503", # Missing explicit `return` at the end of function able to return non-`None` value + "RET504", # Unnecessary assignment to `result` before `return` statement + "RET505", # Unnecessary `else` after `return` statement + "RET506", # Unnecessary `else` after `raise` statement + "SIM102", # Use a single `if` statement instead of nested `if` statements + "SIM103", # Return the condition directly "SIM105", # Use `contextlib.suppress(KeyError)` instead of `try`-`except`-`pass` + "SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements "SIM118", # Use `key in dict` instead of `key in dict.keys()` - "F811", # Redefinition of unused name "UP008", # Use `super()` instead of `super(__class__, self)` - "D417", # Missing argument description in the docstring - "SIM103", # Return the condition directly - "D404", # First word of the docstring should not be "This" - "NPY002", # Replace legacy `np.random.uniform` call with `np.random.Generator` + "UP028", # Replace `yield` over `for` loop with `yield from` + "UP029", # Unnecessary builtin import + "UP031", # Use format specifiers instead of percent format + "UP032", # Use f-string instead of `format` call ] [tool.ruff.lint.per-file-ignores]