You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by gy...@apache.org on 2015/09/14 18:33:46 UTC

flink git commit: [FLINK-2664] [streaming] Allow partitioned state removal

Repository: flink
Updated Branches:
  refs/heads/master ce68cbd91 -> 8a75025ab


[FLINK-2664] [streaming] Allow partitioned state removal

Closes #1126


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8a75025a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8a75025a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8a75025a

Branch: refs/heads/master
Commit: 8a75025abb229c8ebfa737525ed76c042379d7fd
Parents: ce68cbd
Author: Gyula Fora <gy...@apache.org>
Authored: Mon Sep 14 16:54:00 2015 +0200
Committer: Gyula Fora <gy...@apache.org>
Committed: Mon Sep 14 18:33:04 2015 +0200

----------------------------------------------------------------------
 .../flink/api/common/state/OperatorState.java   |   4 +-
 .../streaming/api/state/EagerStateStore.java    |  22 ++--
 .../api/state/PartitionedStateStore.java        |   9 +-
 .../state/PartitionedStreamOperatorState.java   |  39 ++++---
 .../api/state/StatefulOperatorTest.java         | 102 +++++++++++++++----
 5 files changed, 126 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/8a75025a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorState.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorState.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorState.java
index 955b35b..3f5e977 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorState.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorState.java
@@ -55,7 +55,9 @@ public interface OperatorState<T> {
 	/**
 	 * Updates the operator state accessible by {@link #value()} to the given
 	 * value. The next time {@link #value()} is called (for the same state
-	 * partition) the returned state will represent the updated value.
+	 * partition) the returned state will represent the updated value. When a
+	 * partitioned state is updated with null, the state for the current key 
+	 * will be removed and the default value is returned on the next access.
 	 * 
 	 * @param state
 	 *            The new value for the state.

http://git-wip-us.apache.org/repos/asf/flink/blob/8a75025a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java
index 3277b3f..2091624 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/EagerStateStore.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.streaming.api.state;
 
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.HashMap;
 import java.util.Map;
@@ -29,7 +30,7 @@ import org.apache.flink.runtime.state.StateHandleProvider;
 
 public class EagerStateStore<S, C extends Serializable> implements PartitionedStateStore<S, C> {
 
-	private StateCheckpointer<S,C> checkpointer;
+	private StateCheckpointer<S, C> checkpointer;
 	private final StateHandleProvider<Serializable> provider;
 
 	private Map<Serializable, S> fetchedState;
@@ -43,7 +44,7 @@ public class EagerStateStore<S, C extends Serializable> implements PartitionedSt
 	}
 
 	@Override
-	public S getStateForKey(Serializable key) throws Exception {
+	public S getStateForKey(Serializable key) throws IOException {
 		return fetchedState.get(key);
 	}
 
@@ -53,7 +54,12 @@ public class EagerStateStore<S, C extends Serializable> implements PartitionedSt
 	}
 
 	@Override
-	public Map<Serializable, S> getPartitionedState() throws Exception {
+	public void removeStateForKey(Serializable key) {
+		fetchedState.remove(key);
+	}
+
+	@Override
+	public Map<Serializable, S> getPartitionedState() throws IOException {
 		return fetchedState;
 	}
 
@@ -69,11 +75,12 @@ public class EagerStateStore<S, C extends Serializable> implements PartitionedSt
 	}
 
 	@Override
-	public void restoreStates(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception {
-		
+	public void restoreStates(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader)
+			throws Exception {
+
 		@SuppressWarnings("unchecked")
 		Map<Serializable, C> checkpoints = (Map<Serializable, C>) snapshot.getState(userCodeClassLoader);
-		
+
 		// we map the values back to the state from the checkpoints
 		for (Entry<Serializable, C> snapshotEntry : checkpoints.entrySet()) {
 			fetchedState.put(snapshotEntry.getKey(), (S) checkpointer.restoreState(snapshotEntry.getValue()));
@@ -89,10 +96,9 @@ public class EagerStateStore<S, C extends Serializable> implements PartitionedSt
 	public void setCheckPointer(StateCheckpointer<S, C> checkpointer) {
 		this.checkpointer = checkpointer;
 	}
-	
+
 	@Override
 	public String toString() {
 		return fetchedState.toString();
 	}
-
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/8a75025a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStateStore.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStateStore.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStateStore.java
index e9a02c1..34bfde7 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStateStore.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStateStore.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.streaming.api.state;
 
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.Map;
 
@@ -35,13 +36,15 @@ import org.apache.flink.runtime.state.StateHandle;
  */
 public interface PartitionedStateStore<S, C extends Serializable> {
 
-	S getStateForKey(Serializable key) throws Exception;
+	S getStateForKey(Serializable key) throws IOException;
 
 	void setStateForKey(Serializable key, S state);
+	
+	void removeStateForKey(Serializable key);
 
-	Map<Serializable, S> getPartitionedState() throws Exception;
+	Map<Serializable, S> getPartitionedState() throws IOException;
 
-	StateHandle<Serializable> snapshotStates(long checkpointId, long checkpointTimestamp) throws Exception;
+	StateHandle<Serializable> snapshotStates(long checkpointId, long checkpointTimestamp) throws IOException;
 
 	void restoreStates(StateHandle<Serializable> snapshot, ClassLoader userCodeClassLoader) throws Exception;
 

http://git-wip-us.apache.org/repos/asf/flink/blob/8a75025a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
index e9ebfb6..115a97c 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/state/PartitionedStreamOperatorState.java
@@ -42,8 +42,7 @@ import org.apache.flink.util.InstantiationUtil;
  * @param <C>
  *            Type of the state snapshot.
  */
-public class PartitionedStreamOperatorState<IN, S, C extends Serializable> extends
-		StreamOperatorState<S, C> {
+public class PartitionedStreamOperatorState<IN, S, C extends Serializable> extends StreamOperatorState<S, C> {
 
 	// KeySelector for getting the state partition key for each input
 	private final KeySelector<IN, Serializable> keySelector;
@@ -77,41 +76,50 @@ public class PartitionedStreamOperatorState<IN, S, C extends Serializable> exten
 		if (currentInput == null) {
 			throw new IllegalStateException("Need a valid input for accessing the state.");
 		} else {
+			Serializable key;
 			try {
-				Serializable key = keySelector.getKey(currentInput);
-				if (stateStore.containsKey(key)) {
-					return stateStore.getStateForKey(key);
-				} else {
+				key = keySelector.getKey(currentInput);
+			} catch (Exception e) {
+				throw new RuntimeException("User-defined key selector threw an exception.", e);
+			}
+			if (stateStore.containsKey(key)) {
+				return stateStore.getStateForKey(key);
+			} else {
+				try {
 					return (S) checkpointer.restoreState((C) InstantiationUtil.deserializeObject(
 							defaultState, cl));
+				} catch (ClassNotFoundException e) {
+					throw new RuntimeException("Could not deserialize default state value.", e);
 				}
-			} catch (Exception e) {
-				throw new RuntimeException("User-defined key selector threw an exception.", e);
 			}
 		}
 	}
 
 	@Override
 	public void update(S state) throws IOException {
-		if (state == null) {
-			throw new RuntimeException("Cannot set state to null.");
-		}
 		if (currentInput == null) {
 			throw new IllegalStateException("Need a valid input for updating a state.");
 		} else {
+			Serializable key;
 			try {
-				stateStore.setStateForKey(keySelector.getKey(currentInput), state);
+				key = keySelector.getKey(currentInput);
 			} catch (Exception e) {
 				throw new RuntimeException("User-defined key selector threw an exception.");
 			}
+			
+			if (state == null) {
+				// Remove state if set to null
+				stateStore.removeStateForKey(key);
+			} else {
+				stateStore.setStateForKey(key, state);
+			}
 		}
 	}
 
 	@Override
 	public void setDefaultState(S defaultState) {
 		try {
-			this.defaultState = InstantiationUtil.serializeObject(checkpointer.snapshotState(
-					defaultState, 0, 0));
+			this.defaultState = InstantiationUtil.serializeObject(checkpointer.snapshotState(defaultState, 0, 0));
 		} catch (IOException e) {
 			throw new RuntimeException("Default state must be serializable.");
 		}
@@ -122,8 +130,7 @@ public class PartitionedStreamOperatorState<IN, S, C extends Serializable> exten
 	}
 
 	@Override
-	public StateHandle<Serializable> snapshotState(long checkpointId,
-			long checkpointTimestamp) throws Exception {
+	public StateHandle<Serializable> snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
 		return stateStore.snapshotStates(checkpointId, checkpointTimestamp);
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/8a75025a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
index c19c548..48bb1f3 100644
--- a/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
+++ b/flink-staging/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/state/StatefulOperatorTest.java
@@ -19,6 +19,8 @@
 package org.apache.flink.streaming.api.state;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import java.io.IOException;
@@ -101,25 +103,28 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
 		assertEquals((Integer) 7, ((StatefulMapper) restoredMap.getUserFunction()).checkpointedCounter);
 
 	}
-	
+
 	@Test
 	public void apiTest() throws Exception {
 		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 		env.setParallelism(3);
 
 		KeyedDataStream<Integer> keyedStream = env.fromCollection(Arrays.asList(0, 1, 2, 3, 4, 5, 6)).keyBy(new ModKey(4));
-		
+
 		keyedStream.map(new StatefulMapper()).addSink(new SinkFunction<String>() {
 			private static final long serialVersionUID = 1L;
+
 			public void invoke(String value) throws Exception {
 			}
 		});
-		
+
 		keyedStream.map(new StatefulMapper2()).setParallelism(1).addSink(new SinkFunction<String>() {
 			private static final long serialVersionUID = 1L;
-			public void invoke(String value) throws Exception {}
+
+			public void invoke(String value) throws Exception {
+			}
 		});
-		
+
 		try {
 			keyedStream.shuffle();
 			fail();
@@ -127,6 +132,21 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
 
 		}
 		
+		env.fromElements(0, 1, 2, 2, 2, 3, 4, 3, 4).keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 1L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+
+		}).map(new PStateKeyRemovalTestMapper()).setParallelism(1).addSink(new SinkFunction<String>() {
+			private static final long serialVersionUID = 1L;
+
+			public void invoke(String value) throws Exception {
+			}
+		});
+
 		env.execute();
 	}
 
@@ -143,7 +163,7 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
 		final List<String> outputList = output;
 
 		StreamingRuntimeContext context = new StreamingRuntimeContext(
-				new MockEnvironment("MockTask", 3 * 1024 * 1024, new MockInputSplitProvider(), 1024), 
+				new MockEnvironment("MockTask", 3 * 1024 * 1024, new MockInputSplitProvider(), 1024),
 				new ExecutionConfig(),
 				partitioner,
 				new LocalStateHandleProvider<Serializable>(),
@@ -181,11 +201,11 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
 
 	public static class StatefulMapper extends RichMapFunction<Integer, String> implements
 			Checkpointed<Integer> {
-	private static final long serialVersionUID = -9007873655253339356L;
+		private static final long serialVersionUID = -9007873655253339356L;
 		OperatorState<Integer> counter;
 		OperatorState<MutableInt> groupCounter;
 		OperatorState<String> concat;
-		
+
 		Integer checkpointedCounter = 0;
 
 		@Override
@@ -199,7 +219,7 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
 			try {
 				counter.update(null);
 				fail();
-			} catch (RuntimeException e){
+			} catch (RuntimeException e) {
 			}
 			return value.toString();
 		}
@@ -212,15 +232,15 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
 			try {
 				getRuntimeContext().getOperatorState("test", null, true);
 				fail();
-			} catch (RuntimeException e){
+			} catch (RuntimeException e) {
 			}
 			try {
 				getRuntimeContext().getOperatorState("test", null, true, null);
 				fail();
-			} catch (RuntimeException e){
+			} catch (RuntimeException e) {
 			}
 		}
-		
+
 		@SuppressWarnings("unchecked")
 		@Override
 		public void close() throws Exception {
@@ -229,14 +249,13 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
 			for (Entry<Serializable, Integer> count : groupCounter.getPartitionedState().entrySet()) {
 				Integer key = (Integer) count.getKey();
 				Integer expected = key < 3 ? 2 : 1;
-				
+
 				assertEquals(new MutableInt(expected), count.getValue());
 			}
 		}
 
 		@Override
-		public Integer snapshotState(long checkpointId, long checkpointTimestamp)
-				throws Exception {
+		public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
 			return checkpointedCounter;
 		}
 
@@ -245,23 +264,23 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
 			this.checkpointedCounter = (Integer) state;
 		}
 	}
-	
+
 	public static class StatefulMapper2 extends RichMapFunction<Integer, String> {
 		private static final long serialVersionUID = 1L;
 		OperatorState<Integer> groupCounter;
-		
+
 		@Override
 		public String map(Integer value) throws Exception {
 			groupCounter.update(groupCounter.value() + 1);
-			
+
 			return value.toString();
 		}
 
 		@Override
-		public void open(Configuration conf) throws IOException {		
+		public void open(Configuration conf) throws IOException {
 			groupCounter = getRuntimeContext().getOperatorState("groupCounter", 0, true);
 		}
-		
+
 		@SuppressWarnings("unchecked")
 		@Override
 		public void close() throws Exception {
@@ -274,9 +293,48 @@ public class StatefulOperatorTest extends StreamingMultipleProgramsTestBase {
 				assertEquals(expected, count.getValue());
 			}
 		}
-		
+
 	}
-	
+
+	public static class PStateKeyRemovalTestMapper extends RichMapFunction<Integer, String> {
+
+		private static final long serialVersionUID = 1L;
+		OperatorState<Boolean> seen;
+
+		@Override
+		public String map(Integer value) throws Exception {
+			if (value == 0) {
+				seen.update(null);
+			}else{
+				Boolean s = seen.value();
+				if (s) {
+					seen.update(null);
+				} else {
+					seen.update(true);
+				}
+			}
+
+			return value.toString();
+		}
+
+		public void open(Configuration c) throws IOException {
+			seen = getRuntimeContext().getOperatorState("seen", false, true);
+		}
+
+		@SuppressWarnings("unchecked")
+		@Override
+		public void close() throws Exception {
+			Map<String, StreamOperatorState<?, ?>> states = ((StreamingRuntimeContext) getRuntimeContext()).getOperatorStates();
+			PartitionedStreamOperatorState<Integer, Boolean, Boolean> seen = (PartitionedStreamOperatorState<Integer, Boolean, Boolean>) states.get("seen");
+			assertFalse(seen.getPartitionedState().containsKey(0));
+			assertEquals(2,seen.getPartitionedState().size());
+			for (Entry<Serializable, Boolean> s : seen.getPartitionedState().entrySet()) {
+					assertTrue(s.getValue());
+			}
+		}
+
+	}
+
 	public static class ModKey implements KeySelector<Integer, Serializable> {
 
 		private static final long serialVersionUID = 4193026742083046736L;