You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2016/04/11 22:20:56 UTC

[1/2] incubator-beam git commit: Add ShardControlledWrite override

Repository: incubator-beam
Updated Branches:
  refs/heads/master cd2a3a18f -> e1a471176


Add ShardControlledWrite override

This is used for TextIO and AvroIO, which provide withNumOutputShards
methods to control the number of output files. Apply this override in
the InProcessPipelineRunner.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/a8ce5fd7
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/a8ce5fd7
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/a8ce5fd7

Branch: refs/heads/master
Commit: a8ce5fd7950a32271a30613c3c41679576b869ca
Parents: cd2a3a1
Author: Thomas Groh <tg...@google.com>
Authored: Tue Mar 29 10:06:02 2016 -0700
Committer: Luke Cwik <lc...@google.com>
Committed: Mon Apr 11 13:06:53 2016 -0700

----------------------------------------------------------------------
 .../inprocess/AvroIOShardedWriteFactory.java    |  76 +++++++++++++
 .../inprocess/InProcessPipelineRunner.java      |   4 +
 .../runners/inprocess/ShardControlledWrite.java |  81 ++++++++++++++
 .../inprocess/TextIOShardedWriteFactory.java    |  78 +++++++++++++
 .../cloud/dataflow/sdk/io/AvroIOTest.java       |  30 ++---
 .../cloud/dataflow/sdk/io/TextIOTest.java       |  33 +++---
 .../AvroIOShardedWriteFactoryTest.java          | 112 +++++++++++++++++++
 .../TextIOShardedWriteFactoryTest.java          | 112 +++++++++++++++++++
 8 files changed, 498 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactory.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactory.java
new file mode 100644
index 0000000..49576e5
--- /dev/null
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactory.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import com.google.cloud.dataflow.sdk.io.AvroIO;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.util.IOChannelUtils;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PDone;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.POutput;
+
+class AvroIOShardedWriteFactory implements PTransformOverrideFactory {
+  @Override
+  public <InputT extends PInput, OutputT extends POutput> PTransform<InputT, OutputT> override(
+      PTransform<InputT, OutputT> transform) {
+    if (transform instanceof AvroIO.Write.Bound) {
+      @SuppressWarnings("unchecked")
+      AvroIO.Write.Bound<InputT> originalWrite = (AvroIO.Write.Bound<InputT>) transform;
+      if (originalWrite.getNumShards() > 1
+          || (originalWrite.getNumShards() == 1
+              && !"".equals(originalWrite.getShardNameTemplate()))) {
+        @SuppressWarnings("unchecked")
+        PTransform<InputT, OutputT> override =
+            (PTransform<InputT, OutputT>) new AvroIOShardedWrite<InputT>(originalWrite);
+        return override;
+      }
+    }
+    return transform;
+  }
+
+  private class AvroIOShardedWrite<InputT> extends ShardControlledWrite<InputT> {
+    private final AvroIO.Write.Bound<InputT> initial;
+
+    private AvroIOShardedWrite(AvroIO.Write.Bound<InputT> initial) {
+      this.initial = initial;
+    }
+
+    @Override
+    int getNumShards() {
+      return initial.getNumShards();
+    }
+
+    @Override
+    PTransform<? super PCollection<InputT>, PDone> getSingleShardTransform(int shardNum) {
+      String shardName =
+          IOChannelUtils.constructName(
+              initial.getFilenamePrefix(),
+              initial.getShardNameTemplate(),
+              initial.getFilenameSuffix(),
+              shardNum,
+              getNumShards());
+      return initial.withoutSharding().to(shardName).withSuffix("");
+    }
+
+    @Override
+    protected PTransform<PCollection<InputT>, PDone> delegate() {
+      return initial;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
index fa93994..764ae09 100644
--- a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java
@@ -21,6 +21,8 @@ import com.google.cloud.dataflow.sdk.Pipeline;
 import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException;
 import com.google.cloud.dataflow.sdk.PipelineResult;
 import com.google.cloud.dataflow.sdk.annotations.Experimental;
+import com.google.cloud.dataflow.sdk.io.AvroIO;
+import com.google.cloud.dataflow.sdk.io.TextIO;
 import com.google.cloud.dataflow.sdk.options.PipelineOptions;
 import com.google.cloud.dataflow.sdk.runners.AggregatorPipelineExtractor;
 import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException;
@@ -83,6 +85,8 @@ public class InProcessPipelineRunner
               .put(Create.Values.class, new InProcessCreateOverrideFactory())
               .put(GroupByKey.class, new InProcessGroupByKeyOverrideFactory())
               .put(CreatePCollectionView.class, new InProcessViewOverrideFactory())
+              .put(AvroIO.Write.Bound.class, new AvroIOShardedWriteFactory())
+              .put(TextIO.Write.Bound.class, new TextIOShardedWriteFactory())
               .build();
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ShardControlledWrite.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ShardControlledWrite.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ShardControlledWrite.java
new file mode 100644
index 0000000..fc6419e
--- /dev/null
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/ShardControlledWrite.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.transforms.Partition;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PCollectionList;
+import com.google.cloud.dataflow.sdk.values.PDone;
+
+import java.util.concurrent.ThreadLocalRandom;
+
+/**
+ * A write that explicitly controls its number of output shards.
+ */
+abstract class ShardControlledWrite<InputT>
+    extends ForwardingPTransform<PCollection<InputT>, PDone> {
+  @Override
+  public PDone apply(PCollection<InputT> input) {
+    int numShards = getNumShards();
+    checkArgument(
+        numShards >= 1,
+        "%s should only be applied if the output has a controlled number of shards (> 1); got %s",
+        getClass().getSimpleName(),
+        getNumShards());
+    PCollectionList<InputT> shards =
+        input.apply(
+            "PartitionInto" + numShards + "Shards",
+            Partition.of(getNumShards(), new RandomSeedPartitionFn<InputT>()));
+    for (int i = 0; i < shards.size(); i++) {
+      PCollection<InputT> shard = shards.get(i);
+      PTransform<? super PCollection<InputT>, PDone> writeShard = getSingleShardTransform(i);
+      shard.apply(String.format("%s(Shard:%s)", writeShard.getName(), i), writeShard);
+    }
+    return PDone.in(input.getPipeline());
+  }
+
+  /**
+   * Returns the number of shards this {@link PTransform} should write to.
+   */
+  abstract int getNumShards();
+
+  /**
+   * Returns a {@link PTransform} that performs a write to the shard with the specified shard
+   * number.
+   *
+   * <p>This method will be called n times, where n is the value of {@link #getNumShards()}, for
+   * shard numbers {@code [0...n)}.
+   */
+  abstract PTransform<? super PCollection<InputT>, PDone> getSingleShardTransform(int shardNum);
+
+  private static class RandomSeedPartitionFn<T> implements Partition.PartitionFn<T> {
+    int nextPartition = -1;
+    @Override
+    public int partitionFor(T elem, int numPartitions) {
+      if (nextPartition < 0) {
+        nextPartition = ThreadLocalRandom.current().nextInt(numPartitions);
+      }
+      nextPartition++;
+      nextPartition %= numPartitions;
+      return nextPartition;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactory.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactory.java b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactory.java
new file mode 100644
index 0000000..af433c4
--- /dev/null
+++ b/sdks/java/core/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactory.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import com.google.cloud.dataflow.sdk.io.TextIO;
+import com.google.cloud.dataflow.sdk.io.TextIO.Write.Bound;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.util.IOChannelUtils;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PDone;
+import com.google.cloud.dataflow.sdk.values.PInput;
+import com.google.cloud.dataflow.sdk.values.POutput;
+
+class TextIOShardedWriteFactory implements PTransformOverrideFactory {
+
+  @Override
+  public <InputT extends PInput, OutputT extends POutput> PTransform<InputT, OutputT> override(
+      PTransform<InputT, OutputT> transform) {
+    if (transform instanceof TextIO.Write.Bound) {
+      @SuppressWarnings("unchecked")
+      TextIO.Write.Bound<InputT> originalWrite = (TextIO.Write.Bound<InputT>) transform;
+      if (originalWrite.getNumShards() > 1
+          || (originalWrite.getNumShards() == 1
+              && !"".equals(originalWrite.getShardNameTemplate()))) {
+        @SuppressWarnings("unchecked")
+        PTransform<InputT, OutputT> override =
+            (PTransform<InputT, OutputT>) new TextIOShardedWrite<InputT>(originalWrite);
+        return override;
+      }
+    }
+    return transform;
+  }
+
+  private static class TextIOShardedWrite<InputT> extends ShardControlledWrite<InputT> {
+    private final TextIO.Write.Bound<InputT> initial;
+
+    private TextIOShardedWrite(Bound<InputT> initial) {
+      this.initial = initial;
+    }
+
+    @Override
+    int getNumShards() {
+      return initial.getNumShards();
+    }
+
+    @Override
+    PTransform<PCollection<InputT>, PDone> getSingleShardTransform(int shardNum) {
+      String shardName =
+          IOChannelUtils.constructName(
+              initial.getFilenamePrefix(),
+              initial.getShardTemplate(),
+              initial.getFilenameSuffix(),
+              shardNum,
+              getNumShards());
+      return TextIO.Write.withCoder(initial.getCoder()).to(shardName).withoutSharding();
+    }
+
+    @Override
+    protected PTransform<PCollection<InputT>, PDone> delegate() {
+      return initial;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java
index 7949b9f..f9e4dba 100644
--- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java
+++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/AvroIOTest.java
@@ -206,35 +206,35 @@ public class AvroIOTest {
     TestPipeline p = TestPipeline.create();
     Bound<String> write = AvroIO.Write.to(outputFilePrefix).withSchema(String.class);
     if (numShards > 1) {
-      write = write.withNumShards(numShards).withShardNameTemplate(ShardNameTemplate.INDEX_OF_MAX);
+      write = write.withNumShards(numShards);
     } else {
       write = write.withoutSharding();
     }
     p.apply(Create.<String>of(expectedElements)).apply(write);
     p.run();
 
+    String shardNameTemplate = write.getShardNameTemplate();
+
+    assertTestOutputs(expectedElements, numShards, outputFilePrefix, shardNameTemplate);
+  }
+
+  public static void assertTestOutputs(
+      String[] expectedElements, int numShards, String outputFilePrefix, String shardNameTemplate)
+      throws IOException {
     // Validate that the data written matches the expected elements in the expected order
     List<File> expectedFiles = new ArrayList<>();
-    if (numShards == 1) {
-      expectedFiles.add(baseOutputFile);
-    } else {
-      for (int i = 0; i < numShards; i++) {
-        expectedFiles.add(
-            new File(
-                IOChannelUtils.constructName(
-                    outputFilePrefix,
-                    ShardNameTemplate.INDEX_OF_MAX,
-                    "" /* no suffix */,
-                    i,
-                    numShards)));
-      }
+    for (int i = 0; i < numShards; i++) {
+      expectedFiles.add(
+          new File(
+              IOChannelUtils.constructName(
+                  outputFilePrefix, shardNameTemplate, "" /* no suffix */, i, numShards)));
     }
 
     List<String> actualElements = new ArrayList<>();
     for (File outputFile : expectedFiles) {
       assertTrue("Expected output file " + outputFile.getName(), outputFile.exists());
       try (DataFileReader<String> reader =
-              new DataFileReader<>(outputFile, AvroCoder.of(String.class).createDatumReader())) {
+          new DataFileReader<>(outputFile, AvroCoder.of(String.class).createDatumReader())) {
         Iterators.addAll(actualElements, reader);
       }
     }

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java
index bc3a6fe..6a660bf 100644
--- a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java
+++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java
@@ -158,7 +158,8 @@ public class TextIOTest {
   }
 
   <T> void runTestWrite(T[] elems, Coder<T> coder, int numShards) throws Exception {
-    String filename = tmpFolder.newFile("file.txt").getPath();
+    String outputName = "file.txt";
+    String baseFilename = tmpFolder.newFile(outputName).getPath();
 
     Pipeline p = TestPipeline.create();
 
@@ -166,11 +167,11 @@ public class TextIOTest {
 
     TextIO.Write.Bound<T> write;
     if (coder.equals(StringUtf8Coder.of())) {
-      TextIO.Write.Bound<String> writeStrings = TextIO.Write.to(filename);
+      TextIO.Write.Bound<String> writeStrings = TextIO.Write.to(baseFilename);
       // T==String
       write = (TextIO.Write.Bound<T>) writeStrings;
     } else {
-      write = TextIO.Write.to(filename).withCoder(coder);
+      write = TextIO.Write.to(baseFilename).withCoder(coder);
     }
     if (numShards == 1) {
       write = write.withoutSharding();
@@ -182,17 +183,23 @@ public class TextIOTest {
 
     p.run();
 
+    assertOutputFiles(elems, coder, numShards, tmpFolder, outputName, write.getShardNameTemplate());
+  }
+
+  public static <T> void assertOutputFiles(
+      T[] elems,
+      Coder<T> coder,
+      int numShards,
+      TemporaryFolder rootLocation,
+      String outputName,
+      String shardNameTemplate)
+      throws Exception {
     List<File> expectedFiles = new ArrayList<>();
-    if (numShards == 1) {
-      expectedFiles.add(new File(filename));
-    } else {
-      for (int i = 0; i < numShards; i++) {
-        expectedFiles.add(
-            new File(
-                tmpFolder.getRoot(),
-                IOChannelUtils.constructName(
-                    "file.txt", ShardNameTemplate.INDEX_OF_MAX, "", i, numShards)));
-      }
+    for (int i = 0; i < numShards; i++) {
+      expectedFiles.add(
+          new File(
+              rootLocation.getRoot(),
+              IOChannelUtils.constructName(outputName, shardNameTemplate, "", i, numShards)));
     }
 
     List<String> actual = new ArrayList<>();

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactoryTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactoryTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactoryTest.java
new file mode 100644
index 0000000..a90ba7b
--- /dev/null
+++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/AvroIOShardedWriteFactoryTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.theInstance;
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.dataflow.sdk.io.AvroIO;
+import com.google.cloud.dataflow.sdk.io.AvroIOTest;
+import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PDone;
+
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.File;
+
+/**
+ * Tests for {@link AvroIOShardedWriteFactory}.
+ */
+@RunWith(JUnit4.class)
+public class AvroIOShardedWriteFactoryTest {
+
+  @Rule public TemporaryFolder tmp = new TemporaryFolder();
+  private AvroIOShardedWriteFactory factory;
+
+  @Before
+  public void setup() {
+    factory = new AvroIOShardedWriteFactory();
+  }
+
+  @Test
+  public void originalWithoutShardingReturnsOriginal() throws Exception {
+    File file = tmp.newFile("foo");
+    PTransform<PCollection<String>, PDone> original =
+        AvroIO.Write.withSchema(String.class).to(file.getAbsolutePath()).withoutSharding();
+    PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+    assertThat(overridden, theInstance(original));
+  }
+
+  @Test
+  public void originalShardingNotSpecifiedReturnsOriginal() throws Exception {
+    File file = tmp.newFile("foo");
+    PTransform<PCollection<String>, PDone> original =
+        AvroIO.Write.withSchema(String.class).to(file.getAbsolutePath());
+    PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+    assertThat(overridden, theInstance(original));
+  }
+
+  @Test
+  public void originalShardedToOneReturnsExplicitlySharded() throws Exception {
+    File file = tmp.newFile("foo");
+    AvroIO.Write.Bound<String> original =
+        AvroIO.Write.withSchema(String.class).to(file.getAbsolutePath()).withNumShards(1);
+    PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+    assertThat(overridden, not(Matchers.<PTransform<PCollection<String>, PDone>>equalTo(original)));
+
+    TestPipeline p = TestPipeline.create();
+    String[] elems = new String[] {"foo", "bar", "baz"};
+    p.apply(Create.<String>of(elems)).apply(overridden);
+
+    file.delete();
+
+    p.run();
+    AvroIOTest.assertTestOutputs(elems, 1, file.getAbsolutePath(), original.getShardNameTemplate());
+  }
+
+  @Test
+  public void originalShardedToManyReturnsExplicitlySharded() throws Exception {
+    File file = tmp.newFile("foo");
+    AvroIO.Write.Bound<String> original =
+        AvroIO.Write.withSchema(String.class).to(file.getAbsolutePath()).withNumShards(3);
+    PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+    assertThat(overridden, not(Matchers.<PTransform<PCollection<String>, PDone>>equalTo(original)));
+
+    TestPipeline p = TestPipeline.create();
+    String[] elems = new String[] {"foo", "bar", "baz", "spam", "ham", "eggs"};
+    p.apply(Create.<String>of(elems)).apply(overridden);
+
+    file.delete();
+    p.run();
+    AvroIOTest.assertTestOutputs(elems, 3, file.getAbsolutePath(), original.getShardNameTemplate());
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/a8ce5fd7/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactoryTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactoryTest.java b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactoryTest.java
new file mode 100644
index 0000000..4c08777
--- /dev/null
+++ b/sdks/java/core/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/TextIOShardedWriteFactoryTest.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.google.cloud.dataflow.sdk.runners.inprocess;
+
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.theInstance;
+import static org.junit.Assert.assertThat;
+
+import com.google.cloud.dataflow.sdk.coders.StringUtf8Coder;
+import com.google.cloud.dataflow.sdk.io.TextIO;
+import com.google.cloud.dataflow.sdk.io.TextIOTest;
+import com.google.cloud.dataflow.sdk.testing.TestPipeline;
+import com.google.cloud.dataflow.sdk.transforms.Create;
+import com.google.cloud.dataflow.sdk.transforms.PTransform;
+import com.google.cloud.dataflow.sdk.values.PCollection;
+import com.google.cloud.dataflow.sdk.values.PDone;
+
+import org.hamcrest.Matchers;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import java.io.File;
+
+/**
+ * Tests for {@link TextIOShardedWriteFactory}.
+ */
+@RunWith(JUnit4.class)
+public class TextIOShardedWriteFactoryTest {
+  @Rule public TemporaryFolder tmp = new TemporaryFolder();
+  private TextIOShardedWriteFactory factory;
+
+  @Before
+  public void setup() {
+    factory = new TextIOShardedWriteFactory();
+  }
+
+  @Test
+  public void originalWithoutShardingReturnsOriginal() throws Exception {
+    File file = tmp.newFile("foo");
+    PTransform<PCollection<String>, PDone> original =
+        TextIO.Write.to(file.getAbsolutePath()).withoutSharding();
+    PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+    assertThat(overridden, theInstance(original));
+  }
+
+  @Test
+  public void originalShardingNotSpecifiedReturnsOriginal() throws Exception {
+    File file = tmp.newFile("foo");
+    PTransform<PCollection<String>, PDone> original = TextIO.Write.to(file.getAbsolutePath());
+    PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+    assertThat(overridden, theInstance(original));
+  }
+
+  @Test
+  public void originalShardedToOneReturnsExplicitlySharded() throws Exception {
+    File file = tmp.newFile("foo");
+    TextIO.Write.Bound<String> original =
+        TextIO.Write.to(file.getAbsolutePath()).withNumShards(1);
+    PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+    assertThat(overridden, not(Matchers.<PTransform<PCollection<String>, PDone>>equalTo(original)));
+
+    TestPipeline p = TestPipeline.create();
+    String[] elems = new String[] {"foo", "bar", "baz"};
+    p.apply(Create.<String>of(elems)).apply(overridden);
+
+    file.delete();
+
+    p.run();
+    TextIOTest.assertOutputFiles(
+        elems, StringUtf8Coder.of(), 1, tmp, "foo", original.getShardNameTemplate());
+  }
+
+  @Test
+  public void originalShardedToManyReturnsExplicitlySharded() throws Exception {
+    File file = tmp.newFile("foo");
+    TextIO.Write.Bound<String> original = TextIO.Write.to(file.getAbsolutePath()).withNumShards(3);
+    PTransform<PCollection<String>, PDone> overridden = factory.override(original);
+
+    assertThat(overridden, not(Matchers.<PTransform<PCollection<String>, PDone>>equalTo(original)));
+
+    TestPipeline p = TestPipeline.create();
+    String[] elems = new String[] {"foo", "bar", "baz", "spam", "ham", "eggs"};
+    p.apply(Create.<String>of(elems)).apply(overridden);
+
+    file.delete();
+    p.run();
+    TextIOTest.assertOutputFiles(
+        elems, StringUtf8Coder.of(), 3, tmp, "foo", original.getShardNameTemplate());
+  }
+}


[2/2] incubator-beam git commit: [BEAM-22] This closes #148

Posted by lc...@apache.org.
[BEAM-22] This closes #148


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/e1a47117
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/e1a47117
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/e1a47117

Branch: refs/heads/master
Commit: e1a471176ab6fdfb5b1d1cdbcdcf8ba7ab726d00
Parents: cd2a3a1 a8ce5fd
Author: Luke Cwik <lc...@google.com>
Authored: Mon Apr 11 13:20:43 2016 -0700
Committer: Luke Cwik <lc...@google.com>
Committed: Mon Apr 11 13:20:43 2016 -0700

----------------------------------------------------------------------
 .../inprocess/AvroIOShardedWriteFactory.java    |  76 +++++++++++++
 .../inprocess/InProcessPipelineRunner.java      |   4 +
 .../runners/inprocess/ShardControlledWrite.java |  81 ++++++++++++++
 .../inprocess/TextIOShardedWriteFactory.java    |  78 +++++++++++++
 .../cloud/dataflow/sdk/io/AvroIOTest.java       |  30 ++---
 .../cloud/dataflow/sdk/io/TextIOTest.java       |  33 +++---
 .../AvroIOShardedWriteFactoryTest.java          | 112 +++++++++++++++++++
 .../TextIOShardedWriteFactoryTest.java          | 112 +++++++++++++++++++
 8 files changed, 498 insertions(+), 28 deletions(-)
----------------------------------------------------------------------