Skip to content

Commit 3a076ed

Browse files
committed
Update preprocess code to work with tf.transform v0.11.0
1 parent e2fde69 commit 3a076ed

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

examples/chicago_taxi/preprocess.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727

2828
import tensorflow_transform as transform
2929

30-
from tensorflow_transform.beam import impl as beam_impl
31-
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
3230
from tensorflow_transform.coders import example_proto_coder
31+
import tensorflow_transform.beam as tft_beam
3332
from tensorflow_transform.tf_metadata import dataset_metadata
3433
from tensorflow_transform.tf_metadata import dataset_schema
3534

@@ -126,7 +125,7 @@ def preprocessing_fn(inputs):
126125
raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)
127126

128127
with beam.Pipeline(argv=pipeline_args) as pipeline:
129-
with beam_impl.Context(temp_dir=working_dir):
128+
with tft_beam.Context(temp_dir=working_dir):
130129
if input_handle.lower().endswith('csv'):
131130
csv_coder = taxi.make_csv_coder(schema)
132131
raw_data = (
@@ -146,22 +145,22 @@ def preprocessing_fn(inputs):
146145
if transform_dir is None:
147146
transform_fn = (
148147
(raw_data, raw_data_metadata)
149-
| ('Analyze' >> beam_impl.AnalyzeDataset(preprocessing_fn)))
148+
| ('Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn)))
150149

151150
_ = (
152151
transform_fn
153152
| ('WriteTransformFn' >>
154-
transform_fn_io.WriteTransformFn(working_dir)))
153+
tft_beam.WriteTransformFn(working_dir)))
155154
else:
156-
transform_fn = pipeline | transform_fn_io.ReadTransformFn(transform_dir)
155+
transform_fn = pipeline | tft_beam.ReadTransformFn(transform_dir)
157156

158157
# Shuffling the data before materialization will improve Training
159158
# effectiveness downstream.
160159
shuffled_data = raw_data | 'RandomizeData' >> beam.transforms.Reshuffle()
161160

162161
(transformed_data, transformed_metadata) = (
163162
((shuffled_data, raw_data_metadata), transform_fn)
164-
| 'Transform' >> beam_impl.TransformDataset())
163+
| 'Transform' >> tft_beam.TransformDataset())
165164

166165
coder = example_proto_coder.ExampleProtoCoder(transformed_metadata.schema)
167166
_ = (

examples/chicago_taxi/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@
3131
'tensorflow-metadata==0.9.0',
3232
'tensorflow-model-analysis==0.9.1',
3333
'tensorflow-serving-api==1.9.0',
34-
'tensorflow-transform==0.9.0',
34+
'tensorflow-transform==0.11.0',
3535
])

0 commit comments

Comments
 (0)