You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2018/08/02 15:43:45 UTC

[flink] branch master updated (1fe5e4b -> 6aaf5ed)

This is an automated email from the ASF dual-hosted git repository.

srichter pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git.


    from 1fe5e4b  [hotfix] [e2e] Remove explicit Maven plugin version
     new 9d273a3  [FLINK-9887][state] Integrate priority queue state with existing serializer upgrade mechanism
     new 6aaf5ed  [hotfix] Minor cleanups in LongSerializer

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../api/common/typeutils/base/LongSerializer.java  |  12 +-
 .../apache/flink/util/StateMigrationException.java |   6 +
 .../runtime/state/AbstractKeyedStateBackend.java   |   2 +-
 .../runtime/state/DefaultOperatorStateBackend.java |   4 +-
 .../RegisteredKeyValueStateBackendMetaInfo.java    |  10 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  | 459 +++++++++++------
 .../state/heap/HeapPriorityQueueSetFactory.java    |   2 +-
 .../HeapPriorityQueueSnapshotRestoreWrapper.java   |  22 +
 .../state/InternalPriorityQueueTestBase.java       |  69 ++-
 .../runtime/state/MemoryStateBackendTest.java      |  16 +-
 .../flink/runtime/state/StateBackendTestBase.java  | 549 ++++++++++++++-------
 ...HeapKeyedStateBackendSnapshotMigrationTest.java |   2 +-
 .../state/ttl/mock/MockKeyedStateBackend.java      |   2 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  |  51 +-
 .../streaming/state/RocksDBStateBackendTest.java   |   6 +
 flink-streaming-java/pom.xml                       |   2 +
 .../api/operators/InternalTimeServiceManager.java  |  18 +-
 .../operators/StreamTaskStateInitializerImpl.java  |   1 -
 .../streaming/api/operators/TimerSerializer.java   |  57 ++-
 .../operators/InternalTimeServiceManagerTest.java  |  23 +-
 .../api/operators/TimerSerializerTest.java         |  62 +++
 .../operators/windowing/TriggerTestHarness.java    |   4 +-
 .../KeyedOneInputStreamOperatorTestHarness.java    |   4 +-
 .../KeyedTwoInputStreamOperatorTestHarness.java    |   2 +-
 24 files changed, 988 insertions(+), 397 deletions(-)
 copy flink-runtime/src/test/java/org/apache/flink/runtime/rest/handler/legacy/SubtaskCurrentAttemptDetailsHandlerTest.java => flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerTest.java (57%)
 create mode 100644 flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/TimerSerializerTest.java


[flink] 02/02: [hotfix] Minor cleanups in LongSerializer

Posted by sr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 6aaf5ed897d6b14241957b1fdc98009e5bec05d7
Author: Stefan Richter <s....@data-artisans.com>
AuthorDate: Mon Jul 30 19:06:48 2018 +0200

    [hotfix] Minor cleanups in LongSerializer
---
 .../flink/api/common/typeutils/base/LongSerializer.java      | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/LongSerializer.java b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/LongSerializer.java
index cbdc3db..2ed2cec 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/LongSerializer.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/typeutils/base/LongSerializer.java
@@ -18,12 +18,12 @@
 
 package org.apache.flink.api.common.typeutils.base;
 
-import java.io.IOException;
-
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
 
+import java.io.IOException;
+
 @Internal
 public final class LongSerializer extends TypeSerializerSingleton<Long> {
 
@@ -31,7 +31,7 @@ public final class LongSerializer extends TypeSerializerSingleton<Long> {
 	
 	public static final LongSerializer INSTANCE = new LongSerializer();
 	
-	private static final Long ZERO = Long.valueOf(0);
+	private static final Long ZERO = 0L;
 
 	@Override
 	public boolean isImmutableType() {
@@ -55,17 +55,17 @@ public final class LongSerializer extends TypeSerializerSingleton<Long> {
 
 	@Override
 	public int getLength() {
-		return 8;
+		return Long.BYTES;
 	}
 
 	@Override
 	public void serialize(Long record, DataOutputView target) throws IOException {
-		target.writeLong(record.longValue());
+		target.writeLong(record);
 	}
 
 	@Override
 	public Long deserialize(DataInputView source) throws IOException {
-		return Long.valueOf(source.readLong());
+		return source.readLong();
 	}
 	
 	@Override


[flink] 01/02: [FLINK-9887][state] Integrate priority queue state with existing serializer upgrade mechanism

Posted by sr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 9d273a3fcda4033e1a385c0ff0c4a2b7ad640721
Author: Stefan Richter <s....@data-artisans.com>
AuthorDate: Fri Jul 27 14:49:27 2018 +0200

    [FLINK-9887][state] Integrate priority queue state with existing serializer upgrade mechanism
    
    This closes #6467.
---
 .../apache/flink/util/StateMigrationException.java |   6 +
 .../runtime/state/AbstractKeyedStateBackend.java   |   2 +-
 .../runtime/state/DefaultOperatorStateBackend.java |   4 +-
 .../RegisteredKeyValueStateBackendMetaInfo.java    |  10 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  | 459 +++++++++++------
 .../state/heap/HeapPriorityQueueSetFactory.java    |   2 +-
 .../HeapPriorityQueueSnapshotRestoreWrapper.java   |  22 +
 .../state/InternalPriorityQueueTestBase.java       |  69 ++-
 .../runtime/state/MemoryStateBackendTest.java      |  16 +-
 .../flink/runtime/state/StateBackendTestBase.java  | 549 ++++++++++++++-------
 ...HeapKeyedStateBackendSnapshotMigrationTest.java |   2 +-
 .../state/ttl/mock/MockKeyedStateBackend.java      |   2 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  |  51 +-
 .../streaming/state/RocksDBStateBackendTest.java   |   6 +
 flink-streaming-java/pom.xml                       |   2 +
 .../api/operators/InternalTimeServiceManager.java  |  18 +-
 .../operators/StreamTaskStateInitializerImpl.java  |   1 -
 .../streaming/api/operators/TimerSerializer.java   |  57 ++-
 .../operators/InternalTimeServiceManagerTest.java  |  31 +-
 .../api/operators/TimerSerializerTest.java         |  62 +++
 .../operators/windowing/TriggerTestHarness.java    |   4 +-
 .../KeyedOneInputStreamOperatorTestHarness.java    |   4 +-
 .../KeyedTwoInputStreamOperatorTestHarness.java    |   2 +-
 23 files changed, 987 insertions(+), 394 deletions(-)

diff --git a/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java b/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
index 00e0e73..12f3ee4 100644
--- a/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
+++ b/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
@@ -24,6 +24,8 @@ package org.apache.flink.util;
 public class StateMigrationException extends FlinkException {
 	private static final long serialVersionUID = 8268516412747670839L;
 
+	public static final String MIGRATION_NOT_SUPPORTED_MSG = "State migration is currently not supported.";
+
 	public StateMigrationException(String message) {
 		super(message);
 	}
@@ -35,4 +37,8 @@ public class StateMigrationException extends FlinkException {
 	public StateMigrationException(String message, Throwable cause) {
 		super(message, cause);
 	}
+
+	public static StateMigrationException notSupported() {
+		return new StateMigrationException(MIGRATION_NOT_SUPPORTED_MSG);
+	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
index 17d24f77..1c2d2a3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/AbstractKeyedStateBackend.java
@@ -313,7 +313,7 @@ public abstract class AbstractKeyedStateBackend<K> implements
 	 * Returns the total number of state entries across all keys/namespaces.
 	 */
 	@VisibleForTesting
-	public abstract int numStateEntries();
+	public abstract int numKeyValueStateEntries();
 
 	// TODO remove this once heap-based timers are working with RocksDB incremental snapshots!
 	public boolean requiresLegacySynchronousTimerSnapshots() {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index f1d0b57..dfff50d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -260,7 +260,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 				// the new serializer; we're deliberately failing here for now to have equal functionality with
 				// the RocksDB backend to avoid confusion for users.
 
-				throw new StateMigrationException("State migration isn't supported, yet.");
+				throw StateMigrationException.notSupported();
 			}
 		}
 
@@ -781,7 +781,7 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 				// the new serializer; we're deliberately failing here for now to have equal functionality with
 				// the RocksDB backend to avoid confusion for users.
 
-				throw new StateMigrationException("State migration isn't supported, yet.");
+				throw StateMigrationException.notSupported();
 			}
 		}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
index d49a05c..b0248fc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
@@ -144,6 +144,12 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 		TypeSerializer<N> newNamespaceSerializer,
 		StateDescriptor<?, S> newStateDescriptor) throws StateMigrationException {
 
+		Preconditions.checkState(restoredStateMetaInfoSnapshot.getBackendStateType()
+				== StateMetaInfoSnapshot.BackendStateType.KEY_VALUE,
+			"Incompatible state types. " +
+				"Was [" + restoredStateMetaInfoSnapshot.getBackendStateType() + "], " +
+				"registered as [" + StateMetaInfoSnapshot.BackendStateType.KEY_VALUE + "].");
+
 		Preconditions.checkState(
 			Objects.equals(newStateDescriptor.getName(), restoredStateMetaInfoSnapshot.getName()),
 			"Incompatible state names. " +
@@ -160,7 +166,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 
 			Preconditions.checkState(
 				newStateDescriptor.getType() == restoredType,
-				"Incompatible state types. " +
+				"Incompatible key/value state types. " +
 					"Was [" + restoredType + "], " +
 					"registered with [" + newStateDescriptor.getType() + "].");
 		}
@@ -184,7 +190,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 
 		if (namespaceCompatibility.isRequiresMigration() || stateCompatibility.isRequiresMigration()) {
 			// TODO state migration currently isn't possible.
-			throw new StateMigrationException("State migration isn't supported, yet.");
+			throw StateMigrationException.notSupported();
 		} else {
 			return new RegisteredKeyValueStateBackendMetaInfo<>(
 				newStateDescriptor.getType(),
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 2c6101e..34c9698 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -28,6 +28,7 @@ import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.CompatibilityResult;
 import org.apache.flink.api.common.typeutils.CompatibilityUtil;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer;
@@ -79,6 +80,7 @@ import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
 
+import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.util.ArrayList;
@@ -111,60 +113,15 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			Tuple2.of(FoldingStateDescriptor.class, (StateFactory) HeapFoldingState::create)
 		).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
-	@SuppressWarnings("unchecked")
-	@Nonnull
-	@Override
-	public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> create(
-		@Nonnull String stateName,
-		@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
-
-		final StateSnapshotRestore snapshotRestore = registeredStates.get(stateName);
-
-		if (snapshotRestore instanceof HeapPriorityQueueSnapshotRestoreWrapper) {
-			//TODO Serializer upgrade story!?
-			return ((HeapPriorityQueueSnapshotRestoreWrapper<T>) snapshotRestore).getPriorityQueue();
-		} else if (snapshotRestore != null) {
-			throw new IllegalStateException("Already found a different state type registered under this name: " + snapshotRestore.getClass());
-		}
-
-		final RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo =
-			new RegisteredPriorityQueueStateBackendMetaInfo<>(stateName, byteOrderedElementSerializer);
-
-		return createInternal(metaInfo);
-	}
-
-	@Nonnull
-	private <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> createInternal(
-		RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
-
-		final String stateName = metaInfo.getName();
-		final HeapPriorityQueueSet<T> priorityQueue = priorityQueueSetFactory.create(
-			stateName,
-			metaInfo.getElementSerializer());
-
-		HeapPriorityQueueSnapshotRestoreWrapper<T> wrapper =
-			new HeapPriorityQueueSnapshotRestoreWrapper<>(
-				priorityQueue,
-				metaInfo,
-				KeyExtractorFunction.forKeyedObjects(),
-				keyGroupRange,
-				numberOfKeyGroups);
-
-		registeredStates.put(stateName, wrapper);
-		return priorityQueue;
-	}
-
-	private interface StateFactory {
-		<K, N, SV, S extends State, IS extends S> IS createState(
-			StateDescriptor<S, SV> stateDesc,
-			StateTable<K, N, SV> stateTable,
-			TypeSerializer<K> keySerializer) throws Exception;
-	}
+	/**
+	 * Map of registered Key/Value states.
+	 */
+	private final Map<String, StateTable<K, ?, ?>> registeredKVStates;
 
 	/**
-	 * Map of registered states for snapshot/restore.
+	 * Map of registered priority queue set states.
 	 */
-	private final Map<String, StateSnapshotRestore> registeredStates = new HashMap<>();
+	private final Map<String, HeapPriorityQueueSnapshotRestoreWrapper> registeredPQStates;
 
 	/**
 	 * Map of state names to their corresponding restored state meta info.
@@ -172,7 +129,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * <p>TODO this map can be removed when eager-state registration is in place.
 	 * TODO we currently need this cached to check state migration strategies when new serializers are registered.
 	 */
-	private final Map<String, StateMetaInfoSnapshot> restoredKvStateMetaInfos;
+	private final Map<StateUID, StateMetaInfoSnapshot> restoredStateMetaInfo;
 
 	/**
 	 * The configuration for local recovery.
@@ -203,6 +160,9 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		super(kvStateRegistry, keySerializer, userCodeClassLoader,
 			numberOfKeyGroups, keyGroupRange, executionConfig, ttlTimeProvider);
+
+		this.registeredKVStates = new HashMap<>();
+		this.registeredPQStates = new HashMap<>();
 		this.localRecoveryConfig = Preconditions.checkNotNull(localRecoveryConfig);
 
 		SnapshotStrategySynchronicityBehavior<K> synchronicityTrait = asynchronousSnapshots ?
@@ -211,7 +171,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		this.snapshotStrategy = new HeapSnapshotStrategy(synchronicityTrait);
 		LOG.info("Initializing heap keyed state backend with stream factory.");
-		this.restoredKvStateMetaInfos = new HashMap<>();
+		this.restoredStateMetaInfo = new HashMap<>();
 		this.priorityQueueSetFactory = priorityQueueSetFactory;
 	}
 
@@ -219,17 +179,85 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	//  state backend operations
 	// ------------------------------------------------------------------------
 
+	@SuppressWarnings("unchecked")
+	@Nonnull
+	@Override
+	public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> create(
+		@Nonnull String stateName,
+		@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
+
+		final HeapPriorityQueueSnapshotRestoreWrapper existingState = registeredPQStates.get(stateName);
+
+		if (existingState != null) {
+			// TODO we implement the simple way of supporting the current functionality, mimicking keyed state
+			// because this should be reworked in FLINK-9376 and then we should have a common algorithm over
+			// StateMetaInfoSnapshot that avoids this code duplication.
+			StateMetaInfoSnapshot restoredMetaInfoSnapshot =
+				restoredStateMetaInfo.get(StateUID.of(stateName, StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE));
+
+			Preconditions.checkState(
+				restoredMetaInfoSnapshot != null,
+				"Requested to check compatibility of a restored RegisteredKeyedBackendStateMetaInfo," +
+					" but its corresponding restored snapshot cannot be found.");
+
+			StateMetaInfoSnapshot.CommonSerializerKeys serializerKey =
+				StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER;
+
+			CompatibilityResult<T> compatibilityResult = CompatibilityUtil.resolveCompatibilityResult(
+				restoredMetaInfoSnapshot.getTypeSerializer(serializerKey),
+				null,
+				restoredMetaInfoSnapshot.getTypeSerializerConfigSnapshot(serializerKey),
+				byteOrderedElementSerializer);
+
+			if (compatibilityResult.isRequiresMigration()) {
+				throw new FlinkRuntimeException(StateMigrationException.notSupported());
+			} else {
+				registeredPQStates.put(
+					stateName,
+					existingState.forUpdatedSerializer(byteOrderedElementSerializer));
+			}
+
+			return existingState.getPriorityQueue();
+		} else {
+			final RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo =
+				new RegisteredPriorityQueueStateBackendMetaInfo<>(stateName, byteOrderedElementSerializer);
+			return createInternal(metaInfo);
+		}
+	}
+
+	@Nonnull
+	private <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T> createInternal(
+		RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo) {
+
+		final String stateName = metaInfo.getName();
+		final HeapPriorityQueueSet<T> priorityQueue = priorityQueueSetFactory.create(
+			stateName,
+			metaInfo.getElementSerializer());
+
+		HeapPriorityQueueSnapshotRestoreWrapper<T> wrapper =
+			new HeapPriorityQueueSnapshotRestoreWrapper<>(
+				priorityQueue,
+				metaInfo,
+				KeyExtractorFunction.forKeyedObjects(),
+				keyGroupRange,
+				numberOfKeyGroups);
+
+		registeredPQStates.put(stateName, wrapper);
+		return priorityQueue;
+	}
+
 	private <N, V> StateTable<K, N, V> tryRegisterStateTable(
 			TypeSerializer<N> namespaceSerializer, StateDescriptor<?, V> stateDesc) throws StateMigrationException {
 
 		@SuppressWarnings("unchecked")
-		StateTable<K, N, V> stateTable = (StateTable<K, N, V>) registeredStates.get(stateDesc.getName());
+		StateTable<K, N, V> stateTable = (StateTable<K, N, V>) registeredKVStates.get(stateDesc.getName());
 
 		RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo;
 		if (stateTable != null) {
 			@SuppressWarnings("unchecked")
 			StateMetaInfoSnapshot restoredMetaInfoSnapshot =
-				restoredKvStateMetaInfos.get(stateDesc.getName());
+				restoredStateMetaInfo.get(
+					StateUID.of(stateDesc.getName(), StateMetaInfoSnapshot.BackendStateType.KEY_VALUE));
 
 			Preconditions.checkState(
 				restoredMetaInfoSnapshot != null,
@@ -250,7 +278,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				stateDesc.getSerializer());
 
 			stateTable = snapshotStrategy.newStateTable(newMetaInfo);
-			registeredStates.put(stateDesc.getName(), stateTable);
+			registeredKVStates.put(stateDesc.getName(), stateTable);
 		}
 
 		return stateTable;
@@ -259,20 +287,17 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	@SuppressWarnings("unchecked")
 	@Override
 	public <N> Stream<K> getKeys(String state, N namespace) {
-		if (!registeredStates.containsKey(state)) {
+		if (!registeredKVStates.containsKey(state)) {
 			return Stream.empty();
 		}
 
-		final StateSnapshotRestore stateSnapshotRestore = registeredStates.get(state);
-		if (!(stateSnapshotRestore instanceof StateTable)) {
-			return Stream.empty();
-		}
+		final StateSnapshotRestore stateSnapshotRestore = registeredKVStates.get(state);
 		StateTable<K, N, ?> table = (StateTable<K, N, ?>) stateSnapshotRestore;
 		return table.getKeys(namespace);
 	}
 
 	private boolean hasRegisteredState() {
-		return !registeredStates.isEmpty();
+		return !(registeredKVStates.isEmpty() && registeredPQStates.isEmpty());
 	}
 
 	@Override
@@ -318,9 +343,9 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	@SuppressWarnings({"unchecked"})
 	private void restorePartitionedState(Collection<KeyedStateHandle> state) throws Exception {
 
-		final Map<Integer, String> kvStatesById = new HashMap<>();
-		int numRegisteredKvStates = 0;
-		registeredStates.clear();
+		final Map<Integer, StateMetaInfoSnapshot> kvStatesById = new HashMap<>();
+		registeredKVStates.clear();
+		registeredPQStates.clear();
 
 		boolean keySerializerRestored = false;
 
@@ -369,70 +394,131 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				}
 
 				List<StateMetaInfoSnapshot> restoredMetaInfos =
-						serializationProxy.getStateMetaInfoSnapshots();
-
-				for (StateMetaInfoSnapshot restoredMetaInfo : restoredMetaInfos) {
-					restoredKvStateMetaInfos.put(restoredMetaInfo.getName(), restoredMetaInfo);
-
-					StateSnapshotRestore snapshotRestore = registeredStates.get(restoredMetaInfo.getName());
+					serializationProxy.getStateMetaInfoSnapshots();
 
-					//important: only create a new table we did not already create it previously
-					if (null == snapshotRestore) {
+				createOrCheckStateForMetaInfo(restoredMetaInfos, kvStatesById);
 
-						if (restoredMetaInfo.getBackendStateType() == StateMetaInfoSnapshot.BackendStateType.KEY_VALUE) {
-							RegisteredKeyValueStateBackendMetaInfo<?, ?> registeredKeyedBackendStateMetaInfo =
-								new RegisteredKeyValueStateBackendMetaInfo<>(restoredMetaInfo);
-
-							snapshotRestore = snapshotStrategy.newStateTable(registeredKeyedBackendStateMetaInfo);
-							registeredStates.put(restoredMetaInfo.getName(), snapshotRestore);
-						} else {
-							createInternal(new RegisteredPriorityQueueStateBackendMetaInfo<>(restoredMetaInfo));
-						}
-						kvStatesById.put(numRegisteredKvStates, restoredMetaInfo.getName());
-						++numRegisteredKvStates;
-					} else {
-						// TODO with eager state registration in place, check here for serializer migration strategies
-					}
+				readStateHandleStateData(
+					fsDataInputStream,
+					inView,
+					keyGroupsStateHandle.getGroupRangeOffsets(),
+					kvStatesById, restoredMetaInfos.size(),
+					serializationProxy.getReadVersion(),
+					serializationProxy.isUsingKeyGroupCompression());
+			} finally {
+				if (cancelStreamRegistry.unregisterCloseable(fsDataInputStream)) {
+					IOUtils.closeQuietly(fsDataInputStream);
 				}
+			}
+		}
+	}
 
-				final StreamCompressionDecorator streamCompressionDecorator = serializationProxy.isUsingKeyGroupCompression() ?
-					SnappyStreamCompressionDecorator.INSTANCE : UncompressedStreamCompressionDecorator.INSTANCE;
+	private void readStateHandleStateData(
+		FSDataInputStream fsDataInputStream,
+		DataInputViewStreamWrapper inView,
+		KeyGroupRangeOffsets keyGroupOffsets,
+		Map<Integer, StateMetaInfoSnapshot> kvStatesById,
+		int numStates,
+		int readVersion,
+		boolean isCompressed) throws IOException {
 
-				for (Tuple2<Integer, Long> groupOffset : keyGroupsStateHandle.getGroupRangeOffsets()) {
-					int keyGroupIndex = groupOffset.f0;
-					long offset = groupOffset.f1;
+		final StreamCompressionDecorator streamCompressionDecorator = isCompressed ?
+			SnappyStreamCompressionDecorator.INSTANCE : UncompressedStreamCompressionDecorator.INSTANCE;
 
-					// Check that restored key groups all belong to the backend.
-					Preconditions.checkState(keyGroupRange.contains(keyGroupIndex), "The key group must belong to the backend.");
+		for (Tuple2<Integer, Long> groupOffset : keyGroupOffsets) {
+			int keyGroupIndex = groupOffset.f0;
+			long offset = groupOffset.f1;
 
-					fsDataInputStream.seek(offset);
+			// Check that restored key groups all belong to the backend.
+			Preconditions.checkState(keyGroupRange.contains(keyGroupIndex), "The key group must belong to the backend.");
 
-					int writtenKeyGroupIndex = inView.readInt();
+			fsDataInputStream.seek(offset);
 
-					try (InputStream kgCompressionInStream =
-							streamCompressionDecorator.decorateWithCompression(fsDataInputStream)) {
+			int writtenKeyGroupIndex = inView.readInt();
+			Preconditions.checkState(writtenKeyGroupIndex == keyGroupIndex,
+				"Unexpected key-group in restore.");
 
-						DataInputViewStreamWrapper kgCompressionInView =
-							new DataInputViewStreamWrapper(kgCompressionInStream);
+			try (InputStream kgCompressionInStream =
+					 streamCompressionDecorator.decorateWithCompression(fsDataInputStream)) {
 
-						Preconditions.checkState(writtenKeyGroupIndex == keyGroupIndex,
-							"Unexpected key-group in restore.");
+				readKeyGroupStateData(
+					kgCompressionInStream,
+					kvStatesById,
+					keyGroupIndex,
+					numStates,
+					readVersion);
+			}
+		}
+	}
 
-						for (int i = 0; i < restoredMetaInfos.size(); i++) {
-							int kvStateId = kgCompressionInView.readShort();
-							StateSnapshotRestore registeredState = registeredStates.get(kvStatesById.get(kvStateId));
+	private void readKeyGroupStateData(
+		InputStream inputStream,
+		Map<Integer, StateMetaInfoSnapshot> kvStatesById,
+		int keyGroupIndex,
+		int numStates,
+		int readVersion) throws IOException {
+
+		DataInputViewStreamWrapper inView =
+			new DataInputViewStreamWrapper(inputStream);
+
+		for (int i = 0; i < numStates; i++) {
+
+			final int kvStateId = inView.readShort();
+			final StateMetaInfoSnapshot stateMetaInfoSnapshot = kvStatesById.get(kvStateId);
+			final StateSnapshotRestore registeredState;
+
+			switch (stateMetaInfoSnapshot.getBackendStateType()) {
+				case KEY_VALUE:
+					registeredState = registeredKVStates.get(stateMetaInfoSnapshot.getName());
+					break;
+				case PRIORITY_QUEUE:
+					registeredState = registeredPQStates.get(stateMetaInfoSnapshot.getName());
+					break;
+				default:
+					throw new IllegalStateException("Unexpected state type: " +
+						stateMetaInfoSnapshot.getBackendStateType() + ".");
+			}
 
-							StateSnapshotKeyGroupReader keyGroupReader =
-								registeredState.keyGroupReader(serializationProxy.getReadVersion());
+			StateSnapshotKeyGroupReader keyGroupReader = registeredState.keyGroupReader(readVersion);
+			keyGroupReader.readMappingsInKeyGroup(inView, keyGroupIndex);
+		}
+	}
 
-							keyGroupReader.readMappingsInKeyGroup(kgCompressionInView, keyGroupIndex);
-						}
+	private void createOrCheckStateForMetaInfo(
+		List<StateMetaInfoSnapshot> restoredMetaInfo,
+		Map<Integer, StateMetaInfoSnapshot> kvStatesById) {
+
+		for (StateMetaInfoSnapshot metaInfoSnapshot : restoredMetaInfo) {
+			restoredStateMetaInfo.put(
+				StateUID.of(metaInfoSnapshot.getName(), metaInfoSnapshot.getBackendStateType()),
+				metaInfoSnapshot);
+
+			final StateSnapshotRestore registeredState;
+
+			switch (metaInfoSnapshot.getBackendStateType()) {
+				case KEY_VALUE:
+					registeredState = registeredKVStates.get(metaInfoSnapshot.getName());
+					if (registeredState == null) {
+						RegisteredKeyValueStateBackendMetaInfo<?, ?> registeredKeyedBackendStateMetaInfo =
+							new RegisteredKeyValueStateBackendMetaInfo<>(metaInfoSnapshot);
+						registeredKVStates.put(
+							metaInfoSnapshot.getName(),
+							snapshotStrategy.newStateTable(registeredKeyedBackendStateMetaInfo));
 					}
-				}
-			} finally {
-				if (cancelStreamRegistry.unregisterCloseable(fsDataInputStream)) {
-					IOUtils.closeQuietly(fsDataInputStream);
-				}
+					break;
+				case PRIORITY_QUEUE:
+					registeredState = registeredPQStates.get(metaInfoSnapshot.getName());
+					if (registeredState == null) {
+						createInternal(new RegisteredPriorityQueueStateBackendMetaInfo<>(metaInfoSnapshot));
+					}
+					break;
+				default:
+					throw new IllegalStateException("Unexpected state type: " +
+						metaInfoSnapshot.getBackendStateType() + ".");
+			}
+
+			if (registeredState == null) {
+				kvStatesById.put(kvStatesById.size(), metaInfoSnapshot);
 			}
 		}
 	}
@@ -478,12 +564,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	@VisibleForTesting
 	@SuppressWarnings("unchecked")
 	@Override
-	public int numStateEntries() {
+	public int numKeyValueStateEntries() {
 		int sum = 0;
-		for (StateSnapshotRestore state : registeredStates.values()) {
-			if (state instanceof StateTable) {
-				sum += ((StateTable<?, ?, ?>) state).size();
-			}
+		for (StateSnapshotRestore state : registeredKVStates.values()) {
+			sum += ((StateTable<?, ?, ?>) state).size();
 		}
 		return sum;
 	}
@@ -492,12 +576,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * Returns the total number of state entries across all keys for the given namespace.
 	 */
 	@VisibleForTesting
-	public int numStateEntries(Object namespace) {
+	public int numKeyValueStateEntries(Object namespace) {
 		int sum = 0;
-		for (StateSnapshotRestore state : registeredStates.values()) {
-			if (state instanceof StateTable) {
-				sum += ((StateTable<?, ?, ?>) state).sizeOfNamespace(namespace);
-			}
+		for (StateTable<?, ?, ?> state : registeredKVStates.values()) {
+			sum += state.sizeOfNamespace(namespace);
 		}
 		return sum;
 	}
@@ -574,7 +656,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		private final SnapshotStrategySynchronicityBehavior<K> snapshotStrategySynchronicityTrait;
 
-		public HeapSnapshotStrategy(
+		HeapSnapshotStrategy(
 			SnapshotStrategySynchronicityBehavior<K> snapshotStrategySynchronicityTrait) {
 			this.snapshotStrategySynchronicityTrait = snapshotStrategySynchronicityTrait;
 		}
@@ -592,28 +674,31 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 			long syncStartTime = System.currentTimeMillis();
 
-			Preconditions.checkState(registeredStates.size() <= Short.MAX_VALUE,
-				"Too many KV-States: " + registeredStates.size() +
-					". Currently at most " + Short.MAX_VALUE + " states are supported");
-
-			List<StateMetaInfoSnapshot> metaInfoSnapshots =
-				new ArrayList<>(registeredStates.size());
+			int numStates = registeredKVStates.size() + registeredPQStates.size();
 
-			final Map<String, Integer> kVStateToId = new HashMap<>(registeredStates.size());
+			Preconditions.checkState(numStates <= Short.MAX_VALUE,
+				"Too many states: " + numStates +
+					". Currently at most " + Short.MAX_VALUE + " states are supported");
 
-			final Map<String, StateSnapshot> cowStateStableSnapshots =
-				new HashMap<>(registeredStates.size());
-
-			for (Map.Entry<String, StateSnapshotRestore> kvState : registeredStates.entrySet()) {
-				String stateName = kvState.getKey();
-				kVStateToId.put(stateName, kVStateToId.size());
-				StateSnapshotRestore state = kvState.getValue();
-				if (null != state) {
-					final StateSnapshot stateSnapshot = state.stateSnapshot();
-					metaInfoSnapshots.add(stateSnapshot.getMetaInfoSnapshot());
-					cowStateStableSnapshots.put(stateName, stateSnapshot);
-				}
-			}
+			final List<StateMetaInfoSnapshot> metaInfoSnapshots = new ArrayList<>(numStates);
+			final Map<StateUID, Integer> stateNamesToId =
+				new HashMap<>(numStates);
+			final Map<StateUID, StateSnapshot> cowStateStableSnapshots =
+				new HashMap<>(numStates);
+
+			processSnapshotMetaInfoForAllStates(
+				metaInfoSnapshots,
+				cowStateStableSnapshots,
+				stateNamesToId,
+				registeredKVStates,
+				StateMetaInfoSnapshot.BackendStateType.KEY_VALUE);
+
+			processSnapshotMetaInfoForAllStates(
+				metaInfoSnapshots,
+				cowStateStableSnapshots,
+				stateNamesToId,
+				registeredPQStates,
+				StateMetaInfoSnapshot.BackendStateType.PRIORITY_QUEUE);
 
 			final KeyedBackendSerializationProxy<K> serializationProxy =
 				new KeyedBackendSerializationProxy<>(
@@ -692,13 +777,14 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 							keyGroupRangeOffsets[keyGroupPos] = localStream.getPos();
 							outView.writeInt(keyGroupId);
 
-							for (Map.Entry<String, StateSnapshot> kvState : cowStateStableSnapshots.entrySet()) {
+							for (Map.Entry<StateUID, StateSnapshot> stateSnapshot :
+								cowStateStableSnapshots.entrySet()) {
 								StateSnapshot.StateKeyGroupWriter partitionedSnapshot =
-									kvState.getValue().getKeyGroupWriter();
+
+									stateSnapshot.getValue().getKeyGroupWriter();
 								try (OutputStream kgCompressionOut = keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
-									String stateName = kvState.getKey();
 									DataOutputViewStreamWrapper kgCompressionView = new DataOutputViewStreamWrapper(kgCompressionOut);
-									kgCompressionView.writeShort(kVStateToId.get(stateName));
+									kgCompressionView.writeShort(stateNamesToId.get(stateSnapshot.getKey()));
 									partitionedSnapshot.writeStateInKeyGroup(kgCompressionView, keyGroupId);
 								} // this will just close the outer compression stream
 							}
@@ -747,5 +833,80 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		public <N, V> StateTable<K, N, V> newStateTable(RegisteredKeyValueStateBackendMetaInfo<N, V> newMetaInfo) {
 			return snapshotStrategySynchronicityTrait.newStateTable(newMetaInfo);
 		}
+
+		private void processSnapshotMetaInfoForAllStates(
+			List<StateMetaInfoSnapshot> metaInfoSnapshots,
+			Map<StateUID, StateSnapshot> cowStateStableSnapshots,
+			Map<StateUID, Integer> stateNamesToId,
+			Map<String, ? extends StateSnapshotRestore> registeredStates,
+			StateMetaInfoSnapshot.BackendStateType stateType) {
+
+			for (Map.Entry<String, ? extends StateSnapshotRestore> kvState : registeredStates.entrySet()) {
+				final StateUID stateUid = StateUID.of(kvState.getKey(), stateType);
+				stateNamesToId.put(stateUid, stateNamesToId.size());
+				StateSnapshotRestore state = kvState.getValue();
+				if (null != state) {
+					final StateSnapshot stateSnapshot = state.stateSnapshot();
+					metaInfoSnapshots.add(stateSnapshot.getMetaInfoSnapshot());
+					cowStateStableSnapshots.put(stateUid, stateSnapshot);
+				}
+			}
+		}
+	}
+
+	private interface StateFactory {
+		<K, N, SV, S extends State, IS extends S> IS createState(
+			StateDescriptor<S, SV> stateDesc,
+			StateTable<K, N, SV> stateTable,
+			TypeSerializer<K> keySerializer) throws Exception;
+	}
+
+	/**
+	 * Unique identifier for registered state in this backend.
+	 */
+	private static final class StateUID {
+
+		@Nonnull
+		private final String stateName;
+
+		@Nonnull
+		private final StateMetaInfoSnapshot.BackendStateType stateType;
+
+		StateUID(@Nonnull String stateName, @Nonnull StateMetaInfoSnapshot.BackendStateType stateType) {
+			this.stateName = stateName;
+			this.stateType = stateType;
+		}
+
+		@Nonnull
+		public String getStateName() {
+			return stateName;
+		}
+
+		@Nonnull
+		public StateMetaInfoSnapshot.BackendStateType getStateType() {
+			return stateType;
+		}
+
+		@Override
+		public boolean equals(Object o) {
+			if (this == o) {
+				return true;
+			}
+			if (o == null || getClass() != o.getClass()) {
+				return false;
+			}
+			StateUID uid = (StateUID) o;
+			return Objects.equals(getStateName(), uid.getStateName()) &&
+				getStateType() == uid.getStateType();
+		}
+
+		@Override
+		public int hashCode() {
+			return Objects.hash(getStateName(), getStateType());
+		}
+
+		public static StateUID of(@Nonnull String stateName, @Nonnull StateMetaInfoSnapshot.BackendStateType stateType) {
+			return new StateUID(stateName, stateType);
+		}
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
index b0255d3..80d79ac 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSetFactory.java
@@ -59,7 +59,7 @@ public class HeapPriorityQueueSetFactory implements PriorityQueueSetFactory {
 		@Nonnull String stateName,
 		@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
 
-		return new HeapPriorityQueueSet<T>(
+		return new HeapPriorityQueueSet<>(
 			PriorityComparator.forPriorityComparableObjects(),
 			KeyExtractorFunction.forKeyedObjects(),
 			minimumCapacity,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
index b2b2843..fc1e0db 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapPriorityQueueSnapshotRestoreWrapper.java
@@ -89,4 +89,26 @@ public class HeapPriorityQueueSnapshotRestoreWrapper<T extends HeapPriorityQueue
 	public HeapPriorityQueueSet<T> getPriorityQueue() {
 		return priorityQueue;
 	}
+
+	@Nonnull
+	public RegisteredPriorityQueueStateBackendMetaInfo<T> getMetaInfo() {
+		return metaInfo;
+	}
+
+	/**
+	 * Returns a deep copy of the snapshot, where the serializer is changed to the given serializer.
+	 */
+	public HeapPriorityQueueSnapshotRestoreWrapper<T> forUpdatedSerializer(
+		@Nonnull TypeSerializer<T> updatedSerializer) {
+
+		RegisteredPriorityQueueStateBackendMetaInfo<T> updatedMetaInfo =
+			new RegisteredPriorityQueueStateBackendMetaInfo<>(metaInfo.getName(), updatedSerializer);
+
+		return new HeapPriorityQueueSnapshotRestoreWrapper<>(
+			priorityQueue,
+			updatedMetaInfo,
+			keyExtractorFunction,
+			localKeyGroupRange,
+			totalKeyGroups);
+	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
index 510d277..935ebb6 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/InternalPriorityQueueTestBase.java
@@ -347,7 +347,7 @@ public abstract class InternalPriorityQueueTestBase extends TestLogger {
 	/**
 	 * Payload for usage in the test.
 	 */
-	protected static class TestElement implements HeapPriorityQueueElement {
+	protected static class TestElement implements HeapPriorityQueueElement, Keyed<Long>, PriorityComparable<TestElement> {
 
 		private final long key;
 		private final long priority;
@@ -359,7 +359,12 @@ public abstract class InternalPriorityQueueTestBase extends TestLogger {
 			this.internalIndex = NOT_CONTAINED;
 		}
 
-		public long getKey() {
+		@Override
+		public int comparePriorityTo(@Nonnull TestElement other) {
+			return Long.compare(priority, other.priority);
+		}
+
+		public Long getKey() {
 			return key;
 		}
 
@@ -386,8 +391,8 @@ public abstract class InternalPriorityQueueTestBase extends TestLogger {
 				return false;
 			}
 			TestElement that = (TestElement) o;
-			return getKey() == that.getKey() &&
-				getPriority() == that.getPriority();
+			return key == that.key &&
+				priority == that.priority;
 		}
 
 		@Override
@@ -414,9 +419,11 @@ public abstract class InternalPriorityQueueTestBase extends TestLogger {
 	 */
 	protected static class TestElementSerializer extends TypeSerializer<TestElement> {
 
+		private static final int REVISION = 1;
+
 		public static final TestElementSerializer INSTANCE = new TestElementSerializer();
 
-		private TestElementSerializer() {
+		protected TestElementSerializer() {
 		}
 
 		@Override
@@ -489,14 +496,62 @@ public abstract class InternalPriorityQueueTestBase extends TestLogger {
 			return 4711;
 		}
 
+		protected int getRevision() {
+			return REVISION;
+		}
+
 		@Override
 		public TypeSerializerConfigSnapshot snapshotConfiguration() {
-			throw new UnsupportedOperationException();
+			return new Snapshot(getRevision());
 		}
 
 		@Override
 		public CompatibilityResult<TestElement> ensureCompatibility(TypeSerializerConfigSnapshot configSnapshot) {
-			throw new UnsupportedOperationException();
+			return (configSnapshot instanceof Snapshot) && ((Snapshot) configSnapshot).revision <= getRevision() ?
+				CompatibilityResult.compatible() : CompatibilityResult.requiresMigration();
+		}
+
+		public static class Snapshot extends TypeSerializerConfigSnapshot {
+
+			private int revision;
+
+			public Snapshot() {
+			}
+
+			public Snapshot(int revision) {
+				this.revision = revision;
+			}
+
+			@Override
+			public boolean equals(Object obj) {
+				return obj instanceof Snapshot && revision == ((Snapshot) obj).revision;
+			}
+
+			@Override
+			public int hashCode() {
+				return revision;
+			}
+
+			@Override
+			public int getVersion() {
+				return 0;
+			}
+
+			public int getRevision() {
+				return revision;
+			}
+
+			@Override
+			public void write(DataOutputView out) throws IOException {
+				super.write(out);
+				out.writeInt(revision);
+			}
+
+			@Override
+			public void read(DataInputView in) throws IOException {
+				super.read(in);
+				this.revision = in.readInt();
+			}
 		}
 	}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index 0ba4c33..215d7d3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.testutils.ArtificialCNFExceptionThrowingClassLoader;
 import org.apache.flink.util.FutureUtil;
+
 import org.junit.Assert;
 import org.junit.Ignore;
 import org.junit.Test;
@@ -154,6 +155,7 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 	@Test
 	public void testKeyedStateRestoreFailsIfSerializerDeserializationFails() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		KeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class, null);
@@ -161,7 +163,7 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 
 		HeapKeyedStateBackend<Integer> heapBackend = (HeapKeyedStateBackend<Integer>) backend;
 
-		assertEquals(0, heapBackend.numStateEntries());
+		assertEquals(0, heapBackend.numKeyValueStateEntries());
 
 		ValueState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 
@@ -170,11 +172,13 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 		state.update("hello");
 		state.update("ciao");
 
-		KeyedStateHandle snapshot = runSnapshot(((HeapKeyedStateBackend<Integer>) backend).snapshot(
-			682375462378L,
-			2,
-			streamFactory,
-			CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot = runSnapshot(
+			((HeapKeyedStateBackend<Integer>) backend).snapshot(
+				682375462378L,
+				2,
+				streamFactory,
+				CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		backend.dispose();
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index bfdc05d..059a706 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -50,7 +50,10 @@ import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.api.java.typeutils.runtime.PojoSerializer;
 import org.apache.flink.api.java.typeutils.runtime.kryo.JavaSerializer;
 import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
 import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.core.testutils.CheckedThread;
@@ -77,6 +80,7 @@ import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory;
 import org.apache.flink.testutils.ArtificialCNFExceptionThrowingClassLoader;
 import org.apache.flink.types.IntValue;
+import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.StateMigrationException;
 import org.apache.flink.util.TestLogger;
@@ -87,6 +91,7 @@ import com.esotericsoftware.kryo.Kryo;
 import com.esotericsoftware.kryo.io.Input;
 import com.esotericsoftware.kryo.io.Output;
 import org.apache.commons.io.output.ByteArrayOutputStream;
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -295,6 +300,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	public void testBackendUsesRegisteredKryoDefaultSerializer() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 		Environment env = new DummyEnvironment();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
 
 		// cast because our test serializer is not typed to TestPojo
@@ -330,7 +336,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		try {
 			// backends that lazily serializes (such as memory state backend) will fail here
-			runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+			runSnapshot(
+				backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 		} catch (ExpectedKryoTestException e) {
 			numExceptions++;
 		} catch (Exception e) {
@@ -350,6 +358,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	public void testBackendUsesRegisteredKryoDefaultSerializerUsingGetOrCreate() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 		Environment env = new DummyEnvironment();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
 
 		// cast because our test serializer is not typed to TestPojo
@@ -390,7 +399,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		try {
 			// backends that lazily serializes (such as memory state backend) will fail here
-			runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+			runSnapshot(
+				backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 		} catch (ExpectedKryoTestException e) {
 			numExceptions++;
 		} catch (Exception e) {
@@ -409,8 +420,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	public void testBackendUsesRegisteredKryoSerializer() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 		Environment env = new DummyEnvironment();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
-
 		env.getExecutionConfig()
 				.registerTypeWithKryoSerializer(TestPojo.class, ExceptionThrowingTestSerializer.class);
 
@@ -444,7 +455,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		try {
 			// backends that lazily serializes (such as memory state backend) will fail here
-			runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+			runSnapshot(
+				backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 		} catch (ExpectedKryoTestException e) {
 			numExceptions++;
 		} catch (Exception e) {
@@ -464,6 +477,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	public void testBackendUsesRegisteredKryoSerializerUsingGetOrCreate() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 		Environment env = new DummyEnvironment();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
 
 		env.getExecutionConfig().registerTypeWithKryoSerializer(TestPojo.class, ExceptionThrowingTestSerializer.class);
@@ -500,7 +514,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		try {
 			// backends that lazily serializes (such as memory state backend) will fail here
-			runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+			runSnapshot(
+				backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 		} catch (ExpectedKryoTestException e) {
 			numExceptions++;
 		} catch (Exception e) {
@@ -528,6 +544,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	public void testKryoRegisteringRestoreResilienceWithRegisteredType() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 		Environment env = new DummyEnvironment();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
 
 		TypeInformation<TestPojo> pojoType = new GenericTypeInfo<>(TestPojo.class);
@@ -548,11 +565,13 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		backend.setCurrentKey(2);
 		state.update(new TestPojo("u2", 2));
 
-		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
+		KeyedStateHandle snapshot = runSnapshot(
+			backend.snapshot(
 				682375462378L,
 				2,
 				streamFactory,
-				CheckpointOptions.forCheckpointWithDefaultLocation()));
+				CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		backend.dispose();
 
@@ -617,9 +636,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				682375462378L,
 				2,
 				streamFactory,
-				CheckpointOptions.forCheckpointWithDefaultLocation()));
+				CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-			snapshot.registerSharedStates(sharedStateRegistry);
 			backend.dispose();
 
 			// ========== restore snapshot - should use default serializer (ONLY SERIALIZATION) ==========
@@ -639,13 +658,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			// update to test state backends that eagerly serialize, such as RocksDB
 			state.update(new TestPojo("u1", 11));
 
-			KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
-				682375462378L,
-				2,
-				streamFactory,
-				CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot2 = runSnapshot(
+				backend.snapshot(
+					682375462378L,
+					2,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-			snapshot2.registerSharedStates(sharedStateRegistry);
 			snapshot.discardState();
 
 			backend.dispose();
@@ -715,13 +735,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			backend.setCurrentKey(2);
 			state.update(new TestPojo("u2", 2));
 
-			KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
-				682375462378L,
-				2,
-				streamFactory,
-				CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot = runSnapshot(
+				backend.snapshot(
+					682375462378L,
+					2,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-			snapshot.registerSharedStates(sharedStateRegistry);
 			backend.dispose();
 
 			// ========== restore snapshot - should use specific serializer (ONLY SERIALIZATION) ==========
@@ -740,13 +761,13 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			// update to test state backends that eagerly serialize, such as RocksDB
 			state.update(new TestPojo("u1", 11));
 
-			KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
-				682375462378L,
-				2,
-				streamFactory,
-				CheckpointOptions.forCheckpointWithDefaultLocation()));
-
-			snapshot2.registerSharedStates(sharedStateRegistry);
+			KeyedStateHandle snapshot2 = runSnapshot(
+				backend.snapshot(
+					682375462378L,
+					2,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
 			snapshot.discardState();
 
@@ -783,6 +804,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	public void testKryoRestoreResilienceWithDifferentRegistrationOrder() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 		Environment env = new DummyEnvironment();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 
 		// register A first then B
 		env.getExecutionConfig().registerKryoType(TestNestedPojoClassA.class);
@@ -790,83 +812,91 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
 
-		TypeInformation<TestPojo> pojoType = new GenericTypeInfo<>(TestPojo.class);
+		try {
 
-		// make sure that we are in fact using the KryoSerializer
-		assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof KryoSerializer);
+			TypeInformation<TestPojo> pojoType = new GenericTypeInfo<>(TestPojo.class);
 
-		ValueStateDescriptor<TestPojo> kvId = new ValueStateDescriptor<>("id", pojoType);
-		ValueState<TestPojo> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+			// make sure that we are in fact using the KryoSerializer
+			assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof KryoSerializer);
 
-		// access the internal state representation to retrieve the original Kryo registration ids;
-		// these will be later used to check that on restore, the new Kryo serializer has reconfigured itself to
-		// have identical mappings
-		InternalKvState internalKvState = (InternalKvState) state;
-		KryoSerializer<TestPojo> kryoSerializer = (KryoSerializer<TestPojo>) internalKvState.getValueSerializer();
-		int mainPojoClassRegistrationId = kryoSerializer.getKryo().getRegistration(TestPojo.class).getId();
-		int nestedPojoClassARegistrationId = kryoSerializer.getKryo().getRegistration(TestNestedPojoClassA.class).getId();
-		int nestedPojoClassBRegistrationId = kryoSerializer.getKryo().getRegistration(TestNestedPojoClassB.class).getId();
+			ValueStateDescriptor<TestPojo> kvId = new ValueStateDescriptor<>("id", pojoType);
+			ValueState<TestPojo> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 
-		// ============== create snapshot of current configuration ==============
+			// access the internal state representation to retrieve the original Kryo registration ids;
+			// these will be later used to check that on restore, the new Kryo serializer has reconfigured itself to
+			// have identical mappings
+			InternalKvState internalKvState = (InternalKvState) state;
+			KryoSerializer<TestPojo> kryoSerializer = (KryoSerializer<TestPojo>) internalKvState.getValueSerializer();
+			int mainPojoClassRegistrationId = kryoSerializer.getKryo().getRegistration(TestPojo.class).getId();
+			int nestedPojoClassARegistrationId = kryoSerializer.getKryo().getRegistration(TestNestedPojoClassA.class).getId();
+			int nestedPojoClassBRegistrationId = kryoSerializer.getKryo().getRegistration(TestNestedPojoClassB.class).getId();
 
-		// make some more modifications
-		backend.setCurrentKey(1);
-		state.update(new TestPojo("u1", 1, new TestNestedPojoClassA(1.0, 2), new TestNestedPojoClassB(2.3, "foo")));
+			// ============== create snapshot of current configuration ==============
 
-		backend.setCurrentKey(2);
-		state.update(new TestPojo("u2", 2, new TestNestedPojoClassA(2.0, 5), new TestNestedPojoClassB(3.1, "bar")));
+			// make some more modifications
+			backend.setCurrentKey(1);
+			state.update(new TestPojo("u1", 1, new TestNestedPojoClassA(1.0, 2), new TestNestedPojoClassB(2.3, "foo")));
 
-		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
-			682375462378L,
-			2,
-			streamFactory,
-			CheckpointOptions.forCheckpointWithDefaultLocation()));
+			backend.setCurrentKey(2);
+			state.update(new TestPojo("u2", 2, new TestNestedPojoClassA(2.0, 5), new TestNestedPojoClassB(3.1, "bar")));
 
-		backend.dispose();
+			KeyedStateHandle snapshot = runSnapshot(
+				backend.snapshot(
+					682375462378L,
+					2,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-		// ========== restore snapshot, with a different registration order in the configuration ==========
+			backend.dispose();
 
-		env = new DummyEnvironment();
+			// ========== restore snapshot, with a different registration order in the configuration ==========
 
-		env.getExecutionConfig().registerKryoType(TestNestedPojoClassB.class); // this time register B first
-		env.getExecutionConfig().registerKryoType(TestNestedPojoClassA.class);
+			env = new DummyEnvironment();
 
-		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env);
+			env.getExecutionConfig().registerKryoType(TestNestedPojoClassB.class); // this time register B first
+			env.getExecutionConfig().registerKryoType(TestNestedPojoClassA.class);
 
-		snapshot.discardState();
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env);
 
-		// re-initialize to ensure that we create the KryoSerializer from scratch, otherwise
-		// initializeSerializerUnlessSet would not pick up our new config
-		kvId = new ValueStateDescriptor<>("id", pojoType);
-		state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+			// re-initialize to ensure that we create the KryoSerializer from scratch, otherwise
+			// initializeSerializerUnlessSet would not pick up our new config
+			kvId = new ValueStateDescriptor<>("id", pojoType);
+			state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 
-		// verify that on restore, the serializer that the state handle uses has reconfigured itself to have
-		// identical Kryo registration ids compared to the previous execution
-		internalKvState = (InternalKvState) state;
-		kryoSerializer = (KryoSerializer<TestPojo>) internalKvState.getValueSerializer();
-		assertEquals(mainPojoClassRegistrationId, kryoSerializer.getKryo().getRegistration(TestPojo.class).getId());
-		assertEquals(nestedPojoClassARegistrationId, kryoSerializer.getKryo().getRegistration(TestNestedPojoClassA.class).getId());
-		assertEquals(nestedPojoClassBRegistrationId, kryoSerializer.getKryo().getRegistration(TestNestedPojoClassB.class).getId());
+			// verify that on restore, the serializer that the state handle uses has reconfigured itself to have
+			// identical Kryo registration ids compared to the previous execution
+			internalKvState = (InternalKvState) state;
+			kryoSerializer = (KryoSerializer<TestPojo>) internalKvState.getValueSerializer();
+			assertEquals(mainPojoClassRegistrationId, kryoSerializer.getKryo().getRegistration(TestPojo.class).getId());
+			assertEquals(nestedPojoClassARegistrationId, kryoSerializer.getKryo().getRegistration(TestNestedPojoClassA.class).getId());
+			assertEquals(nestedPojoClassBRegistrationId, kryoSerializer.getKryo().getRegistration(TestNestedPojoClassB.class).getId());
 
-		backend.setCurrentKey(1);
+			backend.setCurrentKey(1);
 
-		// update to test state backends that eagerly serialize, such as RocksDB
-		state.update(new TestPojo("u1", 11, new TestNestedPojoClassA(22.1, 12), new TestNestedPojoClassB(1.23, "foobar")));
+			// update to test state backends that eagerly serialize, such as RocksDB
+			state.update(new TestPojo("u1", 11, new TestNestedPojoClassA(22.1, 12), new TestNestedPojoClassB(1.23, "foobar")));
 
-		// this tests backends that lazily serialize, such as memory state backend
-		runSnapshot(backend.snapshot(
-			682375462378L,
-			2,
-			streamFactory,
-			CheckpointOptions.forCheckpointWithDefaultLocation()));
+			// this tests backends that lazily serialize, such as memory state backend
+			runSnapshot(
+				backend.snapshot(
+					682375462378L,
+					2,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-		backend.dispose();
+			snapshot.discardState();
+		} finally {
+			backend.dispose();
+		}
 	}
 
 	@Test
 	public void testPojoRestoreResilienceWithDifferentRegistrationOrder() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 		Environment env = new DummyEnvironment();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 
 		// register A first then B
 		env.getExecutionConfig().registerPojoType(TestNestedPojoClassA.class);
@@ -874,60 +904,66 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
 
-		TypeInformation<TestPojo> pojoType = TypeExtractor.getForClass(TestPojo.class);
+		try {
 
-		// make sure that we are in fact using the PojoSerializer
-		assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof PojoSerializer);
+			TypeInformation<TestPojo> pojoType = TypeExtractor.getForClass(TestPojo.class);
 
-		ValueStateDescriptor<TestPojo> kvId = new ValueStateDescriptor<>("id", pojoType);
-		ValueState<TestPojo> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+			// make sure that we are in fact using the PojoSerializer
+			assertTrue(pojoType.createSerializer(env.getExecutionConfig()) instanceof PojoSerializer);
 
-		// ============== create snapshot of current configuration ==============
+			ValueStateDescriptor<TestPojo> kvId = new ValueStateDescriptor<>("id", pojoType);
+			ValueState<TestPojo> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 
-		// make some more modifications
-		backend.setCurrentKey(1);
-		state.update(new TestPojo("u1", 1, new TestNestedPojoClassA(1.0, 2), new TestNestedPojoClassB(2.3, "foo")));
+			// ============== create snapshot of current configuration ==============
 
-		backend.setCurrentKey(2);
-		state.update(new TestPojo("u2", 2, new TestNestedPojoClassA(2.0, 5), new TestNestedPojoClassB(3.1, "bar")));
+			// make some more modifications
+			backend.setCurrentKey(1);
+			state.update(new TestPojo("u1", 1, new TestNestedPojoClassA(1.0, 2), new TestNestedPojoClassB(2.3, "foo")));
 
-		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(
-			682375462378L,
-			2,
-			streamFactory,
-			CheckpointOptions.forCheckpointWithDefaultLocation()));
+			backend.setCurrentKey(2);
+			state.update(new TestPojo("u2", 2, new TestNestedPojoClassA(2.0, 5), new TestNestedPojoClassB(3.1, "bar")));
 
-		backend.dispose();
+			KeyedStateHandle snapshot = runSnapshot(
+				backend.snapshot(
+					682375462378L,
+					2,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-		// ========== restore snapshot, with a different registration order in the configuration ==========
+			backend.dispose();
 
-		env = new DummyEnvironment();
+			// ========== restore snapshot, with a different registration order in the configuration ==========
 
-		env.getExecutionConfig().registerPojoType(TestNestedPojoClassB.class); // this time register B first
-		env.getExecutionConfig().registerPojoType(TestNestedPojoClassA.class);
+			env = new DummyEnvironment();
 
-		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env);
+			env.getExecutionConfig().registerPojoType(TestNestedPojoClassB.class); // this time register B first
+			env.getExecutionConfig().registerPojoType(TestNestedPojoClassA.class);
 
-		snapshot.discardState();
+			backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env);
 
-		// re-initialize to ensure that we create the PojoSerializer from scratch, otherwise
-		// initializeSerializerUnlessSet would not pick up our new config
-		kvId = new ValueStateDescriptor<>("id", pojoType);
-		state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
+			// re-initialize to ensure that we create the PojoSerializer from scratch, otherwise
+			// initializeSerializerUnlessSet would not pick up our new config
+			kvId = new ValueStateDescriptor<>("id", pojoType);
+			state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 
-		backend.setCurrentKey(1);
+			backend.setCurrentKey(1);
 
-		// update to test state backends that eagerly serialize, such as RocksDB
-		state.update(new TestPojo("u1", 11, new TestNestedPojoClassA(22.1, 12), new TestNestedPojoClassB(1.23, "foobar")));
+			// update to test state backends that eagerly serialize, such as RocksDB
+			state.update(new TestPojo("u1", 11, new TestNestedPojoClassA(22.1, 12), new TestNestedPojoClassB(1.23, "foobar")));
 
-		// this tests backends that lazily serialize, such as memory state backend
-		runSnapshot(backend.snapshot(
-			682375462378L,
-			2,
-			streamFactory,
-			CheckpointOptions.forCheckpointWithDefaultLocation()));
+			// this tests backends that lazily serialize, such as memory state backend
+			runSnapshot(
+				backend.snapshot(
+					682375462378L,
+					2,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()), sharedStateRegistry);
 
-		backend.dispose();
+			snapshot.discardState();
+		} finally {
+			backend.dispose();
+		}
 	}
 
 	@Test
@@ -957,13 +993,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			assertTrue(internal.getValueSerializer() instanceof TestReconfigurableCustomTypeSerializer);
 			assertFalse(((TestReconfigurableCustomTypeSerializer) internal.getValueSerializer()).isReconfigured());
 
-			KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(
-				682375462378L,
-				2,
-				streamFactory,
-				CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot1 = runSnapshot(
+				backend.snapshot(
+					682375462378L,
+					2,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-			snapshot1.registerSharedStates(sharedStateRegistry);
 			backend.dispose();
 
 			// ========== restore snapshot, which should reconfigure the serializer, and then create a snapshot again ==========
@@ -995,13 +1032,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 			state.update(new TestCustomStateClass("new-test-message-2", "extra-message-2"));
 
-			KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
-				682375462379L,
-				3,
-				streamFactory,
-				CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot2 = runSnapshot(
+				backend.snapshot(
+					682375462379L,
+					3,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-			snapshot2.registerSharedStates(sharedStateRegistry);
 			snapshot1.discardState();
 			backend.dispose();
 
@@ -1055,13 +1093,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			backend.setCurrentKey(2);
 			state.update(new TestCustomStateClass("test-message-2", "this-should-be-ignored"));
 
-			KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(
-				682375462378L,
-				2,
-				streamFactory,
-				CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot1 = runSnapshot(
+				backend.snapshot(
+					682375462378L,
+					2,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-			snapshot1.registerSharedStates(sharedStateRegistry);
 			backend.dispose();
 
 			// ========== restore snapshot, using the new serializer (that has different classname) ==========
@@ -1093,13 +1132,14 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			backend.setCurrentKey(2);
 			state.update(new TestCustomStateClass("new-test-message-2", "extra-message-2"));
 
-			KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(
-				682375462379L,
-				3,
-				streamFactory,
-				CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot2 = runSnapshot(
+				backend.snapshot(
+					682375462379L,
+					3,
+					streamFactory,
+					CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
-			snapshot2.registerSharedStates(sharedStateRegistry);
 			snapshot1.discardState();
 		} finally {
 			backend.dispose();
@@ -1107,9 +1147,105 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	}
 
 	@Test
+	public void testPriorityQueueSerializerUpdates() throws Exception {
+
+		final String stateName = "test";
+		final CheckpointStreamFactory streamFactory = createStreamFactory();
+		final SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+
+		AbstractKeyedStateBackend<Integer> keyedBackend = createKeyedBackend(IntSerializer.INSTANCE);
+
+		try {
+			TypeSerializer<InternalPriorityQueueTestBase.TestElement> serializer =
+				InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE;
+
+			KeyGroupedInternalPriorityQueue<InternalPriorityQueueTestBase.TestElement> priorityQueue =
+				keyedBackend.create(stateName, serializer);
+
+			priorityQueue.add(new InternalPriorityQueueTestBase.TestElement(42L, 0L));
+
+			RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot =
+				keyedBackend.snapshot(0L, 0L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
+
+			KeyedStateHandle keyedStateHandle = runSnapshot(snapshot, sharedStateRegistry);
+
+			keyedBackend.dispose();
+
+			// test restore with a modified but compatible serializer ---------------------------
+
+			keyedBackend = restoreKeyedBackend(IntSerializer.INSTANCE, keyedStateHandle);
+
+			serializer = new ModifiedTestElementSerializer();
+
+			priorityQueue = keyedBackend.create(stateName, serializer);
+
+			final InternalPriorityQueueTestBase.TestElement checkElement =
+				new InternalPriorityQueueTestBase.TestElement(4711L, 1L);
+			priorityQueue.add(checkElement);
+
+			snapshot = keyedBackend.snapshot(1L, 1L, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
+
+			keyedStateHandle = runSnapshot(snapshot, sharedStateRegistry);
+
+			keyedBackend.dispose();
+
+			// test that the modified serializer was actually used ---------------------------
+
+			keyedBackend = restoreKeyedBackend(IntSerializer.INSTANCE, keyedStateHandle);
+			priorityQueue = keyedBackend.create(stateName, serializer);
+
+			priorityQueue.poll();
+
+			ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos();
+			DataOutputViewStreamWrapper outWrapper = new DataOutputViewStreamWrapper(out);
+			serializer.serialize(checkElement, outWrapper);
+			InternalPriorityQueueTestBase.TestElement expected =
+				serializer.deserialize(new DataInputViewStreamWrapper(new ByteArrayInputStreamWithPos(out.toByteArray())));
+
+			Assert.assertEquals(
+				expected,
+				priorityQueue.poll());
+			Assert.assertTrue(priorityQueue.isEmpty());
+
+			keyedBackend.dispose();
+
+			// test that incompatible serializer is rejected ---------------------------
+
+			serializer = InternalPriorityQueueTestBase.TestElementSerializer.INSTANCE;
+			keyedBackend = restoreKeyedBackend(IntSerializer.INSTANCE, keyedStateHandle);
+
+			try {
+				// this is expected to fail, because the old and new serializer shoulbe be incompatible through
+				// different revision numbers.
+				keyedBackend.create("test", serializer);
+				Assert.fail("Expected exception from incompatible serializer.");
+			} catch (Exception e) {
+				Assert.assertTrue("Exception was not caused by state migration: " + e,
+					ExceptionUtils.findThrowable(e, StateMigrationException.class).isPresent());
+			}
+		} finally {
+			keyedBackend.dispose();
+		}
+	}
+
+	public static class ModifiedTestElementSerializer extends InternalPriorityQueueTestBase.TestElementSerializer {
+
+		@Override
+		public void serialize(InternalPriorityQueueTestBase.TestElement record, DataOutputView target) throws IOException {
+			super.serialize(new InternalPriorityQueueTestBase.TestElement(record.getKey() + 1, record.getPriority() + 1), target);
+		}
+
+		@Override
+		protected int getRevision() {
+			return super.getRevision() + 1;
+		}
+	}
+
+	@Test
 	@SuppressWarnings("unchecked")
 	public void testValueState() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class);
@@ -1138,7 +1274,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot1 = runSnapshot(
+			backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1149,7 +1287,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.update("u3");
 
 		// draw another snapshot
-		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot2 = runSnapshot(
+			backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1320,7 +1460,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	@SuppressWarnings("unchecked")
 	public void testMultipleValueStates() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
-
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
 				IntSerializer.INSTANCE,
 				1,
@@ -1350,7 +1490,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		// draw a snapshot
 		KeyedStateHandle snapshot1 =
-			runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+			runSnapshot(
+				backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
 		backend.dispose();
 		backend = restoreKeyedBackend(
@@ -1394,6 +1536,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		}
 
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ValueStateDescriptor<Long> kvId = new ValueStateDescriptor<>("id", LongSerializer.INSTANCE, 42L);
@@ -1422,7 +1565,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals(42L, (long) state.value());
 
 		// draw a snapshot
-		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot1 = runSnapshot(
+			backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		backend.dispose();
 		backend = restoreKeyedBackend(IntSerializer.INSTANCE, snapshot1);
@@ -1438,6 +1583,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	@SuppressWarnings("unchecked,rawtypes")
 	public void testListState() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
@@ -1470,7 +1616,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("1", joiner.join(getSerializedList(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer)));
 
 		// draw a snapshot
-		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot1 = runSnapshot(
+			backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1483,7 +1631,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.add("u3");
 
 		// draw another snapshot
-		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot2 = runSnapshot(
+			backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -1756,7 +1906,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.clear();
 
 			// make sure all lists / maps are cleared
-			assertThat("State backend is not empty.", keyedBackend.numStateEntries(), is(0));
+			assertThat("State backend is not empty.", keyedBackend.numKeyValueStateEntries(), is(0));
 		} finally {
 			keyedBackend.close();
 			keyedBackend.dispose();
@@ -1870,7 +2020,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.setCurrentNamespace(namespace1);
 			state.clear();
 
-			assertThat("State backend is not empty.", keyedBackend.numStateEntries(), is(0));
+			assertThat("State backend is not empty.", keyedBackend.numKeyValueStateEntries(), is(0));
 		}
 		finally {
 			keyedBackend.close();
@@ -1882,6 +2032,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	@SuppressWarnings("unchecked")
 	public void testReducingState() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		ReducingStateDescriptor<String> kvId = new ReducingStateDescriptor<>("id", new AppendingReduce(), String.class);
@@ -1910,7 +2061,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot1 = runSnapshot(
+			backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -1921,7 +2074,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.add("u3");
 
 		// draw another snapshot
-		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot2 = runSnapshot(
+			backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -2019,7 +2174,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.clear();
 
 			// make sure all lists / maps are cleared
-			assertThat("State backend is not empty.", keyedBackend.numStateEntries(), is(0));
+			assertThat("State backend is not empty.", keyedBackend.numKeyValueStateEntries(), is(0));
 		}
 		finally {
 			keyedBackend.close();
@@ -2137,7 +2292,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.setCurrentNamespace(namespace1);
 			state.clear();
 
-			assertThat("State backend is not empty.", keyedBackend.numStateEntries(), is(0));
+			assertThat("State backend is not empty.", keyedBackend.numKeyValueStateEntries(), is(0));
 		}
 		finally {
 			keyedBackend.close();
@@ -2192,7 +2347,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.clear();
 
 			// make sure all lists / maps are cleared
-			assertThat("State backend is not empty.", keyedBackend.numStateEntries(), is(0));
+			assertThat("State backend is not empty.", keyedBackend.numKeyValueStateEntries(), is(0));
 		}
 		finally {
 			keyedBackend.close();
@@ -2310,7 +2465,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.setCurrentNamespace(namespace1);
 			state.clear();
 
-			assertThat("State backend is not empty.", keyedBackend.numStateEntries(), is(0));
+			assertThat("State backend is not empty.", keyedBackend.numKeyValueStateEntries(), is(0));
 		}
 		finally {
 			keyedBackend.close();
@@ -2365,7 +2520,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.clear();
 
 			// make sure all lists / maps are cleared
-			assertThat("State backend is not empty.", keyedBackend.numStateEntries(), is(0));
+			assertThat("State backend is not empty.", keyedBackend.numKeyValueStateEntries(), is(0));
 		}
 		finally {
 			keyedBackend.close();
@@ -2483,7 +2638,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.setCurrentNamespace(namespace1);
 			state.clear();
 
-			assertThat("State backend is not empty.", keyedBackend.numStateEntries(), is(0));
+			assertThat("State backend is not empty.", keyedBackend.numKeyValueStateEntries(), is(0));
 		}
 		finally {
 			keyedBackend.close();
@@ -2495,6 +2650,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	@SuppressWarnings("unchecked,rawtypes")
 	public void testFoldingState() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		FoldingStateDescriptor<Integer, String> kvId = new FoldingStateDescriptor<>("id",
@@ -2526,7 +2682,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		assertEquals("Fold-Initial:,1", getSerializedValue(kvState, 1, keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, valueSerializer));
 
 		// draw a snapshot
-		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot1 = runSnapshot(
+			backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// make some more modifications
 		backend.setCurrentKey(1);
@@ -2538,7 +2696,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.add(103);
 
 		// draw another snapshot
-		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot2 = runSnapshot(
+			backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// validate the original state
 		backend.setCurrentKey(1);
@@ -2594,6 +2754,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	@SuppressWarnings("unchecked,rawtypes")
 	public void testMapState() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<String> backend = createKeyedBackend(StringSerializer.INSTANCE);
 
 		MapStateDescriptor<Integer, String> kvId = new MapStateDescriptor<>("id", Integer.class, String.class);
@@ -2633,7 +2794,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			getSerializedMap(kvState, "11", keySerializer, VoidNamespace.INSTANCE, namespaceSerializer, userKeySerializer, userValueSerializer));
 
 		// draw a snapshot
-		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot1 = runSnapshot(
+			backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// make some more modifications
 		backend.setCurrentKey("1");
@@ -2645,7 +2808,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.putAll(new HashMap<Integer, String>() {{ put(1031, "1031"); put(1032, "1032"); }});
 
 		// draw another snapshot
-		KeyedStateHandle snapshot2 = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot2 = runSnapshot(
+			backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		// validate the original state
 		backend.setCurrentKey("1");
@@ -2920,6 +3085,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		final int MAX_PARALLELISM = 10;
 
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		final AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(
 				IntSerializer.INSTANCE,
 				MAX_PARALLELISM,
@@ -2951,7 +3117,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.update("ShouldBeInSecondHalf");
 
 
-		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(0, 0, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot = runSnapshot(
+			backend.snapshot(0, 0, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		List<KeyedStateHandle> firstHalfKeyGroupStates = StateAssignmentOperation.getKeyedStateHandles(
 				Collections.singletonList(snapshot),
@@ -3004,6 +3172,8 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	public void testRestoreWithWrongKeySerializer() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
 
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+
 		// use an IntSerializer at first
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
@@ -3018,7 +3188,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.update("2");
 
 		// draw a snapshot
-		KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot1 = runSnapshot(
+			backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		backend.dispose();
 
@@ -3036,6 +3208,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	@SuppressWarnings("unchecked")
 	public void testValueStateRestoreWithWrongSerializers() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		try {
@@ -3049,7 +3222,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.update("2");
 
 			// draw a snapshot
-			KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot1 = runSnapshot(
+				backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -3080,6 +3255,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	@SuppressWarnings("unchecked")
 	public void testListStateRestoreWithWrongSerializers() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		try {
@@ -3092,7 +3268,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.add("2");
 
 			// draw a snapshot
-			KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot1 = runSnapshot(
+				backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -3123,6 +3301,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	@SuppressWarnings("unchecked")
 	public void testReducingStateRestoreWithWrongSerializers() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		try {
@@ -3137,7 +3316,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.add("2");
 
 			// draw a snapshot
-			KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot1 = runSnapshot(
+				backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -3168,6 +3349,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	@SuppressWarnings("unchecked")
 	public void testMapStateRestoreWithWrongSerializers() throws Exception {
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 		try {
@@ -3180,7 +3362,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 			state.put("2", "Second");
 
 			// draw a snapshot
-			KeyedStateHandle snapshot1 = runSnapshot(backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+			KeyedStateHandle snapshot1 = runSnapshot(
+				backend.snapshot(682375462378L, 2, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+				sharedStateRegistry);
 
 			backend.dispose();
 			// restore the first snapshot and validate it
@@ -3421,6 +3605,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		KvStateRegistry registry = env.getKvStateRegistry();
 
 		CheckpointStreamFactory streamFactory = createStreamFactory();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE, env);
 		KeyGroupRange expectedKeyGroupRange = backend.getKeyGroupRange();
 
@@ -3439,7 +3624,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 				eq(env.getJobID()), eq(env.getJobVertexId()), eq(expectedKeyGroupRange), eq("banana"), any(KvStateID.class));
 
 
-		KeyedStateHandle snapshot = runSnapshot(backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+		KeyedStateHandle snapshot = runSnapshot(
+			backend.snapshot(682375462379L, 4, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+			sharedStateRegistry);
 
 		backend.dispose();
 
@@ -3465,13 +3652,16 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		try {
 			CheckpointStreamFactory streamFactory = createStreamFactory();
+			SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 			AbstractKeyedStateBackend<Integer> backend = createKeyedBackend(IntSerializer.INSTANCE);
 
 			ListStateDescriptor<String> kvId = new ListStateDescriptor<>("id", String.class);
 
 			// draw a snapshot
 			KeyedStateHandle snapshot =
-				runSnapshot(backend.snapshot(682375462379L, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()));
+				runSnapshot(
+					backend.snapshot(682375462379L, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation()),
+					sharedStateRegistry);
 			assertNull(snapshot);
 			backend.dispose();
 
@@ -3491,7 +3681,7 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 
 		ValueStateDescriptor<String> kvId = new ValueStateDescriptor<>("id", String.class);
 
-		assertEquals(0, backend.numStateEntries());
+		assertEquals(0, backend.numKeyValueStateEntries());
 
 		ValueState<String> state = backend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, kvId);
 
@@ -3499,22 +3689,22 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		state.update("hello");
 		state.update("ciao");
 
-		assertEquals(1, backend.numStateEntries());
+		assertEquals(1, backend.numKeyValueStateEntries());
 
 		backend.setCurrentKey(42);
 		state.update("foo");
 
-		assertEquals(2, backend.numStateEntries());
+		assertEquals(2, backend.numKeyValueStateEntries());
 
 		backend.setCurrentKey(0);
 		state.clear();
 
-		assertEquals(1, backend.numStateEntries());
+		assertEquals(1, backend.numKeyValueStateEntries());
 
 		backend.setCurrentKey(42);
 		state.clear();
 
-		assertEquals(0, backend.numStateEntries());
+		assertEquals(0, backend.numKeyValueStateEntries());
 
 		backend.dispose();
 	}
@@ -4048,14 +4238,19 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	}
 
 	protected KeyedStateHandle runSnapshot(
-		RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshotRunnableFuture) throws Exception {
+		RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshotRunnableFuture,
+		SharedStateRegistry sharedStateRegistry) throws Exception {
 
 		if (!snapshotRunnableFuture.isDone()) {
 			snapshotRunnableFuture.run();
 		}
 
 		SnapshotResult<KeyedStateHandle> snapshotResult = snapshotRunnableFuture.get();
-		return snapshotResult.getJobManagerOwnedSnapshot();
+		KeyedStateHandle jobManagerOwnedSnapshot = snapshotResult.getJobManagerOwnedSnapshot();
+		if (jobManagerOwnedSnapshot != null) {
+			jobManagerOwnedSnapshot.registerSharedStates(sharedStateRegistry);
+		}
+		return jobManagerOwnedSnapshot;
 	}
 
 	public static class TestPojo implements Serializable {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
index 7b8d69f..1ca3e80 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackendSnapshotMigrationTest.java
@@ -235,7 +235,7 @@ public class HeapKeyedStateBackendSnapshotMigrationTest extends HeapStateBackend
 
 			InternalListState<String, Integer, Long> state = keyedBackend.createInternalState(IntSerializer.INSTANCE, stateDescr);
 
-			assertEquals(7, keyedBackend.numStateEntries());
+			assertEquals(7, keyedBackend.numKeyValueStateEntries());
 
 			keyedBackend.setCurrentKey("abc");
 			state.setCurrentNamespace(namespace1);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index 9e9328b..805ae1c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -109,7 +109,7 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@Override
-	public int numStateEntries() {
+	public int numKeyValueStateEntries() {
 		int count = 0;
 		for (String state : stateValues.keySet()) {
 			for (K key : stateValues.get(state).keySet()) {
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 4af5a27..7ead620 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -27,6 +27,7 @@ import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.CompatibilityResult;
 import org.apache.flink.api.common.typeutils.CompatibilityUtil;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer;
@@ -1319,7 +1320,6 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		RegisteredKeyValueStateBackendMetaInfo<N, S> newMetaInfo;
 		if (stateInfo != null) {
 
-			@SuppressWarnings("unchecked")
 			StateMetaInfoSnapshot restoredMetaInfoSnapshot = restoredKvStateMetaInfos.get(stateDesc.getName());
 
 			Preconditions.checkState(
@@ -1398,7 +1398,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	@VisibleForTesting
 	@SuppressWarnings("unchecked")
 	@Override
-	public int numStateEntries() {
+	public int numKeyValueStateEntries() {
 		int count = 0;
 
 		for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> column : kvStateInformation.values()) {
@@ -2668,10 +2668,10 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T>
 		create(@Nonnull String stateName, @Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
 
-			final Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> entry =
+			final Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> metaInfoTuple =
 				tryRegisterPriorityQueueMetaInfo(stateName, byteOrderedElementSerializer);
 
-			final ColumnFamilyHandle columnFamilyHandle = entry.f0;
+			final ColumnFamilyHandle columnFamilyHandle = metaInfoTuple.f0;
 
 			return new KeyGroupPartitionedPriorityQueue<>(
 				KeyExtractorFunction.forKeyedObjects(),
@@ -2708,20 +2708,51 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		@Nonnull String stateName,
 		@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
 
-		Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> entry =
+		Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> metaInfoTuple =
 			kvStateInformation.get(stateName);
 
-		if (entry == null) {
+		if (metaInfoTuple == null) {
+			final ColumnFamilyHandle columnFamilyHandle = createColumnFamily(stateName);
+
 			RegisteredPriorityQueueStateBackendMetaInfo<T> metaInfo =
 				new RegisteredPriorityQueueStateBackendMetaInfo<>(stateName, byteOrderedElementSerializer);
 
-			final ColumnFamilyHandle columnFamilyHandle = createColumnFamily(stateName);
+			metaInfoTuple = new Tuple2<>(columnFamilyHandle, metaInfo);
+			kvStateInformation.put(stateName, metaInfoTuple);
+		} else {
+			// TODO we implement the simple way of supporting the current functionality, mimicking keyed state
+			// because this should be reworked in FLINK-9376 and then we should have a common algorithm over
+			// StateMetaInfoSnapshot that avoids this code duplication.
+			StateMetaInfoSnapshot restoredMetaInfoSnapshot = restoredKvStateMetaInfos.get(stateName);
 
-			entry = new Tuple2<>(columnFamilyHandle, metaInfo);
-			kvStateInformation.put(stateName, entry);
+			Preconditions.checkState(
+				restoredMetaInfoSnapshot != null,
+				"Requested to check compatibility of a restored RegisteredKeyedBackendStateMetaInfo," +
+					" but its corresponding restored snapshot cannot be found.");
+
+			StateMetaInfoSnapshot.CommonSerializerKeys serializerKey =
+				StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER;
+
+			TypeSerializer<?> metaInfoTypeSerializer = restoredMetaInfoSnapshot.getTypeSerializer(serializerKey);
+
+			if (metaInfoTypeSerializer != byteOrderedElementSerializer) {
+				CompatibilityResult<T> compatibilityResult = CompatibilityUtil.resolveCompatibilityResult(
+					metaInfoTypeSerializer,
+					null,
+					restoredMetaInfoSnapshot.getTypeSerializerConfigSnapshot(serializerKey),
+					byteOrderedElementSerializer);
+
+				if (compatibilityResult.isRequiresMigration()) {
+					throw new FlinkRuntimeException(StateMigrationException.notSupported());
+				}
+
+				// update meta info with new serializer
+				metaInfoTuple.f1 =
+					new RegisteredPriorityQueueStateBackendMetaInfo<>(stateName, byteOrderedElementSerializer);
+			}
 		}
 
-		return entry;
+		return metaInfoTuple;
 	}
 
 	@Override
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
index 6b254ce..0ea0d3f 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackendTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.execution.Environment;
@@ -124,6 +125,11 @@ public class RocksDBStateBackendTest extends StateBackendTestBase<RocksDBStateBa
 		dbPath = tempFolder.newFolder().getAbsolutePath();
 		String checkpointPath = tempFolder.newFolder().toURI().toString();
 		RocksDBStateBackend backend = new RocksDBStateBackend(new FsStateBackend(checkpointPath), enableIncrementalCheckpointing);
+		Configuration configuration = new Configuration();
+		configuration.setString(
+			RocksDBOptions.TIMER_SERVICE_FACTORY,
+			RocksDBStateBackend.PriorityQueueStateType.ROCKSDB.toString());
+		backend = backend.configure(configuration);
 		backend.setDbStoragePath(dbPath);
 		return backend;
 	}
diff --git a/flink-streaming-java/pom.xml b/flink-streaming-java/pom.xml
index 02da827..e64ed48 100644
--- a/flink-streaming-java/pom.xml
+++ b/flink-streaming-java/pom.xml
@@ -42,6 +42,8 @@ under the License.
 			<groupId>org.apache.flink</groupId>
 			<artifactId>flink-core</artifactId>
 			<version>${project.version}</version>
+			<scope>test</scope>
+			<type>test-jar</type>
 		</dependency>
 
 		<dependency>
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
index b54a1a9..ff48c3f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManager.java
@@ -46,12 +46,13 @@ import java.util.Map;
 @Internal
 public class InternalTimeServiceManager<K> {
 
-	//TODO guard these constants with a test
-	private static final String TIMER_STATE_PREFIX = "_timer_state";
-	private static final String PROCESSING_TIMER_PREFIX = TIMER_STATE_PREFIX + "/processing_";
-	private static final String EVENT_TIMER_PREFIX = TIMER_STATE_PREFIX + "/event_";
+	@VisibleForTesting
+	static final String TIMER_STATE_PREFIX = "_timer_state";
+	@VisibleForTesting
+	static final String PROCESSING_TIMER_PREFIX = TIMER_STATE_PREFIX + "/processing_";
+	@VisibleForTesting
+	static final String EVENT_TIMER_PREFIX = TIMER_STATE_PREFIX + "/event_";
 
-	private final int totalKeyGroups;
 	private final KeyGroupRange localKeyGroupRange;
 	private final KeyContext keyContext;
 
@@ -63,14 +64,11 @@ public class InternalTimeServiceManager<K> {
 	private final boolean useLegacySynchronousSnapshots;
 
 	InternalTimeServiceManager(
-		int totalKeyGroups,
 		KeyGroupRange localKeyGroupRange,
 		KeyContext keyContext,
 		PriorityQueueSetFactory priorityQueueSetFactory,
 		ProcessingTimeService processingTimeService, boolean useLegacySynchronousSnapshots) {
 
-		Preconditions.checkArgument(totalKeyGroups > 0);
-		this.totalKeyGroups = totalKeyGroups;
 		this.localKeyGroupRange = Preconditions.checkNotNull(localKeyGroupRange);
 		this.priorityQueueSetFactory = Preconditions.checkNotNull(priorityQueueSetFactory);
 		this.keyContext = Preconditions.checkNotNull(keyContext);
@@ -155,10 +153,6 @@ public class InternalTimeServiceManager<K> {
 		serializationProxy.read(stream);
 	}
 
-	public boolean isUseLegacySynchronousSnapshots() {
-		return useLegacySynchronousSnapshots;
-	}
-
 	////////////////////			Methods used ONLY IN TESTS				////////////////////
 
 	@VisibleForTesting
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
index a6bee4c..64af993 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/StreamTaskStateInitializerImpl.java
@@ -205,7 +205,6 @@ public class StreamTaskStateInitializerImpl implements StreamTaskStateInitialize
 		final KeyGroupRange keyGroupRange = keyedStatedBackend.getKeyGroupRange();
 
 		final InternalTimeServiceManager<K> timeServiceManager = new InternalTimeServiceManager<>(
-			keyedStatedBackend.getNumberOfKeyGroups(),
 			keyGroupRange,
 			keyContext,
 			keyedStatedBackend,
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
index 73f42ef..a83cc3a 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/TimerSerializer.java
@@ -19,9 +19,12 @@
 package org.apache.flink.streaming.api.operators;
 
 import org.apache.flink.api.common.typeutils.CompatibilityResult;
+import org.apache.flink.api.common.typeutils.CompatibilityUtil;
 import org.apache.flink.api.common.typeutils.CompositeTypeSerializerConfigSnapshot;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot;
+import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.util.MathUtils;
@@ -29,6 +32,7 @@ import org.apache.flink.util.MathUtils;
 import javax.annotation.Nonnull;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Objects;
 
 /**
@@ -42,6 +46,9 @@ public class TimerSerializer<K, N> extends TypeSerializer<TimerHeapInternalTimer
 
 	private static final long serialVersionUID = 1L;
 
+	private static final int KEY_SERIALIZER_SNAPSHOT_INDEX = 0;
+	private static final int NAMESPACE_SERIALIZER_SNAPSHOT_INDEX = 1;
+
 	/** Serializer for the key. */
 	@Nonnull
 	private final TypeSerializer<K> keySerializer;
@@ -208,8 +215,35 @@ public class TimerSerializer<K, N> extends TypeSerializer<TimerHeapInternalTimer
 	@Override
 	public CompatibilityResult<TimerHeapInternalTimer<K, N>> ensureCompatibility(
 		TypeSerializerConfigSnapshot configSnapshot) {
-		//TODO this is just a mock (assuming no serializer updates) for now and needs a proper implementation! change this before release.
-		return CompatibilityResult.compatible();
+
+		if (configSnapshot instanceof TimerSerializerConfigSnapshot) {
+			List<Tuple2<TypeSerializer<?>, TypeSerializerConfigSnapshot>> previousSerializersAndConfigs =
+				((TimerSerializerConfigSnapshot) configSnapshot).getNestedSerializersAndConfigs();
+
+			if (previousSerializersAndConfigs.size() == 2) {
+				Tuple2<TypeSerializer<?>, TypeSerializerConfigSnapshot> keySerializerAndSnapshot =
+					previousSerializersAndConfigs.get(KEY_SERIALIZER_SNAPSHOT_INDEX);
+				Tuple2<TypeSerializer<?>, TypeSerializerConfigSnapshot> namespaceSerializerAndSnapshot =
+					previousSerializersAndConfigs.get(NAMESPACE_SERIALIZER_SNAPSHOT_INDEX);
+				CompatibilityResult<K> keyCompatibilityResult = CompatibilityUtil.resolveCompatibilityResult(
+					keySerializerAndSnapshot.f0,
+					UnloadableDummyTypeSerializer.class,
+					keySerializerAndSnapshot.f1,
+					keySerializer);
+
+				CompatibilityResult<N> namespaceCompatibilityResult = CompatibilityUtil.resolveCompatibilityResult(
+					namespaceSerializerAndSnapshot.f0,
+					UnloadableDummyTypeSerializer.class,
+					namespaceSerializerAndSnapshot.f1,
+					namespaceSerializer);
+
+				if (!keyCompatibilityResult.isRequiresMigration()
+					&& !namespaceCompatibilityResult.isRequiresMigration()) {
+					return CompatibilityResult.compatible();
+				}
+			}
+		}
+		return CompatibilityResult.requiresMigration();
 	}
 
 	@Nonnull
@@ -230,16 +264,29 @@ public class TimerSerializer<K, N> extends TypeSerializer<TimerHeapInternalTimer
 	 */
 	public static class TimerSerializerConfigSnapshot<K, N> extends CompositeTypeSerializerConfigSnapshot {
 
+		private static final int VERSION = 1;
+
 		public TimerSerializerConfigSnapshot() {
 		}
 
-		public TimerSerializerConfigSnapshot(TypeSerializer<K> keySerializer, TypeSerializer<N> namespaceSerializer) {
-			super(keySerializer, namespaceSerializer);
+		public TimerSerializerConfigSnapshot(
+			@Nonnull TypeSerializer<K> keySerializer,
+			@Nonnull TypeSerializer<N> namespaceSerializer) {
+			super(init(keySerializer, namespaceSerializer));
+		}
+
+		private static TypeSerializer<?>[] init(
+			@Nonnull TypeSerializer<?> keySerializer,
+			@Nonnull TypeSerializer<?> namespaceSerializer) {
+			TypeSerializer<?>[] timerSerializers = new TypeSerializer[2];
+			timerSerializers[KEY_SERIALIZER_SNAPSHOT_INDEX] = keySerializer;
+			timerSerializers[NAMESPACE_SERIALIZER_SNAPSHOT_INDEX] = namespaceSerializer;
+			return timerSerializers;
 		}
 
 		@Override
 		public int getVersion() {
-			return 0;
+			return VERSION;
 		}
 	}
 }
diff --git a/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerTest.java
similarity index 51%
copy from flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
copy to flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerTest.java
index 00e0e73..905e8d7 100644
--- a/flink-core/src/main/java/org/apache/flink/util/StateMigrationException.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/InternalTimeServiceManagerTest.java
@@ -16,23 +16,26 @@
  * limitations under the License.
  */
 
-package org.apache.flink.util;
+package org.apache.flink.streaming.api.operators;
 
-/**
- * Base class for state migration related exceptions.
- */
-public class StateMigrationException extends FlinkException {
-	private static final long serialVersionUID = 8268516412747670839L;
+import org.apache.flink.util.TestLogger;
 
-	public StateMigrationException(String message) {
-		super(message);
-	}
+import org.junit.Assert;
+import org.junit.Test;
 
-	public StateMigrationException(Throwable cause) {
-		super(cause);
-	}
+/**
+ * Tests for {@link InternalTimeServiceManager}.
+ */
+public class InternalTimeServiceManagerTest extends TestLogger {
 
-	public StateMigrationException(String message, Throwable cause) {
-		super(message, cause);
+	/**
+	 * This test fixes some constants, because changing them can harm backwards compatibility.
+	 */
+	@Test
+	public void fixConstants() {
+		String expectedTimerStatePrefix = "_timer_state";
+		Assert.assertEquals(expectedTimerStatePrefix, InternalTimeServiceManager.TIMER_STATE_PREFIX);
+		Assert.assertEquals(expectedTimerStatePrefix + "/processing_", InternalTimeServiceManager.PROCESSING_TIMER_PREFIX);
+		Assert.assertEquals(expectedTimerStatePrefix + "/event_", InternalTimeServiceManager.EVENT_TIMER_PREFIX);
 	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/TimerSerializerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/TimerSerializerTest.java
new file mode 100644
index 0000000..9fe4ffc
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/TimerSerializerTest.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.api.common.typeutils.SerializerTestBase;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+
+/**
+ * Test for {@link TimerSerializer}.
+ */
+public class TimerSerializerTest extends SerializerTestBase<TimerHeapInternalTimer<Long, TimeWindow>> {
+
+	private static final TypeSerializer<Long> KEY_SERIALIZER = LongSerializer.INSTANCE;
+	private static final TypeSerializer<TimeWindow> NAMESPACE_SERIALIZER = new TimeWindow.Serializer();
+
+	@Override
+	protected TypeSerializer<TimerHeapInternalTimer<Long, TimeWindow>> createSerializer() {
+		return new TimerSerializer<>(KEY_SERIALIZER, NAMESPACE_SERIALIZER);
+	}
+
+	@Override
+	protected int getLength() {
+		return Long.BYTES + KEY_SERIALIZER.getLength() + NAMESPACE_SERIALIZER.getLength();
+	}
+
+	@SuppressWarnings("unchecked")
+	@Override
+	protected Class<TimerHeapInternalTimer<Long, TimeWindow>> getTypeClass() {
+		return (Class<TimerHeapInternalTimer<Long, TimeWindow>>) (Class<?>) TimerHeapInternalTimer.class;
+	}
+
+	@SuppressWarnings("unchecked")
+	@Override
+	protected TimerHeapInternalTimer<Long, TimeWindow>[] getTestData() {
+		return (TimerHeapInternalTimer<Long, TimeWindow>[]) new TimerHeapInternalTimer[]{
+			new TimerHeapInternalTimer<>(42L, 4711L, new TimeWindow(1000L, 2000L)),
+			new TimerHeapInternalTimer<>(0L, 0L, new TimeWindow(0L, 0L)),
+			new TimerHeapInternalTimer<>(1L, -1L, new TimeWindow(1L, -1L)),
+			new TimerHeapInternalTimer<>(-1L, 1L, new TimeWindow(-1L, 1L)),
+			new TimerHeapInternalTimer<>(Long.MAX_VALUE, Long.MIN_VALUE, new TimeWindow(Long.MAX_VALUE, Long.MIN_VALUE)),
+			new TimerHeapInternalTimer<>(Long.MIN_VALUE, Long.MAX_VALUE, new TimeWindow(Long.MIN_VALUE, Long.MAX_VALUE))
+		};
+	}
+}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TriggerTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TriggerTestHarness.java
index 1536956..bc5bb1b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TriggerTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/TriggerTestHarness.java
@@ -115,11 +115,11 @@ public class TriggerTestHarness<T, W extends Window> {
 	}
 
 	public int numStateEntries() {
-		return stateBackend.numStateEntries();
+		return stateBackend.numKeyValueStateEntries();
 	}
 
 	public int numStateEntries(W window) {
-		return stateBackend.numStateEntries(window);
+		return stateBackend.numKeyValueStateEntries(window);
 	}
 
 	/**
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index 2035c46..caf846f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -72,7 +72,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 		AbstractStreamOperator<?> abstractStreamOperator = (AbstractStreamOperator<?>) operator;
 		KeyedStateBackend<Object> keyedStateBackend = abstractStreamOperator.getKeyedStateBackend();
 		if (keyedStateBackend instanceof HeapKeyedStateBackend) {
-			return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries();
+			return ((HeapKeyedStateBackend) keyedStateBackend).numKeyValueStateEntries();
 		} else {
 			throw new UnsupportedOperationException();
 		}
@@ -82,7 +82,7 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 		AbstractStreamOperator<?> abstractStreamOperator = (AbstractStreamOperator<?>) operator;
 		KeyedStateBackend<Object> keyedStateBackend = abstractStreamOperator.getKeyedStateBackend();
 		if (keyedStateBackend instanceof HeapKeyedStateBackend) {
-			return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries(namespace);
+			return ((HeapKeyedStateBackend) keyedStateBackend).numKeyValueStateEntries(namespace);
 		} else {
 			throw new UnsupportedOperationException();
 		}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
index 607eee0..c00e59a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
@@ -62,7 +62,7 @@ public class KeyedTwoInputStreamOperatorTestHarness<K, IN1, IN2, OUT>
 		AbstractStreamOperator<?> abstractStreamOperator = (AbstractStreamOperator<?>) operator;
 		KeyedStateBackend<Object> keyedStateBackend = abstractStreamOperator.getKeyedStateBackend();
 		if (keyedStateBackend instanceof HeapKeyedStateBackend) {
-			return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries();
+			return ((HeapKeyedStateBackend) keyedStateBackend).numKeyValueStateEntries();
 		} else {
 			throw new UnsupportedOperationException();
 		}