You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2018/06/15 19:06:51 UTC

[1/2] flink git commit: [FLINK-9487][state] Prepare InternalTimerHeap for asynchronous snapshots

Repository: flink
Updated Branches:
  refs/heads/master 4a0dc8273 -> ae9178fa0


[FLINK-9487][state] Prepare InternalTimerHeap for asynchronous snapshots

This closes #6159.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/7e0eafa7
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/7e0eafa7
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/7e0eafa7

Branch: refs/heads/master
Commit: 7e0eafa74d81a805d1445fb74e967acaf455e967
Parents: 4a0dc82
Author: Stefan Richter <s....@data-artisans.com>
Authored: Mon Jun 11 14:48:06 2018 +0200
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Fri Jun 15 21:06:19 2018 +0200

----------------------------------------------------------------------
 .../runtime/state/KeyGroupPartitioner.java      | 297 +++++++++++++++++++
 .../flink/runtime/state/StateSnapshot.java      |  67 +++++
 .../state/heap/AbstractStateTableSnapshot.java  |   9 +-
 .../state/heap/CopyOnWriteStateTable.java       |  44 +--
 .../heap/CopyOnWriteStateTableSnapshot.java     | 170 ++++++-----
 .../state/heap/HeapKeyedStateBackend.java       |  11 +-
 .../state/heap/NestedMapsStateTable.java        |  14 +-
 .../flink/runtime/state/heap/StateTable.java    |   3 +-
 .../runtime/state/heap/StateTableSnapshot.java  |  45 ---
 .../state/KeyGroupPartitionerTestBase.java      | 176 +++++++++++
 .../state/heap/CopyOnWriteStateTableTest.java   |   7 +-
 .../heap/StateTableKeyGroupPartitionerTest.java | 102 +++++++
 .../StateTableSnapshotCompatibilityTest.java    |  10 +-
 .../api/operators/InternalTimerHeap.java        |  18 +-
 .../operators/InternalTimerHeapSnapshot.java    |  97 ++++++
 .../KeyGroupPartitionerForTimersTest.java       |  36 +++
 16 files changed, 944 insertions(+), 162 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java
new file mode 100644
index 0000000..673a0ef
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java
@@ -0,0 +1,297 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+
+/**
+ * Abstract class that contains the base algorithm for partitioning data into key-groups. This algorithm currently works
+ * with two array (input, output) for optimal algorithmic complexity. Notice that this could also be implemented over a
+ * single array, using some cuckoo-hashing-style element replacement. This would have worse algorithmic complexity but
+ * better space efficiency. We currently prefer the trade-off in favor of better algorithmic complexity.
+ *
+ * @param <T> type of the partitioned elements.
+ */
+public class KeyGroupPartitioner<T> {
+
+	/**
+	 * The input data for the partitioning. All elements to consider must be densely in the index interval
+	 * [0, {@link #numberOfElements}[, without null values.
+	 */
+	@Nonnull
+	protected final T[] partitioningSource;
+
+	/**
+	 * The output array for the partitioning. The size must be {@link #numberOfElements} (or bigger).
+	 */
+	@Nonnull
+	protected final T[] partitioningDestination;
+
+	/** Total number of input elements. */
+	@Nonnegative
+	protected final int numberOfElements;
+
+	/** The total number of key-groups in the job. */
+	@Nonnegative
+	protected final int totalKeyGroups;
+
+	/** The key-group range for the input data, covered in this partitioning. */
+	@Nonnull
+	protected final KeyGroupRange keyGroupRange;
+
+	/**
+	 * This bookkeeping array is used to count the elements in each key-group. In a second step, it is transformed into
+	 * a histogram by accumulation.
+	 */
+	@Nonnull
+	protected final int[] counterHistogram;
+
+	/**
+	 * This is a helper array that caches the key-group for each element, so we do not have to compute them twice.
+	 */
+	@Nonnull
+	protected final int[] elementKeyGroups;
+
+	/** Cached value of keyGroupRange#firstKeyGroup. */
+	@Nonnegative
+	protected final int firstKeyGroup;
+
+	/** Function to extract the key from a given element. */
+	@Nonnull
+	protected final KeyExtractorFunction<T> keyExtractorFunction;
+
+	/** Function to write an element to a {@link DataOutputView}. */
+	@Nonnull
+	protected final ElementWriterFunction<T> elementWriterFunction;
+
+	/** Cached result. */
+	@Nullable
+	protected StateSnapshot.KeyGroupPartitionedSnapshot computedResult;
+
+	/**
+	 * Creates a new {@link KeyGroupPartitioner}.
+	 *
+	 * @param partitioningSource the input for the partitioning. All elements must be densely packed in the index
+	 *                              interval [0, {@link #numberOfElements}[, without null values.
+	 * @param numberOfElements the number of elements to consider from the input, starting at input index 0.
+	 * @param partitioningDestination the output of the partitioning. Must have capacity of at least numberOfElements.
+	 * @param keyGroupRange the key-group range of the data that will be partitioned by this instance.
+	 * @param totalKeyGroups the total number of key groups in the job.
+	 * @param keyExtractorFunction this function extracts the partition key from an element.
+	 */
+	public KeyGroupPartitioner(
+		@Nonnull T[] partitioningSource,
+		@Nonnegative int numberOfElements,
+		@Nonnull T[] partitioningDestination,
+		@Nonnull KeyGroupRange keyGroupRange,
+		@Nonnegative int totalKeyGroups,
+		@Nonnull KeyExtractorFunction<T> keyExtractorFunction,
+		@Nonnull ElementWriterFunction<T> elementWriterFunction) {
+
+		Preconditions.checkState(partitioningSource != partitioningDestination);
+		Preconditions.checkState(partitioningSource.length >= numberOfElements);
+		Preconditions.checkState(partitioningDestination.length >= numberOfElements);
+
+		this.partitioningSource = partitioningSource;
+		this.partitioningDestination = partitioningDestination;
+		this.numberOfElements = numberOfElements;
+		this.keyGroupRange = keyGroupRange;
+		this.totalKeyGroups = totalKeyGroups;
+		this.keyExtractorFunction = keyExtractorFunction;
+		this.elementWriterFunction = elementWriterFunction;
+		this.firstKeyGroup = keyGroupRange.getStartKeyGroup();
+		this.elementKeyGroups = new int[numberOfElements];
+		this.counterHistogram = new int[keyGroupRange.getNumberOfKeyGroups()];
+		this.computedResult = null;
+	}
+
+	/**
+	 * Partitions the data into key-groups and returns the result via {@link PartitioningResult}.
+	 */
+	public StateSnapshot.KeyGroupPartitionedSnapshot partitionByKeyGroup() {
+		if (computedResult == null) {
+			reportAllElementKeyGroups();
+			buildHistogramByAccumulatingCounts();
+			executePartitioning();
+		}
+		return computedResult;
+	}
+
+	/**
+	 * This method iterates over the input data and reports the key-group for each element.
+	 */
+	protected void reportAllElementKeyGroups() {
+
+		Preconditions.checkState(partitioningSource.length >= numberOfElements);
+
+		for (int i = 0; i < numberOfElements; ++i) {
+			int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(
+				keyExtractorFunction.extractKeyFromElement(partitioningSource[i]), totalKeyGroups);
+			reportKeyGroupOfElementAtIndex(i, keyGroup);
+		}
+	}
+
+	/**
+	 * This method reports in the bookkeeping data that the element at the given index belongs to the given key-group.
+	 */
+	protected void reportKeyGroupOfElementAtIndex(int index, int keyGroup) {
+		final int keyGroupIndex = keyGroup - firstKeyGroup;
+		elementKeyGroups[index] = keyGroupIndex;
+		++counterHistogram[keyGroupIndex];
+	}
+
+	/**
+	 * This method creates a histogram from the counts per key-group in {@link #counterHistogram}.
+	 */
+	private void buildHistogramByAccumulatingCounts() {
+		int sum = 0;
+		for (int i = 0; i < counterHistogram.length; ++i) {
+			int currentSlotValue = counterHistogram[i];
+			counterHistogram[i] = sum;
+			sum += currentSlotValue;
+		}
+
+		// sanity check that the sum matches the expected number of elements.
+		Preconditions.checkState(sum == numberOfElements);
+	}
+
+	private void executePartitioning() {
+
+		// We repartition the entries by their pre-computed key-groups, using the histogram values as write indexes
+		for (int inIdx = 0; inIdx < numberOfElements; ++inIdx) {
+			int effectiveKgIdx = elementKeyGroups[inIdx];
+			int outIdx = counterHistogram[effectiveKgIdx]++;
+			partitioningDestination[outIdx] = partitioningSource[inIdx];
+		}
+
+		this.computedResult = new PartitioningResult<>(
+			elementWriterFunction,
+			firstKeyGroup,
+			counterHistogram,
+			partitioningDestination);
+	}
+
+	/**
+	 * This represents the result of key-group partitioning. The data in {@link #partitionedElements} is partitioned
+	 * w.r.t. {@link KeyGroupPartitioner#keyGroupRange}.
+	 */
+	public static class PartitioningResult<T> implements StateSnapshot.KeyGroupPartitionedSnapshot {
+
+		/**
+		 * Function to write one element to a {@link DataOutputView}.
+		 */
+		@Nonnull
+		private final ElementWriterFunction<T> elementWriterFunction;
+
+		/**
+		 * The exclusive-end-offsets for all key-groups of the covered range for the partitioning. Exclusive-end-offset
+		 * for key-group n is under keyGroupOffsets[n - firstKeyGroup].
+		 */
+		@Nonnull
+		private final int[] keyGroupOffsets;
+
+		/**
+		 * Array with elements that are partitioned w.r.t. the covered key-group range. The start offset for each
+		 * key-group is in {@link #keyGroupOffsets}.
+		 */
+		@Nonnull
+		private final T[] partitionedElements;
+
+		/**
+		 * The first key-group of the range covered in the partitioning.
+		 */
+		@Nonnegative
+		private final int firstKeyGroup;
+
+		PartitioningResult(
+			@Nonnull ElementWriterFunction<T> elementWriterFunction,
+			@Nonnegative int firstKeyGroup,
+			@Nonnull int[] keyGroupEndOffsets,
+			@Nonnull T[] partitionedElements) {
+			this.elementWriterFunction = elementWriterFunction;
+			this.firstKeyGroup = firstKeyGroup;
+			this.keyGroupOffsets = keyGroupEndOffsets;
+			this.partitionedElements = partitionedElements;
+		}
+
+		@Nonnegative
+		private int getKeyGroupStartOffsetInclusive(int keyGroup) {
+			int idx = keyGroup - firstKeyGroup - 1;
+			return idx < 0 ? 0 : keyGroupOffsets[idx];
+		}
+
+		@Nonnegative
+		private int getKeyGroupEndOffsetExclusive(int keyGroup) {
+			return keyGroupOffsets[keyGroup - firstKeyGroup];
+		}
+
+		@Override
+		public void writeMappingsInKeyGroup(@Nonnull DataOutputView dov, int keyGroupId) throws IOException {
+
+			int startOffset = getKeyGroupStartOffsetInclusive(keyGroupId);
+			int endOffset = getKeyGroupEndOffsetExclusive(keyGroupId);
+
+			// write number of mappings in key-group
+			dov.writeInt(endOffset - startOffset);
+
+			// write mappings
+			for (int i = startOffset; i < endOffset; ++i) {
+				elementWriterFunction.writeElement(partitionedElements[i], dov);
+			}
+		}
+	}
+
+	/**
+	 * @param <T> type of the element from which we extract the key.
+	 */
+	@FunctionalInterface
+	public interface KeyExtractorFunction<T> {
+
+		/**
+		 * Returns the key for the given element by which the key-group can be computed.
+		 */
+		@Nonnull
+		Object extractKeyFromElement(@Nonnull T element);
+	}
+
+	/**
+	 * This functional interface defines how one element is written to a {@link DataOutputView}.
+	 *
+	 * @param <T> type of the written elements.
+	 */
+	@FunctionalInterface
+	public interface ElementWriterFunction<T> {
+
+		/**
+		 * This method defines how to write a single element to the output.
+		 *
+		 * @param element the element to be written.
+		 * @param dov     the output view to write the element.
+		 * @throws IOException on write-related problems.
+		 */
+		void writeElement(@Nonnull T element, @Nonnull DataOutputView dov) throws IOException;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshot.java
new file mode 100644
index 0000000..1fcac5c
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshot.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.memory.DataOutputView;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import java.io.IOException;
+
+/**
+ * General interface for state snapshots that should be written partitioned by key-groups.
+ * All snapshots should be released after usage. This interface outlines the asynchronous snapshot life-cycle, which
+ * typically looks as follows. In the synchronous part of a checkpoint, an instance of {@link StateSnapshot} is produced
+ * for a state and captures the state at this point in time. Then, in the asynchronous part of the checkpoint, the user
+ * calls {@link #partitionByKeyGroup()} to ensure that the snapshot is partitioned into key-groups. For state that is
+ * already partitioned, this can be a NOP. The returned {@link KeyGroupPartitionedSnapshot} can be used by the caller
+ * to write the state by key-group. As a last step, when the state is completely written, the user calls
+ * {@link #release()}.
+ */
+public interface StateSnapshot {
+
+	/**
+	 * This method partitions the snapshot by key-group and then returns a {@link KeyGroupPartitionedSnapshot}.
+	 */
+	@Nonnull
+	KeyGroupPartitionedSnapshot partitionByKeyGroup();
+
+	/**
+	 * Release the snapshot. All snapshots should be released when they are no longer used because some implementation
+	 * can only release resources after a release. Produced {@link KeyGroupPartitionedSnapshot} should no longer be used
+	 * after calling this method.
+	 */
+	void release();
+
+	/**
+	 * Interface for writing a snapshot after it is partitioned into key-groups.
+	 */
+	interface KeyGroupPartitionedSnapshot {
+		/**
+		 * Writes the data for the specified key-group to the output. You must call {@link #partitionByKeyGroup()} once
+		 * before first calling this method.
+		 *
+		 * @param dov        the output.
+		 * @param keyGroupId the key-group to write.
+		 * @throws IOException on write-related problems.
+		 */
+		void writeMappingsInKeyGroup(@Nonnull DataOutputView dov, @Nonnegative int keyGroupId) throws IOException;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java
index b0d7727..03f253b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/AbstractStateTableSnapshot.java
@@ -19,14 +19,15 @@
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.util.Preconditions;
 
 /**
- * Abstract class to encapsulate the logic to take snapshots of {@link StateTable} implementations and also defines how
- * the snapshot is written during the serialization phase of checkpointing.
+ * Abstract base class for snapshots of a {@link StateTable}. Offers a way to serialize the snapshot (by key-group).
+ * All snapshots should be released after usage.
  */
 @Internal
-abstract class AbstractStateTableSnapshot<K, N, S, T extends StateTable<K, N, S>> implements StateTableSnapshot {
+abstract class AbstractStateTableSnapshot<K, N, S, T extends StateTable<K, N, S>> implements StateSnapshot {
 
 	/**
 	 * The {@link StateTable} from which this snapshot was created.
@@ -48,4 +49,4 @@ abstract class AbstractStateTableSnapshot<K, N, S, T extends StateTable<K, N, S>
 	@Override
 	public void release() {
 	}
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
index 4ecb0ed..d28ed46 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
@@ -28,6 +28,9 @@ import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
 import java.util.Arrays;
 import java.util.ConcurrentModificationException;
 import java.util.Iterator;
@@ -129,7 +132,8 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	/**
 	 * Empty entry that we use to bootstrap our {@link CopyOnWriteStateTable.StateEntryIterator}.
 	 */
-	private static final StateTableEntry<?, ?, ?> ITERATOR_BOOTSTRAP_ENTRY = new StateTableEntry<>();
+	private static final StateTableEntry<?, ?, ?> ITERATOR_BOOTSTRAP_ENTRY =
+		new StateTableEntry<>(new Object(), new Object(), new Object(), 0, null, 0, 0);
 
 	/**
 	 * Maintains an ordered set of version ids that are still in use by unreleased snapshots.
@@ -291,10 +295,9 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 
 	@Override
 	public Stream<K> getKeys(N namespace) {
-		Iterable<StateEntry<K, N, S>> iterable = () -> iterator();
-		return StreamSupport.stream(iterable.spliterator(), false)
+		return StreamSupport.stream(spliterator(), false)
 			.filter(entry -> entry.getNamespace().equals(namespace))
-			.map(entry -> entry.getKey());
+			.map(StateEntry::getKey);
 	}
 
 	@Override
@@ -554,6 +557,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 
 	// Iteration  ------------------------------------------------------------------------------------------------------
 
+	@Nonnull
 	@Override
 	public Iterator<StateEntry<K, N, S>> iterator() {
 		return new StateEntryIterator();
@@ -878,27 +882,32 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	 * @param <N> type of namespace.
 	 * @param <S> type of state.
 	 */
-	static class StateTableEntry<K, N, S> implements StateEntry<K, N, S> {
+	@VisibleForTesting
+	protected static class StateTableEntry<K, N, S> implements StateEntry<K, N, S> {
 
 		/**
 		 * The key. Assumed to be immutable and not null.
 		 */
+		@Nonnull
 		final K key;
 
 		/**
 		 * The namespace. Assumed to be immutable and not null.
 		 */
+		@Nonnull
 		final N namespace;
 
 		/**
 		 * The state. This is not final to allow exchanging the object for copy-on-write. Can be null.
 		 */
+		@Nullable
 		S state;
 
 		/**
 		 * Link to another {@link StateTableEntry}. This is used to resolve collisions in the
 		 * {@link CopyOnWriteStateTable} through chaining.
 		 */
+		@Nullable
 		StateTableEntry<K, N, S> next;
 
 		/**
@@ -916,22 +925,18 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		 */
 		final int hash;
 
-		StateTableEntry() {
-			this(null, null, null, 0, null, 0, 0);
-		}
-
 		StateTableEntry(StateTableEntry<K, N, S> other, int entryVersion) {
 			this(other.key, other.namespace, other.state, other.hash, other.next, entryVersion, other.stateVersion);
 		}
 
 		StateTableEntry(
-				K key,
-				N namespace,
-				S state,
-				int hash,
-				StateTableEntry<K, N, S> next,
-				int entryVersion,
-				int stateVersion) {
+			@Nonnull K key,
+			@Nonnull N namespace,
+			@Nullable S state,
+			int hash,
+			@Nullable StateTableEntry<K, N, S> next,
+			int entryVersion,
+			int stateVersion) {
 			this.key = key;
 			this.namespace = namespace;
 			this.hash = hash;
@@ -941,7 +946,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 			this.stateVersion = stateVersion;
 		}
 
-		public final void setState(S value, int mapVersion) {
+		public final void setState(@Nullable S value, int mapVersion) {
 			// naturally, we can update the state version every time we replace the old state with a different object
 			if (value != state) {
 				this.state = value;
@@ -949,16 +954,19 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 			}
 		}
 
+		@Nonnull
 		@Override
 		public K getKey() {
 			return key;
 		}
 
+		@Nonnull
 		@Override
 		public N getNamespace() {
 			return namespace;
 		}
 
+		@Nullable
 		@Override
 		public S getState() {
 			return state;
@@ -1010,7 +1018,7 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 		private StateTableEntry<K, N, S>[] activeTable;
 		private int nextTablePosition;
 		private StateTableEntry<K, N, S> nextEntry;
-		private int expectedModCount = modCount;
+		private int expectedModCount;
 
 		StateEntryIterator() {
 			this.activeTable = primaryTable;

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
index 2ac88b3..cf0056e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
@@ -19,12 +19,16 @@
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.state.KeyGroupPartitioner;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.StateSnapshot;
 
-import java.io.IOException;
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 /**
  * This class represents the snapshot of a {@link CopyOnWriteStateTable} and has a role in operator state checkpointing. Besides
@@ -53,39 +57,43 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 	private final int snapshotVersion;
 
 	/**
-	 * The number of entries in the {@link CopyOnWriteStateTable} at the time of creating this snapshot.
-	 */
-	private final int stateTableSize;
-
-	/**
 	 * The state table entries, as by the time this snapshot was created. Objects in this array may or may not be deep
 	 * copies of the current entries in the {@link CopyOnWriteStateTable} that created this snapshot. This depends for each entry
 	 * on whether or not it was subject to copy-on-write operations by the {@link CopyOnWriteStateTable}.
 	 */
+	@Nonnull
 	private final CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshotData;
 
-	/**
-	 * Offsets for the individual key-groups. This is lazily created when the snapshot is grouped by key-group during
-	 * the process of writing this snapshot to an output as part of checkpointing.
-	 */
-	private int[] keyGroupOffsets;
+	/** The number of (non-null) entries in snapshotData. */
+	@Nonnegative
+	private final int numberOfEntriesInSnapshotData;
 
 	/**
 	 * A local duplicate of the table's key serializer.
 	 */
+	@Nonnull
 	private final TypeSerializer<K> localKeySerializer;
 
 	/**
 	 * A local duplicate of the table's namespace serializer.
 	 */
+	@Nonnull
 	private final TypeSerializer<N> localNamespaceSerializer;
 
 	/**
 	 * A local duplicate of the table's state serializer.
 	 */
+	@Nonnull
 	private final TypeSerializer<S> localStateSerializer;
 
 	/**
+	 * Result of partitioning the snapshot by key-group. This is lazily created in the process of writing this snapshot
+	 * to an output as part of checkpointing.
+	 */
+	@Nullable
+	private StateSnapshot.KeyGroupPartitionedSnapshot partitionedStateTableSnapshot;
+
+	/**
 	 * Creates a new {@link CopyOnWriteStateTableSnapshot}.
 	 *
 	 * @param owningStateTable the {@link CopyOnWriteStateTable} for which this object represents a snapshot.
@@ -95,7 +103,8 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 		super(owningStateTable);
 		this.snapshotData = owningStateTable.snapshotTableArrays();
 		this.snapshotVersion = owningStateTable.getStateTableVersion();
-		this.stateTableSize = owningStateTable.size();
+		this.numberOfEntriesInSnapshotData = owningStateTable.size();
+
 
 		// We create duplicates of the serializers for the async snapshot, because TypeSerializer
 		// might be stateful and shared with the event processing thread.
@@ -103,7 +112,7 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 		this.localNamespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer().duplicate();
 		this.localStateSerializer = owningStateTable.metaInfo.getStateSerializer().duplicate();
 
-		this.keyGroupOffsets = null;
+		this.partitionedStateTableSnapshot = null;
 	}
 
 	/**
@@ -123,47 +132,32 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 	 * As a possible future optimization, we could perform the repartitioning in-place, using a scheme similar to the
 	 * cuckoo cycles in cuckoo hashing. This can trade some performance for a smaller memory footprint.
 	 */
+	@Nonnull
 	@SuppressWarnings("unchecked")
-	private void partitionEntriesByKeyGroup() {
-
-		// We only have to perform this step once before the first key-group is written
-		if (null != keyGroupOffsets) {
-			return;
-		}
-
-		final KeyGroupRange keyGroupRange = owningStateTable.keyContext.getKeyGroupRange();
-		final int totalKeyGroups = owningStateTable.keyContext.getNumberOfKeyGroups();
-		final int baseKgIdx = keyGroupRange.getStartKeyGroup();
-		final int[] histogram = new int[keyGroupRange.getNumberOfKeyGroups() + 1];
-
-		CopyOnWriteStateTable.StateTableEntry<K, N, S>[] unfold = new CopyOnWriteStateTable.StateTableEntry[stateTableSize];
-
-		// 1) In this step we i) 'unfold' the linked list of entries to a flat array and ii) build a histogram for key-groups
-		int unfoldIndex = 0;
-		for (CopyOnWriteStateTable.StateTableEntry<K, N, S> entry : snapshotData) {
-			while (null != entry) {
-				int effectiveKgIdx =
-						KeyGroupRangeAssignment.computeKeyGroupForKeyHash(entry.key.hashCode(), totalKeyGroups) - baseKgIdx + 1;
-				++histogram[effectiveKgIdx];
-				unfold[unfoldIndex++] = entry;
-				entry = entry.next;
-			}
-		}
-
-		// 2) We accumulate the histogram bins to obtain key-group ranges in the final array
-		for (int i = 1; i < histogram.length; ++i) {
-			histogram[i] += histogram[i - 1];
-		}
-
-		// 3) We repartition the entries by key-group, using the histogram values as write indexes
-		for (CopyOnWriteStateTable.StateTableEntry<K, N, S> t : unfold) {
-			int effectiveKgIdx =
-					KeyGroupRangeAssignment.computeKeyGroupForKeyHash(t.key.hashCode(), totalKeyGroups) - baseKgIdx;
-			snapshotData[histogram[effectiveKgIdx]++] = t;
+	@Override
+	public KeyGroupPartitionedSnapshot partitionByKeyGroup() {
+
+		if (partitionedStateTableSnapshot == null) {
+
+			final InternalKeyContext<K> keyContext = owningStateTable.keyContext;
+			final KeyGroupRange keyGroupRange = keyContext.getKeyGroupRange();
+			final int numberOfKeyGroups = keyContext.getNumberOfKeyGroups();
+
+			final StateTableKeyGroupPartitioner<K, N, S> keyGroupPartitioner = new StateTableKeyGroupPartitioner<>(
+				snapshotData,
+				numberOfEntriesInSnapshotData,
+				keyGroupRange,
+				numberOfKeyGroups,
+				(element, dov) -> {
+					localNamespaceSerializer.serialize(element.namespace, dov);
+					localKeySerializer.serialize(element.key, dov);
+					localStateSerializer.serialize(element.state, dov);
+				});
+
+			partitionedStateTableSnapshot = keyGroupPartitioner.partitionByKeyGroup();
 		}
 
-		// 4) As byproduct, we also created the key-group offsets
-		this.keyGroupOffsets = histogram;
+		return partitionedStateTableSnapshot;
 	}
 
 	@Override
@@ -171,36 +165,56 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 		owningStateTable.releaseSnapshot(this);
 	}
 
-	@Override
-	public void writeMappingsInKeyGroup(DataOutputView dov, int keyGroupId) throws IOException {
-
-		if (null == keyGroupOffsets) {
-			partitionEntriesByKeyGroup();
-		}
-
-		final CopyOnWriteStateTable.StateTableEntry<K, N, S>[] groupedOut = snapshotData;
-		KeyGroupRange keyGroupRange = owningStateTable.keyContext.getKeyGroupRange();
-		int keyGroupOffsetIdx = keyGroupId - keyGroupRange.getStartKeyGroup() - 1;
-		int startOffset = keyGroupOffsetIdx < 0 ? 0 : keyGroupOffsets[keyGroupOffsetIdx];
-		int endOffset = keyGroupOffsets[keyGroupOffsetIdx + 1];
-
-		// write number of mappings in key-group
-		dov.writeInt(endOffset - startOffset);
-
-		// write mappings
-		for (int i = startOffset; i < endOffset; ++i) {
-			CopyOnWriteStateTable.StateTableEntry<K, N, S> toWrite = groupedOut[i];
-			groupedOut[i] = null; // free asap for GC
-			localNamespaceSerializer.serialize(toWrite.namespace, dov);
-			localKeySerializer.serialize(toWrite.key, dov);
-			localStateSerializer.serialize(toWrite.state, dov);
-		}
-	}
-
 	/**
 	 * Returns true iff the given state table is the owner of this snapshot object.
 	 */
 	boolean isOwner(CopyOnWriteStateTable<K, N, S> stateTable) {
 		return stateTable == owningStateTable;
 	}
+
+	/**
+	 * This class is the implementation of {@link KeyGroupPartitioner} for {@link CopyOnWriteStateTable}. This class
+	 * swaps input and output in {@link #reportAllElementKeyGroups()} for performance reasons, so that we can reuse
+	 * the non-flattened original snapshot array as partitioning output.
+	 *
+	 * @param <K> type of key.
+	 * @param <N> type of namespace.
+	 * @param <S> type of state value.
+	 */
+	@VisibleForTesting
+	protected static final class StateTableKeyGroupPartitioner<K, N, S>
+		extends KeyGroupPartitioner<CopyOnWriteStateTable.StateTableEntry<K, N, S>> {
+
+		@SuppressWarnings("unchecked")
+		StateTableKeyGroupPartitioner(
+			@Nonnull CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshotData,
+			@Nonnegative int stateTableSize,
+			@Nonnull KeyGroupRange keyGroupRange,
+			@Nonnegative int totalKeyGroups,
+			@Nonnull ElementWriterFunction<CopyOnWriteStateTable.StateTableEntry<K, N, S>> elementWriterFunction) {
+
+			super(
+				new CopyOnWriteStateTable.StateTableEntry[stateTableSize],
+				stateTableSize,
+				snapshotData,
+				keyGroupRange,
+				totalKeyGroups,
+				CopyOnWriteStateTable.StateTableEntry::getKey,
+				elementWriterFunction);
+		}
+
+		@Override
+		protected void reportAllElementKeyGroups() {
+			// In this step we i) 'flatten' the linked list of entries to a second array and ii) report key-groups.
+			int flattenIndex = 0;
+			for (CopyOnWriteStateTable.StateTableEntry<K, N, S> entry : partitioningDestination) {
+				while (null != entry) {
+					final int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(entry.key, totalKeyGroups);
+					reportKeyGroupOfElementAtIndex(flattenIndex, keyGroup);
+					partitioningSource[flattenIndex++] = entry;
+					entry = entry.next;
+				}
+			}
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index ab91ee1..76479b0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -55,6 +55,7 @@ import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.SnapshotStrategy;
+import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
@@ -599,7 +600,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 			final Map<String, Integer> kVStateToId = new HashMap<>(stateTables.size());
 
-			final Map<String, StateTableSnapshot> cowStateStableSnapshots =
+			final Map<String, StateSnapshot> cowStateStableSnapshots =
 				new HashMap<>(stateTables.size());
 
 			for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
@@ -653,7 +654,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 						unregisterAndCloseStreamAndResultExtractor();
 
-						for (StateTableSnapshot tableSnapshot : cowStateStableSnapshots.values()) {
+						for (StateSnapshot tableSnapshot : cowStateStableSnapshots.values()) {
 							tableSnapshot.release();
 						}
 					}
@@ -689,12 +690,14 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 							keyGroupRangeOffsets[keyGroupPos] = localStream.getPos();
 							outView.writeInt(keyGroupId);
 
-							for (Map.Entry<String, StateTableSnapshot> kvState : cowStateStableSnapshots.entrySet()) {
+							for (Map.Entry<String, StateSnapshot> kvState : cowStateStableSnapshots.entrySet()) {
+								StateSnapshot.KeyGroupPartitionedSnapshot partitionedSnapshot =
+									kvState.getValue().partitionByKeyGroup();
 								try (OutputStream kgCompressionOut = keyGroupCompressionDecorator.decorateWithCompression(localStream)) {
 									String stateName = kvState.getKey();
 									DataOutputViewStreamWrapper kgCompressionView = new DataOutputViewStreamWrapper(kgCompressionOut);
 									kgCompressionView.writeShort(kVStateToId.get(stateName));
-									kvState.getValue().writeMappingsInKeyGroup(kgCompressionView, keyGroupId);
+									partitionedSnapshot.writeMappingsInKeyGroup(kgCompressionView, keyGroupId);
 								} // this will just close the outer compression stream
 							}
 						}

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
index 870ecbf..18551b5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
@@ -23,9 +23,12 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.util.Preconditions;
 
+import javax.annotation.Nonnull;
+
 import java.io.IOException;
 import java.util.Arrays;
 import java.util.Collections;
@@ -333,12 +336,19 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	 * @param <S> type of state.
 	 */
 	static class NestedMapsStateTableSnapshot<K, N, S>
-			extends AbstractStateTableSnapshot<K, N, S, NestedMapsStateTable<K, N, S>> {
+			extends AbstractStateTableSnapshot<K, N, S, NestedMapsStateTable<K, N, S>>
+			implements StateSnapshot.KeyGroupPartitionedSnapshot {
 
 		NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> owningTable) {
 			super(owningTable);
 		}
 
+		@Nonnull
+		@Override
+		public KeyGroupPartitionedSnapshot partitionByKeyGroup() {
+			return this;
+		}
+
 		/**
 		 * Implementation note: we currently chose the same format between {@link NestedMapsStateTable} and
 		 * {@link CopyOnWriteStateTable}.
@@ -349,7 +359,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 		 * implementations).
 		 */
 		@Override
-		public void writeMappingsInKeyGroup(DataOutputView dov, int keyGroupId) throws IOException {
+		public void writeMappingsInKeyGroup(@Nonnull DataOutputView dov, int keyGroupId) throws IOException {
 			final Map<N, Map<K, S>> keyGroupMap = owningStateTable.getMapForKeyGroup(keyGroupId);
 			if (null != keyGroupMap) {
 				TypeSerializer<K> keySerializer = owningStateTable.keyContext.getKeySerializer();

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
index 8c07b25..de2290a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.state.heap;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.util.Preconditions;
 
@@ -182,7 +183,7 @@ public abstract class StateTable<K, N, S> {
 
 	// Snapshot / Restore -------------------------------------------------------------------------
 
-	abstract StateTableSnapshot createSnapshot();
+	abstract StateSnapshot createSnapshot();
 
 	public abstract void put(K key, int keyGroup, N namespace, S state);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableSnapshot.java
deleted file mode 100644
index d4244d7..0000000
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTableSnapshot.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state.heap;
-
-import org.apache.flink.core.memory.DataOutputView;
-
-import java.io.IOException;
-
-/**
- * Interface for the snapshots of a {@link StateTable}. Offers a way to serialize the snapshot (by key-group). All
- * snapshots should be released after usage.
- */
-interface StateTableSnapshot {
-
-	/**
-	 * Writes the data for the specified key-group to the output.
-	 *
-	 * @param dov the output
-	 * @param keyGroupId the key-group to write
-	 * @throws IOException on write related problems
-	 */
-	void writeMappingsInKeyGroup(DataOutputView dov, int keyGroupId) throws IOException;
-
-	/**
-	 * Release the snapshot. All snapshots should be released when they are no longer used because some implementation
-	 * can only release resources after a release.
-	 */
-	void release();
-}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupPartitionerTestBase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupPartitionerTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupPartitionerTestBase.java
new file mode 100644
index 0000000..9298743
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/KeyGroupPartitionerTestBase.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+
+import java.io.IOException;
+import java.lang.reflect.Array;
+import java.util.Collections;
+import java.util.IdentityHashMap;
+import java.util.Random;
+import java.util.Set;
+import java.util.function.Function;
+
+/**
+ * Abstract test base for implementations of {@link KeyGroupPartitioner}.
+ */
+public abstract class KeyGroupPartitionerTestBase<T> extends TestLogger {
+
+	private static final DataOutputView DUMMY_OUT_VIEW =
+		new DataOutputViewStreamWrapper(new ByteArrayOutputStreamWithPos(0));
+
+	@Nonnull
+	protected final KeyGroupPartitioner.KeyExtractorFunction<T> keyExtractorFunction;
+
+	@Nonnull
+	protected final Function<Random, T> elementGenerator;
+
+	protected KeyGroupPartitionerTestBase(
+		@Nonnull Function<Random, T> elementGenerator,
+		@Nonnull KeyGroupPartitioner.KeyExtractorFunction<T> keyExtractorFunction) {
+
+		this.elementGenerator = elementGenerator;
+		this.keyExtractorFunction = keyExtractorFunction;
+	}
+
+	@Test
+	public void testPartitionByKeyGroup() throws IOException {
+
+		final Random random = new Random(0x42);
+		testPartitionByKeyGroupForSize(0, random);
+		testPartitionByKeyGroupForSize(1, random);
+		testPartitionByKeyGroupForSize(2, random);
+		testPartitionByKeyGroupForSize(10, random);
+	}
+
+	@SuppressWarnings("unchecked")
+	private void testPartitionByKeyGroupForSize(int testSize, Random random) throws IOException {
+
+		final Set<T> allElementsIdentitySet = Collections.newSetFromMap(new IdentityHashMap<>());
+		final T[] data = generateTestInput(random, testSize, allElementsIdentitySet);
+
+		Assert.assertEquals(testSize, allElementsIdentitySet.size());
+
+		// Test with 5 key-groups.
+		final KeyGroupRange range = new KeyGroupRange(0, 4);
+		final int numberOfKeyGroups = range.getNumberOfKeyGroups();
+
+		final ValidatingElementWriterDummy<T> validatingElementWriter =
+			new ValidatingElementWriterDummy<>(keyExtractorFunction, numberOfKeyGroups, allElementsIdentitySet);
+
+		final KeyGroupPartitioner<T> testInstance = createPartitioner(data, testSize, range, numberOfKeyGroups, validatingElementWriter);
+		final StateSnapshot.KeyGroupPartitionedSnapshot result = testInstance.partitionByKeyGroup();
+
+		for (int keyGroup = 0; keyGroup < numberOfKeyGroups; ++keyGroup) {
+			validatingElementWriter.setCurrentKeyGroup(keyGroup);
+			result.writeMappingsInKeyGroup(DUMMY_OUT_VIEW, keyGroup);
+		}
+
+		validatingElementWriter.validateAllElementsSeen();
+	}
+
+	@SuppressWarnings("unchecked")
+	protected T[] generateTestInput(Random random, int numElementsToGenerate, Set<T> allElementsIdentitySet) {
+
+		final int arraySize = numElementsToGenerate > 1 ? numElementsToGenerate + 5 : numElementsToGenerate;
+		T element = elementGenerator.apply(random);
+		final T[] partitioningIn = (T[]) Array.newInstance(element.getClass(), arraySize);
+
+		for (int i = 0; i < numElementsToGenerate; ++i) {
+			partitioningIn[i] = element;
+			allElementsIdentitySet.add(element);
+			element = elementGenerator.apply(random);
+		}
+
+		Assert.assertEquals(numElementsToGenerate, allElementsIdentitySet.size());
+		return partitioningIn;
+	}
+
+	@SuppressWarnings("unchecked")
+	protected KeyGroupPartitioner<T> createPartitioner(
+		T[] data,
+		int numElements,
+		KeyGroupRange keyGroupRange,
+		int totalKeyGroups,
+		KeyGroupPartitioner.ElementWriterFunction<T> elementWriterFunction) {
+
+		final T[] partitioningOut = (T[]) Array.newInstance(data.getClass().getComponentType(), numElements);
+		return new KeyGroupPartitioner<>(
+			data,
+			numElements,
+			partitioningOut,
+			keyGroupRange,
+			totalKeyGroups,
+			keyExtractorFunction,
+			elementWriterFunction);
+	}
+
+
+	/**
+	 * Simple test implementation with validation .
+	 */
+	static final class ValidatingElementWriterDummy<T> implements KeyGroupPartitioner.ElementWriterFunction<T> {
+
+		@Nonnull
+		private final KeyGroupPartitioner.KeyExtractorFunction<T> keyExtractorFunction;
+		@Nonnegative
+		private final int numberOfKeyGroups;
+		@Nonnull
+		private final Set<T> allElementsSet;
+		@Nonnegative
+		private int currentKeyGroup;
+
+		ValidatingElementWriterDummy(
+			@Nonnull KeyGroupPartitioner.KeyExtractorFunction<T> keyExtractorFunction,
+			@Nonnegative int numberOfKeyGroups,
+			@Nonnull Set<T> allElementsSet) {
+			this.keyExtractorFunction = keyExtractorFunction;
+			this.numberOfKeyGroups = numberOfKeyGroups;
+			this.allElementsSet = allElementsSet;
+		}
+
+		@Override
+		public void writeElement(@Nonnull T element, @Nonnull DataOutputView dov) {
+			Assert.assertTrue(allElementsSet.remove(element));
+			Assert.assertEquals(
+				currentKeyGroup,
+				KeyGroupRangeAssignment.assignToKeyGroup(
+					keyExtractorFunction.extractKeyFromElement(element),
+					numberOfKeyGroups));
+		}
+
+		void validateAllElementsSeen() {
+			Assert.assertTrue(allElementsSet.isEmpty());
+		}
+
+		void setCurrentKeyGroup(int currentKeyGroup) {
+			this.currentKeyGroup = currentKeyGroup;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
index 09a0620..4f36d62 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.state.ArrayListSerializer;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.util.TestLogger;
 import org.junit.Assert;
@@ -442,15 +443,15 @@ public class CopyOnWriteStateTableTest extends TestLogger {
 		table.put(2, 0, 1, 2);
 
 
-		CopyOnWriteStateTableSnapshot<Integer, Integer, Integer> snapshot = table.createSnapshot();
+		final CopyOnWriteStateTableSnapshot<Integer, Integer, Integer> snapshot = table.createSnapshot();
 
 		try {
-
+			final StateSnapshot.KeyGroupPartitionedSnapshot partitionedSnapshot = snapshot.partitionByKeyGroup();
 			namespaceSerializer.disable();
 			keySerializer.disable();
 			stateSerializer.disable();
 
-			snapshot.writeMappingsInKeyGroup(
+			partitionedSnapshot.writeMappingsInKeyGroup(
 				new DataOutputViewStreamWrapper(
 					new ByteArrayOutputStreamWithPos(1024)), 0);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableKeyGroupPartitionerTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableKeyGroupPartitionerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableKeyGroupPartitionerTest.java
new file mode 100644
index 0000000..745719a
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableKeyGroupPartitionerTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.heap;
+
+import org.apache.flink.runtime.state.KeyGroupPartitioner;
+import org.apache.flink.runtime.state.KeyGroupPartitionerTestBase;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.VoidNamespace;
+import org.apache.flink.runtime.state.heap.CopyOnWriteStateTable.StateTableEntry;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.util.Random;
+import java.util.Set;
+
+/**
+ * Test for {@link org.apache.flink.runtime.state.heap.CopyOnWriteStateTableSnapshot.StateTableKeyGroupPartitioner}.
+ */
+public class StateTableKeyGroupPartitionerTest extends
+	KeyGroupPartitionerTestBase<StateTableEntry<Integer, VoidNamespace, Integer>> {
+
+	public StateTableKeyGroupPartitionerTest() {
+		super(random -> generateElement(random, null), StateTableEntry::getKey);
+	}
+
+	@SuppressWarnings("unchecked")
+	@Override
+	protected StateTableEntry<Integer, VoidNamespace, Integer>[] generateTestInput(
+		Random random,
+		int numElementsToGenerate,
+		Set<StateTableEntry<Integer, VoidNamespace, Integer>> allElementsIdentitySet) {
+
+		// we let the array size differ a bit from the test size to check this works
+		final int arraySize = numElementsToGenerate > 1 ? numElementsToGenerate + 5 : numElementsToGenerate;
+		final StateTableEntry<Integer, VoidNamespace, Integer>[] data = new StateTableEntry[arraySize];
+
+		while (numElementsToGenerate > 0) {
+
+			final int generateAsChainCount = Math.min(1 + random.nextInt(3) , numElementsToGenerate);
+
+			StateTableEntry<Integer, VoidNamespace, Integer> element = null;
+			for (int i = 0; i < generateAsChainCount; ++i) {
+				element = generateElement(random, element);
+				allElementsIdentitySet. add(element);
+			}
+
+			data[data.length - numElementsToGenerate + random.nextInt(generateAsChainCount)] = element;
+			numElementsToGenerate -= generateAsChainCount;
+		}
+
+		return data;
+	}
+
+	@Override
+	protected KeyGroupPartitioner<StateTableEntry<Integer, VoidNamespace, Integer>> createPartitioner(
+		StateTableEntry<Integer, VoidNamespace, Integer>[] data,
+		int numElements,
+		KeyGroupRange keyGroupRange,
+		int totalKeyGroups,
+		KeyGroupPartitioner.ElementWriterFunction<
+			StateTableEntry<Integer, VoidNamespace, Integer>> elementWriterFunction) {
+
+		return new CopyOnWriteStateTableSnapshot.StateTableKeyGroupPartitioner<>(
+			data,
+			numElements,
+			keyGroupRange,
+			totalKeyGroups,
+			elementWriterFunction);
+	}
+
+	private static StateTableEntry<Integer, VoidNamespace, Integer> generateElement(
+		@Nonnull Random random,
+		@Nullable StateTableEntry<Integer, VoidNamespace, Integer> next) {
+
+		Integer generatedKey =  random.nextInt() & Integer.MAX_VALUE;
+		return new StateTableEntry<>(
+			generatedKey,
+			VoidNamespace.INSTANCE,
+			random.nextInt(),
+			generatedKey.hashCode(),
+			next,
+			0,
+			0);
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
index 85bc177..0c8e8fe 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/heap/StateTableSnapshotCompatibilityTest.java
@@ -28,6 +28,8 @@ import org.apache.flink.runtime.state.ArrayListSerializer;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
 import org.apache.flink.runtime.state.RegisteredKeyedBackendStateMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshot;
+
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -67,7 +69,7 @@ public class StateTableSnapshotCompatibilityTest {
 			cowStateTable.put(r.nextInt(10), r.nextInt(2), list);
 		}
 
-		StateTableSnapshot snapshot = cowStateTable.createSnapshot();
+		StateSnapshot snapshot = cowStateTable.createSnapshot();
 
 		final NestedMapsStateTable<Integer, Integer, ArrayList<Integer>> nestedMapsStateTable =
 				new NestedMapsStateTable<>(keyContext, metaInfo);
@@ -95,14 +97,14 @@ public class StateTableSnapshotCompatibilityTest {
 
 	private static <K, N, S> void restoreStateTableFromSnapshot(
 			StateTable<K, N, S> stateTable,
-			StateTableSnapshot snapshot,
+			StateSnapshot snapshot,
 			KeyGroupRange keyGroupRange) throws IOException {
 
 		final ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(1024 * 1024);
 		final DataOutputViewStreamWrapper dov = new DataOutputViewStreamWrapper(out);
-
+		final StateSnapshot.KeyGroupPartitionedSnapshot keyGroupPartitionedSnapshot = snapshot.partitionByKeyGroup();
 		for (Integer keyGroup : keyGroupRange) {
-			snapshot.writeMappingsInKeyGroup(dov, keyGroup);
+			keyGroupPartitionedSnapshot.writeMappingsInKeyGroup(dov, keyGroup);
 		}
 
 		final ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(out.getBuf());

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerHeap.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerHeap.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerHeap.java
index 94b3572..0abc29d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerHeap.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerHeap.java
@@ -21,6 +21,7 @@ package org.apache.flink.streaming.api.operators;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.StateSnapshot;
 
 import javax.annotation.Nonnegative;
 import javax.annotation.Nonnull;
@@ -54,7 +55,7 @@ import static org.apache.flink.util.Preconditions.checkArgument;
  * @param <K> type of the key of the internal timers managed by this priority queue.
  * @param <N> type of the namespace of the internal timers managed by this priority queue.
  */
-public class InternalTimerHeap<K, N> implements Iterable<InternalTimer<K, N>> {//implements Queue<TimerHeapInternalTimer<K, N>>, Set<TimerHeapInternalTimer<K, N>> {
+public class InternalTimerHeap<K, N> implements Iterable<InternalTimer<K, N>> {
 
 	/**
 	 * A safe maximum size for arrays in the JVM.
@@ -136,7 +137,7 @@ public class InternalTimerHeap<K, N> implements Iterable<InternalTimer<K, N>> {/
 	 * @param namespace the timer namespace.
 	 * @return true iff a new timer with given timestamp, key, and namespace was added to the heap.
 	 */
-	public boolean scheduleTimer(long timestamp, K key, N namespace) {
+	public boolean scheduleTimer(long timestamp, @Nonnull K key, @Nonnull N namespace) {
 		return addInternal(new TimerHeapInternalTimer<>(timestamp, key, namespace));
 	}
 
@@ -148,7 +149,7 @@ public class InternalTimerHeap<K, N> implements Iterable<InternalTimer<K, N>> {/
 	 * @param namespace the timer namespace.
 	 * @return true iff a timer with given timestamp, key, and namespace was found and removed from the heap.
 	 */
-	public boolean stopTimer(long timestamp, K key, N namespace) {
+	public boolean stopTimer(long timestamp, @Nonnull K key, @Nonnull N namespace) {
 		return removeInternal(new TimerHeapInternalTimer<>(timestamp, key, namespace));
 	}
 
@@ -210,12 +211,14 @@ public class InternalTimerHeap<K, N> implements Iterable<InternalTimer<K, N>> {/
 	/**
 	 * Returns an unmodifiable set of all timers in the given key-group.
 	 */
+	@Nonnull
 	Set<InternalTimer<K, N>> getTimersForKeyGroup(@Nonnegative int keyGroupIdx) {
 		return Collections.unmodifiableSet(getDedupMapForKeyGroup(keyGroupIdx).keySet());
 	}
 
 	@VisibleForTesting
 	@SuppressWarnings("unchecked")
+	@Nonnull
 	List<Set<InternalTimer<K, N>>> getTimersByKeyGroup() {
 		List<Set<InternalTimer<K, N>>> result = new ArrayList<>(deduplicationMapsByKeyGroup.length);
 		for (int i = 0; i < deduplicationMapsByKeyGroup.length; ++i) {
@@ -224,6 +227,15 @@ public class InternalTimerHeap<K, N> implements Iterable<InternalTimer<K, N>> {/
 		return result;
 	}
 
+	@Nonnull
+	StateSnapshot snapshot(TimerHeapInternalTimer.TimerSerializer<K, N> serializer) {
+		return new InternalTimerHeapSnapshot<>(
+			Arrays.copyOfRange(queue, 1, size + 1),
+			serializer,
+			keyGroupRange,
+			totalNumberOfKeyGroups);
+	}
+
 	private boolean addInternal(TimerHeapInternalTimer<K, N> timer) {
 
 		if (getDedupMapForTimer(timer).putIfAbsent(timer, timer) == null) {

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerHeapSnapshot.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerHeapSnapshot.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerHeapSnapshot.java
new file mode 100644
index 0000000..6eb4057
--- /dev/null
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/InternalTimerHeapSnapshot.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.runtime.state.KeyGroupPartitioner;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.StateSnapshot;
+
+import javax.annotation.Nonnegative;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+/**
+ * This class represents the snapshot of an {@link InternalTimerHeap}.
+ *
+ * @param <K> type of key.
+ * @param <N> type of namespace.
+ */
+public class InternalTimerHeapSnapshot<K, N> implements StateSnapshot {
+
+	/** Copy of the heap array containing all the (immutable timers). */
+	@Nonnull
+	private final TimerHeapInternalTimer<K, N>[] timerHeapArrayCopy;
+
+	/** The timer serializer. */
+	@Nonnull
+	private final TimerHeapInternalTimer.TimerSerializer<K, N> timerSerializer;
+
+	/** The key-group range covered by this snapshot. */
+	@Nonnull
+	private final KeyGroupRange keyGroupRange;
+
+	/** The total number of key-groups in the job. */
+	@Nonnegative
+	private final int totalKeyGroups;
+
+	/** Result of partitioning the snapshot by key-group. */
+	@Nullable
+	private KeyGroupPartitionedSnapshot partitionedSnapshot;
+
+	InternalTimerHeapSnapshot(
+		@Nonnull TimerHeapInternalTimer<K, N>[] timerHeapArrayCopy,
+		@Nonnull TimerHeapInternalTimer.TimerSerializer<K, N> timerSerializer,
+		@Nonnull KeyGroupRange keyGroupRange,
+		@Nonnegative int totalKeyGroups) {
+
+		this.timerHeapArrayCopy = timerHeapArrayCopy;
+		this.timerSerializer = timerSerializer;
+		this.keyGroupRange = keyGroupRange;
+		this.totalKeyGroups = totalKeyGroups;
+	}
+
+	@SuppressWarnings("unchecked")
+	@Nonnull
+	@Override
+	public KeyGroupPartitionedSnapshot partitionByKeyGroup() {
+
+		if (partitionedSnapshot == null) {
+
+			TimerHeapInternalTimer<K, N>[] partitioningOutput = new TimerHeapInternalTimer[timerHeapArrayCopy.length];
+
+			KeyGroupPartitioner<TimerHeapInternalTimer<K, N>> timerPartitioner =
+				new KeyGroupPartitioner<>(
+					timerHeapArrayCopy,
+					timerHeapArrayCopy.length,
+					partitioningOutput,
+					keyGroupRange,
+					totalKeyGroups,
+					TimerHeapInternalTimer::getKey,
+					timerSerializer::serialize);
+
+			partitionedSnapshot = timerPartitioner.partitionByKeyGroup();
+		}
+
+		return partitionedSnapshot;
+	}
+
+	@Override
+	public void release() {
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/7e0eafa7/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/KeyGroupPartitionerForTimersTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/KeyGroupPartitionerForTimersTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/KeyGroupPartitionerForTimersTest.java
new file mode 100644
index 0000000..8ad7115
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/KeyGroupPartitionerForTimersTest.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.api.operators;
+
+import org.apache.flink.runtime.state.KeyGroupPartitioner;
+import org.apache.flink.runtime.state.KeyGroupPartitionerTestBase;
+import org.apache.flink.runtime.state.VoidNamespace;
+
+/**
+ * Test of {@link KeyGroupPartitioner} for timers.
+ */
+public class KeyGroupPartitionerForTimersTest
+	extends KeyGroupPartitionerTestBase<TimerHeapInternalTimer<Integer, VoidNamespace>> {
+
+	public KeyGroupPartitionerForTimersTest() {
+		super(
+			(random -> new TimerHeapInternalTimer<>(42L, random.nextInt() & Integer.MAX_VALUE, VoidNamespace.INSTANCE)),
+			TimerHeapInternalTimer::getKey);
+	}
+}


[2/2] flink git commit: [hotfix][state] Introduce hard size limit for CopyOnWriteStateTable

Posted by sr...@apache.org.
[hotfix][state] Introduce hard size limit for CopyOnWriteStateTable


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ae9178fa
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ae9178fa
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ae9178fa

Branch: refs/heads/master
Commit: ae9178fa0f93c67d081a3404c1088f3b2c494cc7
Parents: 7e0eafa
Author: Stefan Richter <s....@data-artisans.com>
Authored: Fri Jun 15 17:12:43 2018 +0200
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Fri Jun 15 21:06:24 2018 +0200

----------------------------------------------------------------------
 .../state/heap/CopyOnWriteStateTable.java       | 21 ++++++++++++++++----
 1 file changed, 17 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/ae9178fa/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
index d28ed46..bb90c37 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTable.java
@@ -105,6 +105,9 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	 */
 	private static final Logger LOG = LoggerFactory.getLogger(HeapKeyedStateBackend.class);
 
+	/** Maximum save array size to allocate in a JVM. */
+	private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8;
+
 	/**
 	 * Min capacity (other than zero) for a {@link CopyOnWriteStateTable}. Must be a power of two
 	 * greater than 1 (and less than 1 << 30).
@@ -630,13 +633,23 @@ public class CopyOnWriteStateTable<K, N, S> extends StateTable<K, N, S> implemen
 	 */
 	private StateTableEntry<K, N, S>[] makeTable(int newCapacity) {
 
-		if (MAXIMUM_CAPACITY == newCapacity) {
-			LOG.warn("Maximum capacity of 2^30 in StateTable reached. Cannot increase hash table size. This can lead " +
-					"to more collisions and lower performance. Please consider scaling-out your job or using a " +
+		if (newCapacity < MAXIMUM_CAPACITY) {
+			threshold = (newCapacity >> 1) + (newCapacity >> 2); // 3/4 capacity
+		} else {
+			if (size() > MAX_ARRAY_SIZE) {
+
+				throw new IllegalStateException("Maximum capacity of CopyOnWriteStateTable is reached and the job " +
+					"cannot continue. Please consider scaling-out your job or using a different keyed state backend " +
+					"implementation!");
+			} else {
+
+				LOG.warn("Maximum capacity of 2^30 in StateTable reached. Cannot increase hash table size. This can " +
+					"lead to more collisions and lower performance. Please consider scaling-out your job or using a " +
 					"different keyed state backend implementation!");
+				threshold = MAX_ARRAY_SIZE;
+			}
 		}
 
-		threshold = (newCapacity >> 1) + (newCapacity >> 2); // 3/4 capacity
 		@SuppressWarnings("unchecked") StateTableEntry<K, N, S>[] newTable
 				= (StateTableEntry<K, N, S>[]) new StateTableEntry[newCapacity];
 		return newTable;