You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jk...@apache.org on 2017/04/18 21:21:29 UTC

[3/6] beam git commit: Gets rid of TFRecordIO.Read.Bound

Gets rid of TFRecordIO.Read.Bound


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/88d2c69c
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/88d2c69c
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/88d2c69c

Branch: refs/heads/master
Commit: 88d2c69c20b2869062991330ffa791daa85a3c84
Parents: 0be61db
Author: Eugene Kirpichov <ki...@google.com>
Authored: Mon Apr 17 17:21:34 2017 -0700
Committer: Eugene Kirpichov <ki...@google.com>
Committed: Tue Apr 18 14:03:42 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/beam/sdk/io/TFRecordIO.java | 261 ++++++++-----------
 .../org/apache/beam/sdk/io/TFRecordIOTest.java  |  10 +-
 2 files changed, 107 insertions(+), 164 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/88d2c69c/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
index 0552236..b4fd93c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
@@ -18,13 +18,11 @@
 package org.apache.beam.sdk.io;
 
 import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.common.base.Preconditions.checkNotNull;
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.hash.HashFunction;
 import com.google.common.hash.Hashing;
-
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
@@ -32,9 +30,7 @@ import java.nio.channels.ReadableByteChannel;
 import java.nio.channels.WritableByteChannel;
 import java.util.NoSuchElementException;
 import java.util.regex.Pattern;
-
 import javax.annotation.Nullable;
-
 import org.apache.beam.sdk.coders.ByteArrayCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.VoidCoder;
@@ -62,7 +58,20 @@ public class TFRecordIO {
    * files matching a pattern) and returns a {@link PCollection} containing
    * the decoding of each of the records of the TFRecord file(s) as a byte array.
    */
-  public static class Read {
+  public static Read read() {
+    return new Read();
+  }
+
+  /** Implementation of {@link #read}. */
+  public static class Read extends PTransform<PBegin, PCollection<byte[]>> {
+    /** The filepattern to read from. */
+    @Nullable private final ValueProvider<String> filepattern;
+
+    /** An option to indicate if input validation is desired. Default is true. */
+    private final boolean validate;
+
+    /** Option to indicate the input source's compression type. Default is AUTO. */
+    private final TFRecordIO.CompressionType compressionType;
 
     /**
      * Returns a transform for reading TFRecord files that reads from the file(s)
@@ -72,16 +81,17 @@ public class TFRecordIO {
      * execution). Standard <a href="http://docs.oracle.com/javase/tutorial/essential/io/find.html"
      * >Java Filesystem glob patterns</a> ("*", "?", "[..]") are supported.
      */
-    public static Bound from(String filepattern) {
-      return new Bound().from(filepattern);
+    public Read from(String filepattern) {
+      return new Read().from(filepattern);
     }
 
     /**
      * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
      */
-    public static Bound from(ValueProvider<String> filepattern) {
-      return new Bound().from(filepattern);
+    public Read from(ValueProvider<String> filepattern) {
+      return new Read().from(filepattern);
     }
+
     /**
      * Returns a transform for reading TFRecord files that has GCS path validation on
      * pipeline creation disabled.
@@ -90,8 +100,8 @@ public class TFRecordIO {
      * exist at the pipeline creation time, but is expected to be
      * available at execution time.
      */
-    public static Bound withoutValidation() {
-      return new Bound().withoutValidation();
+    public Read withoutValidation() {
+      return new Read().withoutValidation();
     }
 
     /**
@@ -104,170 +114,103 @@ public class TFRecordIO {
      * (e.g., {@code *.gz} is gzipped, {@code *.zlib} is zlib compressed, and all other
      * extensions are uncompressed).
      */
-    public static Bound withCompressionType(TFRecordIO.CompressionType compressionType) {
-      return new Bound().withCompressionType(compressionType);
+    public Read withCompressionType(TFRecordIO.CompressionType compressionType) {
+      return new Read().withCompressionType(compressionType);
     }
 
-    /**
-     * A {@link PTransform} that reads from one or more TFRecord files and returns a bounded
-     * {@link PCollection} containing one element for each record of the input files.
-     */
-    public static class Bound extends PTransform<PBegin, PCollection<byte[]>> {
-      /** The filepattern to read from. */
-      @Nullable private final ValueProvider<String> filepattern;
-
-      /** An option to indicate if input validation is desired. Default is true. */
-      private final boolean validate;
-
-      /** Option to indicate the input source's compression type. Default is AUTO. */
-      private final TFRecordIO.CompressionType compressionType;
-
-      private Bound() {
-        this(null, null, true, TFRecordIO.CompressionType.AUTO);
-      }
-
-      private Bound(
-          @Nullable String name,
-          @Nullable ValueProvider<String> filepattern,
-          boolean validate,
-          TFRecordIO.CompressionType compressionType) {
-        super(name);
-        this.filepattern = filepattern;
-        this.validate = validate;
-        this.compressionType = compressionType;
-      }
-
-      /**
-       * Returns a new transform for reading from TFRecord files that's like this one but that
-       * reads from the file(s) with the given name or pattern. See {@link TFRecordIO.Read#from}
-       * for a description of filepatterns.
-       *
-       * <p>Does not modify this object.
-
-       */
-      public Bound from(String filepattern) {
-        checkNotNull(filepattern, "Filepattern cannot be empty.");
-        return new Bound(name, StaticValueProvider.of(filepattern), validate, compressionType);
-      }
-
-      /**
-       * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
-       */
-      public Bound from(ValueProvider<String> filepattern) {
-        checkNotNull(filepattern, "Filepattern cannot be empty.");
-        return new Bound(name, filepattern, validate, compressionType);
-      }
-
-      /**
-       * Returns a new transform for reading from TFRecord files that's like this one but
-       * that has GCS path validation on pipeline creation disabled.
-       *
-       * <p>This can be useful in the case where the GCS input does not
-       * exist at the pipeline creation time, but is expected to be
-       * available at execution time.
-       *
-       * <p>Does not modify this object.
-       */
-      public Bound withoutValidation() {
-        return new Bound(name, filepattern, false, compressionType);
-      }
+    private Read() {
+      this(null, null, true, TFRecordIO.CompressionType.AUTO);
+    }
 
-      /**
-       * Returns a new transform for reading from TFRecord files that's like this one but
-       * reads from input sources using the specified compression type.
-       *
-       * <p>If no compression type is specified, the default is
-       * {@link TFRecordIO.CompressionType#AUTO}.
-       * See {@link TFRecordIO.Read#withCompressionType} for more details.
-       *
-       * <p>Does not modify this object.
-       */
-      public Bound withCompressionType(TFRecordIO.CompressionType compressionType) {
-        return new Bound(name, filepattern, validate, compressionType);
-      }
+    private Read(
+        @Nullable String name,
+        @Nullable ValueProvider<String> filepattern,
+        boolean validate,
+        TFRecordIO.CompressionType compressionType) {
+      super(name);
+      this.filepattern = filepattern;
+      this.validate = validate;
+      this.compressionType = compressionType;
+    }
 
-      @Override
-      public PCollection<byte[]> expand(PBegin input) {
-        if (filepattern == null) {
+    @Override
+    public PCollection<byte[]> expand(PBegin input) {
+      if (filepattern == null) {
+        throw new IllegalStateException(
+            "Need to set the filepattern of a TFRecordIO.Read transform");
+      }
+
+      if (validate) {
+        checkState(filepattern.isAccessible(), "Cannot validate with a RVP.");
+        try {
+          checkState(
+              !IOChannelUtils.getFactory(filepattern.get()).match(filepattern.get()).isEmpty(),
+              "Unable to find any files matching %s",
+              filepattern);
+        } catch (IOException e) {
           throw new IllegalStateException(
-              "Need to set the filepattern of a TFRecordIO.Read transform");
-        }
-
-        if (validate) {
-          checkState(filepattern.isAccessible(), "Cannot validate with a RVP.");
-          try {
-            checkState(
-                !IOChannelUtils.getFactory(filepattern.get()).match(filepattern.get()).isEmpty(),
-                "Unable to find any files matching %s",
-                filepattern);
-          } catch (IOException e) {
-            throw new IllegalStateException(
-                String.format("Failed to validate %s", filepattern.get()), e);
-          }
-        }
-
-        final Bounded<byte[]> read = org.apache.beam.sdk.io.Read.from(getSource());
-        PCollection<byte[]> pcol = input.getPipeline().apply("Read", read);
-        // Honor the default output coder that would have been used by this PTransform.
-        pcol.setCoder(getDefaultOutputCoder());
-        return pcol;
-      }
-
-      // Helper to create a source specific to the requested compression type.
-      protected FileBasedSource<byte[]> getSource() {
-        switch (compressionType) {
-          case NONE:
-            return new TFRecordSource(filepattern);
-          case AUTO:
-            return CompressedSource.from(new TFRecordSource(filepattern));
-          case GZIP:
-            return
-                CompressedSource.from(new TFRecordSource(filepattern))
-                    .withDecompression(CompressedSource.CompressionMode.GZIP);
-          case ZLIB:
-            return
-                CompressedSource.from(new TFRecordSource(filepattern))
-                    .withDecompression(CompressedSource.CompressionMode.DEFLATE);
-          default:
-            throw new IllegalArgumentException("Unknown compression type: " + compressionType);
+              String.format("Failed to validate %s", filepattern.get()), e);
         }
       }
 
-      @Override
-      public void populateDisplayData(DisplayData.Builder builder) {
-        super.populateDisplayData(builder);
+      final Bounded<byte[]> read = org.apache.beam.sdk.io.Read.from(getSource());
+      PCollection<byte[]> pcol = input.getPipeline().apply("Read", read);
+      // Honor the default output coder that would have been used by this PTransform.
+      pcol.setCoder(getDefaultOutputCoder());
+      return pcol;
+    }
 
-        String filepatternDisplay = filepattern.isAccessible()
-            ? filepattern.get() : filepattern.toString();
-        builder
-            .add(DisplayData.item("compressionType", compressionType.toString())
-                .withLabel("Compression Type"))
-            .addIfNotDefault(DisplayData.item("validation", validate)
-                .withLabel("Validation Enabled"), true)
-            .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
-                .withLabel("File Pattern"));
+    // Helper to create a source specific to the requested compression type.
+    protected FileBasedSource<byte[]> getSource() {
+      switch (compressionType) {
+        case NONE:
+          return new TFRecordSource(filepattern);
+        case AUTO:
+          return CompressedSource.from(new TFRecordSource(filepattern));
+        case GZIP:
+          return
+              CompressedSource.from(new TFRecordSource(filepattern))
+                  .withDecompression(CompressedSource.CompressionMode.GZIP);
+        case ZLIB:
+          return
+              CompressedSource.from(new TFRecordSource(filepattern))
+                  .withDecompression(CompressedSource.CompressionMode.DEFLATE);
+        default:
+          throw new IllegalArgumentException("Unknown compression type: " + compressionType);
       }
+    }
 
-      @Override
-      protected Coder<byte[]> getDefaultOutputCoder() {
-        return DEFAULT_BYTE_ARRAY_CODER;
-      }
+    @Override
+    public void populateDisplayData(DisplayData.Builder builder) {
+      super.populateDisplayData(builder);
 
-      public String getFilepattern() {
-        return filepattern.get();
-      }
+      String filepatternDisplay = filepattern.isAccessible()
+          ? filepattern.get() : filepattern.toString();
+      builder
+          .add(DisplayData.item("compressionType", compressionType.toString())
+              .withLabel("Compression Type"))
+          .addIfNotDefault(DisplayData.item("validation", validate)
+              .withLabel("Validation Enabled"), true)
+          .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
+              .withLabel("File Pattern"));
+    }
 
-      public boolean needsValidation() {
-        return validate;
-      }
+    @Override
+    protected Coder<byte[]> getDefaultOutputCoder() {
+      return DEFAULT_BYTE_ARRAY_CODER;
+    }
 
-      public TFRecordIO.CompressionType getCompressionType() {
-        return compressionType;
-      }
+    public String getFilepattern() {
+      return filepattern.get();
     }
 
-    /** Disallow construction of utility classes. */
-    private Read() {}
+    public boolean needsValidation() {
+      return validate;
+    }
+
+    public TFRecordIO.CompressionType getCompressionType() {
+      return compressionType;
+    }
   }
 
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/beam/blob/88d2c69c/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
index 94530ef..2a455d1 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
@@ -136,15 +136,15 @@ public class TFRecordIOTest {
 
     assertEquals(
         "TFRecordIO.Read/Read.out",
-        p.apply(TFRecordIO.Read.withoutValidation().from("foo.*")).getName());
+        p.apply(TFRecordIO.read().from("foo.*").withoutValidation()).getName());
     assertEquals(
         "MyRead/Read.out",
-        p.apply("MyRead", TFRecordIO.Read.withoutValidation().from("foo.*")).getName());
+        p.apply("MyRead", TFRecordIO.read().from("foo.*").withoutValidation()).getName());
   }
 
   @Test
   public void testReadDisplayData() {
-    TFRecordIO.Read.Bound read = TFRecordIO.Read
+    TFRecordIO.Read read = TFRecordIO.read()
         .from("foo.*")
         .withCompressionType(GZIP)
         .withoutValidation();
@@ -241,7 +241,7 @@ public class TFRecordIOTest {
     fos.write(data);
     fos.close();
 
-    TFRecordIO.Read.Bound read = TFRecordIO.Read.from(filename);
+    TFRecordIO.Read read = TFRecordIO.read().from(filename);
     PCollection<String> output = p.apply(read).apply(ParDo.of(new ByteArrayToString()));
 
     PAssert.that(output).containsInAnyOrder(expected);
@@ -338,7 +338,7 @@ public class TFRecordIOTest {
         .apply(write);
     p.run();
 
-    TFRecordIO.Read.Bound read = TFRecordIO.Read.from(baseFilename + "*")
+    TFRecordIO.Read read = TFRecordIO.read().from(baseFilename + "*")
         .withCompressionType(readCompressionType);
     PCollection<String> output = p2.apply(read).apply(ParDo.of(new ByteArrayToString()));