You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ie...@apache.org on 2019/03/29 14:41:14 UTC
[beam] branch master updated: [BEAM-6929] Prevent
NullPointerException in Flink's CombiningState
This is an automated email from the ASF dual-hosted git repository.
iemejia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 7ee97c8 [BEAM-6929] Prevent NullPointerException in Flink's CombiningState
new fb33e22 Merge pull request #8162: [BEAM-6929] Prevent NullPointerException in Flink's CombiningState
7ee97c8 is described below
commit 7ee97c80291385cc5b688c00897de8b9ab9f89d8
Author: Maximilian Michels <mx...@apache.org>
AuthorDate: Thu Mar 28 20:48:24 2019 +0100
[BEAM-6929] Prevent NullPointerException in Flink's CombiningState
When the accumulator is retrieved and it has not been initialized yet, the
Flink's StateInternals return null. This can lead to a NullPointerException for
Session windows with allowed lateness, where the state is cleared and the
accumulator is retrieved afterwards.
---
.../beam/runners/core/StateInternalsTest.java | 92 +++++++++++++++++++++-
.../state/FlinkBroadcastStateInternals.java | 12 ++-
.../streaming/state/FlinkStateInternals.java | 20 +++--
.../streaming/FlinkSplitStateInternalsTest.java | 8 ++
.../spark/stateful/SparkStateInternals.java | 3 +-
5 files changed, 121 insertions(+), 14 deletions(-)
diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
index f4093fa..fced746 100644
--- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
+++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java
@@ -17,14 +17,15 @@
*/
package org.apache.beam.runners.core;
+import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItems;
import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.notNullValue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import java.io.IOException;
@@ -46,6 +47,7 @@ import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.IntervalWindow;
@@ -68,6 +70,10 @@ public abstract class StateInternalsTest {
StateTags.value("stringValue", StringUtf8Coder.of());
private static final StateTag<CombiningState<Integer, int[], Integer>> SUM_INTEGER_ADDR =
StateTags.combiningValueFromInputInternal("sumInteger", VarIntCoder.of(), Sum.ofIntegers());
+ private static final StateTag<CombiningState<Integer, Integer, Integer>>
+ SUM_INTEGER_CONTEXT_ADDR =
+ StateTags.combiningValueWithContext(
+ "sumIntegerWithContext", VarIntCoder.of(), new SummingContextFn());
private static final StateTag<BagState<String>> STRING_BAG_ADDR =
StateTags.bag("stringBag", StringUtf8Coder.of());
private static final StateTag<SetState<String>> STRING_SET_ADDR =
@@ -427,6 +433,9 @@ public abstract class StateInternalsTest {
CombiningState<Integer, int[], Integer> value1 = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR);
CombiningState<Integer, int[], Integer> value2 = underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
+ assertThat(value1.getAccum(), Matchers.is(notNullValue()));
+ assertThat(value2.getAccum(), Matchers.is(notNullValue()));
+
value1.add(5);
value2.add(10);
value1.add(6);
@@ -447,6 +456,59 @@ public abstract class StateInternalsTest {
CombiningState<Integer, int[], Integer> value2 = underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR);
CombiningState<Integer, int[], Integer> value3 = underTest.state(NAMESPACE_3, SUM_INTEGER_ADDR);
+ assertThat(value1.getAccum(), Matchers.is(notNullValue()));
+ assertThat(value2.getAccum(), Matchers.is(notNullValue()));
+ assertThat(value3.getAccum(), Matchers.is(notNullValue()));
+
+ value1.add(5);
+ value2.add(10);
+ value1.add(6);
+
+ StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3);
+
+ // Merging clears the old values and updates the result value.
+ assertThat(value1.read(), equalTo(0));
+ assertThat(value2.read(), equalTo(0));
+ assertThat(value3.read(), equalTo(21));
+ }
+
+ @Test
+ public void testMergeCombiningWithContextValueIntoSource() throws Exception {
+ CombiningState<Integer, Integer, Integer> value1 =
+ underTest.state(NAMESPACE_1, SUM_INTEGER_CONTEXT_ADDR);
+ CombiningState<Integer, Integer, Integer> value2 =
+ underTest.state(NAMESPACE_2, SUM_INTEGER_CONTEXT_ADDR);
+
+ assertThat(value1.getAccum(), Matchers.is(notNullValue()));
+ assertThat(value2.getAccum(), Matchers.is(notNullValue()));
+
+ value1.add(5);
+ value2.add(10);
+ value1.add(6);
+
+ assertThat(value1.read(), equalTo(11));
+ assertThat(value2.read(), equalTo(10));
+
+ // Merging clears the old values and updates the result value.
+ StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1);
+
+ assertThat(value1.read(), equalTo(21));
+ assertThat(value2.read(), equalTo(0));
+ }
+
+ @Test
+ public void testMergeCombiningWithContextValueIntoNewNamespace() throws Exception {
+ CombiningState<Integer, Integer, Integer> value1 =
+ underTest.state(NAMESPACE_1, SUM_INTEGER_CONTEXT_ADDR);
+ CombiningState<Integer, Integer, Integer> value2 =
+ underTest.state(NAMESPACE_2, SUM_INTEGER_CONTEXT_ADDR);
+ CombiningState<Integer, Integer, Integer> value3 =
+ underTest.state(NAMESPACE_3, SUM_INTEGER_CONTEXT_ADDR);
+
+ assertThat(value1.getAccum(), Matchers.is(notNullValue()));
+ assertThat(value2.getAccum(), Matchers.is(notNullValue()));
+ assertThat(value3.getAccum(), Matchers.is(notNullValue()));
+
value1.add(5);
value2.add(10);
value1.add(6);
@@ -619,4 +681,32 @@ public abstract class StateInternalsTest {
return super.hashCode();
}
}
+
+ private static class SummingContextFn
+ extends CombineWithContext.CombineFnWithContext<Integer, Integer, Integer> {
+
+ @Override
+ public Integer createAccumulator(CombineWithContext.Context c) {
+ return 0;
+ }
+
+ @Override
+ public Integer addInput(Integer accumulator, Integer input, CombineWithContext.Context c) {
+ return accumulator + input;
+ }
+
+ @Override
+ public Integer mergeAccumulators(Iterable<Integer> accumulators, CombineWithContext.Context c) {
+ int sum = createAccumulator(c);
+ for (Integer accumulator : accumulators) {
+ sum += accumulator;
+ }
+ return sum;
+ }
+
+ @Override
+ public Integer extractOutput(Integer accumulator, CombineWithContext.Context c) {
+ return accumulator;
+ }
+ }
}
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java
index 9dca8b2..69ff465 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java
@@ -465,7 +465,8 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals {
@Override
public AccumT getAccum() {
- return readInternal();
+ AccumT accum = readInternal();
+ return accum != null ? accum : combineFn.createAccumulator();
}
@Override
@@ -589,7 +590,8 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals {
@Override
public AccumT getAccum() {
try {
- return readInternal();
+ AccumT accum = readInternal();
+ return accum != null ? accum : combineFn.createAccumulator();
} catch (Exception e) {
throw new RuntimeException("Error reading state.", e);
}
@@ -724,7 +726,8 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals {
@Override
public AccumT getAccum() {
try {
- return readInternal();
+ AccumT accum = readInternal();
+ return accum != null ? accum : combineFn.createAccumulator(context);
} catch (Exception e) {
throw new RuntimeException("Error reading state.", e);
}
@@ -739,6 +742,9 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals {
public OutputT read() {
try {
AccumT accum = readInternal();
+ if (accum == null) {
+ accum = combineFn.createAccumulator(context);
+ }
return combineFn.extractOutput(accum, context);
} catch (Exception e) {
throw new RuntimeException("Error reading state.", e);
diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
index 3877fc5..4407d33 100644
--- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
+++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
@@ -503,10 +503,12 @@ public class FlinkStateInternals<K> implements StateInternals {
@Override
public AccumT getAccum() {
try {
- return flinkStateBackend
- .getPartitionedState(
- namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
- .value();
+ AccumT accum =
+ flinkStateBackend
+ .getPartitionedState(
+ namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
+ .value();
+ return accum != null ? accum : combineFn.createAccumulator();
} catch (Exception e) {
throw new RuntimeException("Error reading state.", e);
}
@@ -665,10 +667,12 @@ public class FlinkStateInternals<K> implements StateInternals {
@Override
public AccumT getAccum() {
try {
- return flinkStateBackend
- .getPartitionedState(
- namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
- .value();
+ AccumT accum =
+ flinkStateBackend
+ .getPartitionedState(
+ namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor)
+ .value();
+ return accum != null ? accum : combineFn.createAccumulator(context);
} catch (Exception e) {
throw new RuntimeException("Error reading state.", e);
}
diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
index e146095..eb27335 100644
--- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
+++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkSplitStateInternalsTest.java
@@ -98,6 +98,14 @@ public class FlinkSplitStateInternalsTest extends StateInternalsTest {
@Override
@Ignore
+ public void testMergeCombiningWithContextValueIntoSource() {}
+
+ @Override
+ @Ignore
+ public void testMergeCombiningWithContextValueIntoNewNamespace() {}
+
+ @Override
+ @Ignore
public void testWatermarkEarliestState() {}
@Override
diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
index 41dbbfa..3ed403f 100644
--- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
+++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java
@@ -297,8 +297,7 @@ class SparkStateInternals<K> implements StateInternals {
@Override
public void add(InputT input) {
- AccumT accum = getAccum();
- combineFn.addInput(accum, input);
+ AccumT accum = combineFn.addInput(getAccum(), input);
writeValue(accum);
}