You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by dh...@apache.org on 2016/07/19 00:08:01 UTC

[4/4] incubator-beam git commit: Dynamically choose number of shards in the DirectRunner

Dynamically choose number of shards in the DirectRunner

Add a Write Override Factory that limits the number of shards if
unspecified. This ensures that we will not write an output file per-key
due to bundling.

Do so by obtaining a count of the elements and obtaining the number of
shards based on the number of outputs.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1535d732
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1535d732
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1535d732

Branch: refs/heads/master
Commit: 1535d732fe384a72e5ea46c065779a22dad1e2f6
Parents: 674d831
Author: Thomas Groh <tg...@google.com>
Authored: Wed Jul 13 14:26:10 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Jul 18 17:07:41 2016 -0700

----------------------------------------------------------------------
 .../beam/runners/direct/DirectRunner.java       |   2 +
 .../direct/WriteWithShardingFactory.java        | 141 +++++++++
 .../direct/WriteWithShardingFactoryTest.java    | 285 +++++++++++++++++++
 .../java/org/apache/beam/sdk/io/TextIOTest.java |  24 +-
 .../java/org/apache/beam/sdk/io/WriteTest.java  |   4 +-
 5 files changed, 444 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index 7408c0b..7fd38c2 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -24,6 +24,7 @@ import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
 import org.apache.beam.sdk.PipelineResult;
 import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.io.Write;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.runners.AggregatorPipelineExtractor;
 import org.apache.beam.sdk.runners.AggregatorRetrievalException;
@@ -78,6 +79,7 @@ public class DirectRunner
           ImmutableMap.<Class<? extends PTransform>, PTransformOverrideFactory>builder()
               .put(GroupByKey.class, new DirectGroupByKeyOverrideFactory())
               .put(CreatePCollectionView.class, new ViewOverrideFactory())
+              .put(Write.Bound.class, new WriteWithShardingFactory())
               .build();
 
   /**

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
new file mode 100644
index 0000000..93f2408
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import org.apache.beam.sdk.io.Write;
+import org.apache.beam.sdk.io.Write.Bound;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.joda.time.Duration;
+
+import java.util.concurrent.ThreadLocalRandom;
+
+/**
+ * A {@link PTransformOverrideFactory} that overrides {@link Write} {@link PTransform PTransforms}
+ * with an unspecified number of shards with a write with a specified number of shards. The number
+ * of shards is the log base 10 of the number of input records, with up to 2 additional shards.
+ */
+class WriteWithShardingFactory implements PTransformOverrideFactory {
+  static final int MAX_RANDOM_EXTRA_SHARDS = 3;
+
+  @Override
+  public <InputT extends PInput, OutputT extends POutput> PTransform<InputT, OutputT> override(
+      PTransform<InputT, OutputT> transform) {
+    if (transform instanceof Write.Bound) {
+      Write.Bound<InputT> that = (Write.Bound<InputT>) transform;
+      if (that.getNumShards() == 0) {
+        return (PTransform<InputT, OutputT>) new DynamicallyReshardedWrite<InputT>(that);
+      }
+    }
+    return transform;
+  }
+
+  private static class DynamicallyReshardedWrite <T> extends PTransform<PCollection<T>, PDone> {
+    private final transient Write.Bound<T> original;
+
+    private DynamicallyReshardedWrite(Bound<T> original) {
+      this.original = original;
+    }
+
+    @Override
+    public PDone apply(PCollection<T> input) {
+      PCollection<T> records = input.apply("RewindowInputs",
+          Window.<T>into(new GlobalWindows()).triggering(DefaultTrigger.of())
+              .withAllowedLateness(Duration.ZERO)
+              .discardingFiredPanes());
+      final PCollectionView<Long> numRecords = records.apply(Count.<T>globally().asSingletonView());
+      PCollection<T> resharded =
+          records
+              .apply(
+                  "ApplySharding",
+                  ParDo.withSideInputs(numRecords)
+                      .of(
+                          new KeyBasedOnCountFn<T>(
+                              numRecords,
+                              ThreadLocalRandom.current().nextInt(MAX_RANDOM_EXTRA_SHARDS))))
+              .apply("GroupIntoShards", GroupByKey.<Integer, T>create())
+              .apply("DropShardingKeys", Values.<Iterable<T>>create())
+              .apply("FlattenShardIterables", Flatten.<T>iterables());
+      // This is an inverted application to apply the expansion of the original Write PTransform
+      // without adding a new Write Transform Node, which would be overwritten the same way, leading
+      // to an infinite recursion. We cannot modify the number of shards, because that is determined
+      // at runtime.
+      return original.apply(resharded);
+    }
+  }
+
+  @VisibleForTesting
+  static class KeyBasedOnCountFn<T> extends DoFn<T, KV<Integer, T>> {
+    @VisibleForTesting
+    static final int MIN_SHARDS_FOR_LOG = 3;
+
+    private final PCollectionView<Long> numRecords;
+    private final int randomExtraShards;
+    private int currentShard;
+    private int maxShards;
+
+    KeyBasedOnCountFn(PCollectionView<Long> numRecords, int extraShards) {
+      this.numRecords = numRecords;
+      this.randomExtraShards = extraShards;
+    }
+
+    @Override
+    public void processElement(ProcessContext c) throws Exception {
+      if (maxShards == 0L) {
+        maxShards = calculateShards(c.sideInput(numRecords));
+        currentShard = ThreadLocalRandom.current().nextInt(maxShards);
+      }
+      int shard = currentShard;
+      currentShard = (currentShard + 1) % maxShards;
+      c.output(KV.of(shard, c.element()));
+    }
+
+    private int calculateShards(long totalRecords) {
+      checkArgument(
+          totalRecords > 0,
+          "KeyBasedOnCountFn cannot be invoked on an element if there are no elements");
+      if (totalRecords < MIN_SHARDS_FOR_LOG + randomExtraShards) {
+        return (int) totalRecords;
+      }
+      // 100mil records before >7 output files
+      int floorLogRecs = Double.valueOf(Math.log10(totalRecords)).intValue();
+      int shards = Math.max(floorLogRecs, MIN_SHARDS_FOR_LOG) + randomExtraShards;
+      return shards;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
new file mode 100644
index 0000000..a53bc64
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.lessThan;
+import static org.hamcrest.Matchers.not;
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.runners.direct.WriteWithShardingFactory.KeyBasedOnCountFn;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.io.Sink;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.io.Write;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFnTester;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.IOChannelUtils;
+import org.apache.beam.sdk.util.PCollectionViews;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PDone;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterables;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.File;
+import java.io.FileReader;
+import java.io.Reader;
+import java.nio.CharBuffer;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ThreadLocalRandom;
+
+/**
+ * Tests for {@link WriteWithShardingFactory}.
+ */
+public class WriteWithShardingFactoryTest {
+  public static final int INPUT_SIZE = 10000;
+  @Rule public TemporaryFolder tmp = new TemporaryFolder();
+  private WriteWithShardingFactory factory = new WriteWithShardingFactory();
+
+  @Test
+  public void dynamicallyReshardedWrite() throws Exception {
+    List<String> strs = new ArrayList<>(INPUT_SIZE);
+    for (int i = 0; i < INPUT_SIZE; i++) {
+      strs.add(UUID.randomUUID().toString());
+    }
+    Collections.shuffle(strs);
+
+    String fileName = "resharded_write";
+    String outputPath = tmp.getRoot().getAbsolutePath();
+    String targetLocation = IOChannelUtils.resolve(outputPath, fileName);
+    TestPipeline p = TestPipeline.create();
+    // TextIO is implemented in terms of the Write PTransform. When sharding is not specified,
+    // resharding should be automatically applied
+    p.apply(Create.of(strs)).apply(TextIO.Write.to(targetLocation));
+
+    p.run();
+
+    Collection<String> files = IOChannelUtils.getFactory(outputPath).match(targetLocation + "*");
+    List<String> actuals = new ArrayList(strs.size());
+    for (String file : files) {
+      CharBuffer buf = CharBuffer.allocate((int) new File(file).length());
+      try (Reader reader = new FileReader(file)) {
+        reader.read(buf);
+        buf.flip();
+      }
+
+      String[] readStrs = buf.toString().split("\n");
+      for (String read : readStrs) {
+        if (read.length() > 0) {
+          actuals.add(read);
+        }
+      }
+    }
+
+    assertThat(actuals, containsInAnyOrder(strs.toArray()));
+    assertThat(
+        files,
+        hasSize(
+            allOf(
+                greaterThan(1),
+                lessThan(
+                    (int)
+                        (Math.log10(INPUT_SIZE)
+                            + WriteWithShardingFactory.MAX_RANDOM_EXTRA_SHARDS)))));
+  }
+
+  @Test
+  public void withShardingSpecifiesOriginalTransform() {
+    PTransform<PCollection<Object>, PDone> original = Write.to(new TestSink()).withNumShards(3);
+
+    assertThat(factory.override(original), equalTo(original));
+  }
+
+  @Test
+  public void withNonWriteReturnsOriginalTransform() {
+    PTransform<PCollection<Object>, PDone> original =
+        new PTransform<PCollection<Object>, PDone>() {
+          @Override
+          public PDone apply(PCollection<Object> input) {
+            return PDone.in(input.getPipeline());
+          }
+        };
+
+    assertThat(factory.override(original), equalTo(original));
+  }
+
+  @Test
+  public void withNoShardingSpecifiedReturnsNewTransform() {
+    PTransform<PCollection<Object>, PDone> original = Write.to(new TestSink());
+    assertThat(factory.override(original), not(equalTo(original)));
+  }
+
+  @Test
+  public void keyBasedOnCountFnWithOneElement() throws Exception {
+    PCollectionView<Long> elementCountView =
+        PCollectionViews.singletonView(
+            TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0);
+    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 1L);
+
+    List<KV<Integer, String>> outputs = fnTester.processBundle("foo", "bar", "bazbar");
+    assertThat(
+        outputs, containsInAnyOrder(KV.of(0, "foo"), KV.of(0, "bar"), KV.of(0, "bazbar")));
+  }
+
+  @Test
+  public void keyBasedOnCountFnWithTwoElements() throws Exception {
+    PCollectionView<Long> elementCountView =
+        PCollectionViews.singletonView(
+            TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0);
+    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 2L);
+
+    List<KV<Integer, String>> outputs = fnTester.processBundle("foo", "bar");
+    assertThat(
+        outputs,
+        anyOf(
+            containsInAnyOrder(KV.of(0, "foo"), KV.of(1, "bar")),
+            containsInAnyOrder(KV.of(1, "foo"), KV.of(0, "bar"))));
+  }
+
+  @Test
+  public void keyBasedOnCountFnFewElementsThreeShards() throws Exception {
+    PCollectionView<Long> elementCountView =
+        PCollectionViews.singletonView(
+            TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0);
+    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 100L);
+
+    List<KV<Integer, String>> outputs =
+        fnTester.processBundle("foo", "bar", "baz", "foobar", "foobaz", "barbaz");
+    assertThat(
+        Iterables.transform(
+            outputs,
+            new Function<KV<Integer, String>, Integer>() {
+              @Override
+              public Integer apply(KV<Integer, String> input) {
+                return input.getKey();
+              }
+            }),
+        containsInAnyOrder(0, 0, 1, 1, 2, 2));
+  }
+
+  @Test
+  public void keyBasedOnCountFnManyElements() throws Exception {
+    PCollectionView<Long> elementCountView =
+        PCollectionViews.singletonView(
+            TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0);
+    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+    double count = Math.pow(10, 10);
+    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, (long) count);
+
+    List<String> strings = new ArrayList<>();
+    for (int i = 0; i < 100; i++) {
+      strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong()));
+    }
+    List<KV<Integer, String>> kvs = fnTester.processBundle(strings);
+    long maxKey = -1L;
+    for (KV<Integer, String> kv : kvs) {
+      maxKey = Math.max(maxKey, kv.getKey());
+    }
+    assertThat(maxKey, equalTo(9L));
+  }
+
+  @Test
+  public void keyBasedOnCountFnFewElementsExtraShards() throws Exception {
+    PCollectionView<Long> elementCountView =
+        PCollectionViews.singletonView(
+            TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 10);
+    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+    long countValue = (long) KeyBasedOnCountFn.MIN_SHARDS_FOR_LOG + 3;
+    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, countValue);
+
+    List<String> strings = new ArrayList<>();
+    for (int i = 0; i < 100; i++) {
+      strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong()));
+    }
+    List<KV<Integer, String>> kvs = fnTester.processBundle(strings);
+    long maxKey = -1L;
+    for (KV<Integer, String> kv : kvs) {
+      maxKey = Math.max(maxKey, kv.getKey());
+    }
+    // 0 to n-1 shard ids.
+    assertThat(maxKey, equalTo(countValue - 1));
+  }
+
+  @Test
+  public void keyBasedOnCountFnManyElementsExtraShards() throws Exception {
+    PCollectionView<Long> elementCountView =
+        PCollectionViews.singletonView(
+            TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+    KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 3);
+    DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+    double count = Math.pow(10, 10);
+    fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, (long) count);
+
+    List<String> strings = new ArrayList<>();
+    for (int i = 0; i < 100; i++) {
+      strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong()));
+    }
+    List<KV<Integer, String>> kvs = fnTester.processBundle(strings);
+    long maxKey = -1L;
+    for (KV<Integer, String> kv : kvs) {
+      maxKey = Math.max(maxKey, kv.getKey());
+    }
+    assertThat(maxKey, equalTo(12L));
+  }
+
+  private static class TestSink extends Sink<Object> {
+    @Override
+    public void validate(PipelineOptions options) {}
+
+    @Override
+    public WriteOperation<Object, ?> createWriteOperation(PipelineOptions options) {
+      throw new IllegalArgumentException("Should not be used");
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
index babb50a..a1f1f70 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
@@ -23,7 +23,6 @@ import static org.apache.beam.sdk.TestUtils.NO_INTS_ARRAY;
 import static org.apache.beam.sdk.TestUtils.NO_LINES_ARRAY;
 import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
 import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasValue;
-
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.hasItem;
 import static org.hamcrest.Matchers.startsWith;
@@ -239,14 +238,11 @@ public class TextIOTest {
     } else if (numShards > 0) {
       write = write.withNumShards(numShards).withShardNameTemplate(ShardNameTemplate.INDEX_OF_MAX);
     }
-    int numOutputShards = (numShards > 0) ? numShards : 1;
-
     input.apply(write);
 
     p.run();
 
-    assertOutputFiles(elems, coder, numOutputShards, tmpFolder, outputName,
-        write.getShardNameTemplate());
+    assertOutputFiles(elems, coder, numShards, tmpFolder, outputName, write.getShardNameTemplate());
   }
 
   public static <T> void assertOutputFiles(
@@ -258,11 +254,19 @@ public class TextIOTest {
       String shardNameTemplate)
       throws Exception {
     List<File> expectedFiles = new ArrayList<>();
-    for (int i = 0; i < numShards; i++) {
-      expectedFiles.add(
-          new File(
-              rootLocation.getRoot(),
-              IOChannelUtils.constructName(outputName, shardNameTemplate, "", i, numShards)));
+    if (numShards == 0) {
+      String pattern =
+          IOChannelUtils.resolve(rootLocation.getRoot().getAbsolutePath(), outputName + "*");
+      for (String expected : IOChannelUtils.getFactory(pattern).match(pattern)) {
+        expectedFiles.add(new File(expected));
+      }
+    } else {
+      for (int i = 0; i < numShards; i++) {
+        expectedFiles.add(
+            new File(
+                rootLocation.getRoot(),
+                IOChannelUtils.constructName(outputName, shardNameTemplate, "", i, numShards)));
+      }
     }
 
     List<String> actual = new ArrayList<>();

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java
index 56643f2..0af0744 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java
@@ -19,10 +19,10 @@ package org.apache.beam.sdk.io;
 
 import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
 import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFrom;
-
 import static org.hamcrest.Matchers.anyOf;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertThat;
@@ -146,7 +146,7 @@ public class WriteTest {
   public void testEmptyWrite() {
     runWrite(Collections.<String>emptyList(), IDENTITY_MAP);
     // Note we did not request a sharded write, so runWrite will not validate the number of shards.
-    assertEquals(1, numShards.intValue());
+    assertThat(numShards.intValue(), greaterThan(0));
   }
 
   /**