You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by uc...@apache.org on 2016/10/07 12:15:34 UTC

flink git commit: [FLINK-4731] Fix HeapKeyedStateBackend Scale-In

Repository: flink
Updated Branches:
  refs/heads/master 97c71675a -> 8d953bf26


[FLINK-4731] Fix HeapKeyedStateBackend Scale-In

Adds additional tests in RescalingITCase for scale-in

This closes #2584.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8d953bf2
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8d953bf2
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8d953bf2

Branch: refs/heads/master
Commit: 8d953bf2626012e3e497334641962bd8f96098de
Parents: 97c7167
Author: Stefan Richter <s....@data-artisans.com>
Authored: Sun Oct 2 16:56:41 2016 +0200
Committer: Ufuk Celebi <uc...@apache.org>
Committed: Fri Oct 7 14:14:27 2016 +0200

----------------------------------------------------------------------
 .../savepoint/SavepointV1Serializer.java        |   4 +-
 .../filesystem/FsCheckpointStreamFactory.java   |  11 +-
 .../state/heap/HeapKeyedStateBackend.java       | 164 ++++++------
 .../state/memory/ByteStreamStateHandle.java     |  79 +++---
 .../memory/MemCheckpointStreamFactory.java      |   3 +-
 .../checkpoint/CheckpointCoordinatorTest.java   |  20 +-
 .../checkpoint/savepoint/SavepointV1Test.java   |   9 +-
 .../jobmanager/JobManagerHARecoveryTest.java    |   4 +-
 .../ZooKeeperSubmittedJobGraphsStoreITCase.java |   7 +-
 .../runtime/testutils/CommonTestUtils.java      |  23 ++
 .../TestByteStreamStateHandleDeepCompare.java   |  54 ++++
 .../test/checkpointing/RescalingITCase.java     | 253 ++++++++++++++++---
 12 files changed, 465 insertions(+), 166 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
index f120e1d..666176b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
@@ -290,6 +290,7 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 		} else if (stateHandle instanceof ByteStreamStateHandle) {
 			dos.writeByte(BYTE_STREAM_STATE_HANDLE);
 			ByteStreamStateHandle byteStreamStateHandle = (ByteStreamStateHandle) stateHandle;
+			dos.writeUTF(byteStreamStateHandle.getHandleName());
 			byte[] internalData = byteStreamStateHandle.getData();
 			dos.writeInt(internalData.length);
 			dos.write(byteStreamStateHandle.getData());
@@ -310,10 +311,11 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			String pathString = dis.readUTF();
 			return new FileStateHandle(new Path(pathString), size);
 		} else if (BYTE_STREAM_STATE_HANDLE == type) {
+			String handleName = dis.readUTF();
 			int numBytes = dis.readInt();
 			byte[] data = new byte[numBytes];
 			dis.readFully(data);
-			return new ByteStreamStateHandle(data);
+			return new ByteStreamStateHandle(handleName, data);
 		} else {
 			throw new IOException("Unknown implementation of StreamStateHandle, code: " + type);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java
index e4f7eba..fcc97b3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsCheckpointStreamFactory.java
@@ -177,7 +177,6 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
 			this.localStateThreshold = localStateThreshold;
 		}
 
-
 		@Override
 		public void write(int b) throws IOException {
 			if (pos >= writeBuffer.length) {
@@ -219,7 +218,7 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
 
 		@Override
 		public long getPos() throws IOException {
-			return outStream == null ? pos : outStream.getPos();
+			return pos + (outStream == null ? 0 : outStream.getPos());
 		}
 
 		@Override
@@ -233,7 +232,7 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
 					Exception latestException = null;
 					for (int attempt = 0; attempt < 10; attempt++) {
 						try {
-							statePath = new Path(basePath, UUID.randomUUID().toString());
+							statePath = createStatePath();
 							outStream = fs.create(statePath, false);
 							break;
 						}
@@ -297,7 +296,7 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
 					if (outStream == null && pos <= localStateThreshold) {
 						closed = true;
 						byte[] bytes = Arrays.copyOf(writeBuffer, pos);
-						return new ByteStreamStateHandle(bytes);
+						return new ByteStreamStateHandle(createStatePath().toString(), bytes);
 					}
 					else {
 						flush();
@@ -318,5 +317,9 @@ public class FsCheckpointStreamFactory implements CheckpointStreamFactory {
 				}
 			}
 		}
+
+		private Path createStatePath() {
+			return new Path(basePath, UUID.randomUUID().toString());
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 040677b..b283494 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -30,7 +30,9 @@ import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
@@ -46,6 +48,7 @@ import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -106,12 +109,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	// ------------------------------------------------------------------------
 	//  state backend operations
 	// ------------------------------------------------------------------------
-
+	@SuppressWarnings("unchecked")
 	@Override
 	public <N, V> ValueState<V> createValueState(TypeSerializer<N> namespaceSerializer, ValueStateDescriptor<V> stateDesc) throws Exception {
-		@SuppressWarnings("unchecked,rawtypes")
-		StateTable<K, N, V> stateTable = (StateTable) stateTables.get(stateDesc.getName());
-
+		StateTable<K, N, V> stateTable = (StateTable<K, N, V>) stateTables.get(stateDesc.getName());
 
 		if (stateTable == null) {
 			stateTable = new StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange);
@@ -121,10 +122,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		return new HeapValueState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer);
 	}
 
+	@SuppressWarnings("unchecked")
 	@Override
 	public <N, T> ListState<T> createListState(TypeSerializer<N> namespaceSerializer, ListStateDescriptor<T> stateDesc) throws Exception {
-		@SuppressWarnings("unchecked,rawtypes")
-		StateTable<K, N, ArrayList<T>> stateTable = (StateTable) stateTables.get(stateDesc.getName());
+		StateTable<K, N, ArrayList<T>> stateTable = (StateTable<K, N, ArrayList<T>>) stateTables.get(stateDesc.getName());
 
 		if (stateTable == null) {
 			stateTable = new StateTable<>(new ArrayListSerializer<>(stateDesc.getSerializer()), namespaceSerializer, keyGroupRange);
@@ -133,11 +134,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		return new HeapListState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer);
 	}
-
+	@SuppressWarnings("unchecked")
 	@Override
 	public <N, T> ReducingState<T> createReducingState(TypeSerializer<N> namespaceSerializer, ReducingStateDescriptor<T> stateDesc) throws Exception {
-		@SuppressWarnings("unchecked,rawtypes")
-		StateTable<K, N, T> stateTable = (StateTable) stateTables.get(stateDesc.getName());
+		StateTable<K, N, T> stateTable = (StateTable<K, N, T>) stateTables.get(stateDesc.getName());
 
 		if (stateTable == null) {
 			stateTable = new StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange);
@@ -146,11 +146,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		return new HeapReducingState<>(this, stateDesc, stateTable, keySerializer, namespaceSerializer);
 	}
-
+	@SuppressWarnings("unchecked")
 	@Override
 	protected <N, T, ACC> FoldingState<T, ACC> createFoldingState(TypeSerializer<N> namespaceSerializer, FoldingStateDescriptor<T, ACC> stateDesc) throws Exception {
-		@SuppressWarnings("unchecked,rawtypes")
-		StateTable<K, N, ACC> stateTable = (StateTable) stateTables.get(stateDesc.getName());
+		StateTable<K, N, ACC> stateTable = (StateTable<K, N, ACC>) stateTables.get(stateDesc.getName());
 
 		if (stateTable == null) {
 			stateTable = new StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange);
@@ -161,7 +160,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@Override
-	@SuppressWarnings("rawtypes,unchecked")
+	@SuppressWarnings("unchecked")
 	public RunnableFuture<KeyGroupsStateHandle> snapshot(
 			long checkpointId,
 			long timestamp,
@@ -188,8 +187,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 				outView.writeUTF(kvState.getKey());
 
-				TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
-				TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
+				TypeSerializer<?> namespaceSerializer = kvState.getValue().getNamespaceSerializer();
+				TypeSerializer<?> stateSerializer = kvState.getValue().getStateSerializer();
 
 				InstantiationUtil.serializeObject(stream, namespaceSerializer);
 				InstantiationUtil.serializeObject(stream, stateSerializer);
@@ -203,39 +202,10 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) {
 				keyGroupRangeOffsets[offsetCounter++] = stream.getPos();
 				outView.writeInt(keyGroupIndex);
-
 				for (Map.Entry<String, StateTable<K, ?, ?>> kvState : stateTables.entrySet()) {
-
 					outView.writeShort(kVStateToId.get(kvState.getKey()));
-
-					TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
-					TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
-
-					// Map<NamespaceT, Map<KeyT, StateT>>
-					Map<?, ? extends Map<K, ?>> namespaceMap = kvState.getValue().get(keyGroupIndex);
-					if (namespaceMap == null) {
-						outView.writeByte(0);
-						continue;
-					}
-
-					outView.writeByte(1);
-
-					// number of namespaces
-					outView.writeInt(namespaceMap.size());
-					for (Map.Entry<?, ? extends Map<K, ?>> namespace : namespaceMap.entrySet()) {
-						namespaceSerializer.serialize(namespace.getKey(), outView);
-
-						Map<K, ?> entryMap = namespace.getValue();
-
-						// number of entries
-						outView.writeInt(entryMap.size());
-						for (Map.Entry<K, ?> entry : entryMap.entrySet()) {
-							keySerializer.serialize(entry.getKey(), outView);
-							stateSerializer.serialize(entry.getValue(), outView);
-						}
-					}
+					writeStateTableForKeyGroup(outView, kvState.getValue(), keyGroupIndex);
 				}
-				outView.flush();
 			}
 
 			StreamStateHandle streamStateHandle = stream.closeAndGetHandle();
@@ -246,8 +216,42 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
-	@SuppressWarnings({"unchecked", "rawtypes"})
-	public void restorePartitionedState(List<KeyGroupsStateHandle> state) throws Exception {
+	private <N, S> void writeStateTableForKeyGroup(
+			DataOutputView outView,
+			StateTable<K, N, S> stateTable,
+			int keyGroupIndex) throws IOException {
+
+		TypeSerializer<N> namespaceSerializer = stateTable.getNamespaceSerializer();
+		TypeSerializer<S> stateSerializer = stateTable.getStateSerializer();
+
+		Map<N, Map<K, S>> namespaceMap = stateTable.get(keyGroupIndex);
+		if (namespaceMap == null) {
+			outView.writeByte(0);
+		} else {
+			outView.writeByte(1);
+
+			// number of namespaces
+			outView.writeInt(namespaceMap.size());
+			for (Map.Entry<N, Map<K, S>> namespace : namespaceMap.entrySet()) {
+				namespaceSerializer.serialize(namespace.getKey(), outView);
+
+				Map<K, S> entryMap = namespace.getValue();
+
+				// number of entries
+				outView.writeInt(entryMap.size());
+				for (Map.Entry<K, S> entry : entryMap.entrySet()) {
+					keySerializer.serialize(entry.getKey(), outView);
+					stateSerializer.serialize(entry.getValue(), outView);
+				}
+			}
+		}
+	}
+
+	@SuppressWarnings({"unchecked"})
+	private void restorePartitionedState(List<KeyGroupsStateHandle> state) throws Exception {
+
+		int numRegisteredKvStates = 0;
+		Map<Integer, String> kvStatesById = new HashMap<>();
 
 		for (KeyGroupsStateHandle keyGroupsHandle : state) {
 
@@ -266,21 +270,23 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 				int numKvStates = inView.readShort();
 
-				Map<Integer, String> kvStatesById = new HashMap<>(numKvStates);
-
 				for (int i = 0; i < numKvStates; ++i) {
 					String stateName = inView.readUTF();
 
-					TypeSerializer namespaceSerializer =
+					TypeSerializer<?> namespaceSerializer =
 							InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
-					TypeSerializer stateSerializer =
+					TypeSerializer<?> stateSerializer =
 							InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
 
-					StateTable<K, ?, ?> stateTable = new StateTable(stateSerializer,
-							namespaceSerializer,
-							keyGroupRange);
-					stateTables.put(stateName, stateTable);
-					kvStatesById.put(i, stateName);
+					StateTable<K, ?, ?> stateTable = stateTables.get(stateName);
+
+					//important: only create a new table we did not already create it previously
+					if (null == stateTable) {
+						stateTable = new StateTable<>(stateSerializer, namespaceSerializer, keyGroupRange);
+						stateTables.put(stateName, stateTable);
+						kvStatesById.put(numRegisteredKvStates, stateName);
+						++numRegisteredKvStates;
+					}
 				}
 
 				for (Tuple2<Integer, Long> groupOffset : keyGroupsHandle.getGroupRangeOffsets()) {
@@ -302,25 +308,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 						StateTable<K, ?, ?> stateTable = stateTables.get(kvStatesById.get(kvStateId));
 						Preconditions.checkNotNull(stateTable);
 
-						TypeSerializer namespaceSerializer = stateTable.getNamespaceSerializer();
-						TypeSerializer stateSerializer = stateTable.getStateSerializer();
-
-						Map namespaceMap = new HashMap<>();
-						stateTable.set(keyGroupIndex, namespaceMap);
-
-						int numNamespaces = inView.readInt();
-						for (int k = 0; k < numNamespaces; k++) {
-							Object namespace = namespaceSerializer.deserialize(inView);
-							Map entryMap = new HashMap<>();
-							namespaceMap.put(namespace, entryMap);
-
-							int numEntries = inView.readInt();
-							for (int l = 0; l < numEntries; l++) {
-								Object key = keySerializer.deserialize(inView);
-								Object value = stateSerializer.deserialize(inView);
-								entryMap.put(key, value);
-							}
-						}
+						readStateTableForKeyGroup(inView, stateTable, keyGroupIndex);
 					}
 				}
 			} finally {
@@ -330,6 +318,32 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	private <N, S> void readStateTableForKeyGroup(
+			DataInputView inView,
+			StateTable<K, N, S> stateTable,
+			int keyGroupIndex) throws IOException {
+
+		TypeSerializer<N> namespaceSerializer = stateTable.getNamespaceSerializer();
+		TypeSerializer<S> stateSerializer = stateTable.getStateSerializer();
+
+		Map<N, Map<K, S>> namespaceMap = new HashMap<>();
+		stateTable.set(keyGroupIndex, namespaceMap);
+
+		int numNamespaces = inView.readInt();
+		for (int k = 0; k < numNamespaces; k++) {
+			N namespace = namespaceSerializer.deserialize(inView);
+			Map<K, S> entryMap = new HashMap<>();
+			namespaceMap.put(namespace, entryMap);
+
+			int numEntries = inView.readInt();
+			for (int l = 0; l < numEntries; l++) {
+				K key = keySerializer.deserialize(inView);
+				S state = stateSerializer.deserialize(inView);
+				entryMap.put(key, state);
+			}
+		}
+	}
+
 	@Override
 	public String toString() {
 		return "HeapKeyedStateBackend";

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
index 7d8b6ce..42703f8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/ByteStreamStateHandle.java
@@ -20,62 +20,49 @@ package org.apache.flink.runtime.state.memory;
 
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
-import java.io.Serializable;
-import java.util.Arrays;
 
 /**
  * A state handle that contains stream state in a byte array.
  */
 public class ByteStreamStateHandle implements StreamStateHandle {
 
-	private static final long serialVersionUID = -5280226231200217594L;
+	private static final long serialVersionUID = -5280226231202517594L;
 
 	/**
-	 * the state data
+	 * The state data.
 	 */
 	protected final byte[] data;
 
 	/**
+	 * A unique name of by which this state handle is identified and compared. Like a filename, all
+	 * {@link ByteStreamStateHandle} with the exact same name must also have the exact same content in data.
+	 */
+	protected final String handleName;
+
+	/**
 	 * Creates a new ByteStreamStateHandle containing the given data.
-	 *
-	 * @param data The state data.
 	 */
-	public ByteStreamStateHandle(byte[] data) {
-		this.data = data;
+	public ByteStreamStateHandle(String handleName, byte[] data) {
+		this.handleName = Preconditions.checkNotNull(handleName);
+		this.data = Preconditions.checkNotNull(data);
 	}
 
 	@Override
 	public FSDataInputStream openInputStream() throws IOException {
-
-		return new FSDataInputStream() {
-			int index = 0;
-
-			@Override
-			public void seek(long desired) throws IOException {
-				Preconditions.checkArgument(desired >= 0 && desired < Integer.MAX_VALUE);
-				index = (int) desired;
-			}
-
-			@Override
-			public long getPos() throws IOException {
-				return index;
-			}
-
-			@Override
-			public int read() throws IOException {
-				return index < data.length ? data[index++] & 0xFF : -1;
-			}
-		};
+		return new ByteStateHandleInputStream();
 	}
 
 	public byte[] getData() {
 		return data;
 	}
 
+	public String getHandleName() {
+		return handleName;
+	}
+
 	@Override
 	public void discardState() {
 	}
@@ -94,17 +81,41 @@ public class ByteStreamStateHandle implements StreamStateHandle {
 			return false;
 		}
 
-		ByteStreamStateHandle that = (ByteStreamStateHandle) o;
-		return Arrays.equals(data, that.data);
 
+		ByteStreamStateHandle that = (ByteStreamStateHandle) o;
+		return handleName.equals(that.handleName);
 	}
 
 	@Override
 	public int hashCode() {
-		return Arrays.hashCode(data);
+		return 31 * handleName.hashCode();
 	}
 
-	public static StreamStateHandle fromSerializable(Serializable value) throws IOException {
-		return new ByteStreamStateHandle(InstantiationUtil.serializeObject(value));
+	/**
+	 * An input stream view on a byte array.
+	 */
+	private final class ByteStateHandleInputStream extends FSDataInputStream {
+
+		private int index;
+
+		public ByteStateHandleInputStream() {
+			this.index = 0;
+		}
+
+		@Override
+		public void seek(long desired) throws IOException {
+			Preconditions.checkArgument(desired >= 0 && desired < Integer.MAX_VALUE);
+			index = (int) desired;
+		}
+
+		@Override
+		public long getPos() throws IOException {
+			return index;
+		}
+
+		@Override
+		public int read() throws IOException {
+			return index < data.length ? data[index++] & 0xFF : -1;
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
index 6f0a814..028f8c8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemCheckpointStreamFactory.java
@@ -23,6 +23,7 @@ import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.StreamStateHandle;
 
 import java.io.IOException;
+import java.util.UUID;
 
 /**
  * {@link CheckpointStreamFactory} that produces streams that write to in-memory byte arrays.
@@ -118,7 +119,7 @@ public class MemCheckpointStreamFactory implements CheckpointStreamFactory {
 			if (isEmpty) {
 				return null;
 			}
-			return new ByteStreamStateHandle(closeAndGetBytes());
+			return new ByteStreamStateHandle(String.valueOf(UUID.randomUUID()), closeAndGetBytes());
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 6289fcb..728c7d5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -45,6 +45,8 @@ import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 import org.junit.Assert;
@@ -65,6 +67,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
+import java.util.UUID;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -2287,7 +2290,8 @@ public class CheckpointCoordinatorTest {
 
 		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
 
-		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
+		ByteStreamStateHandle allSerializedStatesHandle = new TestByteStreamStateHandleDeepCompare(
+				String.valueOf(UUID.randomUUID()),
 				serializedDataWithOffsets.f0);
 		KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle(
 				keyGroupRangeOffsets,
@@ -2343,7 +2347,8 @@ public class CheckpointCoordinatorTest {
 
 	public static ChainedStateHandle<StreamStateHandle> generateChainedStateHandle(
 			Serializable value) throws IOException {
-		return ChainedStateHandle.wrapSingleHandle(ByteStreamStateHandle.fromSerializable(value));
+		return ChainedStateHandle.wrapSingleHandle(
+				TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value));
 	}
 
 	public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
@@ -2387,7 +2392,8 @@ public class CheckpointCoordinatorTest {
 			++idx;
 		}
 
-		ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
+		ByteStreamStateHandle streamStateHandle = new TestByteStreamStateHandleDeepCompare(
+				String.valueOf(UUID.randomUUID()),
 				serializationWithOffsets.f0);
 
 		OperatorStateHandle operatorStateHandle =
@@ -2468,7 +2474,9 @@ public class CheckpointCoordinatorTest {
 			ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
 			ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = executionJobVertex.
 					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle();
-			assertEquals(expectNonPartitionedState.get(0), actualNonPartitionedState.get(0));
+			assertTrue(CommonTestUtils.isSteamContentEqual(
+					expectNonPartitionedState.get(0).openInputStream(),
+					actualNonPartitionedState.get(0).openInputStream()));
 
 			ChainedStateHandle<OperatorStateHandle> expectedPartitionableState =
 					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8);
@@ -2476,7 +2484,9 @@ public class CheckpointCoordinatorTest {
 			List<Collection<OperatorStateHandle>> actualPartitionableState = executionJobVertex.
 					getTaskVertices()[i].getCurrentExecutionAttempt().getChainedPartitionableStateHandle();
 
-			assertEquals(expectedPartitionableState.get(0), actualPartitionableState.get(0).iterator().next());
+			assertTrue(CommonTestUtils.isSteamContentEqual(
+					expectedPartitionableState.get(0).openInputStream(),
+					actualPartitionableState.get(0).iterator().next().openInputStream()));
 
 			List<KeyGroupsStateHandle> expectPartitionedKeyGroupState = generateKeyGroupState(
 					jobVertexID,

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
index c82be18..e38e5fb 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
@@ -26,7 +26,7 @@ import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.junit.Test;
 
 import java.io.IOException;
@@ -72,11 +72,11 @@ public class SavepointV1Test {
 		for (int i = 0; i < numTaskStates; i++) {
 			TaskState taskState = new TaskState(new JobVertexID(), numSubtaskStates, numSubtaskStates, 1);
 			for (int j = 0; j < numSubtaskStates; j++) {
-				StreamStateHandle stateHandle = new ByteStreamStateHandle("Hello".getBytes());
+				StreamStateHandle stateHandle = new TestByteStreamStateHandleDeepCompare("a", "Hello".getBytes());
 				taskState.putState(i, new SubtaskState(
 						new ChainedStateHandle<>(Collections.singletonList(stateHandle)), 0));
 
-				stateHandle = new ByteStreamStateHandle("Beautiful".getBytes());
+				stateHandle = new TestByteStreamStateHandleDeepCompare("b", "Beautiful".getBytes());
 				Map<String, long[]> offsetsMap = new HashMap<>();
 				offsetsMap.put("A", new long[]{0, 10, 20});
 				offsetsMap.put("B", new long[]{30, 40, 50});
@@ -93,7 +93,8 @@ public class SavepointV1Test {
 			taskState.putKeyedState(
 					0,
 					new KeyGroupsStateHandle(
-							new KeyGroupRangeOffsets(1,1, new long[] {42}), new ByteStreamStateHandle("World".getBytes())));
+							new KeyGroupRangeOffsets(1, 1, new long[]{42}),
+							new TestByteStreamStateHandleDeepCompare("c", "World".getBytes())));
 
 			taskStates.add(taskState);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index 6100856..b9c2bdf 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -69,6 +69,7 @@ import org.apache.flink.runtime.testingUtils.TestingMessages;
 import org.apache.flink.runtime.testingUtils.TestingTaskManager;
 import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
+import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
@@ -461,7 +462,8 @@ public class JobManagerHARecoveryTest {
 		@Override
 		public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData) {
 			try {
-				ByteStreamStateHandle byteStreamStateHandle = new ByteStreamStateHandle(
+				ByteStreamStateHandle byteStreamStateHandle = new TestByteStreamStateHandleDeepCompare(
+						String.valueOf(UUID.randomUUID()),
 						InstantiationUtil.serializeObject(checkpointMetaData.getCheckpointId()));
 
 				RetrievableStreamStateHandle<Long> state = new RetrievableStreamStateHandle<Long>(byteStreamStateHandle);

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
index 6ef184d..7d21cfd 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.jobmanager.SubmittedJobGraphStore.SubmittedJobGr
 import org.apache.flink.runtime.state.RetrievableStateHandle;
 import org.apache.flink.runtime.state.RetrievableStreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment;
 import org.apache.flink.util.InstantiationUtil;
@@ -40,6 +41,7 @@ import org.mockito.stubbing.Answer;
 import java.io.IOException;
 import java.util.HashMap;
 import java.util.List;
+import java.util.UUID;
 import java.util.concurrent.CountDownLatch;
 
 import static org.junit.Assert.assertEquals;
@@ -62,9 +64,10 @@ public class ZooKeeperSubmittedJobGraphsStoreITCase extends TestLogger {
 	private final static RetrievableStateStorageHelper<SubmittedJobGraph> localStateStorage = new RetrievableStateStorageHelper<SubmittedJobGraph>() {
 		@Override
 		public RetrievableStateHandle<SubmittedJobGraph> store(SubmittedJobGraph state) throws IOException {
-			ByteStreamStateHandle byteStreamStateHandle = new ByteStreamStateHandle(
+			ByteStreamStateHandle byteStreamStateHandle = new TestByteStreamStateHandleDeepCompare(
+					String.valueOf(UUID.randomUUID()),
 					InstantiationUtil.serializeObject(state));
-			return new RetrievableStreamStateHandle<SubmittedJobGraph>(byteStreamStateHandle);
+			return new RetrievableStreamStateHandle<>(byteStreamStateHandle);
 		}
 	};
 

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
index 2a787f3..d857a19 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.testutils;
 
 import org.apache.flink.util.FileUtils;
 
+import java.io.BufferedInputStream;
 import java.io.File;
 import java.io.FileWriter;
 import java.io.IOException;
@@ -193,4 +194,26 @@ public class CommonTestUtils {
 			}
 		}
 	}
+
+	public static boolean isSteamContentEqual(InputStream input1, InputStream input2) throws IOException {
+
+		if (!(input1 instanceof BufferedInputStream)) {
+			input1 = new BufferedInputStream(input1);
+		}
+		if (!(input2 instanceof BufferedInputStream)) {
+			input2 = new BufferedInputStream(input2);
+		}
+
+		int ch = input1.read();
+		while (-1 != ch) {
+			int ch2 = input2.read();
+			if (ch != ch2) {
+				return false;
+			}
+			ch = input1.read();
+		}
+
+		int ch2 = input2.read();
+		return (ch2 == -1);
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestByteStreamStateHandleDeepCompare.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestByteStreamStateHandleDeepCompare.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestByteStreamStateHandleDeepCompare.java
new file mode 100644
index 0000000..7d8797b
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestByteStreamStateHandleDeepCompare.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.util;
+
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.util.InstantiationUtil;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Arrays;
+
+public class TestByteStreamStateHandleDeepCompare extends ByteStreamStateHandle {
+
+	private static final long serialVersionUID = -4946526195523509L;
+
+	public TestByteStreamStateHandleDeepCompare(String handleName, byte[] data) {
+		super(handleName, data);
+	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (!super.equals(o)) {
+			return false;
+		}
+		ByteStreamStateHandle other = (ByteStreamStateHandle) o;
+		return Arrays.equals(getData(), other.getData());
+	}
+
+	@Override
+	public int hashCode() {
+		return 31 * super.hashCode() + Arrays.hashCode(getData());
+	}
+
+	public static StreamStateHandle fromSerializable(String handleName, Serializable value) throws IOException {
+		return new TestByteStreamStateHandleDeepCompare(handleName, InstantiationUtil.serializeObject(value));
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/8d953bf2/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index 263bf79..848a579 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -39,6 +39,7 @@ import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.testingUtils.TestingCluster;
 import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
@@ -57,7 +58,9 @@ import scala.concurrent.duration.Deadline;
 import scala.concurrent.duration.FiniteDuration;
 
 import java.io.File;
+import java.util.ArrayList;
 import java.util.HashSet;
+import java.util.List;
 import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
@@ -67,6 +70,10 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
+/**
+ * TODO : parameterize to test all different state backends!
+ * TODO: reactivate ignored test as soon as savepoints work with deactivated checkpoints.
+ */
 public class RescalingITCase extends TestLogger {
 
 	private static final int numTaskManagers = 2;
@@ -103,17 +110,26 @@ public class RescalingITCase extends TestLogger {
 		}
 	}
 
+	@Test
+	public void testSavepointRescalingInKeyedState() throws Exception {
+		testSavepointRescalingKeyedState(false);
+	}
+
+	@Test
+	public void testSavepointRescalingOutKeyedState() throws Exception {
+		testSavepointRescalingKeyedState(true);
+	}
+
 	/**
-	 * Tests that a a job with purely partitioned state can be restarted from a savepoint
+	 * Tests that a a job with purely keyed state can be restarted from a savepoint
 	 * with a different parallelism.
 	 */
-	@Test
-	public void testSavepointRescalingWithPartitionedState() throws Exception {
+	public void testSavepointRescalingKeyedState(boolean scaleOut) throws Exception {
 		final int numberKeys = 42;
 		final int numberElements = 1000;
 		final int numberElements2 = 500;
-		final int parallelism = numSlots / 2;
-		final int parallelism2 = numSlots;
+		final int parallelism = scaleOut ? numSlots / 2 : numSlots;
+		final int parallelism2 = scaleOut ? numSlots : numSlots / 2;
 		final int maxParallelism = 13;
 
 		FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
@@ -125,7 +141,7 @@ public class RescalingITCase extends TestLogger {
 		try {
 			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
-			JobGraph jobGraph = createPartitionedStateJobGraph(parallelism, maxParallelism, numberKeys, numberElements, false, 100);
+			JobGraph jobGraph = createJobGraphWithKeyedState(parallelism, maxParallelism, numberKeys, numberElements, false, 100);
 
 			jobID = jobGraph.getJobID();
 
@@ -168,7 +184,7 @@ public class RescalingITCase extends TestLogger {
 
 			jobID = null;
 
-			JobGraph scaledJobGraph = createPartitionedStateJobGraph(parallelism2, maxParallelism, numberKeys, numberElements2, true, 100);
+			JobGraph scaledJobGraph = createJobGraphWithKeyedState(parallelism2, maxParallelism, numberKeys, numberElements2, true, 100);
 
 			scaledJobGraph.setSavepointPath(savepointPath);
 
@@ -189,6 +205,7 @@ public class RescalingITCase extends TestLogger {
 
 			assertEquals(expectedResult2, actualResult2);
 
+
 		} finally {
 			// clear the CollectionSink set for the restarted job
 			CollectionSink.clearElementsSet();
@@ -213,7 +230,7 @@ public class RescalingITCase extends TestLogger {
 	 * @throws Exception
 	 */
 	@Test
-	public void testSavepointRescalingFailureWithNonPartitionedState() throws Exception {
+	public void testSavepointRescalingNonPartitionedStateCausesException() throws Exception {
 		final int parallelism = numSlots / 2;
 		final int parallelism2 = numSlots;
 		final int maxParallelism = 13;
@@ -227,7 +244,7 @@ public class RescalingITCase extends TestLogger {
 		try {
 			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
-			JobGraph jobGraph = createNonPartitionedStateJobGraph(parallelism, maxParallelism, 500);
+			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, false);
 
 			jobID = jobGraph.getJobID();
 
@@ -236,7 +253,7 @@ public class RescalingITCase extends TestLogger {
 			Object savepointResponse = null;
 
 			// wait until the operator is started
-			NonPartitionedStateSource.workStartedLatch.await();
+			StateSourceBase.workStartedLatch.await();
 
 			while (deadline.hasTimeLeft()) {
 
@@ -266,7 +283,7 @@ public class RescalingITCase extends TestLogger {
 			// job successfully removed
 			jobID = null;
 
-			JobGraph scaledJobGraph = createNonPartitionedStateJobGraph(parallelism2, maxParallelism, 500);
+			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, false);
 
 			scaledJobGraph.setSavepointPath(savepointPath);
 
@@ -311,7 +328,7 @@ public class RescalingITCase extends TestLogger {
 	 * @throws Exception
 	 */
 	@Test
-	public void testSavepointRescalingWithPartiallyNonPartitionedState() throws Exception {
+	public void testSavepointRescalingWithKeyedAndNonPartitionedState() throws Exception {
 		int numberKeys = 42;
 		int numberElements = 1000;
 		int numberElements2 = 500;
@@ -328,7 +345,7 @@ public class RescalingITCase extends TestLogger {
 		try {
 			 jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
-			JobGraph jobGraph = createPartitionedNonPartitionedStateJobGraph(
+			JobGraph jobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
 				parallelism,
 				maxParallelism,
 				parallelism,
@@ -378,7 +395,7 @@ public class RescalingITCase extends TestLogger {
 
 			jobID = null;
 
-			JobGraph scaledJobGraph = createPartitionedNonPartitionedStateJobGraph(
+			JobGraph scaledJobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
 				parallelism2,
 				maxParallelism,
 				parallelism,
@@ -423,23 +440,137 @@ public class RescalingITCase extends TestLogger {
 		}
 	}
 
-	private static JobGraph createNonPartitionedStateJobGraph(int parallelism, int maxParallelism, long checkpointInterval) {
+	@Test
+	public void testSavepointRescalingInPartitionedOperatorState() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(false);
+	}
+
+	@Test
+	public void testSavepointRescalingOutPartitionedOperatorState() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(true);
+	}
+
+
+	/**
+	 * Tests rescaling of partitioned operator state. More specific, we test the mechanism with {@link ListCheckpointed}
+	 * as it subsumes {@link org.apache.flink.streaming.api.checkpoint.CheckpointedFunction}.
+	 */
+	public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut) throws Exception {
+		final int parallelism = scaleOut ? numSlots : numSlots / 2;
+		final int parallelism2 = scaleOut ? numSlots / 2 : numSlots;
+		final int maxParallelism = 13;
+
+		FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
+		Deadline deadline = timeout.fromNow();
+
+		JobID jobID = null;
+		ActorGateway jobManager = null;
+
+		int counterSize = Math.max(parallelism, parallelism2);
+
+		PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
+		PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
+
+		try {
+			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+
+			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, true);
+
+			jobID = jobGraph.getJobID();
+
+			cluster.submitJobDetached(jobGraph);
+
+			Object savepointResponse = null;
+
+			// wait until the operator is started
+			StateSourceBase.workStartedLatch.await();
+
+			while (deadline.hasTimeLeft()) {
+
+				Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID), deadline.timeLeft());
+				FiniteDuration waitingTime = new FiniteDuration(10, TimeUnit.SECONDS);
+				savepointResponse = Await.result(savepointPathFuture, waitingTime);
+
+				if (savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess) {
+					break;
+				}
+			}
+
+			assertTrue(savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
+
+			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath();
+
+			Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
+
+			Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
+
+			Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
+
+			assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
+
+			Await.ready(jobRemovedFuture, deadline.timeLeft());
+
+			// job successfully removed
+			jobID = null;
+
+			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism, true);
+
+			scaledJobGraph.setSavepointPath(savepointPath);
+
+			jobID = scaledJobGraph.getJobID();
+
+			cluster.submitJobAndWait(scaledJobGraph, false);
+
+			int sumExp = 0;
+			int sumAct = 0;
+
+			for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
+				sumExp += c;
+			}
+
+			for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
+				sumAct += c;
+			}
+
+			assertEquals(sumExp, sumAct);
+			jobID = null;
+
+		} finally {
+			// clear any left overs from a possibly failed job
+			if (jobID != null && jobManager != null) {
+				Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
+
+				try {
+					Await.ready(jobRemovedFuture, timeout);
+				} catch (TimeoutException | InterruptedException ie) {
+					fail("Failed while cleaning up the cluster.");
+				}
+			}
+		}
+	}
+
+	//------------------------------------------------------------------------------------------------------------------
+
+	private static JobGraph createJobGraphWithOperatorState(
+			int parallelism, int maxParallelism, boolean partitionedOperatorState) {
+
 		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 		env.setParallelism(parallelism);
 		env.getConfig().setMaxParallelism(maxParallelism);
-		env.enableCheckpointing(checkpointInterval);
+		env.enableCheckpointing(Long.MAX_VALUE);
 		env.setRestartStrategy(RestartStrategies.noRestart());
 
-		NonPartitionedStateSource.workStartedLatch = new CountDownLatch(1);
+		StateSourceBase.workStartedLatch = new CountDownLatch(1);
 
-		DataStream<Integer> input = env.addSource(new NonPartitionedStateSource());
+		DataStream<Integer> input = env.addSource(
+				partitionedOperatorState ? new PartitionedStateSource() : new NonPartitionedStateSource());
 
 		input.addSink(new DiscardingSink<Integer>());
 
 		return env.getStreamGraph().getJobGraph();
 	}
 
-	private static JobGraph createPartitionedStateJobGraph(
+	private static JobGraph createJobGraphWithKeyedState(
 		int parallelism,
 		int maxParallelism,
 		int numberKeys,
@@ -475,7 +606,7 @@ public class RescalingITCase extends TestLogger {
 		return env.getStreamGraph().getJobGraph();
 	}
 
-	private static JobGraph createPartitionedNonPartitionedStateJobGraph(
+	private static JobGraph createJobGraphWithKeyedAndNonPartitionedOperatorState(
 		int parallelism,
 		int maxParallelism,
 		int fixedParallelism,
@@ -606,12 +737,13 @@ public class RescalingITCase extends TestLogger {
 
 		@Override
 		public void open(Configuration configuration) {
-			counter = getRuntimeContext().getState(new ValueStateDescriptor<Integer>("counter", Integer.class, 0));
-			sum = getRuntimeContext().getState(new ValueStateDescriptor<Integer>("sum", Integer.class, 0));
+			counter = getRuntimeContext().getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
+			sum = getRuntimeContext().getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
 		}
 
 		@Override
 		public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) throws Exception {
+
 			int count = counter.value() + 1;
 			counter.update(count);
 
@@ -645,14 +777,43 @@ public class RescalingITCase extends TestLogger {
 		}
 	}
 
-	private static class NonPartitionedStateSource extends RichParallelSourceFunction<Integer> implements Checkpointed<Integer> {
-
-		private static final long serialVersionUID = -8108185918123186841L;
+	private static class StateSourceBase extends RichParallelSourceFunction<Integer> {
 
 		private static volatile CountDownLatch workStartedLatch = new CountDownLatch(1);
 
-		private volatile int counter = 0;
-		private volatile boolean running = true;
+		protected volatile int counter = 0;
+		protected volatile boolean running = true;
+
+		@Override
+		public void run(SourceContext<Integer> ctx) throws Exception {
+			final Object lock = ctx.getCheckpointLock();
+
+			while (running) {
+				synchronized (lock) {
+					++counter;
+					ctx.collect(1);
+				}
+
+				Thread.sleep(2);
+				if (counter == 10) {
+					workStartedLatch.countDown();
+				}
+
+				if (counter >= 500) {
+					break;
+				}
+			}
+		}
+
+		@Override
+		public void cancel() {
+			running = false;
+		}
+	}
+
+	private static class NonPartitionedStateSource extends StateSourceBase implements Checkpointed<Integer> {
+
+		private static final long serialVersionUID = -8108185918123186841L;
 
 		@Override
 		public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
@@ -663,28 +824,42 @@ public class RescalingITCase extends TestLogger {
 		public void restoreState(Integer state) throws Exception {
 			counter = state;
 		}
+	}
+
+	private static class PartitionedStateSource extends StateSourceBase implements ListCheckpointed<Integer> {
+
+		private static final long serialVersionUID = -359715965103593462L;
+		private static final int NUM_PARTITIONS = 7;
+
+		private static int[] CHECK_CORRECT_SNAPSHOT;
+		private static int[] CHECK_CORRECT_RESTORE;
 
 		@Override
-		public void run(SourceContext<Integer> ctx) throws Exception {
-			final Object lock = ctx.getCheckpointLock();
+		public List<Integer> snapshotState(long checkpointId, long timestamp) throws Exception {
 
-			while (running) {
-				synchronized (lock) {
-					counter++;
+			CHECK_CORRECT_SNAPSHOT[getRuntimeContext().getIndexOfThisSubtask()] = counter;
 
-					ctx.collect(counter * getRuntimeContext().getIndexOfThisSubtask());
-				}
+			int div = counter / NUM_PARTITIONS;
+			int mod = counter % NUM_PARTITIONS;
 
-				Thread.sleep(2);
-				if(counter == 10) {
-					workStartedLatch.countDown();
+			List<Integer> split = new ArrayList<>();
+			for (int i = 0; i < NUM_PARTITIONS; ++i) {
+				int partitionValue = div;
+				if (mod > 0) {
+					--mod;
+					++partitionValue;
 				}
+				split.add(partitionValue);
 			}
+			return split;
 		}
 
 		@Override
-		public void cancel() {
-			running = false;
+		public void restoreState(List<Integer> state) throws Exception {
+			for (Integer v : state) {
+				counter += v;
+			}
+			CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
 		}
 	}
 }