You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ib...@apache.org on 2020/05/11 21:45:50 UTC
[beam] branch master updated: [BEAM-9835] [Portable Spark]
Broadcast a PCollection at most once.
This is an automated email from the ASF dual-hosted git repository.
ibzib pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 126d728 [BEAM-9835] [Portable Spark] Broadcast a PCollection at most once.
new 86cb5b0 Merge pull request #11644 from ibzib/BEAM-9835
126d728 is described below
commit 126d728da2f6c27b277360b6cb20d3f6755551de
Author: Kyle Weaver <kc...@google.com>
AuthorDate: Fri May 8 12:11:04 2020 -0700
[BEAM-9835] [Portable Spark] Broadcast a PCollection at most once.
---
.../SparkBatchPortablePipelineTranslator.java | 47 +++++++++++++++-------
1 file changed, 33 insertions(+), 14 deletions(-)
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
index 7402c00..d1e51ea 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
@@ -24,6 +24,7 @@ import static org.apache.beam.runners.fnexecution.translation.PipelineTranslator
import com.google.auto.service.AutoService;
import java.io.IOException;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
@@ -211,18 +212,8 @@ public class SparkBatchPortablePipelineTranslator {
Coder windowCoder =
getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder();
- ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
- broadcastVariablesBuilder = ImmutableMap.builder();
- for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
- RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
- String collectionId =
- stagePayloadComponents
- .getTransformsOrThrow(sideInputId.getTransformId())
- .getInputsOrThrow(sideInputId.getLocalName());
- Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
- broadcastSideInput(collectionId, stagePayloadComponents, context);
- broadcastVariablesBuilder.put(collectionId, tuple2);
- }
+ ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
+ broadcastVariables = broadcastSideInputs(stagePayload, context);
JavaRDD<RawUnionValue> staged;
if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
@@ -254,7 +245,7 @@ public class SparkBatchPortablePipelineTranslator {
context.jobInfo,
outputExtractionMap,
SparkExecutableStageContextFactory.getInstance(),
- broadcastVariablesBuilder.build(),
+ broadcastVariables,
MetricsAccumulator.getInstance(),
windowCoder);
staged = groupedByKey.flatMap(function.forPair());
@@ -266,7 +257,7 @@ public class SparkBatchPortablePipelineTranslator {
context.jobInfo,
outputExtractionMap,
SparkExecutableStageContextFactory.getInstance(),
- broadcastVariablesBuilder.build(),
+ broadcastVariables,
MetricsAccumulator.getInstance(),
windowCoder);
staged = inputRdd2.mapPartitions(function2);
@@ -323,6 +314,34 @@ public class SparkBatchPortablePipelineTranslator {
}
/**
+ * Broadcast the side inputs of an executable stage. *This can be expensive.*
+ *
+ * @return Map from PCollection ID to Spark broadcast variable and coder to decode its contents.
+ */
+ private static <SideInputT>
+ ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
+ broadcastSideInputs(
+ RunnerApi.ExecutableStagePayload stagePayload, SparkTranslationContext context) {
+ Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
+ broadcastVariables = new HashMap<>();
+ for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
+ RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
+ String collectionId =
+ stagePayloadComponents
+ .getTransformsOrThrow(sideInputId.getTransformId())
+ .getInputsOrThrow(sideInputId.getLocalName());
+ if (broadcastVariables.containsKey(collectionId)) {
+ // This PCollection has already been broadcast.
+ continue;
+ }
+ Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
+ broadcastSideInput(collectionId, stagePayloadComponents, context);
+ broadcastVariables.put(collectionId, tuple2);
+ }
+ return ImmutableMap.copyOf(broadcastVariables);
+ }
+
+ /**
* Collect and serialize the data and then broadcast the result. *This can be expensive.*
*
* @return Spark broadcast variable and coder to decode its contents