You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by jk...@apache.org on 2017/04/18 21:21:30 UTC
[4/6] beam git commit: Converts TFRecordIO.Read to AutoValue
Converts TFRecordIO.Read to AutoValue
Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/f5d92a49
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/f5d92a49
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/f5d92a49
Branch: refs/heads/master
Commit: f5d92a49c91c4555f6ad527425077f92bb2bb377
Parents: 88d2c69
Author: Eugene Kirpichov <ki...@google.com>
Authored: Mon Apr 17 17:29:06 2017 -0700
Committer: Eugene Kirpichov <ki...@google.com>
Committed: Tue Apr 18 14:03:42 2017 -0700
----------------------------------------------------------------------
.../java/org/apache/beam/sdk/io/TFRecordIO.java | 99 +++++++++-----------
1 file changed, 44 insertions(+), 55 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/beam/blob/f5d92a49/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
index b4fd93c..fb4ff5b 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/TFRecordIO.java
@@ -20,6 +20,7 @@ package org.apache.beam.sdk.io;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
+import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.hash.HashFunction;
import com.google.common.hash.Hashing;
@@ -59,19 +60,32 @@ public class TFRecordIO {
* the decoding of each of the records of the TFRecord file(s) as a byte array.
*/
public static Read read() {
- return new Read();
+ return new AutoValue_TFRecordIO_Read.Builder()
+ .setValidate(true)
+ .setCompressionType(CompressionType.AUTO)
+ .build();
}
/** Implementation of {@link #read}. */
- public static class Read extends PTransform<PBegin, PCollection<byte[]>> {
- /** The filepattern to read from. */
- @Nullable private final ValueProvider<String> filepattern;
+ @AutoValue
+ public abstract static class Read extends PTransform<PBegin, PCollection<byte[]>> {
+ @Nullable
+ abstract ValueProvider<String> getFilepattern();
- /** An option to indicate if input validation is desired. Default is true. */
- private final boolean validate;
+ abstract boolean getValidate();
- /** Option to indicate the input source's compression type. Default is AUTO. */
- private final TFRecordIO.CompressionType compressionType;
+ abstract CompressionType getCompressionType();
+
+ abstract Builder toBuilder();
+
+ @AutoValue.Builder
+ abstract static class Builder {
+ abstract Builder setFilepattern(ValueProvider<String> filepattern);
+ abstract Builder setValidate(boolean validate);
+ abstract Builder setCompressionType(CompressionType compressionType);
+
+ abstract Read build();
+ }
/**
* Returns a transform for reading TFRecord files that reads from the file(s)
@@ -82,14 +96,14 @@ public class TFRecordIO {
* >Java Filesystem glob patterns</a> ("*", "?", "[..]") are supported.
*/
public Read from(String filepattern) {
- return new Read().from(filepattern);
+ return from(StaticValueProvider.of(filepattern));
}
/**
* Same as {@code from(filepattern)}, but accepting a {@link ValueProvider}.
*/
public Read from(ValueProvider<String> filepattern) {
- return new Read().from(filepattern);
+ return toBuilder().setFilepattern(filepattern).build();
}
/**
@@ -101,7 +115,7 @@ public class TFRecordIO {
* available at execution time.
*/
public Read withoutValidation() {
- return new Read().withoutValidation();
+ return toBuilder().setValidate(false).build();
}
/**
@@ -115,41 +129,28 @@ public class TFRecordIO {
* extensions are uncompressed).
*/
public Read withCompressionType(TFRecordIO.CompressionType compressionType) {
- return new Read().withCompressionType(compressionType);
- }
-
- private Read() {
- this(null, null, true, TFRecordIO.CompressionType.AUTO);
- }
-
- private Read(
- @Nullable String name,
- @Nullable ValueProvider<String> filepattern,
- boolean validate,
- TFRecordIO.CompressionType compressionType) {
- super(name);
- this.filepattern = filepattern;
- this.validate = validate;
- this.compressionType = compressionType;
+ return toBuilder().setCompressionType(compressionType).build();
}
@Override
public PCollection<byte[]> expand(PBegin input) {
- if (filepattern == null) {
+ if (getFilepattern() == null) {
throw new IllegalStateException(
"Need to set the filepattern of a TFRecordIO.Read transform");
}
- if (validate) {
- checkState(filepattern.isAccessible(), "Cannot validate with a RVP.");
+ if (getValidate()) {
+ checkState(getFilepattern().isAccessible(), "Cannot validate with a RVP.");
try {
checkState(
- !IOChannelUtils.getFactory(filepattern.get()).match(filepattern.get()).isEmpty(),
+ !IOChannelUtils.getFactory(getFilepattern().get())
+ .match(getFilepattern().get())
+ .isEmpty(),
"Unable to find any files matching %s",
- filepattern);
+ getFilepattern());
} catch (IOException e) {
throw new IllegalStateException(
- String.format("Failed to validate %s", filepattern.get()), e);
+ String.format("Failed to validate %s", getFilepattern().get()), e);
}
}
@@ -162,21 +163,21 @@ public class TFRecordIO {
// Helper to create a source specific to the requested compression type.
protected FileBasedSource<byte[]> getSource() {
- switch (compressionType) {
+ switch (getCompressionType()) {
case NONE:
- return new TFRecordSource(filepattern);
+ return new TFRecordSource(getFilepattern());
case AUTO:
- return CompressedSource.from(new TFRecordSource(filepattern));
+ return CompressedSource.from(new TFRecordSource(getFilepattern()));
case GZIP:
return
- CompressedSource.from(new TFRecordSource(filepattern))
+ CompressedSource.from(new TFRecordSource(getFilepattern()))
.withDecompression(CompressedSource.CompressionMode.GZIP);
case ZLIB:
return
- CompressedSource.from(new TFRecordSource(filepattern))
+ CompressedSource.from(new TFRecordSource(getFilepattern()))
.withDecompression(CompressedSource.CompressionMode.DEFLATE);
default:
- throw new IllegalArgumentException("Unknown compression type: " + compressionType);
+ throw new IllegalArgumentException("Unknown compression type: " + getCompressionType());
}
}
@@ -184,12 +185,12 @@ public class TFRecordIO {
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
- String filepatternDisplay = filepattern.isAccessible()
- ? filepattern.get() : filepattern.toString();
+ String filepatternDisplay = getFilepattern().isAccessible()
+ ? getFilepattern().get() : getFilepattern().toString();
builder
- .add(DisplayData.item("compressionType", compressionType.toString())
+ .add(DisplayData.item("compressionType", getCompressionType().toString())
.withLabel("Compression Type"))
- .addIfNotDefault(DisplayData.item("validation", validate)
+ .addIfNotDefault(DisplayData.item("validation", getValidate())
.withLabel("Validation Enabled"), true)
.addIfNotNull(DisplayData.item("filePattern", filepatternDisplay)
.withLabel("File Pattern"));
@@ -199,18 +200,6 @@ public class TFRecordIO {
protected Coder<byte[]> getDefaultOutputCoder() {
return DEFAULT_BYTE_ARRAY_CODER;
}
-
- public String getFilepattern() {
- return filepattern.get();
- }
-
- public boolean needsValidation() {
- return validate;
- }
-
- public TFRecordIO.CompressionType getCompressionType() {
- return compressionType;
- }
}
/////////////////////////////////////////////////////////////////////////////