You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ch...@apache.org on 2017/03/21 18:25:45 UTC
[1/2] beam git commit: add TensorFlow TFRecordIO
Repository: beam
Updated Branches:
refs/heads/master bea4f5aec -> 92d1a6635
add TensorFlow TFRecordIO
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/68d42f9b
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/68d42f9b
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/68d42f9b
Branch: refs/heads/master
Commit: 68d42f9b941dab420947a8aaa616070b1f3fb5f8
Parents: bea4f5a
Author: Neville Li <ne...@spotify.com>
Authored: Tue Feb 21 15:51:18 2017 -0500
Committer: Chamikara Jayalath <ch...@google.com>
Committed: Tue Mar 21 11:23:02 2017 -0700
----------------------------------------------------------------------
.../apache/beam/sdk/io/CompressedSource.java | 13 +-
.../java/org/apache/beam/sdk/io/TFRecordIO.java | 905 +++++++++++++++++++
.../java/org/apache/beam/sdk/io/TextIO.java | 2 +-
.../org/apache/beam/sdk/io/TFRecordIOTest.java | 368 ++++++++
.../java/org/apache/beam/sdk/io/TextIOTest.java | 2 +-
5 files changed, 1285 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
index 6de22f9..ecd0fd9 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
@@ -334,8 +334,14 @@ public class CompressedSource<T> extends FileBasedSource<T> {
super(filePatternOrSpec, minBundleSize, startOffset, endOffset);
this.sourceDelegate = sourceDelegate;
this.channelFactory = channelFactory;
+ boolean splittable = false;
+ try {
+ splittable = isSplittable();
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to determine if the source is splittable", e);
+ }
checkArgument(
- isSplittable() || startOffset == 0,
+ splittable || startOffset == 0,
"CompressedSources must start reading at offset 0. Requested offset: " + startOffset);
}
@@ -366,11 +372,12 @@ public class CompressedSource<T> extends FileBasedSource<T> {
* from the requested file name that the file is not compressed.
*/
@Override
- protected final boolean isSplittable() {
+ protected final boolean isSplittable() throws Exception {
if (channelFactory instanceof FileNameBasedDecompressingChannelFactory) {
FileNameBasedDecompressingChannelFactory fileNameBasedChannelFactory =
(FileNameBasedDecompressingChannelFactory) channelFactory;
- return !fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec());
+ return !fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec())
+ && sourceDelegate.isSplittable();
}
return false;
}
http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
new file mode 100644
index 0000000..243506c
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
@@ -0,0 +1,905 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.hash.HashFunction;
+import com.google.common.hash.Hashing;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.channels.ReadableByteChannel;
+import java.nio.channels.WritableByteChannel;
+import java.util.NoSuchElementException;
+import java.util.regex.Pattern;
+
+import javax.annotation.Nullable;
+
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.io.Read.Bounded;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.util.IOChannelUtils;
+import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PDone;
+
+/**
+ * {@link PTransform}s for reading and writing TensorFlow TFRecord files.
+ */
+public class TFRecordIO {
+ /** The default coder, which returns each record of the input file as a byte array. */
+ public static final Coder<byte[]> DEFAULT_BYTE_ARRAY_CODER = ByteArrayCoder.of();
+
+ /**
+ * A {@link PTransform} that reads from a TFRecord file (or multiple TFRecord
+ * files matching a pattern) and returns a {@link PCollection} containing
+ * the decoding of each of the records of the TFRecord file(s) as a byte array.
+ */
+ public static class Read {
+
+ /**
+ * Returns a transform for reading TFRecord files that reads from the file(s)
+ * with the given filename or filename pattern. This can be a local path (if running locally),
+ * or a Google Cloud Storage filename or filename pattern of the form
+ * {@code "gs://<bucket>/<filepath>"} (if running locally or via the Google Cloud Dataflow
+ * service). Standard <a href="http://docs.oracle.com/javase/tutorial/essential/io/find.html"
+ * >Java Filesystem glob patterns</a> ("*", "?", "[..]") are supported.
+ */
+ public static Bound from(String filepattern) {
+ return new Bound().from(filepattern);
+ }
+
+ /**
+ * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
+ */
+ public static Bound from(ValueProvider<String> filepattern) {
+ return new Bound().from(filepattern);
+ }
+ /**
+ * Returns a transform for reading TFRecord files that has GCS path validation on
+ * pipeline creation disabled.
+ *
+ * <p>This can be useful in the case where the GCS input does not
+ * exist at the pipeline creation time, but is expected to be
+ * available at execution time.
+ */
+ public static Bound withoutValidation() {
+ return new Bound().withoutValidation();
+ }
+
+ /**
+ * Returns a transform for reading TFRecord files that decompresses all input files
+ * using the specified compression type.
+ *
+ * <p>If no compression type is specified, the default is
+ * {@link TFRecordIO.CompressionType#AUTO}.
+ * In this mode, the compression type of the file is determined by its extension
+ * (e.g., {@code *.gz} is gzipped, {@code *.zlib} is zlib compressed, and all other
+ * extensions are uncompressed).
+ */
+ public static Bound withCompressionType(TFRecordIO.CompressionType compressionType) {
+ return new Bound().withCompressionType(compressionType);
+ }
+
+ /**
+ * A {@link PTransform} that reads from one or more TFRecord files and returns a bounded
+ * {@link PCollection} containing one element for each record of the input files.
+ */
+ public static class Bound extends PTransform<PBegin, PCollection<byte[]>> {
+ /** The filepattern to read from. */
+ @Nullable private final ValueProvider<String> filepattern;
+
+ /** An option to indicate if input validation is desired. Default is true. */
+ private final boolean validate;
+
+ /** Option to indicate the input source's compression type. Default is AUTO. */
+ private final TFRecordIO.CompressionType compressionType;
+
+ private Bound() {
+ this(null, null, true, TFRecordIO.CompressionType.AUTO);
+ }
+
+ private Bound(
+ @Nullable String name,
+ @Nullable ValueProvider<String> filepattern,
+ boolean validate,
+ TFRecordIO.CompressionType compressionType) {
+ super(name);
+ this.filepattern = filepattern;
+ this.validate = validate;
+ this.compressionType = compressionType;
+ }
+
+ /**
+ * Returns a new transform for reading from TFRecord files that's like this one but that
+ * reads from the file(s) with the given name or pattern. See {@link TFRecordIO.Read#from}
+ * for a description of filepatterns.
+ *
+ * <p>Does not modify this object.
+
+ */
+ public Bound from(String filepattern) {
+ checkNotNull(filepattern, "Filepattern cannot be empty.");
+ return new Bound(name, StaticValueProvider.of(filepattern), validate, compressionType);
+ }
+
+ /**
+ * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
+ */
+ public Bound from(ValueProvider<String> filepattern) {
+ checkNotNull(filepattern, "Filepattern cannot be empty.");
+ return new Bound(name, filepattern, validate, compressionType);
+ }
+
+ /**
+ * Returns a new transform for reading from TFRecord files that's like this one but
+ * that has GCS path validation on pipeline creation disabled.
+ *
+ * <p>This can be useful in the case where the GCS input does not
+ * exist at the pipeline creation time, but is expected to be
+ * available at execution time.
+ *
+ * <p>Does not modify this object.
+ */
+ public Bound withoutValidation() {
+ return new Bound(name, filepattern, false, compressionType);
+ }
+
+ /**
+ * Returns a new transform for reading from TFRecord files that's like this one but
+ * reads from input sources using the specified compression type.
+ *
+ * <p>If no compression type is specified, the default is
+ * {@link TFRecordIO.CompressionType#AUTO}.
+ * See {@link TFRecordIO.Read#withCompressionType} for more details.
+ *
+ * <p>Does not modify this object.
+ */
+ public Bound withCompressionType(TFRecordIO.CompressionType compressionType) {
+ return new Bound(name, filepattern, validate, compressionType);
+ }
+
+ @Override
+ public PCollection<byte[]> expand(PBegin input) {
+ if (filepattern == null) {
+ throw new IllegalStateException(
+ "Need to set the filepattern of a TFRecordIO.Read transform");
+ }
+
+ if (validate) {
+ checkState(filepattern.isAccessible(), "Cannot validate with a RVP.");
+ try {
+ checkState(
+ !IOChannelUtils.getFactory(filepattern.get()).match(filepattern.get()).isEmpty(),
+ "Unable to find any files matching %s",
+ filepattern);
+ } catch (IOException e) {
+ throw new IllegalStateException(
+ String.format("Failed to validate %s", filepattern.get()), e);
+ }
+ }
+
+ final Bounded<byte[]> read = org.apache.beam.sdk.io.Read.from(getSource());
+ PCollection<byte[]> pcol = input.getPipeline().apply("Read", read);
+ // Honor the default output coder that would have been used by this PTransform.
+ pcol.setCoder(getDefaultOutputCoder());
+ return pcol;
+ }
+
+ // Helper to create a source specific to the requested compression type.
+ protected FileBasedSource<byte[]> getSource() {
+ switch (compressionType) {
+ case NONE:
+ return new TFRecordSource(filepattern);
+ case AUTO:
+ return CompressedSource.from(new TFRecordSource(filepattern));
+ case GZIP:
+ return
+ CompressedSource.from(new TFRecordSource(filepattern))
+ .withDecompression(CompressedSource.CompressionMode.GZIP);
+ case ZLIB:
+ return
+ CompressedSource.from(new TFRecordSource(filepattern))
+ .withDecompression(CompressedSource.CompressionMode.DEFLATE);
+ default:
+ throw new IllegalArgumentException("Unknown compression type: " + compressionType);
+ }
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
+
+ String filepatternDisplay = filepattern.isAccessible()
+ ? filepattern.get() : filepattern.toString();
+ builder
+ .add(DisplayData.item("compressionType", compressionType.toString())
+ .withLabel("Compression Type"))
+ .addIfNotDefault(DisplayData.item("validation", validate)
+ .withLabel("Validation Enabled"), true)
+ .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
+ .withLabel("File Pattern"));
+ }
+
+ @Override
+ protected Coder<byte[]> getDefaultOutputCoder() {
+ return DEFAULT_BYTE_ARRAY_CODER;
+ }
+
+ public String getFilepattern() {
+ return filepattern.get();
+ }
+
+ public boolean needsValidation() {
+ return validate;
+ }
+
+ public TFRecordIO.CompressionType getCompressionType() {
+ return compressionType;
+ }
+ }
+
+ /** Disallow construction of utility classes. */
+ private Read() {}
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * A {@link PTransform} that writes a {@link PCollection} to TFRecord file (or
+ * multiple TFRecord files matching a sharding pattern), with each
+ * element of the input collection encoded into its own record.
+ */
+ public static class Write {
+
+ /**
+ * Returns a transform for writing to TFRecord files that writes to the file(s)
+ * with the given prefix. This can be a local filename
+ * (if running locally), or a Google Cloud Storage filename of
+ * the form {@code "gs://<bucket>/<filepath>"}
+ * (if running locally or via the Google Cloud Dataflow service).
+ *
+ * <p>The files written will begin with this prefix, followed by
+ * a shard identifier (see {@link TFRecordIO.Write.Bound#withNumShards(int)}, and end
+ * in a common extension, if given by {@link TFRecordIO.Write.Bound#withSuffix(String)}.
+ */
+ public static Bound to(String prefix) {
+ return new Bound().to(prefix);
+ }
+
+ /**
+ * Like {@link #to(String)}, but with a {@link ValueProvider}.
+ */
+ public static Bound to(ValueProvider<String> prefix) {
+ return new Bound().to(prefix);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that appends the specified suffix
+ * to the created files.
+ */
+ public static Bound withSuffix(String nameExtension) {
+ return new Bound().withSuffix(nameExtension);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that uses the provided shard count.
+ *
+ * <p>Constraining the number of shards is likely to reduce
+ * the performance of a pipeline. Setting this value is not recommended
+ * unless you require a specific number of output files.
+ *
+ * @param numShards the number of shards to use, or 0 to let the system
+ * decide.
+ */
+ public static Bound withNumShards(int numShards) {
+ return new Bound().withNumShards(numShards);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that uses the given shard name
+ * template.
+ *
+ * <p>See {@link ShardNameTemplate} for a description of shard templates.
+ */
+ public static Bound withShardNameTemplate(String shardTemplate) {
+ return new Bound().withShardNameTemplate(shardTemplate);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that forces a single file as
+ * output.
+ */
+ public static Bound withoutSharding() {
+ return new Bound().withoutSharding();
+ }
+
+ /**
+ * Returns a transform for writing to text files that has GCS path validation on
+ * pipeline creation disabled.
+ *
+ * <p>This can be useful in the case where the GCS output location does
+ * not exist at the pipeline creation time, but is expected to be available
+ * at execution time.
+ */
+ public static Bound withoutValidation() {
+ return new Bound().withoutValidation();
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files like this one but writes to output files
+ * using the specified compression type.
+ *
+ * <p>If no compression type is specified, the default is
+ * {@link TFRecordIO.CompressionType#NONE}.
+ * See {@link TFRecordIO.Read#withCompressionType} for more details.
+ */
+ public static Bound withCompressionType(CompressionType compressionType) {
+ return new Bound().withCompressionType(compressionType);
+ }
+
+ /**
+ * A PTransform that writes a bounded PCollection to a TFRecord file (or
+ * multiple TFRecord files matching a sharding pattern), with each
+ * PCollection element being encoded into its own record.
+ */
+ public static class Bound extends PTransform<PCollection<byte[]>, PDone> {
+ private static final String DEFAULT_SHARD_TEMPLATE = ShardNameTemplate.INDEX_OF_MAX;
+
+ /** The prefix of each file written, combined with suffix and shardTemplate. */
+ private final ValueProvider<String> filenamePrefix;
+ /** The suffix of each file written, combined with prefix and shardTemplate. */
+ private final String filenameSuffix;
+
+ /** Requested number of shards. 0 for automatic. */
+ private final int numShards;
+
+ /** The shard template of each file written, combined with prefix and suffix. */
+ private final String shardTemplate;
+
+ /** An option to indicate if output validation is desired. Default is true. */
+ private final boolean validate;
+
+ /** Option to indicate the output sink's compression type. Default is NONE. */
+ private final TFRecordIO.CompressionType compressionType;
+
+ private Bound() {
+ this(null, null, "", 0, DEFAULT_SHARD_TEMPLATE, true, TFRecordIO.CompressionType.NONE);
+ }
+
+ private Bound(String name, ValueProvider<String> filenamePrefix, String filenameSuffix,
+ int numShards, String shardTemplate, boolean validate,
+ CompressionType compressionType) {
+ super(name);
+ this.filenamePrefix = filenamePrefix;
+ this.filenameSuffix = filenameSuffix;
+ this.numShards = numShards;
+ this.shardTemplate = shardTemplate;
+ this.validate = validate;
+ this.compressionType = compressionType;
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that's like this one but
+ * that writes to the file(s) with the given filename prefix.
+ *
+ * <p>See {@link TFRecordIO.Write#to(String) Write.to(String)} for more information.
+ *
+ * <p>Does not modify this object.
+ */
+ public Bound to(String filenamePrefix) {
+ validateOutputComponent(filenamePrefix);
+ return new Bound(name, StaticValueProvider.of(filenamePrefix), filenameSuffix, numShards,
+ shardTemplate, validate, compressionType);
+ }
+
+ /**
+ * Like {@link #to(String)}, but with a {@link ValueProvider}.
+ */
+ public Bound to(ValueProvider<String> filenamePrefix) {
+ return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate,
+ compressionType);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that that's like this one but
+ * that writes to the file(s) with the given filename suffix.
+ *
+ * <p>Does not modify this object.
+ *
+ * @see ShardNameTemplate
+ */
+ public Bound withSuffix(String nameExtension) {
+ validateOutputComponent(nameExtension);
+ return new Bound(name, filenamePrefix, nameExtension, numShards, shardTemplate, validate,
+ compressionType);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that's like this one but
+ * that uses the provided shard count.
+ *
+ * <p>Constraining the number of shards is likely to reduce
+ * the performance of a pipeline. Setting this value is not recommended
+ * unless you require a specific number of output files.
+ *
+ * <p>Does not modify this object.
+ *
+ * @param numShards the number of shards to use, or 0 to let the system
+ * decide.
+ * @see ShardNameTemplate
+ */
+ public Bound withNumShards(int numShards) {
+ checkArgument(numShards >= 0);
+ return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate,
+ compressionType);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that's like this one but
+ * that uses the given shard name template.
+ *
+ * <p>Does not modify this object.
+ *
+ * @see ShardNameTemplate
+ */
+ public Bound withShardNameTemplate(String shardTemplate) {
+ return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate,
+ compressionType);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that's like this one but
+ * that forces a single file as output.
+ *
+ * <p>Constraining the number of shards is likely to reduce
+ * the performance of a pipeline. Using this setting is not recommended
+ * unless you truly require a single output file.
+ *
+ * <p>This is a shortcut for
+ * {@code .withNumShards(1).withShardNameTemplate("")}
+ *
+ * <p>Does not modify this object.
+ */
+ public Bound withoutSharding() {
+ return new Bound(name, filenamePrefix, filenameSuffix, 1, "",
+ validate, compressionType);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files that's like this one but
+ * that has GCS output path validation on pipeline creation disabled.
+ *
+ * <p>This can be useful in the case where the GCS output location does
+ * not exist at the pipeline creation time, but is expected to be
+ * available at execution time.
+ *
+ * <p>Does not modify this object.
+ */
+ public Bound withoutValidation() {
+ return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, false,
+ compressionType);
+ }
+
+ /**
+ * Returns a transform for writing to TFRecord files like this one but writes to output files
+ * using the specified compression type.
+ *
+ * <p>If no compression type is specified, the default is
+ * {@link TFRecordIO.CompressionType#NONE}.
+ * See {@link TFRecordIO.Read#withCompressionType} for more details.
+ *
+ * <p>Does not modify this object.
+ */
+ public Bound withCompressionType(CompressionType compressionType) {
+ return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate,
+ compressionType);
+ }
+
+ @Override
+ public PDone expand(PCollection<byte[]> input) {
+ if (filenamePrefix == null) {
+ throw new IllegalStateException(
+ "need to set the filename prefix of a TFRecordIO.Write transform");
+ }
+ org.apache.beam.sdk.io.Write<byte[]> write =
+ org.apache.beam.sdk.io.Write.to(
+ new TFRecordSink(filenamePrefix, filenameSuffix, shardTemplate, compressionType));
+ if (getNumShards() > 0) {
+ write = write.withNumShards(getNumShards());
+ }
+ return input.apply("Write", write);
+ }
+
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
+
+ String prefixString = filenamePrefix.isAccessible()
+ ? filenamePrefix.get() : filenamePrefix.toString();
+ builder
+ .addIfNotNull(DisplayData.item("filePrefix", prefixString)
+ .withLabel("Output File Prefix"))
+ .addIfNotDefault(DisplayData.item("fileSuffix", filenameSuffix)
+ .withLabel("Output File Suffix"), "")
+ .addIfNotDefault(DisplayData.item("shardNameTemplate", shardTemplate)
+ .withLabel("Output Shard Name Template"),
+ DEFAULT_SHARD_TEMPLATE)
+ .addIfNotDefault(DisplayData.item("validation", validate)
+ .withLabel("Validation Enabled"), true)
+ .addIfNotDefault(DisplayData.item("numShards", numShards)
+ .withLabel("Maximum Output Shards"), 0)
+ .add(DisplayData
+ .item("compressionType", compressionType.toString())
+ .withLabel("Compression Type"));
+ }
+
+ /**
+ * Returns the current shard name template string.
+ */
+ public String getShardNameTemplate() {
+ return shardTemplate;
+ }
+
+ @Override
+ protected Coder<Void> getDefaultOutputCoder() {
+ return VoidCoder.of();
+ }
+
+ public String getFilenamePrefix() {
+ return filenamePrefix.get();
+ }
+
+ public String getShardTemplate() {
+ return shardTemplate;
+ }
+
+ public int getNumShards() {
+ return numShards;
+ }
+
+ public String getFilenameSuffix() {
+ return filenameSuffix;
+ }
+
+ public boolean needsValidation() {
+ return validate;
+ }
+ }
+ }
+
+ /**
+ * Possible TFRecord file compression types.
+ */
+ public enum CompressionType {
+ /**
+ * Automatically determine the compression type based on filename extension.
+ */
+ AUTO(""),
+ /**
+ * Uncompressed.
+ */
+ NONE(""),
+ /**
+ * GZipped.
+ */
+ GZIP(".gz"),
+ /**
+ * ZLIB compressed.
+ */
+ ZLIB(".zlib");
+
+ private String filenameSuffix;
+
+ CompressionType(String suffix) {
+ this.filenameSuffix = suffix;
+ }
+
+ /**
+ * Determine if a given filename matches a compression type based on its extension.
+ * @param filename the filename to match
+ * @return true iff the filename ends with the compression type's known extension.
+ */
+ public boolean matches(String filename) {
+ return filename.toLowerCase().endsWith(filenameSuffix.toLowerCase());
+ }
+ }
+
+ // Pattern which matches old-style shard output patterns, which are now
+ // disallowed.
+ private static final Pattern SHARD_OUTPUT_PATTERN = Pattern.compile("@([0-9]+|\\*)");
+
+ private static void validateOutputComponent(String partialFilePattern) {
+ checkArgument(
+ !SHARD_OUTPUT_PATTERN.matcher(partialFilePattern).find(),
+ "Output name components are not allowed to contain @* or @N patterns: "
+ + partialFilePattern);
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+
+ /** Disable construction of utility class. */
+ private TFRecordIO() {}
+
+ /**
+ * A {@link FileBasedSource} which can decode records in TFRecord files.
+ */
+ @VisibleForTesting
+ static class TFRecordSource extends FileBasedSource<byte[]> {
+ @VisibleForTesting
+ TFRecordSource(String fileSpec) {
+ super(fileSpec, 1L);
+ }
+
+ @VisibleForTesting
+ TFRecordSource(ValueProvider<String> fileSpec) {
+ super(fileSpec, Long.MAX_VALUE);
+ }
+
+ private TFRecordSource(String fileName, long start, long end) {
+ super(fileName, Long.MAX_VALUE, start, end);
+ }
+
+ @Override
+ protected FileBasedSource<byte[]> createForSubrangeOfFile(
+ String fileName,
+ long start,
+ long end) {
+ checkArgument(start == 0, "TFRecordSource is not splittable");
+ return new TFRecordSource(fileName, start, end);
+ }
+
+ @Override
+ protected FileBasedReader<byte[]> createSingleFileReader(PipelineOptions options) {
+ return new TFRecordReader(this);
+ }
+
+ @Override
+ public Coder<byte[]> getDefaultOutputCoder() {
+ return DEFAULT_BYTE_ARRAY_CODER;
+ }
+
+ @Override
+ protected boolean isSplittable() throws Exception {
+ // TFRecord files are not splittable
+ return false;
+ }
+
+ /**
+ * A {@link org.apache.beam.sdk.io.FileBasedSource.FileBasedReader FileBasedReader}
+ * which can decode records in TFRecord files.
+ *
+ * <p>See {@link TFRecordIO.TFRecordSource} for further details.
+ */
+ @VisibleForTesting
+ static class TFRecordReader extends FileBasedReader<byte[]> {
+ private long startOfRecord;
+ private volatile long startOfNextRecord;
+ private volatile boolean elementIsPresent;
+ private byte[] currentValue;
+ private ReadableByteChannel inChannel;
+ private TFRecordCodec codec;
+
+ private TFRecordReader(TFRecordSource source) {
+ super(source);
+ }
+
+ @Override
+ protected long getCurrentOffset() throws NoSuchElementException {
+ if (!elementIsPresent) {
+ throw new NoSuchElementException();
+ }
+ return startOfRecord;
+ }
+
+ @Override
+ public byte[] getCurrent() throws NoSuchElementException {
+ if (!elementIsPresent) {
+ throw new NoSuchElementException();
+ }
+ return currentValue;
+ }
+
+ @Override
+ protected void startReading(ReadableByteChannel channel) throws IOException {
+ this.inChannel = channel;
+ this.codec = new TFRecordCodec();
+ }
+
+ @Override
+ protected boolean readNextRecord() throws IOException {
+ startOfRecord = startOfNextRecord;
+ currentValue = codec.read(inChannel);
+ if (currentValue != null) {
+ elementIsPresent = true;
+ startOfNextRecord = startOfRecord + codec.recordLength(currentValue);
+ return true;
+ } else {
+ elementIsPresent = false;
+ return false;
+ }
+ }
+ }
+ }
+
+ /**
+ * A {@link FileBasedSink} for TFRecord files. Produces TFRecord files.
+ */
+ @VisibleForTesting
+ static class TFRecordSink extends FileBasedSink<byte[]> {
+ @VisibleForTesting
+ TFRecordSink(ValueProvider<String> baseOutputFilename,
+ String extension,
+ String fileNameTemplate,
+ TFRecordIO.CompressionType compressionType) {
+ super(baseOutputFilename, extension, fileNameTemplate,
+ writableByteChannelFactory(compressionType));
+ }
+
+ @Override
+ public FileBasedWriteOperation<byte[]> createWriteOperation(PipelineOptions options) {
+ return new TFRecordWriteOperation(this);
+ }
+
+ private static WritableByteChannelFactory writableByteChannelFactory(
+ TFRecordIO.CompressionType compressionType) {
+ switch (compressionType) {
+ case AUTO:
+ throw new IllegalArgumentException("Unsupported compression type AUTO");
+ case NONE:
+ return CompressionType.UNCOMPRESSED;
+ case GZIP:
+ return CompressionType.GZIP;
+ case ZLIB:
+ return CompressionType.DEFLATE;
+ }
+ return CompressionType.UNCOMPRESSED;
+ }
+
+ /**
+ * A {@link org.apache.beam.sdk.io.FileBasedSink.FileBasedWriteOperation
+ * FileBasedWriteOperation} for TFRecord files.
+ */
+ private static class TFRecordWriteOperation extends FileBasedWriteOperation<byte[]> {
+ private TFRecordWriteOperation(TFRecordSink sink) {
+ super(sink);
+ }
+
+ @Override
+ public FileBasedWriter<byte[]> createWriter(PipelineOptions options) throws Exception {
+ return new TFRecordWriter(this);
+ }
+ }
+
+ /**
+ * A {@link org.apache.beam.sdk.io.FileBasedSink.FileBasedWriter FileBasedWriter}
+ * for TFRecord files.
+ */
+ private static class TFRecordWriter extends FileBasedWriter<byte[]> {
+ private WritableByteChannel outChannel;
+ private TFRecordCodec codec;
+
+ private TFRecordWriter(FileBasedWriteOperation<byte[]> writeOperation) {
+ super(writeOperation);
+ this.mimeType = MimeTypes.BINARY;
+ }
+
+ @Override
+ protected void prepareWrite(WritableByteChannel channel) throws Exception {
+ this.outChannel = channel;
+ this.codec = new TFRecordCodec();
+ }
+
+ @Override
+ public void write(byte[] value) throws Exception {
+ codec.write(outChannel, value);
+ }
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Codec for TFRecords file format.
+ * See https://www.tensorflow.org/api_guides/python/python_io#TFRecords_Format_Details
+ */
+ private static class TFRecordCodec {
+ private static final int HEADER_LEN = (Long.SIZE + Integer.SIZE) / Byte.SIZE;
+ private static final int FOOTER_LEN = Integer.SIZE / Byte.SIZE;
+ private static HashFunction crc32c = Hashing.crc32c();
+
+ private ByteBuffer header = ByteBuffer.allocate(HEADER_LEN).order(ByteOrder.LITTLE_ENDIAN);
+ private ByteBuffer footer = ByteBuffer.allocate(FOOTER_LEN).order(ByteOrder.LITTLE_ENDIAN);
+
+ private int mask(int crc) {
+ return ((crc >>> 15) | (crc << 17)) + 0xa282ead8;
+ }
+
+ private int hashLong(long x) {
+ return mask(crc32c.hashLong(x).asInt());
+ }
+
+ private int hashBytes(byte[] x) {
+ return mask(crc32c.hashBytes(x).asInt());
+ }
+
+ public int recordLength(byte[] data) {
+ return HEADER_LEN + data.length + FOOTER_LEN;
+ }
+
+ public byte[] read(ReadableByteChannel inChannel) throws IOException {
+ header.clear();
+ int headerBytes = inChannel.read(header);
+ if (headerBytes <= 0) {
+ return null;
+ }
+ checkState(
+ headerBytes == HEADER_LEN,
+ "Not a valid TFRecord. Fewer than 12 bytes.");
+ header.rewind();
+ long length = header.getLong();
+ int maskedCrc32OfLength = header.getInt();
+ checkState(
+ hashLong(length) == maskedCrc32OfLength,
+ "Mismatch of length mask");
+
+ ByteBuffer data = ByteBuffer.allocate((int) length);
+ checkState(inChannel.read(data) == length, "Invalid data");
+
+ footer.clear();
+ inChannel.read(footer);
+ footer.rewind();
+ int maskedCrc32OfData = footer.getInt();
+
+ checkState(
+ hashBytes(data.array()) == maskedCrc32OfData,
+ "Mismatch of data mask");
+ return data.array();
+ }
+
+ public void write(WritableByteChannel outChannel, byte[] data) throws IOException {
+ int maskedCrc32OfLength = hashLong(data.length);
+ int maskedCrc32OfData = hashBytes(data);
+
+ header.clear();
+ header.putLong(data.length).putInt(maskedCrc32OfLength);
+ header.rewind();
+ outChannel.write(header);
+
+ outChannel.write(ByteBuffer.wrap(data));
+
+ footer.clear();
+ footer.putInt(maskedCrc32OfData);
+ footer.rewind();
+ outChannel.write(footer);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
index fe8d0fd..f8943a5 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
@@ -682,7 +682,7 @@ public class TextIO {
.addIfNotNull(DisplayData.item("filePrefix", prefixString)
.withLabel("Output File Prefix"))
.addIfNotDefault(DisplayData.item("fileSuffix", filenameSuffix)
- .withLabel("Output Fix Suffix"), "")
+ .withLabel("Output File Suffix"), "")
.addIfNotDefault(DisplayData.item("shardNameTemplate", shardTemplate)
.withLabel("Output Shard Name Template"),
DEFAULT_SHARD_TEMPLATE)
http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
new file mode 100644
index 0000000..70620fb
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType;
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.AUTO;
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.GZIP;
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.NONE;
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.ZLIB;
+import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
+import static org.hamcrest.Matchers.isIn;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.BaseEncoding;
+import com.google.common.io.ByteStreams;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.file.FileVisitResult;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.SimpleFileVisitor;
+import java.nio.file.attribute.BasicFileAttributes;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for TFRecordIO Read and Write transforms.
+ */
+@RunWith(JUnit4.class)
+public class TFRecordIOTest {
+
+ /*
+ From https://github.com/apache/beam/blob/master/sdks/python/apache_beam/io/tfrecordio_test.py
+ Created by running following code in python:
+ >>> import tensorflow as tf
+ >>> import base64
+ >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
+ >>> writer.write('foo')
+ >>> writer.close()
+ >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
+ ... data = base64.b64encode(f.read())
+ ... print data
+ */
+ private static final String FOO_RECORD_BASE64 = "AwAAAAAAAACwmUkOZm9vYYq+/g==";
+
+ // Same as above but containing two records ['foo', 'bar']
+ private static final String FOO_BAR_RECORD_BASE64 =
+ "AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=";
+ private static final String BAR_FOO_RECORD_BASE64 =
+ "AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4=";
+
+ private static final String[] FOO_RECORDS = {"foo"};
+ private static final String[] FOO_BAR_RECORDS = {"foo", "bar"};
+
+ private static final Iterable<String> EMPTY = Collections.emptyList();
+ private static final Iterable<String> LARGE = makeLines(5000);
+
+ private static Path tempFolder;
+
+ @Rule
+ public TestPipeline p = TestPipeline.create();
+
+ @Rule
+ public TestPipeline p2 = TestPipeline.create();
+
+ @Rule
+ public ExpectedException expectedException = ExpectedException.none();
+
+ @BeforeClass
+ public static void setupClass() throws IOException {
+ tempFolder = Files.createTempDirectory("TFRecordIOTest");
+ }
+
+ @AfterClass
+ public static void teardownClass() throws IOException {
+ Files.walkFileTree(tempFolder, new SimpleFileVisitor<Path>() {
+ @Override
+ public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
+ Files.delete(file);
+ return FileVisitResult.CONTINUE;
+ }
+
+ @Override
+ public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
+ Files.delete(dir);
+ return FileVisitResult.CONTINUE;
+ }
+ });
+ }
+
+ @Test
+ public void testReadNamed() {
+ p.enableAbandonedNodeEnforcement(false);
+
+ assertEquals(
+ "TFRecordIO.Read/Read.out",
+ p.apply(TFRecordIO.Read.withoutValidation().from("foo.*")).getName());
+ assertEquals(
+ "MyRead/Read.out",
+ p.apply("MyRead", TFRecordIO.Read.withoutValidation().from("foo.*")).getName());
+ }
+
+ @Test
+ public void testReadDisplayData() {
+ TFRecordIO.Read.Bound read = TFRecordIO.Read
+ .from("foo.*")
+ .withCompressionType(GZIP)
+ .withoutValidation();
+
+ DisplayData displayData = DisplayData.from(read);
+
+ assertThat(displayData, hasDisplayItem("filePattern", "foo.*"));
+ assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString()));
+ assertThat(displayData, hasDisplayItem("validation", false));
+ }
+
+ @Test
+ public void testWriteDisplayData() {
+ TFRecordIO.Write.Bound write = TFRecordIO.Write
+ .to("foo")
+ .withSuffix("bar")
+ .withShardNameTemplate("-SS-of-NN-")
+ .withNumShards(100)
+ .withCompressionType(GZIP)
+ .withoutValidation();
+
+ DisplayData displayData = DisplayData.from(write);
+
+ assertThat(displayData, hasDisplayItem("filePrefix", "foo"));
+ assertThat(displayData, hasDisplayItem("fileSuffix", "bar"));
+ assertThat(displayData, hasDisplayItem("shardNameTemplate", "-SS-of-NN-"));
+ assertThat(displayData, hasDisplayItem("numShards", 100));
+ assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString()));
+ assertThat(displayData, hasDisplayItem("validation", false));
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testReadOne() throws Exception {
+ runTestRead(FOO_RECORD_BASE64, FOO_RECORDS);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testReadTwo() throws Exception {
+ runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testWriteOne() throws Exception {
+ runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testWriteTwo() throws Exception {
+ runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, BAR_FOO_RECORD_BASE64);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testReadInvalidRecord() throws Exception {
+ expectedException.expect(IllegalStateException.class);
+ expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 bytes.");
+ System.out.println("abr".getBytes().length);
+ runTestRead("bar".getBytes(), new String[0]);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testReadInvalidLengthMask() throws Exception {
+ expectedException.expect(IllegalStateException.class);
+ expectedException.expectMessage("Mismatch of length mask");
+ byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+ data[9] += 1;
+ runTestRead(data, FOO_RECORDS);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void testReadInvalidDataMask() throws Exception {
+ expectedException.expect(IllegalStateException.class);
+ expectedException.expectMessage("Mismatch of data mask");
+ byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+ data[16] += 1;
+ runTestRead(data, FOO_RECORDS);
+ }
+
+ private void runTestRead(String base64, String[] expected) throws IOException {
+ runTestRead(BaseEncoding.base64().decode(base64), expected);
+ }
+
+ private void runTestRead(byte[] data, String[] expected) throws IOException {
+ File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile();
+ String filename = tmpFile.getPath();
+
+ FileOutputStream fos = new FileOutputStream(tmpFile);
+ fos.write(data);
+ fos.close();
+
+ TFRecordIO.Read.Bound read = TFRecordIO.Read.from(filename);
+ PCollection<String> output = p.apply(read).apply(ParDo.of(new ByteArrayToString()));
+
+ PAssert.that(output).containsInAnyOrder(expected);
+ p.run();
+ }
+
+ private void runTestWrite(String[] elems, String ...base64) throws IOException {
+ File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile();
+ String filename = tmpFile.getPath();
+
+ PCollection<byte[]> input = p.apply(Create.of(Arrays.asList(elems)))
+ .apply(ParDo.of(new StringToByteArray()));
+
+ TFRecordIO.Write.Bound write = TFRecordIO.Write.to(filename).withoutSharding();
+ input.apply(write);
+
+ p.run();
+
+ FileInputStream fis = new FileInputStream(tmpFile);
+ String written = BaseEncoding.base64().encode(ByteStreams.toByteArray(fis));
+ // bytes written may vary depending the order of elems
+ assertThat(written, isIn(base64));
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void runTestRoundTrip() throws IOException {
+ runTestRoundTrip(LARGE, 10, ".tfrecords", NONE, NONE);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void runTestRoundTripWithEmptyData() throws IOException {
+ runTestRoundTrip(EMPTY, 10, ".tfrecords", NONE, NONE);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void runTestRoundTripWithOneShards() throws IOException {
+ runTestRoundTrip(LARGE, 1, ".tfrecords", NONE, NONE);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void runTestRoundTripWithSuffix() throws IOException {
+ runTestRoundTrip(LARGE, 10, ".suffix", NONE, NONE);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void runTestRoundTripGzip() throws IOException {
+ runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, GZIP);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void runTestRoundTripZlib() throws IOException {
+ runTestRoundTrip(LARGE, 10, ".tfrecords", ZLIB, ZLIB);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void runTestRoundTripUncompressedFilesWithAuto() throws IOException {
+ runTestRoundTrip(LARGE, 10, ".tfrecords", NONE, AUTO);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void runTestRoundTripGzipFilesWithAuto() throws IOException {
+ runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, AUTO);
+ }
+
+ @Test
+ @Category(NeedsRunner.class)
+ public void runTestRoundTripZlibFilesWithAuto() throws IOException {
+ runTestRoundTrip(LARGE, 10, ".tfrecords", ZLIB, AUTO);
+ }
+
+ private void runTestRoundTrip(Iterable<String> elems,
+ int numShards,
+ String suffix,
+ CompressionType writeCompressionType,
+ CompressionType readCompressionType) throws IOException {
+ String outputName = "file";
+ Path baseDir = Files.createTempDirectory(tempFolder, "test-rt");
+ String baseFilename = baseDir.resolve(outputName).toString();
+
+ TFRecordIO.Write.Bound write = TFRecordIO.Write.to(baseFilename)
+ .withNumShards(numShards)
+ .withSuffix(suffix)
+ .withCompressionType(writeCompressionType);
+ p.apply(Create.of(elems).withCoder(StringUtf8Coder.of()))
+ .apply(ParDo.of(new StringToByteArray()))
+ .apply(write);
+ p.run();
+
+ TFRecordIO.Read.Bound read = TFRecordIO.Read.from(baseFilename + "*")
+ .withCompressionType(readCompressionType);
+ PCollection<String> output = p2.apply(read).apply(ParDo.of(new ByteArrayToString()));
+
+ PAssert.that(output).containsInAnyOrder(elems);
+ p2.run();
+ }
+
+ private static Iterable<String> makeLines(int n) {
+ List<String> ret = Lists.newArrayList();
+ for (int i = 0; i < n; ++i) {
+ ret.add("word" + i);
+ }
+ return ret;
+ }
+
+ static class ByteArrayToString extends DoFn<byte[], String> {
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ c.output(new String(c.element()));
+ }
+ }
+
+ static class StringToByteArray extends DoFn<String, byte[]> {
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ c.output(c.element().getBytes());
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
index cd94dc5..713cb71 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
@@ -206,7 +206,7 @@ public class TextIOTest {
}
@AfterClass
- public static void testdownClass() throws IOException {
+ public static void teardownClass() throws IOException {
Files.walkFileTree(tempFolder, new SimpleFileVisitor<Path>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
[2/2] beam git commit: This closes #2066
Posted by ch...@apache.org.
This closes #2066
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/92d1a663
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/92d1a663
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/92d1a663
Branch: refs/heads/master
Commit: 92d1a66351512ca2d65125a173903d031e4ac2d6
Parents: bea4f5a 68d42f9
Author: Chamikara Jayalath <ch...@google.com>
Authored: Tue Mar 21 11:24:04 2017 -0700
Committer: Chamikara Jayalath <ch...@google.com>
Committed: Tue Mar 21 11:24:04 2017 -0700
----------------------------------------------------------------------
.../apache/beam/sdk/io/CompressedSource.java | 13 +-
.../java/org/apache/beam/sdk/io/TFRecordIO.java | 905 +++++++++++++++++++
.../java/org/apache/beam/sdk/io/TextIO.java | 2 +-
.../org/apache/beam/sdk/io/TFRecordIOTest.java | 368 ++++++++
.../java/org/apache/beam/sdk/io/TextIOTest.java | 2 +-
5 files changed, 1285 insertions(+), 5 deletions(-)
----------------------------------------------------------------------