You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jk...@apache.org on 2017/04/18 21:21:30 UTC

[4/6] beam git commit: Converts TFRecordIO.Read to AutoValue

Converts TFRecordIO.Read to AutoValue


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/f5d92a49
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/f5d92a49
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/f5d92a49

Branch: refs/heads/master
Commit: f5d92a49c91c4555f6ad527425077f92bb2bb377
Parents: 88d2c69
Author: Eugene Kirpichov <ki...@google.com>
Authored: Mon Apr 17 17:29:06 2017 -0700
Committer: Eugene Kirpichov <ki...@google.com>
Committed: Tue Apr 18 14:03:42 2017 -0700

----------------------------------------------------------------------
 .../java/org/apache/beam/sdk/io/TFRecordIO.java | 99 +++++++++-----------
 1 file changed, 44 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/f5d92a49/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
index b4fd93c..fb4ff5b 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
@@ -20,6 +20,7 @@ package org.apache.beam.sdk.io;
 import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkState;
 
+import com.google.auto.value.AutoValue;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.hash.HashFunction;
 import com.google.common.hash.Hashing;
@@ -59,19 +60,32 @@ public class TFRecordIO {
    * the decoding of each of the records of the TFRecord file(s) as a byte array.
    */
   public static Read read() {
-    return new Read();
+    return new AutoValue_TFRecordIO_Read.Builder()
+        .setValidate(true)
+        .setCompressionType(CompressionType.AUTO)
+        .build();
   }
 
   /** Implementation of {@link #read}. */
-  public static class Read extends PTransform<PBegin, PCollection<byte[]>> {
-    /** The filepattern to read from. */
-    @Nullable private final ValueProvider<String> filepattern;
+  @AutoValue
+  public abstract static class Read extends PTransform<PBegin, PCollection<byte[]>> {
+    @Nullable
+    abstract ValueProvider<String> getFilepattern();
 
-    /** An option to indicate if input validation is desired. Default is true. */
-    private final boolean validate;
+    abstract boolean getValidate();
 
-    /** Option to indicate the input source's compression type. Default is AUTO. */
-    private final TFRecordIO.CompressionType compressionType;
+    abstract CompressionType getCompressionType();
+
+    abstract Builder toBuilder();
+
+    @AutoValue.Builder
+    abstract static class Builder {
+      abstract Builder setFilepattern(ValueProvider<String> filepattern);
+      abstract Builder setValidate(boolean validate);
+      abstract Builder setCompressionType(CompressionType compressionType);
+
+      abstract Read build();
+    }
 
     /**
      * Returns a transform for reading TFRecord files that reads from the file(s)
@@ -82,14 +96,14 @@ public class TFRecordIO {
      * >Java Filesystem glob patterns</a> ("*", "?", "[..]") are supported.
      */
     public Read from(String filepattern) {
-      return new Read().from(filepattern);
+      return from(StaticValueProvider.of(filepattern));
     }
 
     /**
      * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
      */
     public Read from(ValueProvider<String> filepattern) {
-      return new Read().from(filepattern);
+      return toBuilder().setFilepattern(filepattern).build();
     }
 
     /**
@@ -101,7 +115,7 @@ public class TFRecordIO {
      * available at execution time.
      */
     public Read withoutValidation() {
-      return new Read().withoutValidation();
+      return toBuilder().setValidate(false).build();
     }
 
     /**
@@ -115,41 +129,28 @@ public class TFRecordIO {
      * extensions are uncompressed).
      */
     public Read withCompressionType(TFRecordIO.CompressionType compressionType) {
-      return new Read().withCompressionType(compressionType);
-    }
-
-    private Read() {
-      this(null, null, true, TFRecordIO.CompressionType.AUTO);
-    }
-
-    private Read(
-        @Nullable String name,
-        @Nullable ValueProvider<String> filepattern,
-        boolean validate,
-        TFRecordIO.CompressionType compressionType) {
-      super(name);
-      this.filepattern = filepattern;
-      this.validate = validate;
-      this.compressionType = compressionType;
+      return toBuilder().setCompressionType(compressionType).build();
     }
 
     @Override
     public PCollection<byte[]> expand(PBegin input) {
-      if (filepattern == null) {
+      if (getFilepattern() == null) {
         throw new IllegalStateException(
             "Need to set the filepattern of a TFRecordIO.Read transform");
       }
 
-      if (validate) {
-        checkState(filepattern.isAccessible(), "Cannot validate with a RVP.");
+      if (getValidate()) {
+        checkState(getFilepattern().isAccessible(), "Cannot validate with a RVP.");
         try {
           checkState(
-              !IOChannelUtils.getFactory(filepattern.get()).match(filepattern.get()).isEmpty(),
+              !IOChannelUtils.getFactory(getFilepattern().get())
+                  .match(getFilepattern().get())
+                  .isEmpty(),
               "Unable to find any files matching %s",
-              filepattern);
+              getFilepattern());
         } catch (IOException e) {
           throw new IllegalStateException(
-              String.format("Failed to validate %s", filepattern.get()), e);
+              String.format("Failed to validate %s", getFilepattern().get()), e);
         }
       }
 
@@ -162,21 +163,21 @@ public class TFRecordIO {
 
     // Helper to create a source specific to the requested compression type.
     protected FileBasedSource<byte[]> getSource() {
-      switch (compressionType) {
+      switch (getCompressionType()) {
         case NONE:
-          return new TFRecordSource(filepattern);
+          return new TFRecordSource(getFilepattern());
         case AUTO:
-          return CompressedSource.from(new TFRecordSource(filepattern));
+          return CompressedSource.from(new TFRecordSource(getFilepattern()));
         case GZIP:
           return
-              CompressedSource.from(new TFRecordSource(filepattern))
+              CompressedSource.from(new TFRecordSource(getFilepattern()))
                   .withDecompression(CompressedSource.CompressionMode.GZIP);
         case ZLIB:
           return
-              CompressedSource.from(new TFRecordSource(filepattern))
+              CompressedSource.from(new TFRecordSource(getFilepattern()))
                   .withDecompression(CompressedSource.CompressionMode.DEFLATE);
         default:
-          throw new IllegalArgumentException("Unknown compression type: " + compressionType);
+          throw new IllegalArgumentException("Unknown compression type: " + getCompressionType());
       }
     }
 
@@ -184,12 +185,12 @@ public class TFRecordIO {
     public void populateDisplayData(DisplayData.Builder builder) {
       super.populateDisplayData(builder);
 
-      String filepatternDisplay = filepattern.isAccessible()
-          ? filepattern.get() : filepattern.toString();
+      String filepatternDisplay = getFilepattern().isAccessible()
+          ? getFilepattern().get() : getFilepattern().toString();
       builder
-          .add(DisplayData.item("compressionType", compressionType.toString())
+          .add(DisplayData.item("compressionType", getCompressionType().toString())
               .withLabel("Compression Type"))
-          .addIfNotDefault(DisplayData.item("validation", validate)
+          .addIfNotDefault(DisplayData.item("validation", getValidate())
               .withLabel("Validation Enabled"), true)
           .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
               .withLabel("File Pattern"));
@@ -199,18 +200,6 @@ public class TFRecordIO {
     protected Coder<byte[]> getDefaultOutputCoder() {
       return DEFAULT_BYTE_ARRAY_CODER;
     }
-
-    public String getFilepattern() {
-      return filepattern.get();
-    }
-
-    public boolean needsValidation() {
-      return validate;
-    }
-
-    public TFRecordIO.CompressionType getCompressionType() {
-      return compressionType;
-    }
   }
 
   /////////////////////////////////////////////////////////////////////////////