Skip to content

Commit

Permalink
Add Java TFRecord read/write pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
Amar3tto committed Feb 19, 2024
1 parent 64118dc commit 967dd41
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 0 deletions.
75 changes: 75 additions & 0 deletions Java/src/main/java/tfrecord/ReadTFRecord.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package tfrecord;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.Validation;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Charsets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ReadTFRecord {

private static final Logger LOG =
LoggerFactory.getLogger(ReadTFRecord.class);

/**
* Pipeline options for read from TFRecord.
*/
public interface ReadTFRecordOptions extends PipelineOptions {

@Description("A file glob pattern to read TFRecords from")
@Validation.Required
String getFilePattern();

void setFilePattern(String filePattern);
}

public static void main(String[] args) {
ReadTFRecordOptions options =
PipelineOptionsFactory.fromArgs(args)
.withValidation().as(ReadTFRecordOptions.class);

Pipeline p = Pipeline.create(options);

p.apply(
"Read from TFRecord",
TFRecordIO.read()
.from(options.getFilePattern())
.withCompression(Compression.UNCOMPRESSED))
.apply(
"Convert to string and log",
ParDo.of(
new DoFn<byte[], String>() {
@DoFn.ProcessElement
public void processElement(ProcessContext c) {
String output =
new String(c.element(), Charsets.UTF_8);
LOG.info("Output: {}", output);
c.output(output);
}
}));

p.run();
}
}
77 changes: 77 additions & 0 deletions Java/src/main/java/tfrecord/WriteTFRecord.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package tfrecord;

import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.TFRecordIO;
import org.apache.beam.sdk.options.Description;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.options.Validation;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Charsets;

import java.util.Arrays;
import java.util.List;

public class WriteTFRecord {

/**
* Pipeline options for write to TFRecord.
*/
public interface WriteTFRecordOptions extends PipelineOptions {

@Description("A file path prefix to write TFRecords files to")
@Validation.Required
String getFilePathPrefix();

void setFilePathPrefix(String filePathPrefix);
}

public static void main(String[] args) {
WriteTFRecordOptions options =
PipelineOptionsFactory.fromArgs(args)
.withValidation().as(WriteTFRecordOptions.class);

Pipeline p = Pipeline.create(options);

List<String> rows = Arrays.asList(
"Charles", "Alice", "Bob", "Amanda", "Alex", "Eliza"
);

p.apply("Create", Create.of(rows))
.apply(
"Convert to bytes",
ParDo.of(
new DoFn<String, byte[]>() {
@DoFn.ProcessElement
public void processElement(ProcessContext c) {
c.output(c.element().getBytes(Charsets.UTF_8));
}
}))
.apply(
"Write to TFRecord",
TFRecordIO.write()
.to(options.getFilePathPrefix())
.withCompression(Compression.UNCOMPRESSED)
.withNumShards(1));

p.run();
}
}

0 comments on commit 967dd41

Please sign in to comment.