You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by lc...@apache.org on 2018/08/13 21:24:02 UTC

[beam] branch master updated: [BEAM-4826] Sanitize pCollections before sending to SDKHarness

This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new a1af90a  [BEAM-4826] Sanitize pCollections before sending to SDKHarness
a1af90a is described below

commit a1af90a9debdfc7d5564045c25081c28dc14384f
Author: Ankur <an...@users.noreply.github.com>
AuthorDate: Mon Aug 13 14:23:57 2018 -0700

    [BEAM-4826] Sanitize pCollections before sending to SDKHarness
---
 .../construction/graph/GreedyPipelineFuser.java    |  98 +++++++++++++++
 .../graph/GreedyPipelineFuserTest.java             | 131 +++++++++++++++++++++
 .../control/ProcessBundleDescriptors.java          |   2 +-
 .../fnexecution/control/RemoteExecutionTest.java   | 102 ++++++++++++++++
 4 files changed, 332 insertions(+), 1 deletion(-)

diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
index 945c194..c248fd2 100644
--- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
+++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuser.java
@@ -24,6 +24,7 @@ import static com.google.common.base.Preconditions.checkState;
 import com.google.auto.value.AutoValue;
 import com.google.common.collect.ComparisonChain;
 import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Multimap;
 import com.google.common.collect.Sets;
 import java.util.ArrayDeque;
@@ -33,13 +34,16 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.NavigableSet;
 import java.util.Queue;
 import java.util.Set;
 import java.util.TreeSet;
 import java.util.stream.Collectors;
+import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Environment;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
+import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
 import org.apache.beam.model.pipeline.v1.RunnerApi.Pipeline;
 import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.graph.OutputDeduplicator.DeduplicationResult;
@@ -165,6 +169,7 @@ public class GreedyPipelineFuser {
         stages
             .stream()
             .map(stage -> deduplicated.getDeduplicatedStages().getOrDefault(stage, stage))
+            .map(GreedyPipelineFuser::sanitizeDanglingPTransformInputs)
             .collect(Collectors.toSet()),
         Sets.union(
             deduplicated.getIntroducedTransforms(),
@@ -342,6 +347,99 @@ public class GreedyPipelineFuser {
             .collect(Collectors.toSet()));
   }
 
+  private static ExecutableStage sanitizeDanglingPTransformInputs(ExecutableStage stage) {
+    /* Possible inputs to a PTransform can only be those which are:
+     * <ul>
+     *  <li>Explicit input PCollection to the stage
+     *  <li>Outputs of a PTransform within the same stage
+     *  <li>Timer PCollections
+     *  <li>Side input PCollections
+     *  <li>Explicit outputs from the stage
+     * </ul>
+     */
+    Set<String> possibleInputs = new HashSet<>();
+    possibleInputs.add(stage.getInputPCollection().getId());
+    possibleInputs.addAll(
+        stage
+            .getOutputPCollections()
+            .stream()
+            .map(PCollectionNode::getId)
+            .collect(Collectors.toSet()));
+    possibleInputs.addAll(
+        stage.getTimers().stream().map(t -> t.collection().getId()).collect(Collectors.toSet()));
+    possibleInputs.addAll(
+        stage
+            .getSideInputs()
+            .stream()
+            .map(s -> s.collection().getId())
+            .collect(Collectors.toSet()));
+    possibleInputs.addAll(
+        stage
+            .getTransforms()
+            .stream()
+            .flatMap(t -> t.getTransform().getOutputsMap().values().stream())
+            .collect(Collectors.toSet()));
+    Set<String> danglingInputs =
+        stage
+            .getTransforms()
+            .stream()
+            .flatMap(t -> t.getTransform().getInputsMap().values().stream())
+            .filter(in -> !possibleInputs.contains(in))
+            .collect(Collectors.toSet());
+
+    ImmutableList.Builder<PTransformNode> pTransformNodesBuilder = ImmutableList.builder();
+    for (PTransformNode transformNode : stage.getTransforms()) {
+      PTransform transform = transformNode.getTransform();
+      Map<String, String> validInputs =
+          transform
+              .getInputsMap()
+              .entrySet()
+              .stream()
+              .filter(e -> !danglingInputs.contains(e.getValue()))
+              .collect(Collectors.toMap(Entry::getKey, Entry::getValue));
+
+      if (!validInputs.equals(transform.getInputsMap())) {
+        // Dangling inputs found so recreate pTransform without the dangling inputs.
+        transformNode =
+            PipelineNode.pTransform(
+                transformNode.getId(),
+                transform.toBuilder().clearInputs().putAllInputs(validInputs).build());
+      }
+
+      pTransformNodesBuilder.add(transformNode);
+    }
+    ImmutableList<PTransformNode> pTransformNodes = pTransformNodesBuilder.build();
+    Components.Builder componentBuilder = stage.getComponents().toBuilder();
+    // Update the pTransforms in components.
+    componentBuilder
+        .clearTransforms()
+        .putAllTransforms(
+            pTransformNodes
+                .stream()
+                .collect(Collectors.toMap(PTransformNode::getId, PTransformNode::getTransform)));
+    Map<String, PCollection> validPCollectionMap =
+        stage
+            .getComponents()
+            .getPcollectionsMap()
+            .entrySet()
+            .stream()
+            .filter(e -> !danglingInputs.contains(e.getKey()))
+            .collect(Collectors.toMap(Entry::getKey, Entry::getValue));
+
+    // Update pCollections in the components.
+    componentBuilder.clearPcollections().putAllPcollections(validPCollectionMap);
+
+    return ImmutableExecutableStage.of(
+        componentBuilder.build(),
+        stage.getEnvironment(),
+        stage.getInputPCollection(),
+        stage.getSideInputs(),
+        stage.getUserStates(),
+        stage.getTimers(),
+        pTransformNodes,
+        stage.getOutputPCollections());
+  }
+
   /**
    * A ({@link PCollectionNode}, {@link PTransformNode}) pair representing a single {@link
    * PTransformNode} consuming a single materialized {@link PCollectionNode}.
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
index 04b7107..c47f743 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/graph/GreedyPipelineFuserTest.java
@@ -1144,4 +1144,135 @@ public class GreedyPipelineFuserTest {
                 .withNoOutputs()
                 .withTransforms("goTransform")));
   }
+
+  @Test
+  public void sanitizedTransforms() throws Exception {
+
+    PCollection flattenOutput = pc("flatten.out");
+    PCollection read1Output = pc("read1.out");
+    PCollection read2Output = pc("read2.out");
+    PCollection impulse1Output = pc("impulse1.out");
+    PCollection impulse2Output = pc("impulse2.out");
+    PTransform flattenTransform =
+        PTransform.newBuilder()
+            .setUniqueName("Flatten")
+            .putInputs(read1Output.getUniqueName(), read1Output.getUniqueName())
+            .putInputs(read2Output.getUniqueName(), read2Output.getUniqueName())
+            .putOutputs(flattenOutput.getUniqueName(), flattenOutput.getUniqueName())
+            .setSpec(
+                FunctionSpec.newBuilder()
+                    .setUrn(PTransformTranslation.FLATTEN_TRANSFORM_URN)
+                    .setPayload(
+                        WindowIntoPayload.newBuilder()
+                            .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+                            .build()
+                            .toByteString()))
+            .build();
+
+    PTransform read1Transform =
+        PTransform.newBuilder()
+            .setUniqueName("read1")
+            .putInputs(impulse1Output.getUniqueName(), impulse1Output.getUniqueName())
+            .putOutputs(read1Output.getUniqueName(), read1Output.getUniqueName())
+            .setSpec(
+                FunctionSpec.newBuilder()
+                    .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN)
+                    .setPayload(
+                        WindowIntoPayload.newBuilder()
+                            .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+                            .build()
+                            .toByteString()))
+            .build();
+    PTransform read2Transform =
+        PTransform.newBuilder()
+            .setUniqueName("read2")
+            .putInputs(impulse2Output.getUniqueName(), impulse2Output.getUniqueName())
+            .putOutputs(read2Output.getUniqueName(), read2Output.getUniqueName())
+            .setSpec(
+                FunctionSpec.newBuilder()
+                    .setUrn(PTransformTranslation.PAR_DO_TRANSFORM_URN)
+                    .setPayload(
+                        WindowIntoPayload.newBuilder()
+                            .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+                            .build()
+                            .toByteString()))
+            .build();
+
+    PTransform impulse1Transform =
+        PTransform.newBuilder()
+            .setUniqueName("impulse1")
+            .putOutputs(impulse1Output.getUniqueName(), impulse1Output.getUniqueName())
+            .setSpec(
+                FunctionSpec.newBuilder()
+                    .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)
+                    .setPayload(
+                        WindowIntoPayload.newBuilder()
+                            .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+                            .build()
+                            .toByteString()))
+            .build();
+    PTransform impulse2Transform =
+        PTransform.newBuilder()
+            .setUniqueName("impulse2")
+            .putOutputs(impulse2Output.getUniqueName(), impulse2Output.getUniqueName())
+            .setSpec(
+                FunctionSpec.newBuilder()
+                    .setUrn(PTransformTranslation.IMPULSE_TRANSFORM_URN)
+                    .setPayload(
+                        WindowIntoPayload.newBuilder()
+                            .setWindowFn(SdkFunctionSpec.newBuilder().setEnvironmentId("py"))
+                            .build()
+                            .toByteString()))
+            .build();
+    Pipeline impulse =
+        Pipeline.newBuilder()
+            .addRootTransformIds(impulse1Transform.getUniqueName())
+            .addRootTransformIds(impulse2Transform.getUniqueName())
+            .addRootTransformIds(flattenTransform.getUniqueName())
+            .setComponents(
+                Components.newBuilder()
+                    .putCoders("coder", Coder.newBuilder().build())
+                    .putCoders("windowCoder", Coder.newBuilder().build())
+                    .putWindowingStrategies(
+                        "ws",
+                        WindowingStrategy.newBuilder().setWindowCoderId("windowCoder").build())
+                    .putEnvironments("py", Environment.newBuilder().setUrl("py").build())
+                    .putPcollections(flattenOutput.getUniqueName(), flattenOutput)
+                    .putTransforms(flattenTransform.getUniqueName(), flattenTransform)
+                    .putPcollections(read1Output.getUniqueName(), read1Output)
+                    .putTransforms(read1Transform.getUniqueName(), read1Transform)
+                    .putPcollections(read2Output.getUniqueName(), read2Output)
+                    .putTransforms(read2Transform.getUniqueName(), read2Transform)
+                    .putPcollections(impulse1Output.getUniqueName(), impulse1Output)
+                    .putTransforms(impulse1Transform.getUniqueName(), impulse1Transform)
+                    .putPcollections(impulse2Output.getUniqueName(), impulse2Output)
+                    .putTransforms(impulse2Transform.getUniqueName(), impulse2Transform)
+                    .build())
+            .build();
+    FusedPipeline fused = GreedyPipelineFuser.fuse(impulse);
+
+    assertThat(fused.getRunnerExecutedTransforms(), hasSize(2));
+    assertThat(fused.getFusedStages(), hasSize(2));
+
+    assertThat(
+        fused.getFusedStages(),
+        containsInAnyOrder(
+            ExecutableStageMatcher.withInput(impulse1Output.getUniqueName())
+                .withTransforms(flattenTransform.getUniqueName(), read1Transform.getUniqueName()),
+            ExecutableStageMatcher.withInput(impulse2Output.getUniqueName())
+                .withTransforms(flattenTransform.getUniqueName(), read2Transform.getUniqueName())));
+    assertThat(
+        fused
+            .getFusedStages()
+            .stream()
+            .flatMap(
+                s ->
+                    s.getComponents()
+                        .getTransformsOrThrow(flattenTransform.getUniqueName())
+                        .getInputsMap()
+                        .values()
+                        .stream())
+            .collect(Collectors.toList()),
+        containsInAnyOrder(read1Output.getUniqueName(), read2Output.getUniqueName()));
+  }
 }
diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
index bc4bdc4..e52c9d8 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
@@ -56,7 +56,6 @@ import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.fn.data.RemoteGrpcPortRead;
 import org.apache.beam.sdk.fn.data.RemoteGrpcPortWrite;
 import org.apache.beam.sdk.state.TimeDomain;
-import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.util.WindowedValue;
@@ -146,6 +145,7 @@ public class ProcessBundleDescriptors {
     if (stateEndpoint != null) {
       bundleDescriptorBuilder.setStateApiServiceDescriptor(stateEndpoint);
     }
+
     bundleDescriptorBuilder
         .putAllCoders(components.getCodersMap())
         .putAllEnvironments(components.getEnvironmentsMap())
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
index 593390e..96ae9de 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/RemoteExecutionTest.java
@@ -20,6 +20,7 @@ package org.apache.beam.runners.fnexecution.control;
 
 import static com.google.common.base.Preconditions.checkState;
 import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.equalTo;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.fail;
 
@@ -47,6 +48,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.ThreadFactory;
+import java.util.function.Function;
 import org.apache.beam.fn.harness.FnHarness;
 import org.apache.beam.model.fnexecution.v1.BeamFnApi.Target;
 import org.apache.beam.model.pipeline.v1.RunnerApi;
@@ -92,6 +94,7 @@ import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
 import org.apache.beam.sdk.testing.ResetDateTimeProvider;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.Impulse;
 import org.apache.beam.sdk.transforms.ParDo;
@@ -103,6 +106,7 @@ import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.PCollectionList;
 import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.vendor.protobuf.v3.com.google.protobuf.ByteString;
 import org.hamcrest.collection.IsEmptyIterable;
@@ -709,6 +713,104 @@ public class RemoteExecutionTest implements Serializable {
             timerStructuralValue(WindowedValue.valueInGlobalWindow(timerBytes("Z", 22L)))));
   }
 
+  @Test
+  public void testExecutionWithMultipleStages() throws Exception {
+    Pipeline p = Pipeline.create();
+
+    Function<String, PCollection<String>> pCollectionGenerator =
+        suffix ->
+            p.apply("impulse" + suffix, Impulse.create())
+                .apply(
+                    "create" + suffix,
+                    ParDo.of(
+                        new DoFn<byte[], String>() {
+                          @ProcessElement
+                          public void process(ProcessContext c) {
+                            try {
+                              c.output(
+                                  CoderUtils.decodeFromByteArray(
+                                      StringUtf8Coder.of(), c.element()));
+                            } catch (CoderException e) {
+                              throw new RuntimeException(e);
+                            }
+                          }
+                        }))
+                .setCoder(StringUtf8Coder.of())
+                .apply(
+                    ParDo.of(
+                        new DoFn<String, String>() {
+                          @ProcessElement
+                          public void processElement(ProcessContext c) {
+                            c.output("stream" + suffix + c.element());
+                          }
+                        }));
+    PCollection<String> input1 = pCollectionGenerator.apply("1");
+    PCollection<String> input2 = pCollectionGenerator.apply("2");
+
+    PCollection<String> outputMerged =
+        PCollectionList.of(input1).and(input2).apply(Flatten.pCollections());
+    outputMerged
+        .apply(
+            "createKV",
+            ParDo.of(
+                new DoFn<String, KV<String, String>>() {
+                  @ProcessElement
+                  public void process(ProcessContext c) {
+                    c.output(KV.of(c.element(), ""));
+                  }
+                }))
+        .setCoder(KvCoder.of(StringUtf8Coder.of(), StringUtf8Coder.of()))
+        .apply("gbk", GroupByKey.create());
+
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineProto);
+    Set<ExecutableStage> stages = fused.getFusedStages();
+
+    assertThat(stages.size(), equalTo(2));
+
+    List<WindowedValue<?>> outputValues = Collections.synchronizedList(new ArrayList<>());
+
+    for (ExecutableStage stage : stages) {
+      ExecutableProcessBundleDescriptor descriptor =
+          ProcessBundleDescriptors.fromExecutableStage(
+              stage.toString(),
+              stage,
+              dataServer.getApiServiceDescriptor(),
+              stateServer.getApiServiceDescriptor());
+
+      BundleProcessor processor =
+          controlClient.getProcessor(
+              descriptor.getProcessBundleDescriptor(),
+              descriptor.getRemoteInputDestinations(),
+              stateDelegator);
+      Map<Target, Coder<WindowedValue<?>>> outputTargets = descriptor.getOutputTargetCoders();
+      Map<Target, RemoteOutputReceiver<?>> outputReceivers = new HashMap<>();
+      for (Entry<Target, Coder<WindowedValue<?>>> targetCoder : outputTargets.entrySet()) {
+        outputReceivers.putIfAbsent(
+            targetCoder.getKey(),
+            RemoteOutputReceiver.of(targetCoder.getValue(), outputValues::add));
+      }
+
+      try (ActiveBundle bundle =
+          processor.newBundle(
+              outputReceivers,
+              StateRequestHandler.unsupported(),
+              BundleProgressHandler.unsupported())) {
+        bundle
+            .getInputReceivers()
+            .get(stage.getInputPCollection().getId())
+            .accept(
+                WindowedValue.valueInGlobalWindow(
+                    CoderUtils.encodeToByteArray(StringUtf8Coder.of(), "X")));
+      }
+    }
+    assertThat(
+        outputValues,
+        containsInAnyOrder(
+            WindowedValue.valueInGlobalWindow(kvBytes("stream1X", "")),
+            WindowedValue.valueInGlobalWindow(kvBytes("stream2X", ""))));
+  }
+
   private KV<byte[], byte[]> kvBytes(String key, long value) throws CoderException {
     return KV.of(
         CoderUtils.encodeToByteArray(StringUtf8Coder.of(), key),