You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ro...@apache.org on 2020/11/24 18:30:21 UTC
[beam] branch master updated: Add an option to GroupIntoBatches to
output ShardedKeys. Update Dataflow pipeline translation accordingly.
This is an automated email from the ASF dual-hosted git repository.
robinyqiu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new b51d64e Add an option to GroupIntoBatches to output ShardedKeys. Update Dataflow pipeline translation accordingly.
new f4d889f Merge pull request #13208 from nehsyc/fix_override
b51d64e is described below
commit b51d64e0eee662a1cc75f1a558ef99c2e812813e
Author: sychen <sy...@google.com>
AuthorDate: Tue Oct 27 14:30:58 2020 -0700
Add an option to GroupIntoBatches to output ShardedKeys. Update Dataflow pipeline translation accordingly.
---
.../core/construction/PTransformMatchers.java | 2 +-
.../dataflow/DataflowPipelineTranslator.java | 14 ++-
.../beam/runners/dataflow/DataflowRunner.java | 29 ++++-
.../runners/dataflow/GroupIntoBatchesOverride.java | 132 ++++++++++++++++++---
.../beam/runners/dataflow/util/PropertyNames.java | 1 +
.../dataflow/DataflowPipelineTranslatorTest.java | 67 ++++++-----
.../beam/runners/dataflow/DataflowRunnerTest.java | 64 ++++++++--
.../beam/sdk/transforms/GroupIntoBatches.java | 70 ++++++++++-
.../beam/sdk/transforms/GroupIntoBatchesTest.java | 82 +++++++++++++
9 files changed, 392 insertions(+), 69 deletions(-)
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
index 0bfbc4b..781ad52 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/PTransformMatchers.java
@@ -487,7 +487,7 @@ public class PTransformMatchers {
return new PTransformMatcher() {
@Override
public boolean matches(AppliedPTransform<?, ?, ?> application) {
- return application.getTransform().getClass().equals(GroupIntoBatches.class);
+ return application.getTransform().getClass().equals(GroupIntoBatches.WithShardedKey.class);
}
@Override
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
index a5c25e9..51f8fc2 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
@@ -654,10 +654,6 @@ public class DataflowPipelineTranslator {
if (value instanceof PValue) {
PValue pvalue = (PValue) value;
addInput(name, translator.asOutputReference(pvalue, translator.getProducer(pvalue)));
- if (value instanceof PCollection
- && translator.runner.doesPCollectionRequireAutoSharding((PCollection<?>) value)) {
- addInput(PropertyNames.ALLOWS_SHARDABLE_STATE, "true");
- }
} else {
throw new IllegalStateException("Input must be a PValue");
}
@@ -696,6 +692,16 @@ public class DataflowPipelineTranslator {
private void addOutput(String name, PValue value, Coder<?> valueCoder) {
translator.registerOutputName(value, name);
+ // If the output requires runner determined sharding, also append necessary input properties.
+ if (value instanceof PCollection
+ && translator.runner.doesPCollectionRequireAutoSharding((PCollection<?>) value)) {
+ addInput(PropertyNames.ALLOWS_SHARDABLE_STATE, "true");
+ // Currently we only allow auto-sharding to be enabled through the GroupIntoBatches
+ // transform. So we also add the following property which GroupIntoBatchesDoFn has, to allow
+ // the backend to perform graph optimization.
+ addInput(PropertyNames.PRESERVES_KEYS, "true");
+ }
+
Map<String, Object> properties = getProperties();
@Nullable List<Map<String, Object>> outputInfoList = null;
try {
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index 20e3272..428c998 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -494,7 +494,8 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
overridesBuilder.add(
PTransformOverride.of(
PTransformMatchers.groupWithShardableStates(),
- new GroupIntoBatchesOverride.StreamingGroupIntoBatchesOverrideFactory(this)));
+ new GroupIntoBatchesOverride.StreamingGroupIntoBatchesWithShardedKeyOverrideFactory(
+ this)));
if (!fnApiEnabled) {
overridesBuilder
@@ -526,9 +527,14 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
// Replace GroupIntoBatches before the state/timer replacements below since
// GroupIntoBatches internally uses a stateful DoFn.
.add(
- PTransformOverride.of(
- PTransformMatchers.classEqualTo(GroupIntoBatches.class),
- new GroupIntoBatchesOverride.BatchGroupIntoBatchesOverrideFactory()));
+ PTransformOverride.of(
+ PTransformMatchers.classEqualTo(GroupIntoBatches.class),
+ new GroupIntoBatchesOverride.BatchGroupIntoBatchesOverrideFactory<>()))
+ .add(
+ PTransformOverride.of(
+ PTransformMatchers.classEqualTo(GroupIntoBatches.WithShardedKey.class),
+ new GroupIntoBatchesOverride
+ .BatchGroupIntoBatchesWithShardedKeyOverrideFactory<>()));
overridesBuilder
// State and timer pardos are implemented by expansion to GBK-then-ParDo
@@ -1281,8 +1287,19 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> {
}
void maybeRecordPCollectionWithAutoSharding(PCollection<?> pcol) {
- if (hasExperiment(options, "enable_streaming_auto_sharding")
- && !hasExperiment(options, "beam_fn_api")) {
+ if (hasExperiment(options, "beam_fn_api")) {
+ LOG.warn(
+ "Runner determined sharding not available in Dataflow for GroupIntoBatches for portable "
+ + "jobs. Default sharding will be applied.");
+ return;
+ }
+ if (!options.isEnableStreamingEngine()) {
+ LOG.warn(
+ "Runner determined sharding not available in Dataflow for GroupIntoBatches for Streaming "
+ + "Appliance jobs. Default sharding will be applied.");
+ return;
+ }
+ if (hasExperiment(options, "enable_streaming_auto_sharding")) {
pcollectionsRequiringAutoSharding.add(pcol);
}
}
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
index b6a13ef..ea92f9d 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/GroupIntoBatchesOverride.java
@@ -17,18 +17,24 @@
*/
package org.apache.beam.runners.dataflow;
+import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.UUID;
import org.apache.beam.runners.core.construction.PTransformReplacements;
import org.apache.beam.runners.core.construction.ReplacementOutputs;
+import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.PTransformOverrideFactory;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.GroupIntoBatches;
+import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SimpleFunction;
+import org.apache.beam.sdk.util.ShardedKey;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TupleTag;
@@ -48,7 +54,7 @@ public class GroupIntoBatchesOverride {
transform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
- new BatchGroupIntoBatches(transform.getTransform().getBatchSize()));
+ new BatchGroupIntoBatches<>(transform.getTransform().getBatchSize()));
}
@Override
@@ -92,54 +98,142 @@ public class GroupIntoBatchesOverride {
}
}
- static class StreamingGroupIntoBatchesOverrideFactory<K, V>
+ static class BatchGroupIntoBatchesWithShardedKeyOverrideFactory<K, V>
implements PTransformOverrideFactory<
- PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupIntoBatches<K, V>> {
+ PCollection<KV<K, V>>,
+ PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+ GroupIntoBatches<K, V>.WithShardedKey> {
+
+ @Override
+ public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>>
+ getReplacementTransform(
+ AppliedPTransform<
+ PCollection<KV<K, V>>,
+ PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+ GroupIntoBatches<K, V>.WithShardedKey>
+ transform) {
+ return PTransformReplacement.of(
+ PTransformReplacements.getSingletonMainInput(transform),
+ new BatchGroupIntoBatchesWithShardedKey<>(transform.getTransform().getBatchSize()));
+ }
+
+ @Override
+ public Map<PCollection<?>, ReplacementOutput> mapOutputs(
+ Map<TupleTag<?>, PCollection<?>> outputs,
+ PCollection<KV<ShardedKey<K>, Iterable<V>>> newOutput) {
+ return ReplacementOutputs.singleton(outputs, newOutput);
+ }
+ }
+
+ /**
+ * Specialized implementation of {@link GroupIntoBatches.WithShardedKey} for bounded Dataflow
+ * pipelines.
+ */
+ static class BatchGroupIntoBatchesWithShardedKey<K, V>
+ extends PTransform<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>> {
+
+ private final long batchSize;
+
+ private BatchGroupIntoBatchesWithShardedKey(long batchSize) {
+ this.batchSize = batchSize;
+ }
+
+ @Override
+ public PCollection<KV<ShardedKey<K>, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
+ return shardKeys(input).apply(new BatchGroupIntoBatches<>(batchSize));
+ }
+ }
+
+ static class StreamingGroupIntoBatchesWithShardedKeyOverrideFactory<K, V>
+ implements PTransformOverrideFactory<
+ PCollection<KV<K, V>>,
+ PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+ GroupIntoBatches<K, V>.WithShardedKey> {
private final DataflowRunner runner;
- StreamingGroupIntoBatchesOverrideFactory(DataflowRunner runner) {
+ StreamingGroupIntoBatchesWithShardedKeyOverrideFactory(DataflowRunner runner) {
this.runner = runner;
}
@Override
- public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>>
+ public PTransformReplacement<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>>
getReplacementTransform(
AppliedPTransform<
- PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>, GroupIntoBatches<K, V>>
+ PCollection<KV<K, V>>,
+ PCollection<KV<ShardedKey<K>, Iterable<V>>>,
+ GroupIntoBatches<K, V>.WithShardedKey>
transform) {
return PTransformReplacement.of(
PTransformReplacements.getSingletonMainInput(transform),
- new StreamingGroupIntoBatches(runner, transform.getTransform()));
+ new StreamingGroupIntoBatchesWithShardedKey<>(
+ runner,
+ transform.getTransform(),
+ PTransformReplacements.getSingletonMainOutput(transform)));
}
@Override
public Map<PCollection<?>, ReplacementOutput> mapOutputs(
- Map<TupleTag<?>, PCollection<?>> outputs, PCollection<KV<K, Iterable<V>>> newOutput) {
+ Map<TupleTag<?>, PCollection<?>> outputs,
+ PCollection<KV<ShardedKey<K>, Iterable<V>>> newOutput) {
return ReplacementOutputs.singleton(outputs, newOutput);
}
}
/**
- * Specialized implementation of {@link GroupIntoBatches} for unbounded Dataflow pipelines. The
- * override does the same thing as the original transform but additionally record the input to add
- * corresponding properties during the graph translation.
+ * Specialized implementation of {@link GroupIntoBatches.WithShardedKey} for unbounded Dataflow
+ * pipelines. The override does the same thing as the original transform but additionally records
+ * the output in order to append required step properties during the graph translation.
*/
- static class StreamingGroupIntoBatches<K, V>
- extends PTransform<PCollection<KV<K, V>>, PCollection<KV<K, Iterable<V>>>> {
+ static class StreamingGroupIntoBatchesWithShardedKey<K, V>
+ extends PTransform<PCollection<KV<K, V>>, PCollection<KV<ShardedKey<K>, Iterable<V>>>> {
private final transient DataflowRunner runner;
- private final GroupIntoBatches<K, V> original;
+ private final GroupIntoBatches<K, V>.WithShardedKey originalTransform;
+ private final transient PCollection<KV<ShardedKey<K>, Iterable<V>>> originalOutput;
- public StreamingGroupIntoBatches(DataflowRunner runner, GroupIntoBatches<K, V> original) {
+ public StreamingGroupIntoBatchesWithShardedKey(
+ DataflowRunner runner,
+ GroupIntoBatches<K, V>.WithShardedKey original,
+ PCollection<KV<ShardedKey<K>, Iterable<V>>> output) {
this.runner = runner;
- this.original = original;
+ this.originalTransform = original;
+ this.originalOutput = output;
}
@Override
- public PCollection<KV<K, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
- runner.maybeRecordPCollectionWithAutoSharding(input);
- return input.apply(original);
+ public PCollection<KV<ShardedKey<K>, Iterable<V>>> expand(PCollection<KV<K, V>> input) {
+ // Record the output PCollection of the original transform since the new output will be
+ // replaced by the original one when the replacement transform is wired to other nodes in the
+ // graph, although the old and the new outputs are effectively the same.
+ runner.maybeRecordPCollectionWithAutoSharding(originalOutput);
+ return input.apply(originalTransform);
}
}
+
+ private static final UUID workerUuid = UUID.randomUUID();
+
+ private static <K, V> PCollection<KV<ShardedKey<K>, V>> shardKeys(PCollection<KV<K, V>> input) {
+ KvCoder<K, V> inputCoder = (KvCoder<K, V>) input.getCoder();
+ org.apache.beam.sdk.coders.Coder<K> keyCoder =
+ (org.apache.beam.sdk.coders.Coder<K>) inputCoder.getCoderArguments().get(0);
+ org.apache.beam.sdk.coders.Coder<V> valueCoder =
+ (org.apache.beam.sdk.coders.Coder<V>) inputCoder.getCoderArguments().get(1);
+ return input
+ .apply(
+ "Shard Keys",
+ MapElements.via(
+ new SimpleFunction<KV<K, V>, KV<ShardedKey<K>, V>>() {
+ @Override
+ public KV<ShardedKey<K>, V> apply(KV<K, V> input) {
+ long tid = Thread.currentThread().getId();
+ ByteBuffer buffer = ByteBuffer.allocate(3 * Long.BYTES);
+ buffer.putLong(workerUuid.getMostSignificantBits());
+ buffer.putLong(workerUuid.getLeastSignificantBits());
+ buffer.putLong(tid);
+ return KV.of(ShardedKey.of(input.getKey(), buffer.array()), input.getValue());
+ }
+ }))
+ .setCoder(KvCoder.of(ShardedKey.Coder.of(keyCoder), valueCoder));
+ }
}
diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java
index dacfdc1..f48bcac 100644
--- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java
+++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/util/PropertyNames.java
@@ -64,6 +64,7 @@ public class PropertyNames {
public static final String VALUE = "value";
public static final String WINDOWING_STRATEGY = "windowing_strategy";
public static final String DISPLAY_DATA = "display_data";
+ public static final String PRESERVES_KEYS = "preserves_keys";
/**
* @deprecated Uses the incorrect terminology. {@link #RESTRICTION_ENCODING}. Should be removed
* once non FnAPI SplittableDoFn expansion for Dataflow is removed.
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
index 202c1e7..a140b1e 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslatorTest.java
@@ -78,6 +78,7 @@ import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.extensions.gcp.auth.TestCredential;
+import org.apache.beam.sdk.extensions.gcp.options.GcpOptions;
import org.apache.beam.sdk.extensions.gcp.util.GcsUtil;
import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath;
import org.apache.beam.sdk.io.FileSystems;
@@ -1120,17 +1121,28 @@ public class DataflowPipelineTranslatorTest implements Serializable {
assertEquals("CollectionToSingleton", collectionToSingletonStep.getKind());
}
- @Test
- public void testStreamingGroupIntoBatchesTranslation() throws Exception {
+ private Map<String, Object> runGroupIntoBatchesAndGetStepProperties(
+ Boolean withShardedKey, Boolean usesFnApi) throws IOException {
DataflowPipelineOptions options = buildPipelineOptions();
- options.setExperiments(Arrays.asList("enable_streaming_auto_sharding"));
+ options.setExperiments(
+ Arrays.asList(
+ "enable_streaming_auto_sharding",
+ GcpOptions.STREAMING_ENGINE_EXPERIMENT,
+ GcpOptions.WINDMILL_SERVICE_EXPERIMENT));
+ if (usesFnApi) {
+ options.setExperiments(Arrays.asList("beam_fn_api"));
+ }
options.setStreaming(true);
DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options);
Pipeline pipeline = Pipeline.create(options);
- pipeline
- .apply(Create.of(Arrays.asList(KV.of(1, "1"), KV.of(2, "2"), KV.of(3, "3"))))
- .apply(GroupIntoBatches.ofSize(2));
+ PCollection<KV<Integer, String>> input =
+ pipeline.apply(Create.of(Arrays.asList(KV.of(1, "1"), KV.of(2, "2"), KV.of(3, "3"))));
+ if (withShardedKey) {
+ input.apply(GroupIntoBatches.<Integer, String>ofSize(3).withShardedKey());
+ } else {
+ input.apply(GroupIntoBatches.ofSize(3));
+ }
DataflowRunner runner = DataflowRunner.fromOptions(options);
runner.replaceTransforms(pipeline);
@@ -1142,40 +1154,37 @@ public class DataflowPipelineTranslatorTest implements Serializable {
.getJob();
List<Step> steps = job.getSteps();
Step shardedStateStep = steps.get(steps.size() - 1);
- Map<String, Object> properties = shardedStateStep.getProperties();
+ return shardedStateStep.getProperties();
+ }
+
+ @Test
+ public void testStreamingGroupIntoBatchesTranslation() throws Exception {
+ Map<String, Object> properties = runGroupIntoBatchesAndGetStepProperties(false, false);
+ assertTrue(properties.containsKey(PropertyNames.USES_KEYED_STATE));
+ assertEquals("true", getString(properties, PropertyNames.USES_KEYED_STATE));
+ assertFalse(properties.containsKey(PropertyNames.ALLOWS_SHARDABLE_STATE));
+ assertFalse(properties.containsKey(PropertyNames.PRESERVES_KEYS));
+ }
+
+ @Test
+ public void testStreamingGroupIntoBatchesWithShardedKeyTranslation() throws Exception {
+ Map<String, Object> properties = runGroupIntoBatchesAndGetStepProperties(true, false);
assertTrue(properties.containsKey(PropertyNames.USES_KEYED_STATE));
assertEquals("true", getString(properties, PropertyNames.USES_KEYED_STATE));
assertTrue(properties.containsKey(PropertyNames.ALLOWS_SHARDABLE_STATE));
assertEquals("true", getString(properties, PropertyNames.ALLOWS_SHARDABLE_STATE));
+ assertTrue(properties.containsKey(PropertyNames.PRESERVES_KEYS));
+ assertEquals("true", getString(properties, PropertyNames.PRESERVES_KEYS));
}
@Test
public void testStreamingGroupIntoBatchesTranslationFnApi() throws Exception {
- DataflowPipelineOptions options = buildPipelineOptions();
- options.setExperiments(Arrays.asList("enable_windmill_auto_sharding"));
- options.setExperiments(Arrays.asList("beam_fn_api"));
- options.setStreaming(true);
- DataflowPipelineTranslator translator = DataflowPipelineTranslator.fromOptions(options);
-
- Pipeline pipeline = Pipeline.create(options);
- pipeline
- .apply(Create.of(Arrays.asList(KV.of(1, "1"), KV.of(2, "2"), KV.of(3, "3"))))
- .apply(GroupIntoBatches.ofSize(2));
-
- DataflowRunner runner = DataflowRunner.fromOptions(options);
- runner.replaceTransforms(pipeline);
- SdkComponents sdkComponents = createSdkComponents(options);
- RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline, sdkComponents, true);
- Job job =
- translator
- .translate(pipeline, pipelineProto, sdkComponents, runner, Collections.emptyList())
- .getJob();
- List<Step> steps = job.getSteps();
- Step shardedStateStep = steps.get(steps.size() - 1);
- Map<String, Object> properties = shardedStateStep.getProperties();
+ Map<String, Object> properties = runGroupIntoBatchesAndGetStepProperties(true, true);
assertTrue(properties.containsKey(PropertyNames.USES_KEYED_STATE));
assertEquals("true", getString(properties, PropertyNames.USES_KEYED_STATE));
+ // "allows_shardable_state" is currently unsupported for portable jobs.
assertFalse(properties.containsKey(PropertyNames.ALLOWS_SHARDABLE_STATE));
+ assertFalse(properties.containsKey(PropertyNames.PRESERVES_KEYS));
}
@Test
diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
index d5e8fa7..cb74bff 100644
--- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
+++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
@@ -38,6 +38,7 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeFalse;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyInt;
@@ -131,14 +132,17 @@ import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupIntoBatches;
+import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SerializableFunctions;
+import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.Sessions;
import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.util.ShardedKey;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PValues;
@@ -1605,12 +1609,32 @@ public class DataflowRunnerTest implements Serializable {
verifyMergingStatefulParDoRejected(options);
}
- private void verifyGroupIntoBatchesOverride(Pipeline p) {
+ private void verifyGroupIntoBatchesOverride(
+ Pipeline p, Boolean withShardedKey, Boolean expectOverriden) {
final int batchSize = 2;
List<KV<String, Integer>> testValues =
Arrays.asList(KV.of("A", 1), KV.of("B", 0), KV.of("A", 2), KV.of("A", 4), KV.of("A", 8));
- PCollection<KV<String, Iterable<Integer>>> output =
- p.apply(Create.of(testValues)).apply(GroupIntoBatches.ofSize(batchSize));
+ PCollection<KV<String, Integer>> input = p.apply(Create.of(testValues));
+ PCollection<KV<String, Iterable<Integer>>> output;
+ if (withShardedKey) {
+ output =
+ input
+ .apply(GroupIntoBatches.<String, Integer>ofSize(batchSize).withShardedKey())
+ .apply(
+ "StripShardId",
+ MapElements.via(
+ new SimpleFunction<
+ KV<ShardedKey<String>, Iterable<Integer>>,
+ KV<String, Iterable<Integer>>>() {
+ @Override
+ public KV<String, Iterable<Integer>> apply(
+ KV<ShardedKey<String>, Iterable<Integer>> input) {
+ return KV.of(input.getKey().getKey(), input.getValue());
+ }
+ }));
+ } else {
+ output = input.apply(GroupIntoBatches.ofSize(batchSize));
+ }
PAssert.thatMultimap(output)
.satisfies(
new SerializableFunction<Map<String, Iterable<Iterable<Integer>>>, Void>() {
@@ -1642,23 +1666,41 @@ public class DataflowRunnerTest implements Serializable {
public CompositeBehavior enterCompositeTransform(Node node) {
if (p.getOptions().as(StreamingOptions.class).isStreaming()
&& node.getTransform()
- instanceof GroupIntoBatchesOverride.StreamingGroupIntoBatches) {
+ instanceof GroupIntoBatchesOverride.StreamingGroupIntoBatchesWithShardedKey) {
sawGroupIntoBatchesOverride.set(true);
}
if (!p.getOptions().as(StreamingOptions.class).isStreaming()
&& node.getTransform() instanceof GroupIntoBatchesOverride.BatchGroupIntoBatches) {
sawGroupIntoBatchesOverride.set(true);
}
+ if (!p.getOptions().as(StreamingOptions.class).isStreaming()
+ && node.getTransform()
+ instanceof GroupIntoBatchesOverride.BatchGroupIntoBatchesWithShardedKey) {
+ sawGroupIntoBatchesOverride.set(true);
+ }
return CompositeBehavior.ENTER_TRANSFORM;
}
});
- assertTrue(sawGroupIntoBatchesOverride.get());
+ if (expectOverriden) {
+ assertTrue(sawGroupIntoBatchesOverride.get());
+ } else {
+ assertFalse(sawGroupIntoBatchesOverride.get());
+ }
}
@Test
@Category(ValidatesRunner.class)
public void testBatchGroupIntoBatchesOverride() {
- verifyGroupIntoBatchesOverride(pipeline);
+ // Ignore this test for streaming pipelines.
+ assumeFalse(pipeline.getOptions().as(StreamingOptions.class).isStreaming());
+ verifyGroupIntoBatchesOverride(pipeline, false, true);
+ }
+
+ @Test
+ public void testBatchGroupIntoBatchesWithShardedKeyOverride() throws IOException {
+ PipelineOptions options = buildPipelineOptions();
+ Pipeline p = Pipeline.create(options);
+ verifyGroupIntoBatchesOverride(p, true, true);
}
@Test
@@ -1666,7 +1708,15 @@ public class DataflowRunnerTest implements Serializable {
PipelineOptions options = buildPipelineOptions();
options.as(StreamingOptions.class).setStreaming(true);
Pipeline p = Pipeline.create(options);
- verifyGroupIntoBatchesOverride(p);
+ verifyGroupIntoBatchesOverride(p, false, false);
+ }
+
+ @Test
+ public void testStreamingGroupIntoBatchesWithShardedKeyOverride() throws IOException {
+ PipelineOptions options = buildPipelineOptions();
+ options.as(StreamingOptions.class).setStreaming(true);
+ Pipeline p = Pipeline.create(options);
+ verifyGroupIntoBatchesOverride(p, true, true);
}
private void testStreamingWriteOverride(PipelineOptions options, int expectedNumShards) {
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
index ae3b0f2..1ac6493 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
@@ -19,7 +19,10 @@ package org.apache.beam.sdk.transforms;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
+import java.nio.ByteBuffer;
+import java.util.UUID;
import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.state.BagState;
@@ -32,6 +35,7 @@ import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
+import org.apache.beam.sdk.util.ShardedKey;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
@@ -89,6 +93,7 @@ public class GroupIntoBatches<K, InputT>
private final long batchSize;
@Nullable private final Duration maxBufferingDuration;
+ private static final UUID workerUuid = UUID.randomUUID();
private GroupIntoBatches(long batchSize, @Nullable Duration maxBufferingDuration) {
this.batchSize = batchSize;
@@ -105,8 +110,8 @@ public class GroupIntoBatches<K, InputT>
}
/**
- * Set a time limit (in processing time) on how long an incomplete batch of elements is allowed to
- * be buffered. Once a batch is flushed to output, the timer is reset.
+ * Sets a time limit (in processing time) on how long an incomplete batch of elements is allowed
+ * to be buffered. Once a batch is flushed to output, the timer is reset.
*/
public GroupIntoBatches<K, InputT> withMaxBufferingDuration(Duration duration) {
checkArgument(
@@ -114,6 +119,65 @@ public class GroupIntoBatches<K, InputT>
return new GroupIntoBatches<>(batchSize, duration);
}
+ /**
+ * Outputs batched elements associated with sharded input keys. By default, keys are sharded to
+ * such that the input elements with the same key are spread to all available threads executing
+ * the transform. Runners may override the default sharding to do a better load balancing during
+ * the execution time.
+ */
+ @Experimental
+ public WithShardedKey withShardedKey() {
+ return new WithShardedKey();
+ }
+
+ public class WithShardedKey
+ extends PTransform<
+ PCollection<KV<K, InputT>>, PCollection<KV<ShardedKey<K>, Iterable<InputT>>>> {
+ private WithShardedKey() {}
+
+ /** Returns the size of the batch. */
+ public long getBatchSize() {
+ return batchSize;
+ }
+
+ @Override
+ public PCollection<KV<ShardedKey<K>, Iterable<InputT>>> expand(
+ PCollection<KV<K, InputT>> input) {
+ Duration allowedLateness = input.getWindowingStrategy().getAllowedLateness();
+
+ checkArgument(
+ input.getCoder() instanceof KvCoder,
+ "coder specified in the input PCollection is not a KvCoder");
+ KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
+ Coder<K> keyCoder = (Coder<K>) inputCoder.getCoderArguments().get(0);
+ Coder<InputT> valueCoder = (Coder<InputT>) inputCoder.getCoderArguments().get(1);
+
+ return input
+ .apply(
+ MapElements.via(
+ new SimpleFunction<KV<K, InputT>, KV<ShardedKey<K>, InputT>>() {
+ @Override
+ public KV<ShardedKey<K>, InputT> apply(KV<K, InputT> input) {
+ long tid = Thread.currentThread().getId();
+ ByteBuffer buffer = ByteBuffer.allocate(3 * Long.BYTES);
+ buffer.putLong(workerUuid.getMostSignificantBits());
+ buffer.putLong(workerUuid.getLeastSignificantBits());
+ buffer.putLong(tid);
+ return KV.of(ShardedKey.of(input.getKey(), buffer.array()), input.getValue());
+ }
+ }))
+ .setCoder(KvCoder.of(ShardedKey.Coder.of(keyCoder), valueCoder))
+ .apply(
+ ParDo.of(
+ new GroupIntoBatchesDoFn<>(
+ batchSize,
+ allowedLateness,
+ maxBufferingDuration,
+ ShardedKey.Coder.of(keyCoder),
+ valueCoder)));
+ }
+ }
+
@Override
public PCollection<KV<K, Iterable<InputT>>> expand(PCollection<KV<K, InputT>> input) {
Duration allowedLateness = input.getWindowingStrategy().getAllowedLateness();
@@ -121,7 +185,7 @@ public class GroupIntoBatches<K, InputT>
checkArgument(
input.getCoder() instanceof KvCoder,
"coder specified in the input PCollection is not a KvCoder");
- KvCoder inputCoder = (KvCoder) input.getCoder();
+ KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) input.getCoder();
Coder<K> keyCoder = (Coder<K>) inputCoder.getCoderArguments().get(0);
Coder<InputT> valueCoder = (Coder<InputT>) inputCoder.getCoderArguments().get(1);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java
index f7b0906..3334875 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/GroupIntoBatchesTest.java
@@ -44,6 +44,7 @@ import org.apache.beam.sdk.transforms.windowing.FixedWindows;
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Repeatedly;
import org.apache.beam.sdk.transforms.windowing.Window;
+import org.apache.beam.sdk.util.ShardedKey;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TimestampedValue;
@@ -125,6 +126,87 @@ public class GroupIntoBatchesTest implements Serializable {
pipeline.run();
}
+ @Test
+ @Category({NeedsRunner.class, UsesTimersInParDo.class, UsesStatefulParDo.class})
+ public void testWithShardedKeyInGlobalWindow() {
+ // Since with default sharding, the number of subshards of of a key is nondeterministic, create
+ // a large number of input elements and a small batch size and check there is no batch larger
+ // than the specified size.
+ int numElements = 10000;
+ int batchSize = 5;
+ PCollection<KV<ShardedKey<String>, Iterable<String>>> collection =
+ pipeline
+ .apply("Input data", Create.of(createTestData(numElements)))
+ .apply(GroupIntoBatches.<String, String>ofSize(batchSize).withShardedKey())
+ .setCoder(
+ KvCoder.of(
+ ShardedKey.Coder.of(StringUtf8Coder.of()),
+ IterableCoder.of(StringUtf8Coder.of())));
+ PAssert.that("Incorrect batch size in one or more elements", collection)
+ .satisfies(
+ new SerializableFunction<Iterable<KV<ShardedKey<String>, Iterable<String>>>, Void>() {
+
+ private boolean checkBatchSizes(
+ Iterable<KV<ShardedKey<String>, Iterable<String>>> listToCheck) {
+ for (KV<ShardedKey<String>, Iterable<String>> element : listToCheck) {
+ if (Iterables.size(element.getValue()) > batchSize) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @Override
+ public Void apply(Iterable<KV<ShardedKey<String>, Iterable<String>>> input) {
+ assertTrue(checkBatchSizes(input));
+ return null;
+ }
+ });
+ PCollection<KV<Integer, Long>> numBatchesbyBatchSize =
+ collection
+ .apply(
+ "KeyByBatchSize",
+ MapElements.via(
+ new SimpleFunction<
+ KV<ShardedKey<String>, Iterable<String>>, KV<Integer, Integer>>() {
+ @Override
+ public KV<Integer, Integer> apply(
+ KV<ShardedKey<String>, Iterable<String>> input) {
+ int batchSize = 0;
+ for (String ignored : input.getValue()) {
+ batchSize++;
+ }
+ return KV.of(batchSize, 1);
+ }
+ }))
+ .apply("CountBatchesBySize", Count.perKey());
+ PAssert.that("Expecting majority of the batches are full", numBatchesbyBatchSize)
+ .satisfies(
+ (SerializableFunction<Iterable<KV<Integer, Long>>, Void>)
+ listOfBatchSize -> {
+ Long numFullBatches = 0L;
+ Long totalNumBatches = 0L;
+ for (KV<Integer, Long> batchSizeAndCount : listOfBatchSize) {
+ if (batchSizeAndCount.getKey() == batchSize) {
+ numFullBatches += batchSizeAndCount.getValue();
+ }
+ totalNumBatches += batchSizeAndCount.getValue();
+ }
+ assertTrue(
+ String.format(
+ "total number of batches should be in the range [%d, %d] but got %d",
+ numElements, numElements / batchSize, numFullBatches),
+ numFullBatches <= numElements && numFullBatches >= numElements / batchSize);
+ assertTrue(
+ String.format(
+ "number of full batches vs. total number of batches in total: %d vs. %d",
+ numFullBatches, totalNumBatches),
+ numFullBatches > totalNumBatches / 2);
+ return null;
+ });
+ pipeline.run();
+ }
+
/** test behavior when the number of input elements is not evenly divisible by batch size. */
@Test
@Category({NeedsRunner.class, UsesTimersInParDo.class, UsesStatefulParDo.class})