You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2017/01/22 22:12:10 UTC

[3/9] flink git commit: [hotfix] Cleanups of the AbstractKeyedStateBackend

[hotfix] Cleanups of the AbstractKeyedStateBackend


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/51c02eee
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/51c02eee
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/51c02eee

Branch: refs/heads/master
Commit: 51c02eee7811116f9879d28e79967685a6a037ac
Parents: 392b2e9
Author: Stephan Ewen <se...@apache.org>
Authored: Fri Jan 13 14:50:30 2017 +0100
Committer: Stephan Ewen <se...@apache.org>
Committed: Sun Jan 22 21:22:21 2017 +0100

----------------------------------------------------------------------
 .../state/AbstractKeyedStateBackend.java        | 116 +++++++++++++------
 1 file changed, 78 insertions(+), 38 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/51c02eee/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index 2daf896..cab2b4f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -291,48 +291,88 @@ public abstract class AbstractKeyedStateBackend<K>
 	}
 
 	@Override
-	@SuppressWarnings("unchecked,rawtypes")
-	public <N, S extends MergingState<?, ?>> void mergePartitionedStates(final N target, Collection<N> sources, final TypeSerializer<N> namespaceSerializer, final StateDescriptor<S, ?> stateDescriptor) throws Exception {
+	public <N, S extends MergingState<?, ?>> void mergePartitionedStates(
+			final N target,
+			Collection<N> sources,
+			final TypeSerializer<N> namespaceSerializer,
+			final StateDescriptor<S, ?> stateDescriptor) throws Exception {
+
 		if (stateDescriptor instanceof ReducingStateDescriptor) {
-			ReducingStateDescriptor reducingStateDescriptor = (ReducingStateDescriptor) stateDescriptor;
-			ReduceFunction reduceFn = reducingStateDescriptor.getReduceFunction();
-			ReducingState state = (ReducingState) getPartitionedState(target, namespaceSerializer, stateDescriptor);
-			KvState kvState = (KvState) state;
-			Object result = null;
-			for (N source: sources) {
-				kvState.setCurrentNamespace(source);
-				Object sourceValue = state.get();
-				if (result == null) {
-					result = state.get();
-				} else if (sourceValue != null) {
-					result = reduceFn.reduce(result, sourceValue);
-				}
-				state.clear();
-			}
-			kvState.setCurrentNamespace(target);
-			if (result != null) {
-				state.add(result);
+			mergeReducingState((ReducingStateDescriptor<?>) stateDescriptor, namespaceSerializer, target,sources);
+		}
+		else if (stateDescriptor instanceof ListStateDescriptor) {
+			mergeListState((ListStateDescriptor<?>) stateDescriptor, namespaceSerializer, target,sources);
+		}
+		else {
+			throw new IllegalArgumentException("Cannot merge states for " + stateDescriptor);
+		}
+	}
+
+	private <N, T> void mergeReducingState(
+			final ReducingStateDescriptor<?> stateDescriptor,
+			final TypeSerializer<N> namespaceSerializer,
+			final N target,
+			final Collection<N> sources) throws Exception {
+
+		@SuppressWarnings("unchecked")
+		final ReducingStateDescriptor<T> reducingStateDescriptor = (ReducingStateDescriptor<T>) stateDescriptor;
+
+		@SuppressWarnings("unchecked")
+		final ReducingState<T> state = (ReducingState<T>) getPartitionedState(target, namespaceSerializer, stateDescriptor);
+
+		@SuppressWarnings("unchecked")
+		final KvState<N> kvState = (KvState<N>) state;
+
+		final ReduceFunction<T> reduceFn = reducingStateDescriptor.getReduceFunction();
+
+		T result = null;
+		for (N source: sources) {
+			kvState.setCurrentNamespace(source);
+			T sourceValue = state.get();
+			if (result == null) {
+				result = state.get();
+			} else if (sourceValue != null) {
+				result = reduceFn.reduce(result, sourceValue);
 			}
-		} else if (stateDescriptor instanceof ListStateDescriptor) {
-			ListState<Object> state = (ListState) getPartitionedState(target, namespaceSerializer, stateDescriptor);
-			KvState kvState = (KvState) state;
-			List<Object> result = new ArrayList<>();
-			for (N source: sources) {
-				kvState.setCurrentNamespace(source);
-				Iterable<Object> sourceValue = state.get();
-				if (sourceValue != null) {
-					for (Object o : sourceValue) {
-						result.add(o);
-					}
+			state.clear();
+		}
+
+		// write result to the target
+		kvState.setCurrentNamespace(target);
+		if (result != null) {
+			state.add(result);
+		}
+	}
+
+	private <N, T> void mergeListState(
+			final ListStateDescriptor<?> listStateDescriptor,
+			final TypeSerializer<N> namespaceSerializer,
+			final N target,
+			final Collection<N> sources) throws Exception {
+
+		@SuppressWarnings("unchecked")
+		final ListState<T> state = (ListState<T>) getPartitionedState(target, namespaceSerializer, listStateDescriptor);
+
+		@SuppressWarnings("unchecked")
+		final KvState<N> kvState = (KvState<N>) state;
+
+		// merge the sources
+		final List<T> result = new ArrayList<>();
+		for (N source: sources) {
+			kvState.setCurrentNamespace(source);
+			Iterable<T> sourceValue = state.get();
+			if (sourceValue != null) {
+				for (T o : sourceValue) {
+					result.add(o);
 				}
-				state.clear();
-			}
-			kvState.setCurrentNamespace(target);
-			for (Object o : result) {
-				state.add(o);
 			}
-		} else {
-			throw new RuntimeException("Cannot merge states for " + stateDescriptor);
+			state.clear();
+		}
+
+		// write to the target
+		kvState.setCurrentNamespace(target);
+		for (T o : result) {
+			state.add(o);
 		}
 	}