You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ie...@apache.org on 2019/03/29 14:41:14 UTC

[beam] branch master updated: [BEAM-6929] Prevent NullPointerException in Flink's CombiningState

This is an automated email from the ASF dual-hosted git repository.

iemejia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ee97c8  [BEAM-6929] Prevent NullPointerException in Flink's CombiningState
     new fb33e22  Merge pull request #8162: [BEAM-6929] Prevent NullPointerException in Flink's CombiningState
7ee97c8 is described below

commit 7ee97c80291385cc5b688c00897de8b9ab9f89d8
Author: Maximilian Michels <mx...@apache.org>
AuthorDate: Thu Mar 28 20:48:24 2019 +0100

    [BEAM-6929] Prevent NullPointerException in Flink's CombiningState
    
    When the accumulator is retrieved and it has not been initialized yet, the
    Flink's StateInternals return null. This can lead to a NullPointerException for
    Session windows with allowed lateness, where the state is cleared and the
    accumulator is retrieved afterwards.
---
 .../beam/runners/core/StateInternalsTest.java      | 92 +++++++++++++++++++++-
 .../state/FlinkBroadcastStateInternals.java        | 12 ++-
 .../streaming/state/FlinkStateInternals.java       | 20 +++--
 .../streaming/FlinkSplitStateInternalsTest.java    |  8 ++
 .../spark/stateful/SparkStateInternals.java        |  3 +-
 5 files changed, 121 insertions(+), 14 deletions(-)

diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
index f4093fa..fced746 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
@@ -17,14 +17,15 @@
  */
 package org.apache.beam.runners.core;
 
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasItems;
 import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.notNullValue;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 
 import java.io.IOException;
@@ -46,6 +47,7 @@ import org.apache.beam.sdk.state.ReadableState;
 import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.ValueState;
 import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.transforms.CombineWithContext;
 import org.apache.beam.sdk.transforms.Sum;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
@@ -68,6 +70,10 @@ public abstract class StateInternalsTest {
       StateTags.value("stringValue", StringUtf8Coder.of());
   private static final StateTag<CombiningState<Integer, int[], Integer>> SUM_INTEGER_ADDR =
       StateTags.combiningValueFromInputInternal("sumInteger", VarIntCoder.of(), Sum.ofIntegers());
+  private static final StateTag<CombiningState<Integer, Integer, Integer>>
+      SUM_INTEGER_CONTEXT_ADDR =
+          StateTags.combiningValueWithContext(
+              "sumIntegerWithContext", VarIntCoder.of(), new SummingContextFn());
   private static final StateTag<BagState<String>> STRING_BAG_ADDR =
       StateTags.bag("stringBag", StringUtf8Coder.of());
   private static final StateTag<SetState<String>> STRING_SET_ADDR =
@@ -427,6 +433,9 @@ public abstract class StateInternalsTest {
     CombiningState<Integer, int[], Integer> value1 = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
     CombiningState<Integer, int[], Integer> value2 = underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
 
+    assertThat(value1.getAccum(), Matchers.is(notNullValue()));
+    assertThat(value2.getAccum(), Matchers.is(notNullValue()));
+
     value1.add(5);
     value2.add(10);
     value1.add(6);
@@ -447,6 +456,59 @@ public abstract class StateInternalsTest {
     CombiningState<Integer, int[], Integer> value2 = underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
     CombiningState<Integer, int[], Integer> value3 = underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR);
 
+    assertThat(value1.getAccum(), Matchers.is(notNullValue()));
+    assertThat(value2.getAccum(), Matchers.is(notNullValue()));
+    assertThat(value3.getAccum(), Matchers.is(notNullValue()));
+
+    value1.add(5);
+    value2.add(10);
+    value1.add(6);
+
+    StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3);
+
+    // Merging clears the old values and updates the result value.
+    assertThat(value1.read(), equalTo(0));
+    assertThat(value2.read(), equalTo(0));
+    assertThat(value3.read(), equalTo(21));
+  }
+
+  @Test
+  public void testMergeCombiningWithContextValueIntoSource() throws Exception {
+    CombiningState<Integer, Integer, Integer> value1 =
+        underTest.state(NAMESPACE_1, SUM_INTEGER_CONTEXT_ADDR);
+    CombiningState<Integer, Integer, Integer> value2 =
+        underTest.state(NAMESPACE_2, SUM_INTEGER_CONTEXT_ADDR);
+
+    assertThat(value1.getAccum(), Matchers.is(notNullValue()));
+    assertThat(value2.getAccum(), Matchers.is(notNullValue()));
+
+    value1.add(5);
+    value2.add(10);
+    value1.add(6);
+
+    assertThat(value1.read(), equalTo(11));
+    assertThat(value2.read(), equalTo(10));
+
+    // Merging clears the old values and updates the result value.
+    StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1);
+
+    assertThat(value1.read(), equalTo(21));
+    assertThat(value2.read(), equalTo(0));
+  }
+
+  @Test
+  public void testMergeCombiningWithContextValueIntoNewNamespace() throws Exception {
+    CombiningState<Integer, Integer, Integer> value1 =
+        underTest.state(NAMESPACE_1, SUM_INTEGER_CONTEXT_ADDR);
+    CombiningState<Integer, Integer, Integer> value2 =
+        underTest.state(NAMESPACE_2, SUM_INTEGER_CONTEXT_ADDR);
+    CombiningState<Integer, Integer, Integer> value3 =
+        underTest.state(NAMESPACE_3, SUM_INTEGER_CONTEXT_ADDR);
+
+    assertThat(value1.getAccum(), Matchers.is(notNullValue()));
+    assertThat(value2.getAccum(), Matchers.is(notNullValue()));
+    assertThat(value3.getAccum(), Matchers.is(notNullValue()));
+
     value1.add(5);
     value2.add(10);
     value1.add(6);
@@ -619,4 +681,32 @@ public abstract class StateInternalsTest {
       return super.hashCode();
     }
   }
+
+  private static class SummingContextFn
+      extends CombineWithContext.CombineFnWithContext<Integer, Integer, Integer> {
+
+    @Override
+    public Integer createAccumulator(CombineWithContext.Context c) {
+      return 0;
+    }
+
+    @Override
+    public Integer addInput(Integer accumulator, Integer input, CombineWithContext.Context c) {
+      return accumulator + input;
+    }
+
+    @Override
+    public Integer mergeAccumulators(Iterable<Integer> accumulators, CombineWithContext.Context c) {
+      int sum = createAccumulator(c);
+      for (Integer accumulator : accumulators) {
+        sum += accumulator;
+      }
+      return sum;
+    }
+
+    @Override
+    public Integer extractOutput(Integer accumulator, CombineWithContext.Context c) {
+      return accumulator;
+    }
+  }
 }
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java
index 9dca8b2..69ff465 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java
@@ -465,7 +465,8 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals {
 
     @Override
     public AccumT getAccum() {
-      return readInternal();
+      AccumT accum = readInternal();
+      return accum != null ? accum : combineFn.createAccumulator();
     }
 
     @Override
@@ -589,7 +590,8 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals {
     @Override
     public AccumT getAccum() {
       try {
-        return readInternal();
+        AccumT accum = readInternal();
+        return accum != null ? accum : combineFn.createAccumulator();
       } catch (Exception e) {
         throw new RuntimeException("Error reading state.", e);
       }
@@ -724,7 +726,8 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals {
     @Override
     public AccumT getAccum() {
       try {
-        return readInternal();
+        AccumT accum = readInternal();
+        return accum != null ? accum : combineFn.createAccumulator(context);
       } catch (Exception e) {
         throw new RuntimeException("Error reading state.", e);
       }
@@ -739,6 +742,9 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals {
     public OutputT read() {
       try {
         AccumT accum = readInternal();
+        if (accum == null) {
+          accum = combineFn.createAccumulator(context);
+        }
         return combineFn.extractOutput(accum, context);
       } catch (Exception e) {
         throw new RuntimeException("Error reading state.", e);
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
index 3877fc5..4407d33 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
@@ -503,10 +503,12 @@ public class FlinkStateInternals<K> implements StateInternals {
     @Override
     public AccumT getAccum() {
       try {
-        return flinkStateBackend
-            .getPartitionedState(
-                namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
-            .value();
+        AccumT accum =
+            flinkStateBackend
+                .getPartitionedState(
+                    namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
+                .value();
+        return accum != null ? accum : combineFn.createAccumulator();
       } catch (Exception e) {
         throw new RuntimeException("Error reading state.", e);
       }
@@ -665,10 +667,12 @@ public class FlinkStateInternals<K> implements StateInternals {
     @Override
     public AccumT getAccum() {
       try {
-        return flinkStateBackend
-            .getPartitionedState(
-                namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
-            .value();
+        AccumT accum =
+            flinkStateBackend
+                .getPartitionedState(
+                    namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
+                .value();
+        return accum != null ? accum : combineFn.createAccumulator(context);
       } catch (Exception e) {
         throw new RuntimeException("Error reading state.", e);
       }
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
index e146095..eb27335 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
@@ -98,6 +98,14 @@ public class FlinkSplitStateInternalsTest extends StateInternalsTest {
 
   @Override
   @Ignore
+  public void testMergeCombiningWithContextValueIntoSource() {}
+
+  @Override
+  @Ignore
+  public void testMergeCombiningWithContextValueIntoNewNamespace() {}
+
+  @Override
+  @Ignore
   public void testWatermarkEarliestState() {}
 
   @Override
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
index 41dbbfa..3ed403f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
@@ -297,8 +297,7 @@ class SparkStateInternals<K> implements StateInternals {
 
     @Override
     public void add(InputT input) {
-      AccumT accum = getAccum();
-      combineFn.addInput(accum, input);
+      AccumT accum = combineFn.addInput(getAccum(), input);
       writeValue(accum);
     }