|
16 | 16 | import functools
|
17 | 17 | import hashlib
|
18 | 18 | import os
|
19 |
| -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union |
| 19 | +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union |
20 | 20 |
|
21 | 21 | from absl import logging
|
22 | 22 | import apache_beam as beam
|
@@ -275,24 +275,6 @@ def _InvokeStatsOptionsUpdaterFn(
|
275 | 275 | return stats_options_updater_fn(stats_type, tfdv.StatsOptions(**options))
|
276 | 276 |
|
277 | 277 |
|
278 |
| -def _FilterInternalColumn( |
279 |
| - record_batch: pa.RecordBatch, |
280 |
| - internal_column_index: Optional[int] = None) -> pa.RecordBatch: |
281 |
| - """Returns shallow copy of a RecordBatch with internal column removed.""" |
282 |
| - if (internal_column_index is None and |
283 |
| - _TRANSFORM_INTERNAL_FEATURE_FOR_KEY not in record_batch.schema.names): |
284 |
| - return record_batch |
285 |
| - else: |
286 |
| - internal_column_index = ( |
287 |
| - internal_column_index or |
288 |
| - record_batch.schema.names.index(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY)) |
289 |
| - # Making shallow copy since input modification is not allowed. |
290 |
| - filtered_columns = list(record_batch.columns) |
291 |
| - filtered_columns.pop(internal_column_index) |
292 |
| - filtered_schema = record_batch.schema.remove(internal_column_index) |
293 |
| - return pa.RecordBatch.from_arrays(filtered_columns, schema=filtered_schema) |
294 |
| - |
295 |
| - |
296 | 278 | class Executor(base_beam_executor.BaseBeamExecutor):
|
297 | 279 | """Transform executor."""
|
298 | 280 |
|
@@ -693,7 +675,7 @@ def _GenerateAndMaybeValidateStats(
|
693 | 675 |
|
694 | 676 | generated_stats = (
|
695 | 677 | pcoll
|
696 |
| - | 'FilterInternalColumn' >> beam.Map(_FilterInternalColumn) |
| 678 | + | 'FilterInternalColumn' >> beam.Map(Executor._FilterInternalColumn) |
697 | 679 | | 'GenerateStatistics' >> tfdv.GenerateStatistics(stats_options))
|
698 | 680 |
|
699 | 681 | stats_result = (
|
@@ -751,42 +733,6 @@ def setup(self):
|
751 | 733 | def process(self, element: List[bytes]) -> Iterable[pa.RecordBatch]:
|
752 | 734 | yield self._decoder.DecodeBatch(element)
|
753 | 735 |
|
754 |
| - @beam.typehints.with_input_types(Tuple[pa.RecordBatch, Dict[str, pa.Array]]) |
755 |
| - @beam.typehints.with_output_types(Tuple[Any, bytes]) |
756 |
| - class _RecordBatchToExamplesFn(beam.DoFn): |
757 |
| - """Maps `pa.RecordBatch` to a generator of serialized `tf.Example`s.""" |
758 |
| - |
759 |
| - def __init__(self, schema: schema_pb2.Schema): |
760 |
| - self._coder = tfx_bsl.coders.example_coder.RecordBatchToExamplesEncoder( |
761 |
| - schema) |
762 |
| - |
763 |
| - def process( |
764 |
| - self, data_batch: Tuple[pa.RecordBatch, Dict[str, pa.Array]] |
765 |
| - ) -> Iterable[Tuple[Any, bytes]]: |
766 |
| - record_batch, unary_passthrough_features = data_batch |
767 |
| - if _TRANSFORM_INTERNAL_FEATURE_FOR_KEY in record_batch.schema.names: |
768 |
| - keys_index = record_batch.schema.names.index( |
769 |
| - _TRANSFORM_INTERNAL_FEATURE_FOR_KEY) |
770 |
| - keys = record_batch.column(keys_index).to_pylist() |
771 |
| - # Filter the record batch to make sure that the internal column doesn't |
772 |
| - # get encoded. |
773 |
| - record_batch = _FilterInternalColumn(record_batch, keys_index) |
774 |
| - examples = self._coder.encode(record_batch) |
775 |
| - for key, example in zip(keys, examples): |
776 |
| - yield (None if key is None else key[0], example) |
777 |
| - else: |
778 |
| - # Internal feature key is not present in the record batch but may be |
779 |
| - # present in the unary pass-through features dict. |
780 |
| - key = unary_passthrough_features.get( |
781 |
| - _TRANSFORM_INTERNAL_FEATURE_FOR_KEY, None) |
782 |
| - if key is not None: |
783 |
| - # The key is `pa.large_list()` and is, therefore, doubly nested. |
784 |
| - key_list = key.to_pylist()[0] |
785 |
| - key = None if key_list is None else key_list[0] |
786 |
| - examples = self._coder.encode(record_batch) |
787 |
| - for example in examples: |
788 |
| - yield (key, example) |
789 |
| - |
790 | 736 | @beam.typehints.with_input_types(beam.Pipeline)
|
791 | 737 | class _OptimizeRun(beam.PTransform):
|
792 | 738 | """Utilizes TFT cache if applicable and removes unused datasets."""
|
@@ -1461,15 +1407,15 @@ def _ExtractRawExampleBatches(record_batch):
|
1461 | 1407 | | 'Transform[{}]'.format(infix) >>
|
1462 | 1408 | tft_beam.TransformDataset(output_record_batches=True))
|
1463 | 1409 |
|
1464 |
| - _, metadata = transform_fn |
1465 |
| - |
1466 |
| - # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in |
1467 |
| - # schema. Currently input dataset schema only contains dtypes, |
1468 |
| - # and other metadata is dropped due to roundtrip to tensors. |
1469 |
| - transformed_schema_proto = metadata.schema |
1470 |
| - |
1471 | 1410 | if not disable_statistics:
|
1472 | 1411 | # Aggregated feature stats after transformation.
|
| 1412 | + _, metadata = transform_fn |
| 1413 | + |
| 1414 | + # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in |
| 1415 | + # schema. Currently input dataset schema only contains dtypes, |
| 1416 | + # and other metadata is dropped due to roundtrip to tensors. |
| 1417 | + transformed_schema_proto = metadata.schema |
| 1418 | + |
1473 | 1419 | for dataset in transform_data_list:
|
1474 | 1420 | infix = 'TransformIndex{}'.format(dataset.index)
|
1475 | 1421 | dataset.transformed_and_standardized = (
|
@@ -1543,8 +1489,8 @@ def _ExtractRawExampleBatches(record_batch):
|
1543 | 1489 | for dataset in transform_data_list:
|
1544 | 1490 | infix = 'TransformIndex{}'.format(dataset.index)
|
1545 | 1491 | (dataset.transformed
|
1546 |
| - | 'EncodeAndSerialize[{}]'.format(infix) >> beam.ParDo( |
1547 |
| - self._RecordBatchToExamplesFn(transformed_schema_proto)) |
| 1492 | + | 'EncodeAndSerialize[{}]'.format(infix) >> beam.FlatMap( |
| 1493 | + Executor._RecordBatchToExamples) |
1548 | 1494 | | 'Materialize[{}]'.format(infix) >> self._WriteExamples(
|
1549 | 1495 | materialization_format, dataset.materialize_output_path))
|
1550 | 1496 |
|
@@ -1732,6 +1678,54 @@ def _GetTFXIOPassthroughKeys() -> Optional[Set[str]]:
|
1732 | 1678 | """Always returns None."""
|
1733 | 1679 | return None
|
1734 | 1680 |
|
| 1681 | + @staticmethod |
| 1682 | + def _FilterInternalColumn( |
| 1683 | + record_batch: pa.RecordBatch, |
| 1684 | + internal_column_index: Optional[int] = None) -> pa.RecordBatch: |
| 1685 | + """Returns shallow copy of a batch with internal column removed.""" |
| 1686 | + if (internal_column_index is None and |
| 1687 | + _TRANSFORM_INTERNAL_FEATURE_FOR_KEY not in record_batch.schema.names): |
| 1688 | + return record_batch |
| 1689 | + else: |
| 1690 | + internal_column_index = ( |
| 1691 | + internal_column_index or |
| 1692 | + record_batch.schema.names.index(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY)) |
| 1693 | + # Making shallow copy since input modification is not allowed. |
| 1694 | + filtered_columns = list(record_batch.columns) |
| 1695 | + filtered_columns.pop(internal_column_index) |
| 1696 | + filtered_schema = record_batch.schema.remove(internal_column_index) |
| 1697 | + return pa.RecordBatch.from_arrays( |
| 1698 | + filtered_columns, schema=filtered_schema) |
| 1699 | + |
| 1700 | + @staticmethod |
| 1701 | + def _RecordBatchToExamples( |
| 1702 | + data_batch: Tuple[pa.RecordBatch, Dict[str, pa.Array]] |
| 1703 | + ) -> Generator[Tuple[Any, bytes], None, None]: |
| 1704 | + """Maps `pa.RecordBatch` to a generator of serialized `tf.Example`s.""" |
| 1705 | + record_batch, unary_passthrough_features = data_batch |
| 1706 | + if _TRANSFORM_INTERNAL_FEATURE_FOR_KEY in record_batch.schema.names: |
| 1707 | + keys_index = record_batch.schema.names.index( |
| 1708 | + _TRANSFORM_INTERNAL_FEATURE_FOR_KEY) |
| 1709 | + keys = record_batch.column(keys_index).to_pylist() |
| 1710 | + # Filter the record batch to make sure that the internal column doesn't |
| 1711 | + # get encoded. |
| 1712 | + record_batch = Executor._FilterInternalColumn(record_batch, keys_index) |
| 1713 | + examples = tfx_bsl.coders.example_coder.RecordBatchToExamples( |
| 1714 | + record_batch) |
| 1715 | + for key, example in zip(keys, examples): |
| 1716 | + yield (None if key is None else key[0], example) |
| 1717 | + else: |
| 1718 | + # Internal feature key is not present in the record batch but may be |
| 1719 | + # present in the unary pass-through features dict. |
| 1720 | + key = unary_passthrough_features.get(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY, |
| 1721 | + None) |
| 1722 | + if key is not None: |
| 1723 | + key = None if key.to_pylist()[0] is None else key.to_pylist()[0][0] |
| 1724 | + examples = tfx_bsl.coders.example_coder.RecordBatchToExamples( |
| 1725 | + record_batch) |
| 1726 | + for example in examples: |
| 1727 | + yield (key, example) |
| 1728 | + |
1735 | 1729 | # TODO(b/130885503): clean this up once the sketch-based generator is the
|
1736 | 1730 | # default.
|
1737 | 1731 | @staticmethod
|
|
0 commit comments