You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jk...@apache.org on 2017/04/18 21:21:29 UTC
[3/6] beam git commit: Gets rid of TFRecordIO.Read.Bound
Gets rid of TFRecordIO.Read.Bound
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/88d2c69c
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/88d2c69c
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/88d2c69c
Branch: refs/heads/master
Commit: 88d2c69c20b2869062991330ffa791daa85a3c84
Parents: 0be61db
Author: Eugene Kirpichov <ki...@google.com>
Authored: Mon Apr 17 17:21:34 2017 -0700
Committer: Eugene Kirpichov <ki...@google.com>
Committed: Tue Apr 18 14:03:42 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/beam/sdk/io/TFRecordIO.java | 261 ++++++++-----------
.../org/apache/beam/sdk/io/TFRecordIOTest.java | 10 +-
2 files changed, 107 insertions(+), 164 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/88d2c69c/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
index 0552236..b4fd93c 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
@@ -18,13 +18,11 @@
package org.apache.beam.sdk.io;
import static com.google.common.base.Preconditions.checkArgument;
-import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.hash.HashFunction;
import com.google.common.hash.Hashing;
-
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -32,9 +30,7 @@ import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.util.NoSuchElementException;
import java.util.regex.Pattern;
-
import javax.annotation.Nullable;
-
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.VoidCoder;
@@ -62,7 +58,20 @@ public class TFRecordIO {
* files matching a pattern) and returns a {@link PCollection} containing
* the decoding of each of the records of the TFRecord file(s) as a byte array.
*/
- public static class Read {
+ public static Read read() {
+ return new Read();
+ }
+
+ /** Implementation of {@link #read}. */
+ public static class Read extends PTransform<PBegin, PCollection<byte[]>> {
+ /** The filepattern to read from. */
+ @Nullable private final ValueProvider<String> filepattern;
+
+ /** An option to indicate if input validation is desired. Default is true. */
+ private final boolean validate;
+
+ /** Option to indicate the input source's compression type. Default is AUTO. */
+ private final TFRecordIO.CompressionType compressionType;
/**
* Returns a transform for reading TFRecord files that reads from the file(s)
@@ -72,16 +81,17 @@ public class TFRecordIO {
* execution). Standard <a href="http://docs.oracle.com/javase/tutorial/essential/io/find.html"
* >Java Filesystem glob patterns</a> ("*", "?", "[..]") are supported.
*/
- public static Bound from(String filepattern) {
- return new Bound().from(filepattern);
+ public Read from(String filepattern) {
+ return new Read().from(filepattern);
}
/**
* Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
*/
- public static Bound from(ValueProvider<String> filepattern) {
- return new Bound().from(filepattern);
+ public Read from(ValueProvider<String> filepattern) {
+ return new Read().from(filepattern);
}
+
/**
* Returns a transform for reading TFRecord files that has GCS path validation on
* pipeline creation disabled.
@@ -90,8 +100,8 @@ public class TFRecordIO {
* exist at the pipeline creation time, but is expected to be
* available at execution time.
*/
- public static Bound withoutValidation() {
- return new Bound().withoutValidation();
+ public Read withoutValidation() {
+ return new Read().withoutValidation();
}
/**
@@ -104,170 +114,103 @@ public class TFRecordIO {
* (e.g., {@code *.gz} is gzipped, {@code *.zlib} is zlib compressed, and all other
* extensions are uncompressed).
*/
- public static Bound withCompressionType(TFRecordIO.CompressionType compressionType) {
- return new Bound().withCompressionType(compressionType);
+ public Read withCompressionType(TFRecordIO.CompressionType compressionType) {
+ return new Read().withCompressionType(compressionType);
}
- /**
- * A {@link PTransform} that reads from one or more TFRecord files and returns a bounded
- * {@link PCollection} containing one element for each record of the input files.
- */
- public static class Bound extends PTransform<PBegin, PCollection<byte[]>> {
- /** The filepattern to read from. */
- @Nullable private final ValueProvider<String> filepattern;
-
- /** An option to indicate if input validation is desired. Default is true. */
- private final boolean validate;
-
- /** Option to indicate the input source's compression type. Default is AUTO. */
- private final TFRecordIO.CompressionType compressionType;
-
- private Bound() {
- this(null, null, true, TFRecordIO.CompressionType.AUTO);
- }
-
- private Bound(
- @Nullable String name,
- @Nullable ValueProvider<String> filepattern,
- boolean validate,
- TFRecordIO.CompressionType compressionType) {
- super(name);
- this.filepattern = filepattern;
- this.validate = validate;
- this.compressionType = compressionType;
- }
-
- /**
- * Returns a new transform for reading from TFRecord files that's like this one but that
- * reads from the file(s) with the given name or pattern. See {@link TFRecordIO.Read#from}
- * for a description of filepatterns.
- *
- * <p>Does not modify this object.
-
- */
- public Bound from(String filepattern) {
- checkNotNull(filepattern, "Filepattern cannot be empty.");
- return new Bound(name, StaticValueProvider.of(filepattern), validate, compressionType);
- }
-
- /**
- * Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
- */
- public Bound from(ValueProvider<String> filepattern) {
- checkNotNull(filepattern, "Filepattern cannot be empty.");
- return new Bound(name, filepattern, validate, compressionType);
- }
-
- /**
- * Returns a new transform for reading from TFRecord files that's like this one but
- * that has GCS path validation on pipeline creation disabled.
- *
- * <p>This can be useful in the case where the GCS input does not
- * exist at the pipeline creation time, but is expected to be
- * available at execution time.
- *
- * <p>Does not modify this object.
- */
- public Bound withoutValidation() {
- return new Bound(name, filepattern, false, compressionType);
- }
+ private Read() {
+ this(null, null, true, TFRecordIO.CompressionType.AUTO);
+ }
- /**
- * Returns a new transform for reading from TFRecord files that's like this one but
- * reads from input sources using the specified compression type.
- *
- * <p>If no compression type is specified, the default is
- * {@link TFRecordIO.CompressionType#AUTO}.
- * See {@link TFRecordIO.Read#withCompressionType} for more details.
- *
- * <p>Does not modify this object.
- */
- public Bound withCompressionType(TFRecordIO.CompressionType compressionType) {
- return new Bound(name, filepattern, validate, compressionType);
- }
+ private Read(
+ @Nullable String name,
+ @Nullable ValueProvider<String> filepattern,
+ boolean validate,
+ TFRecordIO.CompressionType compressionType) {
+ super(name);
+ this.filepattern = filepattern;
+ this.validate = validate;
+ this.compressionType = compressionType;
+ }
- @Override
- public PCollection<byte[]> expand(PBegin input) {
- if (filepattern == null) {
+ @Override
+ public PCollection<byte[]> expand(PBegin input) {
+ if (filepattern == null) {
+ throw new IllegalStateException(
+ "Need to set the filepattern of a TFRecordIO.Read transform");
+ }
+
+ if (validate) {
+ checkState(filepattern.isAccessible(), "Cannot validate with a RVP.");
+ try {
+ checkState(
+ !IOChannelUtils.getFactory(filepattern.get()).match(filepattern.get()).isEmpty(),
+ "Unable to find any files matching %s",
+ filepattern);
+ } catch (IOException e) {
throw new IllegalStateException(
- "Need to set the filepattern of a TFRecordIO.Read transform");
- }
-
- if (validate) {
- checkState(filepattern.isAccessible(), "Cannot validate with a RVP.");
- try {
- checkState(
- !IOChannelUtils.getFactory(filepattern.get()).match(filepattern.get()).isEmpty(),
- "Unable to find any files matching %s",
- filepattern);
- } catch (IOException e) {
- throw new IllegalStateException(
- String.format("Failed to validate %s", filepattern.get()), e);
- }
- }
-
- final Bounded<byte[]> read = org.apache.beam.sdk.io.Read.from(getSource());
- PCollection<byte[]> pcol = input.getPipeline().apply("Read", read);
- // Honor the default output coder that would have been used by this PTransform.
- pcol.setCoder(getDefaultOutputCoder());
- return pcol;
- }
-
- // Helper to create a source specific to the requested compression type.
- protected FileBasedSource<byte[]> getSource() {
- switch (compressionType) {
- case NONE:
- return new TFRecordSource(filepattern);
- case AUTO:
- return CompressedSource.from(new TFRecordSource(filepattern));
- case GZIP:
- return
- CompressedSource.from(new TFRecordSource(filepattern))
- .withDecompression(CompressedSource.CompressionMode.GZIP);
- case ZLIB:
- return
- CompressedSource.from(new TFRecordSource(filepattern))
- .withDecompression(CompressedSource.CompressionMode.DEFLATE);
- default:
- throw new IllegalArgumentException("Unknown compression type: " + compressionType);
+ String.format("Failed to validate %s", filepattern.get()), e);
}
}
- @Override
- public void populateDisplayData(DisplayData.Builder builder) {
- super.populateDisplayData(builder);
+ final Bounded<byte[]> read = org.apache.beam.sdk.io.Read.from(getSource());
+ PCollection<byte[]> pcol = input.getPipeline().apply("Read", read);
+ // Honor the default output coder that would have been used by this PTransform.
+ pcol.setCoder(getDefaultOutputCoder());
+ return pcol;
+ }
- String filepatternDisplay = filepattern.isAccessible()
- ? filepattern.get() : filepattern.toString();
- builder
- .add(DisplayData.item("compressionType", compressionType.toString())
- .withLabel("Compression Type"))
- .addIfNotDefault(DisplayData.item("validation", validate)
- .withLabel("Validation Enabled"), true)
- .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
- .withLabel("File Pattern"));
+ // Helper to create a source specific to the requested compression type.
+ protected FileBasedSource<byte[]> getSource() {
+ switch (compressionType) {
+ case NONE:
+ return new TFRecordSource(filepattern);
+ case AUTO:
+ return CompressedSource.from(new TFRecordSource(filepattern));
+ case GZIP:
+ return
+ CompressedSource.from(new TFRecordSource(filepattern))
+ .withDecompression(CompressedSource.CompressionMode.GZIP);
+ case ZLIB:
+ return
+ CompressedSource.from(new TFRecordSource(filepattern))
+ .withDecompression(CompressedSource.CompressionMode.DEFLATE);
+ default:
+ throw new IllegalArgumentException("Unknown compression type: " + compressionType);
}
+ }
- @Override
- protected Coder<byte[]> getDefaultOutputCoder() {
- return DEFAULT_BYTE_ARRAY_CODER;
- }
+ @Override
+ public void populateDisplayData(DisplayData.Builder builder) {
+ super.populateDisplayData(builder);
- public String getFilepattern() {
- return filepattern.get();
- }
+ String filepatternDisplay = filepattern.isAccessible()
+ ? filepattern.get() : filepattern.toString();
+ builder
+ .add(DisplayData.item("compressionType", compressionType.toString())
+ .withLabel("Compression Type"))
+ .addIfNotDefault(DisplayData.item("validation", validate)
+ .withLabel("Validation Enabled"), true)
+ .addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
+ .withLabel("File Pattern"));
+ }
- public boolean needsValidation() {
- return validate;
- }
+ @Override
+ protected Coder<byte[]> getDefaultOutputCoder() {
+ return DEFAULT_BYTE_ARRAY_CODER;
+ }
- public TFRecordIO.CompressionType getCompressionType() {
- return compressionType;
- }
+ public String getFilepattern() {
+ return filepattern.get();
}
- /** Disallow construction of utility classes. */
- private Read() {}
+ public boolean needsValidation() {
+ return validate;
+ }
+
+ public TFRecordIO.CompressionType getCompressionType() {
+ return compressionType;
+ }
}
/////////////////////////////////////////////////////////////////////////////
http://git-wip-us.apache.org/repos/asf/beam/blob/88d2c69c/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
index 94530ef..2a455d1 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TFRecordIOTest.java
@@ -136,15 +136,15 @@ public class TFRecordIOTest {
assertEquals(
"TFRecordIO.Read/Read.out",
- p.apply(TFRecordIO.Read.withoutValidation().from("foo.*")).getName());
+ p.apply(TFRecordIO.read().from("foo.*").withoutValidation()).getName());
assertEquals(
"MyRead/Read.out",
- p.apply("MyRead", TFRecordIO.Read.withoutValidation().from("foo.*")).getName());
+ p.apply("MyRead", TFRecordIO.read().from("foo.*").withoutValidation()).getName());
}
@Test
public void testReadDisplayData() {
- TFRecordIO.Read.Bound read = TFRecordIO.Read
+ TFRecordIO.Read read = TFRecordIO.read()
.from("foo.*")
.withCompressionType(GZIP)
.withoutValidation();
@@ -241,7 +241,7 @@ public class TFRecordIOTest {
fos.write(data);
fos.close();
- TFRecordIO.Read.Bound read = TFRecordIO.Read.from(filename);
+ TFRecordIO.Read read = TFRecordIO.read().from(filename);
PCollection<String> output = p.apply(read).apply(ParDo.of(new ByteArrayToString()));
PAssert.that(output).containsInAnyOrder(expected);
@@ -338,7 +338,7 @@ public class TFRecordIOTest {
.apply(write);
p.run();
- TFRecordIO.Read.Bound read = TFRecordIO.Read.from(baseFilename + "*")
+ TFRecordIO.Read read = TFRecordIO.read().from(baseFilename + "*")
.withCompressionType(readCompressionType);
PCollection<String> output = p2.apply(read).apply(ParDo.of(new ByteArrayToString()));