You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ib...@apache.org on 2020/05/11 21:45:50 UTC

[beam] branch master updated: [BEAM-9835] [Portable Spark] Broadcast a PCollection at most once.

This is an automated email from the ASF dual-hosted git repository.

ibzib pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 126d728  [BEAM-9835] [Portable Spark] Broadcast a PCollection at most once.
     new 86cb5b0  Merge pull request #11644 from ibzib/BEAM-9835
126d728 is described below

commit 126d728da2f6c27b277360b6cb20d3f6755551de
Author: Kyle Weaver <kc...@google.com>
AuthorDate: Fri May 8 12:11:04 2020 -0700

    [BEAM-9835] [Portable Spark] Broadcast a PCollection at most once.
---
 .../SparkBatchPortablePipelineTranslator.java      | 47 +++++++++++++++-------
 1 file changed, 33 insertions(+), 14 deletions(-)

diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
index 7402c00..d1e51ea 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
@@ -24,6 +24,7 @@ import static org.apache.beam.runners.fnexecution.translation.PipelineTranslator
 import com.google.auto.service.AutoService;
 import java.io.IOException;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -211,18 +212,8 @@ public class SparkBatchPortablePipelineTranslator {
     Coder windowCoder =
         getWindowingStrategy(inputPCollectionId, components).getWindowFn().windowCoder();
 
-    ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
-        broadcastVariablesBuilder = ImmutableMap.builder();
-    for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
-      RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
-      String collectionId =
-          stagePayloadComponents
-              .getTransformsOrThrow(sideInputId.getTransformId())
-              .getInputsOrThrow(sideInputId.getLocalName());
-      Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
-          broadcastSideInput(collectionId, stagePayloadComponents, context);
-      broadcastVariablesBuilder.put(collectionId, tuple2);
-    }
+    ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
+        broadcastVariables = broadcastSideInputs(stagePayload, context);
 
     JavaRDD<RawUnionValue> staged;
     if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
@@ -254,7 +245,7 @@ public class SparkBatchPortablePipelineTranslator {
               context.jobInfo,
               outputExtractionMap,
               SparkExecutableStageContextFactory.getInstance(),
-              broadcastVariablesBuilder.build(),
+              broadcastVariables,
               MetricsAccumulator.getInstance(),
               windowCoder);
       staged = groupedByKey.flatMap(function.forPair());
@@ -266,7 +257,7 @@ public class SparkBatchPortablePipelineTranslator {
               context.jobInfo,
               outputExtractionMap,
               SparkExecutableStageContextFactory.getInstance(),
-              broadcastVariablesBuilder.build(),
+              broadcastVariables,
               MetricsAccumulator.getInstance(),
               windowCoder);
       staged = inputRdd2.mapPartitions(function2);
@@ -323,6 +314,34 @@ public class SparkBatchPortablePipelineTranslator {
   }
 
   /**
+   * Broadcast the side inputs of an executable stage. *This can be expensive.*
+   *
+   * @return Map from PCollection ID to Spark broadcast variable and coder to decode its contents.
+   */
+  private static <SideInputT>
+      ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
+          broadcastSideInputs(
+              RunnerApi.ExecutableStagePayload stagePayload, SparkTranslationContext context) {
+    Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
+        broadcastVariables = new HashMap<>();
+    for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
+      RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
+      String collectionId =
+          stagePayloadComponents
+              .getTransformsOrThrow(sideInputId.getTransformId())
+              .getInputsOrThrow(sideInputId.getLocalName());
+      if (broadcastVariables.containsKey(collectionId)) {
+        // This PCollection has already been broadcast.
+        continue;
+      }
+      Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
+          broadcastSideInput(collectionId, stagePayloadComponents, context);
+      broadcastVariables.put(collectionId, tuple2);
+    }
+    return ImmutableMap.copyOf(broadcastVariables);
+  }
+
+  /**
    * Collect and serialize the data and then broadcast the result. *This can be expensive.*
    *
    * @return Spark broadcast variable and coder to decode its contents