You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by bo...@apache.org on 2020/10/21 18:36:08 UTC

[beam] branch master updated: Lengthprefix any input coder for an ProcessBundleDescriptor.

This is an automated email from the ASF dual-hosted git repository.

boyuanz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new bc955de  Lengthprefix any input coder for an ProcessBundleDescriptor.
     new 5419c3b  Merge pull request #13120 from [BEAM-10940] Lengthprefix any input coder for an ProcessBundleDescriptor.
bc955de is described below

commit bc955ded10e0a054d437adf5c7117004de978d46
Author: Boyuan Zhang <bo...@google.com>
AuthorDate: Wed Oct 14 13:47:55 2020 -0700

    Lengthprefix any input coder for an ProcessBundleDescriptor.
---
 .../control/ProcessBundleDescriptors.java          |  36 +++-----
 .../control/ProcessBundleDescriptorsTest.java      | 101 +++++++++++++++++++++
 2 files changed, 113 insertions(+), 24 deletions(-)

diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
index e76c130..ac3b882 100644
--- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
+++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptors.java
@@ -37,7 +37,6 @@ import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
 import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.WireCoderSetting;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
 import org.apache.beam.model.pipeline.v1.RunnerApi.PTransform;
-import org.apache.beam.runners.core.construction.ModelCoders;
 import org.apache.beam.runners.core.construction.RehydratedComponents;
 import org.apache.beam.runners.core.construction.Timer;
 import org.apache.beam.runners.core.construction.graph.ExecutableStage;
@@ -59,7 +58,6 @@ import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.InvalidProtocolBufferException;
-import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableTable;
@@ -141,9 +139,7 @@ public class ProcessBundleDescriptors {
 
     Map<String, Map<String, TimerSpec>> timerSpecs = forTimerSpecs(stage, components);
 
-    if (bagUserStateSpecs.size() > 0 || timerSpecs.size() > 0) {
-      lengthPrefixKeyCoder(stage.getInputPCollection().getId(), components);
-    }
+    lengthPrefixAnyInputCoder(stage.getInputPCollection().getId(), components);
 
     // Copy data from components to ProcessBundleDescriptor.
     ProcessBundleDescriptor.Builder bundleDescriptorBuilder =
@@ -174,26 +170,18 @@ public class ProcessBundleDescriptors {
   }
 
   /**
-   * Patches the input coder of a stateful transform to ensure that the byte representation of a key
-   * used to partition the input element at the Runner, matches the key byte representation received
-   * for state requests and timers from the SDK Harness. Stateful transforms always have a KvCoder
-   * as input.
+   * Patches the input coder of the transform to ensure that the byte representation of input used
+   * at the Runner, matches the byte representation received from the SDK Harness.
    */
-  private static void lengthPrefixKeyCoder(
-      String inputColId, Components.Builder componentsBuilder) {
-    RunnerApi.PCollection pcollection = componentsBuilder.getPcollectionsOrThrow(inputColId);
-    RunnerApi.Coder kvCoder = componentsBuilder.getCodersOrThrow(pcollection.getCoderId());
-    Preconditions.checkState(
-        ModelCoders.KV_CODER_URN.equals(kvCoder.getSpec().getUrn()),
-        "Stateful executable stages must use a KV coder, but is: %s",
-        kvCoder.getSpec().getUrn());
-    String keyCoderId = ModelCoders.getKvCoderComponents(kvCoder).keyCoderId();
-    // Retain the original coder, but wrap in LengthPrefixCoder
-    String newKeyCoderId =
-        LengthPrefixUnknownCoders.addLengthPrefixedCoder(keyCoderId, componentsBuilder, false);
-    // Replace old key coder with LengthPrefixCoder<old_key_coder>
-    kvCoder = kvCoder.toBuilder().setComponentCoderIds(0, newKeyCoderId).build();
-    componentsBuilder.putCoders(pcollection.getCoderId(), kvCoder);
+  private static void lengthPrefixAnyInputCoder(
+      String inputPCollectionId, Components.Builder componentsBuilder) {
+    RunnerApi.PCollection pcollection =
+        componentsBuilder.getPcollectionsOrThrow(inputPCollectionId);
+    String newInputCoderId =
+        LengthPrefixUnknownCoders.addLengthPrefixedCoder(
+            pcollection.getCoderId(), componentsBuilder, false);
+    componentsBuilder.putPcollections(
+        inputPCollectionId, pcollection.toBuilder().setCoderId(newInputCoderId).build());
   }
 
   private static Map<String, Coder<WindowedValue<?>>> addStageOutputs(
diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
index 98fe899..9337c63 100644
--- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
+++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/control/ProcessBundleDescriptorsTest.java
@@ -29,11 +29,14 @@ import org.apache.beam.model.pipeline.v1.RunnerApi;
 import org.apache.beam.runners.core.construction.CoderTranslation;
 import org.apache.beam.runners.core.construction.ModelCoderRegistrar;
 import org.apache.beam.runners.core.construction.ModelCoders;
+import org.apache.beam.runners.core.construction.PTransformTranslation;
 import org.apache.beam.runners.core.construction.PipelineTranslation;
 import org.apache.beam.runners.core.construction.graph.ExecutableStage;
 import org.apache.beam.runners.core.construction.graph.FusedPipeline;
 import org.apache.beam.runners.core.construction.graph.GreedyPipelineFuser;
 import org.apache.beam.runners.core.construction.graph.PipelineNode;
+import org.apache.beam.runners.core.construction.graph.ProtoOverrides;
+import org.apache.beam.runners.core.construction.graph.SplittableParDoExpander;
 import org.apache.beam.runners.core.construction.graph.TimerReference;
 import org.apache.beam.sdk.Pipeline;
 import org.apache.beam.sdk.coders.Coder;
@@ -48,9 +51,12 @@ import org.apache.beam.sdk.state.Timer;
 import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
 import org.apache.beam.sdk.transforms.DoFn;
+import org.apache.beam.sdk.transforms.DoFn.ProcessContext;
+import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.Impulse;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
@@ -151,6 +157,99 @@ public class ProcessBundleDescriptorsTest implements Serializable {
     ensureLengthPrefixed(timerKeyCoder, originalKeyCoder, pbsCoderMap);
   }
 
+  @Test
+  public void testLengthPrefixingOfInputCoderExecutableStage() throws Exception {
+    Pipeline p = Pipeline.create();
+    Coder<Void> voidCoder = VoidCoder.of();
+    assertThat(ModelCoderRegistrar.isKnownCoder(voidCoder), is(false));
+    p.apply("impulse", Impulse.create())
+        .apply(
+            ParDo.of(
+                new DoFn<byte[], Void>() {
+                  @ProcessElement
+                  public void process(ProcessContext ctxt) {}
+                }))
+        .setCoder(voidCoder)
+        .apply(
+            ParDo.of(
+                new DoFn<Void, Void>() {
+                  @ProcessElement
+                  public void processElement(
+                      ProcessContext context, RestrictionTracker<Void, Void> tracker) {}
+
+                  @GetInitialRestriction
+                  public Void getInitialRestriction() {
+                    return null;
+                  }
+
+                  @NewTracker
+                  public SomeTracker newTracker(@Restriction Void restriction) {
+                    return null;
+                  }
+                }))
+        .setCoder(voidCoder);
+    RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
+    RunnerApi.Pipeline pipelineWithSdfExpanded =
+        ProtoOverrides.updateTransform(
+            PTransformTranslation.PAR_DO_TRANSFORM_URN,
+            pipelineProto,
+            SplittableParDoExpander.createSizedReplacement());
+    FusedPipeline fused = GreedyPipelineFuser.fuse(pipelineWithSdfExpanded);
+    Optional<ExecutableStage> optionalStage =
+        Iterables.tryFind(
+            fused.getFusedStages(),
+            (ExecutableStage stage) ->
+                stage.getTransforms().stream()
+                    .anyMatch(
+                        transform ->
+                            transform
+                                .getTransform()
+                                .getSpec()
+                                .getUrn()
+                                .equals(
+                                    PTransformTranslation
+                                        .SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN)));
+    checkState(
+        optionalStage.isPresent(),
+        "Expected a stage with SPLITTABLE_PROCESS_SIZED_ELEMENTS_AND_RESTRICTIONS_URN.");
+
+    ExecutableStage stage = optionalStage.get();
+    PipelineNode.PCollectionNode inputPCollection = stage.getInputPCollection();
+    Map<String, RunnerApi.Coder> stageCoderMap = stage.getComponents().getCodersMap();
+    RunnerApi.Coder originalMainInputCoder =
+        stageCoderMap.get(inputPCollection.getPCollection().getCoderId());
+
+    BeamFnApi.ProcessBundleDescriptor pbd =
+        ProcessBundleDescriptors.fromExecutableStage(
+                "test_stage", stage, Endpoints.ApiServiceDescriptor.getDefaultInstance())
+            .getProcessBundleDescriptor();
+    Map<String, RunnerApi.Coder> pbsCoderMap = pbd.getCodersMap();
+
+    RunnerApi.Coder pbsMainInputCoder =
+        pbsCoderMap.get(pbd.getPcollectionsOrThrow(inputPCollection.getId()).getCoderId());
+
+    RunnerApi.Coder kvCoder =
+        pbsCoderMap.get(ModelCoders.getKvCoderComponents(pbsMainInputCoder).keyCoderId());
+    RunnerApi.Coder keyCoder =
+        pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).keyCoderId());
+    RunnerApi.Coder valueKvCoder =
+        pbsCoderMap.get(ModelCoders.getKvCoderComponents(kvCoder).valueCoderId());
+    RunnerApi.Coder valueCoder =
+        pbsCoderMap.get(ModelCoders.getKvCoderComponents(valueKvCoder).keyCoderId());
+
+    RunnerApi.Coder originalKvCoder =
+        stageCoderMap.get(ModelCoders.getKvCoderComponents(originalMainInputCoder).keyCoderId());
+    RunnerApi.Coder originalKeyCoder =
+        stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).keyCoderId());
+    RunnerApi.Coder originalvalueKvCoder =
+        stageCoderMap.get(ModelCoders.getKvCoderComponents(originalKvCoder).valueCoderId());
+    RunnerApi.Coder originalvalueCoder =
+        stageCoderMap.get(ModelCoders.getKvCoderComponents(originalvalueKvCoder).keyCoderId());
+
+    ensureLengthPrefixed(keyCoder, originalKeyCoder, pbsCoderMap);
+    ensureLengthPrefixed(valueCoder, originalvalueCoder, pbsCoderMap);
+  }
+
   private static void ensureLengthPrefixed(
       RunnerApi.Coder coder,
       RunnerApi.Coder originalCoder,
@@ -160,4 +259,6 @@ public class ProcessBundleDescriptorsTest implements Serializable {
     String lengthPrefixedWrappedCoderId = coder.getComponentCoderIds(0);
     assertThat(pbsCoderMap.get(lengthPrefixedWrappedCoderId), is(originalCoder));
   }
+
+  private abstract static class SomeTracker extends RestrictionTracker<Void, Void> {}
 }