Skip to content

Commit 429fd63

Browse files
authored
Merge pull request #34 from terrytangyuan/preprocess
Update preprocess code to work with tf.transform v0.11.0
2 parents 349d4ff + 18cf2cf commit 429fd63

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

examples/chicago_taxi/preprocess.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626
from trainer import taxi
2727

2828
import tensorflow_transform as transform
29+
import tensorflow_transform.beam as tft_beam
2930

30-
from tensorflow_transform.beam import impl as beam_impl
31-
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
3231
from tensorflow_transform.coders import example_proto_coder
3332
from tensorflow_transform.tf_metadata import dataset_metadata
3433
from tensorflow_transform.tf_metadata import dataset_schema
@@ -127,7 +126,7 @@ def preprocessing_fn(inputs):
127126
raw_data_metadata = dataset_metadata.DatasetMetadata(raw_schema)
128127

129128
with beam.Pipeline(argv=pipeline_args) as pipeline:
130-
with beam_impl.Context(temp_dir=working_dir):
129+
with tft_beam.Context(temp_dir=working_dir):
131130
if input_handle.lower().endswith('csv'):
132131
csv_coder = taxi.make_csv_coder(schema)
133132
raw_data = (
@@ -147,22 +146,22 @@ def preprocessing_fn(inputs):
147146
if transform_dir is None:
148147
transform_fn = (
149148
(raw_data, raw_data_metadata)
150-
| ('Analyze' >> beam_impl.AnalyzeDataset(preprocessing_fn)))
149+
| ('Analyze' >> tft_beam.AnalyzeDataset(preprocessing_fn)))
151150

152151
_ = (
153152
transform_fn
154153
| ('WriteTransformFn' >>
155-
transform_fn_io.WriteTransformFn(working_dir)))
154+
tft_beam.WriteTransformFn(working_dir)))
156155
else:
157-
transform_fn = pipeline | transform_fn_io.ReadTransformFn(transform_dir)
156+
transform_fn = pipeline | tft_beam.ReadTransformFn(transform_dir)
158157

159158
# Shuffling the data before materialization will improve Training
160159
# effectiveness downstream.
161160
shuffled_data = raw_data | 'RandomizeData' >> beam.transforms.Reshuffle()
162161

163162
(transformed_data, transformed_metadata) = (
164163
((shuffled_data, raw_data_metadata), transform_fn)
165-
| 'Transform' >> beam_impl.TransformDataset())
164+
| 'Transform' >> tft_beam.TransformDataset())
166165

167166
coder = example_proto_coder.ExampleProtoCoder(transformed_metadata.schema)
168167
_ = (

examples/chicago_taxi/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,5 @@
3131
'tensorflow-metadata==0.9.0',
3232
'tensorflow-model-analysis==0.9.2',
3333
'tensorflow-serving-api==1.9.0',
34-
'tensorflow-transform==0.9.0',
34+
'tensorflow-transform==0.11.0',
3535
])

0 commit comments

Comments
 (0)