You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ch...@apache.org on 2017/03/21 18:25:45 UTC

[1/2] beam git commit: add TensorFlow TFRecordIO

Repository: beam
Updated Branches:
  refs/heads/master bea4f5aec -> 92d1a6635


add TensorFlow TFRecordIO


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/68d42f9b
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/68d42f9b
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/68d42f9b

Branch: refs/heads/master
Commit: 68d42f9b941dab420947a8aaa616070b1f3fb5f8
Parents: bea4f5a
Author: Neville Li <ne...@spotify.com>
Authored: Tue Feb 21 15:51:18 2017 -0500
Committer: Chamikara Jayalath <ch...@google.com>
Committed: Tue Mar 21 11:23:02 2017 -0700

----------------------------------------------------------------------
 .../apache/beam/sdk/io/CompressedSource.java    |  13 +-
 .../java/org/apache/beam/sdk/io/TFRecordIO.java | 905 +++++++++++++++++++
 .../java/org/apache/beam/sdk/io/TextIO.java     |   2 +-
 .../org/apache/beam/sdk/io/TFRecordIOTest.java  | 368 ++++++++
 .../java/org/apache/beam/sdk/io/TextIOTest.java |   2 +-
 5 files changed, 1285 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
index 6de22f9..ecd0fd9 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/CompressedSource.java
@@ -334,8 +334,14 @@ public class CompressedSource<T> extends FileBasedSource<T> {
     super(filePatternOrSpec, minBundleSize, startOffset, endOffset);
     this.sourceDelegate = sourceDelegate;
     this.channelFactory = channelFactory;
+    boolean splittable = false;
+    try {
+      splittable = isSplittable();
+    } catch (Exception e) {
+      throw new RuntimeException("Failed to determine if the source is splittable", e);
+    }
     checkArgument(
-        isSplittable() || startOffset == 0,
+        splittable || startOffset == 0,
         "CompressedSources must start reading at offset 0. Requested offset: " + startOffset);
   }
 
@@ -366,11 +372,12 @@ public class CompressedSource<T> extends FileBasedSource<T> {
    * from the requested file name that the file is not compressed.
    */
   @Override
-  protected final boolean isSplittable() {
+  protected final boolean isSplittable() throws Exception {
     if (channelFactory instanceof FileNameBasedDecompressingChannelFactory) {
       FileNameBasedDecompressingChannelFactory fileNameBasedChannelFactory =
           (FileNameBasedDecompressingChannelFactory) channelFactory;
-      return !fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec());
+      return !fileNameBasedChannelFactory.isCompressed(getFileOrPatternSpec())
+          && sourceDelegate.isSplittable();
     }
     return false;
   }

http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
new file mode 100644
index 0000000..243506c
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
@@ -0,0 +1,905 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.hash.HashFunction;
+import com.google.common.hash.Hashing;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.channels.ReadableByteChannel;
+import java.nio.channels.WritableByteChannel;
+import java.util.NoSuchElementException;
+import java.util.regex.Pattern;
+
+import javax.annotation.Nullable;
+
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.sdk.coders.VoidCoder;
+import org.apache.beam.sdk.io.Read.Bounded;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.options.ValueProvider;
+import org.apache.beam.sdk.options.ValueProvider.StaticValueProvider;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.util.IOChannelUtils;
+import org.apache.beam.sdk.util.MimeTypes;
+import org.apache.beam.sdk.values.PBegin;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PDone;
+
+/**
+ * {@link PTransform}s for reading and writing TensorFlow TFRecord files.
+ */
+public class TFRecordIO {
+  /** The default coder, which returns each record of the input file as a byte array. */
+  public static final Coder<byte[]> DEFAULT_BYTE_ARRAY_CODER = ByteArrayCoder.of();
+
+  /**
+   * A {@link PTransform} that reads from a TFRecord file (or multiple TFRecord
+   * files matching a pattern) and returns a {@link PCollection} containing
+   * the decoding of each of the records of the TFRecord file(s) as a byte array.
+   */
+  public static class Read {
+
+    /**
+     * Returns a transform for reading TFRecord files that reads from the file(s)
+     * with the given filename or filename pattern. This can be a local path (if running locally),
+     * or a Google Cloud Storage filename or filename pattern of the form
+     * {@code "gs://<bucket>/<filepath>"} (if running locally or via the Google Cloud Dataflow
+     * service). Standard <a href="http://docs.oracle.com/javase/tutorial/essential/io/find.html"
+     * >Java Filesystem glob patterns</a> ("*", "?", "[..]") are supported.
+     */
+    public static Bound from(String filepattern) {
+      return new Bound().from(filepattern);
+    }
+
+    /**
+     * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
+     */
+    public static Bound from(ValueProvider<String> filepattern) {
+      return new Bound().from(filepattern);
+    }
+    /**
+     * Returns a transform for reading TFRecord files that has GCS path validation on
+     * pipeline creation disabled.
+     *
+     * <p>This can be useful in the case where the GCS input does not
+     * exist at the pipeline creation time, but is expected to be
+     * available at execution time.
+     */
+    public static Bound withoutValidation() {
+      return new Bound().withoutValidation();
+    }
+
+    /**
+     * Returns a transform for reading TFRecord files that decompresses all input files
+     * using the specified compression type.
+     *
+     * <p>If no compression type is specified, the default is
+     * {@link TFRecordIO.CompressionType#AUTO}.
+     * In this mode, the compression type of the file is determined by its extension
+     * (e.g., {@code *.gz} is gzipped, {@code *.zlib} is zlib compressed, and all other
+     * extensions are uncompressed).
+     */
+    public static Bound withCompressionType(TFRecordIO.CompressionType compressionType) {
+      return new Bound().withCompressionType(compressionType);
+    }
+
+    /**
+     * A {@link PTransform} that reads from one or more TFRecord files and returns a bounded
+     * {@link PCollection} containing one element for each record of the input files.
+     */
+    public static class Bound extends PTransform<PBegin, PCollection<byte[]>> {
+      /** The filepattern to read from. */
+      @Nullable private final ValueProvider<String> filepattern;
+
+      /** An option to indicate if input validation is desired. Default is true. */
+      private final boolean validate;
+
+      /** Option to indicate the input source's compression type. Default is AUTO. */
+      private final TFRecordIO.CompressionType compressionType;
+
+      private Bound() {
+        this(null, null, true, TFRecordIO.CompressionType.AUTO);
+      }
+
+      private Bound(
+          @Nullable String name,
+          @Nullable ValueProvider<String> filepattern,
+          boolean validate,
+          TFRecordIO.CompressionType compressionType) {
+        super(name);
+        this.filepattern = filepattern;
+        this.validate = validate;
+        this.compressionType = compressionType;
+      }
+
+      /**
+       * Returns a new transform for reading from TFRecord files that's like this one but that
+       * reads from the file(s) with the given name or pattern. See {@link TFRecordIO.Read#from}
+       * for a description of filepatterns.
+       *
+       * <p>Does not modify this object.
+
+       */
+      public Bound from(String filepattern) {
+        checkNotNull(filepattern, "Filepattern cannot be empty.");
+        return new Bound(name, StaticValueProvider.of(filepattern), validate, compressionType);
+      }
+
+      /**
+       * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
+       */
+      public Bound from(ValueProvider<String> filepattern) {
+        checkNotNull(filepattern, "Filepattern cannot be empty.");
+        return new Bound(name, filepattern, validate, compressionType);
+      }
+
+      /**
+       * Returns a new transform for reading from TFRecord files that's like this one but
+       * that has GCS path validation on pipeline creation disabled.
+       *
+       * <p>This can be useful in the case where the GCS input does not
+       * exist at the pipeline creation time, but is expected to be
+       * available at execution time.
+       *
+       * <p>Does not modify this object.
+       */
+      public Bound withoutValidation() {
+        return new Bound(name, filepattern, false, compressionType);
+      }
+
+      /**
+       * Returns a new transform for reading from TFRecord files that's like this one but
+       * reads from input sources using the specified compression type.
+       *
+       * <p>If no compression type is specified, the default is
+       * {@link TFRecordIO.CompressionType#AUTO}.
+       * See {@link TFRecordIO.Read#withCompressionType} for more details.
+       *
+       * <p>Does not modify this object.
+       */
+      public Bound withCompressionType(TFRecordIO.CompressionType compressionType) {
+        return new Bound(name, filepattern, validate, compressionType);
+      }
+
+      @Override
+      public PCollection<byte[]> expand(PBegin input) {
+        if (filepattern == null) {
+          throw new IllegalStateException(
+              "Need to set the filepattern of a TFRecordIO.Read transform");
+        }
+
+        if (validate) {
+          checkState(filepattern.isAccessible(), "Cannot validate with a RVP.");
+          try {
+            checkState(
+                !IOChannelUtils.getFactory(filepattern.get()).match(filepattern.get()).isEmpty(),
+                "Unable to find any files matching %s",
+                filepattern);
+          } catch (IOException e) {
+            throw new IllegalStateException(
+                String.format("Failed to validate %s", filepattern.get()), e);
+          }
+        }
+
+        final Bounded<byte[]> read = org.apache.beam.sdk.io.Read.from(getSource());
+        PCollection<byte[]> pcol = input.getPipeline().apply("Read", read);
+        // Honor the default output coder that would have been used by this PTransform.
+        pcol.setCoder(getDefaultOutputCoder());
+        return pcol;
+      }
+
+      // Helper to create a source specific to the requested compression type.
+      protected FileBasedSource<byte[]> getSource() {
+        switch (compressionType) {
+          case NONE:
+            return new TFRecordSource(filepattern);
+          case AUTO:
+            return CompressedSource.from(new TFRecordSource(filepattern));
+          case GZIP:
+            return
+                CompressedSource.from(new TFRecordSource(filepattern))
+                    .withDecompression(CompressedSource.CompressionMode.GZIP);
+          case ZLIB:
+            return
+                CompressedSource.from(new TFRecordSource(filepattern))
+                    .withDecompression(CompressedSource.CompressionMode.DEFLATE);
+          default:
+            throw new IllegalArgumentException("Unknown compression type: " + compressionType);
+        }
+      }
+
+      @Override
+      public void populateDisplayData(DisplayData.Builder builder) {
+        super.populateDisplayData(builder);
+
+        String filepatternDisplay = filepattern.isAccessible()
+            ? filepattern.get() : filepattern.toString();
+        builder
+            .add(DisplayData.item("compressionType", compressionType.toString())
+                .withLabel("Compression Type"))
+            .addIfNotDefault(DisplayData.item("validation", validate)
+                .withLabel("Validation Enabled"), true)
+            .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
+                .withLabel("File Pattern"));
+      }
+
+      @Override
+      protected Coder<byte[]> getDefaultOutputCoder() {
+        return DEFAULT_BYTE_ARRAY_CODER;
+      }
+
+      public String getFilepattern() {
+        return filepattern.get();
+      }
+
+      public boolean needsValidation() {
+        return validate;
+      }
+
+      public TFRecordIO.CompressionType getCompressionType() {
+        return compressionType;
+      }
+    }
+
+    /** Disallow construction of utility classes. */
+    private Read() {}
+  }
+
+  /////////////////////////////////////////////////////////////////////////////
+
+  /**
+   * A {@link PTransform} that writes a {@link PCollection} to TFRecord file (or
+   * multiple TFRecord files matching a sharding pattern), with each
+   * element of the input collection encoded into its own record.
+   */
+  public static class Write {
+
+    /**
+     * Returns a transform for writing to TFRecord files that writes to the file(s)
+     * with the given prefix. This can be a local filename
+     * (if running locally), or a Google Cloud Storage filename of
+     * the form {@code "gs://<bucket>/<filepath>"}
+     * (if running locally or via the Google Cloud Dataflow service).
+     *
+     * <p>The files written will begin with this prefix, followed by
+     * a shard identifier (see {@link TFRecordIO.Write.Bound#withNumShards(int)}, and end
+     * in a common extension, if given by {@link TFRecordIO.Write.Bound#withSuffix(String)}.
+     */
+    public static Bound to(String prefix) {
+      return new Bound().to(prefix);
+    }
+
+    /**
+     * Like {@link #to(String)}, but with a {@link ValueProvider}.
+     */
+    public static Bound to(ValueProvider<String> prefix) {
+      return new Bound().to(prefix);
+    }
+
+    /**
+     * Returns a transform for writing to TFRecord files that appends the specified suffix
+     * to the created files.
+     */
+    public static Bound withSuffix(String nameExtension) {
+      return new Bound().withSuffix(nameExtension);
+    }
+
+    /**
+     * Returns a transform for writing to TFRecord files that uses the provided shard count.
+     *
+     * <p>Constraining the number of shards is likely to reduce
+     * the performance of a pipeline. Setting this value is not recommended
+     * unless you require a specific number of output files.
+     *
+     * @param numShards the number of shards to use, or 0 to let the system
+     *                  decide.
+     */
+    public static Bound withNumShards(int numShards) {
+      return new Bound().withNumShards(numShards);
+    }
+
+    /**
+     * Returns a transform for writing to TFRecord files that uses the given shard name
+     * template.
+     *
+     * <p>See {@link ShardNameTemplate} for a description of shard templates.
+     */
+    public static Bound withShardNameTemplate(String shardTemplate) {
+      return new Bound().withShardNameTemplate(shardTemplate);
+    }
+
+    /**
+     * Returns a transform for writing to TFRecord files that forces a single file as
+     * output.
+     */
+    public static Bound withoutSharding() {
+      return new Bound().withoutSharding();
+    }
+
+    /**
+     * Returns a transform for writing to text files that has GCS path validation on
+     * pipeline creation disabled.
+     *
+     * <p>This can be useful in the case where the GCS output location does
+     * not exist at the pipeline creation time, but is expected to be available
+     * at execution time.
+     */
+    public static Bound withoutValidation() {
+      return new Bound().withoutValidation();
+    }
+
+    /**
+     * Returns a transform for writing to TFRecord files like this one but writes to output files
+     * using the specified compression type.
+     *
+     * <p>If no compression type is specified, the default is
+     * {@link TFRecordIO.CompressionType#NONE}.
+     * See {@link TFRecordIO.Read#withCompressionType} for more details.
+     */
+    public static Bound withCompressionType(CompressionType compressionType) {
+      return new Bound().withCompressionType(compressionType);
+    }
+
+    /**
+     * A PTransform that writes a bounded PCollection to a TFRecord file (or
+     * multiple TFRecord files matching a sharding pattern), with each
+     * PCollection element being encoded into its own record.
+     */
+    public static class Bound extends PTransform<PCollection<byte[]>, PDone> {
+      private static final String DEFAULT_SHARD_TEMPLATE = ShardNameTemplate.INDEX_OF_MAX;
+
+      /** The prefix of each file written, combined with suffix and shardTemplate. */
+      private final ValueProvider<String> filenamePrefix;
+      /** The suffix of each file written, combined with prefix and shardTemplate. */
+      private final String filenameSuffix;
+
+      /** Requested number of shards. 0 for automatic. */
+      private final int numShards;
+
+      /** The shard template of each file written, combined with prefix and suffix. */
+      private final String shardTemplate;
+
+      /** An option to indicate if output validation is desired. Default is true. */
+      private final boolean validate;
+
+      /** Option to indicate the output sink's compression type. Default is NONE. */
+      private final TFRecordIO.CompressionType compressionType;
+
+      private Bound() {
+        this(null, null, "", 0, DEFAULT_SHARD_TEMPLATE, true, TFRecordIO.CompressionType.NONE);
+      }
+
+      private Bound(String name, ValueProvider<String> filenamePrefix, String filenameSuffix,
+                   int numShards, String shardTemplate, boolean validate,
+                   CompressionType compressionType) {
+        super(name);
+        this.filenamePrefix = filenamePrefix;
+        this.filenameSuffix = filenameSuffix;
+        this.numShards = numShards;
+        this.shardTemplate = shardTemplate;
+        this.validate = validate;
+        this.compressionType = compressionType;
+      }
+
+      /**
+       * Returns a transform for writing to TFRecord files that's like this one but
+       * that writes to the file(s) with the given filename prefix.
+       *
+       * <p>See {@link TFRecordIO.Write#to(String) Write.to(String)} for more information.
+       *
+       * <p>Does not modify this object.
+       */
+      public Bound to(String filenamePrefix) {
+        validateOutputComponent(filenamePrefix);
+        return new Bound(name, StaticValueProvider.of(filenamePrefix), filenameSuffix, numShards,
+            shardTemplate, validate, compressionType);
+      }
+
+      /**
+       * Like {@link #to(String)}, but with a {@link ValueProvider}.
+       */
+      public Bound to(ValueProvider<String> filenamePrefix) {
+        return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate,
+            compressionType);
+      }
+
+      /**
+       * Returns a transform for writing to TFRecord files that that's like this one but
+       * that writes to the file(s) with the given filename suffix.
+       *
+       * <p>Does not modify this object.
+       *
+       * @see ShardNameTemplate
+       */
+      public Bound withSuffix(String nameExtension) {
+        validateOutputComponent(nameExtension);
+        return new Bound(name, filenamePrefix, nameExtension, numShards, shardTemplate, validate,
+            compressionType);
+      }
+
+      /**
+       * Returns a transform for writing to TFRecord files that's like this one but
+       * that uses the provided shard count.
+       *
+       * <p>Constraining the number of shards is likely to reduce
+       * the performance of a pipeline. Setting this value is not recommended
+       * unless you require a specific number of output files.
+       *
+       * <p>Does not modify this object.
+       *
+       * @param numShards the number of shards to use, or 0 to let the system
+       *                  decide.
+       * @see ShardNameTemplate
+       */
+      public Bound withNumShards(int numShards) {
+        checkArgument(numShards >= 0);
+        return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate,
+            compressionType);
+      }
+
+      /**
+       * Returns a transform for writing to TFRecord files that's like this one but
+       * that uses the given shard name template.
+       *
+       * <p>Does not modify this object.
+       *
+       * @see ShardNameTemplate
+       */
+      public Bound withShardNameTemplate(String shardTemplate) {
+        return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate,
+            compressionType);
+      }
+
+      /**
+       * Returns a transform for writing to TFRecord files that's like this one but
+       * that forces a single file as output.
+       *
+       * <p>Constraining the number of shards is likely to reduce
+       * the performance of a pipeline. Using this setting is not recommended
+       * unless you truly require a single output file.
+       *
+       * <p>This is a shortcut for
+       * {@code .withNumShards(1).withShardNameTemplate("")}
+       *
+       * <p>Does not modify this object.
+       */
+      public Bound withoutSharding() {
+        return new Bound(name, filenamePrefix, filenameSuffix, 1, "",
+            validate, compressionType);
+      }
+
+      /**
+       * Returns a transform for writing to TFRecord files that's like this one but
+       * that has GCS output path validation on pipeline creation disabled.
+       *
+       * <p>This can be useful in the case where the GCS output location does
+       * not exist at the pipeline creation time, but is expected to be
+       * available at execution time.
+       *
+       * <p>Does not modify this object.
+       */
+      public Bound withoutValidation() {
+        return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, false,
+            compressionType);
+      }
+
+      /**
+       * Returns a transform for writing to TFRecord files like this one but writes to output files
+       * using the specified compression type.
+       *
+       * <p>If no compression type is specified, the default is
+       * {@link TFRecordIO.CompressionType#NONE}.
+       * See {@link TFRecordIO.Read#withCompressionType} for more details.
+       *
+       * <p>Does not modify this object.
+       */
+      public Bound withCompressionType(CompressionType compressionType) {
+        return new Bound(name, filenamePrefix, filenameSuffix, numShards, shardTemplate, validate,
+            compressionType);
+      }
+
+      @Override
+      public PDone expand(PCollection<byte[]> input) {
+        if (filenamePrefix == null) {
+          throw new IllegalStateException(
+              "need to set the filename prefix of a TFRecordIO.Write transform");
+        }
+        org.apache.beam.sdk.io.Write<byte[]> write =
+            org.apache.beam.sdk.io.Write.to(
+                new TFRecordSink(filenamePrefix, filenameSuffix, shardTemplate, compressionType));
+        if (getNumShards() > 0) {
+          write = write.withNumShards(getNumShards());
+        }
+        return input.apply("Write", write);
+      }
+
+      @Override
+      public void populateDisplayData(DisplayData.Builder builder) {
+        super.populateDisplayData(builder);
+
+        String prefixString = filenamePrefix.isAccessible()
+            ? filenamePrefix.get() : filenamePrefix.toString();
+        builder
+            .addIfNotNull(DisplayData.item("filePrefix", prefixString)
+                .withLabel("Output File Prefix"))
+            .addIfNotDefault(DisplayData.item("fileSuffix", filenameSuffix)
+                .withLabel("Output File Suffix"), "")
+            .addIfNotDefault(DisplayData.item("shardNameTemplate", shardTemplate)
+                    .withLabel("Output Shard Name Template"),
+                DEFAULT_SHARD_TEMPLATE)
+            .addIfNotDefault(DisplayData.item("validation", validate)
+                .withLabel("Validation Enabled"), true)
+            .addIfNotDefault(DisplayData.item("numShards", numShards)
+                .withLabel("Maximum Output Shards"), 0)
+            .add(DisplayData
+                .item("compressionType", compressionType.toString())
+                .withLabel("Compression Type"));
+      }
+
+      /**
+       * Returns the current shard name template string.
+       */
+      public String getShardNameTemplate() {
+        return shardTemplate;
+      }
+
+      @Override
+      protected Coder<Void> getDefaultOutputCoder() {
+        return VoidCoder.of();
+      }
+
+      public String getFilenamePrefix() {
+        return filenamePrefix.get();
+      }
+
+      public String getShardTemplate() {
+        return shardTemplate;
+      }
+
+      public int getNumShards() {
+        return numShards;
+      }
+
+      public String getFilenameSuffix() {
+        return filenameSuffix;
+      }
+
+      public boolean needsValidation() {
+        return validate;
+      }
+    }
+  }
+
+  /**
+   * Possible TFRecord file compression types.
+   */
+  public enum CompressionType {
+    /**
+     * Automatically determine the compression type based on filename extension.
+     */
+    AUTO(""),
+    /**
+     * Uncompressed.
+     */
+    NONE(""),
+    /**
+     * GZipped.
+     */
+    GZIP(".gz"),
+    /**
+     * ZLIB compressed.
+     */
+    ZLIB(".zlib");
+
+    private String filenameSuffix;
+
+    CompressionType(String suffix) {
+      this.filenameSuffix = suffix;
+    }
+
+    /**
+     * Determine if a given filename matches a compression type based on its extension.
+     * @param filename the filename to match
+     * @return true iff the filename ends with the compression type's known extension.
+     */
+    public boolean matches(String filename) {
+      return filename.toLowerCase().endsWith(filenameSuffix.toLowerCase());
+    }
+  }
+
+  // Pattern which matches old-style shard output patterns, which are now
+  // disallowed.
+  private static final Pattern SHARD_OUTPUT_PATTERN = Pattern.compile("@([0-9]+|\\*)");
+
+  private static void validateOutputComponent(String partialFilePattern) {
+    checkArgument(
+        !SHARD_OUTPUT_PATTERN.matcher(partialFilePattern).find(),
+        "Output name components are not allowed to contain @* or @N patterns: "
+            + partialFilePattern);
+  }
+
+  //////////////////////////////////////////////////////////////////////////////
+
+  /** Disable construction of utility class. */
+  private TFRecordIO() {}
+
+  /**
+   * A {@link FileBasedSource} which can decode records in TFRecord files.
+   */
+  @VisibleForTesting
+  static class TFRecordSource extends FileBasedSource<byte[]> {
+    @VisibleForTesting
+    TFRecordSource(String fileSpec) {
+      super(fileSpec, 1L);
+    }
+
+    @VisibleForTesting
+    TFRecordSource(ValueProvider<String> fileSpec) {
+      super(fileSpec, Long.MAX_VALUE);
+    }
+
+    private TFRecordSource(String fileName, long start, long end) {
+      super(fileName, Long.MAX_VALUE, start, end);
+    }
+
+    @Override
+    protected FileBasedSource<byte[]> createForSubrangeOfFile(
+        String fileName,
+        long start,
+        long end) {
+      checkArgument(start == 0, "TFRecordSource is not splittable");
+      return new TFRecordSource(fileName, start, end);
+    }
+
+    @Override
+    protected FileBasedReader<byte[]> createSingleFileReader(PipelineOptions options) {
+      return new TFRecordReader(this);
+    }
+
+    @Override
+    public Coder<byte[]> getDefaultOutputCoder() {
+      return DEFAULT_BYTE_ARRAY_CODER;
+    }
+
+    @Override
+    protected boolean isSplittable() throws Exception {
+      // TFRecord files are not splittable
+      return false;
+    }
+
+    /**
+     * A {@link org.apache.beam.sdk.io.FileBasedSource.FileBasedReader FileBasedReader}
+     * which can decode records in TFRecord files.
+     *
+     * <p>See {@link TFRecordIO.TFRecordSource} for further details.
+     */
+    @VisibleForTesting
+    static class TFRecordReader extends FileBasedReader<byte[]> {
+      private long startOfRecord;
+      private volatile long startOfNextRecord;
+      private volatile boolean elementIsPresent;
+      private byte[] currentValue;
+      private ReadableByteChannel inChannel;
+      private TFRecordCodec codec;
+
+      private TFRecordReader(TFRecordSource source) {
+        super(source);
+      }
+
+      @Override
+      protected long getCurrentOffset() throws NoSuchElementException {
+        if (!elementIsPresent) {
+          throw new NoSuchElementException();
+        }
+        return startOfRecord;
+      }
+
+      @Override
+      public byte[] getCurrent() throws NoSuchElementException {
+        if (!elementIsPresent) {
+          throw new NoSuchElementException();
+        }
+        return currentValue;
+      }
+
+      @Override
+      protected void startReading(ReadableByteChannel channel) throws IOException {
+        this.inChannel = channel;
+        this.codec = new TFRecordCodec();
+      }
+
+      @Override
+      protected boolean readNextRecord() throws IOException {
+        startOfRecord = startOfNextRecord;
+        currentValue = codec.read(inChannel);
+        if (currentValue != null) {
+          elementIsPresent = true;
+          startOfNextRecord = startOfRecord + codec.recordLength(currentValue);
+          return true;
+        } else {
+          elementIsPresent = false;
+          return false;
+        }
+      }
+    }
+  }
+
+  /**
+   * A {@link FileBasedSink} for TFRecord files. Produces TFRecord files.
+   */
+  @VisibleForTesting
+  static class TFRecordSink extends FileBasedSink<byte[]> {
+    @VisibleForTesting
+    TFRecordSink(ValueProvider<String> baseOutputFilename,
+                 String extension,
+                 String fileNameTemplate,
+                 TFRecordIO.CompressionType compressionType) {
+      super(baseOutputFilename, extension, fileNameTemplate,
+          writableByteChannelFactory(compressionType));
+    }
+
+    @Override
+    public FileBasedWriteOperation<byte[]> createWriteOperation(PipelineOptions options) {
+      return new TFRecordWriteOperation(this);
+    }
+
+    private static WritableByteChannelFactory writableByteChannelFactory(
+        TFRecordIO.CompressionType compressionType) {
+      switch (compressionType) {
+        case AUTO:
+          throw new IllegalArgumentException("Unsupported compression type AUTO");
+        case NONE:
+          return CompressionType.UNCOMPRESSED;
+        case GZIP:
+          return CompressionType.GZIP;
+        case ZLIB:
+          return CompressionType.DEFLATE;
+      }
+      return CompressionType.UNCOMPRESSED;
+    }
+
+    /**
+     * A {@link org.apache.beam.sdk.io.FileBasedSink.FileBasedWriteOperation
+     * FileBasedWriteOperation} for TFRecord files.
+     */
+    private static class TFRecordWriteOperation extends FileBasedWriteOperation<byte[]> {
+      private TFRecordWriteOperation(TFRecordSink sink) {
+        super(sink);
+      }
+
+      @Override
+      public FileBasedWriter<byte[]> createWriter(PipelineOptions options) throws Exception {
+        return new TFRecordWriter(this);
+      }
+    }
+
+    /**
+     * A {@link org.apache.beam.sdk.io.FileBasedSink.FileBasedWriter FileBasedWriter}
+     * for TFRecord files.
+     */
+    private static class TFRecordWriter extends FileBasedWriter<byte[]> {
+      private WritableByteChannel outChannel;
+      private TFRecordCodec codec;
+
+      private TFRecordWriter(FileBasedWriteOperation<byte[]> writeOperation) {
+        super(writeOperation);
+        this.mimeType = MimeTypes.BINARY;
+      }
+
+      @Override
+      protected void prepareWrite(WritableByteChannel channel) throws Exception {
+        this.outChannel = channel;
+        this.codec = new TFRecordCodec();
+      }
+
+      @Override
+      public void write(byte[] value) throws Exception {
+        codec.write(outChannel, value);
+      }
+    }
+  }
+
+  //////////////////////////////////////////////////////////////////////////////
+
+  /**
+   * Codec for TFRecords file format.
+   * See https://www.tensorflow.org/api_guides/python/python_io#TFRecords_Format_Details
+   */
+  private static class TFRecordCodec {
+    private static final int HEADER_LEN = (Long.SIZE + Integer.SIZE) / Byte.SIZE;
+    private static final int FOOTER_LEN = Integer.SIZE / Byte.SIZE;
+    private static HashFunction crc32c = Hashing.crc32c();
+
+    private ByteBuffer header = ByteBuffer.allocate(HEADER_LEN).order(ByteOrder.LITTLE_ENDIAN);
+    private ByteBuffer footer = ByteBuffer.allocate(FOOTER_LEN).order(ByteOrder.LITTLE_ENDIAN);
+
+    private int mask(int crc) {
+      return ((crc >>> 15) | (crc << 17)) + 0xa282ead8;
+    }
+
+    private int hashLong(long x) {
+      return mask(crc32c.hashLong(x).asInt());
+    }
+
+    private int hashBytes(byte[] x) {
+      return mask(crc32c.hashBytes(x).asInt());
+    }
+
+    public int recordLength(byte[] data) {
+      return HEADER_LEN + data.length + FOOTER_LEN;
+    }
+
+    public byte[] read(ReadableByteChannel inChannel) throws IOException {
+      header.clear();
+      int headerBytes = inChannel.read(header);
+      if (headerBytes <= 0) {
+        return null;
+      }
+      checkState(
+          headerBytes == HEADER_LEN,
+          "Not a valid TFRecord. Fewer than 12 bytes.");
+      header.rewind();
+      long length = header.getLong();
+      int maskedCrc32OfLength = header.getInt();
+      checkState(
+          hashLong(length) == maskedCrc32OfLength,
+          "Mismatch of length mask");
+
+      ByteBuffer data = ByteBuffer.allocate((int) length);
+      checkState(inChannel.read(data) == length, "Invalid data");
+
+      footer.clear();
+      inChannel.read(footer);
+      footer.rewind();
+      int maskedCrc32OfData = footer.getInt();
+
+      checkState(
+          hashBytes(data.array()) == maskedCrc32OfData,
+          "Mismatch of data mask");
+      return data.array();
+    }
+
+    public void write(WritableByteChannel outChannel, byte[] data) throws IOException {
+      int maskedCrc32OfLength = hashLong(data.length);
+      int maskedCrc32OfData = hashBytes(data);
+
+      header.clear();
+      header.putLong(data.length).putInt(maskedCrc32OfLength);
+      header.rewind();
+      outChannel.write(header);
+
+      outChannel.write(ByteBuffer.wrap(data));
+
+      footer.clear();
+      footer.putInt(maskedCrc32OfData);
+      footer.rewind();
+      outChannel.write(footer);
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
index fe8d0fd..f8943a5 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
@@ -682,7 +682,7 @@ public class TextIO {
             .addIfNotNull(DisplayData.item("filePrefix", prefixString)
               .withLabel("Output File Prefix"))
             .addIfNotDefault(DisplayData.item("fileSuffix", filenameSuffix)
-              .withLabel("Output Fix Suffix"), "")
+              .withLabel("Output File Suffix"), "")
             .addIfNotDefault(DisplayData.item("shardNameTemplate", shardTemplate)
               .withLabel("Output Shard Name Template"),
                 DEFAULT_SHARD_TEMPLATE)

http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
new file mode 100644
index 0000000..70620fb
--- /dev/null
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io;
+
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType;
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.AUTO;
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.GZIP;
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.NONE;
+import static org.apache.beam.sdk.io.TFRecordIO.CompressionType.ZLIB;
+import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
+import static org.hamcrest.Matchers.isIn;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.BaseEncoding;
+import com.google.common.io.ByteStreams;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.file.FileVisitResult;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.SimpleFileVisitor;
+import java.nio.file.attribute.BasicFileAttributes;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.sdk.testing.NeedsRunner;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.display.DisplayData;
+import org.apache.beam.sdk.values.PCollection;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Tests for TFRecordIO Read and Write transforms.
+ */
+@RunWith(JUnit4.class)
+public class TFRecordIOTest {
+
+  /*
+  From https://github.com/apache/beam/blob/master/sdks/python/apache_beam/io/tfrecordio_test.py
+  Created by running following code in python:
+  >>> import tensorflow as tf
+  >>> import base64
+  >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
+  >>> writer.write('foo')
+  >>> writer.close()
+  >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
+  ...   data =  base64.b64encode(f.read())
+  ...   print data
+  */
+  private static final String FOO_RECORD_BASE64 = "AwAAAAAAAACwmUkOZm9vYYq+/g==";
+
+  // Same as above but containing two records ['foo', 'bar']
+  private static final String FOO_BAR_RECORD_BASE64 =
+      "AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=";
+  private static final String BAR_FOO_RECORD_BASE64 =
+      "AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4=";
+
+  private static final String[] FOO_RECORDS = {"foo"};
+  private static final String[] FOO_BAR_RECORDS = {"foo", "bar"};
+
+  private static final Iterable<String> EMPTY = Collections.emptyList();
+  private static final Iterable<String> LARGE = makeLines(5000);
+
+  private static Path tempFolder;
+
+  @Rule
+  public TestPipeline p = TestPipeline.create();
+
+  @Rule
+  public TestPipeline p2 = TestPipeline.create();
+
+  @Rule
+  public ExpectedException expectedException = ExpectedException.none();
+
+  @BeforeClass
+  public static void setupClass() throws IOException {
+    tempFolder = Files.createTempDirectory("TFRecordIOTest");
+  }
+
+  @AfterClass
+  public static void teardownClass() throws IOException {
+    Files.walkFileTree(tempFolder, new SimpleFileVisitor<Path>() {
+      @Override
+      public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
+        Files.delete(file);
+        return FileVisitResult.CONTINUE;
+      }
+
+      @Override
+      public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
+        Files.delete(dir);
+        return FileVisitResult.CONTINUE;
+      }
+    });
+  }
+
+  @Test
+  public void testReadNamed() {
+    p.enableAbandonedNodeEnforcement(false);
+
+    assertEquals(
+        "TFRecordIO.Read/Read.out",
+        p.apply(TFRecordIO.Read.withoutValidation().from("foo.*")).getName());
+    assertEquals(
+        "MyRead/Read.out",
+        p.apply("MyRead", TFRecordIO.Read.withoutValidation().from("foo.*")).getName());
+  }
+
+  @Test
+  public void testReadDisplayData() {
+    TFRecordIO.Read.Bound read = TFRecordIO.Read
+        .from("foo.*")
+        .withCompressionType(GZIP)
+        .withoutValidation();
+
+    DisplayData displayData = DisplayData.from(read);
+
+    assertThat(displayData, hasDisplayItem("filePattern", "foo.*"));
+    assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString()));
+    assertThat(displayData, hasDisplayItem("validation", false));
+  }
+
+  @Test
+  public void testWriteDisplayData() {
+    TFRecordIO.Write.Bound write = TFRecordIO.Write
+        .to("foo")
+        .withSuffix("bar")
+        .withShardNameTemplate("-SS-of-NN-")
+        .withNumShards(100)
+        .withCompressionType(GZIP)
+        .withoutValidation();
+
+    DisplayData displayData = DisplayData.from(write);
+
+    assertThat(displayData, hasDisplayItem("filePrefix", "foo"));
+    assertThat(displayData, hasDisplayItem("fileSuffix", "bar"));
+    assertThat(displayData, hasDisplayItem("shardNameTemplate", "-SS-of-NN-"));
+    assertThat(displayData, hasDisplayItem("numShards", 100));
+    assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString()));
+    assertThat(displayData, hasDisplayItem("validation", false));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadOne() throws Exception {
+    runTestRead(FOO_RECORD_BASE64, FOO_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadTwo() throws Exception {
+    runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWriteOne() throws Exception {
+    runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testWriteTwo() throws Exception {
+    runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, BAR_FOO_RECORD_BASE64);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidRecord() throws Exception {
+    expectedException.expect(IllegalStateException.class);
+    expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 bytes.");
+    System.out.println("abr".getBytes().length);
+    runTestRead("bar".getBytes(), new String[0]);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidLengthMask() throws Exception {
+    expectedException.expect(IllegalStateException.class);
+    expectedException.expectMessage("Mismatch of length mask");
+    byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+    data[9] += 1;
+    runTestRead(data, FOO_RECORDS);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void testReadInvalidDataMask() throws Exception {
+    expectedException.expect(IllegalStateException.class);
+    expectedException.expectMessage("Mismatch of data mask");
+    byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
+    data[16] += 1;
+    runTestRead(data, FOO_RECORDS);
+  }
+
+  private void runTestRead(String base64, String[] expected) throws IOException {
+    runTestRead(BaseEncoding.base64().decode(base64), expected);
+  }
+
+  private void runTestRead(byte[] data, String[] expected) throws IOException {
+    File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile();
+    String filename = tmpFile.getPath();
+
+    FileOutputStream fos = new FileOutputStream(tmpFile);
+    fos.write(data);
+    fos.close();
+
+    TFRecordIO.Read.Bound read = TFRecordIO.Read.from(filename);
+    PCollection<String> output = p.apply(read).apply(ParDo.of(new ByteArrayToString()));
+
+    PAssert.that(output).containsInAnyOrder(expected);
+    p.run();
+  }
+
+  private void runTestWrite(String[] elems, String ...base64) throws IOException {
+    File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile();
+    String filename = tmpFile.getPath();
+
+    PCollection<byte[]> input = p.apply(Create.of(Arrays.asList(elems)))
+        .apply(ParDo.of(new StringToByteArray()));
+
+    TFRecordIO.Write.Bound write = TFRecordIO.Write.to(filename).withoutSharding();
+    input.apply(write);
+
+    p.run();
+
+    FileInputStream fis = new FileInputStream(tmpFile);
+    String written = BaseEncoding.base64().encode(ByteStreams.toByteArray(fis));
+    // bytes written may vary depending the order of elems
+    assertThat(written, isIn(base64));
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTrip() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", NONE, NONE);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithEmptyData() throws IOException {
+    runTestRoundTrip(EMPTY, 10, ".tfrecords", NONE, NONE);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithOneShards() throws IOException {
+    runTestRoundTrip(LARGE, 1, ".tfrecords", NONE, NONE);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripWithSuffix() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".suffix", NONE, NONE);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripGzip() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, GZIP);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripZlib() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", ZLIB, ZLIB);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripUncompressedFilesWithAuto() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", NONE, AUTO);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripGzipFilesWithAuto() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, AUTO);
+  }
+
+  @Test
+  @Category(NeedsRunner.class)
+  public void runTestRoundTripZlibFilesWithAuto() throws IOException {
+    runTestRoundTrip(LARGE, 10, ".tfrecords", ZLIB, AUTO);
+  }
+
+  private void runTestRoundTrip(Iterable<String> elems,
+                                int numShards,
+                                String suffix,
+                                CompressionType writeCompressionType,
+                                CompressionType readCompressionType) throws IOException {
+    String outputName = "file";
+    Path baseDir = Files.createTempDirectory(tempFolder, "test-rt");
+    String baseFilename = baseDir.resolve(outputName).toString();
+
+    TFRecordIO.Write.Bound write = TFRecordIO.Write.to(baseFilename)
+        .withNumShards(numShards)
+        .withSuffix(suffix)
+        .withCompressionType(writeCompressionType);
+    p.apply(Create.of(elems).withCoder(StringUtf8Coder.of()))
+        .apply(ParDo.of(new StringToByteArray()))
+        .apply(write);
+    p.run();
+
+    TFRecordIO.Read.Bound read = TFRecordIO.Read.from(baseFilename + "*")
+        .withCompressionType(readCompressionType);
+    PCollection<String> output = p2.apply(read).apply(ParDo.of(new ByteArrayToString()));
+
+    PAssert.that(output).containsInAnyOrder(elems);
+    p2.run();
+  }
+
+  private static Iterable<String> makeLines(int n) {
+    List<String> ret = Lists.newArrayList();
+    for (int i = 0; i < n; ++i) {
+      ret.add("word" + i);
+    }
+    return ret;
+  }
+
+  static class ByteArrayToString extends DoFn<byte[], String> {
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      c.output(new String(c.element()));
+    }
+  }
+
+  static class StringToByteArray extends DoFn<String, byte[]> {
+    @ProcessElement
+    public void processElement(ProcessContext c) {
+      c.output(c.element().getBytes());
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/68d42f9b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
index cd94dc5..713cb71 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
@@ -206,7 +206,7 @@ public class TextIOTest {
   }
 
   @AfterClass
-  public static void testdownClass() throws IOException {
+  public static void teardownClass() throws IOException {
     Files.walkFileTree(tempFolder, new SimpleFileVisitor<Path>() {
       @Override
       public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {


[2/2] beam git commit: This closes #2066

Posted by ch...@apache.org.
This closes #2066


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/92d1a663
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/92d1a663
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/92d1a663

Branch: refs/heads/master
Commit: 92d1a66351512ca2d65125a173903d031e4ac2d6
Parents: bea4f5a 68d42f9
Author: Chamikara Jayalath <ch...@google.com>
Authored: Tue Mar 21 11:24:04 2017 -0700
Committer: Chamikara Jayalath <ch...@google.com>
Committed: Tue Mar 21 11:24:04 2017 -0700

----------------------------------------------------------------------
 .../apache/beam/sdk/io/CompressedSource.java    |  13 +-
 .../java/org/apache/beam/sdk/io/TFRecordIO.java | 905 +++++++++++++++++++
 .../java/org/apache/beam/sdk/io/TextIO.java     |   2 +-
 .../org/apache/beam/sdk/io/TFRecordIOTest.java  | 368 ++++++++
 .../java/org/apache/beam/sdk/io/TextIOTest.java |   2 +-
 5 files changed, 1285 insertions(+), 5 deletions(-)
----------------------------------------------------------------------