You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by go...@apache.org on 2019/05/28 22:14:11 UTC
[beam] 01/01: Merge pull request #8558 [BEAM-7131] Spark: cache
output to prevent re-computation
This is an automated email from the ASF dual-hosted git repository.
goenka pushed a commit to branch spark-cache-stage
in repository https://gitbox.apache.org/repos/asf/beam.git
commit 4cb97fb3cf08ba9e2a57f8ca5f1a32496d90afbe
Author: Robert Bradshaw <ro...@gmail.com>
AuthorDate: Fri May 24 10:49:00 2019 +0200
Merge pull request #8558 [BEAM-7131] Spark: cache output to prevent re-computation
---
.../SparkBatchPortablePipelineTranslator.java | 94 +++++++++++----
.../spark/translation/SparkTranslationContext.java | 24 +++-
.../runners/spark/SparkPortableExecutionTest.java | 126 +++++++++++++++++++--
3 files changed, 211 insertions(+), 33 deletions(-)
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
index 8e7796f..0180496 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkBatchPortablePipelineTranslator.java
@@ -64,6 +64,7 @@ import org.apache.spark.Partitioner;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
+import org.apache.spark.storage.StorageLevel;
import scala.Tuple2;
/** Translates a bounded portable pipeline into a Spark job. */
@@ -112,6 +113,24 @@ public class SparkBatchPortablePipelineTranslator {
QueryablePipeline.forTransforms(
pipeline.getRootTransformIdsList(), pipeline.getComponents());
for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
+ // Pre-scan pipeline to count which pCollections are consumed as inputs more than once so
+ // their corresponding RDDs can later be cached.
+ for (String inputId : transformNode.getTransform().getInputsMap().values()) {
+ context.incrementConsumptionCountBy(inputId, 1);
+ }
+ // Executable stage consists of two parts: computation and extraction. This means the result
+ // of computation is an intermediate RDD, which we might also need to cache.
+ if (transformNode.getTransform().getSpec().getUrn().equals(ExecutableStage.URN)) {
+ context.incrementConsumptionCountBy(
+ getExecutableStageIntermediateId(transformNode),
+ transformNode.getTransform().getOutputsMap().size());
+ }
+ for (String outputId : transformNode.getTransform().getOutputsMap().values()) {
+ WindowedValueCoder outputCoder = getWindowedValueCoder(outputId, pipeline.getComponents());
+ context.putCoder(outputId, outputCoder);
+ }
+ }
+ for (PipelineNode.PTransformNode transformNode : p.getTopologicallyOrderedTransforms()) {
urnToTransformTranslator
.getOrDefault(
transformNode.getTransform().getSpec().getUrn(),
@@ -141,18 +160,9 @@ public class SparkBatchPortablePipelineTranslator {
RunnerApi.Components components = pipeline.getComponents();
String inputId = getInputId(transformNode);
- PCollection inputPCollection = components.getPcollectionsOrThrow(inputId);
Dataset inputDataset = context.popDataset(inputId);
JavaRDD<WindowedValue<KV<K, V>>> inputRdd = ((BoundedDataset<KV<K, V>>) inputDataset).getRDD();
- PCollectionNode inputPCollectionNode = PipelineNode.pCollection(inputId, inputPCollection);
- WindowedValueCoder<KV<K, V>> inputCoder;
- try {
- inputCoder =
- (WindowedValueCoder)
- WireCoders.instantiateRunnerWireCoder(inputPCollectionNode, components);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
+ WindowedValueCoder<KV<K, V>> inputCoder = getWindowedValueCoder(inputId, components);
KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) inputCoder.getValueCoder();
Coder<K> inputKeyCoder = inputKvCoder.getKeyCoder();
Coder<V> inputValueCoder = inputKvCoder.getValueCoder();
@@ -200,18 +210,18 @@ public class SparkBatchPortablePipelineTranslator {
Dataset inputDataset = context.popDataset(inputPCollectionId);
JavaRDD<WindowedValue<InputT>> inputRdd = ((BoundedDataset<InputT>) inputDataset).getRDD();
Map<String, String> outputs = transformNode.getTransform().getOutputsMap();
- BiMap<String, Integer> outputMap = createOutputMap(outputs.values());
+ BiMap<String, Integer> outputExtractionMap = createOutputMap(outputs.values());
ImmutableMap.Builder<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
broadcastVariablesBuilder = ImmutableMap.builder();
for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
- RunnerApi.Components components = stagePayload.getComponents();
+ RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
String collectionId =
- components
+ stagePayloadComponents
.getTransformsOrThrow(sideInputId.getTransformId())
.getInputsOrThrow(sideInputId.getLocalName());
Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
- broadcastSideInput(collectionId, components, context);
+ broadcastSideInput(collectionId, stagePayloadComponents, context);
broadcastVariablesBuilder.put(collectionId, tuple2);
}
@@ -219,14 +229,38 @@ public class SparkBatchPortablePipelineTranslator {
new SparkExecutableStageFunction<>(
stagePayload,
context.jobInfo,
- outputMap,
+ outputExtractionMap,
broadcastVariablesBuilder.build(),
MetricsAccumulator.getInstance());
JavaRDD<RawUnionValue> staged = inputRdd.mapPartitions(function);
+ String intermediateId = getExecutableStageIntermediateId(transformNode);
+ context.pushDataset(
+ intermediateId,
+ new Dataset() {
+ @Override
+ public void cache(String storageLevel, Coder<?> coder) {
+ StorageLevel level = StorageLevel.fromString(storageLevel);
+ staged.persist(level);
+ }
+
+ @Override
+ public void action() {
+ // Empty function to force computation of RDD.
+ staged.foreach(TranslationUtils.emptyVoidFunction());
+ }
+
+ @Override
+ public void setName(String name) {
+ staged.setName(name);
+ }
+ });
+ // pop dataset to mark RDD as used
+ context.popDataset(intermediateId);
for (String outputId : outputs.values()) {
JavaRDD<WindowedValue<OutputT>> outputRdd =
- staged.flatMap(new SparkExecutableStageExtractionFunction<>(outputMap.get(outputId)));
+ staged.flatMap(
+ new SparkExecutableStageExtractionFunction<>(outputExtractionMap.get(outputId)));
context.pushDataset(outputId, new BoundedDataset<>(outputRdd));
}
if (outputs.isEmpty()) {
@@ -249,17 +283,9 @@ public class SparkBatchPortablePipelineTranslator {
*/
private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<T>> broadcastSideInput(
String collectionId, RunnerApi.Components components, SparkTranslationContext context) {
- PCollection collection = components.getPcollectionsOrThrow(collectionId);
@SuppressWarnings("unchecked")
BoundedDataset<T> dataset = (BoundedDataset<T>) context.popDataset(collectionId);
- PCollectionNode collectionNode = PipelineNode.pCollection(collectionId, collection);
- WindowedValueCoder<T> coder;
- try {
- coder =
- (WindowedValueCoder<T>) WireCoders.instantiateRunnerWireCoder(collectionNode, components);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
+ WindowedValueCoder<T> coder = getWindowedValueCoder(collectionId, components);
List<byte[]> bytes = dataset.getBytes(coder);
Broadcast<List<byte[]>> broadcast = context.getSparkContext().broadcast(bytes);
return new Tuple2<>(broadcast, coder);
@@ -324,4 +350,22 @@ public class SparkBatchPortablePipelineTranslator {
private static String getOutputId(PTransformNode transformNode) {
return Iterables.getOnlyElement(transformNode.getTransform().getOutputsMap().values());
}
+
+ private static <T> WindowedValueCoder<T> getWindowedValueCoder(
+ String pCollectionId, RunnerApi.Components components) {
+ PCollection pCollection = components.getPcollectionsOrThrow(pCollectionId);
+ PCollectionNode pCollectionNode = PipelineNode.pCollection(pCollectionId, pCollection);
+ WindowedValueCoder<T> coder;
+ try {
+ coder =
+ (WindowedValueCoder) WireCoders.instantiateRunnerWireCoder(pCollectionNode, components);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ return coder;
+ }
+
+ private static String getExecutableStageIntermediateId(PTransformNode transformNode) {
+ return transformNode.getId();
+ }
}
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java
index 772e0d2..8c2cee8 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkTranslationContext.java
@@ -17,12 +17,16 @@
*/
package org.apache.beam.runners.spark.translation;
+import com.sun.istack.Nullable;
+import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.fnexecution.provisioning.JobInfo;
+import org.apache.beam.runners.spark.SparkPipelineOptions;
+import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.spark.api.java.JavaSparkContext;
@@ -33,6 +37,9 @@ import org.apache.spark.api.java.JavaSparkContext;
public class SparkTranslationContext {
private final JavaSparkContext jsc;
final JobInfo jobInfo;
+ // Map pCollection IDs to the number of times they are consumed as inputs.
+ private final Map<String, Integer> consumptionCount = new HashMap<>();
+ private final Map<String, Coder> coderMap = new HashMap<>();
private final Map<String, Dataset> datasets = new LinkedHashMap<>();
private final Set<Dataset> leaves = new LinkedHashSet<>();
final SerializablePipelineOptions serializablePipelineOptions;
@@ -51,7 +58,13 @@ public class SparkTranslationContext {
/** Add output of transform to context. */
public void pushDataset(String pCollectionId, Dataset dataset) {
dataset.setName(pCollectionId);
- // TODO cache
+ SparkPipelineOptions sparkOptions =
+ serializablePipelineOptions.get().as(SparkPipelineOptions.class);
+ if (!sparkOptions.isCacheDisabled() && consumptionCount.getOrDefault(pCollectionId, 0) > 1) {
+ String storageLevel = sparkOptions.getStorageLevel();
+ @Nullable Coder coder = coderMap.get(pCollectionId);
+ dataset.cache(storageLevel, coder);
+ }
datasets.put(pCollectionId, dataset);
leaves.add(dataset);
}
@@ -70,6 +83,15 @@ public class SparkTranslationContext {
}
}
+ void incrementConsumptionCountBy(String pCollectionId, int addend) {
+ int count = consumptionCount.getOrDefault(pCollectionId, 0);
+ consumptionCount.put(pCollectionId, count + addend);
+ }
+
+ void putCoder(String pCollectionId, Coder coder) {
+ coderMap.put(pCollectionId, coder);
+ }
+
/** Generate a unique pCollection id number to identify runner-generated sinks. */
public int nextSinkId() {
return sinkId++;
diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
index d7d3428..eb9dce0 100644
--- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
+++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkPortableExecutionTest.java
@@ -17,9 +17,11 @@
*/
package org.apache.beam.runners.spark;
+import java.io.File;
import java.io.Serializable;
+import java.nio.file.FileSystems;
import java.util.Collections;
-import java.util.concurrent.Executors;
+import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.apache.beam.model.jobmanagement.v1.JobApi.JobState.Enum;
import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -46,8 +48,11 @@ import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableLis
import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.ListeningExecutorService;
import org.apache.beam.vendor.guava.v20_0.com.google.common.util.concurrent.MoreExecutors;
import org.junit.AfterClass;
+import org.junit.Assert;
import org.junit.BeforeClass;
+import org.junit.ClassRule;
import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -56,15 +61,13 @@ import org.slf4j.LoggerFactory;
*/
public class SparkPortableExecutionTest implements Serializable {
+ @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder();
private static final Logger LOG = LoggerFactory.getLogger(SparkPortableExecutionTest.class);
-
private static ListeningExecutorService sparkJobExecutor;
@BeforeClass
public static void setup() {
- // Restrict this to only one thread to avoid multiple Spark clusters up at the same time
- // which is not suitable for memory-constraint environments, i.e. Jenkins.
- sparkJobExecutor = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(1));
+ sparkJobExecutor = MoreExecutors.newDirectExecutorService();
}
@AfterClass
@@ -159,8 +162,117 @@ public class SparkPortableExecutionTest implements Serializable {
pipelineProto,
options.as(SparkPipelineOptions.class));
jobInvocation.start();
- while (jobInvocation.getState() != Enum.DONE) {
- Thread.sleep(1000);
+ Assert.assertEquals(Enum.DONE, jobInvocation.getState());
+ }
+
+ /**
+ * Verifies that each executable stage runs exactly once, even if that executable stage has
+ * multiple immediate outputs. While re-computation may be necessary in the event of failure,
+ * re-computation of a whole executable stage is expensive and can cause unexpected behavior when
+ * the executable stage has side effects (BEAM-7131).
+ *
+ * <pre>
+ * |-> B -> GBK
+ * A -|
+ * |-> C -> GBK
+ * </pre>
+ */
+ @Test(timeout = 120_000)
+ public void testExecStageWithMultipleOutputs() throws Exception {
+ PipelineOptions options = PipelineOptionsFactory.create();
+ options.setRunner(CrashingRunner.class);
+ options
+ .as(PortablePipelineOptions.class)
+ .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
+ Pipeline pipeline = Pipeline.create(options);
+ PCollection<KV<String, String>> a =
+ pipeline
+ .apply("impulse", Impulse.create())
+ .apply("A", ParDo.of(new DoFnWithSideEffect<>("A")));
+ PCollection<KV<String, String>> b = a.apply("B", ParDo.of(new DoFnWithSideEffect<>("B")));
+ PCollection<KV<String, String>> c = a.apply("C", ParDo.of(new DoFnWithSideEffect<>("C")));
+ // Use GBKs to force re-computation of executable stage unless cached.
+ b.apply(GroupByKey.create());
+ c.apply(GroupByKey.create());
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline);
+ JobInvocation jobInvocation =
+ SparkJobInvoker.createJobInvocation(
+ "testExecStageWithMultipleOutputs",
+ "testExecStageWithMultipleOutputsRetrievalToken",
+ sparkJobExecutor,
+ pipelineProto,
+ options.as(SparkPipelineOptions.class));
+ jobInvocation.start();
+ Assert.assertEquals(Enum.DONE, jobInvocation.getState());
+ }
+
+ /**
+ * Verifies that each executable stage runs exactly once, even if that executable stage has
+ * multiple downstream consumers. While re-computation may be necessary in the event of failure,
+ * re-computation of a whole executable stage is expensive and can cause unexpected behavior when
+ * the executable stage has side effects (BEAM-7131).
+ *
+ * <pre>
+ * |-> G
+ * F -> GBK -|
+ * |-> H
+ * </pre>
+ */
+ @Test(timeout = 120_000)
+ public void testExecStageWithMultipleConsumers() throws Exception {
+ PipelineOptions options = PipelineOptionsFactory.create();
+ options.setRunner(CrashingRunner.class);
+ options
+ .as(PortablePipelineOptions.class)
+ .setDefaultEnvironmentType(Environments.ENVIRONMENT_EMBEDDED);
+ Pipeline pipeline = Pipeline.create(options);
+ PCollection<KV<String, Iterable<String>>> f =
+ pipeline
+ .apply("impulse", Impulse.create())
+ .apply("F", ParDo.of(new DoFnWithSideEffect<>("F")))
+ // use GBK to prevent fusion of F, G, and H
+ .apply(GroupByKey.create());
+ f.apply("G", ParDo.of(new DoFnWithSideEffect<>("G")));
+ f.apply("H", ParDo.of(new DoFnWithSideEffect<>("H")));
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(pipeline);
+ JobInvocation jobInvocation =
+ SparkJobInvoker.createJobInvocation(
+ "testExecStageWithMultipleConsumers",
+ "testExecStageWithMultipleConsumersRetrievalToken",
+ sparkJobExecutor,
+ pipelineProto,
+ options.as(SparkPipelineOptions.class));
+ jobInvocation.start();
+ Assert.assertEquals(Enum.DONE, jobInvocation.getState());
+ }
+
+ /** A non-idempotent DoFn that cannot be run more than once without error. */
+ private class DoFnWithSideEffect<InputT> extends DoFn<InputT, KV<String, String>> {
+
+ private final String name;
+ private final File file;
+
+ DoFnWithSideEffect(String name) {
+ this.name = name;
+ String path =
+ FileSystems.getDefault()
+ .getPath(
+ temporaryFolder.getRoot().getAbsolutePath(),
+ String.format("%s-%s", this.name, UUID.randomUUID().toString()))
+ .toString();
+ file = new File(path);
+ }
+
+ @ProcessElement
+ public void process(ProcessContext context) throws Exception {
+ context.output(KV.of(name, name));
+ // Verify this DoFn has not run more than once by enacting a side effect via the local file
+ // system.
+ Assert.assertTrue(
+ String.format(
+ "Create file %s failed (DoFn %s should only have been run once).",
+ file.getAbsolutePath(), name),
+ file.createNewFile());
}
}
}