Skip to content

Commit 7f432b9

Browse files
author
Jiyong Jung
authored
Revert "Introducing RecordBatchToExamplesEncoder to encode nested lists representing tf.RaggedTensor as tf.Examples." (#4306)
This reverts commit f6beebf.
1 parent 9f05db3 commit 7f432b9

File tree

2 files changed

+59
-67
lines changed

2 files changed

+59
-67
lines changed

RELEASE.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
pre-defined schema file. ImportSchemaGen will replace `Importer` with
1414
simpler syntax and less constraints. You have to pass the file path to the
1515
schema file instead of the parent directory unlike `Importer`.
16-
* Added support for outputting and encoding `tf.RaggedTensor`s in TFX
17-
Transform component.
1816

1917
## Breaking Changes
2018

tfx/components/transform/executor.py

Lines changed: 59 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import functools
1717
import hashlib
1818
import os
19-
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union
19+
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union
2020

2121
from absl import logging
2222
import apache_beam as beam
@@ -275,24 +275,6 @@ def _InvokeStatsOptionsUpdaterFn(
275275
return stats_options_updater_fn(stats_type, tfdv.StatsOptions(**options))
276276

277277

278-
def _FilterInternalColumn(
279-
record_batch: pa.RecordBatch,
280-
internal_column_index: Optional[int] = None) -> pa.RecordBatch:
281-
"""Returns shallow copy of a RecordBatch with internal column removed."""
282-
if (internal_column_index is None and
283-
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY not in record_batch.schema.names):
284-
return record_batch
285-
else:
286-
internal_column_index = (
287-
internal_column_index or
288-
record_batch.schema.names.index(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY))
289-
# Making shallow copy since input modification is not allowed.
290-
filtered_columns = list(record_batch.columns)
291-
filtered_columns.pop(internal_column_index)
292-
filtered_schema = record_batch.schema.remove(internal_column_index)
293-
return pa.RecordBatch.from_arrays(filtered_columns, schema=filtered_schema)
294-
295-
296278
class Executor(base_beam_executor.BaseBeamExecutor):
297279
"""Transform executor."""
298280

@@ -693,7 +675,7 @@ def _GenerateAndMaybeValidateStats(
693675

694676
generated_stats = (
695677
pcoll
696-
| 'FilterInternalColumn' >> beam.Map(_FilterInternalColumn)
678+
| 'FilterInternalColumn' >> beam.Map(Executor._FilterInternalColumn)
697679
| 'GenerateStatistics' >> tfdv.GenerateStatistics(stats_options))
698680

699681
stats_result = (
@@ -751,42 +733,6 @@ def setup(self):
751733
def process(self, element: List[bytes]) -> Iterable[pa.RecordBatch]:
752734
yield self._decoder.DecodeBatch(element)
753735

754-
@beam.typehints.with_input_types(Tuple[pa.RecordBatch, Dict[str, pa.Array]])
755-
@beam.typehints.with_output_types(Tuple[Any, bytes])
756-
class _RecordBatchToExamplesFn(beam.DoFn):
757-
"""Maps `pa.RecordBatch` to a generator of serialized `tf.Example`s."""
758-
759-
def __init__(self, schema: schema_pb2.Schema):
760-
self._coder = tfx_bsl.coders.example_coder.RecordBatchToExamplesEncoder(
761-
schema)
762-
763-
def process(
764-
self, data_batch: Tuple[pa.RecordBatch, Dict[str, pa.Array]]
765-
) -> Iterable[Tuple[Any, bytes]]:
766-
record_batch, unary_passthrough_features = data_batch
767-
if _TRANSFORM_INTERNAL_FEATURE_FOR_KEY in record_batch.schema.names:
768-
keys_index = record_batch.schema.names.index(
769-
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY)
770-
keys = record_batch.column(keys_index).to_pylist()
771-
# Filter the record batch to make sure that the internal column doesn't
772-
# get encoded.
773-
record_batch = _FilterInternalColumn(record_batch, keys_index)
774-
examples = self._coder.encode(record_batch)
775-
for key, example in zip(keys, examples):
776-
yield (None if key is None else key[0], example)
777-
else:
778-
# Internal feature key is not present in the record batch but may be
779-
# present in the unary pass-through features dict.
780-
key = unary_passthrough_features.get(
781-
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY, None)
782-
if key is not None:
783-
# The key is `pa.large_list()` and is, therefore, doubly nested.
784-
key_list = key.to_pylist()[0]
785-
key = None if key_list is None else key_list[0]
786-
examples = self._coder.encode(record_batch)
787-
for example in examples:
788-
yield (key, example)
789-
790736
@beam.typehints.with_input_types(beam.Pipeline)
791737
class _OptimizeRun(beam.PTransform):
792738
"""Utilizes TFT cache if applicable and removes unused datasets."""
@@ -1461,15 +1407,15 @@ def _ExtractRawExampleBatches(record_batch):
14611407
| 'Transform[{}]'.format(infix) >>
14621408
tft_beam.TransformDataset(output_record_batches=True))
14631409

1464-
_, metadata = transform_fn
1465-
1466-
# TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
1467-
# schema. Currently input dataset schema only contains dtypes,
1468-
# and other metadata is dropped due to roundtrip to tensors.
1469-
transformed_schema_proto = metadata.schema
1470-
14711410
if not disable_statistics:
14721411
# Aggregated feature stats after transformation.
1412+
_, metadata = transform_fn
1413+
1414+
# TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
1415+
# schema. Currently input dataset schema only contains dtypes,
1416+
# and other metadata is dropped due to roundtrip to tensors.
1417+
transformed_schema_proto = metadata.schema
1418+
14731419
for dataset in transform_data_list:
14741420
infix = 'TransformIndex{}'.format(dataset.index)
14751421
dataset.transformed_and_standardized = (
@@ -1543,8 +1489,8 @@ def _ExtractRawExampleBatches(record_batch):
15431489
for dataset in transform_data_list:
15441490
infix = 'TransformIndex{}'.format(dataset.index)
15451491
(dataset.transformed
1546-
| 'EncodeAndSerialize[{}]'.format(infix) >> beam.ParDo(
1547-
self._RecordBatchToExamplesFn(transformed_schema_proto))
1492+
| 'EncodeAndSerialize[{}]'.format(infix) >> beam.FlatMap(
1493+
Executor._RecordBatchToExamples)
15481494
| 'Materialize[{}]'.format(infix) >> self._WriteExamples(
15491495
materialization_format, dataset.materialize_output_path))
15501496

@@ -1732,6 +1678,54 @@ def _GetTFXIOPassthroughKeys() -> Optional[Set[str]]:
17321678
"""Always returns None."""
17331679
return None
17341680

1681+
@staticmethod
1682+
def _FilterInternalColumn(
1683+
record_batch: pa.RecordBatch,
1684+
internal_column_index: Optional[int] = None) -> pa.RecordBatch:
1685+
"""Returns shallow copy of a batch with internal column removed."""
1686+
if (internal_column_index is None and
1687+
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY not in record_batch.schema.names):
1688+
return record_batch
1689+
else:
1690+
internal_column_index = (
1691+
internal_column_index or
1692+
record_batch.schema.names.index(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY))
1693+
# Making shallow copy since input modification is not allowed.
1694+
filtered_columns = list(record_batch.columns)
1695+
filtered_columns.pop(internal_column_index)
1696+
filtered_schema = record_batch.schema.remove(internal_column_index)
1697+
return pa.RecordBatch.from_arrays(
1698+
filtered_columns, schema=filtered_schema)
1699+
1700+
@staticmethod
1701+
def _RecordBatchToExamples(
1702+
data_batch: Tuple[pa.RecordBatch, Dict[str, pa.Array]]
1703+
) -> Generator[Tuple[Any, bytes], None, None]:
1704+
"""Maps `pa.RecordBatch` to a generator of serialized `tf.Example`s."""
1705+
record_batch, unary_passthrough_features = data_batch
1706+
if _TRANSFORM_INTERNAL_FEATURE_FOR_KEY in record_batch.schema.names:
1707+
keys_index = record_batch.schema.names.index(
1708+
_TRANSFORM_INTERNAL_FEATURE_FOR_KEY)
1709+
keys = record_batch.column(keys_index).to_pylist()
1710+
# Filter the record batch to make sure that the internal column doesn't
1711+
# get encoded.
1712+
record_batch = Executor._FilterInternalColumn(record_batch, keys_index)
1713+
examples = tfx_bsl.coders.example_coder.RecordBatchToExamples(
1714+
record_batch)
1715+
for key, example in zip(keys, examples):
1716+
yield (None if key is None else key[0], example)
1717+
else:
1718+
# Internal feature key is not present in the record batch but may be
1719+
# present in the unary pass-through features dict.
1720+
key = unary_passthrough_features.get(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY,
1721+
None)
1722+
if key is not None:
1723+
key = None if key.to_pylist()[0] is None else key.to_pylist()[0][0]
1724+
examples = tfx_bsl.coders.example_coder.RecordBatchToExamples(
1725+
record_batch)
1726+
for example in examples:
1727+
yield (key, example)
1728+
17351729
# TODO(b/130885503): clean this up once the sketch-based generator is the
17361730
# default.
17371731
@staticmethod

0 commit comments

Comments
 (0)