You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2018/08/13 21:24:02 UTC
[beam] branch master updated: [BEAM-4826] Sanitize pCollections
before sending to SDKHarness
This is an automated email from the ASF dual-hosted git repository.
lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new a1af90a [BEAM-4826] Sanitize pCollections before sending to SDKHarness
a1af90a is described below
commit a1af90a9debdfc7d5564045c25081c28dc14384f
Author: Ankur <an...@users.noreply.github.com>
AuthorDate: Mon Aug 13 14:23:57 2018 -0700
[BEAM-4826] Sanitize pCollections before sending to SDKHarness
---
.../construction/graph/GreedyPipelineFuser.java | 98 +++++++++++++++
.../graph/GreedyPipelineFuserTest.java | 131 +++++++++++++++++++++
.../control/ProcessBundleDescriptors.java | 2 +-
.../fnexecution/control/RemoteExecutionTest.java | 102 ++++++++++++++++
4 files changed, 332 insertions(+), 1 deletion(-)
diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
index 945c194..c248fd2 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
@@ -24,6 +24,7 @@ import static com.google.common.base.Preconditions.checkState;
import com.google.auto.value.AutoValue;
import com.google.common.collect.ComparisonChain;
import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
import java.util.ArrayDeque;
@@ -33,13 +34,16 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
+import java.util.Map.Entry;
import java.util.NavigableSet;
import java.util.Queue;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.graph.OutputDeduplicator.DeduplicationResult;
@@ -165,6 +169,7 @@ public class GreedyPipelineFuser {
stages
.stream()
.map(stage -> deduplicated.getDeduplicatedStages().getOrDefault(stage, stage))
+ .map(GreedyPipelineFuser::sanitizeDanglingPTransformInputs)
.collect(Collectors.toSet()),
Sets.union(
deduplicated.getIntroducedTransforms(),
@@ -342,6 +347,99 @@ public class GreedyPipelineFuser {
.collect(Collectors.toSet()));
}
+ private static ExecutableStage sanitizeDanglingPTransformInputs(ExecutableStage stage) {
+ /* Possible inputs to a PTransform can only be those which are:
+ * <ul>
+ * <li>Explicit input PCollection to the stage
+ * <li>Outputs of a PTransform within the same stage
+ * <li>Timer PCollections
+ * <li>Side input PCollections
+ * <li>Explicit outputs from the stage
+ * </ul>
+ */
+ Set<String> possibleInputs = new HashSet<>();
+ possibleInputs.add(stage.getInputPCollection().getId());
+ possibleInputs.addAll(
+ stage
+ .getOutputPCollections()
+ .stream()
+ .map(PCollectionNode::getId)
+ .collect(Collectors.toSet()));
+ possibleInputs.addAll(
+ stage.getTimers().stream().map(t -> t.collection().getId()).collect(Collectors.toSet()));
+ possibleInputs.addAll(
+ stage
+ .getSideInputs()
+ .stream()
+ .map(s -> s.collection().getId())
+ .collect(Collectors.toSet()));
+ possibleInputs.addAll(
+ stage
+ .getTransforms()
+ .stream()
+ .flatMap(t -> t.getTransform().getOutputsMap().values().stream())
+ .collect(Collectors.toSet()));
+ Set<String> danglingInputs =
+ stage
+ .getTransforms()
+ .stream()
+ .flatMap(t -> t.getTransform().getInputsMap().values().stream())
+ .filter(in -> !possibleInputs.contains(in))
+ .collect(Collectors.toSet());
+
+ ImmutableList.Builder<PTransformNode> pTransformNodesBuilder = ImmutableList.builder();
+ for (PTransformNode transformNode : stage.getTransforms()) {
+ PTransform transform = transformNode.getTransform();
+ Map<String, String> validInputs =
+ transform
+ .getInputsMap()
+ .entrySet()
+ .stream()
+ .filter(e -> !danglingInputs.contains(e.getValue()))
+ .collect(Collectors.toMap(Entry::getKey, Entry::getValue));
+
+ if (!validInputs.equals(transform.getInputsMap())) {
+ // Dangling inputs found so recreate pTransform without the dangling inputs.
+ transformNode =
+ PipelineNode.pTransform(
+ transformNode.getId(),
+ transform.toBuilder().clearInputs().putAllInputs(validInputs).build());
+ }
+
+ pTransformNodesBuilder.add(transformNode);
+ }
+ ImmutableList<PTransformNode> pTransformNodes = pTransformNodesBuilder.build();
+ Components.Builder componentBuilder = stage.getComponents().toBuilder();
+ // Update the pTransforms in components.
+ componentBuilder
+ .clearTransforms()
+ .putAllTransforms(
+ pTransformNodes
+ .stream()
+ .collect(Collectors.toMap(PTransformNode::getId, PTransformNode::getTransform)));
+ Map<String, PCollection> validPCollectionMap =
+ stage
+ .getComponents()
+ .getPcollectionsMap()
+ .entrySet()
+ .stream()
+ .filter(e -> !danglingInputs.contains(e.getKey()))
+ .collect(Collectors.toMap(Entry::getKey, Entry::getValue));
+
+ // Update pCollections in the components.
+ componentBuilder.clearPcollections().putAllPcollections(validPCollectionMap);
+
+ return ImmutableExecutableStage.of(
+ componentBuilder.build(),
+ stage.getEnvironment(),
+ stage.getInputPCollection(),
+ stage.getSideInputs(),
+ stage.getUserStates(),
+ stage.getTimers(),
+ pTransformNodes,
+ stage.getOutputPCollections());
+ }
+
/**
* A ({@link PCollectionNode}, {@link PTransformNode}) pair representing a single {@link
* PTransformNode} consuming a single materialized {@link PCollectionNode}.
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
index 04b7107..c47f743 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
@@ -1144,4 +1144,135 @@ public class GreedyPipelineFuserTest {
.withNoOutputs()
.withTransforms("goTransform")));
}
+
+ @Test
+ public void sanitizedTransforms() throws Exception {
+
+ PCollection flattenOutput = pc("flatten.out");
+ PCollection read1Output = pc("read1.out");
+ PCollection read2Output = pc("read2.out");
+ PCollection impulse1Output = pc("impulse1.out");
+ PCollection impulse2Output = pc("impulse2.out");
+ PTransform flattenTransform =
+ PTransform.newBuilder()
+ .setUniqueName("Flatten")
+ .putInputs(read1Output.getUniqueName(), read1Output.getUniqueName())
+ .putInputs(read2Output.getUniqueName(), read2Output.getUniqueName())
+ .putOutputs(flattenOutput.getUniqueName(), flattenOutput.getUniqueName())
+ .setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN)
+ .setPayload(
+ WindowIntoPayload.newBuilder()
+ .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+ .build()
+ .toByteString()))
+ .build();
+
+ PTransform read1Transform =
+ PTransform.newBuilder()
+ .setUniqueName("read1")
+ .putInputs(impulse1Output.getUniqueName(), impulse1Output.getUniqueName())
+ .putOutputs(read1Output.getUniqueName(), read1Output.getUniqueName())
+ .setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN)
+ .setPayload(
+ WindowIntoPayload.newBuilder()
+ .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+ .build()
+ .toByteString()))
+ .build();
+ PTransform read2Transform =
+ PTransform.newBuilder()
+ .setUniqueName("read2")
+ .putInputs(impulse2Output.getUniqueName(), impulse2Output.getUniqueName())
+ .putOutputs(read2Output.getUniqueName(), read2Output.getUniqueName())
+ .setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN)
+ .setPayload(
+ WindowIntoPayload.newBuilder()
+ .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+ .build()
+ .toByteString()))
+ .build();
+
+ PTransform impulse1Transform =
+ PTransform.newBuilder()
+ .setUniqueName("impulse1")
+ .putOutputs(impulse1Output.getUniqueName(), impulse1Output.getUniqueName())
+ .setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)
+ .setPayload(
+ WindowIntoPayload.newBuilder()
+ .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+ .build()
+ .toByteString()))
+ .build();
+ PTransform impulse2Transform =
+ PTransform.newBuilder()
+ .setUniqueName("impulse2")
+ .putOutputs(impulse2Output.getUniqueName(), impulse2Output.getUniqueName())
+ .setSpec(
+ FunctionSpec.newBuilder()
+ .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)
+ .setPayload(
+ WindowIntoPayload.newBuilder()
+ .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+ .build()
+ .toByteString()))
+ .build();
+ Pipeline impulse =
+ Pipeline.newBuilder()
+ .addRootTransformIds(impulse1Transform.getUniqueName())
+ .addRootTransformIds(impulse2Transform.getUniqueName())
+ .addRootTransformIds(flattenTransform.getUniqueName())
+ .setComponents(
+ Components.newBuilder()
+ .putCoders("coder", Coder.newBuilder().build())
+ .putCoders("windowCoder", Coder.newBuilder().build())
+ .putWindowingStrategies(
+ "ws",
+ WindowingStrategy.newBuilder().setWindowCoderId("windowCoder").build())
+ .putEnvironments("py", Environment.newBuilder().setUrl("py").build())
+ .putPcollections(flattenOutput.getUniqueName(), flattenOutput)
+ .putTransforms(flattenTransform.getUniqueName(), flattenTransform)
+ .putPcollections(read1Output.getUniqueName(), read1Output)
+ .putTransforms(read1Transform.getUniqueName(), read1Transform)
+ .putPcollections(read2Output.getUniqueName(), read2Output)
+ .putTransforms(read2Transform.getUniqueName(), read2Transform)
+ .putPcollections(impulse1Output.getUniqueName(), impulse1Output)
+ .putTransforms(impulse1Transform.getUniqueName(), impulse1Transform)
+ .putPcollections(impulse2Output.getUniqueName(), impulse2Output)
+ .putTransforms(impulse2Transform.getUniqueName(), impulse2Transform)
+ .build())
+ .build();
+ FusedPipeline fused = GreedyPipelineFuser.fuse(impulse);
+
+ assertThat(fused.getRunnerExecutedTransforms(), hasSize(2));
+ assertThat(fused.getFusedStages(), hasSize(2));
+
+ assertThat(
+ fused.getFusedStages(),
+ containsInAnyOrder(
+ ExecutableStageMatcher.withInput(impulse1Output.getUniqueName())
+ .withTransforms(flattenTransform.getUniqueName(), read1Transform.getUniqueName()),
+ ExecutableStageMatcher.withInput(impulse2Output.getUniqueName())
+ .withTransforms(flattenTransform.getUniqueName(), read2Transform.getUniqueName())));
+ assertThat(
+ fused
+ .getFusedStages()
+ .stream()
+ .flatMap(
+ s ->
+ s.getComponents()
+ .getTransformsOrThrow(flattenTransform.getUniqueName())
+ .getInputsMap()
+ .values()
+ .stream())
+ .collect(Collectors.toList()),
+ containsInAnyOrder(read1Output.getUniqueName(), read2Output.getUniqueName()));
+ }
}
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
index bc4bdc4..e52c9d8 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
@@ -56,7 +56,6 @@ import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.fn.data.RemoteGrpcPortRead;
import org.apache.beam.sdk.fn.data.RemoteGrpcPortWrite;
import org.apache.beam.sdk.state.TimeDomain;
-import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
@@ -146,6 +145,7 @@ public class ProcessBundleDescriptors {
if (stateEndpoint != null) {
bundleDescriptorBuilder.setStateApiServiceDescriptor(stateEndpoint);
}
+
bundleDescriptorBuilder
.putAllCoders(components.getCodersMap())
.putAllEnvironments(components.getEnvironmentsMap())
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 593390e..96ae9de 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -20,6 +20,7 @@ package org.apache.beam.runners.fnexecution.control;
import static com.google.common.base.Preconditions.checkState;
import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
@@ -47,6 +48,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
+import java.util.function.Function;
import org.apache.beam.fn.harness.FnHarness;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.Target;
import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -92,6 +94,7 @@ import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.state.TimerSpecs;
import org.apache.beam.sdk.testing.ResetDateTimeProvider;
import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Flatten;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.ParDo;
@@ -103,6 +106,7 @@ import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.vendor.protobuf.v3.com.google.protobuf.ByteString;
import org.hamcrest.collection.IsEmptyIterable;
@@ -709,6 +713,104 @@ public class RemoteExecutionTest implements Serializable {
timerStructuralValue(WindowedValue.valueInGlobalWindow(timerBytes("Z", 22L)))));
}
+ @Test
+ public void testExecutionWithMultipleStages() throws Exception {
+ Pipeline p = Pipeline.create();
+
+ Function<String, PCollection<String>> pCollectionGenerator =
+ suffix ->
+ p.apply("impulse" + suffix, Impulse.create())
+ .apply(
+ "create" + suffix,
+ ParDo.of(
+ new DoFn<byte[], String>() {
+ @ProcessElement
+ public void process(ProcessContext c) {
+ try {
+ c.output(
+ CoderUtils.decodeFromByteArray(
+ StringUtf8Coder.of(), c.element()));
+ } catch (CoderException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }))
+ .setCoder(StringUtf8Coder.of())
+ .apply(
+ ParDo.of(
+ new DoFn<String, String>() {
+ @ProcessElement
+ public void processElement(ProcessContext c) {
+ c.output("stream" + suffix + c.element());
+ }
+ }));
+ PCollection<String> input1 = pCollectionGenerator.apply("1");
+ PCollection<String> input2 = pCollectionGenerator.apply("2");
+
+ PCollection<String> outputMerged =
+ PCollectionList.of(input1).and(input2).apply(Flatten.pCollections());
+ outputMerged
+ .apply(
+ "createKV",
+ ParDo.of(
+ new DoFn<String, KV<String, String>>() {
+ @ProcessElement
+ public void process(ProcessContext c) {
+ c.output(KV.of(c.element(), ""));
+ }
+ }))
+ .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
+ .apply("gbk", GroupByKey.create());
+
+ RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+ FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+ Set<ExecutableStage> stages = fused.getFusedStages();
+
+ assertThat(stages.size(), equalTo(2));
+
+ List<WindowedValue<?>> outputValues = Collections.synchronizedList(new ArrayList<>());
+
+ for (ExecutableStage stage : stages) {
+ ExecutableProcessBundleDescriptor descriptor =
+ ProcessBundleDescriptors.fromExecutableStage(
+ stage.toString(),
+ stage,
+ dataServer.getApiServiceDescriptor(),
+ stateServer.getApiServiceDescriptor());
+
+ BundleProcessor processor =
+ controlClient.getProcessor(
+ descriptor.getProcessBundleDescriptor(),
+ descriptor.getRemoteInputDestinations(),
+ stateDelegator);
+ Map<Target, Coder<WindowedValue<?>>> outputTargets = descriptor.getOutputTargetCoders();
+ Map<Target, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+ for (Entry<Target, Coder<WindowedValue<?>>> targetCoder : outputTargets.entrySet()) {
+ outputReceivers.putIfAbsent(
+ targetCoder.getKey(),
+ RemoteOutputReceiver.of(targetCoder.getValue(), outputValues::add));
+ }
+
+ try (ActiveBundle bundle =
+ processor.newBundle(
+ outputReceivers,
+ StateRequestHandler.unsupported(),
+ BundleProgressHandler.unsupported())) {
+ bundle
+ .getInputReceivers()
+ .get(stage.getInputPCollection().getId())
+ .accept(
+ WindowedValue.valueInGlobalWindow(
+ CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
+ }
+ }
+ assertThat(
+ outputValues,
+ containsInAnyOrder(
+ WindowedValue.valueInGlobalWindow(kvBytes("stream1X", "")),
+ WindowedValue.valueInGlobalWindow(kvBytes("stream2X", ""))));
+ }
+
private KV<byte[], byte[]> kvBytes(String key, long value) throws CoderException {
return KV.of(
CoderUtils.encodeToByteArray(StringUtf8Coder.of(), key),