You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by dh...@apache.org on 2016/07/19 00:08:01 UTC
[4/4] incubator-beam git commit: Dynamically choose number of shards
in the DirectRunner
Dynamically choose number of shards in the DirectRunner
Add a Write Override Factory that limits the number of shards if
unspecified. This ensures that we will not write an output file per-key
due to bundling.
Do so by obtaining a count of the elements and obtaining the number of
shards based on the number of outputs.
Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/1535d732
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/1535d732
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/1535d732
Branch: refs/heads/master
Commit: 1535d732fe384a72e5ea46c065779a22dad1e2f6
Parents: 674d831
Author: Thomas Groh <tg...@google.com>
Authored: Wed Jul 13 14:26:10 2016 -0700
Committer: Dan Halperin <dh...@google.com>
Committed: Mon Jul 18 17:07:41 2016 -0700
----------------------------------------------------------------------
.../beam/runners/direct/DirectRunner.java | 2 +
.../direct/WriteWithShardingFactory.java | 141 +++++++++
.../direct/WriteWithShardingFactoryTest.java | 285 +++++++++++++++++++
.../java/org/apache/beam/sdk/io/TextIOTest.java | 24 +-
.../java/org/apache/beam/sdk/io/WriteTest.java | 4 +-
5 files changed, 444 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
index 7408c0b..7fd38c2 100644
--- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectRunner.java
@@ -24,6 +24,7 @@ import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.io.Write;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.AggregatorPipelineExtractor;
import org.apache.beam.sdk.runners.AggregatorRetrievalException;
@@ -78,6 +79,7 @@ public class DirectRunner
ImmutableMap.<Class<? extends PTransform>, PTransformOverrideFactory>builder()
.put(GroupByKey.class, new DirectGroupByKeyOverrideFactory())
.put(CreatePCollectionView.class, new ViewOverrideFactory())
+ .put(Write.Bound.class, new WriteWithShardingFactory())
.build();
/**
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
new file mode 100644
index 0000000..93f2408
--- /dev/null
+++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/WriteWithShardingFactory.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static com.google.common.base.Preconditions.checkArgument;
+
+import org.apache.beam.sdk.io.Write;
+import org.apache.beam.sdk.io.Write.Bound;
+import org.apache.beam.sdk.transforms.Count;
+import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Flatten;
+import org.apache.beam.sdk.transforms.GroupByKey;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.Values;
+import org.apache.beam.sdk.transforms.windowing.DefaultTrigger;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
+import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PDone;
+import org.apache.beam.sdk.values.PInput;
+import org.apache.beam.sdk.values.POutput;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.joda.time.Duration;
+
+import java.util.concurrent.ThreadLocalRandom;
+
+/**
+ * A {@link PTransformOverrideFactory} that overrides {@link Write} {@link PTransform PTransforms}
+ * with an unspecified number of shards with a write with a specified number of shards. The number
+ * of shards is the log base 10 of the number of input records, with up to 2 additional shards.
+ */
+class WriteWithShardingFactory implements PTransformOverrideFactory {
+ static final int MAX_RANDOM_EXTRA_SHARDS = 3;
+
+ @Override
+ public <InputT extends PInput, OutputT extends POutput> PTransform<InputT, OutputT> override(
+ PTransform<InputT, OutputT> transform) {
+ if (transform instanceof Write.Bound) {
+ Write.Bound<InputT> that = (Write.Bound<InputT>) transform;
+ if (that.getNumShards() == 0) {
+ return (PTransform<InputT, OutputT>) new DynamicallyReshardedWrite<InputT>(that);
+ }
+ }
+ return transform;
+ }
+
+ private static class DynamicallyReshardedWrite <T> extends PTransform<PCollection<T>, PDone> {
+ private final transient Write.Bound<T> original;
+
+ private DynamicallyReshardedWrite(Bound<T> original) {
+ this.original = original;
+ }
+
+ @Override
+ public PDone apply(PCollection<T> input) {
+ PCollection<T> records = input.apply("RewindowInputs",
+ Window.<T>into(new GlobalWindows()).triggering(DefaultTrigger.of())
+ .withAllowedLateness(Duration.ZERO)
+ .discardingFiredPanes());
+ final PCollectionView<Long> numRecords = records.apply(Count.<T>globally().asSingletonView());
+ PCollection<T> resharded =
+ records
+ .apply(
+ "ApplySharding",
+ ParDo.withSideInputs(numRecords)
+ .of(
+ new KeyBasedOnCountFn<T>(
+ numRecords,
+ ThreadLocalRandom.current().nextInt(MAX_RANDOM_EXTRA_SHARDS))))
+ .apply("GroupIntoShards", GroupByKey.<Integer, T>create())
+ .apply("DropShardingKeys", Values.<Iterable<T>>create())
+ .apply("FlattenShardIterables", Flatten.<T>iterables());
+ // This is an inverted application to apply the expansion of the original Write PTransform
+ // without adding a new Write Transform Node, which would be overwritten the same way, leading
+ // to an infinite recursion. We cannot modify the number of shards, because that is determined
+ // at runtime.
+ return original.apply(resharded);
+ }
+ }
+
+ @VisibleForTesting
+ static class KeyBasedOnCountFn<T> extends DoFn<T, KV<Integer, T>> {
+ @VisibleForTesting
+ static final int MIN_SHARDS_FOR_LOG = 3;
+
+ private final PCollectionView<Long> numRecords;
+ private final int randomExtraShards;
+ private int currentShard;
+ private int maxShards;
+
+ KeyBasedOnCountFn(PCollectionView<Long> numRecords, int extraShards) {
+ this.numRecords = numRecords;
+ this.randomExtraShards = extraShards;
+ }
+
+ @Override
+ public void processElement(ProcessContext c) throws Exception {
+ if (maxShards == 0L) {
+ maxShards = calculateShards(c.sideInput(numRecords));
+ currentShard = ThreadLocalRandom.current().nextInt(maxShards);
+ }
+ int shard = currentShard;
+ currentShard = (currentShard + 1) % maxShards;
+ c.output(KV.of(shard, c.element()));
+ }
+
+ private int calculateShards(long totalRecords) {
+ checkArgument(
+ totalRecords > 0,
+ "KeyBasedOnCountFn cannot be invoked on an element if there are no elements");
+ if (totalRecords < MIN_SHARDS_FOR_LOG + randomExtraShards) {
+ return (int) totalRecords;
+ }
+ // 100mil records before >7 output files
+ int floorLogRecs = Double.valueOf(Math.log10(totalRecords)).intValue();
+ int shards = Math.max(floorLogRecs, MIN_SHARDS_FOR_LOG) + randomExtraShards;
+ return shards;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
----------------------------------------------------------------------
diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
new file mode 100644
index 0000000..a53bc64
--- /dev/null
+++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/WriteWithShardingFactoryTest.java
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.beam.runners.direct;
+
+import static org.hamcrest.Matchers.allOf;
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.lessThan;
+import static org.hamcrest.Matchers.not;
+import static org.junit.Assert.assertThat;
+
+import org.apache.beam.runners.direct.WriteWithShardingFactory.KeyBasedOnCountFn;
+import org.apache.beam.sdk.coders.VarLongCoder;
+import org.apache.beam.sdk.io.Sink;
+import org.apache.beam.sdk.io.TextIO;
+import org.apache.beam.sdk.io.Write;
+import org.apache.beam.sdk.options.PipelineOptions;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.DoFnTester;
+import org.apache.beam.sdk.transforms.PTransform;
+import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
+import org.apache.beam.sdk.util.IOChannelUtils;
+import org.apache.beam.sdk.util.PCollectionViews;
+import org.apache.beam.sdk.util.WindowingStrategy;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionView;
+import org.apache.beam.sdk.values.PDone;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterables;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.File;
+import java.io.FileReader;
+import java.io.Reader;
+import java.nio.CharBuffer;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ThreadLocalRandom;
+
+/**
+ * Tests for {@link WriteWithShardingFactory}.
+ */
+public class WriteWithShardingFactoryTest {
+ public static final int INPUT_SIZE = 10000;
+ @Rule public TemporaryFolder tmp = new TemporaryFolder();
+ private WriteWithShardingFactory factory = new WriteWithShardingFactory();
+
+ @Test
+ public void dynamicallyReshardedWrite() throws Exception {
+ List<String> strs = new ArrayList<>(INPUT_SIZE);
+ for (int i = 0; i < INPUT_SIZE; i++) {
+ strs.add(UUID.randomUUID().toString());
+ }
+ Collections.shuffle(strs);
+
+ String fileName = "resharded_write";
+ String outputPath = tmp.getRoot().getAbsolutePath();
+ String targetLocation = IOChannelUtils.resolve(outputPath, fileName);
+ TestPipeline p = TestPipeline.create();
+ // TextIO is implemented in terms of the Write PTransform. When sharding is not specified,
+ // resharding should be automatically applied
+ p.apply(Create.of(strs)).apply(TextIO.Write.to(targetLocation));
+
+ p.run();
+
+ Collection<String> files = IOChannelUtils.getFactory(outputPath).match(targetLocation + "*");
+ List<String> actuals = new ArrayList(strs.size());
+ for (String file : files) {
+ CharBuffer buf = CharBuffer.allocate((int) new File(file).length());
+ try (Reader reader = new FileReader(file)) {
+ reader.read(buf);
+ buf.flip();
+ }
+
+ String[] readStrs = buf.toString().split("\n");
+ for (String read : readStrs) {
+ if (read.length() > 0) {
+ actuals.add(read);
+ }
+ }
+ }
+
+ assertThat(actuals, containsInAnyOrder(strs.toArray()));
+ assertThat(
+ files,
+ hasSize(
+ allOf(
+ greaterThan(1),
+ lessThan(
+ (int)
+ (Math.log10(INPUT_SIZE)
+ + WriteWithShardingFactory.MAX_RANDOM_EXTRA_SHARDS)))));
+ }
+
+ @Test
+ public void withShardingSpecifiesOriginalTransform() {
+ PTransform<PCollection<Object>, PDone> original = Write.to(new TestSink()).withNumShards(3);
+
+ assertThat(factory.override(original), equalTo(original));
+ }
+
+ @Test
+ public void withNonWriteReturnsOriginalTransform() {
+ PTransform<PCollection<Object>, PDone> original =
+ new PTransform<PCollection<Object>, PDone>() {
+ @Override
+ public PDone apply(PCollection<Object> input) {
+ return PDone.in(input.getPipeline());
+ }
+ };
+
+ assertThat(factory.override(original), equalTo(original));
+ }
+
+ @Test
+ public void withNoShardingSpecifiedReturnsNewTransform() {
+ PTransform<PCollection<Object>, PDone> original = Write.to(new TestSink());
+ assertThat(factory.override(original), not(equalTo(original)));
+ }
+
+ @Test
+ public void keyBasedOnCountFnWithOneElement() throws Exception {
+ PCollectionView<Long> elementCountView =
+ PCollectionViews.singletonView(
+ TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+ KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0);
+ DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+ fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 1L);
+
+ List<KV<Integer, String>> outputs = fnTester.processBundle("foo", "bar", "bazbar");
+ assertThat(
+ outputs, containsInAnyOrder(KV.of(0, "foo"), KV.of(0, "bar"), KV.of(0, "bazbar")));
+ }
+
+ @Test
+ public void keyBasedOnCountFnWithTwoElements() throws Exception {
+ PCollectionView<Long> elementCountView =
+ PCollectionViews.singletonView(
+ TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+ KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0);
+ DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+ fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 2L);
+
+ List<KV<Integer, String>> outputs = fnTester.processBundle("foo", "bar");
+ assertThat(
+ outputs,
+ anyOf(
+ containsInAnyOrder(KV.of(0, "foo"), KV.of(1, "bar")),
+ containsInAnyOrder(KV.of(1, "foo"), KV.of(0, "bar"))));
+ }
+
+ @Test
+ public void keyBasedOnCountFnFewElementsThreeShards() throws Exception {
+ PCollectionView<Long> elementCountView =
+ PCollectionViews.singletonView(
+ TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+ KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0);
+ DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+ fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, 100L);
+
+ List<KV<Integer, String>> outputs =
+ fnTester.processBundle("foo", "bar", "baz", "foobar", "foobaz", "barbaz");
+ assertThat(
+ Iterables.transform(
+ outputs,
+ new Function<KV<Integer, String>, Integer>() {
+ @Override
+ public Integer apply(KV<Integer, String> input) {
+ return input.getKey();
+ }
+ }),
+ containsInAnyOrder(0, 0, 1, 1, 2, 2));
+ }
+
+ @Test
+ public void keyBasedOnCountFnManyElements() throws Exception {
+ PCollectionView<Long> elementCountView =
+ PCollectionViews.singletonView(
+ TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+ KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 0);
+ DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+ double count = Math.pow(10, 10);
+ fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, (long) count);
+
+ List<String> strings = new ArrayList<>();
+ for (int i = 0; i < 100; i++) {
+ strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong()));
+ }
+ List<KV<Integer, String>> kvs = fnTester.processBundle(strings);
+ long maxKey = -1L;
+ for (KV<Integer, String> kv : kvs) {
+ maxKey = Math.max(maxKey, kv.getKey());
+ }
+ assertThat(maxKey, equalTo(9L));
+ }
+
+ @Test
+ public void keyBasedOnCountFnFewElementsExtraShards() throws Exception {
+ PCollectionView<Long> elementCountView =
+ PCollectionViews.singletonView(
+ TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+ KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 10);
+ DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+ long countValue = (long) KeyBasedOnCountFn.MIN_SHARDS_FOR_LOG + 3;
+ fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, countValue);
+
+ List<String> strings = new ArrayList<>();
+ for (int i = 0; i < 100; i++) {
+ strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong()));
+ }
+ List<KV<Integer, String>> kvs = fnTester.processBundle(strings);
+ long maxKey = -1L;
+ for (KV<Integer, String> kv : kvs) {
+ maxKey = Math.max(maxKey, kv.getKey());
+ }
+ // 0 to n-1 shard ids.
+ assertThat(maxKey, equalTo(countValue - 1));
+ }
+
+ @Test
+ public void keyBasedOnCountFnManyElementsExtraShards() throws Exception {
+ PCollectionView<Long> elementCountView =
+ PCollectionViews.singletonView(
+ TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of());
+ KeyBasedOnCountFn<String> fn = new KeyBasedOnCountFn<>(elementCountView, 3);
+ DoFnTester<String, KV<Integer, String>> fnTester = DoFnTester.of(fn);
+
+ double count = Math.pow(10, 10);
+ fnTester.setSideInput(elementCountView, GlobalWindow.INSTANCE, (long) count);
+
+ List<String> strings = new ArrayList<>();
+ for (int i = 0; i < 100; i++) {
+ strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong()));
+ }
+ List<KV<Integer, String>> kvs = fnTester.processBundle(strings);
+ long maxKey = -1L;
+ for (KV<Integer, String> kv : kvs) {
+ maxKey = Math.max(maxKey, kv.getKey());
+ }
+ assertThat(maxKey, equalTo(12L));
+ }
+
+ private static class TestSink extends Sink<Object> {
+ @Override
+ public void validate(PipelineOptions options) {}
+
+ @Override
+ public WriteOperation<Object, ?> createWriteOperation(PipelineOptions options) {
+ throw new IllegalArgumentException("Should not be used");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
index babb50a..a1f1f70 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/TextIOTest.java
@@ -23,7 +23,6 @@ import static org.apache.beam.sdk.TestUtils.NO_INTS_ARRAY;
import static org.apache.beam.sdk.TestUtils.NO_LINES_ARRAY;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasValue;
-
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasItem;
import static org.hamcrest.Matchers.startsWith;
@@ -239,14 +238,11 @@ public class TextIOTest {
} else if (numShards > 0) {
write = write.withNumShards(numShards).withShardNameTemplate(ShardNameTemplate.INDEX_OF_MAX);
}
- int numOutputShards = (numShards > 0) ? numShards : 1;
-
input.apply(write);
p.run();
- assertOutputFiles(elems, coder, numOutputShards, tmpFolder, outputName,
- write.getShardNameTemplate());
+ assertOutputFiles(elems, coder, numShards, tmpFolder, outputName, write.getShardNameTemplate());
}
public static <T> void assertOutputFiles(
@@ -258,11 +254,19 @@ public class TextIOTest {
String shardNameTemplate)
throws Exception {
List<File> expectedFiles = new ArrayList<>();
- for (int i = 0; i < numShards; i++) {
- expectedFiles.add(
- new File(
- rootLocation.getRoot(),
- IOChannelUtils.constructName(outputName, shardNameTemplate, "", i, numShards)));
+ if (numShards == 0) {
+ String pattern =
+ IOChannelUtils.resolve(rootLocation.getRoot().getAbsolutePath(), outputName + "*");
+ for (String expected : IOChannelUtils.getFactory(pattern).match(pattern)) {
+ expectedFiles.add(new File(expected));
+ }
+ } else {
+ for (int i = 0; i < numShards; i++) {
+ expectedFiles.add(
+ new File(
+ rootLocation.getRoot(),
+ IOChannelUtils.constructName(outputName, shardNameTemplate, "", i, numShards)));
+ }
}
List<String> actual = new ArrayList<>();
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/1535d732/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java
----------------------------------------------------------------------
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java
index 56643f2..0af0744 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/WriteTest.java
@@ -19,10 +19,10 @@ package org.apache.beam.sdk.io;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFrom;
-
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
@@ -146,7 +146,7 @@ public class WriteTest {
public void testEmptyWrite() {
runWrite(Collections.<String>emptyList(), IDENTITY_MAP);
// Note we did not request a sharded write, so runWrite will not validate the number of shards.
- assertEquals(1, numShards.intValue());
+ assertThat(numShards.intValue(), greaterThan(0));
}
/**