You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by go...@apache.org on 2019/05/28 22:14:10 UTC

[beam] branch spark-cache-stage created (now 4cb97fb)

This is an automated email from the ASF dual-hosted git repository.

goenka pushed a change to branch spark-cache-stage
in repository https://gitbox.apache.org/repos/asf/beam.git.


      at 4cb97fb  Merge pull request #8558 [BEAM-7131] Spark: cache output to prevent re-computation

This branch includes the following new commits:

     new 4cb97fb  Merge pull request #8558 [BEAM-7131] Spark: cache output to prevent re-computation

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.



[beam] 01/01: Merge pull request #8558 [BEAM-7131] Spark: cache output to prevent re-computation

Posted by go...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

goenka pushed a commit to branch spark-cache-stage
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 4cb97fb3cf08ba9e2a57f8ca5f1a32496d90afbe
Author: Robert Bradshaw <ro...@gmail.com>
AuthorDate: Fri May 24 10:49:00 2019 +0200

    Merge pull request #8558 [BEAM-7131] Spark: cache output to prevent re-computation
---
 .../SparkBatchPortablePipelineTranslator.java      |  94 +++++++++++----
 .../spark/translation/SparkTranslationContext.java |  24 +++-
 .../runners/spark/SparkPortableExecutionTest.java  | 126 +++++++++++++++++++--
 3 files changed, 211 insertions(+), 33 deletions(-)

diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
index 8e7796f..0180496 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
@@ -64,6 +64,7 @@ import org.apache.spark.Partitioner;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.storage.StorageLevel;
 import scala.Tuple2;
 
 /** Translates a bounded portable pipeline into a Spark job. */
@@ -112,6 +113,24 @@ public class SparkBatchPortablePipelineTranslator {
         QueryablePipeline.forTransforms(
             pipeline.getRootTransformIdsList(), pipeline.getComponents());
     for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
+      // Pre-scan pipeline to count which pCollections are consumed as inputs more than once so
+      // their corresponding RDDs can later be cached.
+      for (String inputId : transformNode.getTransform().getInputsMap().values()) {
+        context.incrementConsumptionCountBy(inputId, 1);
+      }
+      // Executable stage consists of two parts: computation and extraction. This means the result
+      // of computation is an intermediate RDD, which we might also need to cache.
+      if (transformNode.getTransform().getSpec().getUrn().equals(ExecutableStage.URN)) {
+        context.incrementConsumptionCountBy(
+            getExecutableStageIntermediateId(transformNode),
+            transformNode.getTransform().getOutputsMap().size());
+      }
+      for (String outputId : transformNode.getTransform().getOutputsMap().values()) {
+        WindowedValueCoder outputCoder = getWindowedValueCoder(outputId, pipeline.getComponents());
+        context.putCoder(outputId, outputCoder);
+      }
+    }
+    for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
       urnToTransformTranslator
           .getOrDefault(
               transformNode.getTransform().getSpec().getUrn(),
@@ -141,18 +160,9 @@ public class SparkBatchPortablePipelineTranslator {
 
     RunnerApi.Components components = pipeline.getComponents();
     String inputId = getInputId(transformNode);
-    PCollection inputPCollection = components.getPcollectionsOrThrow(inputId);
     Dataset inputDataset = context.popDataset(inputId);
     JavaRDD<WindowedValue<KV<K, V>>> inputRdd = ((BoundedDataset<KV<K, V>>) inputDataset).getRDD();
-    PCollectionNode inputPCollectionNode = PipelineNode.pCollection(inputId, inputPCollection);
-    WindowedValueCoder<KV<K, V>> inputCoder;
-    try {
-      inputCoder =
-          (WindowedValueCoder)
-              WireCoders.instantiateRunnerWireCoder(inputPCollectionNode, components);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
+    WindowedValueCoder<KV<K, V>> inputCoder = getWindowedValueCoder(inputId, components);
     KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
     Coder<K> inputKeyCoder = inputKvCoder.getKeyCoder();
     Coder<V> inputValueCoder = inputKvCoder.getValueCoder();
@@ -200,18 +210,18 @@ public class SparkBatchPortablePipelineTranslator {
     Dataset inputDataset = context.popDataset(inputPCollectionId);
     JavaRDD<WindowedValue<InputT>> inputRdd = ((BoundedDataset<InputT>) inputDataset).getRDD();
     Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
-    BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
+    BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values());
 
     ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
         broadcastVariablesBuilder = ImmutableMap.builder();
     for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
-      RunnerApi.Components components = stagePayload.getComponents();
+      RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
       String collectionId =
-          components
+          stagePayloadComponents
               .getTransformsOrThrow(sideInputId.getTransformId())
               .getInputsOrThrow(sideInputId.getLocalName());
       Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
-          broadcastSideInput(collectionId, components, context);
+          broadcastSideInput(collectionId, stagePayloadComponents, context);
       broadcastVariablesBuilder.put(collectionId, tuple2);
     }
 
@@ -219,14 +229,38 @@ public class SparkBatchPortablePipelineTranslator {
         new SparkExecutableStageFunction<>(
             stagePayload,
             context.jobInfo,
-            outputMap,
+            outputExtractionMap,
             broadcastVariablesBuilder.build(),
             MetricsAccumulator.getInstance());
     JavaRDD<RawUnionValue> staged = inputRdd.mapPartitions(function);
+    String intermediateId = getExecutableStageIntermediateId(transformNode);
+    context.pushDataset(
+        intermediateId,
+        new Dataset() {
+          @Override
+          public void cache(String storageLevel, Coder<?> coder) {
+            StorageLevel level = StorageLevel.fromString(storageLevel);
+            staged.persist(level);
+          }
+
+          @Override
+          public void action() {
+            // Empty function to force computation of RDD.
+            staged.foreach(TranslationUtils.emptyVoidFunction());
+          }
+
+          @Override
+          public void setName(String name) {
+            staged.setName(name);
+          }
+        });
+    // pop dataset to mark RDD as used
+    context.popDataset(intermediateId);
 
     for (String outputId : outputs.values()) {
       JavaRDD<WindowedValue<OutputT>> outputRdd =
-          staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputMap.get(outputId)));
+          staged.flatMap(
+              new SparkExecutableStageExtractionFunction<>(outputExtractionMap.get(outputId)));
       context.pushDataset(outputId, new BoundedDataset<>(outputRdd));
     }
     if (outputs.isEmpty()) {
@@ -249,17 +283,9 @@ public class SparkBatchPortablePipelineTranslator {
    */
   private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<T>> broadcastSideInput(
       String collectionId, RunnerApi.Components components, SparkTranslationContext context) {
-    PCollection collection = components.getPcollectionsOrThrow(collectionId);
     @SuppressWarnings("unchecked")
     BoundedDataset<T> dataset = (BoundedDataset<T>) context.popDataset(collectionId);
-    PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, collection);
-    WindowedValueCoder<T> coder;
-    try {
-      coder =
-          (WindowedValueCoder<T>) WireCoders.instantiateRunnerWireCoder(collectionNode, components);
-    } catch (IOException e) {
-      throw new RuntimeException(e);
-    }
+    WindowedValueCoder<T> coder = getWindowedValueCoder(collectionId, components);
     List<byte[]> bytes = dataset.getBytes(coder);
     Broadcast<List<byte[]>> broadcast = context.getSparkContext().broadcast(bytes);
     return new Tuple2<>(broadcast, coder);
@@ -324,4 +350,22 @@ public class SparkBatchPortablePipelineTranslator {
   private static String getOutputId(PTransformNode transformNode) {
     return Iterables.getOnlyElement(transformNode.getTransform().getOutputsMap().values());
   }
+
+  private static <T> WindowedValueCoder<T> getWindowedValueCoder(
+      String pCollectionId, RunnerApi.Components components) {
+    PCollection pCollection = components.getPcollectionsOrThrow(pCollectionId);
+    PCollectionNode pCollectionNode = PipelineNode.pCollection(pCollectionId, pCollection);
+    WindowedValueCoder<T> coder;
+    try {
+      coder =
+          (WindowedValueCoder) WireCoders.instantiateRunnerWireCoder(pCollectionNode, components);
+    } catch (IOException e) {
+      throw new RuntimeException(e);
+    }
+    return coder;
+  }
+
+  private static String getExecutableStageIntermediateId(PTransformNode transformNode) {
+    return transformNode.getId();
+  }
 }
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java
index 772e0d2..8c2cee8 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java
@@ -17,12 +17,16 @@
  */
 package org.apache.beam.runners.spark.translation;
 
+import com.sun.istack.Nullable;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
 import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
 import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.spark.api.java.JavaSparkContext;
 
@@ -33,6 +37,9 @@ import org.apache.spark.api.java.JavaSparkContext;
 public class SparkTranslationContext {
   private final JavaSparkContext jsc;
   final JobInfo jobInfo;
+  // Map pCollection IDs to the number of times they are consumed as inputs.
+  private final Map<String, Integer> consumptionCount = new HashMap<>();
+  private final Map<String, Coder> coderMap = new HashMap<>();
   private final Map<String, Dataset> datasets = new LinkedHashMap<>();
   private final Set<Dataset> leaves = new LinkedHashSet<>();
   final SerializablePipelineOptions serializablePipelineOptions;
@@ -51,7 +58,13 @@ public class SparkTranslationContext {
   /** Add output of transform to context. */
   public void pushDataset(String pCollectionId, Dataset dataset) {
     dataset.setName(pCollectionId);
-    // TODO cache
+    SparkPipelineOptions sparkOptions =
+        serializablePipelineOptions.get().as(SparkPipelineOptions.class);
+    if (!sparkOptions.isCacheDisabled() && consumptionCount.getOrDefault(pCollectionId, 0) > 1) {
+      String storageLevel = sparkOptions.getStorageLevel();
+      @Nullable Coder coder = coderMap.get(pCollectionId);
+      dataset.cache(storageLevel, coder);
+    }
     datasets.put(pCollectionId, dataset);
     leaves.add(dataset);
   }
@@ -70,6 +83,15 @@ public class SparkTranslationContext {
     }
   }
 
+  void incrementConsumptionCountBy(String pCollectionId, int addend) {
+    int count = consumptionCount.getOrDefault(pCollectionId, 0);
+    consumptionCount.put(pCollectionId, count + addend);
+  }
+
+  void putCoder(String pCollectionId, Coder coder) {
+    coderMap.put(pCollectionId, coder);
+  }
+
   /** Generate a unique pCollection id number to identify runner-generated sinks. */
   public int nextSinkId() {
     return sinkId++;
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
index d7d3428..eb9dce0 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
@@ -17,9 +17,11 @@
  */
 package org.apache.beam.runners.spark;
 
+import java.io.File;
 import java.io.Serializable;
+import java.nio.file.FileSystems;
 import java.util.Collections;
-import java.util.concurrent.Executors;
+import java.util.UUID;
 import java.util.concurrent.TimeUnit;
 import org.apache.beam.model.jobmanagement.v1.JobApi.JobState.Enum;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -46,8 +48,11 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableLis
 import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.ListeningExecutorService;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.MoreExecutors;
 import org.junit.AfterClass;
+import org.junit.Assert;
 import org.junit.BeforeClass;
+import org.junit.ClassRule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -56,15 +61,13 @@ import org.slf4j.LoggerFactory;
  */
 public class SparkPortableExecutionTest implements Serializable {
 
+  @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder();
   private static final Logger LOG = LoggerFactory.getLogger(SparkPortableExecutionTest.class);
-
   private static ListeningExecutorService sparkJobExecutor;
 
   @BeforeClass
   public static void setup() {
-    // Restrict this to only one thread to avoid multiple Spark clusters up at the same time
-    // which is not suitable for memory-constraint environments, i.e. Jenkins.
-    sparkJobExecutor = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1));
+    sparkJobExecutor = MoreExecutors.newDirectExecutorService();
   }
 
   @AfterClass
@@ -159,8 +162,117 @@ public class SparkPortableExecutionTest implements Serializable {
             pipelineProto,
             options.as(SparkPipelineOptions.class));
     jobInvocation.start();
-    while (jobInvocation.getState() != Enum.DONE) {
-      Thread.sleep(1000);
+    Assert.assertEquals(Enum.DONE, jobInvocation.getState());
+  }
+
+  /**
+   * Verifies that each executable stage runs exactly once, even if that executable stage has
+   * multiple immediate outputs. While re-computation may be necessary in the event of failure,
+   * re-computation of a whole executable stage is expensive and can cause unexpected behavior when
+   * the executable stage has side effects (BEAM-7131).
+   *
+   * <pre>
+   *    |-> B -> GBK
+   * A -|
+   *    |-> C -> GBK
+   * </pre>
+   */
+  @Test(timeout = 120_000)
+  public void testExecStageWithMultipleOutputs() throws Exception {
+    PipelineOptions options = PipelineOptionsFactory.create();
+    options.setRunner(CrashingRunner.class);
+    options
+        .as(PortablePipelineOptions.class)
+        .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
+    Pipeline pipeline = Pipeline.create(options);
+    PCollection<KV<String, String>> a =
+        pipeline
+            .apply("impulse", Impulse.create())
+            .apply("A", ParDo.of(new DoFnWithSideEffect<>("A")));
+    PCollection<KV<String, String>> b = a.apply("B", ParDo.of(new DoFnWithSideEffect<>("B")));
+    PCollection<KV<String, String>> c = a.apply("C", ParDo.of(new DoFnWithSideEffect<>("C")));
+    // Use GBKs to force re-computation of executable stage unless cached.
+    b.apply(GroupByKey.create());
+    c.apply(GroupByKey.create());
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline);
+    JobInvocation jobInvocation =
+        SparkJobInvoker.createJobInvocation(
+            "testExecStageWithMultipleOutputs",
+            "testExecStageWithMultipleOutputsRetrievalToken",
+            sparkJobExecutor,
+            pipelineProto,
+            options.as(SparkPipelineOptions.class));
+    jobInvocation.start();
+    Assert.assertEquals(Enum.DONE, jobInvocation.getState());
+  }
+
+  /**
+   * Verifies that each executable stage runs exactly once, even if that executable stage has
+   * multiple downstream consumers. While re-computation may be necessary in the event of failure,
+   * re-computation of a whole executable stage is expensive and can cause unexpected behavior when
+   * the executable stage has side effects (BEAM-7131).
+   *
+   * <pre>
+   *           |-> G
+   * F -> GBK -|
+   *           |-> H
+   * </pre>
+   */
+  @Test(timeout = 120_000)
+  public void testExecStageWithMultipleConsumers() throws Exception {
+    PipelineOptions options = PipelineOptionsFactory.create();
+    options.setRunner(CrashingRunner.class);
+    options
+        .as(PortablePipelineOptions.class)
+        .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
+    Pipeline pipeline = Pipeline.create(options);
+    PCollection<KV<String, Iterable<String>>> f =
+        pipeline
+            .apply("impulse", Impulse.create())
+            .apply("F", ParDo.of(new DoFnWithSideEffect<>("F")))
+            // use GBK to prevent fusion of F, G, and H
+            .apply(GroupByKey.create());
+    f.apply("G", ParDo.of(new DoFnWithSideEffect<>("G")));
+    f.apply("H", ParDo.of(new DoFnWithSideEffect<>("H")));
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline);
+    JobInvocation jobInvocation =
+        SparkJobInvoker.createJobInvocation(
+            "testExecStageWithMultipleConsumers",
+            "testExecStageWithMultipleConsumersRetrievalToken",
+            sparkJobExecutor,
+            pipelineProto,
+            options.as(SparkPipelineOptions.class));
+    jobInvocation.start();
+    Assert.assertEquals(Enum.DONE, jobInvocation.getState());
+  }
+
+  /** A non-idempotent DoFn that cannot be run more than once without error. */
+  private class DoFnWithSideEffect<InputT> extends DoFn<InputT, KV<String, String>> {
+
+    private final String name;
+    private final File file;
+
+    DoFnWithSideEffect(String name) {
+      this.name = name;
+      String path =
+          FileSystems.getDefault()
+              .getPath(
+                  temporaryFolder.getRoot().getAbsolutePath(),
+                  String.format("%s-%s", this.name, UUID.randomUUID().toString()))
+              .toString();
+      file = new File(path);
+    }
+
+    @ProcessElement
+    public void process(ProcessContext context) throws Exception {
+      context.output(KV.of(name, name));
+      // Verify this DoFn has not run more than once by enacting a side effect via the local file
+      // system.
+      Assert.assertTrue(
+          String.format(
+              "Create file %s failed (DoFn %s should only have been run once).",
+              file.getAbsolutePath(), name),
+          file.createNewFile());
     }
   }
 }