You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2018/07/16 21:06:33 UTC

[4/8] flink git commit: [FLINK-9489] Checkpoint timers as part of managed keyed state instead of raw keyed state

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
index 3a348a9..72c70bc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
@@ -20,7 +20,7 @@ package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.util.MathUtils;
 import org.apache.flink.util.Preconditions;
@@ -204,7 +204,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	 * @param keyContext the key context.
 	 * @param metaInfo   the meta information, including the type serializer for state copy-on-write.
 	 */
-	CopyOnWriteStateTable(InternalKeyContext<K> keyContext, RegisteredKeyedBackendStateMetaInfo<N, S> metaInfo) {
+	CopyOnWriteStateTable(InternalKeyContext<K> keyContext, RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo) {
 		this(keyContext, metaInfo, 1024);
 	}
 
@@ -217,7 +217,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	 * @throws IllegalArgumentException when the capacity is less than zero.
 	 */
 	@SuppressWarnings("unchecked")
-	private CopyOnWriteStateTable(InternalKeyContext<K> keyContext, RegisteredKeyedBackendStateMetaInfo<N, S> metaInfo, int capacity) {
+	private CopyOnWriteStateTable(InternalKeyContext<K> keyContext, RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo, int capacity) {
 		super(keyContext, metaInfo);
 
 		// initialized tables to EMPTY_TABLE.
@@ -547,12 +547,12 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	}
 
 	@Override
-	public RegisteredKeyedBackendStateMetaInfo<N, S> getMetaInfo() {
+	public RegisteredKeyValueStateBackendMetaInfo<N, S> getMetaInfo() {
 		return metaInfo;
 	}
 
 	@Override
-	public void setMetaInfo(RegisteredKeyedBackendStateMetaInfo<N, S> metaInfo) {
+	public void setMetaInfo(RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo) {
 		this.metaInfo = metaInfo;
 	}
 
@@ -871,8 +871,9 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	 *
 	 * @return a snapshot from this {@link CopyOnWriteStateTable}, for checkpointing.
 	 */
+	@Nonnull
 	@Override
-	public CopyOnWriteStateTableSnapshot<K, N, S> createSnapshot() {
+	public CopyOnWriteStateTableSnapshot<K, N, S> stateSnapshot() {
 		return new CopyOnWriteStateTableSnapshot<>(this);
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
index 4c0ab6f..f3f21dd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
@@ -24,7 +24,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.KeyGroupPartitioner;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
@@ -91,7 +91,7 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 	 * to an output as part of checkpointing.
 	 */
 	@Nullable
-	private StateSnapshot.KeyGroupPartitionedSnapshot partitionedStateTableSnapshot;
+	private StateKeyGroupWriter partitionedStateTableSnapshot;
 
 	/**
 	 * Creates a new {@link CopyOnWriteStateTableSnapshot}.
@@ -135,7 +135,7 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 	@Nonnull
 	@SuppressWarnings("unchecked")
 	@Override
-	public KeyGroupPartitionedSnapshot partitionByKeyGroup() {
+	public StateKeyGroupWriter getKeyGroupWriter() {
 
 		if (partitionedStateTableSnapshot == null) {
 
@@ -160,6 +160,12 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 		return partitionedStateTableSnapshot;
 	}
 
+	@Nonnull
+	@Override
+	public StateMetaInfoSnapshot getMetaInfoSnapshot() {
+		return owningStateTable.metaInfo.snapshot();
+	}
+
 	@Override
 	public void release() {
 		owningStateTable.releaseSnapshot(this);

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 495dfe0..2c6101e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -49,22 +49,25 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.Keyed;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
 import org.apache.flink.runtime.state.KeyedStateFunction;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.LocalRecoveryConfig;
-import org.apache.flink.runtime.state.PriorityComparator;
-import org.apache.flink.runtime.state.PriorityQueueSetFactory;
-import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.PriorityComparable;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
+import org.apache.flink.runtime.state.StateSnapshotRestore;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
-import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StateMigrationException;
@@ -108,19 +111,47 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			Tuple2.of(FoldingStateDescriptor.class, (StateFactory) HeapFoldingState::create)
 		).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
+	@SuppressWarnings("unchecked")
 	@Nonnull
 	@Override
-	public <T extends HeapPriorityQueueElement> KeyGroupedInternalPriorityQueue<T> create(
+	public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> create(
 		@Nonnull String stateName,
-		@Nonnull TypeSerializer<T> byteOrderedElementSerializer,
-		@Nonnull PriorityComparator<T> elementPriorityComparator,
-		@Nonnull KeyExtractorFunction<T> keyExtractor) {
+		@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
+
+		final StateSnapshotRestore snapshotRestore = registeredStates.get(stateName);
+
+		if (snapshotRestore instanceof HeapPriorityQueueSnapshotRestoreWrapper) {
+			//TODO Serializer upgrade story!?
+			return ((HeapPriorityQueueSnapshotRestoreWrapper<T>) snapshotRestore).getPriorityQueue();
+		} else if (snapshotRestore != null) {
+			throw new IllegalStateException("Already found a different state type registered under this name: " + snapshotRestore.getClass());
+		}
+
+		final RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo =
+			new RegisteredPriorityQueueStateBackendMetaInfo<>(stateName, byteOrderedElementSerializer);
 
-		return priorityQueueSetFactory.create(
+		return createInternal(metaInfo);
+	}
+
+	@Nonnull
+	private <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> createInternal(
+		RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
+
+		final String stateName = metaInfo.getName();
+		final HeapPriorityQueueSet<T> priorityQueue = priorityQueueSetFactory.create(
 			stateName,
-			byteOrderedElementSerializer,
-			elementPriorityComparator,
-			keyExtractor);
+			metaInfo.getElementSerializer());
+
+		HeapPriorityQueueSnapshotRestoreWrapper<T> wrapper =
+			new HeapPriorityQueueSnapshotRestoreWrapper<>(
+				priorityQueue,
+				metaInfo,
+				KeyExtractorFunction.forKeyedObjects(),
+				keyGroupRange,
+				numberOfKeyGroups);
+
+		registeredStates.put(stateName, wrapper);
+		return priorityQueue;
 	}
 
 	private interface StateFactory {
@@ -131,14 +162,9 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	/**
-	 * Map of state tables that stores all state of key/value states. We store it centrally so
-	 * that we can easily checkpoint/restore it.
-	 *
-	 * <p>The actual parameters of StateTable are {@code StateTable<NamespaceT, Map<KeyT, StateT>>}
-	 * but we can't put them here because different key/value states with different types and
-	 * namespace types share this central list of tables.
+	 * Map of registered states for snapshot/restore.
 	 */
-	private final Map<String, StateTable<K, ?, ?>> stateTables = new HashMap<>();
+	private final Map<String, StateSnapshotRestore> registeredStates = new HashMap<>();
 
 	/**
 	 * Map of state names to their corresponding restored state meta info.
@@ -161,7 +187,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	/**
 	 * Factory for state that is organized as priority queue.
 	 */
-	private final PriorityQueueSetFactory priorityQueueSetFactory;
+	private final HeapPriorityQueueSetFactory priorityQueueSetFactory;
 
 	public HeapKeyedStateBackend(
 			TaskKvStateRegistry kvStateRegistry,
@@ -172,7 +198,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			boolean asynchronousSnapshots,
 			ExecutionConfig executionConfig,
 			LocalRecoveryConfig localRecoveryConfig,
-			PriorityQueueSetFactory priorityQueueSetFactory,
+			HeapPriorityQueueSetFactory priorityQueueSetFactory,
 			TtlTimeProvider ttlTimeProvider) {
 
 		super(kvStateRegistry, keySerializer, userCodeClassLoader,
@@ -197,9 +223,9 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			TypeSerializer<N> namespaceSerializer, StateDescriptor<?, V> stateDesc) throws StateMigrationException {
 
 		@SuppressWarnings("unchecked")
-		StateTable<K, N, V> stateTable = (StateTable<K, N, V>) stateTables.get(stateDesc.getName());
+		StateTable<K, N, V> stateTable = (StateTable<K, N, V>) registeredStates.get(stateDesc.getName());
 
-		RegisteredKeyedBackendStateMetaInfo<N, V> newMetaInfo;
+		RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo;
 		if (stateTable != null) {
 			@SuppressWarnings("unchecked")
 			StateMetaInfoSnapshot restoredMetaInfoSnapshot =
@@ -210,37 +236,43 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				"Requested to check compatibility of a restored RegisteredKeyedBackendStateMetaInfo," +
 					" but its corresponding restored snapshot cannot be found.");
 
-			newMetaInfo = RegisteredKeyedBackendStateMetaInfo.resolveKvStateCompatibility(
+			newMetaInfo = RegisteredKeyValueStateBackendMetaInfo.resolveKvStateCompatibility(
 				restoredMetaInfoSnapshot,
 				namespaceSerializer,
 				stateDesc);
 
 			stateTable.setMetaInfo(newMetaInfo);
 		} else {
-			newMetaInfo = new RegisteredKeyedBackendStateMetaInfo<>(
+			newMetaInfo = new RegisteredKeyValueStateBackendMetaInfo<>(
 				stateDesc.getType(),
 				stateDesc.getName(),
 				namespaceSerializer,
 				stateDesc.getSerializer());
 
 			stateTable = snapshotStrategy.newStateTable(newMetaInfo);
-			stateTables.put(stateDesc.getName(), stateTable);
+			registeredStates.put(stateDesc.getName(), stateTable);
 		}
 
 		return stateTable;
 	}
 
+	@SuppressWarnings("unchecked")
 	@Override
 	public <N> Stream<K> getKeys(String state, N namespace) {
-		if (!stateTables.containsKey(state)) {
+		if (!registeredStates.containsKey(state)) {
 			return Stream.empty();
 		}
-		StateTable<K, N, ?> table = (StateTable<K, N, ?>) stateTables.get(state);
+
+		final StateSnapshotRestore stateSnapshotRestore = registeredStates.get(state);
+		if (!(stateSnapshotRestore instanceof StateTable)) {
+			return Stream.empty();
+		}
+		StateTable<K, N, ?> table = (StateTable<K, N, ?>) stateSnapshotRestore;
 		return table.getKeys(namespace);
 	}
 
 	private boolean hasRegisteredState() {
-		return !stateTables.isEmpty();
+		return !registeredStates.isEmpty();
 	}
 
 	@Override
@@ -288,7 +320,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		final Map<Integer, String> kvStatesById = new HashMap<>();
 		int numRegisteredKvStates = 0;
-		stateTables.clear();
+		registeredStates.clear();
 
 		boolean keySerializerRestored = false;
 
@@ -342,16 +374,20 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				for (StateMetaInfoSnapshot restoredMetaInfo : restoredMetaInfos) {
 					restoredKvStateMetaInfos.put(restoredMetaInfo.getName(), restoredMetaInfo);
 
-					StateTable<K, ?, ?> stateTable = stateTables.get(restoredMetaInfo.getName());
+					StateSnapshotRestore snapshotRestore = registeredStates.get(restoredMetaInfo.getName());
 
 					//important: only create a new table we did not already create it previously
-					if (null == stateTable) {
+					if (null == snapshotRestore) {
 
-						RegisteredKeyedBackendStateMetaInfo<?, ?> registeredKeyedBackendStateMetaInfo =
-								new RegisteredKeyedBackendStateMetaInfo<>(restoredMetaInfo);
+						if (restoredMetaInfo.getBackendStateType() == StateMetaInfoSnapshot.BackendStateType.KEY_VALUE) {
+							RegisteredKeyValueStateBackendMetaInfo<?, ?> registeredKeyedBackendStateMetaInfo =
+								new RegisteredKeyValueStateBackendMetaInfo<>(restoredMetaInfo);
 
-						stateTable = snapshotStrategy.newStateTable(registeredKeyedBackendStateMetaInfo);
-						stateTables.put(restoredMetaInfo.getName(), stateTable);
+							snapshotRestore = snapshotStrategy.newStateTable(registeredKeyedBackendStateMetaInfo);
+							registeredStates.put(restoredMetaInfo.getName(), snapshotRestore);
+						} else {
+							createInternal(new RegisteredPriorityQueueStateBackendMetaInfo<>(restoredMetaInfo));
+						}
 						kvStatesById.put(numRegisteredKvStates, restoredMetaInfo.getName());
 						++numRegisteredKvStates;
 					} else {
@@ -384,12 +420,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 						for (int i = 0; i < restoredMetaInfos.size(); i++) {
 							int kvStateId = kgCompressionInView.readShort();
-							StateTable<K, ?, ?> stateTable = stateTables.get(kvStatesById.get(kvStateId));
+							StateSnapshotRestore registeredState = registeredStates.get(kvStatesById.get(kvStateId));
 
-							StateTableByKeyGroupReader keyGroupReader =
-								StateTableByKeyGroupReaders.readerForVersion(
-									stateTable,
-									serializationProxy.getReadVersion());
+							StateSnapshotKeyGroupReader keyGroupReader =
+								registeredState.keyGroupReader(serializationProxy.getReadVersion());
 
 							keyGroupReader.readMappingsInKeyGroup(kgCompressionInView, keyGroupIndex);
 						}
@@ -446,8 +480,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	@Override
 	public int numStateEntries() {
 		int sum = 0;
-		for (StateTable<K, ?, ?> stateTable : stateTables.values()) {
-			sum += stateTable.size();
+		for (StateSnapshotRestore state : registeredStates.values()) {
+			if (state instanceof StateTable) {
+				sum += ((StateTable<?, ?, ?>) state).size();
+			}
 		}
 		return sum;
 	}
@@ -458,8 +494,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	@VisibleForTesting
 	public int numStateEntries(Object namespace) {
 		int sum = 0;
-		for (StateTable<K, ?, ?> stateTable : stateTables.values()) {
-			sum += stateTable.sizeOfNamespace(namespace);
+		for (StateSnapshotRestore state : registeredStates.values()) {
+			if (state instanceof StateTable) {
+				sum += ((StateTable<?, ?, ?>) state).sizeOfNamespace(namespace);
+			}
 		}
 		return sum;
 	}
@@ -486,7 +524,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		boolean isAsynchronous();
 
-		<N, V> StateTable<K, N, V> newStateTable(RegisteredKeyedBackendStateMetaInfo<N, V> newMetaInfo);
+		<N, V> StateTable<K, N, V> newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo);
 	}
 
 	private class AsyncSnapshotStrategySynchronicityBehavior implements SnapshotStrategySynchronicityBehavior<K> {
@@ -503,7 +541,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 
 		@Override
-		public <N, V> StateTable<K, N, V> newStateTable(RegisteredKeyedBackendStateMetaInfo<N, V> newMetaInfo) {
+		public <N, V> StateTable<K, N, V> newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
 			return new CopyOnWriteStateTable<>(HeapKeyedStateBackend.this, newMetaInfo);
 		}
 	}
@@ -522,7 +560,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 
 		@Override
-		public <N, V> StateTable<K, N, V> newStateTable(RegisteredKeyedBackendStateMetaInfo<N, V> newMetaInfo) {
+		public <N, V> StateTable<K, N, V> newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
 			return new NestedMapsStateTable<>(HeapKeyedStateBackend.this, newMetaInfo);
 		}
 	}
@@ -554,25 +592,26 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 			long syncStartTime = System.currentTimeMillis();
 
-			Preconditions.checkState(stateTables.size() <= Short.MAX_VALUE,
-				"Too many KV-States: " + stateTables.size() +
+			Preconditions.checkState(registeredStates.size() <= Short.MAX_VALUE,
+				"Too many KV-States: " + registeredStates.size() +
 					". Currently at most " + Short.MAX_VALUE + " states are supported");
 
 			List<StateMetaInfoSnapshot> metaInfoSnapshots =
-				new ArrayList<>(stateTables.size());
+				new ArrayList<>(registeredStates.size());
 
-			final Map<String, Integer> kVStateToId = new HashMap<>(stateTables.size());
+			final Map<String, Integer> kVStateToId = new HashMap<>(registeredStates.size());
 
 			final Map<String, StateSnapshot> cowStateStableSnapshots =
-				new HashMap<>(stateTables.size());
+				new HashMap<>(registeredStates.size());
 
-			for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
+			for (Map.Entry<String, StateSnapshotRestore> kvState : registeredStates.entrySet()) {
 				String stateName = kvState.getKey();
 				kVStateToId.put(stateName, kVStateToId.size());
-				StateTable<K, ?, ?> stateTable = kvState.getValue();
-				if (null != stateTable) {
-					metaInfoSnapshots.add(stateTable.getMetaInfo().snapshot());
-					cowStateStableSnapshots.put(stateName, stateTable.createSnapshot());
+				StateSnapshotRestore state = kvState.getValue();
+				if (null != state) {
+					final StateSnapshot stateSnapshot = state.stateSnapshot();
+					metaInfoSnapshots.add(stateSnapshot.getMetaInfoSnapshot());
+					cowStateStableSnapshots.put(stateName, stateSnapshot);
 				}
 			}
 
@@ -654,13 +693,13 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 							outView.writeInt(keyGroupId);
 
 							for (Map.Entry<String, StateSnapshot> kvState : cowStateStableSnapshots.entrySet()) {
-								StateSnapshot.KeyGroupPartitionedSnapshot partitionedSnapshot =
-									kvState.getValue().partitionByKeyGroup();
+								StateSnapshot.StateKeyGroupWriter partitionedSnapshot =
+									kvState.getValue().getKeyGroupWriter();
 								try (OutputStream kgCompressionOut = keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
 									String stateName = kvState.getKey();
 									DataOutputViewStreamWrapper kgCompressionView = new DataOutputViewStreamWrapper(kgCompressionOut);
 									kgCompressionView.writeShort(kVStateToId.get(stateName));
-									partitionedSnapshot.writeMappingsInKeyGroup(kgCompressionView, keyGroupId);
+									partitionedSnapshot.writeStateInKeyGroup(kgCompressionView, keyGroupId);
 								} // this will just close the outer compression stream
 							}
 						}
@@ -705,7 +744,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 
 		@Override
-		public <N, V> StateTable<K, N, V> newStateTable(RegisteredKeyedBackendStateMetaInfo<N, V> newMetaInfo) {
+		public <N, V> StateTable<K, N, V> newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
 			return snapshotStrategySynchronicityTrait.newStateTable(newMetaInfo);
 		}
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java
index e5f610e..22b2419 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueue.java
@@ -50,7 +50,8 @@ import static org.apache.flink.util.CollectionUtil.MAX_ARRAY_SIZE;
  *
  * @param <T> type of the contained elements.
  */
-public class HeapPriorityQueue<T extends HeapPriorityQueueElement> implements InternalPriorityQueue<T> {
+public class HeapPriorityQueue<T extends HeapPriorityQueueElement>
+	implements InternalPriorityQueue<T> {
 
 	/**
 	 * The index of the head element in the array that represents the heap.

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
index ee6fda9..b0255d3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
@@ -21,7 +21,8 @@ package org.apache.flink.runtime.state.heap;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
+import org.apache.flink.runtime.state.Keyed;
+import org.apache.flink.runtime.state.PriorityComparable;
 import org.apache.flink.runtime.state.PriorityComparator;
 import org.apache.flink.runtime.state.PriorityQueueSetFactory;
 
@@ -29,7 +30,7 @@ import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
 
 /**
- *
+ * Factory for {@link HeapPriorityQueueSet}.
  */
 public class HeapPriorityQueueSetFactory implements PriorityQueueSetFactory {
 
@@ -54,14 +55,13 @@ public class HeapPriorityQueueSetFactory implements PriorityQueueSetFactory {
 
 	@Nonnull
 	@Override
-	public <T extends HeapPriorityQueueElement> KeyGroupedInternalPriorityQueue<T> create(
+	public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> HeapPriorityQueueSet<T> create(
 		@Nonnull String stateName,
-		@Nonnull TypeSerializer<T> byteOrderedElementSerializer,
-		@Nonnull PriorityComparator<T> elementPriorityComparator,
-		@Nonnull KeyExtractorFunction<T> keyExtractor) {
-		return new HeapPriorityQueueSet<>(
-			elementPriorityComparator,
-			keyExtractor,
+		@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
+
+		return new HeapPriorityQueueSet<T>(
+			PriorityComparator.forPriorityComparableObjects(),
+			KeyExtractorFunction.forKeyedObjects(),
 			minimumCapacity,
 			keyGroupRange,
 			totalKeyGroups);

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
new file mode 100644
index 0000000..5fd67f0
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.KeyExtractorFunction;
+import org.apache.flink.runtime.state.KeyGroupPartitioner;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
+import org.apache.flink.runtime.state.StateSnapshotRestore;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+/**
+ * This wrapper combines a HeapPriorityQueue with backend meta data.
+ *
+ * @param <T> type of the queue elements.
+ */
+public class HeapPriorityQueueSnapshotRestoreWrapper<T extends HeapPriorityQueueElement>
+	implements StateSnapshotRestore {
+
+	@Nonnull
+	private final HeapPriorityQueueSet<T> priorityQueue;
+	@Nonnull
+	private final KeyExtractorFunction<T> keyExtractorFunction;
+	@Nonnull
+	private final RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo;
+	@Nonnull
+	private final KeyGroupRange localKeyGroupRange;
+	@Nonnegative
+	private final int totalKeyGroups;
+
+	public HeapPriorityQueueSnapshotRestoreWrapper(
+		@Nonnull HeapPriorityQueueSet<T> priorityQueue,
+		@Nonnull RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo,
+		@Nonnull KeyExtractorFunction<T> keyExtractorFunction,
+		@Nonnull KeyGroupRange localKeyGroupRange,
+		int totalKeyGroups) {
+
+		this.priorityQueue = priorityQueue;
+		this.keyExtractorFunction = keyExtractorFunction;
+		this.metaInfo = metaInfo;
+		this.localKeyGroupRange = localKeyGroupRange;
+		this.totalKeyGroups = totalKeyGroups;
+	}
+
+	@SuppressWarnings("unchecked")
+	@Nonnull
+	@Override
+	public StateSnapshot stateSnapshot() {
+		final T[] queueDump = (T[]) priorityQueue.toArray(new HeapPriorityQueueElement[priorityQueue.size()]);
+
+		final TypeSerializer<T> elementSerializer = metaInfo.getElementSerializer();
+
+		// turn the flat copy into a deep copy if required.
+		if (!elementSerializer.isImmutableType()) {
+			for (int i = 0; i < queueDump.length; ++i) {
+				queueDump[i] = elementSerializer.copy(queueDump[i]);
+			}
+		}
+
+		return new HeapPriorityQueueStateSnapshot<>(
+			queueDump,
+			keyExtractorFunction,
+			metaInfo.deepCopy(),
+			localKeyGroupRange,
+			totalKeyGroups);
+	}
+
+	@Nonnull
+	@Override
+	public StateSnapshotKeyGroupReader keyGroupReader(int readVersionHint) {
+		final TypeSerializer<T> elementSerializer = metaInfo.getElementSerializer();
+		return KeyGroupPartitioner.createKeyGroupPartitionReader(
+			elementSerializer::deserialize, //we know that this does not deliver nulls, because we never write nulls
+			(element, keyGroupId) -> priorityQueue.add(element));
+	}
+
+	@Nonnull
+	public HeapPriorityQueueSet<T> getPriorityQueue() {
+		return priorityQueue;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueStateSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueStateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueStateSnapshot.java
new file mode 100644
index 0000000..18e7d54
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueStateSnapshot.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.KeyExtractorFunction;
+import org.apache.flink.runtime.state.KeyGroupPartitioner;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.RegisteredPriorityQueueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.lang.reflect.Array;
+
+/**
+ * This class represents the snapshot of an {@link HeapPriorityQueueSet}.
+ *
+ * @param <T> type of the state elements.
+ */
+public class HeapPriorityQueueStateSnapshot<T> implements StateSnapshot {
+
+	/** Function that extracts keys from elements. */
+	@Nonnull
+	private final KeyExtractorFunction<T> keyExtractor;
+
+	/** Copy of the heap array containing all the (immutable or deeply copied) elements. */
+	@Nonnull
+	private final T[] heapArrayCopy;
+
+	/** The meta info of the state. */
+	@Nonnull
+	private final RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo;
+
+	/** The key-group range covered by this snapshot. */
+	@Nonnull
+	private final KeyGroupRange keyGroupRange;
+
+	/** The total number of key-groups in the job. */
+	@Nonnegative
+	private final int totalKeyGroups;
+
+	/** Result of partitioning the snapshot by key-group. */
+	@Nullable
+	private StateKeyGroupWriter stateKeyGroupWriter;
+
+	HeapPriorityQueueStateSnapshot(
+		@Nonnull T[] heapArrayCopy,
+		@Nonnull KeyExtractorFunction<T> keyExtractor,
+		@Nonnull RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo,
+		@Nonnull KeyGroupRange keyGroupRange,
+		@Nonnegative int totalKeyGroups) {
+
+		this.keyExtractor = keyExtractor;
+		this.heapArrayCopy = heapArrayCopy;
+		this.metaInfo = metaInfo;
+		this.keyGroupRange = keyGroupRange;
+		this.totalKeyGroups = totalKeyGroups;
+	}
+
+	@SuppressWarnings("unchecked")
+	@Nonnull
+	@Override
+	public StateKeyGroupWriter getKeyGroupWriter() {
+
+		if (stateKeyGroupWriter == null) {
+
+			T[] partitioningOutput = (T[]) Array.newInstance(
+				heapArrayCopy.getClass().getComponentType(),
+				heapArrayCopy.length);
+
+			final TypeSerializer<T> elementSerializer = metaInfo.getElementSerializer();
+
+			KeyGroupPartitioner<T> keyGroupPartitioner =
+				new KeyGroupPartitioner<>(
+					heapArrayCopy,
+					heapArrayCopy.length,
+					partitioningOutput,
+					keyGroupRange,
+					totalKeyGroups,
+					keyExtractor,
+					elementSerializer::serialize);
+
+			stateKeyGroupWriter = keyGroupPartitioner.partitionByKeyGroup();
+		}
+
+		return stateKeyGroupWriter;
+	}
+
+	@Nonnull
+	@Override
+	public StateMetaInfoSnapshot getMetaInfoSnapshot() {
+		return metaInfo.snapshot();
+	}
+
+	@Override
+	public void release() {
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java
index 6f4f911..d8b0a5a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/KeyGroupPartitionedPriorityQueue.java
@@ -49,6 +49,8 @@ import java.util.function.Predicate;
 public class KeyGroupPartitionedPriorityQueue<T, PQ extends InternalPriorityQueue<T> & HeapPriorityQueueElement>
 	implements InternalPriorityQueue<T>, KeyGroupedInternalPriorityQueue<T> {
 
+	static final boolean ENABLE_RELAXED_FIRING_ORDER_OPTIMIZATION = false;
+
 	/** A heap of heap sets. Each sub-heap represents the partition for a key-group.*/
 	@Nonnull
 	private final HeapPriorityQueue<PQ> heapOfkeyGroupedHeaps;
@@ -94,6 +96,22 @@ public class KeyGroupPartitionedPriorityQueue<T, PQ extends InternalPriorityQueu
 
 	@Override
 	public void bulkPoll(@Nonnull Predicate<T> canConsume, @Nonnull Consumer<T> consumer) {
+		if (ENABLE_RELAXED_FIRING_ORDER_OPTIMIZATION) {
+			bulkPollRelaxedOrder(canConsume, consumer);
+		} else {
+			bulkPollStrictOrder(canConsume, consumer);
+		}
+	}
+
+	private void bulkPollRelaxedOrder(@Nonnull Predicate<T> canConsume, @Nonnull Consumer<T> consumer) {
+		PQ headList = heapOfkeyGroupedHeaps.peek();
+		while (headList.peek() != null && canConsume.test(headList.peek())) {
+			headList.bulkPoll(canConsume, consumer);
+			heapOfkeyGroupedHeaps.adjustModifiedElement(headList);
+		}
+	}
+
+	private void bulkPollStrictOrder(@Nonnull Predicate<T> canConsume, @Nonnull Consumer<T> consumer) {
 		T element;
 		while ((element = peek()) != null && canConsume.test(element)) {
 			poll();

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
index 18551b5..efed1cc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
@@ -22,9 +22,10 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateTransformationFunction;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.Preconditions;
 
 import javax.annotation.Nonnull;
@@ -34,6 +35,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Objects;
 import java.util.stream.Stream;
 
 /**
@@ -69,7 +71,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	 * @param keyContext the key context.
 	 * @param metaInfo the meta information for this state table.
 	 */
-	public NestedMapsStateTable(InternalKeyContext<K> keyContext, RegisteredKeyedBackendStateMetaInfo<N, S> metaInfo) {
+	public NestedMapsStateTable(InternalKeyContext<K> keyContext, RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo) {
 		super(keyContext, metaInfo);
 		this.keyGroupOffset = keyContext.getKeyGroupRange().getStartKeyGroup();
 
@@ -175,7 +177,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	@Override
 	public Stream<K> getKeys(N namespace) {
 		return Arrays.stream(state)
-			.filter(namespaces -> namespaces != null)
+			.filter(Objects::nonNull)
 			.map(namespaces -> namespaces.getOrDefault(namespace, Collections.emptyMap()))
 			.flatMap(namespaceSate -> namespaceSate.keySet().stream());
 	}
@@ -232,12 +234,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 			setMapForKeyGroup(keyGroupIndex, namespaceMap);
 		}
 
-		Map<K, S> keyedMap = namespaceMap.get(namespace);
-
-		if (keyedMap == null) {
-			keyedMap = new HashMap<>();
-			namespaceMap.put(namespace, keyedMap);
-		}
+		Map<K, S> keyedMap = namespaceMap.computeIfAbsent(namespace, k -> new HashMap<>());
 
 		return keyedMap.put(key, value);
 	}
@@ -302,13 +299,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 			setMapForKeyGroup(keyGroupIndex, namespaceMap);
 		}
 
-		Map<K, S> keyedMap = namespaceMap.get(namespace);
-
-		if (keyedMap == null) {
-			keyedMap = new HashMap<>();
-			namespaceMap.put(namespace, keyedMap);
-		}
-
+		Map<K, S> keyedMap = namespaceMap.computeIfAbsent(namespace, k -> new HashMap<>());
 		keyedMap.put(key, transformation.apply(keyedMap.get(key), value));
 	}
 
@@ -323,8 +314,9 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 		return count;
 	}
 
+	@Nonnull
 	@Override
-	public NestedMapsStateTableSnapshot<K, N, S> createSnapshot() {
+	public NestedMapsStateTableSnapshot<K, N, S> stateSnapshot() {
 		return new NestedMapsStateTableSnapshot<>(this);
 	}
 
@@ -337,7 +329,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	 */
 	static class NestedMapsStateTableSnapshot<K, N, S>
 			extends AbstractStateTableSnapshot<K, N, S, NestedMapsStateTable<K, N, S>>
-			implements StateSnapshot.KeyGroupPartitionedSnapshot {
+			implements StateSnapshot.StateKeyGroupWriter {
 
 		NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> owningTable) {
 			super(owningTable);
@@ -345,10 +337,16 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 
 		@Nonnull
 		@Override
-		public KeyGroupPartitionedSnapshot partitionByKeyGroup() {
+		public StateKeyGroupWriter getKeyGroupWriter() {
 			return this;
 		}
 
+		@Nonnull
+		@Override
+		public StateMetaInfoSnapshot getMetaInfoSnapshot() {
+			return owningStateTable.metaInfo.snapshot();
+		}
+
 		/**
 		 * Implementation note: we currently chose the same format between {@link NestedMapsStateTable} and
 		 * {@link CopyOnWriteStateTable}.
@@ -359,7 +357,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 		 * implementations).
 		 */
 		@Override
-		public void writeMappingsInKeyGroup(@Nonnull DataOutputView dov, int keyGroupId) throws IOException {
+		public void writeStateInKeyGroup(@Nonnull DataOutputView dov, int keyGroupId) throws IOException {
 			final Map<N, Map<K, S>> keyGroupMap = owningStateTable.getMapForKeyGroup(keyGroupId);
 			if (null != keyGroupMap) {
 				TypeSerializer<K> keySerializer = owningStateTable.keyContext.getKeySerializer();

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
index de2290a..58a2f97 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
@@ -20,11 +20,14 @@ package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
-import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
+import org.apache.flink.runtime.state.StateSnapshotRestore;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.util.Preconditions;
 
+import javax.annotation.Nonnull;
+
 import java.util.stream.Stream;
 
 /**
@@ -35,7 +38,7 @@ import java.util.stream.Stream;
  * @param <N> type of namespace
  * @param <S> type of state
  */
-public abstract class StateTable<K, N, S> {
+public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 
 	/**
 	 * The key context view on the backend. This provides information, such as the currently active key.
@@ -45,14 +48,14 @@ public abstract class StateTable<K, N, S> {
 	/**
 	 * Combined meta information such as name and serializers for this state
 	 */
-	protected RegisteredKeyedBackendStateMetaInfo<N, S> metaInfo;
+	protected RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo;
 
 	/**
 	 *
 	 * @param keyContext the key context provides the key scope for all put/get/delete operations.
 	 * @param metaInfo the meta information, including the type serializer for state copy-on-write.
 	 */
-	public StateTable(InternalKeyContext<K> keyContext, RegisteredKeyedBackendStateMetaInfo<N, S> metaInfo) {
+	public StateTable(InternalKeyContext<K> keyContext, RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo) {
 		this.keyContext = Preconditions.checkNotNull(keyContext);
 		this.metaInfo = Preconditions.checkNotNull(metaInfo);
 	}
@@ -173,22 +176,26 @@ public abstract class StateTable<K, N, S> {
 		return metaInfo.getNamespaceSerializer();
 	}
 
-	public RegisteredKeyedBackendStateMetaInfo<N, S> getMetaInfo() {
+	public RegisteredKeyValueStateBackendMetaInfo<N, S> getMetaInfo() {
 		return metaInfo;
 	}
 
-	public void setMetaInfo(RegisteredKeyedBackendStateMetaInfo<N, S> metaInfo) {
+	public void setMetaInfo(RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo) {
 		this.metaInfo = metaInfo;
 	}
 
 	// Snapshot / Restore -------------------------------------------------------------------------
 
-	abstract StateSnapshot createSnapshot();
-
 	public abstract void put(K key, int keyGroup, N namespace, S state);
 
 	// For testing --------------------------------------------------------------------------------
 
 	@VisibleForTesting
 	public abstract int sizeOfNamespace(Object namespace);
+
+	@Nonnull
+	@Override
+	public StateSnapshotKeyGroupReader keyGroupReader(int readVersion) {
+		return StateTableByKeyGroupReaders.readerForVersion(this, readVersion);
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReader.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReader.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReader.java
deleted file mode 100644
index 659c174..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReader.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state.heap;
-
-import org.apache.flink.core.memory.DataInputView;
-
-import java.io.IOException;
-
-/**
- * Interface for state de-serialization into {@link StateTable}s by key-group.
- */
-interface StateTableByKeyGroupReader {
-
-	/**
-	 * Read the data for the specified key-group from the input.
-	 *
-	 * @param div        the input
-	 * @param keyGroupId the key-group to write
-	 * @throws IOException on write related problems
-	 */
-	void readMappingsInKeyGroup(DataInputView div, int keyGroupId) throws IOException;
-}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReaders.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReaders.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReaders.java
index e08e90e..2f83857 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReaders.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableByKeyGroupReaders.java
@@ -19,12 +19,18 @@
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.runtime.state.KeyGroupPartitioner;
+import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
 
 import java.io.IOException;
 
 /**
- * This class provides a static factory method to create different implementations of {@link StateTableByKeyGroupReader}
+ * This class provides a static factory method to create different implementations of {@link StateSnapshotKeyGroupReader}
  * depending on the provided serialization format version.
  * <p>
  * The implementations are also located here as inner classes.
@@ -35,69 +41,58 @@ class StateTableByKeyGroupReaders {
 	 * Creates a new StateTableByKeyGroupReader that inserts de-serialized mappings into the given table, using the
 	 * de-serialization algorithm that matches the given version.
 	 *
-	 * @param table the {@link StateTable} into which de-serialized mappings are inserted.
+	 * @param stateTable the {@link StateTable} into which de-serialized mappings are inserted.
 	 * @param version version for the de-serialization algorithm.
 	 * @param <K> type of key.
 	 * @param <N> type of namespace.
 	 * @param <S> type of state.
 	 * @return the appropriate reader.
 	 */
-	static <K, N, S> StateTableByKeyGroupReader readerForVersion(StateTable<K, N, S> table, int version) {
+	static <K, N, S> StateSnapshotKeyGroupReader readerForVersion(StateTable<K, N, S> stateTable, int version) {
 		switch (version) {
 			case 1:
-				return new StateTableByKeyGroupReaderV1<>(table);
+				return new StateTableByKeyGroupReaderV1<>(stateTable);
 			case 2:
 			case 3:
 			case 4:
 			case 5:
-				return new StateTableByKeyGroupReaderV2V3<>(table);
+				return createV2PlusReader(stateTable);
 			default:
 				throw new IllegalArgumentException("Unknown version: " + version);
 		}
 	}
 
-	static abstract class AbstractStateTableByKeyGroupReader<K, N, S>
-			implements StateTableByKeyGroupReader {
-
-		protected final StateTable<K, N, S> stateTable;
-
-		AbstractStateTableByKeyGroupReader(StateTable<K, N, S> stateTable) {
-			this.stateTable = stateTable;
-		}
-
-		@Override
-		public abstract void readMappingsInKeyGroup(DataInputView div, int keyGroupId) throws IOException;
-
-		protected TypeSerializer<K> getKeySerializer() {
-			return stateTable.keyContext.getKeySerializer();
-		}
-
-		protected TypeSerializer<N> getNamespaceSerializer() {
-			return stateTable.getNamespaceSerializer();
-		}
-
-		protected TypeSerializer<S> getStateSerializer() {
-			return stateTable.getStateSerializer();
-		}
+	private static <K, N, S> StateSnapshotKeyGroupReader createV2PlusReader(StateTable<K, N, S> stateTable) {
+		final TypeSerializer<K> keySerializer = stateTable.keyContext.getKeySerializer();
+		final TypeSerializer<N> namespaceSerializer = stateTable.getNamespaceSerializer();
+		final TypeSerializer<S> stateSerializer = stateTable.getStateSerializer();
+		final Tuple3<N, K, S> buffer = new Tuple3<>();
+		return KeyGroupPartitioner.createKeyGroupPartitionReader((in) -> {
+			buffer.f0 = namespaceSerializer.deserialize(in);
+			buffer.f1 = keySerializer.deserialize(in);
+			buffer.f2 = stateSerializer.deserialize(in);
+			return buffer;
+		}, (element, keyGroupId1) -> stateTable.put(element.f1, keyGroupId1, element.f0, element.f2));
 	}
 
-	static final class StateTableByKeyGroupReaderV1<K, N, S>
-			extends AbstractStateTableByKeyGroupReader<K, N, S> {
+	static final class StateTableByKeyGroupReaderV1<K, N, S> implements StateSnapshotKeyGroupReader {
+
+		protected final StateTable<K, N, S> stateTable;
 
 		StateTableByKeyGroupReaderV1(StateTable<K, N, S> stateTable) {
-			super(stateTable);
+			this.stateTable = stateTable;
 		}
 
 		@Override
-		public void readMappingsInKeyGroup(DataInputView inView, int keyGroupId) throws IOException {
+		public void readMappingsInKeyGroup(@Nonnull DataInputView inView, @Nonnegative int keyGroupId) throws IOException {
 
 			if (inView.readByte() == 0) {
 				return;
 			}
 
-			final TypeSerializer<K> keySerializer = getKeySerializer();
-			final TypeSerializer<N> namespaceSerializer = getNamespaceSerializer();
-			final TypeSerializer<S> stateSerializer = getStateSerializer();
+			final TypeSerializer<K> keySerializer = stateTable.keyContext.getKeySerializer();
+			final TypeSerializer<N> namespaceSerializer = stateTable.getNamespaceSerializer();
+			final TypeSerializer<S> stateSerializer = stateTable.getStateSerializer();
 
 			// V1 uses kind of namespace compressing format
 			int numNamespaces = inView.readInt();
@@ -112,28 +107,4 @@ class StateTableByKeyGroupReaders {
 			}
 		}
 	}
-
-	private static final class StateTableByKeyGroupReaderV2V3<K, N, S>
-			extends AbstractStateTableByKeyGroupReader<K, N, S> {
-
-		StateTableByKeyGroupReaderV2V3(StateTable<K, N, S> stateTable) {
-			super(stateTable);
-		}
-
-		@Override
-		public void readMappingsInKeyGroup(DataInputView inView, int keyGroupId) throws IOException {
-
-			final TypeSerializer<K> keySerializer = getKeySerializer();
-			final TypeSerializer<N> namespaceSerializer = getNamespaceSerializer();
-			final TypeSerializer<S> stateSerializer = getStateSerializer();
-
-			int numKeys = inView.readInt();
-			for (int i = 0; i < numKeys; ++i) {
-				N namespace = namespaceSerializer.deserialize(inView);
-				K key = keySerializer.deserialize(inView);
-				S state = stateSerializer.deserialize(inView);
-				stateTable.put(key, keyGroupId, namespace, state);
-			}
-		}
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshot.java
index 9341a5a..5a3190c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshot.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.state.metainfo;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
@@ -30,7 +31,7 @@ import java.util.Map;
 
 /**
  * Generalized snapshot for meta information about one state in a state backend
- * (e.g. {@link org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo}).
+ * (e.g. {@link RegisteredKeyValueStateBackendMetaInfo}).
  */
 public class StateMetaInfoSnapshot {
 
@@ -41,7 +42,7 @@ public class StateMetaInfoSnapshot {
 		KEY_VALUE,
 		OPERATOR,
 		BROADCAST,
-		TIMER
+		PRIORITY_QUEUE
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotReadersWriters.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotReadersWriters.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotReadersWriters.java
index ce535ef..926e75f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotReadersWriters.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotReadersWriters.java
@@ -39,6 +39,8 @@ import java.util.Map;
  */
 public class StateMetaInfoSnapshotReadersWriters {
 
+	private StateMetaInfoSnapshotReadersWriters() {}
+
 	/**
 	 * Current version for the serialization format of {@link StateMetaInfoSnapshotReadersWriters}.
 	 * - v5: Flink 1.6.x
@@ -74,23 +76,35 @@ public class StateMetaInfoSnapshotReadersWriters {
 	@Nonnull
 	public static StateMetaInfoReader getReader(int readVersion, @Nonnull StateTypeHint stateTypeHint) {
 
+		if (readVersion < CURRENT_STATE_META_INFO_SNAPSHOT_VERSION) {
+			switch (stateTypeHint) {
+				case KEYED_STATE:
+					return getLegacyKeyedStateMetaInfoReader(readVersion);
+				case OPERATOR_STATE:
+					return getLegacyOperatorStateMetaInfoReader(readVersion);
+				default:
+					throw new IllegalArgumentException("Unsupported state type hint: " + stateTypeHint +
+						" with version " + readVersion);
+			}
+		} else {
+			return getReader(readVersion);
+		}
+	}
+
+	/**
+	 * Returns a reader for {@link StateMetaInfoSnapshot} with the requested state type and version number.
+	 *
+	 * @param readVersion the format version to read.
+	 * @return the requested reader.
+	 */
+	@Nonnull
+	public static StateMetaInfoReader getReader(int readVersion) {
 		if (readVersion == CURRENT_STATE_META_INFO_SNAPSHOT_VERSION) {
 			// latest version shortcut
 			return CurrentReaderImpl.INSTANCE;
-		}
-
-		if (readVersion > CURRENT_STATE_META_INFO_SNAPSHOT_VERSION) {
+		} else {
 			throw new IllegalArgumentException("Unsupported read version for state meta info: " + readVersion);
 		}
-
-		switch (stateTypeHint) {
-			case KEYED_STATE:
-				return getLegacyKeyedStateMetaInfoReader(readVersion);
-			case OPERATOR_STATE:
-				return getLegacyOperatorStateMetaInfoReader(readVersion);
-			default:
-				throw new IllegalArgumentException("Unsupported state type hint: " + stateTypeHint);
-		}
 	}
 
 	@Nonnull

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupPartitionerTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupPartitionerTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupPartitionerTestBase.java
index e6b7739..3756187 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupPartitionerTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupPartitionerTestBase.java
@@ -85,11 +85,11 @@ public abstract class KeyGroupPartitionerTestBase<T> extends TestLogger {
 			new ValidatingElementWriterDummy<>(keyExtractorFunction, numberOfKeyGroups, allElementsIdentitySet);
 
 		final KeyGroupPartitioner<T> testInstance = createPartitioner(data, testSize, range, numberOfKeyGroups, validatingElementWriter);
-		final StateSnapshot.KeyGroupPartitionedSnapshot result = testInstance.partitionByKeyGroup();
+		final StateSnapshot.StateKeyGroupWriter result = testInstance.partitionByKeyGroup();
 
 		for (int keyGroup = 0; keyGroup < numberOfKeyGroups; ++keyGroup) {
 			validatingElementWriter.setCurrentKeyGroup(keyGroup);
-			result.writeMappingsInKeyGroup(DUMMY_OUT_VIEW, keyGroup);
+			result.writeStateInKeyGroup(DUMMY_OUT_VIEW, keyGroup);
 		}
 
 		validatingElementWriter.validateAllElementsSeen();

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
index 5241dd8..9c487a4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
@@ -55,11 +55,11 @@ public class SerializationProxiesTest {
 
 		List<StateMetaInfoSnapshot> stateMetaInfoList = new ArrayList<>();
 
-		stateMetaInfoList.add(new RegisteredKeyedBackendStateMetaInfo<>(
+		stateMetaInfoList.add(new RegisteredKeyValueStateBackendMetaInfo<>(
 			StateDescriptor.Type.VALUE, "a", namespaceSerializer, stateSerializer).snapshot());
-		stateMetaInfoList.add(new RegisteredKeyedBackendStateMetaInfo<>(
+		stateMetaInfoList.add(new RegisteredKeyValueStateBackendMetaInfo<>(
 			StateDescriptor.Type.VALUE, "b", namespaceSerializer, stateSerializer).snapshot());
-		stateMetaInfoList.add(new RegisteredKeyedBackendStateMetaInfo<>(
+		stateMetaInfoList.add(new RegisteredKeyValueStateBackendMetaInfo<>(
 			StateDescriptor.Type.VALUE, "c", namespaceSerializer, stateSerializer).snapshot());
 
 		KeyedBackendSerializationProxy<?> serializationProxy =
@@ -93,11 +93,11 @@ public class SerializationProxiesTest {
 
 		List<StateMetaInfoSnapshot> stateMetaInfoList = new ArrayList<>();
 
-		stateMetaInfoList.add(new RegisteredKeyedBackendStateMetaInfo<>(
+		stateMetaInfoList.add(new RegisteredKeyValueStateBackendMetaInfo<>(
 			StateDescriptor.Type.VALUE, "a", namespaceSerializer, stateSerializer).snapshot());
-		stateMetaInfoList.add(new RegisteredKeyedBackendStateMetaInfo<>(
+		stateMetaInfoList.add(new RegisteredKeyValueStateBackendMetaInfo<>(
 			StateDescriptor.Type.VALUE, "b", namespaceSerializer, stateSerializer).snapshot());
-		stateMetaInfoList.add(new RegisteredKeyedBackendStateMetaInfo<>(
+		stateMetaInfoList.add(new RegisteredKeyValueStateBackendMetaInfo<>(
 			StateDescriptor.Type.VALUE, "c", namespaceSerializer, stateSerializer).snapshot());
 
 		KeyedBackendSerializationProxy<?> serializationProxy =
@@ -132,7 +132,7 @@ public class SerializationProxiesTest {
 		Assert.assertEquals(keySerializer.snapshotConfiguration(), serializationProxy.getKeySerializerConfigSnapshot());
 
 		for (StateMetaInfoSnapshot snapshot : serializationProxy.getStateMetaInfoSnapshots()) {
-			final RegisteredKeyedBackendStateMetaInfo<?, ?> restoredMetaInfo = new RegisteredKeyedBackendStateMetaInfo<>(snapshot);
+			final RegisteredKeyValueStateBackendMetaInfo<?, ?> restoredMetaInfo = new RegisteredKeyValueStateBackendMetaInfo<>(snapshot);
 			Assert.assertTrue(restoredMetaInfo.getNamespaceSerializer() instanceof UnloadableDummyTypeSerializer);
 			Assert.assertTrue(restoredMetaInfo.getStateSerializer() instanceof UnloadableDummyTypeSerializer);
 			Assert.assertEquals(namespaceSerializer.snapshotConfiguration(), snapshot.getTypeSerializerConfigSnapshot(StateMetaInfoSnapshot.CommonSerializerKeys.NAMESPACE_SERIALIZER));
@@ -147,7 +147,7 @@ public class SerializationProxiesTest {
 		TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE;
 		TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
 
-		StateMetaInfoSnapshot metaInfo = new RegisteredKeyedBackendStateMetaInfo<>(
+		StateMetaInfoSnapshot metaInfo = new RegisteredKeyValueStateBackendMetaInfo<>(
 			StateDescriptor.Type.VALUE, name, namespaceSerializer, stateSerializer).snapshot();
 
 		byte[] serialized;
@@ -173,7 +173,7 @@ public class SerializationProxiesTest {
 		TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE;
 		TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
 
-		StateMetaInfoSnapshot snapshot = new RegisteredKeyedBackendStateMetaInfo<>(
+		StateMetaInfoSnapshot snapshot = new RegisteredKeyValueStateBackendMetaInfo<>(
 			StateDescriptor.Type.VALUE, name, namespaceSerializer, stateSerializer).snapshot();
 
 		byte[] serialized;
@@ -198,7 +198,7 @@ public class SerializationProxiesTest {
 				new DataInputViewStreamWrapper(in), classLoader);
 		}
 
-		RegisteredKeyedBackendStateMetaInfo<?, ?> restoredMetaInfo = new RegisteredKeyedBackendStateMetaInfo<>(snapshot);
+		RegisteredKeyValueStateBackendMetaInfo<?, ?> restoredMetaInfo = new RegisteredKeyValueStateBackendMetaInfo<>(snapshot);
 
 		Assert.assertEquals(name, restoredMetaInfo.getName());
 		Assert.assertTrue(restoredMetaInfo.getNamespaceSerializer() instanceof UnloadableDummyTypeSerializer);
@@ -216,18 +216,18 @@ public class SerializationProxiesTest {
 
 		List<StateMetaInfoSnapshot> stateMetaInfoSnapshots = new ArrayList<>();
 
-		stateMetaInfoSnapshots.add(new RegisteredOperatorBackendStateMetaInfo<>(
+		stateMetaInfoSnapshots.add(new RegisteredOperatorStateBackendMetaInfo<>(
 			"a", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE).snapshot());
-		stateMetaInfoSnapshots.add(new RegisteredOperatorBackendStateMetaInfo<>(
+		stateMetaInfoSnapshots.add(new RegisteredOperatorStateBackendMetaInfo<>(
 			"b", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE).snapshot());
-		stateMetaInfoSnapshots.add(new RegisteredOperatorBackendStateMetaInfo<>(
+		stateMetaInfoSnapshots.add(new RegisteredOperatorStateBackendMetaInfo<>(
 			"c", stateSerializer, OperatorStateHandle.Mode.UNION).snapshot());
 
 		List<StateMetaInfoSnapshot> broadcastStateMetaInfoSnapshots = new ArrayList<>();
 
-		broadcastStateMetaInfoSnapshots.add(new RegisteredBroadcastBackendStateMetaInfo<>(
+		broadcastStateMetaInfoSnapshots.add(new RegisteredBroadcastStateBackendMetaInfo<>(
 				"d", OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot());
-		broadcastStateMetaInfoSnapshots.add(new RegisteredBroadcastBackendStateMetaInfo<>(
+		broadcastStateMetaInfoSnapshots.add(new RegisteredBroadcastStateBackendMetaInfo<>(
 				"e", OperatorStateHandle.Mode.BROADCAST, valueSerializer, keySerializer).snapshot());
 
 		OperatorBackendSerializationProxy serializationProxy =
@@ -257,7 +257,7 @@ public class SerializationProxiesTest {
 		TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
 
 		StateMetaInfoSnapshot snapshot =
-			new RegisteredOperatorBackendStateMetaInfo<>(
+			new RegisteredOperatorStateBackendMetaInfo<>(
 				name, stateSerializer, OperatorStateHandle.Mode.UNION).snapshot();
 
 		byte[] serialized;
@@ -274,8 +274,8 @@ public class SerializationProxiesTest {
 				new DataInputViewStreamWrapper(in), Thread.currentThread().getContextClassLoader());
 		}
 
-		RegisteredOperatorBackendStateMetaInfo<?> restoredMetaInfo =
-			new RegisteredOperatorBackendStateMetaInfo<>(snapshot);
+		RegisteredOperatorStateBackendMetaInfo<?> restoredMetaInfo =
+			new RegisteredOperatorStateBackendMetaInfo<>(snapshot);
 
 		Assert.assertEquals(name, restoredMetaInfo.getName());
 		Assert.assertEquals(OperatorStateHandle.Mode.UNION, restoredMetaInfo.getAssignmentMode());
@@ -290,7 +290,7 @@ public class SerializationProxiesTest {
 		TypeSerializer<?> valueSerializer = StringSerializer.INSTANCE;
 
 		StateMetaInfoSnapshot snapshot =
-			new RegisteredBroadcastBackendStateMetaInfo<>(
+			new RegisteredBroadcastStateBackendMetaInfo<>(
 				name, OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot();
 
 		byte[] serialized;
@@ -308,8 +308,8 @@ public class SerializationProxiesTest {
 				new DataInputViewStreamWrapper(in), Thread.currentThread().getContextClassLoader());
 		}
 
-		RegisteredBroadcastBackendStateMetaInfo<?, ?> restoredMetaInfo =
-			new RegisteredBroadcastBackendStateMetaInfo<>(snapshot);
+		RegisteredBroadcastStateBackendMetaInfo<?, ?> restoredMetaInfo =
+			new RegisteredBroadcastStateBackendMetaInfo<>(snapshot);
 
 		Assert.assertEquals(name, restoredMetaInfo.getName());
 		Assert.assertEquals(
@@ -325,7 +325,7 @@ public class SerializationProxiesTest {
 		TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
 
 		StateMetaInfoSnapshot snapshot =
-			new RegisteredOperatorBackendStateMetaInfo<>(
+			new RegisteredOperatorStateBackendMetaInfo<>(
 				name, stateSerializer, OperatorStateHandle.Mode.UNION).snapshot();
 
 		byte[] serialized;
@@ -348,8 +348,8 @@ public class SerializationProxiesTest {
 			snapshot = reader.readStateMetaInfoSnapshot(new DataInputViewStreamWrapper(in), classLoader);
 		}
 
-		RegisteredOperatorBackendStateMetaInfo<?> restoredMetaInfo =
-			new RegisteredOperatorBackendStateMetaInfo<>(snapshot);
+		RegisteredOperatorStateBackendMetaInfo<?> restoredMetaInfo =
+			new RegisteredOperatorStateBackendMetaInfo<>(snapshot);
 
 		Assert.assertEquals(name, restoredMetaInfo.getName());
 		Assert.assertTrue(restoredMetaInfo.getPartitionStateSerializer() instanceof UnloadableDummyTypeSerializer);
@@ -365,7 +365,7 @@ public class SerializationProxiesTest {
 		TypeSerializer<?> valueSerializer = StringSerializer.INSTANCE;
 
 		StateMetaInfoSnapshot snapshot =
-			new RegisteredBroadcastBackendStateMetaInfo<>(
+			new RegisteredBroadcastStateBackendMetaInfo<>(
 				broadcastName, OperatorStateHandle.Mode.BROADCAST, keySerializer, valueSerializer).snapshot();
 
 		byte[] serialized;
@@ -393,8 +393,8 @@ public class SerializationProxiesTest {
 			snapshot = reader.readStateMetaInfoSnapshot(new DataInputViewStreamWrapper(in), classLoader);
 		}
 
-		RegisteredBroadcastBackendStateMetaInfo<?, ?> restoredMetaInfo =
-			new RegisteredBroadcastBackendStateMetaInfo<>(snapshot);
+		RegisteredBroadcastStateBackendMetaInfo<?, ?> restoredMetaInfo =
+			new RegisteredBroadcastStateBackendMetaInfo<>(snapshot);
 
 		Assert.assertEquals(broadcastName, restoredMetaInfo.getName());
 		Assert.assertEquals(OperatorStateHandle.Mode.BROADCAST, restoredMetaInfo.getAssignmentMode());

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
index 558f629..355387d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotCompressionTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import org.apache.flink.runtime.state.internal.InternalValueState;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
@@ -55,7 +56,7 @@ public class StateSnapshotCompressionTest extends TestLogger {
 			true,
 			executionConfig,
 			TestLocalRecoveryConfig.disabled(),
-			mock(PriorityQueueSetFactory.class),
+			mock(HeapPriorityQueueSetFactory.class),
 			TtlTimeProvider.DEFAULT);
 
 		try {
@@ -79,7 +80,7 @@ public class StateSnapshotCompressionTest extends TestLogger {
 			true,
 			executionConfig,
 			TestLocalRecoveryConfig.disabled(),
-			mock(PriorityQueueSetFactory.class),
+			mock(HeapPriorityQueueSetFactory.class),
 			TtlTimeProvider.DEFAULT);
 
 		try {
@@ -121,7 +122,7 @@ public class StateSnapshotCompressionTest extends TestLogger {
 			true,
 			executionConfig,
 			TestLocalRecoveryConfig.disabled(),
-			mock(PriorityQueueSetFactory.class),
+			mock(HeapPriorityQueueSetFactory.class),
 			TtlTimeProvider.DEFAULT);
 
 		try {
@@ -164,7 +165,7 @@ public class StateSnapshotCompressionTest extends TestLogger {
 			true,
 			executionConfig,
 			TestLocalRecoveryConfig.disabled(),
-			mock(PriorityQueueSetFactory.class),
+			mock(HeapPriorityQueueSetFactory.class),
 			TtlTimeProvider.DEFAULT);
 		try {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
index cf6bcc8..2c48e4b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
@@ -31,7 +31,7 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.state.ArrayListSerializer;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.util.TestLogger;
@@ -53,8 +53,8 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 	 */
 	@Test
 	public void testPutGetRemoveContainsTransform() throws Exception {
-		RegisteredKeyedBackendStateMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyedBackendStateMetaInfo<>(
+		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
+			new RegisteredKeyValueStateBackendMetaInfo<>(
 				StateDescriptor.Type.UNKNOWN,
 				"test",
 				IntSerializer.INSTANCE,
@@ -125,8 +125,8 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 	 */
 	@Test
 	public void testIncrementalRehash() {
-		RegisteredKeyedBackendStateMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyedBackendStateMetaInfo<>(
+		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
+			new RegisteredKeyValueStateBackendMetaInfo<>(
 				StateDescriptor.Type.UNKNOWN,
 				"test",
 				IntSerializer.INSTANCE,
@@ -170,8 +170,8 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 	@Test
 	public void testRandomModificationsAndCopyOnWriteIsolation() throws Exception {
 
-		final RegisteredKeyedBackendStateMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyedBackendStateMetaInfo<>(
+		final RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
+			new RegisteredKeyValueStateBackendMetaInfo<>(
 				StateDescriptor.Type.UNKNOWN,
 				"test",
 				IntSerializer.INSTANCE,
@@ -325,8 +325,8 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 	 */
 	@Test
 	public void testCopyOnWriteContracts() {
-		RegisteredKeyedBackendStateMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyedBackendStateMetaInfo<>(
+		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
+			new RegisteredKeyValueStateBackendMetaInfo<>(
 				StateDescriptor.Type.UNKNOWN,
 				"test",
 				IntSerializer.INSTANCE,
@@ -356,7 +356,7 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 
 		// no snapshot taken, we get the original back
 		Assert.assertTrue(stateTable.get(1, 1) == originalState1);
-		CopyOnWriteStateTableSnapshot<Integer, Integer, ArrayList<Integer>> snapshot1 = stateTable.createSnapshot();
+		CopyOnWriteStateTableSnapshot<Integer, Integer, ArrayList<Integer>> snapshot1 = stateTable.stateSnapshot();
 		// after snapshot1 is taken, we get a copy...
 		final ArrayList<Integer> copyState = stateTable.get(1, 1);
 		Assert.assertFalse(copyState == originalState1);
@@ -370,7 +370,7 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 		Assert.assertTrue(copyState == stateTable.get(1, 1));
 
 		// we take snapshot2
-		CopyOnWriteStateTableSnapshot<Integer, Integer, ArrayList<Integer>> snapshot2 = stateTable.createSnapshot();
+		CopyOnWriteStateTableSnapshot<Integer, Integer, ArrayList<Integer>> snapshot2 = stateTable.stateSnapshot();
 		// after the second snapshot, copy-on-write is active again for old entries
 		Assert.assertFalse(copyState == stateTable.get(1, 1));
 		// and equality still holds
@@ -400,8 +400,8 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 		final TestDuplicateSerializer stateSerializer = new TestDuplicateSerializer();
 		final TestDuplicateSerializer keySerializer = new TestDuplicateSerializer();
 
-		RegisteredKeyedBackendStateMetaInfo<Integer, Integer> metaInfo =
-			new RegisteredKeyedBackendStateMetaInfo<>(
+		RegisteredKeyValueStateBackendMetaInfo<Integer, Integer> metaInfo =
+			new RegisteredKeyValueStateBackendMetaInfo<>(
 				StateDescriptor.Type.VALUE,
 				"test",
 				namespaceSerializer,
@@ -443,15 +443,15 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 		table.put(2, 0, 1, 2);
 
 
-		final CopyOnWriteStateTableSnapshot<Integer, Integer, Integer> snapshot = table.createSnapshot();
+		final CopyOnWriteStateTableSnapshot<Integer, Integer, Integer> snapshot = table.stateSnapshot();
 
 		try {
-			final StateSnapshot.KeyGroupPartitionedSnapshot partitionedSnapshot = snapshot.partitionByKeyGroup();
+			final StateSnapshot.StateKeyGroupWriter partitionedSnapshot = snapshot.getKeyGroupWriter();
 			namespaceSerializer.disable();
 			keySerializer.disable();
 			stateSerializer.disable();
 
-			partitionedSnapshot.writeMappingsInKeyGroup(
+			partitionedSnapshot.writeStateInKeyGroup(
 				new DataOutputViewStreamWrapper(
 					new ByteArrayOutputStreamWithPos(1024)), 0);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
index 45c86c7..41ef0c9 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
@@ -27,8 +27,9 @@ import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.state.ArrayListSerializer;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
-import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
 
 import org.junit.Assert;
 import org.junit.Test;
@@ -46,8 +47,8 @@ public class StateTableSnapshotCompatibilityTest {
 	@Test
 	public void checkCompatibleSerializationFormats() throws IOException {
 		final Random r = new Random(42);
-		RegisteredKeyedBackendStateMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyedBackendStateMetaInfo<>(
+		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
+			new RegisteredKeyValueStateBackendMetaInfo<>(
 				StateDescriptor.Type.UNKNOWN,
 				"test",
 				IntSerializer.INSTANCE,
@@ -69,7 +70,7 @@ public class StateTableSnapshotCompatibilityTest {
 			cowStateTable.put(r.nextInt(10), r.nextInt(2), list);
 		}
 
-		StateSnapshot snapshot = cowStateTable.createSnapshot();
+		StateSnapshot snapshot = cowStateTable.stateSnapshot();
 
 		final NestedMapsStateTable<Integer, Integer, ArrayList<Integer>> nestedMapsStateTable =
 			new NestedMapsStateTable<>(keyContext, metaInfo);
@@ -83,7 +84,7 @@ public class StateTableSnapshotCompatibilityTest {
 			Assert.assertEquals(entry.getState(), nestedMapsStateTable.get(entry.getKey(), entry.getNamespace()));
 		}
 
-		snapshot = nestedMapsStateTable.createSnapshot();
+		snapshot = nestedMapsStateTable.stateSnapshot();
 		cowStateTable = new CopyOnWriteStateTable<>(keyContext, metaInfo);
 
 		restoreStateTableFromSnapshot(cowStateTable, snapshot, keyContext.getKeyGroupRange());
@@ -102,15 +103,15 @@ public class StateTableSnapshotCompatibilityTest {
 
 		final ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(1024 * 1024);
 		final DataOutputViewStreamWrapper dov = new DataOutputViewStreamWrapper(out);
-		final StateSnapshot.KeyGroupPartitionedSnapshot keyGroupPartitionedSnapshot = snapshot.partitionByKeyGroup();
+		final StateSnapshot.StateKeyGroupWriter keyGroupPartitionedSnapshot = snapshot.getKeyGroupWriter();
 		for (Integer keyGroup : keyGroupRange) {
-			keyGroupPartitionedSnapshot.writeMappingsInKeyGroup(dov, keyGroup);
+			keyGroupPartitionedSnapshot.writeStateInKeyGroup(dov, keyGroup);
 		}
 
 		final ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(out.getBuf());
 		final DataInputViewStreamWrapper div = new DataInputViewStreamWrapper(in);
 
-		final StateTableByKeyGroupReader keyGroupReader =
+		final StateSnapshotKeyGroupReader keyGroupReader =
 			StateTableByKeyGroupReaders.readerForVersion(stateTable, KeyedBackendSerializationProxy.VERSION);
 
 		for (Integer keyGroup : keyGroupRange) {

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/test/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotEnumConstantsTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotEnumConstantsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotEnumConstantsTest.java
index 409c796..e196b1a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotEnumConstantsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotEnumConstantsTest.java
@@ -32,11 +32,11 @@ public class StateMetaInfoSnapshotEnumConstantsTest {
 		Assert.assertEquals(0, StateMetaInfoSnapshot.BackendStateType.KEY_VALUE.ordinal());
 		Assert.assertEquals(1, StateMetaInfoSnapshot.BackendStateType.OPERATOR.ordinal());
 		Assert.assertEquals(2, StateMetaInfoSnapshot.BackendStateType.BROADCAST.ordinal());
-		Assert.assertEquals(3, StateMetaInfoSnapshot.BackendStateType.TIMER.ordinal());
+		Assert.assertEquals(3, StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE.ordinal());
 		Assert.assertEquals("KEY_VALUE", StateMetaInfoSnapshot.BackendStateType.KEY_VALUE.toString());
 		Assert.assertEquals("OPERATOR", StateMetaInfoSnapshot.BackendStateType.OPERATOR.toString());
 		Assert.assertEquals("BROADCAST", StateMetaInfoSnapshot.BackendStateType.BROADCAST.toString());
-		Assert.assertEquals("TIMER", StateMetaInfoSnapshot.BackendStateType.TIMER.toString());
+		Assert.assertEquals("PRIORITY_QUEUE", StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE.toString());
 	}
 
 	@Test

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index 363ecf8..9e9328b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -36,8 +36,10 @@ import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
+import org.apache.flink.runtime.state.Keyed;
 import org.apache.flink.runtime.state.KeyedStateFactory;
 import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.PriorityComparable;
 import org.apache.flink.runtime.state.PriorityComparator;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.SnapshotResult;
@@ -118,6 +120,11 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@Override
+	public boolean requiresLegacySynchronousTimerSnapshots() {
+		return false;
+	}
+
+	@Override
 	public void notifyCheckpointComplete(long checkpointId) {
 		// noop
 	}
@@ -167,15 +174,13 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	@Nonnull
 	@Override
-	public <T extends HeapPriorityQueueElement> KeyGroupedInternalPriorityQueue<T>
+	public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T>
 	create(
 		@Nonnull String stateName,
-		@Nonnull TypeSerializer<T> byteOrderedElementSerializer,
-		@Nonnull PriorityComparator<T> elementPriorityComparator,
-		@Nonnull KeyExtractorFunction<T> keyExtractor) {
-		return new HeapPriorityQueueSet<>(
-			elementPriorityComparator,
-			keyExtractor,
+		@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
+		return new HeapPriorityQueueSet<T>(
+			PriorityComparator.forPriorityComparableObjects(),
+			KeyExtractorFunction.forKeyedObjects(),
 			0,
 			keyGroupRange,
 			0);

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingState.java
----------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingState.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingState.java
index ceae3e1..209d18f 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingState.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBAggregatingState.java
@@ -27,7 +27,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
-import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
 import org.apache.flink.util.FlinkRuntimeException;
 
@@ -177,7 +177,7 @@ class RocksDBAggregatingState<K, N, T, ACC, R>
 	@SuppressWarnings("unchecked")
 	static <K, N, SV, S extends State, IS extends S> IS create(
 		StateDescriptor<S, SV> stateDesc,
-		Tuple2<ColumnFamilyHandle, RegisteredKeyedBackendStateMetaInfo<N, SV>> registerResult,
+		Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> registerResult,
 		RocksDBKeyedStateBackend<K> backend) {
 		return (IS) new RocksDBAggregatingState<>(
 			registerResult.f0,

http://git-wip-us.apache.org/repos/asf/flink/blob/dbddf00b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBFoldingState.java
----------------------------------------------------------------------
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBFoldingState.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBFoldingState.java
index cf7974f..4d66357 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBFoldingState.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBFoldingState.java
@@ -25,7 +25,7 @@ import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.internal.InternalFoldingState;
 
 import org.rocksdb.ColumnFamilyHandle;
@@ -103,7 +103,7 @@ class RocksDBFoldingState<K, N, T, ACC>
 	@SuppressWarnings("unchecked")
 	static <K, N, SV, S extends State, IS extends S> IS create(
 		StateDescriptor<S, SV> stateDesc,
-		Tuple2<ColumnFamilyHandle, RegisteredKeyedBackendStateMetaInfo<N, SV>> registerResult,
+		Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> registerResult,
 		RocksDBKeyedStateBackend<K> backend) {
 		return (IS) new RocksDBFoldingState<>(
 			registerResult.f0,