You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by mx...@apache.org on 2019/01/10 21:17:50 UTC

[beam] 01/02: [BEAM-6326] Fix portable stateful processing with side input

This is an automated email from the ASF dual-hosted git repository.

mxm pushed a commit to branch release-2.10.0
in repository https://gitbox.apache.org/repos/asf/beam.git

commit ae38b89c9c4f4f42ca7a896c7fbb1b2fed2c7abb
Author: Maximilian Michels <mx...@apache.org>
AuthorDate: Wed Jan 9 19:55:31 2019 -0500

    [BEAM-6326] Fix portable stateful processing with side input
    
    This adds support in the portable Flink Runner for stateful ParDo processing
    when side inputs are present. The processing would previously fail because
    operator state was not partitioned by key when the operator had been connected
    to the operator producing the side input.
    
    This is tested by ParDoTest#testBagStateSideInput.
---
 .../FlinkStreamingPortablePipelineTranslator.java  | 41 +++++++++++++++++++---
 .../state/FlinkKeyGroupStateInternals.java         |  6 ++--
 2 files changed, 40 insertions(+), 7 deletions(-)

diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
index 838baa8..51d467b 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java
@@ -95,6 +95,7 @@ import org.apache.flink.streaming.api.datastream.KeyedStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.OutputTag;
 
@@ -583,9 +584,11 @@ public class FlinkStreamingPortablePipelineTranslator
     final Coder<WindowedValue<InputT>> windowedInputCoder =
         instantiateCoder(inputPCollectionId, components);
 
+    final boolean stateful =
+        stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0;
     Coder keyCoder = null;
     KeySelector<WindowedValue<InputT>, ?> keySelector = null;
-    if (stagePayload.getUserStatesCount() > 0 || stagePayload.getTimersCount() > 0) {
+    if (stateful) {
       // Stateful stages are only allowed of KV input
       Coder valueCoder =
           ((WindowedValue.FullWindowedValueCoder) windowedInputCoder).getValueCoder();
@@ -632,10 +635,38 @@ public class FlinkStreamingPortablePipelineTranslator
     if (transformedSideInputs.unionTagToView.isEmpty()) {
       outputStream = inputDataStream.transform(operatorName, outputTypeInformation, doFnOperator);
     } else {
-      outputStream =
-          inputDataStream
-              .connect(transformedSideInputs.unionedSideInputs.broadcast())
-              .transform(operatorName, outputTypeInformation, doFnOperator);
+      DataStream<RawUnionValue> sideInputStream =
+          transformedSideInputs.unionedSideInputs.broadcast();
+      if (stateful) {
+        // We have to manually construct the two-input transform because we're not
+        // allowed to have only one input keyed, normally. Since Flink 1.5.0 it's
+        // possible to use the Broadcast State Pattern which provides a more elegant
+        // way to process keyed main input with broadcast state, but it's not feasible
+        // here because it breaks the DoFnOperator abstraction.
+        TwoInputTransformation<WindowedValue<KV<?, InputT>>, RawUnionValue, WindowedValue<OutputT>>
+            rawFlinkTransform =
+                new TwoInputTransformation(
+                    inputDataStream.getTransformation(),
+                    sideInputStream.getTransformation(),
+                    transform.getUniqueName(),
+                    doFnOperator,
+                    outputTypeInformation,
+                    inputDataStream.getParallelism());
+
+        rawFlinkTransform.setStateKeyType(((KeyedStream) inputDataStream).getKeyType());
+        rawFlinkTransform.setStateKeySelectors(
+            ((KeyedStream) inputDataStream).getKeySelector(), null);
+
+        outputStream =
+            new SingleOutputStreamOperator(
+                inputDataStream.getExecutionEnvironment(),
+                rawFlinkTransform) {}; // we have to cheat around the ctor being protected
+      } else {
+        outputStream =
+            inputDataStream
+                .connect(sideInputStream)
+                .transform(operatorName, outputTypeInformation, doFnOperator);
+      }
     }
 
     if (mainOutputTag != null) {
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java
index 0ca767f..01026db 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java
@@ -71,8 +71,10 @@ public class FlinkKeyGroupStateInternals<K> implements StateInternals {
   private final Map<String, Tuple2<Coder<?>, Map<String, ?>>>[] stateTables;
 
   public FlinkKeyGroupStateInternals(Coder<K> keyCoder, KeyedStateBackend keyedStateBackend) {
-    this.keyCoder = keyCoder;
-    this.keyedStateBackend = keyedStateBackend;
+    this.keyCoder = Preconditions.checkNotNull(keyCoder, "Coder for key must be provided.");
+    this.keyedStateBackend =
+        Preconditions.checkNotNull(
+            keyedStateBackend, "KeyedStateBackend must not be null. Missing keyBy call?");
     this.localKeyGroupRange = keyedStateBackend.getKeyGroupRange();
     // find the starting index of the local key-group range
     int startIdx = Integer.MAX_VALUE;