From 967dd414f59f2c5a088243642551dd98cbf75aea Mon Sep 17 00:00:00 2001 From: "vitaly.terentyev" Date: Mon, 19 Feb 2024 13:18:44 +0400 Subject: [PATCH] Add Java TFRecord read/write pipelines --- Java/src/main/java/tfrecord/ReadTFRecord.java | 75 ++++++++++++++++++ .../src/main/java/tfrecord/WriteTFRecord.java | 77 +++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 Java/src/main/java/tfrecord/ReadTFRecord.java create mode 100644 Java/src/main/java/tfrecord/WriteTFRecord.java diff --git a/Java/src/main/java/tfrecord/ReadTFRecord.java b/Java/src/main/java/tfrecord/ReadTFRecord.java new file mode 100644 index 0000000..1e0afa0 --- /dev/null +++ b/Java/src/main/java/tfrecord/ReadTFRecord.java @@ -0,0 +1,75 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package tfrecord; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.Compression; +import org.apache.beam.sdk.io.TFRecordIO; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Charsets; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ReadTFRecord { + + private static final Logger LOG = + LoggerFactory.getLogger(ReadTFRecord.class); + + /** + * Pipeline options for read from TFRecord. + */ + public interface ReadTFRecordOptions extends PipelineOptions { + + @Description("A file glob pattern to read TFRecords from") + @Validation.Required + String getFilePattern(); + + void setFilePattern(String filePattern); + } + + public static void main(String[] args) { + ReadTFRecordOptions options = + PipelineOptionsFactory.fromArgs(args) + .withValidation().as(ReadTFRecordOptions.class); + + Pipeline p = Pipeline.create(options); + + p.apply( + "Read from TFRecord", + TFRecordIO.read() + .from(options.getFilePattern()) + .withCompression(Compression.UNCOMPRESSED)) + .apply( + "Convert to string and log", + ParDo.of( + new DoFn() { + @DoFn.ProcessElement + public void processElement(ProcessContext c) { + String output = + new String(c.element(), Charsets.UTF_8); + LOG.info("Output: {}", output); + c.output(output); + } + })); + + p.run(); + } +} diff --git a/Java/src/main/java/tfrecord/WriteTFRecord.java b/Java/src/main/java/tfrecord/WriteTFRecord.java new file mode 100644 index 0000000..aea44ff --- /dev/null +++ b/Java/src/main/java/tfrecord/WriteTFRecord.java @@ -0,0 +1,77 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package tfrecord; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.Compression; +import org.apache.beam.sdk.io.TFRecordIO; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Charsets; + +import java.util.Arrays; +import java.util.List; + +public class WriteTFRecord { + + /** + * Pipeline options for write to TFRecord. + */ + public interface WriteTFRecordOptions extends PipelineOptions { + + @Description("A file path prefix to write TFRecords files to") + @Validation.Required + String getFilePathPrefix(); + + void setFilePathPrefix(String filePathPrefix); + } + + public static void main(String[] args) { + WriteTFRecordOptions options = + PipelineOptionsFactory.fromArgs(args) + .withValidation().as(WriteTFRecordOptions.class); + + Pipeline p = Pipeline.create(options); + + List rows = Arrays.asList( + "Charles", "Alice", "Bob", "Amanda", "Alex", "Eliza" + ); + + p.apply("Create", Create.of(rows)) + .apply( + "Convert to bytes", + ParDo.of( + new DoFn() { + @DoFn.ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element().getBytes(Charsets.UTF_8)); + } + })) + .apply( + "Write to TFRecord", + TFRecordIO.write() + .to(options.getFilePathPrefix()) + .withCompression(Compression.UNCOMPRESSED) + .withNumShards(1)); + + p.run(); + } +}