You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2019/07/09 07:21:30 UTC

[flink] branch master updated: [FLINK-12693][state] Store state per key-group in CopyOnWriteStateTable

This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 8f47b38  [FLINK-12693][state] Store state per key-group in CopyOnWriteStateTable
8f47b38 is described below

commit 8f47b38b5273c09603b53af1c6579c42172fe634
Author: PengFei Li <lp...@gmail.com>
AuthorDate: Tue Jul 9 15:21:19 2019 +0800

    [FLINK-12693][state] Store state per key-group in CopyOnWriteStateTable
    
    This closes #8611.
---
 .../state/heap/AbstractStateTableSnapshot.java     |   78 +-
 ...iteStateTable.java => CopyOnWriteStateMap.java} |  631 +++++------
 .../state/heap/CopyOnWriteStateMapSnapshot.java    |  317 ++++++
 .../runtime/state/heap/CopyOnWriteStateTable.java  | 1161 +-------------------
 .../state/heap/CopyOnWriteStateTableSnapshot.java  |  271 +----
 .../runtime/state/heap/NestedMapsStateTable.java   |  486 +-------
 .../flink/runtime/state/heap/NestedStateMap.java   |  290 +++++
 .../runtime/state/heap/NestedStateMapSnapshot.java |  106 ++
 .../apache/flink/runtime/state/heap/StateMap.java  |  165 +++
 .../flink/runtime/state/heap/StateMapSnapshot.java |   79 ++
 .../flink/runtime/state/heap/StateTable.java       |  269 ++++-
 .../flink/runtime/state/StateBackendTestBase.java  |    6 +-
 ...TableTest.java => CopyOnWriteStateMapTest.java} |  274 ++---
 .../state/heap/CopyOnWriteStateTableTest.java      |  503 +--------
 .../runtime/state/heap/MockInternalKeyContext.java |   36 +
 .../heap/StateTableKeyGroupPartitionerTest.java    |  102 --
 .../heap/StateTableSnapshotCompatibilityTest.java  |   10 +-
 .../flink/runtime/state/ttl/TtlStateTestBase.java  |   15 +-
 18 files changed, 1759 insertions(+), 3040 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java
index 03f253b..41e522f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java
@@ -19,34 +19,106 @@
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.Preconditions;
 
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+
 /**
  * Abstract base class for snapshots of a {@link StateTable}. Offers a way to serialize the snapshot (by key-group).
  * All snapshots should be released after usage.
  */
 @Internal
-abstract class AbstractStateTableSnapshot<K, N, S, T extends StateTable<K, N, S>> implements StateSnapshot {
+abstract class AbstractStateTableSnapshot<K, N, S>
+	implements StateSnapshot, StateSnapshot.StateKeyGroupWriter {
 
 	/**
 	 * The {@link StateTable} from which this snapshot was created.
 	 */
-	final T owningStateTable;
+	protected final StateTable<K, N, S> owningStateTable;
+
+	/**
+	 * A local duplicate of the table's key serializer.
+	 */
+	@Nonnull
+	protected final TypeSerializer<K> localKeySerializer;
+
+	/**
+	 * A local duplicate of the table's namespace serializer.
+	 */
+	@Nonnull
+	protected final TypeSerializer<N> localNamespaceSerializer;
+
+	/**
+	 * A local duplicate of the table's state serializer.
+	 */
+	@Nonnull
+	protected final TypeSerializer<S> localStateSerializer;
+
+	@Nullable
+	protected final StateSnapshotTransformer<S> stateSnapshotTransformer;
 
 	/**
 	 * Creates a new {@link AbstractStateTableSnapshot} for and owned by the given table.
 	 *
 	 * @param owningStateTable the {@link StateTable} for which this object represents a snapshot.
 	 */
-	AbstractStateTableSnapshot(T owningStateTable) {
+	AbstractStateTableSnapshot(
+		StateTable<K, N, S> owningStateTable,
+		TypeSerializer<K> localKeySerializer,
+		TypeSerializer<N> localNamespaceSerializer,
+		TypeSerializer<S> localStateSerializer,
+		@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) {
 		this.owningStateTable = Preconditions.checkNotNull(owningStateTable);
+		this.localKeySerializer = Preconditions.checkNotNull(localKeySerializer);
+		this.localNamespaceSerializer = Preconditions.checkNotNull(localNamespaceSerializer);
+		this.localStateSerializer = Preconditions.checkNotNull(localStateSerializer);
+		this.stateSnapshotTransformer = stateSnapshotTransformer;
 	}
 
 	/**
+	 * Return the state map snapshot for the key group. If the snapshot does not exist, return null.
+	 */
+	protected abstract StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> getStateMapSnapshotForKeyGroup(int keyGroup);
+
+	/**
 	 * Optional hook to release resources for this snapshot at the end of its lifecycle.
 	 */
 	@Override
 	public void release() {
 	}
+
+	@Nonnull
+	@Override
+	public StateMetaInfoSnapshot getMetaInfoSnapshot() {
+		return owningStateTable.getMetaInfo().snapshot();
+	}
+
+	@Override
+	public StateKeyGroupWriter getKeyGroupWriter() {
+		return this;
+	}
+
+	/**
+	 * Implementation note: we currently chose the same format between {@link NestedMapsStateTable} and
+	 * {@link CopyOnWriteStateTable}.
+	 *
+	 * <p>{@link NestedMapsStateTable} could naturally support a kind of
+	 * prefix-compressed format (grouping by namespace, writing the namespace only once per group instead for each
+	 * mapping). We might implement support for different formats later (tailored towards different state table
+	 * implementations).
+	 */
+	@Override
+	public void writeStateInKeyGroup(@Nonnull DataOutputView dov, int keyGroupId) throws IOException {
+		StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> stateMapSnapshot = getStateMapSnapshotForKeyGroup(keyGroupId);
+		stateMapSnapshot.writeState(localKeySerializer, localNamespaceSerializer, localStateSerializer, dov, stateSnapshotTransformer);
+		stateMapSnapshot.release();
+	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMap.java
similarity index 56%
copy from flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
copy to flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMap.java
index 5852bc2..7e71c0b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMap.java
@@ -20,10 +20,9 @@ package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateEntry;
 import org.apache.flink.runtime.state.StateTransformationFunction;
-import org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor;
+import org.apache.flink.runtime.state.internal.InternalKvState;
 import org.apache.flink.util.MathUtils;
 import org.apache.flink.util.Preconditions;
 
@@ -46,54 +45,50 @@ import java.util.stream.StreamSupport;
 import static org.apache.flink.util.CollectionUtil.MAX_ARRAY_SIZE;
 
 /**
- * Implementation of Flink's in-memory state tables with copy-on-write support. This map does not support null values
+ * Implementation of Flink's in-memory state maps with copy-on-write support. This map does not support null values
  * for key or namespace.
- * <p>
- * {@link CopyOnWriteStateTable} sacrifices some peak performance and memory efficiency for features like incremental
+ *
+ * <p>{@link CopyOnWriteStateMap} sacrifices some peak performance and memory efficiency for features like incremental
  * rehashing and asynchronous snapshots through copy-on-write. Copy-on-write tries to minimize the amount of copying by
  * maintaining version meta data for both, the map structure and the state objects. However, we must often proactively
  * copy state objects when we hand them to the user.
- * <p>
- * As for any state backend, user should not keep references on state objects that they obtained from state backends
+ *
+ * <p>As for any state backend, user should not keep references on state objects that they obtained from state backends
  * outside the scope of the user function calls.
- * <p>
- * Some brief maintenance notes:
- * <p>
- * 1) Flattening the underlying data structure from nested maps (namespace) -> (key) -> (state) to one flat map
+ *
+ * <p>Some brief maintenance notes:
+ *
+ * <p>1) Flattening the underlying data structure from nested maps (namespace) -> (key) -> (state) to one flat map
  * (key, namespace) -> (state) brings certain performance trade-offs. In theory, the flat map has one less level of
  * indirection compared to the nested map. However, the nested map naturally de-duplicates namespace objects for which
  * #equals() is true. This leads to potentially a lot of redundant namespace objects for the flattened version. Those,
  * in turn, can again introduce more cache misses because we need to follow the namespace object on all operations to
  * ensure entry identities. Obviously, copy-on-write can also add memory overhead. So does the meta data to track
- * copy-on-write requirement (state and entry versions on {@link StateTableEntry}).
- * <p>
- * 2) A flat map structure is a lot easier when it comes to tracking copy-on-write of the map structure.
- * <p>
- * 3) Nested structure had the (never used) advantage that we can easily drop and iterate whole namespaces. This could
+ * copy-on-write requirement (state and entry versions on {@link StateMapEntry}).
+ *
+ * <p>2) A flat map structure is a lot easier when it comes to tracking copy-on-write of the map structure.
+ *
+ * <p>3) Nested structure had the (never used) advantage that we can easily drop and iterate whole namespaces. This could
  * give locality advantages for certain access pattern, e.g. iterating a namespace.
- * <p>
- * 4) Serialization format is changed from namespace-prefix compressed (as naturally provided from the old nested
+ *
+ * <p>4) Serialization format is changed from namespace-prefix compressed (as naturally provided from the old nested
  * structure) to making all entries self contained as (key, namespace, state).
- * <p>
- * 5) We got rid of having multiple nested tables, one for each key-group. Instead, we partition state into key-groups
- * on-the-fly, during the asynchronous part of a snapshot.
- * <p>
- * 6) Currently, a state table can only grow, but never shrinks on low load. We could easily add this if required.
- * <p>
- * 7) Heap based state backends like this can easily cause a lot of GC activity. Besides using G1 as garbage collector,
+ *
+ * <p>5) Currently, a state map can only grow, but never shrinks on low load. We could easily add this if required.
+ *
+ * <p>6) Heap based state backends like this can easily cause a lot of GC activity. Besides using G1 as garbage collector,
  * we should provide an additional state backend that operates on off-heap memory. This would sacrifice peak performance
  * (due to de/serialization of objects) for a lower, but more constant throughput and potentially huge simplifications
  * w.r.t. copy-on-write.
- * <p>
- * 8) We could try a hybrid of a serialized and object based backends, where key and namespace of the entries are both
+ *
+ * <p>7) We could try a hybrid of a serialized and object based backends, where key and namespace of the entries are both
  * serialized in one byte-array.
- * <p>
- * 9) We could consider smaller types (e.g. short) for the version counting and think about some reset strategy before
+ *
+ * <p>9) We could consider smaller types (e.g. short) for the version counting and think about some reset strategy before
  * overflows, when there is no snapshot running. However, this would have to touch all entries in the map.
- * <p>
- * This class was initially based on the {@link java.util.HashMap} implementation of the Android JDK, but is now heavily
- * customized towards the use case of table for state entries.
  *
+ * <p>This class was initially based on the {@link java.util.HashMap} implementation of the Android JDK, but is now heavily
+ * customized towards the use case of map for state entries.
  * IMPORTANT: the contracts for this class rely on the user not holding any references to objects returned by this map
  * beyond the life cycle of per-element operations. Or phrased differently, all get-update-put operations on a mapping
  * should be within one call of processElement. Otherwise, the user must take care of taking deep copies, e.g. for
@@ -103,7 +98,7 @@ import static org.apache.flink.util.CollectionUtil.MAX_ARRAY_SIZE;
  * @param <N> type of namespace.
  * @param <S> type of value.
  */
-public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implements Iterable<StateEntry<K, N, S>> {
+public class CopyOnWriteStateMap<K, N, S> extends StateMap<K, N, S> {
 
 	/**
 	 * The logger.
@@ -111,40 +106,45 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	private static final Logger LOG = LoggerFactory.getLogger(HeapKeyedStateBackend.class);
 
 	/**
-	 * Min capacity (other than zero) for a {@link CopyOnWriteStateTable}. Must be a power of two
+	 * Min capacity (other than zero) for a {@link CopyOnWriteStateMap}. Must be a power of two
 	 * greater than 1 (and less than 1 << 30).
 	 */
 	private static final int MINIMUM_CAPACITY = 4;
 
 	/**
-	 * Max capacity for a {@link CopyOnWriteStateTable}. Must be a power of two >= MINIMUM_CAPACITY.
+	 * Max capacity for a {@link CopyOnWriteStateMap}. Must be a power of two >= MINIMUM_CAPACITY.
 	 */
 	private static final int MAXIMUM_CAPACITY = 1 << 30;
 
 	/**
-	 * Default capacity for a {@link CopyOnWriteStateTable}. Must be a power of two,
+	 * Default capacity for a {@link CopyOnWriteStateMap}. Must be a power of two,
 	 * greater than {@code MINIMUM_CAPACITY} and less than {@code MAXIMUM_CAPACITY}.
 	 */
-	public static final int DEFAULT_CAPACITY = 1024;
+	public static final int DEFAULT_CAPACITY = 128;
 
 	/**
-	 * Minimum number of entries that one step of incremental rehashing migrates from the old to the new sub-table.
+	 * Minimum number of entries that one step of incremental rehashing migrates from the old to the new sub-map.
 	 */
 	private static final int MIN_TRANSFERRED_PER_INCREMENTAL_REHASH = 4;
 
 	/**
-	 * An empty table shared by all zero-capacity maps (typically from default
+	 * The serializer of the state.
+	 */
+	protected final TypeSerializer<S> stateSerializer;
+
+	/**
+	 * An empty map shared by all zero-capacity maps (typically from default
 	 * constructor). It is never written to, and replaced on first put. Its size
 	 * is set to half the minimum, so that the first resize will create a
-	 * minimum-sized table.
+	 * minimum-sized map.
 	 */
-	private static final StateTableEntry<?, ?, ?>[] EMPTY_TABLE = new StateTableEntry[MINIMUM_CAPACITY >>> 1];
+	private static final StateMapEntry<?, ?, ?>[] EMPTY_TABLE = new StateMapEntry[MINIMUM_CAPACITY >>> 1];
 
 	/**
-	 * Empty entry that we use to bootstrap our {@link CopyOnWriteStateTable.StateEntryIterator}.
+	 * Empty entry that we use to bootstrap our {@link CopyOnWriteStateMap.StateEntryIterator}.
 	 */
-	private static final StateTableEntry<?, ?, ?> ITERATOR_BOOTSTRAP_ENTRY =
-		new StateTableEntry<>(new Object(), new Object(), new Object(), 0, null, 0, 0);
+	private static final StateMapEntry<?, ?, ?> ITERATOR_BOOTSTRAP_ENTRY =
+		new StateMapEntry<>(new Object(), new Object(), new Object(), 0, null, 0, 0);
 
 	/**
 	 * Maintains an ordered set of version ids that are still in use by unreleased snapshots.
@@ -152,20 +152,20 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	private final TreeSet<Integer> snapshotVersions;
 
 	/**
-	 * This is the primary entry array (hash directory) of the state table. If no incremental rehash is ongoing, this
+	 * This is the primary entry array (hash directory) of the state map. If no incremental rehash is ongoing, this
 	 * is the only used table.
 	 **/
-	private StateTableEntry<K, N, S>[] primaryTable;
+	private StateMapEntry<K, N, S>[] primaryTable;
 
 	/**
 	 * We maintain a secondary entry array while performing an incremental rehash. The purpose is to slowly migrate
 	 * entries from the primary table to this resized table array. When all entries are migrated, this becomes the new
 	 * primary table.
 	 */
-	private StateTableEntry<K, N, S>[] incrementalRehashTable;
+	private StateMapEntry<K, N, S>[] incrementalRehashTable;
 
 	/**
-	 * The current number of mappings in the primary table.
+	 * The current number of mappings in the primary talbe.
 	 */
 	private int primaryTableSize;
 
@@ -182,7 +182,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	/**
 	 * The current version of this map. Used for copy-on-write mechanics.
 	 */
-	private int stateTableVersion;
+	private int stateMapVersion;
 
 	/**
 	 * The highest version of this map that is still required by any unreleased snapshot.
@@ -195,7 +195,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	private N lastNamespace;
 
 	/**
-	 * The {@link CopyOnWriteStateTable} is rehashed when its size exceeds this threshold.
+	 * The {@link CopyOnWriteStateMap} is rehashed when its size exceeds this threshold.
 	 * The value of this field is generally .75 * capacity, except when
 	 * the capacity is zero, as described in the EMPTY_TABLE declaration
 	 * above.
@@ -209,46 +209,36 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	private int modCount;
 
 	/**
-	 * Constructs a new {@code StateTable} with default capacity of {@code DEFAULT_CAPACITY}.
+	 * Constructs a new {@code StateMap} with default capacity of {@code DEFAULT_CAPACITY}.
 	 *
-	 * @param keyContext    the key context.
-	 * @param metaInfo      the meta information, including the type serializer for state copy-on-write.
-	 * @param keySerializer the serializer of the key.
+	 * @param stateSerializer the serializer of the key.
 	 */
-	CopyOnWriteStateTable(
-		InternalKeyContext<K> keyContext,
-		RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo,
-		TypeSerializer<K> keySerializer) {
-		this(keyContext, metaInfo, DEFAULT_CAPACITY, keySerializer);
+	CopyOnWriteStateMap(TypeSerializer<S> stateSerializer) {
+		this(DEFAULT_CAPACITY, stateSerializer);
 	}
 
 	/**
-	 * Constructs a new {@code StateTable} instance with the specified capacity.
+	 * Constructs a new {@code StateMap} instance with the specified capacity.
 	 *
-	 * @param keyContext    the key context.
-	 * @param metaInfo      the meta information, including the type serializer for state copy-on-write.
 	 * @param capacity      the initial capacity of this hash map.
-	 * @param keySerializer the serializer of the key.
+	 * @param stateSerializer the serializer of the key.
 	 * @throws IllegalArgumentException when the capacity is less than zero.
 	 */
 	@SuppressWarnings("unchecked")
-	private CopyOnWriteStateTable(
-		InternalKeyContext<K> keyContext,
-		RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo,
-		int capacity,
-		TypeSerializer<K> keySerializer) {
-		super(keyContext, metaInfo, keySerializer);
+	private CopyOnWriteStateMap(
+		int capacity, TypeSerializer<S> stateSerializer) {
+		this.stateSerializer = Preconditions.checkNotNull(stateSerializer);
 
-		// initialized tables to EMPTY_TABLE.
-		this.primaryTable = (StateTableEntry<K, N, S>[]) EMPTY_TABLE;
-		this.incrementalRehashTable = (StateTableEntry<K, N, S>[]) EMPTY_TABLE;
+		// initialized maps to EMPTY_TABLE.
+		this.primaryTable = (StateMapEntry<K, N, S>[]) EMPTY_TABLE;
+		this.incrementalRehashTable = (StateMapEntry<K, N, S>[]) EMPTY_TABLE;
 
 		// initialize sizes to 0.
 		this.primaryTableSize = 0;
 		this.incrementalRehashTableSize = 0;
 
 		this.rehashIndex = 0;
-		this.stateTableVersion = 0;
+		this.stateMapVersion = 0;
 		this.highestRequiredSnapshotVersion = 0;
 		this.snapshotVersions = new TreeSet<>();
 
@@ -271,12 +261,12 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		primaryTable = makeTable(capacity);
 	}
 
-	// Public API from AbstractStateTable ------------------------------------------------------------------------------
+	// Public API from StateMap ------------------------------------------------------------------------------
 
 	/**
-	 * Returns the total number of entries in this {@link CopyOnWriteStateTable}. This is the sum of both sub-tables.
+	 * Returns the total number of entries in this {@link CopyOnWriteStateMap}. This is the sum of both sub-maps.
 	 *
-	 * @return the number of entries in this {@link CopyOnWriteStateTable}.
+	 * @return the number of entries in this {@link CopyOnWriteStateMap}.
 	 */
 	@Override
 	public int size() {
@@ -288,10 +278,10 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 
 		final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace);
 		final int requiredVersion = highestRequiredSnapshotVersion;
-		final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash);
+		final StateMapEntry<K, N, S>[] tab = selectActiveTable(hash);
 		int index = hash & (tab.length - 1);
 
-		for (StateTableEntry<K, N, S> e = tab[index]; e != null; e = e.next) {
+		for (StateMapEntry<K, N, S> e = tab[index]; e != null; e = e.next) {
 			final K eKey = e.key;
 			final N eNamespace = e.namespace;
 			if ((e.hash == hash && key.equals(eKey) && namespace.equals(eNamespace))) {
@@ -302,7 +292,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 					if (e.entryVersion < requiredVersion) {
 						e = handleChainedEntryCopyOnWrite(tab, hash & (tab.length - 1), e);
 					}
-					e.stateVersion = stateTableVersion;
+					e.stateVersion = stateMapVersion;
 					e.state = getStateSerializer().copy(e.state);
 				}
 
@@ -314,69 +304,12 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	}
 
 	@Override
-	public Stream<K> getKeys(N namespace) {
-		return StreamSupport.stream(spliterator(), false)
-			.filter(entry -> entry.getNamespace().equals(namespace))
-			.map(StateEntry::getKey);
-	}
-
-	@Override
-	public void put(K key, int keyGroup, N namespace, S state) {
-		put(key, namespace, state);
-	}
-
-	@Override
-	public S get(N namespace) {
-		return get(keyContext.getCurrentKey(), namespace);
-	}
-
-	@Override
-	public boolean containsKey(N namespace) {
-		return containsKey(keyContext.getCurrentKey(), namespace);
-	}
-
-	@Override
-	public void put(N namespace, S state) {
-		put(keyContext.getCurrentKey(), namespace, state);
-	}
-
-	@Override
-	public S putAndGetOld(N namespace, S state) {
-		return putAndGetOld(keyContext.getCurrentKey(), namespace, state);
-	}
-
-	@Override
-	public void remove(N namespace) {
-		remove(keyContext.getCurrentKey(), namespace);
-	}
-
-	@Override
-	public S removeAndGetOld(N namespace) {
-		return removeAndGetOld(keyContext.getCurrentKey(), namespace);
-	}
-
-	@Override
-	public <T> void transform(N namespace, T value, StateTransformationFunction<S, T> transformation) throws Exception {
-		transform(keyContext.getCurrentKey(), namespace, value, transformation);
-	}
-
-	// Private implementation details of the API methods ---------------------------------------------------------------
-
-	/**
-	 * Returns whether this table contains the specified key/namespace composite key.
-	 *
-	 * @param key       the key in the composite key to search for. Not null.
-	 * @param namespace the namespace in the composite key to search for. Not null.
-	 * @return {@code true} if this map contains the specified key/namespace composite key,
-	 * {@code false} otherwise.
-	 */
-	boolean containsKey(K key, N namespace) {
-
+	public boolean containsKey(K key, N namespace) {
 		final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace);
-		final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash);
+		final StateMapEntry<K, N, S>[] tab = selectActiveTable(hash);
 		int index = hash & (tab.length - 1);
 
-		for (StateTableEntry<K, N, S> e = tab[index]; e != null; e = e.next) {
+		for (StateMapEntry<K, N, S> e = tab[index]; e != null; e = e.next) {
 			final K eKey = e.key;
 			final N eNamespace = e.namespace;
 
@@ -387,116 +320,84 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		return false;
 	}
 
-	/**
-	 * Maps the specified key/namespace composite key to the specified value. This method should be preferred
-	 * over {@link #putAndGetOld(Object, Object, Object)} (Object, Object)} when the caller is not interested
-	 * in the old value, because this can potentially reduce copy-on-write activity.
-	 *
-	 * @param key       the key. Not null.
-	 * @param namespace the namespace. Not null.
-	 * @param value     the value. Can be null.
-	 */
-	void put(K key, N namespace, S value) {
-		final StateTableEntry<K, N, S> e = putEntry(key, namespace);
+	@Override
+	public void put(K key, N namespace, S value) {
+		final StateMapEntry<K, N, S> e = putEntry(key, namespace);
 
 		e.state = value;
-		e.stateVersion = stateTableVersion;
+		e.stateVersion = stateMapVersion;
 	}
 
-	/**
-	 * Maps the specified key/namespace composite key to the specified value. Returns the previous state that was
-	 * registered under the composite key.
-	 *
-	 * @param key       the key. Not null.
-	 * @param namespace the namespace. Not null.
-	 * @param value     the value. Can be null.
-	 * @return the value of any previous mapping with the specified key or
-	 * {@code null} if there was no such mapping.
-	 */
-	S putAndGetOld(K key, N namespace, S value) {
-
-		final StateTableEntry<K, N, S> e = putEntry(key, namespace);
+	@Override
+	public S putAndGetOld(K key, N namespace, S state) {
+		final StateMapEntry<K, N, S> e = putEntry(key, namespace);
 
 		// copy-on-write check for state
 		S oldState = (e.stateVersion < highestRequiredSnapshotVersion) ?
-				getStateSerializer().copy(e.state) :
-				e.state;
+			getStateSerializer().copy(e.state) :
+			e.state;
 
-		e.state = value;
-		e.stateVersion = stateTableVersion;
+		e.state = state;
+		e.stateVersion = stateMapVersion;
 
 		return oldState;
 	}
 
-	/**
-	 * Removes the mapping with the specified key/namespace composite key from this map. This method should be preferred
-	 * over {@link #removeAndGetOld(Object, Object)} when the caller is not interested in the old value, because this
-	 * can potentially reduce copy-on-write activity.
-	 *
-	 * @param key       the key of the mapping to remove. Not null.
-	 * @param namespace the namespace of the mapping to remove. Not null.
-	 */
-	void remove(K key, N namespace) {
+	@Override
+	public void remove(K key, N namespace) {
 		removeEntry(key, namespace);
 	}
 
-	/**
-	 * Removes the mapping with the specified key/namespace composite key from this map, returning the state that was
-	 * found under the entry.
-	 *
-	 * @param key       the key of the mapping to remove. Not null.
-	 * @param namespace the namespace of the mapping to remove. Not null.
-	 * @return the value of the removed mapping or {@code null} if no mapping
-	 * for the specified key was found.
-	 */
-	S removeAndGetOld(K key, N namespace) {
+	@Override
+	public S removeAndGetOld(K key, N namespace) {
 
-		final StateTableEntry<K, N, S> e = removeEntry(key, namespace);
+		final StateMapEntry<K, N, S> e = removeEntry(key, namespace);
 
 		return e != null ?
-				// copy-on-write check for state
-				(e.stateVersion < highestRequiredSnapshotVersion ?
-						getStateSerializer().copy(e.state) :
-						e.state) :
-				null;
+			// copy-on-write check for state
+			(e.stateVersion < highestRequiredSnapshotVersion ?
+				getStateSerializer().copy(e.state) :
+				e.state) :
+			null;
 	}
 
-	/**
-	 * @param key            the key of the mapping to remove. Not null.
-	 * @param namespace      the namespace of the mapping to remove. Not null.
-	 * @param value          the value that is the second input for the transformation.
-	 * @param transformation the transformation function to apply on the old state and the given value.
-	 * @param <T>            type of the value that is the second input to the {@link StateTransformationFunction}.
-	 * @throws Exception exception that happen on applying the function.
-	 * @see #transform(Object, Object, StateTransformationFunction).
-	 */
-	<T> void transform(
-			K key,
-			N namespace,
-			T value,
-			StateTransformationFunction<S, T> transformation) throws Exception {
+	@Override
+	public Stream<K> getKeys(N namespace) {
+		return StreamSupport.stream(spliterator(), false)
+			.filter(entry -> entry.getNamespace().equals(namespace))
+			.map(StateEntry::getKey);
+	}
+
+	@Override
+	public <T> void transform(
+		K key,
+		N namespace,
+		T value,
+		StateTransformationFunction<S, T> transformation) throws Exception {
 
-		final StateTableEntry<K, N, S> entry = putEntry(key, namespace);
+		final StateMapEntry<K, N, S> entry = putEntry(key, namespace);
 
 		// copy-on-write check for state
 		entry.state = transformation.apply(
-				(entry.stateVersion < highestRequiredSnapshotVersion) ?
-						getStateSerializer().copy(entry.state) :
-						entry.state,
-				value);
-		entry.stateVersion = stateTableVersion;
+			(entry.stateVersion < highestRequiredSnapshotVersion) ?
+				getStateSerializer().copy(entry.state) :
+				entry.state,
+			value);
+		entry.stateVersion = stateMapVersion;
 	}
 
+	// Private implementation details of the API methods ---------------------------------------------------------------
+
 	/**
 	 * Helper method that is the basis for operations that add mappings.
 	 */
-	private StateTableEntry<K, N, S> putEntry(K key, N namespace) {
+	private StateMapEntry<K, N, S> putEntry(K key, N namespace) {
 
 		final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace);
-		final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash);
+		final StateMapEntry<K, N, S>[] tab = selectActiveTable(hash);
 		int index = hash & (tab.length - 1);
 
-		for (StateTableEntry<K, N, S> e = tab[index]; e != null; e = e.next) {
+		for (StateMapEntry<K, N, S> e = tab[index]; e != null; e = e.next) {
 			if (e.hash == hash && key.equals(e.key) && namespace.equals(e.namespace)) {
 
 				// copy-on-write check for entry
@@ -513,19 +414,19 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 			doubleCapacity();
 		}
 
-		return addNewStateTableEntry(tab, key, namespace, hash);
+		return addNewStateMapEntry(tab, key, namespace, hash);
 	}
 
 	/**
 	 * Helper method that is the basis for operations that remove mappings.
 	 */
-	private StateTableEntry<K, N, S> removeEntry(K key, N namespace) {
+	private StateMapEntry<K, N, S> removeEntry(K key, N namespace) {
 
 		final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace);
-		final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash);
+		final StateMapEntry<K, N, S>[] tab = selectActiveTable(hash);
 		int index = hash & (tab.length - 1);
 
-		for (StateTableEntry<K, N, S> e = tab[index], prev = null; e != null; prev = e, e = e.next) {
+		for (StateMapEntry<K, N, S> e = tab[index], prev = null; e != null; prev = e, e = e.next) {
 			if (e.hash == hash && key.equals(e.key) && namespace.equals(e.namespace)) {
 				if (prev == null) {
 					tab[index] = e.next;
@@ -548,33 +449,6 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		return null;
 	}
 
-	private void checkKeyNamespacePreconditions(K key, N namespace) {
-		Preconditions.checkNotNull(key, "No key set. This method should not be called outside of a keyed context.");
-		Preconditions.checkNotNull(namespace, "Provided namespace is null.");
-	}
-
-	// Meta data setter / getter and toString --------------------------------------------------------------------------
-
-	@Override
-	public TypeSerializer<S> getStateSerializer() {
-		return metaInfo.getStateSerializer();
-	}
-
-	@Override
-	public TypeSerializer<N> getNamespaceSerializer() {
-		return metaInfo.getNamespaceSerializer();
-	}
-
-	@Override
-	public RegisteredKeyValueStateBackendMetaInfo<N, S> getMetaInfo() {
-		return metaInfo;
-	}
-
-	@Override
-	public void setMetaInfo(RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo) {
-		this.metaInfo = metaInfo;
-	}
-
 	// Iteration  ------------------------------------------------------------------------------------------------------
 
 	@Nonnull
@@ -583,10 +457,10 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		return new StateEntryIterator();
 	}
 
-	// Private utility functions for StateTable management -------------------------------------------------------------
+	// Private utility functions for StateMap management -------------------------------------------------------------
 
 	/**
-	 * @see #releaseSnapshot(CopyOnWriteStateTableSnapshot)
+	 * @see #releaseSnapshot(StateMapSnapshot)
 	 */
 	@VisibleForTesting
 	void releaseSnapshot(int snapshotVersion) {
@@ -600,40 +474,40 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 
 	/**
 	 * Creates (combined) copy of the table arrays for a snapshot. This method must be called by the same Thread that
-	 * does modifications to the {@link CopyOnWriteStateTable}.
+	 * does modifications to the {@link CopyOnWriteStateMap}.
 	 */
 	@VisibleForTesting
 	@SuppressWarnings("unchecked")
-	StateTableEntry<K, N, S>[] snapshotTableArrays() {
+	StateMapEntry<K, N, S>[] snapshotMapArrays() {
 
 		// we guard against concurrent modifications of highestRequiredSnapshotVersion between snapshot and release.
 		// Only stale reads of from the result of #releaseSnapshot calls are ok. This is why we must call this method
-		// from the same thread that does all the modifications to the table.
+		// from the same thread that does all the modifications to the map.
 		synchronized (snapshotVersions) {
 
-			// increase the table version for copy-on-write and register the snapshot
-			if (++stateTableVersion < 0) {
+			// increase the map version for copy-on-write and register the snapshot
+			if (++stateMapVersion < 0) {
 				// this is just a safety net against overflows, but should never happen in practice (i.e., only after 2^31 snapshots)
-				throw new IllegalStateException("Version count overflow in CopyOnWriteStateTable. Enforcing restart.");
+				throw new IllegalStateException("Version count overflow in CopyOnWriteStateMap. Enforcing restart.");
 			}
 
-			highestRequiredSnapshotVersion = stateTableVersion;
+			highestRequiredSnapshotVersion = stateMapVersion;
 			snapshotVersions.add(highestRequiredSnapshotVersion);
 		}
 
-		StateTableEntry<K, N, S>[] table = primaryTable;
+		StateMapEntry<K, N, S>[] table = primaryTable;
 
 		// In order to reuse the copied array as the destination array for the partitioned records in
-		// CopyOnWriteStateTableSnapshot#partitionByKeyGroup(), we need to make sure that the copied array
+		// CopyOnWriteStateMapSnapshot.TransformedSnapshotIterator, we need to make sure that the copied array
 		// is big enough to hold the flattened entries. In fact, given the current rehashing algorithm, we only
 		// need to do this check when isRehashing() is false, but in order to get a more robust code(in case that
 		// the rehashing algorithm may changed in the future), we do this check for all the case.
-		final int totalTableIndexSize = rehashIndex + table.length;
-		final int copiedArraySize = Math.max(totalTableIndexSize, size());
-		final StateTableEntry<K, N, S>[] copy = new StateTableEntry[copiedArraySize];
+		final int totalMapIndexSize = rehashIndex + table.length;
+		final int copiedArraySize = Math.max(totalMapIndexSize, size());
+		final StateMapEntry<K, N, S>[] copy = new StateMapEntry[copiedArraySize];
 
 		if (isRehashing()) {
-			// consider both tables for the snapshot, the rehash index tells us which part of the two tables we need
+			// consider both maps for the snapshot, the rehash index tells us which part of the two maps we need
 			final int localRehashIndex = rehashIndex;
 			final int localCopyLength = table.length - localRehashIndex;
 			// for the primary table, take every index >= rhIdx.
@@ -652,43 +526,47 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		return copy;
 	}
 
+	int getStateMapVersion() {
+		return stateMapVersion;
+	}
+
 	/**
 	 * Allocate a table of the given capacity and set the threshold accordingly.
 	 *
 	 * @param newCapacity must be a power of two
 	 */
-	private StateTableEntry<K, N, S>[] makeTable(int newCapacity) {
+	private StateMapEntry<K, N, S>[] makeTable(int newCapacity) {
 
 		if (newCapacity < MAXIMUM_CAPACITY) {
 			threshold = (newCapacity >> 1) + (newCapacity >> 2); // 3/4 capacity
 		} else {
 			if (size() > MAX_ARRAY_SIZE) {
 
-				throw new IllegalStateException("Maximum capacity of CopyOnWriteStateTable is reached and the job " +
+				throw new IllegalStateException("Maximum capacity of CopyOnWriteStateMap is reached and the job " +
 					"cannot continue. Please consider scaling-out your job or using a different keyed state backend " +
 					"implementation!");
 			} else {
 
-				LOG.warn("Maximum capacity of 2^30 in StateTable reached. Cannot increase hash table size. This can " +
+				LOG.warn("Maximum capacity of 2^30 in StateMap reached. Cannot increase hash map size. This can " +
 					"lead to more collisions and lower performance. Please consider scaling-out your job or using a " +
 					"different keyed state backend implementation!");
 				threshold = MAX_ARRAY_SIZE;
 			}
 		}
 
-		@SuppressWarnings("unchecked") StateTableEntry<K, N, S>[] newTable
-				= (StateTableEntry<K, N, S>[]) new StateTableEntry[newCapacity];
-		return newTable;
+		@SuppressWarnings("unchecked") StateMapEntry<K, N, S>[] newMap =
+			(StateMapEntry<K, N, S>[]) new StateMapEntry[newCapacity];
+		return newMap;
 	}
 
 	/**
-	 * Creates and inserts a new {@link StateTableEntry}.
+	 * Creates and inserts a new {@link StateMapEntry}.
 	 */
-	private StateTableEntry<K, N, S> addNewStateTableEntry(
-			StateTableEntry<K, N, S>[] table,
-			K key,
-			N namespace,
-			int hash) {
+	private StateMapEntry<K, N, S> addNewStateMapEntry(
+		StateMapEntry<K, N, S>[] table,
+		K key,
+		N namespace,
+		int hash) {
 
 		// small optimization that aims to avoid holding references on duplicate namespace objects
 		if (namespace.equals(lastNamespace)) {
@@ -698,14 +576,14 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		}
 
 		int index = hash & (table.length - 1);
-		StateTableEntry<K, N, S> newEntry = new StateTableEntry<>(
-				key,
-				namespace,
-				null,
-				hash,
-				table[index],
-				stateTableVersion,
-				stateTableVersion);
+		StateMapEntry<K, N, S> newEntry = new StateMapEntry<>(
+			key,
+			namespace,
+			null,
+			hash,
+			table[index],
+			stateMapVersion,
+			stateMapVersion);
 		table[index] = newEntry;
 
 		if (table == primaryTable) {
@@ -722,7 +600,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	 * @param hashCode the hash code which we use to decide about the table that is responsible.
 	 * @return the index of the sub-table that is responsible for the entry with the given hash code.
 	 */
-	private StateTableEntry<K, N, S>[] selectActiveTable(int hashCode) {
+	private StateMapEntry<K, N, S>[] selectActiveTable(int hashCode) {
 		return (hashCode & (primaryTable.length - 1)) >= rehashIndex ? primaryTable : incrementalRehashTable;
 	}
 
@@ -737,9 +615,9 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		// There can only be one rehash in flight. From the amount of incremental rehash steps we take, this should always hold.
 		Preconditions.checkState(!isRehashing(), "There is already a rehash in progress.");
 
-		StateTableEntry<K, N, S>[] oldTable = primaryTable;
+		StateMapEntry<K, N, S>[] oldMap = primaryTable;
 
-		int oldCapacity = oldTable.length;
+		int oldCapacity = oldMap.length;
 
 		if (oldCapacity == MAXIMUM_CAPACITY) {
 			return;
@@ -763,8 +641,6 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	 */
 	private int computeHashForOperationAndDoIncrementalRehash(K key, N namespace) {
 
-		checkKeyNamespacePreconditions(key, namespace);
-
 		if (isRehashing()) {
 			incrementalRehash();
 		}
@@ -778,11 +654,11 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	@SuppressWarnings("unchecked")
 	private void incrementalRehash() {
 
-		StateTableEntry<K, N, S>[] oldTable = primaryTable;
-		StateTableEntry<K, N, S>[] newTable = incrementalRehashTable;
+		StateMapEntry<K, N, S>[] oldMap = primaryTable;
+		StateMapEntry<K, N, S>[] newMap = incrementalRehashTable;
 
-		int oldCapacity = oldTable.length;
-		int newMask = newTable.length - 1;
+		int oldCapacity = oldMap.length;
+		int newMask = newMap.length - 1;
 		int requiredVersion = highestRequiredSnapshotVersion;
 		int rhIdx = rehashIndex;
 		int transferred = 0;
@@ -790,26 +666,26 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		// we migrate a certain minimum amount of entries from the old to the new table
 		while (transferred < MIN_TRANSFERRED_PER_INCREMENTAL_REHASH) {
 
-			StateTableEntry<K, N, S> e = oldTable[rhIdx];
+			StateMapEntry<K, N, S> e = oldMap[rhIdx];
 
 			while (e != null) {
 				// copy-on-write check for entry
 				if (e.entryVersion < requiredVersion) {
-					e = new StateTableEntry<>(e, stateTableVersion);
+					e = new StateMapEntry<>(e, stateMapVersion);
 				}
-				StateTableEntry<K, N, S> n = e.next;
+				StateMapEntry<K, N, S> n = e.next;
 				int pos = e.hash & newMask;
-				e.next = newTable[pos];
-				newTable[pos] = e;
+				e.next = newMap[pos];
+				newMap[pos] = e;
 				e = n;
 				++transferred;
 			}
 
-			oldTable[rhIdx] = null;
+			oldMap[rhIdx] = null;
 			if (++rhIdx == oldCapacity) {
 				//here, the rehash is complete and we release resources and reset fields
-				primaryTable = newTable;
-				incrementalRehashTable = (StateTableEntry<K, N, S>[]) EMPTY_TABLE;
+				primaryTable = newMap;
+				incrementalRehashTable = (StateMapEntry<K, N, S>[]) EMPTY_TABLE;
 				primaryTableSize += incrementalRehashTableSize;
 				incrementalRehashTableSize = 0;
 				rehashIndex = 0;
@@ -827,19 +703,19 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	 * Perform copy-on-write for entry chains. We iterate the (hopefully and probably) still cached chain, replace
 	 * all links up to the 'untilEntry', which we actually wanted to modify.
 	 */
-	private StateTableEntry<K, N, S> handleChainedEntryCopyOnWrite(
-			StateTableEntry<K, N, S>[] tab,
-			int tableIdx,
-			StateTableEntry<K, N, S> untilEntry) {
+	private StateMapEntry<K, N, S> handleChainedEntryCopyOnWrite(
+		StateMapEntry<K, N, S>[] tab,
+		int mapIdx,
+		StateMapEntry<K, N, S> untilEntry) {
 
 		final int required = highestRequiredSnapshotVersion;
 
-		StateTableEntry<K, N, S> current = tab[tableIdx];
-		StateTableEntry<K, N, S> copy;
+		StateMapEntry<K, N, S> current = tab[mapIdx];
+		StateMapEntry<K, N, S> copy;
 
 		if (current.entryVersion < required) {
-			copy = new StateTableEntry<>(current, stateTableVersion);
-			tab[tableIdx] = copy;
+			copy = new StateMapEntry<>(current, stateMapVersion);
+			tab[mapIdx] = copy;
 		} else {
 			// nothing to do, just advance copy to current
 			copy = current;
@@ -853,7 +729,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 
 			if (current.entryVersion < required) {
 				// copy and advance the current's copy
-				copy.next = new StateTableEntry<>(current, stateTableVersion);
+				copy.next = new StateMapEntry<>(current, stateMapVersion);
 				copy = copy.next;
 			} else {
 				// nothing to do, just advance copy to current
@@ -865,8 +741,8 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	}
 
 	@SuppressWarnings("unchecked")
-	private static <K, N, S> StateTableEntry<K, N, S> getBootstrapEntry() {
-		return (StateTableEntry<K, N, S>) ITERATOR_BOOTSTRAP_ENTRY;
+	private static <K, N, S> StateMapEntry<K, N, S> getBootstrapEntry() {
+		return (StateMapEntry<K, N, S>) ITERATOR_BOOTSTRAP_ENTRY;
 	}
 
 	/**
@@ -877,62 +753,65 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		return MathUtils.bitMix(key.hashCode() ^ namespace.hashCode());
 	}
 
-	// Snapshotting ----------------------------------------------------------------------------------------------------
-
-	int getStateTableVersion() {
-		return stateTableVersion;
-	}
-
 	/**
-	 * Creates a snapshot of this {@link CopyOnWriteStateTable}, to be written in checkpointing. The snapshot integrity
-	 * is protected through copy-on-write from the {@link CopyOnWriteStateTable}. Users should call
-	 * {@link #releaseSnapshot(CopyOnWriteStateTableSnapshot)} after using the returned object.
+	 * Creates a snapshot of this {@link CopyOnWriteStateMap}, to be written in checkpointing. The snapshot integrity
+	 * is protected through copy-on-write from the {@link CopyOnWriteStateMap}. Users should call
+	 * {@link #releaseSnapshot(StateMapSnapshot)} after using the returned object.
 	 *
-	 * @return a snapshot from this {@link CopyOnWriteStateTable}, for checkpointing.
+	 * @return a snapshot from this {@link CopyOnWriteStateMap}, for checkpointing.
 	 */
 	@Nonnull
 	@Override
-	public CopyOnWriteStateTableSnapshot<K, N, S> stateSnapshot() {
-		return new CopyOnWriteStateTableSnapshot<>(this);
+	public CopyOnWriteStateMapSnapshot<K, N, S> stateSnapshot() {
+		return new CopyOnWriteStateMapSnapshot<>(this);
 	}
 
 	/**
-	 * Releases a snapshot for this {@link CopyOnWriteStateTable}. This method should be called once a snapshot is no more needed,
-	 * so that the {@link CopyOnWriteStateTable} can stop considering this snapshot for copy-on-write, thus avoiding unnecessary
+	 * Releases a snapshot for this {@link CopyOnWriteStateMap}. This method should be called once a snapshot is no more needed,
+	 * so that the {@link CopyOnWriteStateMap} can stop considering this snapshot for copy-on-write, thus avoiding unnecessary
 	 * object creation.
 	 *
-	 * @param snapshotToRelease the snapshot to release, which was previously created by this state table.
+	 * @param snapshotToRelease the snapshot to release, which was previously created by this state map.
 	 */
-	void releaseSnapshot(CopyOnWriteStateTableSnapshot<K, N, S> snapshotToRelease) {
+	@Override
+	public void releaseSnapshot(StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> snapshotToRelease) {
+
+		CopyOnWriteStateMapSnapshot<K, N, S> copyOnWriteStateMapSnapshot = (CopyOnWriteStateMapSnapshot<K, N, S>) snapshotToRelease;
+
+			Preconditions.checkArgument(copyOnWriteStateMapSnapshot.isOwner(this),
+			"Cannot release snapshot which is owned by a different state map.");
 
-		Preconditions.checkArgument(snapshotToRelease.isOwner(this),
-				"Cannot release snapshot which is owned by a different state table.");
+		releaseSnapshot(copyOnWriteStateMapSnapshot.getSnapshotVersion());
+	}
+
+	// Meta data setter / getter and toString -----------------------------------------------------
 
-		releaseSnapshot(snapshotToRelease.getSnapshotVersion());
+	public TypeSerializer<S> getStateSerializer() {
+		return stateSerializer;
 	}
 
-	// StateTableEntry -------------------------------------------------------------------------------------------------
+	// StateMapEntry -------------------------------------------------------------------------------------------------
 
 	/**
-	 * One entry in the {@link CopyOnWriteStateTable}. This is a triplet of key, namespace, and state. Thereby, key and
+	 * One entry in the {@link CopyOnWriteStateMap}. This is a triplet of key, namespace, and state. Thereby, key and
 	 * namespace together serve as a composite key for the state. This class also contains some management meta data for
-	 * copy-on-write, a pointer to link other {@link StateTableEntry}s to a list, and cached hash code.
+	 * copy-on-write, a pointer to link other {@link StateMapEntry}s to a list, and cached hash code.
 	 *
 	 * @param <K> type of key.
 	 * @param <N> type of namespace.
 	 * @param <S> type of state.
 	 */
 	@VisibleForTesting
-	protected static class StateTableEntry<K, N, S> implements StateEntry<K, N, S> {
+	protected static class StateMapEntry<K, N, S> implements StateEntry<K, N, S> {
 
 		/**
-		 * The key. Assumed to be immutable and not null.
+		 * The key. Assumed to be immumap and not null.
 		 */
 		@Nonnull
 		final K key;
 
 		/**
-		 * The namespace. Assumed to be immutable and not null.
+		 * The namespace. Assumed to be immumap and not null.
 		 */
 		@Nonnull
 		final N namespace;
@@ -944,14 +823,14 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		S state;
 
 		/**
-		 * Link to another {@link StateTableEntry}. This is used to resolve collisions in the
-		 * {@link CopyOnWriteStateTable} through chaining.
+		 * Link to another {@link StateMapEntry}. This is used to resolve collisions in the
+		 * {@link CopyOnWriteStateMap} through chaining.
 		 */
 		@Nullable
-		StateTableEntry<K, N, S> next;
+		StateMapEntry<K, N, S> next;
 
 		/**
-		 * The version of this {@link StateTableEntry}. This is meta data for copy-on-write of the table structure.
+		 * The version of this {@link StateMapEntry}. This is meta data for copy-on-write of the map structure.
 		 */
 		int entryVersion;
 
@@ -965,16 +844,16 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		 */
 		final int hash;
 
-		StateTableEntry(StateTableEntry<K, N, S> other, int entryVersion) {
+		StateMapEntry(StateMapEntry<K, N, S> other, int entryVersion) {
 			this(other.key, other.namespace, other.state, other.hash, other.next, entryVersion, other.stateVersion);
 		}
 
-		StateTableEntry(
+		StateMapEntry(
 			@Nonnull K key,
 			@Nonnull N namespace,
 			@Nullable S state,
 			int hash,
-			@Nullable StateTableEntry<K, N, S> next,
+			@Nullable StateMapEntry<K, N, S> next,
 			int entryVersion,
 			int stateVersion) {
 			this.key = key;
@@ -1014,14 +893,14 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 
 		@Override
 		public final boolean equals(Object o) {
-			if (!(o instanceof CopyOnWriteStateTable.StateTableEntry)) {
+			if (!(o instanceof CopyOnWriteStateMap.StateMapEntry)) {
 				return false;
 			}
 
 			StateEntry<?, ?, ?> e = (StateEntry<?, ?, ?>) o;
 			return e.getKey().equals(key)
-					&& e.getNamespace().equals(namespace)
-					&& Objects.equals(e.getState(), state);
+				&& e.getNamespace().equals(namespace)
+				&& Objects.equals(e.getState(), state);
 		}
 
 		@Override
@@ -1052,60 +931,60 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	// StateEntryIterator  ---------------------------------------------------------------------------------------------
 
 	@Override
-	public StateIncrementalVisitor<K, N, S> getStateIncrementalVisitor(int recommendedMaxNumberOfReturnedRecords) {
+	public InternalKvState.StateIncrementalVisitor<K, N, S> getStateIncrementalVisitor(int recommendedMaxNumberOfReturnedRecords) {
 		return new StateIncrementalVisitorImpl(recommendedMaxNumberOfReturnedRecords);
 	}
 
 	/**
-	 * Iterator over state entry chains in a {@link CopyOnWriteStateTable}.
+	 * Iterator over state entry chains in a {@link CopyOnWriteStateMap}.
 	 */
-	class StateEntryChainIterator implements Iterator<StateTableEntry<K, N, S>> {
-		StateTableEntry<K, N, S>[] activeTable;
-		private int nextTablePosition;
-		private final int maxTraversedTablePositions;
+	class StateEntryChainIterator implements Iterator<StateMapEntry<K, N, S>> {
+		StateMapEntry<K, N, S>[] activeTable;
+		private int nextMapPosition;
+		private final int maxTraversedMapPositions;
 
 		StateEntryChainIterator() {
 			this(Integer.MAX_VALUE);
 		}
 
-		StateEntryChainIterator(int maxTraversedTablePositions) {
-			this.maxTraversedTablePositions = maxTraversedTablePositions;
+		StateEntryChainIterator(int maxTraversedMapPositions) {
+			this.maxTraversedMapPositions = maxTraversedMapPositions;
 			this.activeTable = primaryTable;
-			this.nextTablePosition = 0;
+			this.nextMapPosition = 0;
 		}
 
 		@Override
 		public boolean hasNext() {
-			return size() > 0 && (nextTablePosition < activeTable.length || activeTable == primaryTable);
+			return size() > 0 && (nextMapPosition < activeTable.length || activeTable == primaryTable);
 		}
 
 		@Override
-		public StateTableEntry<K, N, S> next() {
-			StateTableEntry<K, N, S> next;
+		public StateMapEntry<K, N, S> next() {
+			StateMapEntry<K, N, S> next;
 			// consider both sub-tables to cover the case of rehash
 			while (true) { // current is empty
 				// try get next in active table or
 				// iteration is done over primary and rehash table
 				// or primary was swapped with rehash when rehash is done
-				next = nextActiveTablePosition();
+				next = nextActiveMapPosition();
 				if (next != null ||
-					nextTablePosition < activeTable.length ||
+					nextMapPosition < activeTable.length ||
 					activeTable == incrementalRehashTable ||
 					activeTable != primaryTable) {
 					return next;
 				} else {
 					// switch to rehash (empty if no rehash)
 					activeTable = incrementalRehashTable;
-					nextTablePosition = 0;
+					nextMapPosition = 0;
 				}
 			}
 		}
 
-		private StateTableEntry<K, N, S> nextActiveTablePosition() {
-			StateTableEntry<K, N, S>[] tab = activeTable;
+		private StateMapEntry<K, N, S> nextActiveMapPosition() {
+			StateMapEntry<K, N, S>[] tab = activeTable;
 			int traversedPositions = 0;
-			while (nextTablePosition < tab.length && traversedPositions < maxTraversedTablePositions) {
-				StateTableEntry<K, N, S> next = tab[nextTablePosition++];
+			while (nextMapPosition < tab.length && traversedPositions < maxTraversedMapPositions) {
+				StateMapEntry<K, N, S> next = tab[nextMapPosition++];
 				if (next != null) {
 					return next;
 				}
@@ -1116,12 +995,12 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	}
 
 	/**
-	 * Iterator over state entries in a {@link CopyOnWriteStateTable} which does not tolerate concurrent modifications.
+	 * Iterator over state entries in a {@link CopyOnWriteStateMap} which does not tolerate concurrent modifications.
 	 */
 	class StateEntryIterator implements Iterator<StateEntry<K, N, S>> {
 
 		private final StateEntryChainIterator chainIterator;
-		private StateTableEntry<K, N, S> nextEntry;
+		private StateMapEntry<K, N, S> nextEntry;
 		private final int expectedModCount;
 
 		StateEntryIterator() {
@@ -1147,9 +1026,9 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 			return advanceIterator();
 		}
 
-		StateTableEntry<K, N, S> advanceIterator() {
-			StateTableEntry<K, N, S> entryToReturn = nextEntry;
-			StateTableEntry<K, N, S> next = nextEntry.next;
+		StateMapEntry<K, N, S> advanceIterator() {
+			StateMapEntry<K, N, S> entryToReturn = nextEntry;
+			StateMapEntry<K, N, S> next = nextEntry.next;
 			if (next == null) {
 				next = chainIterator.next();
 			}
@@ -1159,9 +1038,9 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	}
 
 	/**
-	 * Incremental visitor over state entries in a {@link CopyOnWriteStateTable}.
+	 * Incremental visitor over state entries in a {@link CopyOnWriteStateMap}.
 	 */
-	class StateIncrementalVisitorImpl implements StateIncrementalVisitor<K, N, S> {
+	class StateIncrementalVisitorImpl implements InternalKvState.StateIncrementalVisitor<K, N, S> {
 
 		private final StateEntryChainIterator chainIterator;
 		private final Collection<StateEntry<K, N, S>> chainToReturn = new ArrayList<>(5);
@@ -1182,9 +1061,9 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 			}
 
 			chainToReturn.clear();
-			for (StateTableEntry<K, N, S> nextEntry = chainIterator.next();
-				 nextEntry != null;
-				 nextEntry = nextEntry.next) {
+			for (StateMapEntry<K, N, S> nextEntry = chainIterator.next();
+					nextEntry != null;
+					nextEntry = nextEntry.next) {
 				chainToReturn.add(nextEntry);
 			}
 			return chainToReturn;
@@ -1192,12 +1071,12 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 
 		@Override
 		public void remove(StateEntry<K, N, S> stateEntry) {
-			CopyOnWriteStateTable.this.remove(stateEntry.getKey(), stateEntry.getNamespace());
+			CopyOnWriteStateMap.this.remove(stateEntry.getKey(), stateEntry.getNamespace());
 		}
 
 		@Override
 		public void update(StateEntry<K, N, S> stateEntry, S newValue) {
-			CopyOnWriteStateTable.this.put(stateEntry.getKey(), stateEntry.getNamespace(), newValue);
+			CopyOnWriteStateMap.this.put(stateEntry.getKey(), stateEntry.getNamespace(), newValue);
 		}
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapSnapshot.java
new file mode 100644
index 0000000..a39a9b0
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapSnapshot.java
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.state.StateEntry;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+import java.util.Objects;
+
+/**
+ * This class represents the snapshot of a {@link CopyOnWriteStateMap}.
+ *
+ * <p>IMPORTANT: Please notice that snapshot integrity of entries in this class rely on proper copy-on-write semantics
+ * through the {@link CopyOnWriteStateMap} that created the snapshot object, but all objects in this snapshot must be considered
+ * as READ-ONLY!. The reason is that the objects held by this class may or may not be deep copies of original objects
+ * that may still used in the {@link CopyOnWriteStateMap}. This depends for each entry on whether or not it was subject to
+ * copy-on-write operations by the {@link CopyOnWriteStateMap}. Phrased differently: the {@link CopyOnWriteStateMap} provides
+ * copy-on-write isolation for this snapshot, but this snapshot does not isolate modifications from the
+ * {@link CopyOnWriteStateMap}!
+ *
+ * @param <K> type of key
+ * @param <N> type of namespace
+ * @param <S> type of state
+ */
+public class CopyOnWriteStateMapSnapshot<K, N, S>
+	extends StateMapSnapshot<K, N, S, CopyOnWriteStateMap<K, N, S>> {
+
+	/**
+	 * Version of the {@link CopyOnWriteStateMap} when this snapshot was created. This can be used to release the snapshot.
+	 */
+	private final int snapshotVersion;
+
+	/**
+	 * The state map entries, as by the time this snapshot was created. Objects in this array may or may not be deep
+	 * copies of the current entries in the {@link CopyOnWriteStateMap} that created this snapshot. This depends for each entry
+	 * on whether or not it was subject to copy-on-write operations by the {@link CopyOnWriteStateMap}.
+	 */
+	@Nonnull
+	private final CopyOnWriteStateMap.StateMapEntry<K, N, S>[] snapshotData;
+
+	/** The number of (non-null) entries in snapshotData. */
+	@Nonnegative
+	private final int numberOfEntriesInSnapshotData;
+
+	/**
+	 * Creates a new {@link CopyOnWriteStateMapSnapshot}.
+	 *
+	 * @param owningStateMap the {@link CopyOnWriteStateMap} for which this object represents a snapshot.
+	 */
+	CopyOnWriteStateMapSnapshot(CopyOnWriteStateMap<K, N, S> owningStateMap) {
+		super(owningStateMap);
+
+		this.snapshotData = owningStateMap.snapshotMapArrays();
+		this.snapshotVersion = owningStateMap.getStateMapVersion();
+		this.numberOfEntriesInSnapshotData = owningStateMap.size();
+	}
+
+	@Override
+	public void release() {
+		owningStateMap.releaseSnapshot(this);
+	}
+
+	/**
+	 * Returns the internal version of the {@link CopyOnWriteStateMap} when this snapshot was created. This value must be used to
+	 * tell the {@link CopyOnWriteStateMap} when to release this snapshot.
+	 */
+	int getSnapshotVersion() {
+		return snapshotVersion;
+	}
+
+	@Override
+	public void writeState(
+		TypeSerializer<K> keySerializer,
+		TypeSerializer<N> namespaceSerializer,
+		TypeSerializer<S> stateSerializer,
+		@Nonnull DataOutputView dov,
+		@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) throws IOException {
+		SnapshotIterator<K, N, S> snapshotIterator = stateSnapshotTransformer == null ?
+			new NonTransformSnapshotIterator<>(numberOfEntriesInSnapshotData, snapshotData) :
+			new TransformedSnapshotIterator<>(numberOfEntriesInSnapshotData, snapshotData, stateSnapshotTransformer);
+
+		int size = snapshotIterator.size();
+		dov.writeInt(size);
+		while (snapshotIterator.hasNext()) {
+			StateEntry<K, N, S> stateEntry = snapshotIterator.next();
+			namespaceSerializer.serialize(stateEntry.getNamespace(), dov);
+			keySerializer.serialize(stateEntry.getKey(), dov);
+			stateSerializer.serialize(stateEntry.getState(), dov);
+		}
+	}
+
+	/**
+	 * Iterator over state entries in a {@link CopyOnWriteStateMapSnapshot}.
+	 */
+	abstract static class SnapshotIterator<K, N, S> implements Iterator<StateEntry<K, N, S>> {
+
+		int numberOfEntriesInSnapshotData;
+
+		CopyOnWriteStateMap.StateMapEntry<K, N, S>[] snapshotData;
+
+		Iterator<CopyOnWriteStateMap.StateMapEntry<K, N, S>> chainIterator;
+
+		Iterator<CopyOnWriteStateMap.StateMapEntry<K, N, S>> entryIterator;
+
+		SnapshotIterator(
+			int numberOfEntriesInSnapshotData,
+			CopyOnWriteStateMap.StateMapEntry<K, N, S>[] snapshotData,
+			@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) {
+			this.numberOfEntriesInSnapshotData = numberOfEntriesInSnapshotData;
+			this.snapshotData = snapshotData;
+
+			transform(stateSnapshotTransformer);
+			this.chainIterator = getChainIterator();
+			this.entryIterator = Collections.emptyIterator();
+		}
+
+		/**
+		 * Return the number of state entries in this snapshot.
+		 */
+		abstract int size();
+
+		/**
+		 * Transform the state in the snapshot before iterating the state.
+		 */
+		abstract void transform(@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer);
+
+		/**
+		 * Return an iterator over the chains of entries in snapshotData.
+		 */
+		abstract Iterator<CopyOnWriteStateMap.StateMapEntry<K, N, S>> getChainIterator();
+
+		/**
+		 * Return an iterator over the entries in the chain.
+		 *
+		 * @param stateMapEntry The head entry of the chain.
+		 */
+		abstract Iterator<CopyOnWriteStateMap.StateMapEntry<K, N, S>> getEntryIterator(
+			CopyOnWriteStateMap.StateMapEntry<K, N, S> stateMapEntry);
+
+		@Override
+		public boolean hasNext() {
+			return entryIterator.hasNext() || chainIterator.hasNext();
+		}
+
+		@Override
+		public CopyOnWriteStateMap.StateMapEntry<K, N, S> next() {
+			if (entryIterator.hasNext()) {
+				return entryIterator.next();
+			}
+
+			CopyOnWriteStateMap.StateMapEntry<K, N, S> stateMapEntry = chainIterator.next();
+			entryIterator = getEntryIterator(stateMapEntry);
+			return entryIterator.next();
+		}
+	}
+
+	/**
+	 * Implementation of {@link SnapshotIterator} with no transform.
+	 */
+	static class NonTransformSnapshotIterator<K, N, S> extends SnapshotIterator<K, N, S> {
+
+		NonTransformSnapshotIterator(
+			int numberOfEntriesInSnapshotData,
+			CopyOnWriteStateMap.StateMapEntry<K, N, S>[] snapshotData) {
+			super(numberOfEntriesInSnapshotData, snapshotData, null);
+		}
+
+		@Override
+		void transform(@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) {
+		}
+
+		@Override
+		public int size() {
+			return numberOfEntriesInSnapshotData;
+		}
+
+		@Override
+		Iterator<CopyOnWriteStateMap.StateMapEntry<K, N, S>> getChainIterator() {
+			return Arrays.stream(snapshotData).filter(Objects::nonNull).iterator();
+		}
+
+		@Override
+		Iterator<CopyOnWriteStateMap.StateMapEntry<K, N, S>> getEntryIterator(
+			final CopyOnWriteStateMap.StateMapEntry<K, N, S> stateMapEntry) {
+			return new Iterator<CopyOnWriteStateMap.StateMapEntry<K, N, S>>() {
+
+				CopyOnWriteStateMap.StateMapEntry<K, N, S> nextEntry = stateMapEntry;
+
+				@Override
+				public boolean hasNext() {
+					return nextEntry != null;
+				}
+
+				@Override
+				public CopyOnWriteStateMap.StateMapEntry<K, N, S> next() {
+					if (nextEntry == null) {
+						throw new NoSuchElementException();
+					}
+					CopyOnWriteStateMap.StateMapEntry<K, N, S> entry = nextEntry;
+					nextEntry = nextEntry.next;
+					return entry;
+				}
+			};
+		}
+	}
+
+	/**
+	 * Implementation of {@link SnapshotIterator} with a {@link StateSnapshotTransformer}.
+	 */
+	static class TransformedSnapshotIterator<K, N, S> extends SnapshotIterator<K, N, S> {
+
+		TransformedSnapshotIterator(
+			int numberOfEntriesInSnapshotData,
+			CopyOnWriteStateMap.StateMapEntry<K, N, S>[] snapshotData,
+			@Nonnull StateSnapshotTransformer<S> stateSnapshotTransformer) {
+			super(numberOfEntriesInSnapshotData, snapshotData, stateSnapshotTransformer);
+		}
+
+		/**
+		 * Move the chains in snapshotData to the back of the array, and return the
+		 * index of the first chain from the front.
+		 */
+		int moveChainsToBackOfArray() {
+			int index = snapshotData.length - 1;
+			// find the first null chain from the back
+			while (index >= 0) {
+				if (snapshotData[index] == null) {
+					break;
+				}
+				index--;
+			}
+
+			int lastNullIndex = index;
+			index--;
+			// move the chains to the back
+			while (index >= 0) {
+				CopyOnWriteStateMap.StateMapEntry<K, N, S> entry = snapshotData[index];
+				if (entry != null) {
+					snapshotData[lastNullIndex] = entry;
+					snapshotData[index] = null;
+					lastNullIndex--;
+				}
+				index--;
+			}
+			// return the index of the first chain from the front
+			return lastNullIndex + 1;
+		}
+
+		@Override
+		void transform(@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) {
+			Preconditions.checkNotNull(stateSnapshotTransformer);
+			int indexOfFirstChain = moveChainsToBackOfArray();
+			int count = 0;
+			// reuse the snapshotData to transform and flatten the entries.
+			for (int i = indexOfFirstChain; i < snapshotData.length; i++) {
+				CopyOnWriteStateMap.StateMapEntry<K, N, S> entry = snapshotData[i];
+				while (entry != null) {
+					S transformedValue = stateSnapshotTransformer.filterOrTransform(entry.state);
+					if (transformedValue != null) {
+						CopyOnWriteStateMap.StateMapEntry<K, N, S> filteredEntry = entry;
+						if (transformedValue != entry.state) {
+							filteredEntry = new CopyOnWriteStateMap.StateMapEntry<>(entry, entry.entryVersion);
+							filteredEntry.state = transformedValue;
+						}
+						snapshotData[count++] = filteredEntry;
+					}
+					entry = entry.next;
+				}
+			}
+			numberOfEntriesInSnapshotData = count;
+		}
+
+		@Override
+		public int size() {
+			return numberOfEntriesInSnapshotData;
+		}
+
+		@Override
+		Iterator<CopyOnWriteStateMap.StateMapEntry<K, N, S>> getChainIterator() {
+			return Arrays.stream(snapshotData, 0, numberOfEntriesInSnapshotData).iterator();
+		}
+
+		@Override
+		Iterator<CopyOnWriteStateMap.StateMapEntry<K, N, S>> getEntryIterator(
+			CopyOnWriteStateMap.StateMapEntry<K, N, S> stateMapEntry) {
+			return Collections.singleton(stateMapEntry).iterator();
+		}
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
index 5852bc2..dd1eb0b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
@@ -18,198 +18,25 @@
 
 package org.apache.flink.runtime.state.heap;
 
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
-import org.apache.flink.runtime.state.StateEntry;
-import org.apache.flink.runtime.state.StateTransformationFunction;
-import org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor;
-import org.apache.flink.util.MathUtils;
-import org.apache.flink.util.Preconditions;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
 
 import java.util.ArrayList;
-import java.util.Collection;
-import java.util.ConcurrentModificationException;
-import java.util.Iterator;
-import java.util.NoSuchElementException;
-import java.util.Objects;
-import java.util.TreeSet;
-import java.util.stream.Stream;
-import java.util.stream.StreamSupport;
-
-import static org.apache.flink.util.CollectionUtil.MAX_ARRAY_SIZE;
+import java.util.List;
 
 /**
- * Implementation of Flink's in-memory state tables with copy-on-write support. This map does not support null values
- * for key or namespace.
- * <p>
- * {@link CopyOnWriteStateTable} sacrifices some peak performance and memory efficiency for features like incremental
- * rehashing and asynchronous snapshots through copy-on-write. Copy-on-write tries to minimize the amount of copying by
- * maintaining version meta data for both, the map structure and the state objects. However, we must often proactively
- * copy state objects when we hand them to the user.
- * <p>
- * As for any state backend, user should not keep references on state objects that they obtained from state backends
- * outside the scope of the user function calls.
- * <p>
- * Some brief maintenance notes:
- * <p>
- * 1) Flattening the underlying data structure from nested maps (namespace) -> (key) -> (state) to one flat map
- * (key, namespace) -> (state) brings certain performance trade-offs. In theory, the flat map has one less level of
- * indirection compared to the nested map. However, the nested map naturally de-duplicates namespace objects for which
- * #equals() is true. This leads to potentially a lot of redundant namespace objects for the flattened version. Those,
- * in turn, can again introduce more cache misses because we need to follow the namespace object on all operations to
- * ensure entry identities. Obviously, copy-on-write can also add memory overhead. So does the meta data to track
- * copy-on-write requirement (state and entry versions on {@link StateTableEntry}).
- * <p>
- * 2) A flat map structure is a lot easier when it comes to tracking copy-on-write of the map structure.
- * <p>
- * 3) Nested structure had the (never used) advantage that we can easily drop and iterate whole namespaces. This could
- * give locality advantages for certain access pattern, e.g. iterating a namespace.
- * <p>
- * 4) Serialization format is changed from namespace-prefix compressed (as naturally provided from the old nested
- * structure) to making all entries self contained as (key, namespace, state).
- * <p>
- * 5) We got rid of having multiple nested tables, one for each key-group. Instead, we partition state into key-groups
- * on-the-fly, during the asynchronous part of a snapshot.
- * <p>
- * 6) Currently, a state table can only grow, but never shrinks on low load. We could easily add this if required.
- * <p>
- * 7) Heap based state backends like this can easily cause a lot of GC activity. Besides using G1 as garbage collector,
- * we should provide an additional state backend that operates on off-heap memory. This would sacrifice peak performance
- * (due to de/serialization of objects) for a lower, but more constant throughput and potentially huge simplifications
- * w.r.t. copy-on-write.
- * <p>
- * 8) We could try a hybrid of a serialized and object based backends, where key and namespace of the entries are both
- * serialized in one byte-array.
- * <p>
- * 9) We could consider smaller types (e.g. short) for the version counting and think about some reset strategy before
- * overflows, when there is no snapshot running. However, this would have to touch all entries in the map.
- * <p>
- * This class was initially based on the {@link java.util.HashMap} implementation of the Android JDK, but is now heavily
- * customized towards the use case of table for state entries.
- *
- * IMPORTANT: the contracts for this class rely on the user not holding any references to objects returned by this map
- * beyond the life cycle of per-element operations. Or phrased differently, all get-update-put operations on a mapping
- * should be within one call of processElement. Otherwise, the user must take care of taking deep copies, e.g. for
- * caching purposes.
+ * This implementation of {@link StateTable} uses {@link CopyOnWriteStateMap}. This implementation supports asynchronous snapshots.
  *
  * @param <K> type of key.
  * @param <N> type of namespace.
- * @param <S> type of value.
+ * @param <S> type of state.
  */
-public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implements Iterable<StateEntry<K, N, S>> {
-
-	/**
-	 * The logger.
-	 */
-	private static final Logger LOG = LoggerFactory.getLogger(HeapKeyedStateBackend.class);
-
-	/**
-	 * Min capacity (other than zero) for a {@link CopyOnWriteStateTable}. Must be a power of two
-	 * greater than 1 (and less than 1 << 30).
-	 */
-	private static final int MINIMUM_CAPACITY = 4;
-
-	/**
-	 * Max capacity for a {@link CopyOnWriteStateTable}. Must be a power of two >= MINIMUM_CAPACITY.
-	 */
-	private static final int MAXIMUM_CAPACITY = 1 << 30;
-
-	/**
-	 * Default capacity for a {@link CopyOnWriteStateTable}. Must be a power of two,
-	 * greater than {@code MINIMUM_CAPACITY} and less than {@code MAXIMUM_CAPACITY}.
-	 */
-	public static final int DEFAULT_CAPACITY = 1024;
-
-	/**
-	 * Minimum number of entries that one step of incremental rehashing migrates from the old to the new sub-table.
-	 */
-	private static final int MIN_TRANSFERRED_PER_INCREMENTAL_REHASH = 4;
-
-	/**
-	 * An empty table shared by all zero-capacity maps (typically from default
-	 * constructor). It is never written to, and replaced on first put. Its size
-	 * is set to half the minimum, so that the first resize will create a
-	 * minimum-sized table.
-	 */
-	private static final StateTableEntry<?, ?, ?>[] EMPTY_TABLE = new StateTableEntry[MINIMUM_CAPACITY >>> 1];
-
-	/**
-	 * Empty entry that we use to bootstrap our {@link CopyOnWriteStateTable.StateEntryIterator}.
-	 */
-	private static final StateTableEntry<?, ?, ?> ITERATOR_BOOTSTRAP_ENTRY =
-		new StateTableEntry<>(new Object(), new Object(), new Object(), 0, null, 0, 0);
-
-	/**
-	 * Maintains an ordered set of version ids that are still in use by unreleased snapshots.
-	 */
-	private final TreeSet<Integer> snapshotVersions;
-
-	/**
-	 * This is the primary entry array (hash directory) of the state table. If no incremental rehash is ongoing, this
-	 * is the only used table.
-	 **/
-	private StateTableEntry<K, N, S>[] primaryTable;
-
-	/**
-	 * We maintain a secondary entry array while performing an incremental rehash. The purpose is to slowly migrate
-	 * entries from the primary table to this resized table array. When all entries are migrated, this becomes the new
-	 * primary table.
-	 */
-	private StateTableEntry<K, N, S>[] incrementalRehashTable;
-
-	/**
-	 * The current number of mappings in the primary table.
-	 */
-	private int primaryTableSize;
-
-	/**
-	 * The current number of mappings in the rehash table.
-	 */
-	private int incrementalRehashTableSize;
-
-	/**
-	 * The next index for a step of incremental rehashing in the primary table.
-	 */
-	private int rehashIndex;
-
-	/**
-	 * The current version of this map. Used for copy-on-write mechanics.
-	 */
-	private int stateTableVersion;
-
-	/**
-	 * The highest version of this map that is still required by any unreleased snapshot.
-	 */
-	private int highestRequiredSnapshotVersion;
-
-	/**
-	 * The last namespace that was actually inserted. This is a small optimization to reduce duplicate namespace objects.
-	 */
-	private N lastNamespace;
-
-	/**
-	 * The {@link CopyOnWriteStateTable} is rehashed when its size exceeds this threshold.
-	 * The value of this field is generally .75 * capacity, except when
-	 * the capacity is zero, as described in the EMPTY_TABLE declaration
-	 * above.
-	 */
-	private int threshold;
-
-	/**
-	 * Incremented by "structural modifications" to allow (best effort)
-	 * detection of concurrent modification.
-	 */
-	private int modCount;
+public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> {
 
 	/**
-	 * Constructs a new {@code StateTable} with default capacity of {@code DEFAULT_CAPACITY}.
+	 * Constructs a new {@code CopyOnWriteStateTable}.
 	 *
 	 * @param keyContext    the key context.
 	 * @param metaInfo      the meta information, including the type serializer for state copy-on-write.
@@ -219,985 +46,39 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		InternalKeyContext<K> keyContext,
 		RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo,
 		TypeSerializer<K> keySerializer) {
-		this(keyContext, metaInfo, DEFAULT_CAPACITY, keySerializer);
-	}
-
-	/**
-	 * Constructs a new {@code StateTable} instance with the specified capacity.
-	 *
-	 * @param keyContext    the key context.
-	 * @param metaInfo      the meta information, including the type serializer for state copy-on-write.
-	 * @param capacity      the initial capacity of this hash map.
-	 * @param keySerializer the serializer of the key.
-	 * @throws IllegalArgumentException when the capacity is less than zero.
-	 */
-	@SuppressWarnings("unchecked")
-	private CopyOnWriteStateTable(
-		InternalKeyContext<K> keyContext,
-		RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo,
-		int capacity,
-		TypeSerializer<K> keySerializer) {
 		super(keyContext, metaInfo, keySerializer);
-
-		// initialized tables to EMPTY_TABLE.
-		this.primaryTable = (StateTableEntry<K, N, S>[]) EMPTY_TABLE;
-		this.incrementalRehashTable = (StateTableEntry<K, N, S>[]) EMPTY_TABLE;
-
-		// initialize sizes to 0.
-		this.primaryTableSize = 0;
-		this.incrementalRehashTableSize = 0;
-
-		this.rehashIndex = 0;
-		this.stateTableVersion = 0;
-		this.highestRequiredSnapshotVersion = 0;
-		this.snapshotVersions = new TreeSet<>();
-
-		if (capacity < 0) {
-			throw new IllegalArgumentException("Capacity: " + capacity);
-		}
-
-		if (capacity == 0) {
-			threshold = -1;
-			return;
-		}
-
-		if (capacity < MINIMUM_CAPACITY) {
-			capacity = MINIMUM_CAPACITY;
-		} else if (capacity > MAXIMUM_CAPACITY) {
-			capacity = MAXIMUM_CAPACITY;
-		} else {
-			capacity = MathUtils.roundUpToPowerOfTwo(capacity);
-		}
-		primaryTable = makeTable(capacity);
-	}
-
-	// Public API from AbstractStateTable ------------------------------------------------------------------------------
-
-	/**
-	 * Returns the total number of entries in this {@link CopyOnWriteStateTable}. This is the sum of both sub-tables.
-	 *
-	 * @return the number of entries in this {@link CopyOnWriteStateTable}.
-	 */
-	@Override
-	public int size() {
-		return primaryTableSize + incrementalRehashTableSize;
-	}
-
-	@Override
-	public S get(K key, N namespace) {
-
-		final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace);
-		final int requiredVersion = highestRequiredSnapshotVersion;
-		final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash);
-		int index = hash & (tab.length - 1);
-
-		for (StateTableEntry<K, N, S> e = tab[index]; e != null; e = e.next) {
-			final K eKey = e.key;
-			final N eNamespace = e.namespace;
-			if ((e.hash == hash && key.equals(eKey) && namespace.equals(eNamespace))) {
-
-				// copy-on-write check for state
-				if (e.stateVersion < requiredVersion) {
-					// copy-on-write check for entry
-					if (e.entryVersion < requiredVersion) {
-						e = handleChainedEntryCopyOnWrite(tab, hash & (tab.length - 1), e);
-					}
-					e.stateVersion = stateTableVersion;
-					e.state = getStateSerializer().copy(e.state);
-				}
-
-				return e.state;
-			}
-		}
-
-		return null;
-	}
-
-	@Override
-	public Stream<K> getKeys(N namespace) {
-		return StreamSupport.stream(spliterator(), false)
-			.filter(entry -> entry.getNamespace().equals(namespace))
-			.map(StateEntry::getKey);
-	}
-
-	@Override
-	public void put(K key, int keyGroup, N namespace, S state) {
-		put(key, namespace, state);
-	}
-
-	@Override
-	public S get(N namespace) {
-		return get(keyContext.getCurrentKey(), namespace);
-	}
-
-	@Override
-	public boolean containsKey(N namespace) {
-		return containsKey(keyContext.getCurrentKey(), namespace);
-	}
-
-	@Override
-	public void put(N namespace, S state) {
-		put(keyContext.getCurrentKey(), namespace, state);
-	}
-
-	@Override
-	public S putAndGetOld(N namespace, S state) {
-		return putAndGetOld(keyContext.getCurrentKey(), namespace, state);
-	}
-
-	@Override
-	public void remove(N namespace) {
-		remove(keyContext.getCurrentKey(), namespace);
-	}
-
-	@Override
-	public S removeAndGetOld(N namespace) {
-		return removeAndGetOld(keyContext.getCurrentKey(), namespace);
-	}
-
-	@Override
-	public <T> void transform(N namespace, T value, StateTransformationFunction<S, T> transformation) throws Exception {
-		transform(keyContext.getCurrentKey(), namespace, value, transformation);
-	}
-
-	// Private implementation details of the API methods ---------------------------------------------------------------
-
-	/**
-	 * Returns whether this table contains the specified key/namespace composite key.
-	 *
-	 * @param key       the key in the composite key to search for. Not null.
-	 * @param namespace the namespace in the composite key to search for. Not null.
-	 * @return {@code true} if this map contains the specified key/namespace composite key,
-	 * {@code false} otherwise.
-	 */
-	boolean containsKey(K key, N namespace) {
-
-		final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace);
-		final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash);
-		int index = hash & (tab.length - 1);
-
-		for (StateTableEntry<K, N, S> e = tab[index]; e != null; e = e.next) {
-			final K eKey = e.key;
-			final N eNamespace = e.namespace;
-
-			if ((e.hash == hash && key.equals(eKey) && namespace.equals(eNamespace))) {
-				return true;
-			}
-		}
-		return false;
-	}
-
-	/**
-	 * Maps the specified key/namespace composite key to the specified value. This method should be preferred
-	 * over {@link #putAndGetOld(Object, Object, Object)} (Object, Object)} when the caller is not interested
-	 * in the old value, because this can potentially reduce copy-on-write activity.
-	 *
-	 * @param key       the key. Not null.
-	 * @param namespace the namespace. Not null.
-	 * @param value     the value. Can be null.
-	 */
-	void put(K key, N namespace, S value) {
-		final StateTableEntry<K, N, S> e = putEntry(key, namespace);
-
-		e.state = value;
-		e.stateVersion = stateTableVersion;
-	}
-
-	/**
-	 * Maps the specified key/namespace composite key to the specified value. Returns the previous state that was
-	 * registered under the composite key.
-	 *
-	 * @param key       the key. Not null.
-	 * @param namespace the namespace. Not null.
-	 * @param value     the value. Can be null.
-	 * @return the value of any previous mapping with the specified key or
-	 * {@code null} if there was no such mapping.
-	 */
-	S putAndGetOld(K key, N namespace, S value) {
-
-		final StateTableEntry<K, N, S> e = putEntry(key, namespace);
-
-		// copy-on-write check for state
-		S oldState = (e.stateVersion < highestRequiredSnapshotVersion) ?
-				getStateSerializer().copy(e.state) :
-				e.state;
-
-		e.state = value;
-		e.stateVersion = stateTableVersion;
-
-		return oldState;
-	}
-
-	/**
-	 * Removes the mapping with the specified key/namespace composite key from this map. This method should be preferred
-	 * over {@link #removeAndGetOld(Object, Object)} when the caller is not interested in the old value, because this
-	 * can potentially reduce copy-on-write activity.
-	 *
-	 * @param key       the key of the mapping to remove. Not null.
-	 * @param namespace the namespace of the mapping to remove. Not null.
-	 */
-	void remove(K key, N namespace) {
-		removeEntry(key, namespace);
-	}
-
-	/**
-	 * Removes the mapping with the specified key/namespace composite key from this map, returning the state that was
-	 * found under the entry.
-	 *
-	 * @param key       the key of the mapping to remove. Not null.
-	 * @param namespace the namespace of the mapping to remove. Not null.
-	 * @return the value of the removed mapping or {@code null} if no mapping
-	 * for the specified key was found.
-	 */
-	S removeAndGetOld(K key, N namespace) {
-
-		final StateTableEntry<K, N, S> e = removeEntry(key, namespace);
-
-		return e != null ?
-				// copy-on-write check for state
-				(e.stateVersion < highestRequiredSnapshotVersion ?
-						getStateSerializer().copy(e.state) :
-						e.state) :
-				null;
-	}
-
-	/**
-	 * @param key            the key of the mapping to remove. Not null.
-	 * @param namespace      the namespace of the mapping to remove. Not null.
-	 * @param value          the value that is the second input for the transformation.
-	 * @param transformation the transformation function to apply on the old state and the given value.
-	 * @param <T>            type of the value that is the second input to the {@link StateTransformationFunction}.
-	 * @throws Exception exception that happen on applying the function.
-	 * @see #transform(Object, Object, StateTransformationFunction).
-	 */
-	<T> void transform(
-			K key,
-			N namespace,
-			T value,
-			StateTransformationFunction<S, T> transformation) throws Exception {
-
-		final StateTableEntry<K, N, S> entry = putEntry(key, namespace);
-
-		// copy-on-write check for state
-		entry.state = transformation.apply(
-				(entry.stateVersion < highestRequiredSnapshotVersion) ?
-						getStateSerializer().copy(entry.state) :
-						entry.state,
-				value);
-		entry.stateVersion = stateTableVersion;
-	}
-
-	/**
-	 * Helper method that is the basis for operations that add mappings.
-	 */
-	private StateTableEntry<K, N, S> putEntry(K key, N namespace) {
-
-		final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace);
-		final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash);
-		int index = hash & (tab.length - 1);
-
-		for (StateTableEntry<K, N, S> e = tab[index]; e != null; e = e.next) {
-			if (e.hash == hash && key.equals(e.key) && namespace.equals(e.namespace)) {
-
-				// copy-on-write check for entry
-				if (e.entryVersion < highestRequiredSnapshotVersion) {
-					e = handleChainedEntryCopyOnWrite(tab, index, e);
-				}
-
-				return e;
-			}
-		}
-
-		++modCount;
-		if (size() > threshold) {
-			doubleCapacity();
-		}
-
-		return addNewStateTableEntry(tab, key, namespace, hash);
-	}
-
-	/**
-	 * Helper method that is the basis for operations that remove mappings.
-	 */
-	private StateTableEntry<K, N, S> removeEntry(K key, N namespace) {
-
-		final int hash = computeHashForOperationAndDoIncrementalRehash(key, namespace);
-		final StateTableEntry<K, N, S>[] tab = selectActiveTable(hash);
-		int index = hash & (tab.length - 1);
-
-		for (StateTableEntry<K, N, S> e = tab[index], prev = null; e != null; prev = e, e = e.next) {
-			if (e.hash == hash && key.equals(e.key) && namespace.equals(e.namespace)) {
-				if (prev == null) {
-					tab[index] = e.next;
-				} else {
-					// copy-on-write check for entry
-					if (prev.entryVersion < highestRequiredSnapshotVersion) {
-						prev = handleChainedEntryCopyOnWrite(tab, index, prev);
-					}
-					prev.next = e.next;
-				}
-				++modCount;
-				if (tab == primaryTable) {
-					--primaryTableSize;
-				} else {
-					--incrementalRehashTableSize;
-				}
-				return e;
-			}
-		}
-		return null;
-	}
-
-	private void checkKeyNamespacePreconditions(K key, N namespace) {
-		Preconditions.checkNotNull(key, "No key set. This method should not be called outside of a keyed context.");
-		Preconditions.checkNotNull(namespace, "Provided namespace is null.");
-	}
-
-	// Meta data setter / getter and toString --------------------------------------------------------------------------
-
-	@Override
-	public TypeSerializer<S> getStateSerializer() {
-		return metaInfo.getStateSerializer();
-	}
-
-	@Override
-	public TypeSerializer<N> getNamespaceSerializer() {
-		return metaInfo.getNamespaceSerializer();
-	}
-
-	@Override
-	public RegisteredKeyValueStateBackendMetaInfo<N, S> getMetaInfo() {
-		return metaInfo;
 	}
 
 	@Override
-	public void setMetaInfo(RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo) {
-		this.metaInfo = metaInfo;
-	}
-
-	// Iteration  ------------------------------------------------------------------------------------------------------
-
-	@Nonnull
-	@Override
-	public Iterator<StateEntry<K, N, S>> iterator() {
-		return new StateEntryIterator();
-	}
-
-	// Private utility functions for StateTable management -------------------------------------------------------------
-
-	/**
-	 * @see #releaseSnapshot(CopyOnWriteStateTableSnapshot)
-	 */
-	@VisibleForTesting
-	void releaseSnapshot(int snapshotVersion) {
-		// we guard against concurrent modifications of highestRequiredSnapshotVersion between snapshot and release.
-		// Only stale reads of from the result of #releaseSnapshot calls are ok.
-		synchronized (snapshotVersions) {
-			Preconditions.checkState(snapshotVersions.remove(snapshotVersion), "Attempt to release unknown snapshot version");
-			highestRequiredSnapshotVersion = snapshotVersions.isEmpty() ? 0 : snapshotVersions.last();
-		}
-	}
-
-	/**
-	 * Creates (combined) copy of the table arrays for a snapshot. This method must be called by the same Thread that
-	 * does modifications to the {@link CopyOnWriteStateTable}.
-	 */
-	@VisibleForTesting
-	@SuppressWarnings("unchecked")
-	StateTableEntry<K, N, S>[] snapshotTableArrays() {
-
-		// we guard against concurrent modifications of highestRequiredSnapshotVersion between snapshot and release.
-		// Only stale reads of from the result of #releaseSnapshot calls are ok. This is why we must call this method
-		// from the same thread that does all the modifications to the table.
-		synchronized (snapshotVersions) {
-
-			// increase the table version for copy-on-write and register the snapshot
-			if (++stateTableVersion < 0) {
-				// this is just a safety net against overflows, but should never happen in practice (i.e., only after 2^31 snapshots)
-				throw new IllegalStateException("Version count overflow in CopyOnWriteStateTable. Enforcing restart.");
-			}
-
-			highestRequiredSnapshotVersion = stateTableVersion;
-			snapshotVersions.add(highestRequiredSnapshotVersion);
-		}
-
-		StateTableEntry<K, N, S>[] table = primaryTable;
-
-		// In order to reuse the copied array as the destination array for the partitioned records in
-		// CopyOnWriteStateTableSnapshot#partitionByKeyGroup(), we need to make sure that the copied array
-		// is big enough to hold the flattened entries. In fact, given the current rehashing algorithm, we only
-		// need to do this check when isRehashing() is false, but in order to get a more robust code(in case that
-		// the rehashing algorithm may changed in the future), we do this check for all the case.
-		final int totalTableIndexSize = rehashIndex + table.length;
-		final int copiedArraySize = Math.max(totalTableIndexSize, size());
-		final StateTableEntry<K, N, S>[] copy = new StateTableEntry[copiedArraySize];
-
-		if (isRehashing()) {
-			// consider both tables for the snapshot, the rehash index tells us which part of the two tables we need
-			final int localRehashIndex = rehashIndex;
-			final int localCopyLength = table.length - localRehashIndex;
-			// for the primary table, take every index >= rhIdx.
-			System.arraycopy(table, localRehashIndex, copy, 0, localCopyLength);
-
-			// for the new table, we are sure that two regions contain all the entries:
-			// [0, rhIdx[ AND [table.length / 2, table.length / 2 + rhIdx[
-			table = incrementalRehashTable;
-			System.arraycopy(table, 0, copy, localCopyLength, localRehashIndex);
-			System.arraycopy(table, table.length >>> 1, copy, localCopyLength + localRehashIndex, localRehashIndex);
-		} else {
-			// we only need to copy the primary table
-			System.arraycopy(table, 0, copy, 0, table.length);
-		}
-
-		return copy;
-	}
-
-	/**
-	 * Allocate a table of the given capacity and set the threshold accordingly.
-	 *
-	 * @param newCapacity must be a power of two
-	 */
-	private StateTableEntry<K, N, S>[] makeTable(int newCapacity) {
-
-		if (newCapacity < MAXIMUM_CAPACITY) {
-			threshold = (newCapacity >> 1) + (newCapacity >> 2); // 3/4 capacity
-		} else {
-			if (size() > MAX_ARRAY_SIZE) {
-
-				throw new IllegalStateException("Maximum capacity of CopyOnWriteStateTable is reached and the job " +
-					"cannot continue. Please consider scaling-out your job or using a different keyed state backend " +
-					"implementation!");
-			} else {
-
-				LOG.warn("Maximum capacity of 2^30 in StateTable reached. Cannot increase hash table size. This can " +
-					"lead to more collisions and lower performance. Please consider scaling-out your job or using a " +
-					"different keyed state backend implementation!");
-				threshold = MAX_ARRAY_SIZE;
-			}
-		}
-
-		@SuppressWarnings("unchecked") StateTableEntry<K, N, S>[] newTable
-				= (StateTableEntry<K, N, S>[]) new StateTableEntry[newCapacity];
-		return newTable;
-	}
-
-	/**
-	 * Creates and inserts a new {@link StateTableEntry}.
-	 */
-	private StateTableEntry<K, N, S> addNewStateTableEntry(
-			StateTableEntry<K, N, S>[] table,
-			K key,
-			N namespace,
-			int hash) {
-
-		// small optimization that aims to avoid holding references on duplicate namespace objects
-		if (namespace.equals(lastNamespace)) {
-			namespace = lastNamespace;
-		} else {
-			lastNamespace = namespace;
-		}
-
-		int index = hash & (table.length - 1);
-		StateTableEntry<K, N, S> newEntry = new StateTableEntry<>(
-				key,
-				namespace,
-				null,
-				hash,
-				table[index],
-				stateTableVersion,
-				stateTableVersion);
-		table[index] = newEntry;
-
-		if (table == primaryTable) {
-			++primaryTableSize;
-		} else {
-			++incrementalRehashTableSize;
-		}
-		return newEntry;
-	}
-
-	/**
-	 * Select the sub-table which is responsible for entries with the given hash code.
-	 *
-	 * @param hashCode the hash code which we use to decide about the table that is responsible.
-	 * @return the index of the sub-table that is responsible for the entry with the given hash code.
-	 */
-	private StateTableEntry<K, N, S>[] selectActiveTable(int hashCode) {
-		return (hashCode & (primaryTable.length - 1)) >= rehashIndex ? primaryTable : incrementalRehashTable;
-	}
-
-	/**
-	 * Doubles the capacity of the hash table. Existing entries are placed in
-	 * the correct bucket on the enlarged table. If the current capacity is,
-	 * MAXIMUM_CAPACITY, this method is a no-op. Returns the table, which
-	 * will be new unless we were already at MAXIMUM_CAPACITY.
-	 */
-	private void doubleCapacity() {
-
-		// There can only be one rehash in flight. From the amount of incremental rehash steps we take, this should always hold.
-		Preconditions.checkState(!isRehashing(), "There is already a rehash in progress.");
-
-		StateTableEntry<K, N, S>[] oldTable = primaryTable;
-
-		int oldCapacity = oldTable.length;
-
-		if (oldCapacity == MAXIMUM_CAPACITY) {
-			return;
-		}
-
-		incrementalRehashTable = makeTable(oldCapacity * 2);
-	}
-
-	/**
-	 * Returns true, if an incremental rehash is in progress.
-	 */
-	@VisibleForTesting
-	boolean isRehashing() {
-		// if we rehash, the secondary table is not empty
-		return EMPTY_TABLE != incrementalRehashTable;
-	}
-
-	/**
-	 * Computes the hash for the composite of key and namespace and performs some steps of incremental rehash if
-	 * incremental rehashing is in progress.
-	 */
-	private int computeHashForOperationAndDoIncrementalRehash(K key, N namespace) {
-
-		checkKeyNamespacePreconditions(key, namespace);
-
-		if (isRehashing()) {
-			incrementalRehash();
-		}
-
-		return compositeHash(key, namespace);
-	}
-
-	/**
-	 * Runs a number of steps for incremental rehashing.
-	 */
-	@SuppressWarnings("unchecked")
-	private void incrementalRehash() {
-
-		StateTableEntry<K, N, S>[] oldTable = primaryTable;
-		StateTableEntry<K, N, S>[] newTable = incrementalRehashTable;
-
-		int oldCapacity = oldTable.length;
-		int newMask = newTable.length - 1;
-		int requiredVersion = highestRequiredSnapshotVersion;
-		int rhIdx = rehashIndex;
-		int transferred = 0;
-
-		// we migrate a certain minimum amount of entries from the old to the new table
-		while (transferred < MIN_TRANSFERRED_PER_INCREMENTAL_REHASH) {
-
-			StateTableEntry<K, N, S> e = oldTable[rhIdx];
-
-			while (e != null) {
-				// copy-on-write check for entry
-				if (e.entryVersion < requiredVersion) {
-					e = new StateTableEntry<>(e, stateTableVersion);
-				}
-				StateTableEntry<K, N, S> n = e.next;
-				int pos = e.hash & newMask;
-				e.next = newTable[pos];
-				newTable[pos] = e;
-				e = n;
-				++transferred;
-			}
-
-			oldTable[rhIdx] = null;
-			if (++rhIdx == oldCapacity) {
-				//here, the rehash is complete and we release resources and reset fields
-				primaryTable = newTable;
-				incrementalRehashTable = (StateTableEntry<K, N, S>[]) EMPTY_TABLE;
-				primaryTableSize += incrementalRehashTableSize;
-				incrementalRehashTableSize = 0;
-				rehashIndex = 0;
-				return;
-			}
-		}
-
-		// sync our local bookkeeping the with official bookkeeping fields
-		primaryTableSize -= transferred;
-		incrementalRehashTableSize += transferred;
-		rehashIndex = rhIdx;
-	}
-
-	/**
-	 * Perform copy-on-write for entry chains. We iterate the (hopefully and probably) still cached chain, replace
-	 * all links up to the 'untilEntry', which we actually wanted to modify.
-	 */
-	private StateTableEntry<K, N, S> handleChainedEntryCopyOnWrite(
-			StateTableEntry<K, N, S>[] tab,
-			int tableIdx,
-			StateTableEntry<K, N, S> untilEntry) {
-
-		final int required = highestRequiredSnapshotVersion;
-
-		StateTableEntry<K, N, S> current = tab[tableIdx];
-		StateTableEntry<K, N, S> copy;
-
-		if (current.entryVersion < required) {
-			copy = new StateTableEntry<>(current, stateTableVersion);
-			tab[tableIdx] = copy;
-		} else {
-			// nothing to do, just advance copy to current
-			copy = current;
-		}
-
-		// we iterate the chain up to 'until entry'
-		while (current != untilEntry) {
-
-			//advance current
-			current = current.next;
-
-			if (current.entryVersion < required) {
-				// copy and advance the current's copy
-				copy.next = new StateTableEntry<>(current, stateTableVersion);
-				copy = copy.next;
-			} else {
-				// nothing to do, just advance copy to current
-				copy = current;
-			}
-		}
-
-		return copy;
-	}
-
-	@SuppressWarnings("unchecked")
-	private static <K, N, S> StateTableEntry<K, N, S> getBootstrapEntry() {
-		return (StateTableEntry<K, N, S>) ITERATOR_BOOTSTRAP_ENTRY;
-	}
-
-	/**
-	 * Helper function that creates and scrambles a composite hash for key and namespace.
-	 */
-	private static int compositeHash(Object key, Object namespace) {
-		// create composite key through XOR, then apply some bit-mixing for better distribution of skewed keys.
-		return MathUtils.bitMix(key.hashCode() ^ namespace.hashCode());
+	protected CopyOnWriteStateMap<K, N, S> createStateMap() {
+		return new CopyOnWriteStateMap<>(getStateSerializer());
 	}
 
 	// Snapshotting ----------------------------------------------------------------------------------------------------
 
-	int getStateTableVersion() {
-		return stateTableVersion;
-	}
-
 	/**
-	 * Creates a snapshot of this {@link CopyOnWriteStateTable}, to be written in checkpointing. The snapshot integrity
-	 * is protected through copy-on-write from the {@link CopyOnWriteStateTable}. Users should call
-	 * {@link #releaseSnapshot(CopyOnWriteStateTableSnapshot)} after using the returned object.
+	 * Creates a snapshot of this {@link CopyOnWriteStateTable}, to be written in checkpointing.
 	 *
 	 * @return a snapshot from this {@link CopyOnWriteStateTable}, for checkpointing.
 	 */
 	@Nonnull
 	@Override
 	public CopyOnWriteStateTableSnapshot<K, N, S> stateSnapshot() {
-		return new CopyOnWriteStateTableSnapshot<>(this);
+		return new CopyOnWriteStateTableSnapshot<>(
+			this,
+			getKeySerializer().duplicate(),
+			getNamespaceSerializer().duplicate(),
+			getStateSerializer().duplicate(),
+			getMetaInfo().getStateSnapshotTransformFactory().createForDeserializedState().orElse(null));
 	}
 
-	/**
-	 * Releases a snapshot for this {@link CopyOnWriteStateTable}. This method should be called once a snapshot is no more needed,
-	 * so that the {@link CopyOnWriteStateTable} can stop considering this snapshot for copy-on-write, thus avoiding unnecessary
-	 * object creation.
-	 *
-	 * @param snapshotToRelease the snapshot to release, which was previously created by this state table.
-	 */
-	void releaseSnapshot(CopyOnWriteStateTableSnapshot<K, N, S> snapshotToRelease) {
-
-		Preconditions.checkArgument(snapshotToRelease.isOwner(this),
-				"Cannot release snapshot which is owned by a different state table.");
-
-		releaseSnapshot(snapshotToRelease.getSnapshotVersion());
-	}
-
-	// StateTableEntry -------------------------------------------------------------------------------------------------
-
-	/**
-	 * One entry in the {@link CopyOnWriteStateTable}. This is a triplet of key, namespace, and state. Thereby, key and
-	 * namespace together serve as a composite key for the state. This class also contains some management meta data for
-	 * copy-on-write, a pointer to link other {@link StateTableEntry}s to a list, and cached hash code.
-	 *
-	 * @param <K> type of key.
-	 * @param <N> type of namespace.
-	 * @param <S> type of state.
-	 */
-	@VisibleForTesting
-	protected static class StateTableEntry<K, N, S> implements StateEntry<K, N, S> {
-
-		/**
-		 * The key. Assumed to be immutable and not null.
-		 */
-		@Nonnull
-		final K key;
-
-		/**
-		 * The namespace. Assumed to be immutable and not null.
-		 */
-		@Nonnull
-		final N namespace;
-
-		/**
-		 * The state. This is not final to allow exchanging the object for copy-on-write. Can be null.
-		 */
-		@Nullable
-		S state;
-
-		/**
-		 * Link to another {@link StateTableEntry}. This is used to resolve collisions in the
-		 * {@link CopyOnWriteStateTable} through chaining.
-		 */
-		@Nullable
-		StateTableEntry<K, N, S> next;
-
-		/**
-		 * The version of this {@link StateTableEntry}. This is meta data for copy-on-write of the table structure.
-		 */
-		int entryVersion;
-
-		/**
-		 * The version of the state object in this entry. This is meta data for copy-on-write of the state object itself.
-		 */
-		int stateVersion;
-
-		/**
-		 * The computed secondary hash for the composite of key and namespace.
-		 */
-		final int hash;
-
-		StateTableEntry(StateTableEntry<K, N, S> other, int entryVersion) {
-			this(other.key, other.namespace, other.state, other.hash, other.next, entryVersion, other.stateVersion);
-		}
-
-		StateTableEntry(
-			@Nonnull K key,
-			@Nonnull N namespace,
-			@Nullable S state,
-			int hash,
-			@Nullable StateTableEntry<K, N, S> next,
-			int entryVersion,
-			int stateVersion) {
-			this.key = key;
-			this.namespace = namespace;
-			this.hash = hash;
-			this.next = next;
-			this.entryVersion = entryVersion;
-			this.state = state;
-			this.stateVersion = stateVersion;
-		}
-
-		public final void setState(@Nullable S value, int mapVersion) {
-			// naturally, we can update the state version every time we replace the old state with a different object
-			if (value != state) {
-				this.state = value;
-				this.stateVersion = mapVersion;
-			}
-		}
-
-		@Nonnull
-		@Override
-		public K getKey() {
-			return key;
-		}
-
-		@Nonnull
-		@Override
-		public N getNamespace() {
-			return namespace;
-		}
-
-		@Nullable
-		@Override
-		public S getState() {
-			return state;
-		}
-
-		@Override
-		public final boolean equals(Object o) {
-			if (!(o instanceof CopyOnWriteStateTable.StateTableEntry)) {
-				return false;
-			}
-
-			StateEntry<?, ?, ?> e = (StateEntry<?, ?, ?>) o;
-			return e.getKey().equals(key)
-					&& e.getNamespace().equals(namespace)
-					&& Objects.equals(e.getState(), state);
-		}
-
-		@Override
-		public final int hashCode() {
-			return (key.hashCode() ^ namespace.hashCode()) ^ Objects.hashCode(state);
-		}
-
-		@Override
-		public final String toString() {
-			return "(" + key + "|" + namespace + ")=" + state;
-		}
-	}
-
-	// For testing  ----------------------------------------------------------------------------------------------------
-
-	@Override
-	public int sizeOfNamespace(Object namespace) {
-		int count = 0;
-		for (StateEntry<K, N, S> entry : this) {
-			if (null != entry && namespace.equals(entry.getNamespace())) {
-				++count;
-			}
-		}
-		return count;
-	}
-
-
-	// StateEntryIterator  ---------------------------------------------------------------------------------------------
-
-	@Override
-	public StateIncrementalVisitor<K, N, S> getStateIncrementalVisitor(int recommendedMaxNumberOfReturnedRecords) {
-		return new StateIncrementalVisitorImpl(recommendedMaxNumberOfReturnedRecords);
-	}
-
-	/**
-	 * Iterator over state entry chains in a {@link CopyOnWriteStateTable}.
-	 */
-	class StateEntryChainIterator implements Iterator<StateTableEntry<K, N, S>> {
-		StateTableEntry<K, N, S>[] activeTable;
-		private int nextTablePosition;
-		private final int maxTraversedTablePositions;
-
-		StateEntryChainIterator() {
-			this(Integer.MAX_VALUE);
-		}
-
-		StateEntryChainIterator(int maxTraversedTablePositions) {
-			this.maxTraversedTablePositions = maxTraversedTablePositions;
-			this.activeTable = primaryTable;
-			this.nextTablePosition = 0;
-		}
-
-		@Override
-		public boolean hasNext() {
-			return size() > 0 && (nextTablePosition < activeTable.length || activeTable == primaryTable);
-		}
-
-		@Override
-		public StateTableEntry<K, N, S> next() {
-			StateTableEntry<K, N, S> next;
-			// consider both sub-tables to cover the case of rehash
-			while (true) { // current is empty
-				// try get next in active table or
-				// iteration is done over primary and rehash table
-				// or primary was swapped with rehash when rehash is done
-				next = nextActiveTablePosition();
-				if (next != null ||
-					nextTablePosition < activeTable.length ||
-					activeTable == incrementalRehashTable ||
-					activeTable != primaryTable) {
-					return next;
-				} else {
-					// switch to rehash (empty if no rehash)
-					activeTable = incrementalRehashTable;
-					nextTablePosition = 0;
-				}
-			}
-		}
-
-		private StateTableEntry<K, N, S> nextActiveTablePosition() {
-			StateTableEntry<K, N, S>[] tab = activeTable;
-			int traversedPositions = 0;
-			while (nextTablePosition < tab.length && traversedPositions < maxTraversedTablePositions) {
-				StateTableEntry<K, N, S> next = tab[nextTablePosition++];
-				if (next != null) {
-					return next;
-				}
-				traversedPositions++;
-			}
-			return null;
-		}
-	}
-
-	/**
-	 * Iterator over state entries in a {@link CopyOnWriteStateTable} which does not tolerate concurrent modifications.
-	 */
-	class StateEntryIterator implements Iterator<StateEntry<K, N, S>> {
-
-		private final StateEntryChainIterator chainIterator;
-		private StateTableEntry<K, N, S> nextEntry;
-		private final int expectedModCount;
-
-		StateEntryIterator() {
-			this.chainIterator = new StateEntryChainIterator();
-			this.expectedModCount = modCount;
-			this.nextEntry = getBootstrapEntry();
-			advanceIterator();
-		}
-
-		@Override
-		public boolean hasNext() {
-			return nextEntry != null;
-		}
-
-		@Override
-		public StateEntry<K, N, S> next() {
-			if (modCount != expectedModCount) {
-				throw new ConcurrentModificationException();
-			}
-			if (!hasNext()) {
-				throw new NoSuchElementException();
-			}
-			return advanceIterator();
-		}
-
-		StateTableEntry<K, N, S> advanceIterator() {
-			StateTableEntry<K, N, S> entryToReturn = nextEntry;
-			StateTableEntry<K, N, S> next = nextEntry.next;
-			if (next == null) {
-				next = chainIterator.next();
-			}
-			nextEntry = next;
-			return entryToReturn;
-		}
-	}
-
-	/**
-	 * Incremental visitor over state entries in a {@link CopyOnWriteStateTable}.
-	 */
-	class StateIncrementalVisitorImpl implements StateIncrementalVisitor<K, N, S> {
-
-		private final StateEntryChainIterator chainIterator;
-		private final Collection<StateEntry<K, N, S>> chainToReturn = new ArrayList<>(5);
-
-		StateIncrementalVisitorImpl(int recommendedMaxNumberOfReturnedRecords) {
-			chainIterator = new StateEntryChainIterator(recommendedMaxNumberOfReturnedRecords);
-		}
-
-		@Override
-		public boolean hasNext() {
-			return chainIterator.hasNext();
-		}
-
-		@Override
-		public Collection<StateEntry<K, N, S>> nextEntries() {
-			if (!hasNext()) {
-				return null;
-			}
-
-			chainToReturn.clear();
-			for (StateTableEntry<K, N, S> nextEntry = chainIterator.next();
-				 nextEntry != null;
-				 nextEntry = nextEntry.next) {
-				chainToReturn.add(nextEntry);
-			}
-			return chainToReturn;
-		}
-
-		@Override
-		public void remove(StateEntry<K, N, S> stateEntry) {
-			CopyOnWriteStateTable.this.remove(stateEntry.getKey(), stateEntry.getNamespace());
-		}
-
-		@Override
-		public void update(StateEntry<K, N, S> stateEntry, S newValue) {
-			CopyOnWriteStateTable.this.put(stateEntry.getKey(), stateEntry.getNamespace(), newValue);
+	@SuppressWarnings("unchecked")
+	List<CopyOnWriteStateMapSnapshot<K, N, S>> getStateMapSnapshotList() {
+		List<CopyOnWriteStateMapSnapshot<K, N, S>> snapshotList = new ArrayList<>(keyGroupedStateMaps.length);
+		for (int i = 0; i < keyGroupedStateMaps.length; i++) {
+			CopyOnWriteStateMap<K, N, S> stateMap = (CopyOnWriteStateMap<K, N, S>) keyGroupedStateMaps[i];
+			snapshotList.add(stateMap.stateSnapshot());
 		}
+		return snapshotList;
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
index 86cd263..32fbdac 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
@@ -19,277 +19,64 @@
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.state.KeyGroupPartitioner;
-import org.apache.flink.runtime.state.KeyGroupPartitioner.ElementWriterFunction;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
-import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 
-import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
+
+import java.util.List;
 
 /**
- * This class represents the snapshot of a {@link CopyOnWriteStateTable} and has a role in operator state checkpointing. Besides
- * holding the {@link CopyOnWriteStateTable}s internal entries at the time of the snapshot, this class is also responsible for
- * preparing and writing the state in the process of checkpointing.
- *
- * <p>IMPORTANT: Please notice that snapshot integrity of entries in this class rely on proper copy-on-write semantics
- * through the {@link CopyOnWriteStateTable} that created the snapshot object, but all objects in this snapshot must be considered
- * as READ-ONLY!. The reason is that the objects held by this class may or may not be deep copies of original objects
- * that may still used in the {@link CopyOnWriteStateTable}. This depends for each entry on whether or not it was subject to
- * copy-on-write operations by the {@link CopyOnWriteStateTable}. Phrased differently: the {@link CopyOnWriteStateTable} provides
- * copy-on-write isolation for this snapshot, but this snapshot does not isolate modifications from the
- * {@link CopyOnWriteStateTable}!
+ * This class represents the snapshot of a {@link CopyOnWriteStateTable} and has a role in operator state checkpointing.
+ * This class is also responsible for writing the state in the process of checkpointing.
  *
  * @param <K> type of key
  * @param <N> type of namespace
  * @param <S> type of state
  */
 @Internal
-public class CopyOnWriteStateTableSnapshot<K, N, S>
-		extends AbstractStateTableSnapshot<K, N, S, CopyOnWriteStateTable<K, N, S>> {
-
-	/**
-	 * Version of the {@link CopyOnWriteStateTable} when this snapshot was created. This can be used to release the snapshot.
-	 */
-	private final int snapshotVersion;
-
-	/**
-	 * The state table entries, as by the time this snapshot was created. Objects in this array may or may not be deep
-	 * copies of the current entries in the {@link CopyOnWriteStateTable} that created this snapshot. This depends for each entry
-	 * on whether or not it was subject to copy-on-write operations by the {@link CopyOnWriteStateTable}.
-	 */
-	@Nonnull
-	private final CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshotData;
-
-	/** The number of (non-null) entries in snapshotData. */
-	@Nonnegative
-	private final int numberOfEntriesInSnapshotData;
-
-	/**
-	 * A local duplicate of the table's key serializer.
-	 */
-	@Nonnull
-	private final TypeSerializer<K> localKeySerializer;
+public class CopyOnWriteStateTableSnapshot<K, N, S> extends AbstractStateTableSnapshot<K, N, S> {
 
 	/**
-	 * A local duplicate of the table's namespace serializer.
+	 * The offset to the contiguous key groups.
 	 */
-	@Nonnull
-	private final TypeSerializer<N> localNamespaceSerializer;
+	private final int keyGroupOffset;
 
 	/**
-	 * A local duplicate of the table's state serializer.
+	 * Snapshots of state partitioned by key-group.
 	 */
 	@Nonnull
-	private final TypeSerializer<S> localStateSerializer;
-
-	@Nullable
-	private final StateSnapshotTransformer<S> stateSnapshotTransformer;
-
-	/**
-	 * Result of partitioning the snapshot by key-group. This is lazily created in the process of writing this snapshot
-	 * to an output as part of checkpointing.
-	 */
-	@Nullable
-	private StateKeyGroupWriter partitionedStateTableSnapshot;
+	private final List<CopyOnWriteStateMapSnapshot<K, N, S>> stateMapSnapshots;
 
 	/**
 	 * Creates a new {@link CopyOnWriteStateTableSnapshot}.
 	 *
 	 * @param owningStateTable the {@link CopyOnWriteStateTable} for which this object represents a snapshot.
 	 */
-	CopyOnWriteStateTableSnapshot(CopyOnWriteStateTable<K, N, S> owningStateTable) {
-
-		super(owningStateTable);
-		this.snapshotData = owningStateTable.snapshotTableArrays();
-		this.snapshotVersion = owningStateTable.getStateTableVersion();
-		this.numberOfEntriesInSnapshotData = owningStateTable.size();
-
-		// We create duplicates of the serializers for the async snapshot, because TypeSerializer
-		// might be stateful and shared with the event processing thread.
-		this.localKeySerializer = owningStateTable.keySerializer.duplicate();
-		this.localNamespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer().duplicate();
-		this.localStateSerializer = owningStateTable.metaInfo.getStateSerializer().duplicate();
-
-		this.partitionedStateTableSnapshot = null;
-
-		this.stateSnapshotTransformer = owningStateTable.metaInfo.
-			getStateSnapshotTransformFactory().createForDeserializedState().orElse(null);
-	}
-
-	/**
-	 * Returns the internal version of the {@link CopyOnWriteStateTable} when this snapshot was created. This value must be used to
-	 * tell the {@link CopyOnWriteStateTable} when to release this snapshot.
-	 */
-	int getSnapshotVersion() {
-		return snapshotVersion;
-	}
-
-	/**
-	 * Partitions the snapshot data by key-group. The algorithm first builds a histogram for the distribution of keys
-	 * into key-groups. Then, the histogram is accumulated to obtain the boundaries of each key-group in an array.
-	 * Last, we use the accumulated counts as write position pointers for the key-group's bins when reordering the
-	 * entries by key-group. This operation is lazily performed before the first writing of a key-group.
-	 *
-	 * <p>As a possible future optimization, we could perform the repartitioning in-place, using a scheme similar to the
-	 * cuckoo cycles in cuckoo hashing. This can trade some performance for a smaller memory footprint.
-	 */
-	@Nonnull
-	@SuppressWarnings("unchecked")
-	@Override
-	public StateKeyGroupWriter getKeyGroupWriter() {
-		if (partitionedStateTableSnapshot == null) {
-			final InternalKeyContext<K> keyContext = owningStateTable.keyContext;
-			final int numberOfKeyGroups = keyContext.getNumberOfKeyGroups();
-			final KeyGroupRange keyGroupRange = keyContext.getKeyGroupRange();
-			ElementWriterFunction<CopyOnWriteStateTable.StateTableEntry<K, N, S>> elementWriterFunction =
-				(element, dov) -> {
-					localNamespaceSerializer.serialize(element.namespace, dov);
-					localKeySerializer.serialize(element.key, dov);
-					localStateSerializer.serialize(element.state, dov);
-				};
-			StateTableKeyGroupPartitioner<K, N, S> stateTableKeyGroupPartitioner = stateSnapshotTransformer != null ?
-				new TransformingStateTableKeyGroupPartitioner<>(
-					snapshotData,
-					numberOfEntriesInSnapshotData,
-					keyGroupRange,
-					numberOfKeyGroups,
-					elementWriterFunction,
-					stateSnapshotTransformer) :
-				new StateTableKeyGroupPartitioner<>(
-					snapshotData,
-					numberOfEntriesInSnapshotData,
-					keyGroupRange,
-					numberOfKeyGroups,
-					elementWriterFunction);
-			partitionedStateTableSnapshot = stateTableKeyGroupPartitioner.partitionByKeyGroup();
-		}
-		return partitionedStateTableSnapshot;
-	}
-
-	@Nonnull
-	@Override
-	public StateMetaInfoSnapshot getMetaInfoSnapshot() {
-		return owningStateTable.metaInfo.snapshot();
+	CopyOnWriteStateTableSnapshot(
+		CopyOnWriteStateTable<K, N, S> owningStateTable,
+		TypeSerializer<K> localKeySerializer,
+		TypeSerializer<N> localNamespaceSerializer,
+		TypeSerializer<S> localStateSerializer,
+		StateSnapshotTransformer<S> stateSnapshotTransformer) {
+		super(owningStateTable,
+			localKeySerializer,
+			localNamespaceSerializer,
+			localStateSerializer,
+			stateSnapshotTransformer);
+
+		this.keyGroupOffset = owningStateTable.getKeyGroupOffset();
+		this.stateMapSnapshots = owningStateTable.getStateMapSnapshotList();
 	}
 
 	@Override
-	public void release() {
-		owningStateTable.releaseSnapshot(this);
-	}
-
-	/**
-	 * Returns true iff the given state table is the owner of this snapshot object.
-	 */
-	boolean isOwner(CopyOnWriteStateTable<K, N, S> stateTable) {
-		return stateTable == owningStateTable;
-	}
-
-	/**
-	 * This class is the implementation of {@link KeyGroupPartitioner} for {@link CopyOnWriteStateTable}. This class
-	 * swaps input and output in {@link #reportAllElementKeyGroups()} for performance reasons, so that we can reuse
-	 * the non-flattened original snapshot array as partitioning output.
-	 *
-	 * @param <K> type of key.
-	 * @param <N> type of namespace.
-	 * @param <S> type of state value.
-	 */
-	@VisibleForTesting
-	protected static class StateTableKeyGroupPartitioner<K, N, S>
-		extends KeyGroupPartitioner<CopyOnWriteStateTable.StateTableEntry<K, N, S>> {
-
-		@SuppressWarnings("unchecked")
-		StateTableKeyGroupPartitioner(
-			@Nonnull CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshotData,
-			@Nonnegative int stateTableSize,
-			@Nonnull KeyGroupRange keyGroupRange,
-			@Nonnegative int totalKeyGroups,
-			@Nonnull ElementWriterFunction<CopyOnWriteStateTable.StateTableEntry<K, N, S>> elementWriterFunction) {
-
-			super(
-				new CopyOnWriteStateTable.StateTableEntry[stateTableSize],
-				stateTableSize,
-				// We have made sure that the snapshotData is big enough to hold the flattened entries in
-				// CopyOnWriteStateTable#snapshotTableArrays(), we can safely reuse it as the destination array here.
-				snapshotData,
-				keyGroupRange,
-				totalKeyGroups,
-				CopyOnWriteStateTable.StateTableEntry::getKey,
-				elementWriterFunction);
-		}
-
-		@Override
-		protected void reportAllElementKeyGroups() {
-			// In this step we i) 'flatten' the linked list of entries to a second array and ii) report key-groups.
-			int flattenIndex = 0;
-			for (CopyOnWriteStateTable.StateTableEntry<K, N, S> entry : partitioningDestination) {
-				while (null != entry) {
-					flattenIndex = tryAddToSource(flattenIndex, entry);
-					entry = entry.next;
-				}
-			}
-		}
-
-		/** Tries to append next entry to {@code partitioningSource} array snapshot and returns next index.*/
-		int tryAddToSource(int currentIndex, CopyOnWriteStateTable.StateTableEntry<K, N, S> entry) {
-			final int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(entry.key, totalKeyGroups);
-			reportKeyGroupOfElementAtIndex(currentIndex, keyGroup);
-			partitioningSource[currentIndex] = entry;
-			return currentIndex + 1;
-		}
-	}
-
-	/**
-	 * Extended state snapshot transforming {@link StateTableKeyGroupPartitioner}.
-	 *
-	 * <p>This partitioner can additionally transform state before including or not into the snapshot.
-	 */
-	protected static final class TransformingStateTableKeyGroupPartitioner<K, N, S>
-		extends StateTableKeyGroupPartitioner<K, N, S> {
-		private final StateSnapshotTransformer<S> stateSnapshotTransformer;
-
-		TransformingStateTableKeyGroupPartitioner(
-			@Nonnull CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshotData,
-			int stateTableSize,
-			@Nonnull KeyGroupRange keyGroupRange,
-			int totalKeyGroups,
-			@Nonnull ElementWriterFunction<CopyOnWriteStateTable.StateTableEntry<K, N, S>> elementWriterFunction,
-			@Nonnull StateSnapshotTransformer<S> stateSnapshotTransformer) {
-			super(
-				snapshotData,
-				stateTableSize,
-				keyGroupRange,
-				totalKeyGroups,
-				elementWriterFunction);
-			this.stateSnapshotTransformer = stateSnapshotTransformer;
-		}
-
-		@Override
-		int tryAddToSource(int currentIndex, CopyOnWriteStateTable.StateTableEntry<K, N, S> entry) {
-			CopyOnWriteStateTable.StateTableEntry<K, N, S> filteredEntry = filterEntry(entry);
-			if (filteredEntry != null) {
-				return super.tryAddToSource(currentIndex, filteredEntry);
-			}
-			return currentIndex;
+	protected StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> getStateMapSnapshotForKeyGroup(int keyGroup) {
+		int indexOffset = keyGroup - keyGroupOffset;
+		CopyOnWriteStateMapSnapshot<K, N, S> stateMapSnapshot = null;
+		if (indexOffset >= 0 && indexOffset < stateMapSnapshots.size()) {
+			stateMapSnapshot = stateMapSnapshots.get(indexOffset);
 		}
 
-		private CopyOnWriteStateTable.StateTableEntry<K, N, S> filterEntry(
-			CopyOnWriteStateTable.StateTableEntry<K, N, S> entry) {
-			S transformedValue = stateSnapshotTransformer.filterOrTransform(entry.state);
-			if (transformedValue != null) {
-				CopyOnWriteStateTable.StateTableEntry<K, N, S> filteredEntry = entry;
-				if (transformedValue != entry.state) {
-					filteredEntry = new CopyOnWriteStateTable.StateTableEntry<>(entry, entry.entryVersion);
-					filteredEntry.state = transformedValue;
-				}
-				return filteredEntry;
-			}
-			return null;
-		}
+		return stateMapSnapshot;
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
index e5e4998..b5589ba 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
@@ -19,40 +19,14 @@
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
-import org.apache.flink.runtime.state.StateEntry;
-import org.apache.flink.runtime.state.StateEntry.SimpleStateEntry;
-import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
-import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
-import org.apache.flink.runtime.state.StateTransformationFunction;
-import org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor;
-import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
-import org.apache.flink.util.Preconditions;
 
 import javax.annotation.Nonnull;
 
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Objects;
-import java.util.stream.Stream;
-
 /**
- * This implementation of {@link StateTable} uses nested {@link HashMap} objects. It is also maintaining a partitioning
- * by key-group.
- *
- * <p>In contrast to {@link CopyOnWriteStateTable}, this implementation does not support asynchronous snapshots. However,
- * it might have a better memory footprint for some use-cases, e.g. it is naturally de-duplicating namespace objects.
+ * This implementation of {@link StateTable} uses {@link NestedStateMap}.
  *
  * @param <K> type of key.
  * @param <N> type of namespace.
@@ -62,21 +36,8 @@ import java.util.stream.Stream;
 public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 
 	/**
-	 * Map for holding the actual state objects. The outer array represents the key-groups. The nested maps provide
-	 * an outer scope by namespace and an inner scope by key.
-	 */
-	private final Map<N, Map<K, S>>[] state;
-
-	/**
-	 * The offset to the contiguous key groups.
-	 */
-	private final int keyGroupOffset;
-
-	// ------------------------------------------------------------------------
-
-	/**
 	 * Creates a new {@link NestedMapsStateTable} for the given key context and meta info.
-	 *  @param keyContext the key context.
+	 * @param keyContext the key context.
 	 * @param metaInfo the meta information for this state table.
 	 * @param keySerializer the serializer of the key.
 	 */
@@ -85,256 +46,24 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 		RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo,
 		TypeSerializer<K> keySerializer) {
 		super(keyContext, metaInfo, keySerializer);
-		this.keyGroupOffset = keyContext.getKeyGroupRange().getStartKeyGroup();
-
-		@SuppressWarnings("unchecked")
-		Map<N, Map<K, S>>[] state = (Map<N, Map<K, S>>[]) new Map[keyContext.getKeyGroupRange().getNumberOfKeyGroups()];
-		this.state = state;
-	}
-
-	// ------------------------------------------------------------------------
-	//  access to maps
-	// ------------------------------------------------------------------------
-
-	/**
-	 * Returns the internal data structure.
-	 */
-	@VisibleForTesting
-	public Map<N, Map<K, S>>[] getState() {
-		return state;
-	}
-
-	@VisibleForTesting
-	Map<N, Map<K, S>> getMapForKeyGroup(int keyGroupIndex) {
-		final int pos = indexToOffset(keyGroupIndex);
-		if (pos >= 0 && pos < state.length) {
-			return state[pos];
-		} else {
-			return null;
-		}
-	}
-
-	/**
-	 * Sets the given map for the given key-group.
-	 */
-	private void setMapForKeyGroup(int keyGroupId, Map<N, Map<K, S>> map) {
-		try {
-			state[indexToOffset(keyGroupId)] = map;
-		} catch (ArrayIndexOutOfBoundsException e) {
-			throw new IllegalArgumentException("Key group index " + keyGroupId + " is out of range of key group " +
-				"range [" + keyGroupOffset + ", " + (keyGroupOffset + state.length) + ").");
-		}
-	}
-
-	/**
-	 * Translates a key-group id to the internal array offset.
-	 */
-	private int indexToOffset(int index) {
-		return index - keyGroupOffset;
-	}
-
-	// ------------------------------------------------------------------------
-
-	@Override
-	public int size() {
-		int count = 0;
-		for (Map<N, Map<K, S>> namespaceMap : state) {
-			if (null != namespaceMap) {
-				for (Map<K, S> keyMap : namespaceMap.values()) {
-					if (null != keyMap) {
-						count += keyMap.size();
-					}
-				}
-			}
-		}
-		return count;
 	}
 
 	@Override
-	public S get(N namespace) {
-		return get(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace);
+	protected NestedStateMap<K, N, S> createStateMap() {
+		return new NestedStateMap<>();
 	}
 
-	@Override
-	public boolean containsKey(N namespace) {
-		return containsKey(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace);
-	}
-
-	@Override
-	public void put(N namespace, S state) {
-		put(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace, state);
-	}
-
-	@Override
-	public S putAndGetOld(N namespace, S state) {
-		return putAndGetOld(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace, state);
-	}
-
-	@Override
-	public void remove(N namespace) {
-		remove(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace);
-	}
-
-	@Override
-	public S removeAndGetOld(N namespace) {
-		return removeAndGetOld(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace);
-	}
-
-	@Override
-	public S get(K key, N namespace) {
-		int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, keyContext.getNumberOfKeyGroups());
-		return get(key, keyGroup, namespace);
-	}
-
-	@Override
-	public Stream<K> getKeys(N namespace) {
-		return Arrays.stream(state)
-			.filter(Objects::nonNull)
-			.map(namespaces -> namespaces.getOrDefault(namespace, Collections.emptyMap()))
-			.flatMap(namespaceSate -> namespaceSate.keySet().stream());
-	}
-
-	@Override
-	public StateIncrementalVisitor<K, N, S> getStateIncrementalVisitor(int recommendedMaxNumberOfReturnedRecords) {
-		return new StateEntryIterator();
-	}
-
-	// ------------------------------------------------------------------------
-
-	private boolean containsKey(K key, int keyGroupIndex, N namespace) {
-
-		checkKeyNamespacePreconditions(key, namespace);
-
-		Map<N, Map<K, S>> namespaceMap = getMapForKeyGroup(keyGroupIndex);
-
-		if (namespaceMap == null) {
-			return false;
-		}
-
-		Map<K, S> keyedMap = namespaceMap.get(namespace);
-
-		return keyedMap != null && keyedMap.containsKey(key);
-	}
-
-	S get(K key, int keyGroupIndex, N namespace) {
-
-		checkKeyNamespacePreconditions(key, namespace);
-
-		Map<N, Map<K, S>> namespaceMap = getMapForKeyGroup(keyGroupIndex);
-
-		if (namespaceMap == null) {
-			return null;
-		}
-
-		Map<K, S> keyedMap = namespaceMap.get(namespace);
-
-		if (keyedMap == null) {
-			return null;
-		}
-
-		return keyedMap.get(key);
-	}
-
-	@Override
-	public void put(K key, int keyGroupIndex, N namespace, S value) {
-		putAndGetOld(key, keyGroupIndex, namespace, value);
-	}
-
-	private S putAndGetOld(K key, int keyGroupIndex, N namespace, S value) {
-
-		checkKeyNamespacePreconditions(key, namespace);
-
-		Map<N, Map<K, S>> namespaceMap = getMapForKeyGroup(keyGroupIndex);
-
-		if (namespaceMap == null) {
-			namespaceMap = new HashMap<>();
-			setMapForKeyGroup(keyGroupIndex, namespaceMap);
-		}
-
-		Map<K, S> keyedMap = namespaceMap.computeIfAbsent(namespace, k -> new HashMap<>());
-
-		return keyedMap.put(key, value);
-	}
-
-	private void remove(K key, int keyGroupIndex, N namespace) {
-		removeAndGetOld(key, keyGroupIndex, namespace);
-	}
-
-	private S removeAndGetOld(K key, int keyGroupIndex, N namespace) {
-
-		checkKeyNamespacePreconditions(key, namespace);
-
-		Map<N, Map<K, S>> namespaceMap = getMapForKeyGroup(keyGroupIndex);
-
-		if (namespaceMap == null) {
-			return null;
-		}
-
-		Map<K, S> keyedMap = namespaceMap.get(namespace);
-
-		if (keyedMap == null) {
-			return null;
-		}
-
-		S removed = keyedMap.remove(key);
-
-		if (keyedMap.isEmpty()) {
-			namespaceMap.remove(namespace);
-		}
-
-		return removed;
-	}
-
-	private void checkKeyNamespacePreconditions(K key, N namespace) {
-		Preconditions.checkNotNull(key, "No key set. This method should not be called outside of a keyed context.");
-		Preconditions.checkNotNull(namespace, "Provided namespace is null.");
-	}
-
-	@Override
-	public int sizeOfNamespace(Object namespace) {
-		int count = 0;
-		for (Map<N, Map<K, S>> namespaceMap : state) {
-			if (null != namespaceMap) {
-				Map<K, S> keyMap = namespaceMap.get(namespace);
-				count += keyMap != null ? keyMap.size() : 0;
-			}
-		}
-
-		return count;
-	}
-
-	@Override
-	public <T> void transform(N namespace, T value, StateTransformationFunction<S, T> transformation) throws Exception {
-		final K key = keyContext.getCurrentKey();
-		checkKeyNamespacePreconditions(key, namespace);
-		final int keyGroupIndex = keyContext.getCurrentKeyGroupIndex();
-
-		Map<N, Map<K, S>> namespaceMap = getMapForKeyGroup(keyGroupIndex);
-
-		if (namespaceMap == null) {
-			namespaceMap = new HashMap<>();
-			setMapForKeyGroup(keyGroupIndex, namespaceMap);
-		}
-
-		Map<K, S> keyedMap = namespaceMap.computeIfAbsent(namespace, k -> new HashMap<>());
-		keyedMap.put(key, transformation.apply(keyedMap.get(key), value));
-	}
-
-	// snapshots ---------------------------------------------------------------------------------------------------
-
-	private static <K, N, S> int countMappingsInKeyGroup(final Map<N, Map<K, S>> keyGroupMap) {
-		int count = 0;
-		for (Map<K, S> namespaceMap : keyGroupMap.values()) {
-			count += namespaceMap.size();
-		}
-
-		return count;
-	}
+	// Snapshotting ----------------------------------------------------------------------------------------------------
 
 	@Nonnull
 	@Override
 	public NestedMapsStateTableSnapshot<K, N, S> stateSnapshot() {
-		return new NestedMapsStateTableSnapshot<>(this, metaInfo.getStateSnapshotTransformFactory());
+		return new NestedMapsStateTableSnapshot<>(
+			this,
+			getKeySerializer(),
+			getNamespaceSerializer(),
+			getStateSerializer(),
+			getMetaInfo().getStateSnapshotTransformFactory().createForDeserializedState().orElse(null));
 	}
 
 	/**
@@ -345,193 +74,26 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	 * @param <S> type of state.
 	 */
 	static class NestedMapsStateTableSnapshot<K, N, S>
-			extends AbstractStateTableSnapshot<K, N, S, NestedMapsStateTable<K, N, S>>
-			implements StateSnapshot.StateKeyGroupWriter {
-		private final TypeSerializer<K> keySerializer;
-		private final TypeSerializer<N> namespaceSerializer;
-		private final TypeSerializer<S> stateSerializer;
-		private final StateSnapshotTransformer<S> snapshotFilter;
+			extends AbstractStateTableSnapshot<K, N, S> {
 
 		NestedMapsStateTableSnapshot(
 			NestedMapsStateTable<K, N, S> owningTable,
-			StateSnapshotTransformFactory<S> snapshotTransformFactory) {
-
-			super(owningTable);
-			this.snapshotFilter = snapshotTransformFactory.createForDeserializedState().orElse(null);
-			this.keySerializer = owningStateTable.keySerializer;
-			this.namespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer();
-			this.stateSerializer = owningStateTable.metaInfo.getStateSerializer();
-		}
-
-		@Nonnull
-		@Override
-		public StateKeyGroupWriter getKeyGroupWriter() {
-			return this;
-		}
-
-		@Nonnull
-		@Override
-		public StateMetaInfoSnapshot getMetaInfoSnapshot() {
-			return owningStateTable.metaInfo.snapshot();
+			TypeSerializer<K> localKeySerializer,
+			TypeSerializer<N> localNamespaceSerializer,
+			TypeSerializer<S> localStateSerializer,
+			StateSnapshotTransformer<S> stateSnapshotTransformer) {
+			super(owningTable,
+				localKeySerializer,
+				localNamespaceSerializer,
+				localStateSerializer,
+				stateSnapshotTransformer);
 		}
 
-		/**
-		 * Implementation note: we currently chose the same format between {@link NestedMapsStateTable} and
-		 * {@link CopyOnWriteStateTable}.
-		 *
-		 * <p>{@link NestedMapsStateTable} could naturally support a kind of
-		 * prefix-compressed format (grouping by namespace, writing the namespace only once per group instead for each
-		 * mapping). We might implement support for different formats later (tailored towards different state table
-		 * implementations).
-		 */
 		@Override
-		public void writeStateInKeyGroup(@Nonnull DataOutputView dov, int keyGroupId) throws IOException {
-			final Map<N, Map<K, S>> keyGroupMap = owningStateTable.getMapForKeyGroup(keyGroupId);
-			if (null != keyGroupMap) {
-				Map<N, Map<K, S>> filteredMappings = filterMappingsInKeyGroupIfNeeded(keyGroupMap);
-				dov.writeInt(countMappingsInKeyGroup(filteredMappings));
-				for (Map.Entry<N, Map<K, S>> namespaceEntry : filteredMappings.entrySet()) {
-					final N namespace = namespaceEntry.getKey();
-					final Map<K, S> namespaceMap = namespaceEntry.getValue();
-					for (Map.Entry<K, S> keyEntry : namespaceMap.entrySet()) {
-						writeElement(namespace, keyEntry, dov);
-					}
-				}
-			} else {
-				dov.writeInt(0);
-			}
-		}
+		protected StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> getStateMapSnapshotForKeyGroup(int keyGroup) {
+			NestedStateMap<K, N, S> stateMap = (NestedStateMap<K, N, S>) owningStateTable.getMapForKeyGroup(keyGroup);
 
-		private void writeElement(N namespace, Map.Entry<K, S> keyEntry, DataOutputView dov) throws IOException {
-			namespaceSerializer.serialize(namespace, dov);
-			keySerializer.serialize(keyEntry.getKey(), dov);
-			stateSerializer.serialize(keyEntry.getValue(), dov);
-		}
-
-		private Map<N, Map<K, S>> filterMappingsInKeyGroupIfNeeded(final Map<N, Map<K, S>> keyGroupMap) {
-			return snapshotFilter == null ?
-				keyGroupMap : filterMappingsInKeyGroup(keyGroupMap);
-		}
-
-		private Map<N, Map<K, S>> filterMappingsInKeyGroup(final Map<N, Map<K, S>> keyGroupMap) {
-			Map<N, Map<K, S>> filtered = new HashMap<>();
-			for (Map.Entry<N, Map<K, S>> namespaceEntry : keyGroupMap.entrySet()) {
-				N namespace = namespaceEntry.getKey();
-				Map<K, S> filteredNamespaceMap = filtered.computeIfAbsent(namespace, n -> new HashMap<>());
-				for (Map.Entry<K, S> keyEntry : namespaceEntry.getValue().entrySet()) {
-					K key = keyEntry.getKey();
-					S transformedvalue = snapshotFilter.filterOrTransform(keyEntry.getValue());
-					if (transformedvalue != null) {
-						filteredNamespaceMap.put(key, transformedvalue);
-					}
-				}
-			}
-			return filtered;
-		}
-	}
-
-	/**
-	 * Iterator over state entries in a {@link NestedMapsStateTable}.
-	 *
-	 * <p>The iterator keeps a snapshotted copy of key/namespace sets, available at the beginning of iteration.
-	 * While further iterating the copy, the iterator returns the actual state value from primary maps
-	 * if exists at that moment.
-	 *
-	 * <p>Note: Usage of this iterator can have a heap memory consumption impact.
-	 */
-	class StateEntryIterator implements StateIncrementalVisitor<K, N, S>, Iterator<StateEntry<K, N, S>> {
-		private int keyGropuIndex;
-		private Iterator<Map.Entry<N, Map<K, S>>> namespaceIterator;
-		private Map.Entry<N, Map<K, S>> namespace;
-		private Iterator<Map.Entry<K, S>> keyValueIterator;
-		private StateEntry<K, N, S> nextEntry;
-		private StateEntry<K, N, S> lastReturnedEntry;
-
-		StateEntryIterator() {
-			keyGropuIndex = 0;
-			namespace = null;
-			keyValueIterator = null;
-			nextKeyIterator();
-		}
-
-		@Override
-		public boolean hasNext() {
-			nextKeyIterator();
-			return keyIteratorHasNext();
-		}
-
-		@Override
-		public Collection<StateEntry<K, N, S>> nextEntries() {
-			StateEntry<K, N, S> nextEntry = next();
-			return nextEntry == null ? Collections.emptyList() : Collections.singletonList(nextEntry);
-		}
-
-		@Override
-		public StateEntry<K, N, S> next() {
-			StateEntry<K, N, S> next = null;
-			if (hasNext()) {
-				next = nextEntry;
-			}
-			nextEntry = null;
-			lastReturnedEntry = next;
-			return next;
-		}
-
-		private void nextKeyIterator() {
-			while (!keyIteratorHasNext()) {
-				nextNamespaceIterator();
-				if (namespaceIteratorHasNext()) {
-					namespace = namespaceIterator.next();
-					keyValueIterator = new HashSet<>(namespace.getValue().entrySet()).iterator();
-				} else {
-					break;
-				}
-			}
-		}
-
-		private void nextNamespaceIterator() {
-			while (!namespaceIteratorHasNext()) {
-				while (keyGropuIndex < state.length && state[keyGropuIndex] == null) {
-					keyGropuIndex++;
-				}
-				if (keyGropuIndex < state.length && state[keyGropuIndex] != null) {
-					namespaceIterator = new HashSet<>(state[keyGropuIndex++].entrySet()).iterator();
-				} else {
-					break;
-				}
-			}
-		}
-
-		private boolean keyIteratorHasNext() {
-			while (nextEntry == null && keyValueIterator != null && keyValueIterator.hasNext()) {
-				Map.Entry<K, S> next = keyValueIterator.next();
-				Map<K, S> ns = state[keyGropuIndex - 1] == null ? null :
-					state[keyGropuIndex - 1].getOrDefault(namespace.getKey(), null);
-				S upToDateValue = ns == null ? null : ns.getOrDefault(next.getKey(), null);
-				if (upToDateValue != null) {
-					nextEntry = new SimpleStateEntry<>(next.getKey(), namespace.getKey(), upToDateValue);
-				}
-			}
-			return nextEntry != null;
-		}
-
-		private boolean namespaceIteratorHasNext() {
-			return namespaceIterator != null && namespaceIterator.hasNext();
-		}
-
-		@Override
-		public void remove() {
-			remove(lastReturnedEntry);
-		}
-
-		@Override
-		public void remove(StateEntry<K, N, S> stateEntry) {
-			state[keyGropuIndex - 1].get(stateEntry.getNamespace()).remove(stateEntry.getKey());
-		}
-
-		@Override
-		public void update(StateEntry<K, N, S> stateEntry, S newValue) {
-			state[keyGropuIndex - 1].get(stateEntry.getNamespace()).put(stateEntry.getKey(), newValue);
+			return stateMap.stateSnapshot();
 		}
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMap.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMap.java
new file mode 100644
index 0000000..34540de
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMap.java
@@ -0,0 +1,290 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.runtime.state.StateEntry;
+import org.apache.flink.runtime.state.StateTransformationFunction;
+import org.apache.flink.runtime.state.internal.InternalKvState;
+
+import javax.annotation.Nonnull;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.stream.Stream;
+
+/**
+ * This implementation of {@link StateMap} uses nested {@link HashMap} objects.
+ *
+ * @param <K> type of key.
+ * @param <N> type of namespace.
+ * @param <S> type of value.
+ */
+public class NestedStateMap<K, N, S> extends StateMap<K, N, S> {
+
+	/**
+	 * Map for holding the actual state objects. The nested map provide
+	 * an outer scope by namespace and an inner scope by key.
+	 */
+	private final Map<N, Map<K, S>> namespaceMap;
+
+	/**
+	 * Constructs a new {@code NestedStateMap}.
+	 */
+	public NestedStateMap() {
+		this.namespaceMap = new HashMap<>();
+	}
+
+	// Public API from StateMap ------------------------------------------------------------------------------
+
+	@Override
+	public int size() {
+		int count = 0;
+		for (Map<K, S> keyMap : namespaceMap.values()) {
+			if (null != keyMap) {
+				count += keyMap.size();
+			}
+		}
+
+		return count;
+	}
+
+	@Override
+	public S get(K key, N namespace) {
+		Map<K, S> keyedMap = namespaceMap.get(namespace);
+
+		if (keyedMap == null) {
+			return null;
+		}
+
+		return keyedMap.get(key);
+	}
+
+	@Override
+	public boolean containsKey(K key, N namespace) {
+		Map<K, S> keyedMap = namespaceMap.get(namespace);
+
+		return keyedMap != null && keyedMap.containsKey(key);
+	}
+
+	@Override
+	public void put(K key, N namespace, S state) {
+		putAndGetOld(key, namespace, state);
+	}
+
+	@Override
+	public S putAndGetOld(K key, N namespace, S state) {
+		Map<K, S> keyedMap = namespaceMap.computeIfAbsent(namespace, k -> new HashMap<>());
+
+		return keyedMap.put(key, state);
+	}
+
+	@Override
+	public void remove(K key, N namespace) {
+		removeAndGetOld(key, namespace);
+	}
+
+	@Override
+	public S removeAndGetOld(K key, N namespace) {
+		Map<K, S> keyedMap = namespaceMap.get(namespace);
+
+		if (keyedMap == null) {
+			return null;
+		}
+
+		S removed = keyedMap.remove(key);
+
+		if (keyedMap.isEmpty()) {
+			namespaceMap.remove(namespace);
+		}
+
+		return removed;
+	}
+
+	@Override
+	public <T> void transform(
+		K key, N namespace, T value, StateTransformationFunction<S, T> transformation) throws Exception {
+		Map<K, S> keyedMap = namespaceMap.computeIfAbsent(namespace, k -> new HashMap<>());
+		keyedMap.put(key, transformation.apply(keyedMap.get(key), value));
+	}
+
+	@Override
+	public Iterator<StateEntry<K, N, S>> iterator() {
+		return new StateEntryIterator();
+	}
+
+	@Override
+	public Stream<K> getKeys(N namespace) {
+		return namespaceMap.getOrDefault(namespace, Collections.emptyMap()).keySet().stream();
+
+	}
+
+	@Override
+	public InternalKvState.StateIncrementalVisitor<K, N, S> getStateIncrementalVisitor(
+		int recommendedMaxNumberOfReturnedRecords) {
+		return new StateEntryVisitor();
+	}
+
+	@Override
+	public int sizeOfNamespace(Object namespace) {
+		Map<K, S> keyMap = namespaceMap.get(namespace);
+		return keyMap != null ? keyMap.size() : 0;
+	}
+
+	@Nonnull
+	@Override
+	public StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> stateSnapshot() {
+		return new NestedStateMapSnapshot<>(this);
+	}
+
+	public Map<N, Map<K, S>> getNamespaceMap() {
+		return namespaceMap;
+	}
+
+	/**
+	 * Iterator over state entries in a {@link NestedStateMap}.
+	 */
+	class StateEntryIterator implements Iterator<StateEntry<K, N, S>> {
+		private Iterator<Map.Entry<N, Map<K, S>>> namespaceIterator;
+		private Map.Entry<N, Map<K, S>> namespace;
+		private Iterator<Map.Entry<K, S>> keyValueIterator;
+
+		StateEntryIterator() {
+			namespaceIterator = namespaceMap.entrySet().iterator();
+			namespace = null;
+			keyValueIterator = Collections.emptyIterator();
+		}
+
+		@Override
+		public boolean hasNext() {
+			return keyValueIterator.hasNext() || namespaceIterator.hasNext();
+		}
+
+		@Override
+		public StateEntry<K, N, S> next() {
+			if (!hasNext()) {
+				throw new NoSuchElementException();
+			}
+
+			if (!keyValueIterator.hasNext()) {
+				namespace = namespaceIterator.next();
+				keyValueIterator = namespace.getValue().entrySet().iterator();
+			}
+
+			Map.Entry<K, S> entry = keyValueIterator.next();
+
+			return new StateEntry.SimpleStateEntry<>(
+				entry.getKey(), namespace.getKey(), entry.getValue());
+		}
+	}
+
+
+	/**
+	 * Incremental visitor over state entries in a {@link NestedStateMap}.
+	 *
+	 * <p>The iterator keeps a snapshotted copy of key/namespace sets, available at the beginning of iteration.
+	 * While further iterating the copy, the iterator returns the actual state value from primary maps
+	 * if exists at that moment.
+	 *
+	 * <p>Note: Usage of this iterator can have a heap memory consumption impact.
+	 */
+	class StateEntryVisitor implements InternalKvState.StateIncrementalVisitor<K, N, S>, Iterator<StateEntry<K, N, S>> {
+		private Iterator<Map.Entry<N, Map<K, S>>> namespaceIterator;
+		private Map.Entry<N, Map<K, S>> namespace;
+		private Iterator<Map.Entry<K, S>> keyValueIterator;
+		private StateEntry<K, N, S> nextEntry;
+		private StateEntry<K, N, S> lastReturnedEntry;
+
+		StateEntryVisitor() {
+			namespaceIterator = new HashSet<>(namespaceMap.entrySet()).iterator();
+			namespace = null;
+			keyValueIterator = null;
+			nextKeyIterator();
+		}
+
+		@Override
+		public boolean hasNext() {
+			nextKeyIterator();
+			return keyIteratorHasNext();
+		}
+
+		@Override
+		public Collection<StateEntry<K, N, S>> nextEntries() {
+			StateEntry<K, N, S> nextEntry = next();
+			return nextEntry == null ? Collections.emptyList() : Collections.singletonList(nextEntry);
+		}
+
+		@Override
+		public StateEntry<K, N, S> next() {
+			StateEntry<K, N, S> next = null;
+			if (hasNext()) {
+				next = nextEntry;
+			}
+			nextEntry = null;
+			lastReturnedEntry = next;
+			return next;
+		}
+
+		private void nextKeyIterator() {
+			while (!keyIteratorHasNext()) {
+				if (namespaceIteratorHasNext()) {
+					namespace = namespaceIterator.next();
+					keyValueIterator = new HashSet<>(namespace.getValue().entrySet()).iterator();
+				} else {
+					break;
+				}
+			}
+		}
+
+		private boolean keyIteratorHasNext() {
+			while (nextEntry == null && keyValueIterator != null && keyValueIterator.hasNext()) {
+				Map.Entry<K, S> next = keyValueIterator.next();
+				Map<K, S> ns = namespaceMap.getOrDefault(namespace.getKey(), null);
+				S upToDateValue = ns == null ? null : ns.getOrDefault(next.getKey(), null);
+				if (upToDateValue != null) {
+					nextEntry = new StateEntry.SimpleStateEntry<>(next.getKey(), namespace.getKey(), upToDateValue);
+				}
+			}
+			return nextEntry != null;
+		}
+
+		private boolean namespaceIteratorHasNext() {
+			return namespaceIterator.hasNext();
+		}
+
+		@Override
+		public void remove() {
+			remove(lastReturnedEntry);
+		}
+
+		@Override
+		public void remove(StateEntry<K, N, S> stateEntry) {
+			namespaceMap.get(stateEntry.getNamespace()).remove(stateEntry.getKey());
+		}
+
+		@Override
+		public void update(StateEntry<K, N, S> stateEntry, S newValue) {
+			namespaceMap.get(stateEntry.getNamespace()).put(stateEntry.getKey(), newValue);
+		}
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMapSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMapSnapshot.java
new file mode 100644
index 0000000..29342e6
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedStateMapSnapshot.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * This class represents the snapshot of a {@link NestedStateMap}.
+ *
+ * @param <K> type of key
+ * @param <N> type of namespace
+ * @param <S> type of state
+ */
+public class NestedStateMapSnapshot<K, N, S>
+	extends StateMapSnapshot<K, N, S, NestedStateMap<K, N, S>> {
+
+	/**
+	 * Creates a new {@link NestedStateMapSnapshot}.
+	 *
+	 * @param owningStateMap the {@link NestedStateMap} for which this object represents a snapshot.
+	 */
+	public NestedStateMapSnapshot(NestedStateMap<K, N, S> owningStateMap) {
+		super(owningStateMap);
+	}
+
+	@Override
+	public void writeState(
+		TypeSerializer<K> keySerializer,
+		TypeSerializer<N> namespaceSerializer,
+		TypeSerializer<S> stateSerializer,
+		@Nonnull DataOutputView dov,
+		@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) throws IOException {
+		Map<N, Map<K, S>> mappings = filterMappingsIfNeeded(owningStateMap.getNamespaceMap(), stateSnapshotTransformer);
+		int numberOfEntries = countMappingsInKeyGroup(mappings);
+
+		dov.writeInt(numberOfEntries);
+		for (Map.Entry<N, Map<K, S>> namespaceEntry : mappings.entrySet()) {
+			N namespace = namespaceEntry.getKey();
+			for (Map.Entry<K, S> entry : namespaceEntry.getValue().entrySet()) {
+				namespaceSerializer.serialize(namespace, dov);
+				keySerializer.serialize(entry.getKey(), dov);
+				stateSerializer.serialize(entry.getValue(), dov);
+			}
+		}
+	}
+
+	private Map<N, Map<K, S>> filterMappingsIfNeeded(
+		final Map<N, Map<K, S>> keyGroupMap,
+		StateSnapshotTransformer<S> stateSnapshotTransformer) {
+		if (stateSnapshotTransformer == null) {
+			return keyGroupMap;
+		}
+
+		Map<N, Map<K, S>> filtered = new HashMap<>();
+		for (Map.Entry<N, Map<K, S>> namespaceEntry : keyGroupMap.entrySet()) {
+			N namespace = namespaceEntry.getKey();
+			Map<K, S> filteredNamespaceMap = filtered.computeIfAbsent(namespace, n -> new HashMap<>());
+			for (Map.Entry<K, S> keyEntry : namespaceEntry.getValue().entrySet()) {
+				K key = keyEntry.getKey();
+				S transformedvalue = stateSnapshotTransformer.filterOrTransform(keyEntry.getValue());
+				if (transformedvalue != null) {
+					filteredNamespaceMap.put(key, transformedvalue);
+				}
+			}
+			if (filteredNamespaceMap.isEmpty()) {
+				filtered.remove(namespace);
+			}
+		}
+
+		return filtered;
+	}
+
+	private int countMappingsInKeyGroup(final Map<N, Map<K, S>> keyGroupMap) {
+		int count = 0;
+		for (Map<K, S> namespaceMap : keyGroupMap.values()) {
+			count += namespaceMap.size();
+		}
+
+		return count;
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMap.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMap.java
new file mode 100644
index 0000000..1ee16b8
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMap.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.state.StateEntry;
+import org.apache.flink.runtime.state.StateTransformationFunction;
+import org.apache.flink.runtime.state.internal.InternalKvState;
+
+import javax.annotation.Nonnull;
+
+import java.util.stream.Stream;
+
+/**
+ * Base class for state maps.
+ *
+ * @param <K> type of key
+ * @param <N> type of namespace
+ * @param <S> type of state
+ */
+public abstract class StateMap<K, N, S> implements Iterable<StateEntry<K, N, S>> {
+
+	// Main interface methods of StateMap -------------------------------------------------------
+
+	/**
+	 * Returns whether this {@link StateMap} is empty.
+	 *
+	 * @return {@code true} if this {@link StateMap} has no elements, {@code false}
+	 * otherwise.
+	 * @see #size()
+	 */
+	public boolean isEmpty() {
+		return size() == 0;
+	}
+
+	/**
+	 * Returns the total number of entries in this {@link StateMap}.
+	 *
+	 * @return the number of entries in this {@link StateMap}.
+	 */
+	public abstract int size();
+
+	/**
+	 * Returns the state for the composite of active key and given namespace.
+	 *
+	 * @param key       the key. Not null.
+	 * @param namespace the namespace. Not null.
+	 * @return the state of the mapping with the specified key/namespace composite key, or {@code null}
+	 * if no mapping for the specified key is found.
+	 */
+	public abstract S get(K key, N namespace);
+
+	/**
+	 * Returns whether this map contains the specified key/namespace composite key.
+	 *
+	 * @param key       the key in the composite key to search for. Not null.
+	 * @param namespace the namespace in the composite key to search for. Not null.
+	 * @return {@code true} if this map contains the specified key/namespace composite key,
+	 * {@code false} otherwise.
+	 */
+	public abstract boolean containsKey(K key, N namespace);
+
+	/**
+	 * Maps the specified key/namespace composite key to the specified value. This method should be preferred
+	 * over {@link #putAndGetOld(K, N, S)} (Namespace, State)} when the caller is not interested in the old state.
+	 *
+	 * @param key       the key. Not null.
+	 * @param namespace the namespace. Not null.
+	 * @param state     the state. Can be null.
+	 */
+	public abstract void put(K key, N namespace, S state);
+
+	/**
+	 * Maps the composite of active key and given namespace to the specified state. Returns the previous state that
+	 * was registered under the composite key.
+	 *
+	 * @param key       the key. Not null.
+	 * @param namespace the namespace. Not null.
+	 * @param state     the state. Can be null.
+	 * @return the state of any previous mapping with the specified key or
+	 * {@code null} if there was no such mapping.
+	 */
+	public abstract S putAndGetOld(K key, N namespace, S state);
+
+	/**
+	 * Removes the mapping for the composite of active key and given namespace. This method should be preferred
+	 * over {@link #removeAndGetOld(K, N)} when the caller is not interested in the old state.
+	 *
+	 * @param key       the key of the mapping to remove. Not null.
+	 * @param namespace the namespace of the mapping to remove. Not null.
+	 */
+	public abstract void remove(K key, N namespace);
+
+	/**
+	 * Removes the mapping for the composite of active key and given namespace, returning the state that was
+	 * found under the entry.
+	 *
+	 * @param key       the key of the mapping to remove. Not null.
+	 * @param namespace the namespace of the mapping to remove. Not null.
+	 * @return the state of the removed mapping or {@code null} if no mapping
+	 * for the specified key was found.
+	 */
+	public abstract S removeAndGetOld(K key, N namespace);
+
+	/**
+	 * Applies the given {@link StateTransformationFunction} to the state (1st input argument), using the given value as
+	 * second input argument. The result of {@link StateTransformationFunction#apply(Object, Object)} is then stored as
+	 * the new state. This function is basically an optimization for get-update-put pattern.
+	 *
+	 * @param key            the key. Not null.
+	 * @param namespace      the namespace. Not null.
+	 * @param value          the value to use in transforming the state. Can be null.
+	 * @param transformation the transformation function.
+	 * @throws Exception if some exception happens in the transformation function.
+	 */
+	public abstract <T> void transform(
+		K key,
+		N namespace,
+		T value,
+		StateTransformationFunction<S, T> transformation) throws Exception;
+
+	// For queryable state ------------------------------------------------------------------------
+
+	public abstract Stream<K> getKeys(N namespace);
+
+	public abstract InternalKvState.StateIncrementalVisitor<K, N, S> getStateIncrementalVisitor(int recommendedMaxNumberOfReturnedRecords);
+
+	/**
+	 * Creates a snapshot of this {@link StateMap}, to be written in checkpointing. Users should call
+	 * {@link #releaseSnapshot(StateMapSnapshot)} after using the returned object.
+	 *
+	 * @return a snapshot from this {@link StateMap}, for checkpointing.
+	 */
+	@Nonnull
+	public abstract StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> stateSnapshot();
+
+	/**
+	 * Releases a snapshot for this {@link StateMap}. This method should be called once a snapshot is no more needed.
+	 *
+	 * @param snapshotToRelease the snapshot to release, which was previously created by this state map.
+	 */
+	public void releaseSnapshot(StateMapSnapshot<K, N, S, ? extends StateMap<K, N, S>> snapshotToRelease) {
+	}
+
+	// For testing --------------------------------------------------------------------------------
+
+	@VisibleForTesting
+	public abstract int sizeOfNamespace(Object namespace);
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMapSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMapSnapshot.java
new file mode 100644
index 0000000..39d6fa0
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateMapSnapshot.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+
+/**
+ * Base class for snapshots of a {@link StateMap}.
+ *
+ * @param <K> type of key
+ * @param <N> type of namespace
+ * @param <S> type of state
+ */
+public abstract class StateMapSnapshot<K, N, S, T extends StateMap<K, N, S>> {
+
+	/**
+	 * The {@link StateMap} from which this snapshot was created.
+	 */
+	protected final T owningStateMap;
+
+	public StateMapSnapshot(T stateMap) {
+		this.owningStateMap = Preconditions.checkNotNull(stateMap);
+	}
+
+	/**
+	 * Returns true iff the given state map is the owner of this snapshot object.
+	 */
+	public boolean isOwner(T stateMap) {
+		return owningStateMap == stateMap;
+	}
+
+	/**
+	 * Release the snapshot.
+	 */
+	public void release() {
+	}
+
+	/**
+	 * Writes the state in this snapshot to output. The state need to be transformed
+	 * with the given transformer if the transformer is non-null.
+	 *
+	 * @param keySerializer the key serializer.
+	 * @param namespaceSerializer the namespace serializer.
+	 * @param stateSerializer the state serializer.
+	 * @param dov the output.
+	 * @param stateSnapshotTransformer state transformer, and can be null.
+	 * @throws IOException on write-related problems.
+	 */
+	public abstract void writeState(
+		TypeSerializer<K> keySerializer,
+		TypeSerializer<N> namespaceSerializer,
+		TypeSerializer<S> stateSerializer,
+		@Nonnull DataOutputView dov,
+		@Nullable StateSnapshotTransformer<S> stateSnapshotTransformer) throws IOException;
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
index 101829c..758ff6d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
@@ -20,7 +20,9 @@ package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.StateEntry;
 import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
 import org.apache.flink.runtime.state.StateSnapshotRestore;
 import org.apache.flink.runtime.state.StateTransformationFunction;
@@ -29,7 +31,13 @@ import org.apache.flink.util.Preconditions;
 
 import javax.annotation.Nonnull;
 
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Objects;
+import java.util.Spliterators;
 import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
 
 /**
  * Base class for state tables. Accesses to state are typically scoped by the currently active key, as provided
@@ -39,7 +47,8 @@ import java.util.stream.Stream;
  * @param <N> type of namespace
  * @param <S> type of state
  */
-public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
+public abstract class StateTable<K, N, S>
+	implements StateSnapshotRestore, Iterable<StateEntry<K, N, S>> {
 
 	/**
 	 * The key context view on the backend. This provides information, such as the currently active key.
@@ -57,6 +66,17 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 	protected final TypeSerializer<K> keySerializer;
 
 	/**
+	 * The offset to the contiguous key groups.
+	 */
+	protected final int keyGroupOffset;
+
+	/**
+	 * Map for holding the actual state objects. The outer array represents the key-groups.
+	 * All array positions will be initialized with an empty state map.
+	 */
+	protected final StateMap<K, N, S>[] keyGroupedStateMaps;
+
+	/**
 	 * @param keyContext    the key context provides the key scope for all put/get/delete operations.
 	 * @param metaInfo      the meta information, including the type serializer for state copy-on-write.
 	 * @param keySerializer the serializer of the key.
@@ -68,8 +88,19 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 		this.keyContext = Preconditions.checkNotNull(keyContext);
 		this.metaInfo = Preconditions.checkNotNull(metaInfo);
 		this.keySerializer = Preconditions.checkNotNull(keySerializer);
+
+		this.keyGroupOffset = keyContext.getKeyGroupRange().getStartKeyGroup();
+
+		@SuppressWarnings("unchecked")
+		StateMap<K, N, S>[] state = (StateMap<K, N, S>[]) new StateMap[keyContext.getKeyGroupRange().getNumberOfKeyGroups()];
+		this.keyGroupedStateMaps = state;
+		for (int i = 0; i < this.keyGroupedStateMaps.length; i++) {
+			this.keyGroupedStateMaps[i] = createStateMap();
+		}
 	}
 
+	protected abstract StateMap<K, N, S> createStateMap();
+
 	// Main interface methods of StateTable -------------------------------------------------------
 
 	/**
@@ -88,7 +119,13 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 	 *
 	 * @return the number of entries in this {@link StateTable}.
 	 */
-	public abstract int size();
+	public int size() {
+		int count = 0;
+		for (StateMap<K, N, S> stateMap : keyGroupedStateMaps) {
+			count += stateMap.size();
+		}
+		return count;
+	}
 
 	/**
 	 * Returns the state of the mapping for the composite of active key and given namespace.
@@ -97,7 +134,9 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 	 * @return the states of the mapping with the specified key/namespace composite key, or {@code null}
 	 * if no mapping for the specified key is found.
 	 */
-	public abstract S get(N namespace);
+	public S get(N namespace) {
+		return get(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace);
+	}
 
 	/**
 	 * Returns whether this table contains a mapping for the composite of active key and given namespace.
@@ -106,27 +145,19 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 	 * @return {@code true} if this map contains the specified key/namespace composite key,
 	 * {@code false} otherwise.
 	 */
-	public abstract boolean containsKey(N namespace);
-
-	/**
-	 * Maps the composite of active key and given namespace to the specified state. This method should be preferred
-	 * over {@link #putAndGetOld(N, S)} (Namespace, State)} when the caller is not interested in the old state.
-	 *
-	 * @param namespace the namespace. Not null.
-	 * @param state     the state. Can be null.
-	 */
-	public abstract void put(N namespace, S state);
+	public boolean containsKey(N namespace) {
+		return containsKey(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace);
+	}
 
 	/**
-	 * Maps the composite of active key and given namespace to the specified state. Returns the previous state that
-	 * was registered under the composite key.
+	 * Maps the composite of active key and given namespace to the specified state.
 	 *
 	 * @param namespace the namespace. Not null.
 	 * @param state     the state. Can be null.
-	 * @return the state of any previous mapping with the specified key or
-	 * {@code null} if there was no such mapping.
 	 */
-	public abstract S putAndGetOld(N namespace, S state);
+	public void put(N namespace, S state) {
+		put(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace, state);
+	}
 
 	/**
 	 * Removes the mapping for the composite of active key and given namespace. This method should be preferred
@@ -134,7 +165,9 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 	 *
 	 * @param namespace the namespace of the mapping to remove. Not null.
 	 */
-	public abstract void remove(N namespace);
+	public void remove(N namespace) {
+		remove(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace);
+	}
 
 	/**
 	 * Removes the mapping for the composite of active key and given namespace, returning the state that was
@@ -144,7 +177,9 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 	 * @return the state of the removed mapping or {@code null} if no mapping
 	 * for the specified key was found.
 	 */
-	public abstract S removeAndGetOld(N namespace);
+	public S removeAndGetOld(N namespace) {
+		return removeAndGetOld(keyContext.getCurrentKey(), keyContext.getCurrentKeyGroupIndex(), namespace);
+	}
 
 	/**
 	 * Applies the given {@link StateTransformationFunction} to the state (1st input argument), using the given value as
@@ -156,10 +191,17 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 	 * @param transformation the transformation function.
 	 * @throws Exception if some exception happens in the transformation function.
 	 */
-	public abstract <T> void transform(
+	public <T> void transform(
 			N namespace,
 			T value,
-			StateTransformationFunction<S, T> transformation) throws Exception;
+			StateTransformationFunction<S, T> transformation) throws Exception {
+		K key = keyContext.getCurrentKey();
+		checkKeyNamespacePreconditions(key, namespace);
+
+		int keyGroup = keyContext.getCurrentKeyGroupIndex();
+		StateMap<K, N, S> stateMap = getMapForKeyGroup(keyGroup);
+		stateMap.transform(key, namespace, value, transformation);
+	}
 
 	// For queryable state ------------------------------------------------------------------------
 
@@ -172,14 +214,103 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 	 * @return the state of the mapping with the specified key/namespace composite key, or {@code null}
 	 * if no mapping for the specified key is found.
 	 */
-	public abstract S get(K key, N namespace);
+	public S get(K key, N namespace) {
+		int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(key, keyContext.getNumberOfKeyGroups());
+		return get(key, keyGroup, namespace);
+	}
+
+	public Stream<K> getKeys(N namespace) {
+		return Arrays.stream(keyGroupedStateMaps)
+			.flatMap(stateMap -> StreamSupport.stream(Spliterators.spliteratorUnknownSize(stateMap.iterator(), 0), false))
+			.filter(entry -> entry.getNamespace().equals(namespace))
+			.map(StateEntry::getKey);
+	}
+
+	public StateIncrementalVisitor<K, N, S> getStateIncrementalVisitor(int recommendedMaxNumberOfReturnedRecords) {
+		return new StateEntryIterator(recommendedMaxNumberOfReturnedRecords);
+	}
+
+	// ------------------------------------------------------------------------
+
+	private S get(K key, int keyGroupIndex, N namespace) {
+		checkKeyNamespacePreconditions(key, namespace);
 
-	public abstract Stream<K> getKeys(N namespace);
+		StateMap<K, N, S> stateMap = getMapForKeyGroup(keyGroupIndex);
 
-	public abstract StateIncrementalVisitor<K, N, S> getStateIncrementalVisitor(int recommendedMaxNumberOfReturnedRecords);
+		if (stateMap == null) {
+			return null;
+		}
+
+		return stateMap.get(key, namespace);
+	}
+
+	private boolean containsKey(K key, int keyGroupIndex, N namespace) {
+		checkKeyNamespacePreconditions(key, namespace);
+
+		StateMap<K, N, S> stateMap = getMapForKeyGroup(keyGroupIndex);
+
+		return stateMap != null && stateMap.containsKey(key, namespace);
+	}
+
+	private void checkKeyNamespacePreconditions(K key, N namespace) {
+		Preconditions.checkNotNull(key, "No key set. This method should not be called outside of a keyed context.");
+		Preconditions.checkNotNull(namespace, "Provided namespace is null.");
+	}
+
+	private void remove(K key, int keyGroupIndex, N namespace) {
+		checkKeyNamespacePreconditions(key, namespace);
+
+		StateMap<K, N, S> stateMap = getMapForKeyGroup(keyGroupIndex);
+		stateMap.remove(key, namespace);
+	}
+
+	private S removeAndGetOld(K key, int keyGroupIndex, N namespace) {
+		checkKeyNamespacePreconditions(key, namespace);
+
+		StateMap<K, N, S> stateMap = getMapForKeyGroup(keyGroupIndex);
+
+		return stateMap.removeAndGetOld(key, namespace);
+	}
+
+	// ------------------------------------------------------------------------
+	//  access to maps
+	// ------------------------------------------------------------------------
+
+	/**
+	 * Returns the internal data structure.
+	 */
+	@VisibleForTesting
+	public StateMap<K, N, S>[] getState() {
+		return keyGroupedStateMaps;
+	}
+
+	public int getKeyGroupOffset() {
+		return keyGroupOffset;
+	}
+
+	@VisibleForTesting
+	protected StateMap<K, N, S> getMapForKeyGroup(int keyGroupIndex) {
+		final int pos = indexToOffset(keyGroupIndex);
+		if (pos >= 0 && pos < keyGroupedStateMaps.length) {
+			return keyGroupedStateMaps[pos];
+		} else {
+			return null;
+		}
+	}
+
+	/**
+	 * Translates a key-group id to the internal array offset.
+	 */
+	private int indexToOffset(int index) {
+		return index - keyGroupOffset;
+	}
 
 	// Meta data setter / getter and toString -----------------------------------------------------
 
+	public TypeSerializer<K> getKeySerializer() {
+		return keySerializer;
+	}
+
 	public TypeSerializer<S> getStateSerializer() {
 		return metaInfo.getStateSerializer();
 	}
@@ -198,16 +329,100 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 
 	// Snapshot / Restore -------------------------------------------------------------------------
 
-	public abstract void put(K key, int keyGroup, N namespace, S state);
+	public void put(K key, int keyGroup, N namespace, S state) {
+		checkKeyNamespacePreconditions(key, namespace);
+
+		StateMap<K, N, S> stateMap = getMapForKeyGroup(keyGroup);
+		stateMap.put(key, namespace, state);
+	}
+
+	@Override
+	public Iterator<StateEntry<K, N, S>> iterator() {
+		return Arrays.stream(keyGroupedStateMaps)
+			.filter(Objects::nonNull)
+			.flatMap(stateMap -> StreamSupport.stream(Spliterators.spliteratorUnknownSize(stateMap.iterator(), 0), false))
+			.iterator();
+	}
 
 	// For testing --------------------------------------------------------------------------------
 
 	@VisibleForTesting
-	public abstract int sizeOfNamespace(Object namespace);
+	public int sizeOfNamespace(Object namespace) {
+		int count = 0;
+		for (StateMap<K, N, S> stateMap : keyGroupedStateMaps) {
+			count += stateMap.sizeOfNamespace(namespace);
+		}
+
+		return count;
+	}
 
 	@Nonnull
 	@Override
 	public StateSnapshotKeyGroupReader keyGroupReader(int readVersion) {
 		return StateTableByKeyGroupReaders.readerForVersion(this, readVersion);
 	}
+
+	// StateEntryIterator  ---------------------------------------------------------------------------------------------
+
+	class StateEntryIterator implements StateIncrementalVisitor<K, N, S> {
+
+		final int recommendedMaxNumberOfReturnedRecords;
+
+		int keyGroupIndex;
+
+		StateIncrementalVisitor<K, N, S> stateIncrementalVisitor;
+
+		StateEntryIterator(int recommendedMaxNumberOfReturnedRecords) {
+			this.recommendedMaxNumberOfReturnedRecords = recommendedMaxNumberOfReturnedRecords;
+			this.keyGroupIndex = 0;
+			next();
+		}
+
+		private void next() {
+			while (keyGroupIndex < keyGroupedStateMaps.length) {
+				StateMap<K, N, S> stateMap = keyGroupedStateMaps[keyGroupIndex++];
+				StateIncrementalVisitor<K, N, S> visitor =
+					stateMap.getStateIncrementalVisitor(recommendedMaxNumberOfReturnedRecords);
+				if (visitor.hasNext()) {
+					stateIncrementalVisitor = visitor;
+					return;
+				}
+			}
+		}
+
+		@Override
+		public boolean hasNext() {
+			while (stateIncrementalVisitor == null || !stateIncrementalVisitor.hasNext()) {
+				if (keyGroupIndex == keyGroupedStateMaps.length) {
+					return false;
+				}
+				StateIncrementalVisitor<K, N, S> visitor =
+					keyGroupedStateMaps[keyGroupIndex++].getStateIncrementalVisitor(recommendedMaxNumberOfReturnedRecords);
+				if (visitor.hasNext()) {
+					stateIncrementalVisitor = visitor;
+					break;
+				}
+			}
+			return true;
+		}
+
+		@Override
+		public Collection<StateEntry<K, N, S>> nextEntries() {
+			if (!hasNext()) {
+				return null;
+			}
+
+			return stateIncrementalVisitor.nextEntries();
+		}
+
+		@Override
+		public void remove(StateEntry<K, N, S> stateEntry) {
+			keyGroupedStateMaps[keyGroupIndex - 1].remove(stateEntry.getKey(), stateEntry.getNamespace());
+		}
+
+		@Override
+		public void update(StateEntry<K, N, S> stateEntry, S newValue) {
+			keyGroupedStateMaps[keyGroupIndex - 1].put(stateEntry.getKey(), stateEntry.getNamespace(), newValue);
+		}
+	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index 462df70..132fd01 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -69,6 +69,7 @@ import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.KvStateRegistryListener;
 import org.apache.flink.runtime.state.heap.AbstractHeapState;
 import org.apache.flink.runtime.state.heap.NestedMapsStateTable;
+import org.apache.flink.runtime.state.heap.NestedStateMap;
 import org.apache.flink.runtime.state.heap.StateTable;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
 import org.apache.flink.runtime.state.internal.InternalKvState;
@@ -3428,8 +3429,9 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 		if (stateTable instanceof NestedMapsStateTable) {
 			int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(1, numberOfKeyGroups);
 			NestedMapsStateTable<?, ?, ?> nestedMapsStateTable = (NestedMapsStateTable<?, ?, ?>) stateTable;
-			assertTrue(nestedMapsStateTable.getState()[keyGroupIndex] instanceof ConcurrentHashMap);
-			assertTrue(nestedMapsStateTable.getState()[keyGroupIndex].get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
+			NestedStateMap<?, ?, ?>[] nestedStateMaps = (NestedStateMap<?, ?, ?>[]) nestedMapsStateTable.getState();
+			assertTrue(nestedStateMaps[keyGroupIndex].getNamespaceMap() instanceof ConcurrentHashMap);
+			assertTrue(nestedStateMaps[keyGroupIndex].getNamespaceMap().get(VoidNamespace.INSTANCE) instanceof ConcurrentHashMap);
 		}
 	}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapTest.java
similarity index 55%
copy from flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
copy to flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapTest.java
index 089dff6..2709cb5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateMapTest.java
@@ -18,25 +18,18 @@
 
 package org.apache.flink.runtime.state.heap;
 
-import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.tuple.Tuple3;
-import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
-import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.state.ArrayListSerializer;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateEntry;
-import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor;
 import org.apache.flink.util.TestLogger;
+
 import org.junit.Assert;
 import org.junit.Test;
 
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
@@ -44,25 +37,18 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Random;
 
-public class CopyOnWriteStateTableTest extends TestLogger {
-	private final TypeSerializer<Integer> keySerializer = IntSerializer.INSTANCE;
+/**
+ * Test for {@link CopyOnWriteStateMap}.
+ */
+public class CopyOnWriteStateMapTest extends TestLogger {
 
 	/**
 	 * Testing the basic map operations.
 	 */
 	@Test
 	public void testPutGetRemoveContainsTransform() throws Exception {
-		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyValueStateBackendMetaInfo<>(
-				StateDescriptor.Type.UNKNOWN,
-				"test",
-				IntSerializer.INSTANCE,
-				new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects.
-
-		final MockInternalKeyContext<Integer> keyContext = new MockInternalKeyContext<>();
-
-		final CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable =
-			new CopyOnWriteStateTable<>(keyContext, metaInfo, keySerializer);
+		final CopyOnWriteStateMap<Integer, Integer, ArrayList<Integer>> stateMap =
+			new CopyOnWriteStateMap<>(new ArrayListSerializer<>(IntSerializer.INSTANCE));
 
 		ArrayList<Integer> state_1_1 = new ArrayList<>();
 		state_1_1.add(41);
@@ -71,38 +57,38 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 		ArrayList<Integer> state_1_2 = new ArrayList<>();
 		state_1_2.add(43);
 
-		Assert.assertNull(stateTable.putAndGetOld(1, 1, state_1_1));
-		Assert.assertEquals(state_1_1, stateTable.get(1, 1));
-		Assert.assertEquals(1, stateTable.size());
+		Assert.assertNull(stateMap.putAndGetOld(1, 1, state_1_1));
+		Assert.assertEquals(state_1_1, stateMap.get(1, 1));
+		Assert.assertEquals(1, stateMap.size());
 
-		Assert.assertNull(stateTable.putAndGetOld(2, 1, state_2_1));
-		Assert.assertEquals(state_2_1, stateTable.get(2, 1));
-		Assert.assertEquals(2, stateTable.size());
+		Assert.assertNull(stateMap.putAndGetOld(2, 1, state_2_1));
+		Assert.assertEquals(state_2_1, stateMap.get(2, 1));
+		Assert.assertEquals(2, stateMap.size());
 
-		Assert.assertNull(stateTable.putAndGetOld(1, 2, state_1_2));
-		Assert.assertEquals(state_1_2, stateTable.get(1, 2));
-		Assert.assertEquals(3, stateTable.size());
+		Assert.assertNull(stateMap.putAndGetOld(1, 2, state_1_2));
+		Assert.assertEquals(state_1_2, stateMap.get(1, 2));
+		Assert.assertEquals(3, stateMap.size());
 
-		Assert.assertTrue(stateTable.containsKey(2, 1));
-		Assert.assertFalse(stateTable.containsKey(3, 1));
-		Assert.assertFalse(stateTable.containsKey(2, 3));
-		stateTable.put(2, 1, null);
-		Assert.assertTrue(stateTable.containsKey(2, 1));
-		Assert.assertEquals(3, stateTable.size());
-		Assert.assertNull(stateTable.get(2, 1));
-		stateTable.put(2, 1, state_2_1);
-		Assert.assertEquals(3, stateTable.size());
+		Assert.assertTrue(stateMap.containsKey(2, 1));
+		Assert.assertFalse(stateMap.containsKey(3, 1));
+		Assert.assertFalse(stateMap.containsKey(2, 3));
+		stateMap.put(2, 1, null);
+		Assert.assertTrue(stateMap.containsKey(2, 1));
+		Assert.assertEquals(3, stateMap.size());
+		Assert.assertNull(stateMap.get(2, 1));
+		stateMap.put(2, 1, state_2_1);
+		Assert.assertEquals(3, stateMap.size());
 
-		Assert.assertEquals(state_2_1, stateTable.removeAndGetOld(2, 1));
-		Assert.assertFalse(stateTable.containsKey(2, 1));
-		Assert.assertEquals(2, stateTable.size());
+		Assert.assertEquals(state_2_1, stateMap.removeAndGetOld(2, 1));
+		Assert.assertFalse(stateMap.containsKey(2, 1));
+		Assert.assertEquals(2, stateMap.size());
 
-		stateTable.remove(1, 2);
-		Assert.assertFalse(stateTable.containsKey(1, 2));
-		Assert.assertEquals(1, stateTable.size());
+		stateMap.remove(1, 2);
+		Assert.assertFalse(stateMap.containsKey(1, 2));
+		Assert.assertEquals(1, stateMap.size());
 
-		Assert.assertNull(stateTable.removeAndGetOld(4, 2));
-		Assert.assertEquals(1, stateTable.size());
+		Assert.assertNull(stateMap.removeAndGetOld(4, 2));
+		Assert.assertEquals(1, stateMap.size());
 
 		StateTransformationFunction<ArrayList<Integer>, Integer> function =
 			new StateTransformationFunction<ArrayList<Integer>, Integer>() {
@@ -114,9 +100,9 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 			};
 
 		final int value = 4711;
-		stateTable.transform(1, 1, value, function);
+		stateMap.transform(1, 1, value, function);
 		state_1_1 = function.apply(state_1_1, value);
-		Assert.assertEquals(state_1_1, stateTable.get(1, 1));
+		Assert.assertEquals(state_1_1, stateMap.get(1, 1));
 	}
 
 	/**
@@ -124,69 +110,50 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 	 */
 	@Test
 	public void testIncrementalRehash() {
-		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyValueStateBackendMetaInfo<>(
-				StateDescriptor.Type.UNKNOWN,
-				"test",
-				IntSerializer.INSTANCE,
-				new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects.
-
-		final MockInternalKeyContext<Integer> keyContext = new MockInternalKeyContext<>();
-
-		final CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable =
-			new CopyOnWriteStateTable<>(keyContext, metaInfo, keySerializer);
+		final CopyOnWriteStateMap<Integer, Integer, ArrayList<Integer>> stateMap =
+			new CopyOnWriteStateMap<>(new ArrayListSerializer<>(IntSerializer.INSTANCE));
 
 		int insert = 0;
 		int remove = 0;
-		while (!stateTable.isRehashing()) {
-			stateTable.put(insert++, 0, new ArrayList<Integer>());
+		while (!stateMap.isRehashing()) {
+			stateMap.put(insert++, 0, new ArrayList<Integer>());
 			if (insert % 8 == 0) {
-				stateTable.remove(remove++, 0);
+				stateMap.remove(remove++, 0);
 			}
 		}
-		Assert.assertEquals(insert - remove, stateTable.size());
-		while (stateTable.isRehashing()) {
-			stateTable.put(insert++, 0, new ArrayList<Integer>());
+		Assert.assertEquals(insert - remove, stateMap.size());
+		while (stateMap.isRehashing()) {
+			stateMap.put(insert++, 0, new ArrayList<Integer>());
 			if (insert % 8 == 0) {
-				stateTable.remove(remove++, 0);
+				stateMap.remove(remove++, 0);
 			}
 		}
-		Assert.assertEquals(insert - remove, stateTable.size());
+		Assert.assertEquals(insert - remove, stateMap.size());
 
 		for (int i = 0; i < insert; ++i) {
 			if (i < remove) {
-				Assert.assertFalse(stateTable.containsKey(i, 0));
+				Assert.assertFalse(stateMap.containsKey(i, 0));
 			} else {
-				Assert.assertTrue(stateTable.containsKey(i, 0));
+				Assert.assertTrue(stateMap.containsKey(i, 0));
 			}
 		}
 	}
 
 	/**
-	 * This test does some random modifications to a state table and a reference (hash map). Then draws snapshots,
+	 * This test does some random modifications to a state map and a reference (hash map). Then draws snapshots,
 	 * performs more modifications and checks snapshot integrity.
 	 */
 	@Test
 	public void testRandomModificationsAndCopyOnWriteIsolation() throws Exception {
-
-		final RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyValueStateBackendMetaInfo<>(
-				StateDescriptor.Type.UNKNOWN,
-				"test",
-				IntSerializer.INSTANCE,
-				new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects.
-
-		final MockInternalKeyContext<Integer> keyContext = new MockInternalKeyContext<>();
-
-		final CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable =
-			new CopyOnWriteStateTable<>(keyContext, metaInfo, keySerializer);
+		final CopyOnWriteStateMap<Integer, Integer, ArrayList<Integer>> stateMap =
+			new CopyOnWriteStateMap<>(new ArrayListSerializer<>(IntSerializer.INSTANCE));
 
 		final HashMap<Tuple2<Integer, Integer>, ArrayList<Integer>> referenceMap = new HashMap<>();
 
 		final Random random = new Random(42);
 
 		// holds snapshots from the map under test
-		CopyOnWriteStateTable.StateTableEntry<Integer, Integer, ArrayList<Integer>>[] snapshot = null;
+		CopyOnWriteStateMap.StateMapEntry<Integer, Integer, ArrayList<Integer>>[] snapshot = null;
 		int snapshotSize = 0;
 
 		// holds a reference snapshot from our reference map that we compare against
@@ -194,7 +161,6 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 
 		int val = 0;
 
-
 		int snapshotCounter = 0;
 		int referencedSnapshotId = 0;
 
@@ -212,7 +178,7 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 			};
 
 		StateIncrementalVisitor<Integer, Integer, ArrayList<Integer>> updatingIterator =
-			stateTable.getStateIncrementalVisitor(5);
+			stateMap.getStateIncrementalVisitor(5);
 
 		// the main loop for modifications
 		for (int i = 0; i < 10_000_000; ++i) {
@@ -229,39 +195,39 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 			switch (op) {
 				case 0:
 				case 1: {
-					state = stateTable.get(key, namespace);
+					state = stateMap.get(key, namespace);
 					referenceState = referenceMap.get(compositeKey);
 					if (null == state) {
 						state = new ArrayList<>();
-						stateTable.put(key, namespace, state);
+						stateMap.put(key, namespace, state);
 						referenceState = new ArrayList<>();
 						referenceMap.put(compositeKey, referenceState);
 					}
 					break;
 				}
 				case 2: {
-					stateTable.put(key, namespace, new ArrayList<Integer>());
+					stateMap.put(key, namespace, new ArrayList<Integer>());
 					referenceMap.put(compositeKey, new ArrayList<Integer>());
 					break;
 				}
 				case 3: {
-					state = stateTable.putAndGetOld(key, namespace, new ArrayList<Integer>());
+					state = stateMap.putAndGetOld(key, namespace, new ArrayList<Integer>());
 					referenceState = referenceMap.put(compositeKey, new ArrayList<Integer>());
 					break;
 				}
 				case 4: {
-					stateTable.remove(key, namespace);
+					stateMap.remove(key, namespace);
 					referenceMap.remove(compositeKey);
 					break;
 				}
 				case 5: {
-					state = stateTable.removeAndGetOld(key, namespace);
+					state = stateMap.removeAndGetOld(key, namespace);
 					referenceState = referenceMap.remove(compositeKey);
 					break;
 				}
 				case 6: {
 					final int updateValue = random.nextInt(1000);
-					stateTable.transform(key, namespace, updateValue, transformationFunction);
+					stateMap.transform(key, namespace, updateValue, transformationFunction);
 					referenceMap.put(compositeKey, transformationFunction.apply(
 						referenceMap.remove(compositeKey), updateValue));
 					break;
@@ -270,20 +236,20 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 				case 8:
 				case 9:
 					if (!updatingIterator.hasNext()) {
-						updatingIterator = stateTable.getStateIncrementalVisitor(5);
+						updatingIterator = stateMap.getStateIncrementalVisitor(5);
 						if (!updatingIterator.hasNext()) {
 							break;
 						}
 					}
 					testStateIteratorWithUpdate(
-						updatingIterator, stateTable, referenceMap, op == 8, op == 9);
+						updatingIterator, stateMap, referenceMap, op == 8, op == 9);
 					break;
 				default: {
 					Assert.fail("Unknown op-code " + op);
 				}
 			}
 
-			Assert.assertEquals(referenceMap.size(), stateTable.size());
+			Assert.assertEquals(referenceMap.size(), stateMap.size());
 
 			if (state != null) {
 				// mutate the states a bit...
@@ -309,8 +275,8 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 					if (i % 1_000 == 0) {
 						// draw and release some other snapshot while holding on the old snapshot
 						++snapshotCounter;
-						stateTable.snapshotTableArrays();
-						stateTable.releaseSnapshot(snapshotCounter);
+						stateMap.snapshotMapArrays();
+						stateMap.releaseSnapshot(snapshotCounter);
 					}
 
 					//release the snapshot after some time
@@ -318,15 +284,15 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 						snapshot = null;
 						reference = null;
 						snapshotSize = 0;
-						stateTable.releaseSnapshot(referencedSnapshotId);
+						stateMap.releaseSnapshot(referencedSnapshotId);
 					}
 
 				} else {
 					// if there is no more referenced snapshot, we create one
 					++snapshotCounter;
 					referencedSnapshotId = snapshotCounter;
-					snapshot = stateTable.snapshotTableArrays();
-					snapshotSize = stateTable.size();
+					snapshot = stateMap.snapshotMapArrays();
+					snapshotSize = stateMap.size();
 					reference = manualDeepDump(referenceMap);
 				}
 			}
@@ -340,7 +306,7 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 	 */
 	private static void testStateIteratorWithUpdate(
 		StateIncrementalVisitor<Integer, Integer, ArrayList<Integer>> updatingIterator,
-		CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable,
+		CopyOnWriteStateMap<Integer, Integer, ArrayList<Integer>> stateMap,
 		HashMap<Tuple2<Integer, Integer>, ArrayList<Integer>> referenceMap,
 		boolean update, boolean remove) {
 
@@ -357,7 +323,7 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 				}
 				updatingIterator.update(stateEntry, newState);
 				referenceMap.put(compositeKey, new ArrayList<>(newState));
-				Assert.assertEquals(newState, stateTable.get(key, namespace));
+				Assert.assertEquals(newState, stateMap.get(key, namespace));
 			}
 
 			if (remove) {
@@ -373,17 +339,8 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 	 */
 	@Test
 	public void testCopyOnWriteContracts() {
-		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyValueStateBackendMetaInfo<>(
-				StateDescriptor.Type.UNKNOWN,
-				"test",
-				IntSerializer.INSTANCE,
-				new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects.
-
-		final MockInternalKeyContext<Integer> keyContext = new MockInternalKeyContext<>();
-
-		final CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable =
-			new CopyOnWriteStateTable<>(keyContext, metaInfo, keySerializer);
+		final CopyOnWriteStateMap<Integer, Integer, ArrayList<Integer>> stateMap =
+			new CopyOnWriteStateMap<>(new ArrayListSerializer<>(IntSerializer.INSTANCE));
 
 		ArrayList<Integer> originalState1 = new ArrayList<>(1);
 		ArrayList<Integer> originalState2 = new ArrayList<>(1);
@@ -397,96 +354,52 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 		originalState4.add(4);
 		originalState5.add(5);
 
-		stateTable.put(1, 1, originalState1);
-		stateTable.put(2, 1, originalState2);
-		stateTable.put(4, 1, originalState4);
-		stateTable.put(5, 1, originalState5);
+		stateMap.put(1, 1, originalState1);
+		stateMap.put(2, 1, originalState2);
+		stateMap.put(4, 1, originalState4);
+		stateMap.put(5, 1, originalState5);
 
 		// no snapshot taken, we get the original back
-		Assert.assertTrue(stateTable.get(1, 1) == originalState1);
-		CopyOnWriteStateTableSnapshot<Integer, Integer, ArrayList<Integer>> snapshot1 = stateTable.stateSnapshot();
+		Assert.assertTrue(stateMap.get(1, 1) == originalState1);
+		CopyOnWriteStateMapSnapshot<Integer, Integer, ArrayList<Integer>> snapshot1 = stateMap.stateSnapshot();
 		// after snapshot1 is taken, we get a copy...
-		final ArrayList<Integer> copyState = stateTable.get(1, 1);
+		final ArrayList<Integer> copyState = stateMap.get(1, 1);
 		Assert.assertFalse(copyState == originalState1);
 		// ...and the copy is equal
 		Assert.assertEquals(originalState1, copyState);
 
 		// we make an insert AFTER snapshot1
-		stateTable.put(3, 1, originalState3);
+		stateMap.put(3, 1, originalState3);
 
 		// on repeated lookups, we get the same copy because no further snapshot was taken
-		Assert.assertTrue(copyState == stateTable.get(1, 1));
+		Assert.assertTrue(copyState == stateMap.get(1, 1));
 
 		// we take snapshot2
-		CopyOnWriteStateTableSnapshot<Integer, Integer, ArrayList<Integer>> snapshot2 = stateTable.stateSnapshot();
+		CopyOnWriteStateMapSnapshot<Integer, Integer, ArrayList<Integer>> snapshot2 = stateMap.stateSnapshot();
 		// after the second snapshot, copy-on-write is active again for old entries
-		Assert.assertFalse(copyState == stateTable.get(1, 1));
+		Assert.assertFalse(copyState == stateMap.get(1, 1));
 		// and equality still holds
-		Assert.assertEquals(copyState, stateTable.get(1, 1));
+		Assert.assertEquals(copyState, stateMap.get(1, 1));
 
 		// after releasing snapshot2
-		stateTable.releaseSnapshot(snapshot2);
+		stateMap.releaseSnapshot(snapshot2);
 		// we still get the original of the untouched late insert (after snapshot1)
-		Assert.assertTrue(originalState3 == stateTable.get(3, 1));
+		Assert.assertTrue(originalState3 == stateMap.get(3, 1));
 		// but copy-on-write is still active for older inserts (before snapshot1)
-		Assert.assertFalse(originalState4 == stateTable.get(4, 1));
+		Assert.assertFalse(originalState4 == stateMap.get(4, 1));
 
 		// after releasing snapshot1
-		stateTable.releaseSnapshot(snapshot1);
+		stateMap.releaseSnapshot(snapshot1);
 		// no copy-on-write is active
-		Assert.assertTrue(originalState5 == stateTable.get(5, 1));
-	}
-
-	/**
-	 * This tests that serializers used for snapshots are duplicates of the ones used in
-	 * processing to avoid race conditions in stateful serializers.
-	 */
-	@Test
-	public void testSerializerDuplicationInSnapshot() throws IOException {
-
-		final TestDuplicateSerializer namespaceSerializer = new TestDuplicateSerializer();
-		final TestDuplicateSerializer stateSerializer = new TestDuplicateSerializer();
-		final TestDuplicateSerializer keySerializer = new TestDuplicateSerializer();
-
-		RegisteredKeyValueStateBackendMetaInfo<Integer, Integer> metaInfo =
-			new RegisteredKeyValueStateBackendMetaInfo<>(
-				StateDescriptor.Type.VALUE,
-				"test",
-				namespaceSerializer,
-				stateSerializer);
-
-		InternalKeyContext<Integer> mockKeyContext = new MockInternalKeyContext<>();
-		CopyOnWriteStateTable<Integer, Integer, Integer> table =
-			new CopyOnWriteStateTable<>(mockKeyContext, metaInfo, keySerializer);
-
-		table.put(0, 0, 0, 0);
-		table.put(1, 0, 0, 1);
-		table.put(2, 0, 1, 2);
-
-
-		final CopyOnWriteStateTableSnapshot<Integer, Integer, Integer> snapshot = table.stateSnapshot();
-
-		try {
-			final StateSnapshot.StateKeyGroupWriter partitionedSnapshot = snapshot.getKeyGroupWriter();
-			namespaceSerializer.disable();
-			keySerializer.disable();
-			stateSerializer.disable();
-
-			partitionedSnapshot.writeStateInKeyGroup(
-				new DataOutputViewStreamWrapper(
-					new ByteArrayOutputStreamWithPos(1024)), 0);
-
-		} finally {
-			table.releaseSnapshot(snapshot);
-		}
+		Assert.assertTrue(originalState5 == stateMap.get(5, 1));
 	}
 
 	@SuppressWarnings("unchecked")
-	private static <K, N, S> Tuple3<K, N, S>[] convert(CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshot, int mapSize) {
+	private static <K, N, S> Tuple3<K, N, S>[] convert(CopyOnWriteStateMap.StateMapEntry<K, N, S>[] snapshot, int mapSize) {
 
 		Tuple3<K, N, S>[] result = new Tuple3[mapSize];
 		int pos = 0;
-		for (CopyOnWriteStateTable.StateTableEntry<K, N, S> entry : snapshot) {
+		for (CopyOnWriteStateMap.StateMapEntry<K, N, S> entry : snapshot) {
 			while (null != entry) {
 				result[pos++] = new Tuple3<>(entry.getKey(), entry.getNamespace(), entry.getState());
 				entry = entry.next;
@@ -543,11 +456,4 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 			Assert.assertEquals(av.f2, bv.f2);
 		}
 	}
-
-	static class MockInternalKeyContext<T> extends InternalKeyContextImpl<T> {
-		MockInternalKeyContext() {
-			super(new KeyGroupRange(0,0),1);
-		}
-	}
-
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
index 089dff6..0a3d272 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
@@ -19,423 +19,19 @@
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
-import org.apache.flink.runtime.state.ArrayListSerializer;
-import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
-import org.apache.flink.runtime.state.StateEntry;
 import org.apache.flink.runtime.state.StateSnapshot;
-import org.apache.flink.runtime.state.StateTransformationFunction;
-import org.apache.flink.runtime.state.internal.InternalKvState.StateIncrementalVisitor;
-import org.apache.flink.util.TestLogger;
-import org.junit.Assert;
+
 import org.junit.Test;
 
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Random;
-
-public class CopyOnWriteStateTableTest extends TestLogger {
-	private final TypeSerializer<Integer> keySerializer = IntSerializer.INSTANCE;
-
-	/**
-	 * Testing the basic map operations.
-	 */
-	@Test
-	public void testPutGetRemoveContainsTransform() throws Exception {
-		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyValueStateBackendMetaInfo<>(
-				StateDescriptor.Type.UNKNOWN,
-				"test",
-				IntSerializer.INSTANCE,
-				new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects.
-
-		final MockInternalKeyContext<Integer> keyContext = new MockInternalKeyContext<>();
-
-		final CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable =
-			new CopyOnWriteStateTable<>(keyContext, metaInfo, keySerializer);
-
-		ArrayList<Integer> state_1_1 = new ArrayList<>();
-		state_1_1.add(41);
-		ArrayList<Integer> state_2_1 = new ArrayList<>();
-		state_2_1.add(42);
-		ArrayList<Integer> state_1_2 = new ArrayList<>();
-		state_1_2.add(43);
-
-		Assert.assertNull(stateTable.putAndGetOld(1, 1, state_1_1));
-		Assert.assertEquals(state_1_1, stateTable.get(1, 1));
-		Assert.assertEquals(1, stateTable.size());
-
-		Assert.assertNull(stateTable.putAndGetOld(2, 1, state_2_1));
-		Assert.assertEquals(state_2_1, stateTable.get(2, 1));
-		Assert.assertEquals(2, stateTable.size());
-
-		Assert.assertNull(stateTable.putAndGetOld(1, 2, state_1_2));
-		Assert.assertEquals(state_1_2, stateTable.get(1, 2));
-		Assert.assertEquals(3, stateTable.size());
-
-		Assert.assertTrue(stateTable.containsKey(2, 1));
-		Assert.assertFalse(stateTable.containsKey(3, 1));
-		Assert.assertFalse(stateTable.containsKey(2, 3));
-		stateTable.put(2, 1, null);
-		Assert.assertTrue(stateTable.containsKey(2, 1));
-		Assert.assertEquals(3, stateTable.size());
-		Assert.assertNull(stateTable.get(2, 1));
-		stateTable.put(2, 1, state_2_1);
-		Assert.assertEquals(3, stateTable.size());
-
-		Assert.assertEquals(state_2_1, stateTable.removeAndGetOld(2, 1));
-		Assert.assertFalse(stateTable.containsKey(2, 1));
-		Assert.assertEquals(2, stateTable.size());
-
-		stateTable.remove(1, 2);
-		Assert.assertFalse(stateTable.containsKey(1, 2));
-		Assert.assertEquals(1, stateTable.size());
-
-		Assert.assertNull(stateTable.removeAndGetOld(4, 2));
-		Assert.assertEquals(1, stateTable.size());
-
-		StateTransformationFunction<ArrayList<Integer>, Integer> function =
-			new StateTransformationFunction<ArrayList<Integer>, Integer>() {
-				@Override
-				public ArrayList<Integer> apply(ArrayList<Integer> previousState, Integer value) throws Exception {
-					previousState.add(value);
-					return previousState;
-				}
-			};
-
-		final int value = 4711;
-		stateTable.transform(1, 1, value, function);
-		state_1_1 = function.apply(state_1_1, value);
-		Assert.assertEquals(state_1_1, stateTable.get(1, 1));
-	}
-
-	/**
-	 * This test triggers incremental rehash and tests for corruptions.
-	 */
-	@Test
-	public void testIncrementalRehash() {
-		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyValueStateBackendMetaInfo<>(
-				StateDescriptor.Type.UNKNOWN,
-				"test",
-				IntSerializer.INSTANCE,
-				new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects.
-
-		final MockInternalKeyContext<Integer> keyContext = new MockInternalKeyContext<>();
-
-		final CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable =
-			new CopyOnWriteStateTable<>(keyContext, metaInfo, keySerializer);
-
-		int insert = 0;
-		int remove = 0;
-		while (!stateTable.isRehashing()) {
-			stateTable.put(insert++, 0, new ArrayList<Integer>());
-			if (insert % 8 == 0) {
-				stateTable.remove(remove++, 0);
-			}
-		}
-		Assert.assertEquals(insert - remove, stateTable.size());
-		while (stateTable.isRehashing()) {
-			stateTable.put(insert++, 0, new ArrayList<Integer>());
-			if (insert % 8 == 0) {
-				stateTable.remove(remove++, 0);
-			}
-		}
-		Assert.assertEquals(insert - remove, stateTable.size());
-
-		for (int i = 0; i < insert; ++i) {
-			if (i < remove) {
-				Assert.assertFalse(stateTable.containsKey(i, 0));
-			} else {
-				Assert.assertTrue(stateTable.containsKey(i, 0));
-			}
-		}
-	}
-
-	/**
-	 * This test does some random modifications to a state table and a reference (hash map). Then draws snapshots,
-	 * performs more modifications and checks snapshot integrity.
-	 */
-	@Test
-	public void testRandomModificationsAndCopyOnWriteIsolation() throws Exception {
-
-		final RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyValueStateBackendMetaInfo<>(
-				StateDescriptor.Type.UNKNOWN,
-				"test",
-				IntSerializer.INSTANCE,
-				new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects.
-
-		final MockInternalKeyContext<Integer> keyContext = new MockInternalKeyContext<>();
-
-		final CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable =
-			new CopyOnWriteStateTable<>(keyContext, metaInfo, keySerializer);
-
-		final HashMap<Tuple2<Integer, Integer>, ArrayList<Integer>> referenceMap = new HashMap<>();
-
-		final Random random = new Random(42);
-
-		// holds snapshots from the map under test
-		CopyOnWriteStateTable.StateTableEntry<Integer, Integer, ArrayList<Integer>>[] snapshot = null;
-		int snapshotSize = 0;
-
-		// holds a reference snapshot from our reference map that we compare against
-		Tuple3<Integer, Integer, ArrayList<Integer>>[] reference = null;
-
-		int val = 0;
-
-
-		int snapshotCounter = 0;
-		int referencedSnapshotId = 0;
-
-		final StateTransformationFunction<ArrayList<Integer>, Integer> transformationFunction =
-			new StateTransformationFunction<ArrayList<Integer>, Integer>() {
-				@Override
-				public ArrayList<Integer> apply(ArrayList<Integer> previousState, Integer value) throws Exception {
-					if (previousState == null) {
-						previousState = new ArrayList<>();
-					}
-					previousState.add(value);
-					// we give back the original, attempting to spot errors in to copy-on-write
-					return previousState;
-				}
-			};
-
-		StateIncrementalVisitor<Integer, Integer, ArrayList<Integer>> updatingIterator =
-			stateTable.getStateIncrementalVisitor(5);
-
-		// the main loop for modifications
-		for (int i = 0; i < 10_000_000; ++i) {
-
-			int key = random.nextInt(20);
-			int namespace = random.nextInt(4);
-			Tuple2<Integer, Integer> compositeKey = new Tuple2<>(key, namespace);
-
-			int op = random.nextInt(10);
-
-			ArrayList<Integer> state = null;
-			ArrayList<Integer> referenceState = null;
 
-			switch (op) {
-				case 0:
-				case 1: {
-					state = stateTable.get(key, namespace);
-					referenceState = referenceMap.get(compositeKey);
-					if (null == state) {
-						state = new ArrayList<>();
-						stateTable.put(key, namespace, state);
-						referenceState = new ArrayList<>();
-						referenceMap.put(compositeKey, referenceState);
-					}
-					break;
-				}
-				case 2: {
-					stateTable.put(key, namespace, new ArrayList<Integer>());
-					referenceMap.put(compositeKey, new ArrayList<Integer>());
-					break;
-				}
-				case 3: {
-					state = stateTable.putAndGetOld(key, namespace, new ArrayList<Integer>());
-					referenceState = referenceMap.put(compositeKey, new ArrayList<Integer>());
-					break;
-				}
-				case 4: {
-					stateTable.remove(key, namespace);
-					referenceMap.remove(compositeKey);
-					break;
-				}
-				case 5: {
-					state = stateTable.removeAndGetOld(key, namespace);
-					referenceState = referenceMap.remove(compositeKey);
-					break;
-				}
-				case 6: {
-					final int updateValue = random.nextInt(1000);
-					stateTable.transform(key, namespace, updateValue, transformationFunction);
-					referenceMap.put(compositeKey, transformationFunction.apply(
-						referenceMap.remove(compositeKey), updateValue));
-					break;
-				}
-				case 7:
-				case 8:
-				case 9:
-					if (!updatingIterator.hasNext()) {
-						updatingIterator = stateTable.getStateIncrementalVisitor(5);
-						if (!updatingIterator.hasNext()) {
-							break;
-						}
-					}
-					testStateIteratorWithUpdate(
-						updatingIterator, stateTable, referenceMap, op == 8, op == 9);
-					break;
-				default: {
-					Assert.fail("Unknown op-code " + op);
-				}
-			}
-
-			Assert.assertEquals(referenceMap.size(), stateTable.size());
-
-			if (state != null) {
-				// mutate the states a bit...
-				if (random.nextBoolean() && !state.isEmpty()) {
-					state.remove(state.size() - 1);
-					referenceState.remove(referenceState.size() - 1);
-				} else {
-					state.add(val);
-					referenceState.add(val);
-					++val;
-				}
-			}
-
-			Assert.assertEquals(referenceState, state);
-
-			// snapshot triggering / comparison / release
-			if (i > 0 && i % 500 == 0) {
-
-				if (snapshot != null) {
-					// check our referenced snapshot
-					deepCheck(reference, convert(snapshot, snapshotSize));
-
-					if (i % 1_000 == 0) {
-						// draw and release some other snapshot while holding on the old snapshot
-						++snapshotCounter;
-						stateTable.snapshotTableArrays();
-						stateTable.releaseSnapshot(snapshotCounter);
-					}
-
-					//release the snapshot after some time
-					if (i % 5_000 == 0) {
-						snapshot = null;
-						reference = null;
-						snapshotSize = 0;
-						stateTable.releaseSnapshot(referencedSnapshotId);
-					}
-
-				} else {
-					// if there is no more referenced snapshot, we create one
-					++snapshotCounter;
-					referencedSnapshotId = snapshotCounter;
-					snapshot = stateTable.snapshotTableArrays();
-					snapshotSize = stateTable.size();
-					reference = manualDeepDump(referenceMap);
-				}
-			}
-		}
-	}
-
-	/**
-	 * Test operations specific for StateIncrementalVisitor in {@code testRandomModificationsAndCopyOnWriteIsolation()}.
-	 *
-	 * <p>Check next, update and remove during global iteration of StateIncrementalVisitor.
-	 */
-	private static void testStateIteratorWithUpdate(
-		StateIncrementalVisitor<Integer, Integer, ArrayList<Integer>> updatingIterator,
-		CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable,
-		HashMap<Tuple2<Integer, Integer>, ArrayList<Integer>> referenceMap,
-		boolean update, boolean remove) {
-
-		for (StateEntry<Integer, Integer, ArrayList<Integer>> stateEntry : updatingIterator.nextEntries()) {
-			Integer key = stateEntry.getKey();
-			Integer namespace = stateEntry.getNamespace();
-			Tuple2<Integer, Integer> compositeKey = new Tuple2<>(key, namespace);
-			Assert.assertEquals(referenceMap.get(compositeKey), stateEntry.getState());
-
-			if (update) {
-				ArrayList<Integer> newState = new ArrayList<>(stateEntry.getState());
-				if (!newState.isEmpty()) {
-					newState.remove(0);
-				}
-				updatingIterator.update(stateEntry, newState);
-				referenceMap.put(compositeKey, new ArrayList<>(newState));
-				Assert.assertEquals(newState, stateTable.get(key, namespace));
-			}
-
-			if (remove) {
-				updatingIterator.remove(stateEntry);
-				referenceMap.remove(compositeKey);
-			}
-		}
-	}
-
-	/**
-	 * This tests for the copy-on-write contracts, e.g. ensures that no copy-on-write is active after all snapshots are
-	 * released.
-	 */
-	@Test
-	public void testCopyOnWriteContracts() {
-		RegisteredKeyValueStateBackendMetaInfo<Integer, ArrayList<Integer>> metaInfo =
-			new RegisteredKeyValueStateBackendMetaInfo<>(
-				StateDescriptor.Type.UNKNOWN,
-				"test",
-				IntSerializer.INSTANCE,
-				new ArrayListSerializer<>(IntSerializer.INSTANCE)); // we use mutable state objects.
-
-		final MockInternalKeyContext<Integer> keyContext = new MockInternalKeyContext<>();
-
-		final CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> stateTable =
-			new CopyOnWriteStateTable<>(keyContext, metaInfo, keySerializer);
-
-		ArrayList<Integer> originalState1 = new ArrayList<>(1);
-		ArrayList<Integer> originalState2 = new ArrayList<>(1);
-		ArrayList<Integer> originalState3 = new ArrayList<>(1);
-		ArrayList<Integer> originalState4 = new ArrayList<>(1);
-		ArrayList<Integer> originalState5 = new ArrayList<>(1);
-
-		originalState1.add(1);
-		originalState2.add(2);
-		originalState3.add(3);
-		originalState4.add(4);
-		originalState5.add(5);
-
-		stateTable.put(1, 1, originalState1);
-		stateTable.put(2, 1, originalState2);
-		stateTable.put(4, 1, originalState4);
-		stateTable.put(5, 1, originalState5);
-
-		// no snapshot taken, we get the original back
-		Assert.assertTrue(stateTable.get(1, 1) == originalState1);
-		CopyOnWriteStateTableSnapshot<Integer, Integer, ArrayList<Integer>> snapshot1 = stateTable.stateSnapshot();
-		// after snapshot1 is taken, we get a copy...
-		final ArrayList<Integer> copyState = stateTable.get(1, 1);
-		Assert.assertFalse(copyState == originalState1);
-		// ...and the copy is equal
-		Assert.assertEquals(originalState1, copyState);
-
-		// we make an insert AFTER snapshot1
-		stateTable.put(3, 1, originalState3);
-
-		// on repeated lookups, we get the same copy because no further snapshot was taken
-		Assert.assertTrue(copyState == stateTable.get(1, 1));
-
-		// we take snapshot2
-		CopyOnWriteStateTableSnapshot<Integer, Integer, ArrayList<Integer>> snapshot2 = stateTable.stateSnapshot();
-		// after the second snapshot, copy-on-write is active again for old entries
-		Assert.assertFalse(copyState == stateTable.get(1, 1));
-		// and equality still holds
-		Assert.assertEquals(copyState, stateTable.get(1, 1));
-
-		// after releasing snapshot2
-		stateTable.releaseSnapshot(snapshot2);
-		// we still get the original of the untouched late insert (after snapshot1)
-		Assert.assertTrue(originalState3 == stateTable.get(3, 1));
-		// but copy-on-write is still active for older inserts (before snapshot1)
-		Assert.assertFalse(originalState4 == stateTable.get(4, 1));
-
-		// after releasing snapshot1
-		stateTable.releaseSnapshot(snapshot1);
-		// no copy-on-write is active
-		Assert.assertTrue(originalState5 == stateTable.get(5, 1));
-	}
+/**
+ * Test for {@link CopyOnWriteStateTable}.
+ */
+public class CopyOnWriteStateTableTest {
 
 	/**
 	 * This tests that serializers used for snapshots are duplicates of the ones used in
@@ -466,88 +62,13 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 
 		final CopyOnWriteStateTableSnapshot<Integer, Integer, Integer> snapshot = table.stateSnapshot();
 
-		try {
-			final StateSnapshot.StateKeyGroupWriter partitionedSnapshot = snapshot.getKeyGroupWriter();
-			namespaceSerializer.disable();
-			keySerializer.disable();
-			stateSerializer.disable();
-
-			partitionedSnapshot.writeStateInKeyGroup(
-				new DataOutputViewStreamWrapper(
-					new ByteArrayOutputStreamWithPos(1024)), 0);
+		final StateSnapshot.StateKeyGroupWriter partitionedSnapshot = snapshot.getKeyGroupWriter();
+		namespaceSerializer.disable();
+		keySerializer.disable();
+		stateSerializer.disable();
 
-		} finally {
-			table.releaseSnapshot(snapshot);
-		}
+		partitionedSnapshot.writeStateInKeyGroup(
+			new DataOutputViewStreamWrapper(
+				new ByteArrayOutputStreamWithPos(1024)), 0);
 	}
-
-	@SuppressWarnings("unchecked")
-	private static <K, N, S> Tuple3<K, N, S>[] convert(CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshot, int mapSize) {
-
-		Tuple3<K, N, S>[] result = new Tuple3[mapSize];
-		int pos = 0;
-		for (CopyOnWriteStateTable.StateTableEntry<K, N, S> entry : snapshot) {
-			while (null != entry) {
-				result[pos++] = new Tuple3<>(entry.getKey(), entry.getNamespace(), entry.getState());
-				entry = entry.next;
-			}
-		}
-		Assert.assertEquals(mapSize, pos);
-		return result;
-	}
-
-	@SuppressWarnings("unchecked")
-	private Tuple3<Integer, Integer, ArrayList<Integer>>[] manualDeepDump(
-		HashMap<Tuple2<Integer, Integer>,
-			ArrayList<Integer>> map) {
-
-		Tuple3<Integer, Integer, ArrayList<Integer>>[] result = new Tuple3[map.size()];
-		int pos = 0;
-		for (Map.Entry<Tuple2<Integer, Integer>, ArrayList<Integer>> entry : map.entrySet()) {
-			Integer key = entry.getKey().f0;
-			Integer namespace = entry.getKey().f1;
-			result[pos++] = new Tuple3<>(key, namespace, new ArrayList<>(entry.getValue()));
-		}
-		return result;
-	}
-
-	private void deepCheck(
-		Tuple3<Integer, Integer, ArrayList<Integer>>[] a,
-		Tuple3<Integer, Integer, ArrayList<Integer>>[] b) {
-
-		if (a == b) {
-			return;
-		}
-
-		Assert.assertEquals(a.length, b.length);
-
-		Comparator<Tuple3<Integer, Integer, ArrayList<Integer>>> comparator =
-			new Comparator<Tuple3<Integer, Integer, ArrayList<Integer>>>() {
-
-				@Override
-				public int compare(Tuple3<Integer, Integer, ArrayList<Integer>> o1, Tuple3<Integer, Integer, ArrayList<Integer>> o2) {
-					int namespaceDiff = o1.f1 - o2.f1;
-					return namespaceDiff != 0 ? namespaceDiff : o1.f0 - o2.f0;
-				}
-			};
-
-		Arrays.sort(a, comparator);
-		Arrays.sort(b, comparator);
-
-		for (int i = 0; i < a.length; ++i) {
-			Tuple3<Integer, Integer, ArrayList<Integer>> av = a[i];
-			Tuple3<Integer, Integer, ArrayList<Integer>> bv = b[i];
-
-			Assert.assertEquals(av.f0, bv.f0);
-			Assert.assertEquals(av.f1, bv.f1);
-			Assert.assertEquals(av.f2, bv.f2);
-		}
-	}
-
-	static class MockInternalKeyContext<T> extends InternalKeyContextImpl<T> {
-		MockInternalKeyContext() {
-			super(new KeyGroupRange(0,0),1);
-		}
-	}
-
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/MockInternalKeyContext.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/MockInternalKeyContext.java
new file mode 100644
index 0000000..865ebe2
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/MockInternalKeyContext.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.runtime.state.KeyGroupRange;
+
+/**
+ * Mock {@link InternalKeyContext}.
+ */
+public class MockInternalKeyContext<K> extends InternalKeyContextImpl<K> {
+	MockInternalKeyContext() {
+		super(new KeyGroupRange(0, 0), 1);
+	}
+
+	@Override
+	public void setCurrentKey(K key) {
+		super.setCurrentKey(key);
+		super.setCurrentKeyGroupIndex(0);
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableKeyGroupPartitionerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableKeyGroupPartitionerTest.java
deleted file mode 100644
index 745719a..0000000
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableKeyGroupPartitionerTest.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state.heap;
-
-import org.apache.flink.runtime.state.KeyGroupPartitioner;
-import org.apache.flink.runtime.state.KeyGroupPartitionerTestBase;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.VoidNamespace;
-import org.apache.flink.runtime.state.heap.CopyOnWriteStateTable.StateTableEntry;
-
-import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
-
-import java.util.Random;
-import java.util.Set;
-
-/**
- * Test for {@link org.apache.flink.runtime.state.heap.CopyOnWriteStateTableSnapshot.StateTableKeyGroupPartitioner}.
- */
-public class StateTableKeyGroupPartitionerTest extends
-	KeyGroupPartitionerTestBase<StateTableEntry<Integer, VoidNamespace, Integer>> {
-
-	public StateTableKeyGroupPartitionerTest() {
-		super(random -> generateElement(random, null), StateTableEntry::getKey);
-	}
-
-	@SuppressWarnings("unchecked")
-	@Override
-	protected StateTableEntry<Integer, VoidNamespace, Integer>[] generateTestInput(
-		Random random,
-		int numElementsToGenerate,
-		Set<StateTableEntry<Integer, VoidNamespace, Integer>> allElementsIdentitySet) {
-
-		// we let the array size differ a bit from the test size to check this works
-		final int arraySize = numElementsToGenerate > 1 ? numElementsToGenerate + 5 : numElementsToGenerate;
-		final StateTableEntry<Integer, VoidNamespace, Integer>[] data = new StateTableEntry[arraySize];
-
-		while (numElementsToGenerate > 0) {
-
-			final int generateAsChainCount = Math.min(1 + random.nextInt(3) , numElementsToGenerate);
-
-			StateTableEntry<Integer, VoidNamespace, Integer> element = null;
-			for (int i = 0; i < generateAsChainCount; ++i) {
-				element = generateElement(random, element);
-				allElementsIdentitySet. add(element);
-			}
-
-			data[data.length - numElementsToGenerate + random.nextInt(generateAsChainCount)] = element;
-			numElementsToGenerate -= generateAsChainCount;
-		}
-
-		return data;
-	}
-
-	@Override
-	protected KeyGroupPartitioner<StateTableEntry<Integer, VoidNamespace, Integer>> createPartitioner(
-		StateTableEntry<Integer, VoidNamespace, Integer>[] data,
-		int numElements,
-		KeyGroupRange keyGroupRange,
-		int totalKeyGroups,
-		KeyGroupPartitioner.ElementWriterFunction<
-			StateTableEntry<Integer, VoidNamespace, Integer>> elementWriterFunction) {
-
-		return new CopyOnWriteStateTableSnapshot.StateTableKeyGroupPartitioner<>(
-			data,
-			numElements,
-			keyGroupRange,
-			totalKeyGroups,
-			elementWriterFunction);
-	}
-
-	private static StateTableEntry<Integer, VoidNamespace, Integer> generateElement(
-		@Nonnull Random random,
-		@Nullable StateTableEntry<Integer, VoidNamespace, Integer> next) {
-
-		Integer generatedKey =  random.nextInt() & Integer.MAX_VALUE;
-		return new StateTableEntry<>(
-			generatedKey,
-			VoidNamespace.INSTANCE,
-			random.nextInt(),
-			generatedKey.hashCode(),
-			next,
-			0,
-			0);
-	}
-}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
index 81d3e65..32e4977 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
@@ -40,6 +40,9 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Random;
 
+/**
+ * Test for snapshot compatiblily between differen state tables.
+ */
 public class StateTableSnapshotCompatibilityTest {
 	private final TypeSerializer<Integer> keySerializer = IntSerializer.INSTANCE;
 
@@ -57,8 +60,7 @@ public class StateTableSnapshotCompatibilityTest {
 				IntSerializer.INSTANCE,
 				new ArrayListSerializer<>(IntSerializer.INSTANCE));
 
-		final CopyOnWriteStateTableTest.MockInternalKeyContext<Integer> keyContext =
-			new CopyOnWriteStateTableTest.MockInternalKeyContext<>();
+		final MockInternalKeyContext<Integer> keyContext = new MockInternalKeyContext<>();
 
 		CopyOnWriteStateTable<Integer, Integer, ArrayList<Integer>> cowStateTable =
 			new CopyOnWriteStateTable<>(keyContext, metaInfo, keySerializer);
@@ -70,7 +72,8 @@ public class StateTableSnapshotCompatibilityTest {
 				list.add(r.nextInt(100));
 			}
 
-			cowStateTable.put(r.nextInt(10), r.nextInt(2), list);
+			keyContext.setCurrentKey(r.nextInt(10));
+			cowStateTable.put(r.nextInt(2), list);
 		}
 
 		StateSnapshot snapshot = cowStateTable.stateSnapshot();
@@ -81,7 +84,6 @@ public class StateTableSnapshotCompatibilityTest {
 		restoreStateTableFromSnapshot(nestedMapsStateTable, snapshot, keyContext.getKeyGroupRange());
 		snapshot.release();
 
-
 		Assert.assertEquals(cowStateTable.size(), nestedMapsStateTable.size());
 		for (StateEntry<Integer, Integer, ArrayList<Integer>> entry : cowStateTable) {
 			Assert.assertEquals(entry.getState(), nestedMapsStateTable.get(entry.getKey(), entry.getNamespace()));
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
index 5c92e67..2292516 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
@@ -25,7 +25,7 @@ import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.SnapshotResult;
-import org.apache.flink.runtime.state.heap.CopyOnWriteStateTable;
+import org.apache.flink.runtime.state.heap.CopyOnWriteStateMap;
 import org.apache.flink.runtime.state.internal.InternalKvState;
 import org.apache.flink.util.StateMigrationException;
 
@@ -40,6 +40,7 @@ import java.util.List;
 import java.util.concurrent.RunnableFuture;
 import java.util.function.Consumer;
 
+import static org.apache.flink.runtime.state.ttl.StateBackendTestContext.NUMBER_OF_KEY_GROUPS;
 import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.CoreMatchers.not;
 import static org.junit.Assert.assertEquals;
@@ -51,7 +52,7 @@ import static org.junit.Assume.assumeTrue;
 public abstract class TtlStateTestBase {
 	protected static final long TTL = 100;
 	private static final int INC_CLEANUP_ALL_KEYS =
-		(CopyOnWriteStateTable.DEFAULT_CAPACITY >> 1) + (CopyOnWriteStateTable.DEFAULT_CAPACITY >> 2) + 1;
+		((CopyOnWriteStateMap.DEFAULT_CAPACITY >> 1) + (CopyOnWriteStateMap.DEFAULT_CAPACITY >> 2) + 1) * NUMBER_OF_KEY_GROUPS;
 
 	protected MockTtlTimeProvider timeProvider;
 	protected StateBackendTestContext sbetc;
@@ -145,7 +146,7 @@ public abstract class TtlStateTestBase {
 	}
 
 	private void takeAndRestoreSnapshot() throws Exception {
-		restoreSnapshot(sbetc.takeSnapshot(), StateBackendTestContext.NUMBER_OF_KEY_GROUPS);
+		restoreSnapshot(sbetc.takeSnapshot(), NUMBER_OF_KEY_GROUPS);
 	}
 
 	protected void takeAndRestoreSnapshot(int numberOfKeyGroupsAfterRestore) throws Exception {
@@ -412,7 +413,7 @@ public abstract class TtlStateTestBase {
 		sbetc.setCurrentKey("k2");
 		ctx().update(ctx().updateUnexpired);
 
-		restoreSnapshot(snapshot, StateBackendTestContext.NUMBER_OF_KEY_GROUPS);
+		restoreSnapshot(snapshot, NUMBER_OF_KEY_GROUPS);
 
 		timeProvider.time = 180;
 		sbetc.setCurrentKey("k1");
@@ -443,7 +444,7 @@ public abstract class TtlStateTestBase {
 
 		initTest(getConfBuilder(TTL).cleanupIncrementally(5, true).build());
 
-		final int keysToUpdate = CopyOnWriteStateTable.DEFAULT_CAPACITY >> 3;
+		final int keysToUpdate = (CopyOnWriteStateMap.DEFAULT_CAPACITY >> 3) * NUMBER_OF_KEY_GROUPS;
 
 		timeProvider.time = 0;
 		// create enough keys to trigger incremental rehash
@@ -465,7 +466,7 @@ public abstract class TtlStateTestBase {
 		KeyedStateHandle snapshot = snapshotRunnableFuture.get().getJobManagerOwnedSnapshot();
 		// restore snapshot which should discard concurrent updates
 		timeProvider.time = 50;
-		restoreSnapshot(snapshot, StateBackendTestContext.NUMBER_OF_KEY_GROUPS);
+		restoreSnapshot(snapshot, NUMBER_OF_KEY_GROUPS);
 
 		// check rest unexpired, also after restore which should discard concurrent updates
 		checkUnexpiredKeys(keysToUpdate, INC_CLEANUP_ALL_KEYS, ctx().getUpdateEmpty);
@@ -481,7 +482,7 @@ public abstract class TtlStateTestBase {
 		checkUnexpiredKeys(0, keysToUpdate >> 1, ctx().getUnexpired);
 		triggerMoreIncrementalCleanupByOtherOps();
 		// check that concurrently updated and then restored with original values are expired
-		checkExpiredKeys(keysToUpdate, keysToUpdate *2);
+		checkExpiredKeys(keysToUpdate, keysToUpdate * 2);
 
 		timeProvider.time = 170;
 		// check rest expired and cleanup updated