You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2016/04/11 22:20:56 UTC
[1/2] incubator-beam git commit: Add ShardControlledWrite override
Repository: incubator-beam
Updated Branches:
refs/heads/master cd2a3a18f -> e1a471176
Add ShardControlledWrite override
This is used for TextIO and AvroIO, which provide withNumOutputShards
methods to control the number of output files. Apply this override in
the InProcessPipelineRunner.
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/a8ce5fd7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/a8ce5fd7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/a8ce5fd7
Branch: refs/heads/master
Commit: a8ce5fd7950a32271a30613c3c41679576b869ca
Parents: cd2a3a1
Author: Thomas Groh <tg...@google.com>
Authored: Tue Mar 29 10:06:02 2016 -0700
Committer: Luke Cwik <lc...@google.com>
Committed: Mon Apr 11 13:06:53 2016 -0700
----------------------------------------------------------------------
.../inprocess/AvroIOShardedWriteFactory.java | 76 +++++++++++++
.../inprocess/InProcessPipelineRunner.java | 4 +
.../runners/inprocess/ShardControlledWrite.java | 81 ++++++++++++++
.../inprocess/TextIOShardedWriteFactory.java | 78 +++++++++++++
.../cloud/dataflow/sdk/io/AvroIOTest.java | 30 ++---
.../cloud/dataflow/sdk/io/TextIOTest.java | 33 +++---
.../AvroIOShardedWriteFactoryTest.java | 112 +++++++++++++++++++
.../TextIOShardedWriteFactoryTest.java | 112 +++++++++++++++++++
8 files changed, 498 insertions(+), 28 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactory.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactory.java
new file mode 100644
index 0000000..49576e5
--- /dev/null
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactory.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import com.google.cloud.dataflow.sdk.io.AvroIO;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.util.IOChannelUtils;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PDone;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.POutput;
+
+class AvroIOShardedWriteFactory implements PTransformOverrideFactory {
+ @Override
+ public <InputT extends PInput, OutputT extends POutput> PTransform<InputT, OutputT> override(
+ PTransform<InputT, OutputT> transform) {
+ if (transform instanceof AvroIO.Write.Bound) {
+ @SuppressWarnings("unchecked")
+ AvroIO.Write.Bound<InputT> originalWrite = (AvroIO.Write.Bound<InputT>) transform;
+ if (originalWrite.getNumShards() > 1
+ || (originalWrite.getNumShards() == 1
+ && !"".equals(originalWrite.getShardNameTemplate()))) {
+ @SuppressWarnings("unchecked")
+ PTransform<InputT, OutputT> override =
+ (PTransform<InputT, OutputT>) new AvroIOShardedWrite<InputT>(originalWrite);
+ return override;
+ }
+ }
+ return transform;
+ }
+
+ private class AvroIOShardedWrite<InputT> extends ShardControlledWrite<InputT> {
+ private final AvroIO.Write.Bound<InputT> initial;
+
+ private AvroIOShardedWrite(AvroIO.Write.Bound<InputT> initial) {
+ this.initial = initial;
+ }
+
+ @Override
+ int getNumShards() {
+ return initial.getNumShards();
+ }
+
+ @Override
+ PTransform<? super PCollection<InputT>, PDone> getSingleShardTransform(int shardNum) {
+ String shardName =
+ IOChannelUtils.constructName(
+ initial.getFilenamePrefix(),
+ initial.getShardNameTemplate(),
+ initial.getFilenameSuffix(),
+ shardNum,
+ getNumShards());
+ return initial.withoutSharding().to(shardName).withSuffix("");
+ }
+
+ @Override
+ protected PTransform<PCollection<InputT>, PDone> delegate() {
+ return initial;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
index fa93994..764ae09 100644
--- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
@@ -21,6 +21,8 @@ import com.google.cloud.dataflow.sdk.Pipeline;
import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException;
import com.google.cloud.dataflow.sdk.PipelineResult;
import com.google.cloud.dataflow.sdk.annotations.Experimental;
+import com.google.cloud.dataflow.sdk.io.AvroIO;
+import com.google.cloud.dataflow.sdk.io.TextIO;
import com.google.cloud.dataflow.sdk.options.PipelineOptions;
import com.google.cloud.dataflow.sdk.runners.AggregatorPipelineExtractor;
import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException;
@@ -83,6 +85,8 @@ public class InProcessPipelineRunner
.put(Create.Values.class, new InProcessCreateOverrideFactory())
.put(GroupByKey.class, new InProcessGroupByKeyOverrideFactory())
.put(CreatePCollectionView.class, new InProcessViewOverrideFactory())
+ .put(AvroIO.Write.Bound.class, new AvroIOShardedWriteFactory())
+ .put(TextIO.Write.Bound.class, new TextIOShardedWriteFactory())
.build();
/**
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ShardControlledWrite.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ShardControlledWrite.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ShardControlledWrite.java
new file mode 100644
index 0000000..fc6419e
--- /dev/null
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ShardControlledWrite.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.transforms.Partition;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PCollectionList;
+import com.google.cloud.dataflow.sdk.values.PDone;
+
+import java.util.concurrent.ThreadLocalRandom;
+
+/**
+ * A write that explicitly controls its number of output shards.
+ */
+abstract class ShardControlledWrite<InputT>
+ extends ForwardingPTransform<PCollection<InputT>, PDone> {
+ @Override
+ public PDone apply(PCollection<InputT> input) {
+ int numShards = getNumShards();
+ checkArgument(
+ numShards >= 1,
+ "%s should only be applied if the output has a controlled number of shards (> 1); got %s",
+ getClass().getSimpleName(),
+ getNumShards());
+ PCollectionList<InputT> shards =
+ input.apply(
+ "PartitionInto" + numShards + "Shards",
+ Partition.of(getNumShards(), new RandomSeedPartitionFn<InputT>()));
+ for (int i = 0; i < shards.size(); i++) {
+ PCollection<InputT> shard = shards.get(i);
+ PTransform<? super PCollection<InputT>, PDone> writeShard = getSingleShardTransform(i);
+ shard.apply(String.format("%s(Shard:%s)", writeShard.getName(), i), writeShard);
+ }
+ return PDone.in(input.getPipeline());
+ }
+
+ /**
+ * Returns the number of shards this {@link PTransform} should write to.
+ */
+ abstract int getNumShards();
+
+ /**
+ * Returns a {@link PTransform} that performs a write to the shard with the specified shard
+ * number.
+ *
+ * <p>This method will be called n times, where n is the value of {@link #getNumShards()}, for
+ * shard numbers {@code [0...n)}.
+ */
+ abstract PTransform<? super PCollection<InputT>, PDone> getSingleShardTransform(int shardNum);
+
+ private static class RandomSeedPartitionFn<T> implements Partition.PartitionFn<T> {
+ int nextPartition = -1;
+ @Override
+ public int partitionFor(T elem, int numPartitions) {
+ if (nextPartition < 0) {
+ nextPartition = ThreadLocalRandom.current().nextInt(numPartitions);
+ }
+ nextPartition++;
+ nextPartition %= numPartitions;
+ return nextPartition;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactory.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactory.java
new file mode 100644
index 0000000..af433c4
--- /dev/null
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactory.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import com.google.cloud.dataflow.sdk.io.TextIO;
+import com.google.cloud.dataflow.sdk.io.TextIO.Write.Bound;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.util.IOChannelUtils;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PDone;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.POutput;
+
+class TextIOShardedWriteFactory implements PTransformOverrideFactory {
+
+ @Override
+ public <InputT extends PInput, OutputT extends POutput> PTransform<InputT, OutputT> override(
+ PTransform<InputT, OutputT> transform) {
+ if (transform instanceof TextIO.Write.Bound) {
+ @SuppressWarnings("unchecked")
+ TextIO.Write.Bound<InputT> originalWrite = (TextIO.Write.Bound<InputT>) transform;
+ if (originalWrite.getNumShards() > 1
+ || (originalWrite.getNumShards() == 1
+ && !"".equals(originalWrite.getShardNameTemplate()))) {
+ @SuppressWarnings("unchecked")
+ PTransform<InputT, OutputT> override =
+ (PTransform<InputT, OutputT>) new TextIOShardedWrite<InputT>(originalWrite);
+ return override;
+ }
+ }
+ return transform;
+ }
+
+ private static class TextIOShardedWrite<InputT> extends ShardControlledWrite<InputT> {
+ private final TextIO.Write.Bound<InputT> initial;
+
+ private TextIOShardedWrite(Bound<InputT> initial) {
+ this.initial = initial;
+ }
+
+ @Override
+ int getNumShards() {
+ return initial.getNumShards();
+ }
+
+ @Override
+ PTransform<PCollection<InputT>, PDone> getSingleShardTransform(int shardNum) {
+ String shardName =
+ IOChannelUtils.constructName(
+ initial.getFilenamePrefix(),
+ initial.getShardTemplate(),
+ initial.getFilenameSuffix(),
+ shardNum,
+ getNumShards());
+ return TextIO.Write.withCoder(initial.getCoder()).to(shardName).withoutSharding();
+ }
+
+ @Override
+ protected PTransform<PCollection<InputT>, PDone> delegate() {
+ return initial;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java
index 7949b9f..f9e4dba 100644
--- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java
+++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java
@@ -206,35 +206,35 @@ public class AvroIOTest {
TestPipeline p = TestPipeline.create();
Bound<String> write = AvroIO.Write.to(outputFilePrefix).withSchema(String.class);
if (numShards > 1) {
- write = write.withNumShards(numShards).withShardNameTemplate(ShardNameTemplate.INDEX_OF_MAX);
+ write = write.withNumShards(numShards);
} else {
write = write.withoutSharding();
}
p.apply(Create.<String>of(expectedElements)).apply(write);
p.run();
+ String shardNameTemplate = write.getShardNameTemplate();
+
+ assertTestOutputs(expectedElements, numShards, outputFilePrefix, shardNameTemplate);
+ }
+
+ public static void assertTestOutputs(
+ String[] expectedElements, int numShards, String outputFilePrefix, String shardNameTemplate)
+ throws IOException {
// Validate that the data written matches the expected elements in the expected order
List<File> expectedFiles = new ArrayList<>();
- if (numShards == 1) {
- expectedFiles.add(baseOutputFile);
- } else {
- for (int i = 0; i < numShards; i++) {
- expectedFiles.add(
- new File(
- IOChannelUtils.constructName(
- outputFilePrefix,
- ShardNameTemplate.INDEX_OF_MAX,
- "" /* no suffix */,
- i,
- numShards)));
- }
+ for (int i = 0; i < numShards; i++) {
+ expectedFiles.add(
+ new File(
+ IOChannelUtils.constructName(
+ outputFilePrefix, shardNameTemplate, "" /* no suffix */, i, numShards)));
}
List<String> actualElements = new ArrayList<>();
for (File outputFile : expectedFiles) {
assertTrue("Expected output file " + outputFile.getName(), outputFile.exists());
try (DataFileReader<String> reader =
- new DataFileReader<>(outputFile, AvroCoder.of(String.class).createDatumReader())) {
+ new DataFileReader<>(outputFile, AvroCoder.of(String.class).createDatumReader())) {
Iterators.addAll(actualElements, reader);
}
}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java
index bc3a6fe..6a660bf 100644
--- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java
+++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java
@@ -158,7 +158,8 @@ public class TextIOTest {
}
<T> void runTestWrite(T[] elems, Coder<T> coder, int numShards) throws Exception {
- String filename = tmpFolder.newFile("file.txt").getPath();
+ String outputName = "file.txt";
+ String baseFilename = tmpFolder.newFile(outputName).getPath();
Pipeline p = TestPipeline.create();
@@ -166,11 +167,11 @@ public class TextIOTest {
TextIO.Write.Bound<T> write;
if (coder.equals(StringUtf8Coder.of())) {
- TextIO.Write.Bound<String> writeStrings = TextIO.Write.to(filename);
+ TextIO.Write.Bound<String> writeStrings = TextIO.Write.to(baseFilename);
// T==String
write = (TextIO.Write.Bound<T>) writeStrings;
} else {
- write = TextIO.Write.to(filename).withCoder(coder);
+ write = TextIO.Write.to(baseFilename).withCoder(coder);
}
if (numShards == 1) {
write = write.withoutSharding();
@@ -182,17 +183,23 @@ public class TextIOTest {
p.run();
+ assertOutputFiles(elems, coder, numShards, tmpFolder, outputName, write.getShardNameTemplate());
+ }
+
+ public static <T> void assertOutputFiles(
+ T[] elems,
+ Coder<T> coder,
+ int numShards,
+ TemporaryFolder rootLocation,
+ String outputName,
+ String shardNameTemplate)
+ throws Exception {
List<File> expectedFiles = new ArrayList<>();
- if (numShards == 1) {
- expectedFiles.add(new File(filename));
- } else {
- for (int i = 0; i < numShards; i++) {
- expectedFiles.add(
- new File(
- tmpFolder.getRoot(),
- IOChannelUtils.constructName(
- "file.txt", ShardNameTemplate.INDEX_OF_MAX, "", i, numShards)));
- }
+ for (int i = 0; i < numShards; i++) {
+ expectedFiles.add(
+ new File(
+ rootLocation.getRoot(),
+ IOChannelUtils.constructName(outputName, shardNameTemplate, "", i, numShards)));
}
List<String> actual = new ArrayList<>();
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactoryTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactoryTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactoryTest.java
new file mode 100644
index 0000000..a90ba7b
--- /dev/null
+++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactoryTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.theInstance;
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.dataflow.sdk.io.AvroIO;
+import com.google.cloud.dataflow.sdk.io.AvroIOTest;
+import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PDone;
+
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.File;
+
+/**
+ * Tests for {@link AvroIOShardedWriteFactory}.
+ */
+@RunWith(JUnit4.class)
+public class AvroIOShardedWriteFactoryTest {
+
+ @Rule public TemporaryFolder tmp = new TemporaryFolder();
+ private AvroIOShardedWriteFactory factory;
+
+ @Before
+ public void setup() {
+ factory = new AvroIOShardedWriteFactory();
+ }
+
+ @Test
+ public void originalWithoutShardingReturnsOriginal() throws Exception {
+ File file = tmp.newFile("foo");
+ PTransform<PCollection<String>, PDone> original =
+ AvroIO.Write.withSchema(String.class).to(file.getAbsolutePath()).withoutSharding();
+ PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+ assertThat(overridden, theInstance(original));
+ }
+
+ @Test
+ public void originalShardingNotSpecifiedReturnsOriginal() throws Exception {
+ File file = tmp.newFile("foo");
+ PTransform<PCollection<String>, PDone> original =
+ AvroIO.Write.withSchema(String.class).to(file.getAbsolutePath());
+ PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+ assertThat(overridden, theInstance(original));
+ }
+
+ @Test
+ public void originalShardedToOneReturnsExplicitlySharded() throws Exception {
+ File file = tmp.newFile("foo");
+ AvroIO.Write.Bound<String> original =
+ AvroIO.Write.withSchema(String.class).to(file.getAbsolutePath()).withNumShards(1);
+ PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+ assertThat(overridden, not(Matchers.<PTransform<PCollection<String>, PDone>>equalTo(original)));
+
+ TestPipeline p = TestPipeline.create();
+ String[] elems = new String[] {"foo", "bar", "baz"};
+ p.apply(Create.<String>of(elems)).apply(overridden);
+
+ file.delete();
+
+ p.run();
+ AvroIOTest.assertTestOutputs(elems, 1, file.getAbsolutePath(), original.getShardNameTemplate());
+ }
+
+ @Test
+ public void originalShardedToManyReturnsExplicitlySharded() throws Exception {
+ File file = tmp.newFile("foo");
+ AvroIO.Write.Bound<String> original =
+ AvroIO.Write.withSchema(String.class).to(file.getAbsolutePath()).withNumShards(3);
+ PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+ assertThat(overridden, not(Matchers.<PTransform<PCollection<String>, PDone>>equalTo(original)));
+
+ TestPipeline p = TestPipeline.create();
+ String[] elems = new String[] {"foo", "bar", "baz", "spam", "ham", "eggs"};
+ p.apply(Create.<String>of(elems)).apply(overridden);
+
+ file.delete();
+ p.run();
+ AvroIOTest.assertTestOutputs(elems, 3, file.getAbsolutePath(), original.getShardNameTemplate());
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactoryTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactoryTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactoryTest.java
new file mode 100644
index 0000000..4c08777
--- /dev/null
+++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactoryTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.theInstance;
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.io.TextIO;
+import com.google.cloud.dataflow.sdk.io.TextIOTest;
+import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PDone;
+
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.File;
+
+/**
+ * Tests for {@link TextIOShardedWriteFactory}.
+ */
+@RunWith(JUnit4.class)
+public class TextIOShardedWriteFactoryTest {
+ @Rule public TemporaryFolder tmp = new TemporaryFolder();
+ private TextIOShardedWriteFactory factory;
+
+ @Before
+ public void setup() {
+ factory = new TextIOShardedWriteFactory();
+ }
+
+ @Test
+ public void originalWithoutShardingReturnsOriginal() throws Exception {
+ File file = tmp.newFile("foo");
+ PTransform<PCollection<String>, PDone> original =
+ TextIO.Write.to(file.getAbsolutePath()).withoutSharding();
+ PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+ assertThat(overridden, theInstance(original));
+ }
+
+ @Test
+ public void originalShardingNotSpecifiedReturnsOriginal() throws Exception {
+ File file = tmp.newFile("foo");
+ PTransform<PCollection<String>, PDone> original = TextIO.Write.to(file.getAbsolutePath());
+ PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+ assertThat(overridden, theInstance(original));
+ }
+
+ @Test
+ public void originalShardedToOneReturnsExplicitlySharded() throws Exception {
+ File file = tmp.newFile("foo");
+ TextIO.Write.Bound<String> original =
+ TextIO.Write.to(file.getAbsolutePath()).withNumShards(1);
+ PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+ assertThat(overridden, not(Matchers.<PTransform<PCollection<String>, PDone>>equalTo(original)));
+
+ TestPipeline p = TestPipeline.create();
+ String[] elems = new String[] {"foo", "bar", "baz"};
+ p.apply(Create.<String>of(elems)).apply(overridden);
+
+ file.delete();
+
+ p.run();
+ TextIOTest.assertOutputFiles(
+ elems, StringUtf8Coder.of(), 1, tmp, "foo", original.getShardNameTemplate());
+ }
+
+ @Test
+ public void originalShardedToManyReturnsExplicitlySharded() throws Exception {
+ File file = tmp.newFile("foo");
+ TextIO.Write.Bound<String> original = TextIO.Write.to(file.getAbsolutePath()).withNumShards(3);
+ PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+ assertThat(overridden, not(Matchers.<PTransform<PCollection<String>, PDone>>equalTo(original)));
+
+ TestPipeline p = TestPipeline.create();
+ String[] elems = new String[] {"foo", "bar", "baz", "spam", "ham", "eggs"};
+ p.apply(Create.<String>of(elems)).apply(overridden);
+
+ file.delete();
+ p.run();
+ TextIOTest.assertOutputFiles(
+ elems, StringUtf8Coder.of(), 3, tmp, "foo", original.getShardNameTemplate());
+ }
+}
[2/2] incubator-beam git commit: [BEAM-22] This closes #148
Posted by lc...@apache.org.
[BEAM-22] This closes #148
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/e1a47117
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/e1a47117
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/e1a47117
Branch: refs/heads/master
Commit: e1a471176ab6fdfb5b1d1cdbcdcf8ba7ab726d00
Parents: cd2a3a1 a8ce5fd
Author: Luke Cwik <lc...@google.com>
Authored: Mon Apr 11 13:20:43 2016 -0700
Committer: Luke Cwik <lc...@google.com>
Committed: Mon Apr 11 13:20:43 2016 -0700
----------------------------------------------------------------------
.../inprocess/AvroIOShardedWriteFactory.java | 76 +++++++++++++
.../inprocess/InProcessPipelineRunner.java | 4 +
.../runners/inprocess/ShardControlledWrite.java | 81 ++++++++++++++
.../inprocess/TextIOShardedWriteFactory.java | 78 +++++++++++++
.../cloud/dataflow/sdk/io/AvroIOTest.java | 30 ++---
.../cloud/dataflow/sdk/io/TextIOTest.java | 33 +++---
.../AvroIOShardedWriteFactoryTest.java | 112 +++++++++++++++++++
.../TextIOShardedWriteFactoryTest.java | 112 +++++++++++++++++++
8 files changed, 498 insertions(+), 28 deletions(-)
----------------------------------------------------------------------