You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by al...@apache.org on 2017/01/13 20:37:28 UTC

[1/2] flink git commit: Hide broadcast state / remove from public API

Repository: flink
Updated Branches:
  refs/heads/master 8b1b4a1cc -> 6a86e9d62


Hide broadcast state / remove from public API


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6a86e9d6
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6a86e9d6
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6a86e9d6

Branch: refs/heads/master
Commit: 6a86e9d62045386b76026f4deead8baa559f008e
Parents: 1020ba2
Author: Stefan Richter <s....@data-artisans.com>
Authored: Fri Jan 13 16:38:33 2017 +0100
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Fri Jan 13 21:29:19 2017 +0100

----------------------------------------------------------------------
 .../api/common/state/OperatorStateStore.java    | 27 --------------------
 .../state/DefaultOperatorStateBackend.java      |  2 --
 .../runtime/state/OperatorStateBackendTest.java | 12 +++++----
 .../test/checkpointing/RescalingITCase.java     |  5 +++-
 4 files changed, 11 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/6a86e9d6/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
index 87a7759..c1cdfe4 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
@@ -57,33 +57,6 @@ public interface OperatorStateStore {
 	<T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception;
 
 	/**
-	 * Creates (or restores) a list state. Each state is registered under a unique name.
-	 * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
-	 *
-	 * On restore, all items in the list are broadcasted to all parallel operator instances.
-	 *
-	 * @param stateDescriptor The descriptor for this state, providing a name and serializer.
-	 * @param <S> The generic type of the state
-	 *
-	 * @return A list for all state partitions.
-	 * @throws Exception
-	 */
-	<S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;
-
-	/**
-	 * Creates a state of the given name that uses Java serialization to persist the state. On restore, all items
-	 * in the list are broadcasted to all parallel operator instances.
-	 *
-	 * <p>This is a simple convenience method. For more flexibility on how state serialization
-	 * should happen, use the {@link #getBroadcastOperatorState(ListStateDescriptor)} method.
-	 *
-	 * @param stateName The name of state to create
-	 * @return A list state using Java serialization to serialize state objects.
-	 * @throws Exception
-	 */
-	<T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception;
-
-	/**
 	 * Returns a set with the names of all currently registered states.
 	 * @return set of names for all registered states.
 	 */

http://git-wip-us.apache.org/repos/asf/flink/blob/6a86e9d6/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index 6c65088..1cd1da7 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -91,12 +91,10 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 	}
 
 	@SuppressWarnings("unchecked")
-	@Override
 	public <T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception {
 		return (ListState<T>) getBroadcastOperatorState(new ListStateDescriptor<>(stateName, javaSerializer));
 	}
 
-	@Override
 	public <S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception {
 		return getOperatorState(stateDescriptor, OperatorStateHandle.Mode.BROADCAST);
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/6a86e9d6/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
index cd0391f..5bd085f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -45,8 +45,9 @@ public class OperatorStateBackendTest {
 		return env;
 	}
 
-	private OperatorStateBackend createNewOperatorStateBackend() throws Exception {
-		return abstractStateBackend.createOperatorStateBackend(
+	private DefaultOperatorStateBackend createNewOperatorStateBackend() throws Exception {
+		//TODO this is temporarily casted to test already functionality that we do not yet expose through public API
+		return (DefaultOperatorStateBackend) abstractStateBackend.createOperatorStateBackend(
 				createMockEnvironment(),
 				"test-operator");
 	}
@@ -60,7 +61,7 @@ public class OperatorStateBackendTest {
 
 	@Test
 	public void testRegisterStates() throws Exception {
-		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
+		DefaultOperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
 		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
 		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
 		ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());
@@ -143,7 +144,7 @@ public class OperatorStateBackendTest {
 
 	@Test
 	public void testSnapshotRestore() throws Exception {
-		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
+		DefaultOperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
 		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
 		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
 		ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());
@@ -171,7 +172,8 @@ public class OperatorStateBackendTest {
 			operatorStateBackend.close();
 			operatorStateBackend.dispose();
 
-			operatorStateBackend = abstractStateBackend.createOperatorStateBackend(
+			//TODO this is temporarily casted to test already functionality that we do not yet expose through public API
+			operatorStateBackend = (DefaultOperatorStateBackend) abstractStateBackend.createOperatorStateBackend(
 					createMockEnvironment(),
 					"testOperator");
 

http://git-wip-us.apache.org/repos/asf/flink/blob/6a86e9d6/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index 45fcc25..bd1678e 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -34,6 +34,7 @@ import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
 import org.apache.flink.runtime.messages.JobManagerMessages;
+import org.apache.flink.runtime.state.DefaultOperatorStateBackend;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
@@ -969,8 +970,10 @@ public class RescalingITCase extends TestLogger {
 		public void initializeState(FunctionInitializationContext context) throws Exception {
 
 			if (broadcast) {
+				//TODO this is temporarily casted to test already functionality that we do not yet expose through public API
+				DefaultOperatorStateBackend operatorStateStore = (DefaultOperatorStateBackend) context.getOperatorStateStore();
 				this.counterPartitions =
-						context.getOperatorStateStore().getBroadcastSerializableListState("counter_partitions");
+						operatorStateStore.getBroadcastSerializableListState("counter_partitions");
 			} else {
 				this.counterPartitions =
 						context.getOperatorStateStore().getSerializableListState("counter_partitions");


[2/2] flink git commit: [FLINK-5265] Introduce state handle replication mode for CheckpointCoordinator

Posted by al...@apache.org.
[FLINK-5265] Introduce state handle replication mode for CheckpointCoordinator


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/1020ba2c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/1020ba2c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/1020ba2c

Branch: refs/heads/master
Commit: 1020ba2c9cfc1d01703e97c72e20a922bae0732d
Parents: 8b1b4a1
Author: Stefan Richter <s....@data-artisans.com>
Authored: Sat Dec 3 02:42:25 2016 +0100
Committer: Aljoscha Krettek <al...@gmail.com>
Committed: Fri Jan 13 21:29:19 2017 +0100

----------------------------------------------------------------------
 .../api/common/state/OperatorStateStore.java    |  37 +++-
 .../RoundRobinOperatorStateRepartitioner.java   | 133 ++++++++++++---
 .../checkpoint/StateAssignmentOperation.java    |  15 +-
 .../savepoint/SavepointV1Serializer.java        |  24 ++-
 .../state/DefaultOperatorStateBackend.java      | 160 +++++++++++------
 .../OperatorBackendSerializationProxy.java      |  51 +++++-
 .../OperatorStateCheckpointOutputStream.java    |  10 +-
 .../runtime/state/OperatorStateHandle.java      |  86 +++++++++-
 .../state/StateInitializationContextImpl.java   |  18 +-
 .../checkpoint/CheckpointCoordinatorTest.java   | 119 +++++++++----
 .../checkpoint/savepoint/SavepointV1Test.java   |   7 +-
 .../runtime/state/OperatorStateBackendTest.java |  54 +++++-
 .../runtime/state/OperatorStateHandleTest.java  |  39 +++++
 ...OperatorStateOutputCheckpointStreamTest.java |  11 +-
 .../runtime/state/SerializationProxiesTest.java |  63 ++++++-
 .../StateInitializationContextImplTest.java     |   6 +-
 .../tasks/InterruptSensitiveRestoreTest.java    |   6 +-
 .../test/checkpointing/RescalingITCase.java     | 171 +++++++++++--------
 18 files changed, 787 insertions(+), 223 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
index 43dbe51..87a7759 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/OperatorStateStore.java
@@ -30,8 +30,22 @@ import java.util.Set;
 public interface OperatorStateStore {
 
 	/**
-	 * Creates a state descriptor of the given name that uses Java serialization to persist the
-	 * state.
+	 * Creates (or restores) a list state. Each state is registered under a unique name.
+	 * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
+	 *
+	 * The items in the list are repartitionable by the system in case of changed operator parallelism.
+	 *
+	 * @param stateDescriptor The descriptor for this state, providing a name and serializer.
+	 * @param <S> The generic type of the state
+	 *
+	 * @return A list for all state partitions.
+	 * @throws Exception
+	 */
+	<S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;
+
+	/**
+	 * Creates a state of the given name that uses Java serialization to persist the state. The items in the list
+	 * are repartitionable by the system in case of changed operator parallelism.
 	 * 
 	 * <p>This is a simple convenience method. For more flexibility on how state serialization
 	 * should happen, use the {@link #getOperatorState(ListStateDescriptor)} method.
@@ -46,13 +60,28 @@ public interface OperatorStateStore {
 	 * Creates (or restores) a list state. Each state is registered under a unique name.
 	 * The provided serializer is used to de/serialize the state in case of checkpointing (snapshot/restore).
 	 *
+	 * On restore, all items in the list are broadcasted to all parallel operator instances.
+	 *
 	 * @param stateDescriptor The descriptor for this state, providing a name and serializer.
 	 * @param <S> The generic type of the state
-	 * 
+	 *
 	 * @return A list for all state partitions.
 	 * @throws Exception
 	 */
-	<S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;
+	<S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception;
+
+	/**
+	 * Creates a state of the given name that uses Java serialization to persist the state. On restore, all items
+	 * in the list are broadcasted to all parallel operator instances.
+	 *
+	 * <p>This is a simple convenience method. For more flexibility on how state serialization
+	 * should happen, use the {@link #getBroadcastOperatorState(ListStateDescriptor)} method.
+	 *
+	 * @param stateName The name of state to create
+	 * @return A list state using Java serialization to serialize state objects.
+	 * @throws Exception
+	 */
+	<T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception;
 
 	/**
 	 * Returns a set with the names of all currently registered states.

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
index 16a7e27..046096f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/RoundRobinOperatorStateRepartitioner.java
@@ -26,6 +26,7 @@ import org.apache.flink.util.Preconditions;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.EnumMap;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -47,8 +48,7 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 		Preconditions.checkArgument(parallelism > 0);
 
 		// Reorganize: group by (State Name -> StreamStateHandle + Offsets)
-		Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState =
-				groupByStateName(previousParallelSubtaskStates);
+		GroupByStateNameResults nameToStateByMode = groupByStateName(previousParallelSubtaskStates);
 
 		if (OPTIMIZE_MEMORY_USE) {
 			previousParallelSubtaskStates.clear(); // free for GC at to cost that old handles are no longer available
@@ -59,7 +59,7 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 
 		// Do the actual repartitioning for all named states
 		List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList =
-				repartition(nameToState, parallelism);
+				repartition(nameToStateByMode, parallelism);
 
 		for (int i = 0; i < mergeMapList.size(); ++i) {
 			result.add(i, new ArrayList<>(mergeMapList.get(i).values()));
@@ -71,16 +71,33 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 	/**
 	 * Group by the different named states.
 	 */
-	private Map<String, List<Tuple2<StreamStateHandle, long[]>>> groupByStateName(
+	@SuppressWarnings("unchecked, rawtype")
+	private GroupByStateNameResults groupByStateName(
 			List<OperatorStateHandle> previousParallelSubtaskStates) {
 
-		//Reorganize: group by (State Name -> StreamStateHandle + Offsets)
-		Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState = new HashMap<>();
+		//Reorganize: group by (State Name -> StreamStateHandle + StateMetaInfo)
+		EnumMap<OperatorStateHandle.Mode,
+				Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> nameToStateByMode =
+				new EnumMap<>(OperatorStateHandle.Mode.class);
+
+		for (OperatorStateHandle.Mode mode : OperatorStateHandle.Mode.values()) {
+			Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> map = new HashMap<>();
+			nameToStateByMode.put(
+					mode,
+					new HashMap<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>());
+		}
+
 		for (OperatorStateHandle psh : previousParallelSubtaskStates) {
 
-			for (Map.Entry<String, long[]> e : psh.getStateNameToPartitionOffsets().entrySet()) {
+			for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e :
+					psh.getStateNameToPartitionOffsets().entrySet()) {
+				OperatorStateHandle.StateMetaInfo metaInfo = e.getValue();
+
+				Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> nameToState =
+						nameToStateByMode.get(metaInfo.getDistributionMode());
 
-				List<Tuple2<StreamStateHandle, long[]>> stateLocations = nameToState.get(e.getKey());
+				List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> stateLocations =
+						nameToState.get(e.getKey());
 
 				if (stateLocations == null) {
 					stateLocations = new ArrayList<>();
@@ -90,32 +107,40 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 				stateLocations.add(new Tuple2<>(psh.getDelegateStateHandle(), e.getValue()));
 			}
 		}
-		return nameToState;
+
+		return new GroupByStateNameResults(nameToStateByMode);
 	}
 
 	/**
 	 * Repartition all named states.
 	 */
 	private List<Map<StreamStateHandle, OperatorStateHandle>> repartition(
-			Map<String, List<Tuple2<StreamStateHandle, long[]>>> nameToState, int parallelism) {
+			GroupByStateNameResults nameToStateByMode,
+			int parallelism) {
 
 		// We will use this to merge w.r.t. StreamStateHandles for each parallel subtask inside the maps
 		List<Map<StreamStateHandle, OperatorStateHandle>> mergeMapList = new ArrayList<>(parallelism);
+
 		// Initialize
 		for (int i = 0; i < parallelism; ++i) {
 			mergeMapList.add(new HashMap<StreamStateHandle, OperatorStateHandle>());
 		}
 
-		int startParallelOP = 0;
+		// Start with the state handles we distribute round robin by splitting by offsets
+		Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> distributeNameToState =
+				nameToStateByMode.getByMode(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
+
+		int startParallelOp = 0;
 		// Iterate all named states and repartition one named state at a time per iteration
-		for (Map.Entry<String, List<Tuple2<StreamStateHandle, long[]>>> e : nameToState.entrySet()) {
+		for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e :
+				distributeNameToState.entrySet()) {
 
-			List<Tuple2<StreamStateHandle, long[]>> current = e.getValue();
+			List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
 
 			// Determine actual number of partitions for this named state
 			int totalPartitions = 0;
-			for (Tuple2<StreamStateHandle, long[]> offsets : current) {
-				totalPartitions += offsets.f1.length;
+			for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> offsets : current) {
+				totalPartitions += offsets.f1.getOffsets().length;
 			}
 
 			// Repartition the state across the parallel operator instances
@@ -124,12 +149,12 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 			int baseFraction = totalPartitions / parallelism;
 			int remainder = totalPartitions % parallelism;
 
-			int newStartParallelOp = startParallelOP;
+			int newStartParallelOp = startParallelOp;
 
 			for (int i = 0; i < parallelism; ++i) {
 
 				// Preparation: calculate the actual index considering wrap around
-				int parallelOpIdx = (i + startParallelOP) % parallelism;
+				int parallelOpIdx = (i + startParallelOp) % parallelism;
 
 				// Now calculate the number of partitions we will assign to the parallel instance in this round ...
 				int numberOfPartitionsToAssign = baseFraction;
@@ -146,11 +171,14 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 				}
 
 				// Now start collection the partitions for the parallel instance into this list
-				List<Tuple2<StreamStateHandle, long[]>> parallelOperatorState = new ArrayList<>();
+				List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> parallelOperatorState =
+						new ArrayList<>();
 
 				while (numberOfPartitionsToAssign > 0) {
-					Tuple2<StreamStateHandle, long[]> handleWithOffsets = current.get(lstIdx);
-					long[] offsets = handleWithOffsets.f1;
+					Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithOffsets =
+							current.get(lstIdx);
+
+					long[] offsets = handleWithOffsets.f1.getOffsets();
 					int remaining = offsets.length - offsetIdx;
 					// Repartition offsets
 					long[] offs;
@@ -166,25 +194,74 @@ public class RoundRobinOperatorStateRepartitioner implements OperatorStateRepart
 						++lstIdx;
 					}
 
-					parallelOperatorState.add(
-							new Tuple2<>(handleWithOffsets.f0, offs));
+					parallelOperatorState.add(new Tuple2<>(
+							handleWithOffsets.f0,
+							new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE)));
 
 					numberOfPartitionsToAssign -= remaining;
 
 					// As a last step we merge partitions that use the same StreamStateHandle in a single
 					// OperatorStateHandle
 					Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(parallelOpIdx);
-					OperatorStateHandle psh = mergeMap.get(handleWithOffsets.f0);
-					if (psh == null) {
-						psh = new OperatorStateHandle(new HashMap<String, long[]>(), handleWithOffsets.f0);
-						mergeMap.put(handleWithOffsets.f0, psh);
+					OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithOffsets.f0);
+					if (operatorStateHandle == null) {
+						operatorStateHandle = new OperatorStateHandle(
+								new HashMap<String, OperatorStateHandle.StateMetaInfo>(),
+								handleWithOffsets.f0);
+
+						mergeMap.put(handleWithOffsets.f0, operatorStateHandle);
 					}
-					psh.getStateNameToPartitionOffsets().put(e.getKey(), offs);
+					operatorStateHandle.getStateNameToPartitionOffsets().put(
+							e.getKey(),
+							new OperatorStateHandle.StateMetaInfo(offs, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 				}
 			}
-			startParallelOP = newStartParallelOp;
+			startParallelOp = newStartParallelOp;
 			e.setValue(null);
 		}
+
+		// Now we also add the state handles marked for broadcast to all parallel instances
+		Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> broadcastNameToState =
+				nameToStateByMode.getByMode(OperatorStateHandle.Mode.BROADCAST);
+
+		for (int i = 0; i < parallelism; ++i) {
+
+			Map<StreamStateHandle, OperatorStateHandle> mergeMap = mergeMapList.get(i);
+
+			for (Map.Entry<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> e :
+					broadcastNameToState.entrySet()) {
+
+				List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>> current = e.getValue();
+
+				for (Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo> handleWithMetaInfo : current) {
+					OperatorStateHandle operatorStateHandle = mergeMap.get(handleWithMetaInfo.f0);
+					if (operatorStateHandle == null) {
+						operatorStateHandle = new OperatorStateHandle(
+								new HashMap<String, OperatorStateHandle.StateMetaInfo>(),
+								handleWithMetaInfo.f0);
+
+						mergeMap.put(handleWithMetaInfo.f0, operatorStateHandle);
+					}
+					operatorStateHandle.getStateNameToPartitionOffsets().put(e.getKey(), handleWithMetaInfo.f1);
+				}
+			}
+		}
 		return mergeMapList;
 	}
+
+	private static final class GroupByStateNameResults {
+		private final EnumMap<OperatorStateHandle.Mode,
+				Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode;
+
+		public GroupByStateNameResults(
+				EnumMap<OperatorStateHandle.Mode,
+						Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>>> byMode) {
+			this.byMode = Preconditions.checkNotNull(byMode);
+		}
+
+		public Map<String, List<Tuple2<StreamStateHandle, OperatorStateHandle.StateMetaInfo>>> getByMode(
+				OperatorStateHandle.Mode mode) {
+			return byMode.get(mode);
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
index 2e05a85..f11f69b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StateAssignmentOperation.java
@@ -338,9 +338,22 @@ public class StateAssignmentOperation {
 					chainOpParallelStates,
 					newParallelism);
 		} else {
-
 			List<Collection<OperatorStateHandle>> repackStream = new ArrayList<>(newParallelism);
 			for (OperatorStateHandle operatorStateHandle : chainOpParallelStates) {
+
+				Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsets =
+						operatorStateHandle.getStateNameToPartitionOffsets();
+
+				for (OperatorStateHandle.StateMetaInfo metaInfo : partitionOffsets.values()) {
+
+					// if we find any broadcast state, we cannot take the shortcut and need to go through repartitioning
+					if (OperatorStateHandle.Mode.BROADCAST.equals(metaInfo.getDistributionMode())) {
+						return opStateRepartitioner.repartitionState(
+								chainOpParallelStates,
+								newParallelism);
+					}
+				}
+
 				repackStream.add(Collections.singletonList(operatorStateHandle));
 			}
 			return repackStream;

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
index 48324ca..ba1949a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Serializer.java
@@ -250,11 +250,18 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 
 		if (stateHandle != null) {
 			dos.writeByte(PARTITIONABLE_OPERATOR_STATE_HANDLE);
-			Map<String, long[]> partitionOffsetsMap = stateHandle.getStateNameToPartitionOffsets();
+			Map<String, OperatorStateHandle.StateMetaInfo> partitionOffsetsMap =
+					stateHandle.getStateNameToPartitionOffsets();
 			dos.writeInt(partitionOffsetsMap.size());
-			for (Map.Entry<String, long[]> entry : partitionOffsetsMap.entrySet()) {
+			for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : partitionOffsetsMap.entrySet()) {
 				dos.writeUTF(entry.getKey());
-				long[] offsets = entry.getValue();
+
+				OperatorStateHandle.StateMetaInfo stateMetaInfo = entry.getValue();
+
+				int mode = stateMetaInfo.getDistributionMode().ordinal();
+				dos.writeByte(mode);
+
+				long[] offsets = stateMetaInfo.getOffsets();
 				dos.writeInt(offsets.length);
 				for (long offset : offsets) {
 					dos.writeLong(offset);
@@ -274,14 +281,21 @@ class SavepointV1Serializer implements SavepointSerializer<SavepointV1> {
 			return null;
 		} else if (PARTITIONABLE_OPERATOR_STATE_HANDLE == type) {
 			int mapSize = dis.readInt();
-			Map<String, long[]> offsetsMap = new HashMap<>(mapSize);
+			Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(mapSize);
 			for (int i = 0; i < mapSize; ++i) {
 				String key = dis.readUTF();
+
+				int modeOrdinal = dis.readByte();
+				OperatorStateHandle.Mode mode = OperatorStateHandle.Mode.values()[modeOrdinal];
+
 				long[] offsets = new long[dis.readInt()];
 				for (int j = 0; j < offsets.length; ++j) {
 					offsets[j] = dis.readLong();
 				}
-				offsetsMap.put(key, offsets);
+
+				OperatorStateHandle.StateMetaInfo metaInfo =
+						new OperatorStateHandle.StateMetaInfo(offsets, mode);
+				offsetsMap.put(key, metaInfo);
 			}
 			StreamStateHandle stateHandle = deserializeStreamStateHandle(dis);
 			return new OperatorStateHandle(offsetsMap, stateHandle);

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
index 10bb409..6c65088 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.state;
 
 import org.apache.commons.io.IOUtils;
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
@@ -44,6 +45,7 @@ import java.util.concurrent.RunnableFuture;
 /**
  * Default implementation of OperatorStateStore that provides the ability to make snapshots.
  */
+@Internal
 public class DefaultOperatorStateBackend implements OperatorStateBackend {
 
 	/** The default namespace for state in cases where no state name is provided */
@@ -62,14 +64,46 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		this.registeredStates = new HashMap<>();
 	}
 
+	@Override
+	public Set<String> getRegisteredStateNames() {
+		return registeredStates.keySet();
+	}
+
+	@Override
+	public void close() throws IOException {
+		closeStreamOnCancelRegistry.close();
+	}
+
+	@Override
+	public void dispose() {
+		registeredStates.clear();
+	}
+
 	@SuppressWarnings("unchecked")
 	@Override
 	public <T extends Serializable> ListState<T> getSerializableListState(String stateName) throws Exception {
 		return (ListState<T>) getOperatorState(new ListStateDescriptor<>(stateName, javaSerializer));
 	}
-	
+
 	@Override
 	public <S> ListState<S> getOperatorState(ListStateDescriptor<S> stateDescriptor) throws IOException {
+		return getOperatorState(stateDescriptor, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
+	}
+
+	@SuppressWarnings("unchecked")
+	@Override
+	public <T extends Serializable> ListState<T> getBroadcastSerializableListState(String stateName) throws Exception {
+		return (ListState<T>) getBroadcastOperatorState(new ListStateDescriptor<>(stateName, javaSerializer));
+	}
+
+	@Override
+	public <S> ListState<S> getBroadcastOperatorState(ListStateDescriptor<S> stateDescriptor) throws Exception {
+		return getOperatorState(stateDescriptor, OperatorStateHandle.Mode.BROADCAST);
+	}
+
+	private <S> ListState<S> getOperatorState(
+			ListStateDescriptor<S> stateDescriptor,
+			OperatorStateHandle.Mode mode) throws IOException {
 
 		Preconditions.checkNotNull(stateDescriptor);
 
@@ -81,10 +115,18 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 
 		if (null == partitionableListState) {
 
-			partitionableListState = new PartitionableListState<>(name, partitionStateSerializer);
+			partitionableListState = new PartitionableListState<>(
+					name,
+					partitionStateSerializer,
+					mode);
+
 			registeredStates.put(name, partitionableListState);
 		} else {
 			Preconditions.checkState(
+					partitionableListState.getAssignmentMode().equals(mode),
+					"Incompatible assignment mode. Provided: " + mode + ", expected: " +
+							partitionableListState.getAssignmentMode());
+			Preconditions.checkState(
 					partitionableListState.getPartitionStateSerializer().
 							isCompatibleWith(stateDescriptor.getSerializer()),
 					"Incompatible type serializers. Provided: " + stateDescriptor.getSerializer() +
@@ -97,16 +139,21 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 	private static <S> void deserializeStateValues(
 			PartitionableListState<S> stateListForName,
 			FSDataInputStream in,
-			long[] offsets) throws IOException {
-
-		DataInputView div = new DataInputViewStreamWrapper(in);
-		TypeSerializer<S> serializer = stateListForName.getPartitionStateSerializer();
-		for (long offset : offsets) {
-			in.seek(offset);
-			stateListForName.add(serializer.deserialize(div));
+			OperatorStateHandle.StateMetaInfo metaInfo) throws IOException {
+
+		if (null != metaInfo) {
+			long[] offsets = metaInfo.getOffsets();
+			if (null != offsets) {
+				DataInputView div = new DataInputViewStreamWrapper(in);
+				TypeSerializer<S> serializer = stateListForName.getPartitionStateSerializer();
+				for (long offset : offsets) {
+					in.seek(offset);
+					stateListForName.add(serializer.deserialize(div));
+				}
+			}
 		}
 	}
-	
+
 	@Override
 	public RunnableFuture<OperatorStateHandle> snapshot(
 			long checkpointId, long timestamp, CheckpointStreamFactory streamFactory) throws Exception {
@@ -123,11 +170,12 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 			OperatorBackendSerializationProxy.StateMetaInfo<?> metaInfo =
 					new OperatorBackendSerializationProxy.StateMetaInfo<>(
 							state.getName(),
-							state.getPartitionStateSerializer());
+							state.getPartitionStateSerializer(),
+							state.getAssignmentMode());
 			metaInfoList.add(metaInfo);
 		}
 
-		Map<String, long[]> writtenStatesMetaData = new HashMap<>(registeredStates.size());
+		Map<String, OperatorStateHandle.StateMetaInfo> writtenStatesMetaData = new HashMap<>(registeredStates.size());
 
 		CheckpointStreamFactory.CheckpointStateOutputStream out = streamFactory.
 				createCheckpointStateOutputStream(checkpointId, timestamp);
@@ -145,8 +193,10 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 			dov.writeInt(registeredStates.size());
 			for (Map.Entry<String, PartitionableListState<?>> entry : registeredStates.entrySet()) {
 
-				long[] partitionOffsets = entry.getValue().write(out);
-				writtenStatesMetaData.put(entry.getKey(), partitionOffsets);
+				PartitionableListState<?> value = entry.getValue();
+				long[] partitionOffsets = value.write(out);
+				OperatorStateHandle.Mode mode = value.getAssignmentMode();
+				writtenStatesMetaData.put(entry.getKey(), new OperatorStateHandle.StateMetaInfo(partitionOffsets, mode));
 			}
 
 			OperatorStateHandle handle = new OperatorStateHandle(writtenStatesMetaData, out.closeAndGetHandle());
@@ -193,7 +243,8 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 					if (null == listState) {
 						listState = new PartitionableListState<>(
 								stateMetaInfo.getName(),
-								stateMetaInfo.getStateSerializer());
+								stateMetaInfo.getStateSerializer(),
+								stateMetaInfo.getMode());
 
 						registeredStates.put(listState.getName(), listState);
 					} else {
@@ -205,7 +256,9 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 				}
 
 				// Restore all the state in PartitionableListStates
-				for (Map.Entry<String, long[]> nameToOffsets : stateHandle.getStateNameToPartitionOffsets().entrySet()) {
+				for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> nameToOffsets :
+						stateHandle.getStateNameToPartitionOffsets().entrySet()) {
+
 					PartitionableListState<?> stateListForName = registeredStates.get(nameToOffsets.getKey());
 
 					Preconditions.checkState(null != stateListForName, "Found state without " +
@@ -222,60 +275,40 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		}
 	}
 
-	@Override
-	public void dispose() {
-		registeredStates.clear();
-	}
-
-	@Override
-	public Set<String> getRegisteredStateNames() {
-		return registeredStates.keySet();
-	}
-
-	@Override
-	public void close() throws IOException {
-		closeStreamOnCancelRegistry.close();
-	}
-
 	static final class PartitionableListState<S> implements ListState<S> {
 
-		private final List<S> internalList;
 		private final String name;
 		private final TypeSerializer<S> partitionStateSerializer;
+		private final OperatorStateHandle.Mode assignmentMode;
+		private final List<S> internalList;
 
-		public PartitionableListState(String name, TypeSerializer<S> partitionStateSerializer) {
-			this.internalList = new ArrayList<>();
-			this.partitionStateSerializer = Preconditions.checkNotNull(partitionStateSerializer);
-			this.name = Preconditions.checkNotNull(name);
-		}
-
-		public long[] write(FSDataOutputStream out) throws IOException {
-
-			long[] partitionOffsets = new long[internalList.size()];
-
-			DataOutputView dov = new DataOutputViewStreamWrapper(out);
-
-			for (int i = 0; i < internalList.size(); ++i) {
-				S element = internalList.get(i);
-				partitionOffsets[i] = out.getPos();
-				partitionStateSerializer.serialize(element, dov);
-			}
-
-			return partitionOffsets;
-		}
+		public PartitionableListState(
+				String name,
+				TypeSerializer<S> partitionStateSerializer,
+				OperatorStateHandle.Mode assignmentMode) {
 
-		public List<S> getInternalList() {
-			return internalList;
+			this.name = Preconditions.checkNotNull(name);
+			this.partitionStateSerializer = Preconditions.checkNotNull(partitionStateSerializer);
+			this.assignmentMode = Preconditions.checkNotNull(assignmentMode);
+			this.internalList = new ArrayList<>();
 		}
 
 		public String getName() {
 			return name;
 		}
 
+		public OperatorStateHandle.Mode getAssignmentMode() {
+			return assignmentMode;
+		}
+
 		public TypeSerializer<S> getPartitionStateSerializer() {
 			return partitionStateSerializer;
 		}
 
+		public List<S> getInternalList() {
+			return internalList;
+		}
+
 		@Override
 		public void clear() {
 			internalList.clear();
@@ -294,8 +327,25 @@ public class DefaultOperatorStateBackend implements OperatorStateBackend {
 		@Override
 		public String toString() {
 			return "PartitionableListState{" +
-					"listState=" + internalList +
+					"name='" + name + '\'' +
+					", assignmentMode=" + assignmentMode +
+					", internalList=" + internalList +
 					'}';
 		}
+
+		public long[] write(FSDataOutputStream out) throws IOException {
+
+			long[] partitionOffsets = new long[internalList.size()];
+
+			DataOutputView dov = new DataOutputViewStreamWrapper(out);
+
+			for (int i = 0; i < internalList.size(); ++i) {
+				S element = internalList.get(i);
+				partitionOffsets[i] = out.getPos();
+				partitionStateSerializer.serialize(element, dov);
+			}
+
+			return partitionOffsets;
+		}
 	}
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java
index 61df979..d571dcc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorBackendSerializationProxy.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.typeutils.runtime.DataInputViewStream;
 import org.apache.flink.api.java.typeutils.runtime.DataOutputViewStream;
@@ -91,15 +92,19 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab
 
 		private String name;
 		private TypeSerializer<S> stateSerializer;
+		private OperatorStateHandle.Mode mode;
+
 		private ClassLoader userClassLoader;
 
-		private StateMetaInfo(ClassLoader userClassLoader) {
+		@VisibleForTesting
+		public StateMetaInfo(ClassLoader userClassLoader) {
 			this.userClassLoader = Preconditions.checkNotNull(userClassLoader);
 		}
 
-		public StateMetaInfo(String name, TypeSerializer<S> stateSerializer) {
+		public StateMetaInfo(String name, TypeSerializer<S> stateSerializer, OperatorStateHandle.Mode mode) {
 			this.name = Preconditions.checkNotNull(name);
 			this.stateSerializer = Preconditions.checkNotNull(stateSerializer);
+			this.mode = Preconditions.checkNotNull(mode);
 		}
 
 		public String getName() {
@@ -118,9 +123,18 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab
 			this.stateSerializer = stateSerializer;
 		}
 
+		public OperatorStateHandle.Mode getMode() {
+			return mode;
+		}
+
+		public void setMode(OperatorStateHandle.Mode mode) {
+			this.mode = mode;
+		}
+
 		@Override
 		public void write(DataOutputView out) throws IOException {
 			out.writeUTF(getName());
+			out.writeByte(getMode().ordinal());
 			DataOutputViewStream dos = new DataOutputViewStream(out);
 			InstantiationUtil.serializeObject(dos, getStateSerializer());
 		}
@@ -128,6 +142,7 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab
 		@Override
 		public void read(DataInputView in) throws IOException {
 			setName(in.readUTF());
+			setMode(OperatorStateHandle.Mode.values()[in.readByte()]);
 			DataInputViewStream dis = new DataInputViewStream(in);
 			try {
 				TypeSerializer<S> stateSerializer = InstantiationUtil.deserializeObject(dis, userClassLoader);
@@ -136,5 +151,37 @@ public class OperatorBackendSerializationProxy extends VersionedIOReadableWritab
 				throw new IOException(exception);
 			}
 		}
+
+		@Override
+		public boolean equals(Object o) {
+
+			if (this == o) {
+				return true;
+			}
+
+			if (o == null || getClass() != o.getClass()) {
+				return false;
+			}
+
+			StateMetaInfo<?> metaInfo = (StateMetaInfo<?>) o;
+
+			if (!getName().equals(metaInfo.getName())) {
+				return false;
+			}
+
+			if (!getStateSerializer().equals(metaInfo.getStateSerializer())) {
+				return false;
+			}
+
+			return getMode() == metaInfo.getMode();
+		}
+
+		@Override
+		public int hashCode() {
+			int result = getName().hashCode();
+			result = 31 * result + getStateSerializer().hashCode();
+			result = 31 * result + getMode().hashCode();
+			return result;
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
index eaa9fd9..036aed0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateCheckpointOutputStream.java
@@ -66,8 +66,14 @@ public final class OperatorStateCheckpointOutputStream
 			startNewPartition();
 		}
 
-		Map<String, long[]> offsetsMap = new HashMap<>(1);
-		offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, partitionOffsets.toArray());
+		Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(1);
+
+		OperatorStateHandle.StateMetaInfo metaInfo =
+				new OperatorStateHandle.StateMetaInfo(
+						partitionOffsets.toArray(),
+						OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
+
+		offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, metaInfo);
 
 		return new OperatorStateHandle(offsetsMap, streamStateHandle);
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
index 3cd37c9..c59fbad 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/OperatorStateHandle.java
@@ -22,6 +22,7 @@ import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
+import java.io.Serializable;
 import java.util.Arrays;
 import java.util.Map;
 
@@ -31,21 +32,27 @@ import java.util.Map;
  */
 public class OperatorStateHandle implements StreamStateHandle {
 
+	public enum Mode {
+		SPLIT_DISTRIBUTE, BROADCAST
+	}
+
 	private static final long serialVersionUID = 35876522969227335L;
 
-	/** unique state name -> offsets for available partitions in the handle stream */
-	private final Map<String, long[]> stateNameToPartitionOffsets;
+	/**
+	 * unique state name -> offsets for available partitions in the handle stream
+	 */
+	private final Map<String, StateMetaInfo> stateNameToPartitionOffsets;
 	private final StreamStateHandle delegateStateHandle;
 
 	public OperatorStateHandle(
-			Map<String, long[]> stateNameToPartitionOffsets,
+			Map<String, StateMetaInfo> stateNameToPartitionOffsets,
 			StreamStateHandle delegateStateHandle) {
 
 		this.delegateStateHandle = Preconditions.checkNotNull(delegateStateHandle);
 		this.stateNameToPartitionOffsets = Preconditions.checkNotNull(stateNameToPartitionOffsets);
 	}
 
-	public Map<String, long[]> getStateNameToPartitionOffsets() {
+	public Map<String, StateMetaInfo> getStateNameToPartitionOffsets() {
 		return stateNameToPartitionOffsets;
 	}
 
@@ -80,12 +87,12 @@ public class OperatorStateHandle implements StreamStateHandle {
 
 		OperatorStateHandle that = (OperatorStateHandle) o;
 
-		if(stateNameToPartitionOffsets.size() != that.stateNameToPartitionOffsets.size()) {
+		if (stateNameToPartitionOffsets.size() != that.stateNameToPartitionOffsets.size()) {
 			return false;
 		}
 
-		for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) {
-			if (!Arrays.equals(entry.getValue(), that.stateNameToPartitionOffsets.get(entry.getKey()))) {
+		for (Map.Entry<String, StateMetaInfo> entry : stateNameToPartitionOffsets.entrySet()) {
+			if (!entry.getValue().equals(that.stateNameToPartitionOffsets.get(entry.getKey()))) {
 				return false;
 			}
 		}
@@ -96,14 +103,75 @@ public class OperatorStateHandle implements StreamStateHandle {
 	@Override
 	public int hashCode() {
 		int result = delegateStateHandle.hashCode();
-		for (Map.Entry<String, long[]> entry : stateNameToPartitionOffsets.entrySet()) {
+		for (Map.Entry<String, StateMetaInfo> entry : stateNameToPartitionOffsets.entrySet()) {
 
 			int entryHash = entry.getKey().hashCode();
 			if (entry.getValue() != null) {
-				entryHash += Arrays.hashCode(entry.getValue());
+				entryHash += entry.getValue().hashCode();
 			}
 			result = 31 * result + entryHash;
 		}
 		return result;
 	}
+
+	@Override
+	public String toString() {
+		return "OperatorStateHandle{" +
+				"stateNameToPartitionOffsets=" + stateNameToPartitionOffsets +
+				", delegateStateHandle=" + delegateStateHandle +
+				'}';
+	}
+
+	public static class StateMetaInfo implements Serializable {
+
+		private static final long serialVersionUID = 3593817615858941166L;
+
+		private final long[] offsets;
+		private final Mode distributionMode;
+
+		public StateMetaInfo(long[] offsets, Mode distributionMode) {
+			this.offsets = Preconditions.checkNotNull(offsets);
+			this.distributionMode = Preconditions.checkNotNull(distributionMode);
+		}
+
+		public long[] getOffsets() {
+			return offsets;
+		}
+
+		public Mode getDistributionMode() {
+			return distributionMode;
+		}
+
+		@Override
+		public boolean equals(Object o) {
+			if (this == o) {
+				return true;
+			}
+			if (o == null || getClass() != o.getClass()) {
+				return false;
+			}
+
+			StateMetaInfo that = (StateMetaInfo) o;
+
+			if (!Arrays.equals(getOffsets(), that.getOffsets())) {
+				return false;
+			}
+			return getDistributionMode() == that.getDistributionMode();
+		}
+
+		@Override
+		public int hashCode() {
+			int result = Arrays.hashCode(getOffsets());
+			result = 31 * result + getDistributionMode().hashCode();
+			return result;
+		}
+
+		@Override
+		public String toString() {
+			return "StateMetaInfo{" +
+					"offsets=" + Arrays.toString(offsets) +
+					", distributionMode=" + distributionMode +
+					'}';
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
index 46445d2..886d214 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateInitializationContextImpl.java
@@ -220,13 +220,21 @@ public class StateInitializationContextImpl implements StateInitializationContex
 
 			while (stateHandleIterator.hasNext()) {
 				currentStateHandle = stateHandleIterator.next();
-				long[] offsets = currentStateHandle.getStateNameToPartitionOffsets().get(stateName);
-				if (null != offsets && offsets.length > 0) {
+				OperatorStateHandle.StateMetaInfo metaInfo =
+						currentStateHandle.getStateNameToPartitionOffsets().get(stateName);
 
-					this.offsets = offsets;
-					this.offPos = 0;
+				if (null != metaInfo) {
+					long[] metaOffsets = metaInfo.getOffsets();
+					if (null != metaOffsets && metaOffsets.length > 0) {
+						this.offsets = metaOffsets;
+						this.offPos = 0;
 
-					return true;
+						closableRegistry.unregisterClosable(currentStream);
+						IOUtils.closeQuietly(currentStream);
+						currentStream = null;
+
+						return true;
+					}
 				}
 			}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index daacbfb..ca9dbc2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -941,7 +941,7 @@ public class CheckpointCoordinatorTest {
 	}
 
 	@Test
-	public void handleMessagesForNonExistingCheckpoints() {
+	public void testHandleMessagesForNonExistingCheckpoints() {
 		try {
 			final JobID jid = new JobID();
 			final long timestamp = System.currentTimeMillis();
@@ -1937,8 +1937,8 @@ public class CheckpointCoordinatorTest {
 		coord.restoreLatestCheckpointedState(tasks, true, false);
 
 		// verify the restored state
-		verifiyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1);
-		verifiyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2);
+		verifyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1);
+		verifyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2);
 	}
 
 	/**
@@ -2318,7 +2318,7 @@ public class CheckpointCoordinatorTest {
 		coord.restoreLatestCheckpointedState(tasks, true, false);
 
 		// verify the restored state
-		verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
+		verifyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
 		List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism());
 		List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism());
 		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
@@ -2390,6 +2390,49 @@ public class CheckpointCoordinatorTest {
 		}
 	}
 
+	@Test
+	public void testReplicateModeStateHandle() {
+		Map<String, OperatorStateHandle.StateMetaInfo> metaInfoMap = new HashMap<>(1);
+		metaInfoMap.put("t-1", new OperatorStateHandle.StateMetaInfo(new long[]{0, 23}, OperatorStateHandle.Mode.BROADCAST));
+		metaInfoMap.put("t-2", new OperatorStateHandle.StateMetaInfo(new long[]{42, 64}, OperatorStateHandle.Mode.BROADCAST));
+		metaInfoMap.put("t-3", new OperatorStateHandle.StateMetaInfo(new long[]{72, 83}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+		OperatorStateHandle osh = new OperatorStateHandle(metaInfoMap, new ByteStreamStateHandle("test", new byte[100]));
+
+		OperatorStateRepartitioner repartitioner = RoundRobinOperatorStateRepartitioner.INSTANCE;
+		List<Collection<OperatorStateHandle>> repartitionedStates =
+				repartitioner.repartitionState(Collections.singletonList(osh), 3);
+
+		Map<String, Integer> checkCounts = new HashMap<>(3);
+
+		for (Collection<OperatorStateHandle> operatorStateHandles : repartitionedStates) {
+			for (OperatorStateHandle operatorStateHandle : operatorStateHandles) {
+				for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> stateNameToMetaInfo :
+						operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+
+					String stateName = stateNameToMetaInfo.getKey();
+					Integer count = checkCounts.get(stateName);
+					if (null == count) {
+						checkCounts.put(stateName, 1);
+					} else {
+						checkCounts.put(stateName, 1 + count);
+					}
+
+					OperatorStateHandle.StateMetaInfo stateMetaInfo = stateNameToMetaInfo.getValue();
+					if (OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.equals(stateMetaInfo.getDistributionMode())) {
+						Assert.assertEquals(1, stateNameToMetaInfo.getValue().getOffsets().length);
+					} else {
+						Assert.assertEquals(2, stateNameToMetaInfo.getValue().getOffsets().length);
+					}
+				}
+			}
+		}
+
+		Assert.assertEquals(3, checkCounts.size());
+		Assert.assertEquals(3, checkCounts.get("t-1").intValue());
+		Assert.assertEquals(3, checkCounts.get("t-2").intValue());
+		Assert.assertEquals(2, checkCounts.get("t-3").intValue());
+	}
+
 	// ------------------------------------------------------------------------
 	//  Utilities
 	// ------------------------------------------------------------------------
@@ -2520,11 +2563,15 @@ public class CheckpointCoordinatorTest {
 
 		Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
 
-		Map<String, long[]> offsetsMap = new HashMap<>(states.size());
+		Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size());
 
 		int idx = 0;
 		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
-			offsetsMap.put(entry.getKey(), serializationWithOffsets.f1.get(idx));
+			offsetsMap.put(
+					entry.getKey(),
+					new OperatorStateHandle.StateMetaInfo(
+							serializationWithOffsets.f1.get(idx),
+							OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 			++idx;
 		}
 
@@ -2601,7 +2648,7 @@ public class CheckpointCoordinatorTest {
 		return vertex;
 	}
 
-	public static void verifiyStateRestore(
+	public static void verifyStateRestore(
 			JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
 			List<KeyGroupRange> keyGroupPartitions) throws Exception {
 
@@ -2697,8 +2744,8 @@ public class CheckpointCoordinatorTest {
 
 	private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
 		try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
-			for (Map.Entry<String, long[]> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
-				for (long offset : entry.getValue()) {
+			for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+				for (long offset : entry.getValue().getOffsets()) {
 					in.seek(offset);
 					Integer state = InstantiationUtil.
 							deserializeObject(in, Thread.currentThread().getContextClassLoader());
@@ -2801,17 +2848,22 @@ public class CheckpointCoordinatorTest {
 
 		for (int i = 0; i < oldParallelism; ++i) {
 			Path fakePath = new Path("/fake-" + i);
-			Map<String, long[]> namedStatesToOffsets = new HashMap<>();
+			Map<String, OperatorStateHandle.StateMetaInfo> namedStatesToOffsets = new HashMap<>();
 			int off = 0;
 			for (int s = 0; s < numNamedStates; ++s) {
 				long[] offs = new long[1 + r.nextInt(maxPartitionsPerState)];
-				if (offs.length > 0) {
-					for (int o = 0; o < offs.length; ++o) {
-						offs[o] = off;
-						++off;
-					}
-					namedStatesToOffsets.put("State-" + s, offs);
+
+				for (int o = 0; o < offs.length; ++o) {
+					offs[o] = off;
+					++off;
 				}
+
+				OperatorStateHandle.Mode mode = r.nextInt(10) == 0 ?
+						OperatorStateHandle.Mode.BROADCAST : OperatorStateHandle.Mode.SPLIT_DISTRIBUTE;
+				namedStatesToOffsets.put(
+						"State-" + s,
+						new OperatorStateHandle.StateMetaInfo(offs, mode));
+
 			}
 
 			previousParallelOpInstanceStates.add(
@@ -2822,14 +2874,21 @@ public class CheckpointCoordinatorTest {
 
 		int expectedTotalPartitions = 0;
 		for (OperatorStateHandle psh : previousParallelOpInstanceStates) {
-			Map<String, long[]> offsMap = psh.getStateNameToPartitionOffsets();
+			Map<String, OperatorStateHandle.StateMetaInfo> offsMap = psh.getStateNameToPartitionOffsets();
 			Map<String, List<Long>> offsMapWithList = new HashMap<>(offsMap.size());
-			for (Map.Entry<String, long[]> e : offsMap.entrySet()) {
-				long[] offs = e.getValue();
-				expectedTotalPartitions += offs.length;
+			for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> e : offsMap.entrySet()) {
+
+				long[] offs = e.getValue().getOffsets();
+				int replication = e.getValue().getDistributionMode().equals(OperatorStateHandle.Mode.BROADCAST) ?
+						newParallelism : 1;
+
+				expectedTotalPartitions += replication * offs.length;
 				List<Long> offsList = new ArrayList<>(offs.length);
+
 				for (int i = 0; i < offs.length; ++i) {
-					offsList.add(i, offs[i]);
+					for(int p = 0; p < replication; ++p) {
+						offsList.add(offs[i]);
+					}
 				}
 				offsMapWithList.put(e.getKey(), offsList);
 			}
@@ -2851,25 +2910,25 @@ public class CheckpointCoordinatorTest {
 
 			Collection<OperatorStateHandle> pshc = pshs.get(p);
 			for (OperatorStateHandle sh : pshc) {
-				for (Map.Entry<String, long[]> namedState : sh.getStateNameToPartitionOffsets().entrySet()) {
+				for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> namedState : sh.getStateNameToPartitionOffsets().entrySet()) {
 
-					Map<String, List<Long>> x = actual.get(sh.getDelegateStateHandle());
-					if (x == null) {
-						x = new HashMap<>();
-						actual.put(sh.getDelegateStateHandle(), x);
+					Map<String, List<Long>> stateToOffsets = actual.get(sh.getDelegateStateHandle());
+					if (stateToOffsets == null) {
+						stateToOffsets = new HashMap<>();
+						actual.put(sh.getDelegateStateHandle(), stateToOffsets);
 					}
 
-					List<Long> actualOffs = x.get(namedState.getKey());
+					List<Long> actualOffs = stateToOffsets.get(namedState.getKey());
 					if (actualOffs == null) {
 						actualOffs = new ArrayList<>();
-						x.put(namedState.getKey(), actualOffs);
+						stateToOffsets.put(namedState.getKey(), actualOffs);
 					}
-					long[] add = namedState.getValue();
+					long[] add = namedState.getValue().getOffsets();
 					for (int i = 0; i < add.length; ++i) {
 						actualOffs.add(add[i]);
 					}
 
-					partitionCount += namedState.getValue().length;
+					partitionCount += namedState.getValue().getOffsets().length;
 				}
 			}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
index db5c35b..5184db8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/SavepointV1Test.java
@@ -99,9 +99,10 @@ public class SavepointV1Test {
 							new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes());
 					StreamStateHandle operatorStateStream =
 							new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes());
-					Map<String, long[]> offsetsMap = new HashMap<>();
-					offsetsMap.put("A", new long[]{0, 10, 20});
-					offsetsMap.put("B", new long[]{30, 40, 50});
+					Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>();
+					offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+					offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+					offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.BROADCAST));
 
 					if (chainIdx != noNonPartitionableStateAtIndex) {
 						nonPartitionableStates.add(nonPartitionableState);

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
index 515011f..cd0391f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateBackendTest.java
@@ -31,12 +31,13 @@ import java.util.Iterator;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
 public class OperatorStateBackendTest {
 
-	AbstractStateBackend abstractStateBackend = new MemoryStateBackend(1024);
+	AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096);
 
 	static Environment createMockEnvironment() {
 		Environment env = mock(Environment.class);
@@ -62,6 +63,7 @@ public class OperatorStateBackendTest {
 		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
 		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
 		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
+		ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());
 		ListState<Serializable> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1);
 		assertNotNull(listState1);
 		assertEquals(1, operatorStateBackend.getRegisteredStateNames().size());
@@ -89,6 +91,20 @@ public class OperatorStateBackendTest {
 		assertEquals(23, it.next());
 		assertTrue(!it.hasNext());
 
+		ListState<Serializable> listState3 = operatorStateBackend.getBroadcastOperatorState(stateDescriptor3);
+		assertNotNull(listState3);
+		assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
+		assertTrue(!it.hasNext());
+		listState3.add(17);
+		listState3.add(3);
+		listState3.add(123);
+
+		it = listState3.get().iterator();
+		assertEquals(17, it.next());
+		assertEquals(3, it.next());
+		assertEquals(123, it.next());
+		assertTrue(!it.hasNext());
+
 		ListState<Serializable> listState1b = operatorStateBackend.getOperatorState(stateDescriptor1);
 		assertNotNull(listState1b);
 		listState1b.add(123);
@@ -109,6 +125,20 @@ public class OperatorStateBackendTest {
 		assertEquals(4711, it.next());
 		assertEquals(123, it.next());
 		assertTrue(!it.hasNext());
+
+		try {
+			operatorStateBackend.getBroadcastOperatorState(stateDescriptor2);
+			fail("Did not detect changed mode");
+		} catch (IllegalStateException ignored) {
+
+		}
+
+		try {
+			operatorStateBackend.getOperatorState(stateDescriptor3);
+			fail("Did not detect changed mode");
+		} catch (IllegalStateException ignored) {
+
+		}
 	}
 
 	@Test
@@ -116,8 +146,10 @@ public class OperatorStateBackendTest {
 		OperatorStateBackend operatorStateBackend = createNewOperatorStateBackend();
 		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
 		ListStateDescriptor<Serializable> stateDescriptor2 = new ListStateDescriptor<>("test2", new JavaSerializer<>());
+		ListStateDescriptor<Serializable> stateDescriptor3 = new ListStateDescriptor<>("test3", new JavaSerializer<>());
 		ListState<Serializable> listState1 = operatorStateBackend.getOperatorState(stateDescriptor1);
 		ListState<Serializable> listState2 = operatorStateBackend.getOperatorState(stateDescriptor2);
+		ListState<Serializable> listState3 = operatorStateBackend.getBroadcastOperatorState(stateDescriptor3);
 
 		listState1.add(42);
 		listState1.add(4711);
@@ -126,11 +158,17 @@ public class OperatorStateBackendTest {
 		listState2.add(13);
 		listState2.add(23);
 
+		listState3.add(17);
+		listState3.add(18);
+		listState3.add(19);
+		listState3.add(20);
+
 		CheckpointStreamFactory streamFactory = abstractStateBackend.createStreamFactory(new JobID(), "testOperator");
 		OperatorStateHandle stateHandle = operatorStateBackend.snapshot(1, 1, streamFactory).get();
 
 		try {
 
+			operatorStateBackend.close();
 			operatorStateBackend.dispose();
 
 			operatorStateBackend = abstractStateBackend.createOperatorStateBackend(
@@ -139,13 +177,13 @@ public class OperatorStateBackendTest {
 
 			operatorStateBackend.restore(Collections.singletonList(stateHandle));
 
-			assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
+			assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
 
 			listState1 = operatorStateBackend.getOperatorState(stateDescriptor1);
 			listState2 = operatorStateBackend.getOperatorState(stateDescriptor2);
+			listState3 = operatorStateBackend.getBroadcastOperatorState(stateDescriptor3);
 
-			assertEquals(2, operatorStateBackend.getRegisteredStateNames().size());
-
+			assertEquals(3, operatorStateBackend.getRegisteredStateNames().size());
 
 			Iterator<Serializable> it = listState1.get().iterator();
 			assertEquals(42, it.next());
@@ -158,6 +196,14 @@ public class OperatorStateBackendTest {
 			assertEquals(23, it.next());
 			assertTrue(!it.hasNext());
 
+			it = listState3.get().iterator();
+			assertEquals(17, it.next());
+			assertEquals(18, it.next());
+			assertEquals(19, it.next());
+			assertEquals(20, it.next());
+			assertTrue(!it.hasNext());
+
+			operatorStateBackend.close();
 			operatorStateBackend.dispose();
 		} finally {
 			stateHandle.discardState();

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
new file mode 100644
index 0000000..ab801b6
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateHandleTest.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class OperatorStateHandleTest {
+
+	@Test
+	public void testFixedEnumOrder() {
+
+		// Ensure the order / ordinal of all values of enum 'mode' are fixed, as this is used for serialization
+		Assert.assertEquals(0, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE.ordinal());
+		Assert.assertEquals(1, OperatorStateHandle.Mode.BROADCAST.ordinal());
+
+		// Ensure all enum values are registered and fixed forever by this test
+		Assert.assertEquals(2, OperatorStateHandle.Mode.values().length);
+
+		// Byte is used to encode enum value on serialization
+		Assert.assertTrue(OperatorStateHandle.Mode.values().length <= Byte.MAX_VALUE);
+	}
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
index c6ef0f0..7efcd0d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/OperatorStateOutputCheckpointStreamTest.java
@@ -27,6 +27,7 @@ import org.junit.Assert;
 import org.junit.Test;
 
 import java.io.IOException;
+import java.util.Map;
 
 public class OperatorStateOutputCheckpointStreamTest {
 
@@ -77,15 +78,23 @@ public class OperatorStateOutputCheckpointStreamTest {
 		OperatorStateHandle fullHandle = writeAllTestKeyGroups(stream, numPartitions);
 		Assert.assertNotNull(fullHandle);
 
+		Map<String, OperatorStateHandle.StateMetaInfo> stateNameToPartitionOffsets =
+				fullHandle.getStateNameToPartitionOffsets();
+		for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : stateNameToPartitionOffsets.entrySet()) {
+
+			Assert.assertEquals(OperatorStateHandle.Mode.SPLIT_DISTRIBUTE, entry.getValue().getDistributionMode());
+		}
 		verifyRead(fullHandle, numPartitions);
 	}
 
 	private static void verifyRead(OperatorStateHandle fullHandle, int numPartitions) throws IOException {
 		int count = 0;
 		try (FSDataInputStream in = fullHandle.openInputStream()) {
-			long[] offsets = fullHandle.getStateNameToPartitionOffsets().
+			OperatorStateHandle.StateMetaInfo metaInfo = fullHandle.getStateNameToPartitionOffsets().
 					get(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME);
 
+			long[] offsets = metaInfo.getOffsets();
+
 			Assert.assertNotNull(offsets);
 
 			DataInputView div = new DataInputViewStreamWrapper(in);

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
index 832b022..2448540 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SerializationProxiesTest.java
@@ -36,7 +36,7 @@ import java.util.List;
 public class SerializationProxiesTest {
 
 	@Test
-	public void testSerializationRoundtrip() throws Exception {
+	public void testKeyedBackendSerializationProxyRoundtrip() throws Exception {
 
 		TypeSerializer<?> keySerializer = IntSerializer.INSTANCE;
 		TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE;
@@ -67,13 +67,12 @@ public class SerializationProxiesTest {
 			serializationProxy.read(new DataInputViewStreamWrapper(in));
 		}
 
-
 		Assert.assertEquals(keySerializer, serializationProxy.getKeySerializerProxy().getTypeSerializer());
 		Assert.assertEquals(stateMetaInfoList, serializationProxy.getNamedStateSerializationProxies());
 	}
 
 	@Test
-	public void testMetaInfoSerialization() throws Exception {
+	public void testKeyedStateMetaInfoSerialization() throws Exception {
 
 		String name = "test";
 		TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE;
@@ -97,6 +96,64 @@ public class SerializationProxiesTest {
 		Assert.assertEquals(name, metaInfo.getStateName());
 	}
 
+
+	@Test
+	public void testOperatorBackendSerializationProxyRoundtrip() throws Exception {
+
+		TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
+
+		List<OperatorBackendSerializationProxy.StateMetaInfo<?>> stateMetaInfoList = new ArrayList<>();
+
+		stateMetaInfoList.add(
+				new OperatorBackendSerializationProxy.StateMetaInfo<>("a", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+		stateMetaInfoList.add(
+				new OperatorBackendSerializationProxy.StateMetaInfo<>("b", stateSerializer, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+		stateMetaInfoList.add(
+				new OperatorBackendSerializationProxy.StateMetaInfo<>("c", stateSerializer, OperatorStateHandle.Mode.BROADCAST));
+
+		OperatorBackendSerializationProxy serializationProxy =
+				new OperatorBackendSerializationProxy(stateMetaInfoList);
+
+		byte[] serialized;
+		try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) {
+			serializationProxy.write(new DataOutputViewStreamWrapper(out));
+			serialized = out.toByteArray();
+		}
+
+		serializationProxy =
+				new OperatorBackendSerializationProxy(Thread.currentThread().getContextClassLoader());
+
+		try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) {
+			serializationProxy.read(new DataInputViewStreamWrapper(in));
+		}
+
+		Assert.assertEquals(stateMetaInfoList, serializationProxy.getNamedStateSerializationProxies());
+	}
+
+	@Test
+	public void testOperatorStateMetaInfoSerialization() throws Exception {
+
+		String name = "test";
+		TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
+
+		OperatorBackendSerializationProxy.StateMetaInfo<?> metaInfo =
+				new OperatorBackendSerializationProxy.StateMetaInfo<>(name, stateSerializer, OperatorStateHandle.Mode.BROADCAST);
+
+		byte[] serialized;
+		try (ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos()) {
+			metaInfo.write(new DataOutputViewStreamWrapper(out));
+			serialized = out.toByteArray();
+		}
+
+		metaInfo = new OperatorBackendSerializationProxy.StateMetaInfo<>(Thread.currentThread().getContextClassLoader());
+
+		try (ByteArrayInputStreamWithPos in = new ByteArrayInputStreamWithPos(serialized)) {
+			metaInfo.read(new DataInputViewStreamWrapper(in));
+		}
+
+		Assert.assertEquals(name, metaInfo.getName());
+	}
+
 	/**
 	 * This test fixes the order of elements in the enum which is important for serialization. Do not modify this test
 	 * except if you are entirely sure what you are doing.

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
index 39dc5d6..963c42c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/StateInitializationContextImplTest.java
@@ -111,8 +111,10 @@ public class StateInitializationContextImplTest {
 				writtenOperatorStates.add(val);
 			}
 
-			Map<String, long[]> offsetsMap = new HashMap<>();
-			offsetsMap.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, offsets.toArray());
+			Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>();
+			offsetsMap.put(
+					DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME,
+					new OperatorStateHandle.StateMetaInfo(offsets.toArray(), OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 			OperatorStateHandle operatorStateHandle =
 					new OperatorStateHandle(offsetsMap, new ByteStateHandleCloseChecking("os-" + i, out.toByteArray()));
 			operatorStateHandles.add(operatorStateHandle);

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 0206cf5..58cfefd 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -191,8 +191,10 @@ public class InterruptSensitiveRestoreTest {
 		List<Collection<OperatorStateHandle>> operatorStateBackend = Collections.emptyList();
 		List<Collection<OperatorStateHandle>> operatorStateStream = Collections.emptyList();
 
-		Map<String, long[]> operatorStateMetadata = new HashMap<>(1);
-		operatorStateMetadata.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, new long[]{0});
+		Map<String, OperatorStateHandle.StateMetaInfo> operatorStateMetadata = new HashMap<>(1);
+		OperatorStateHandle.StateMetaInfo metaInfo =
+				new OperatorStateHandle.StateMetaInfo(new long[]{0}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
+		operatorStateMetadata.put(DefaultOperatorStateBackend.DEFAULT_OPERATOR_STATE_NAME, metaInfo);
 
 		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(new KeyGroupRange(0,0));
 

http://git-wip-us.apache.org/repos/asf/flink/blob/1020ba2c/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index da4a01b..45fcc25 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -86,7 +86,7 @@ public class RescalingITCase extends TestLogger {
 	private static final int numSlots = numTaskManagers * slotsPerTaskManager;
 
 	enum OperatorCheckpointMethod {
-		NON_PARTITIONED, CHECKPOINTED_FUNCTION, LIST_CHECKPOINTED
+		NON_PARTITIONED, CHECKPOINTED_FUNCTION, CHECKPOINTED_FUNCTION_BROADCAST, LIST_CHECKPOINTED
 	}
 
 	private static TestingCluster cluster;
@@ -179,7 +179,7 @@ public class RescalingITCase extends TestLogger {
 			Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft());
 
 			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
-				Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
+					Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
 
 			Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
 
@@ -270,7 +270,7 @@ public class RescalingITCase extends TestLogger {
 
 			assertTrue(String.valueOf(savepointResponse), savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
 
-			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath();
+			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) savepointResponse).savepointPath();
 
 			Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
 
@@ -339,16 +339,16 @@ public class RescalingITCase extends TestLogger {
 		JobID jobID = null;
 
 		try {
-			 jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
 			JobGraph jobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
-				parallelism,
-				maxParallelism,
-				parallelism,
-				numberKeys,
-				numberElements,
-				false,
-				100);
+					parallelism,
+					maxParallelism,
+					parallelism,
+					numberKeys,
+					numberElements,
+					false,
+					100);
 
 			jobID = jobGraph.getJobID();
 
@@ -366,7 +366,7 @@ public class RescalingITCase extends TestLogger {
 			for (int key = 0; key < numberKeys; key++) {
 				int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
 
-				expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex) , numberElements * key));
+				expectedResult.add(Tuple2.of(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex), numberElements * key));
 			}
 
 			assertEquals(expectedResult, actualResult);
@@ -377,7 +377,7 @@ public class RescalingITCase extends TestLogger {
 			Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID, Option.<String>empty()), deadline.timeLeft());
 
 			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
-				Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
+					Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
 
 			Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
 
@@ -392,13 +392,13 @@ public class RescalingITCase extends TestLogger {
 			jobID = null;
 
 			JobGraph scaledJobGraph = createJobGraphWithKeyedAndNonPartitionedOperatorState(
-				parallelism2,
-				maxParallelism,
-				parallelism,
-				numberKeys,
-				numberElements + numberElements2,
-				true,
-				100);
+					parallelism2,
+					maxParallelism,
+					parallelism,
+					numberKeys,
+					numberElements + numberElements2,
+					true,
+					100);
 
 			scaledJobGraph.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(savepointPath));
 
@@ -447,6 +447,16 @@ public class RescalingITCase extends TestLogger {
 	}
 
 	@Test
+	public void testSavepointRescalingInBroadcastOperatorState() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
+	}
+
+	@Test
+	public void testSavepointRescalingOutBroadcastOperatorState() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
+	}
+
+	@Test
 	public void testSavepointRescalingInPartitionedOperatorStateList() throws Exception {
 		testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.LIST_CHECKPOINTED);
 	}
@@ -474,7 +484,8 @@ public class RescalingITCase extends TestLogger {
 
 		int counterSize = Math.max(parallelism, parallelism2);
 
-		if(checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
+		if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION ||
+				checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
 			PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
 			PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
 		} else {
@@ -505,11 +516,12 @@ public class RescalingITCase extends TestLogger {
 				if (savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess) {
 					break;
 				}
+				System.out.println(savepointResponse);
 			}
 
 			assertTrue(savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
 
-			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath();
+			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess) savepointResponse).savepointPath();
 
 			Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
 
@@ -543,6 +555,16 @@ public class RescalingITCase extends TestLogger {
 				for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
 					sumAct += c;
 				}
+			} else if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
+				for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
+					sumExp += c;
+				}
+
+				for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
+					sumAct += c;
+				}
+
+				sumExp *= parallelism2;
 			} else {
 				for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT) {
 					sumExp += c;
@@ -587,7 +609,10 @@ public class RescalingITCase extends TestLogger {
 
 		switch (checkpointMethod) {
 			case CHECKPOINTED_FUNCTION:
-				src = new PartitionedStateSource();
+				src = new PartitionedStateSource(false);
+				break;
+			case CHECKPOINTED_FUNCTION_BROADCAST:
+				src = new PartitionedStateSource(true);
 				break;
 			case LIST_CHECKPOINTED:
 				src = new PartitionedStateSourceListCheckpointed();
@@ -607,12 +632,12 @@ public class RescalingITCase extends TestLogger {
 	}
 
 	private static JobGraph createJobGraphWithKeyedState(
-		int parallelism,
-		int maxParallelism,
-		int numberKeys,
-		int numberElements,
-		boolean terminateAfterEmission,
-		int checkpointingInterval) {
+			int parallelism,
+			int maxParallelism,
+			int numberKeys,
+			int numberElements,
+			boolean terminateAfterEmission,
+			int checkpointingInterval) {
 
 		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 		env.setParallelism(parallelism);
@@ -621,17 +646,17 @@ public class RescalingITCase extends TestLogger {
 		env.setRestartStrategy(RestartStrategies.noRestart());
 
 		DataStream<Integer> input = env.addSource(new SubtaskIndexSource(
-			numberKeys,
-			numberElements,
-			terminateAfterEmission))
-			.keyBy(new KeySelector<Integer, Integer>() {
-				private static final long serialVersionUID = -7952298871120320940L;
-
-				@Override
-				public Integer getKey(Integer value) throws Exception {
-					return value;
-				}
-			});
+				numberKeys,
+				numberElements,
+				terminateAfterEmission))
+				.keyBy(new KeySelector<Integer, Integer>() {
+					private static final long serialVersionUID = -7952298871120320940L;
+
+					@Override
+					public Integer getKey(Integer value) throws Exception {
+						return value;
+					}
+				});
 
 		SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
 
@@ -643,13 +668,13 @@ public class RescalingITCase extends TestLogger {
 	}
 
 	private static JobGraph createJobGraphWithKeyedAndNonPartitionedOperatorState(
-		int parallelism,
-		int maxParallelism,
-		int fixedParallelism,
-		int numberKeys,
-		int numberElements,
-		boolean terminateAfterEmission,
-		int checkpointingInterval) {
+			int parallelism,
+			int maxParallelism,
+			int fixedParallelism,
+			int numberKeys,
+			int numberElements,
+			boolean terminateAfterEmission,
+			int checkpointingInterval) {
 
 		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 		env.setParallelism(parallelism);
@@ -658,18 +683,18 @@ public class RescalingITCase extends TestLogger {
 		env.setRestartStrategy(RestartStrategies.noRestart());
 
 		DataStream<Integer> input = env.addSource(new SubtaskIndexNonPartitionedStateSource(
-			numberKeys,
-			numberElements,
-			terminateAfterEmission))
-			.setParallelism(fixedParallelism)
-			.keyBy(new KeySelector<Integer, Integer>() {
-				private static final long serialVersionUID = -7952298871120320940L;
-
-				@Override
-				public Integer getKey(Integer value) throws Exception {
-					return value;
-				}
-			});
+				numberKeys,
+				numberElements,
+				terminateAfterEmission))
+				.setParallelism(fixedParallelism)
+				.keyBy(new KeySelector<Integer, Integer>() {
+					private static final long serialVersionUID = -7952298871120320940L;
+
+					@Override
+					public Integer getKey(Integer value) throws Exception {
+						return value;
+					}
+				});
 
 		SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
 
@@ -681,7 +706,7 @@ public class RescalingITCase extends TestLogger {
 	}
 
 	private static class SubtaskIndexSource
-		extends RichParallelSourceFunction<Integer> {
+			extends RichParallelSourceFunction<Integer> {
 
 		private static final long serialVersionUID = -400066323594122516L;
 
@@ -694,9 +719,9 @@ public class RescalingITCase extends TestLogger {
 		private boolean running = true;
 
 		SubtaskIndexSource(
-			int numberKeys,
-			int numberElements,
-			boolean terminateAfterEmission) {
+				int numberKeys,
+				int numberElements,
+				boolean terminateAfterEmission) {
 
 			this.numberKeys = numberKeys;
 			this.numberElements = numberElements;
@@ -713,8 +738,8 @@ public class RescalingITCase extends TestLogger {
 				if (counter < numberElements) {
 					synchronized (lock) {
 						for (int value = subtaskIndex;
-							 value < numberKeys;
-							 value += getRuntimeContext().getNumberOfParallelSubtasks()) {
+						     value < numberKeys;
+						     value += getRuntimeContext().getNumberOfParallelSubtasks()) {
 
 							ctx.collect(value);
 						}
@@ -836,6 +861,7 @@ public class RescalingITCase extends TestLogger {
 				}
 
 				Thread.sleep(2);
+
 				if (counter == 10) {
 					workStartedLatch.countDown();
 				}
@@ -910,10 +936,14 @@ public class RescalingITCase extends TestLogger {
 		private static final int NUM_PARTITIONS = 7;
 
 		private ListState<Integer> counterPartitions;
+		private boolean broadcast;
 
 		private static int[] CHECK_CORRECT_SNAPSHOT;
 		private static int[] CHECK_CORRECT_RESTORE;
 
+		public PartitionedStateSource(boolean broadcast) {
+			this.broadcast = broadcast;
+		}
 
 		@Override
 		public void snapshotState(FunctionSnapshotContext context) throws Exception {
@@ -937,8 +967,15 @@ public class RescalingITCase extends TestLogger {
 
 		@Override
 		public void initializeState(FunctionInitializationContext context) throws Exception {
-			this.counterPartitions =
-					context.getOperatorStateStore().getSerializableListState("counter_partitions");
+
+			if (broadcast) {
+				this.counterPartitions =
+						context.getOperatorStateStore().getBroadcastSerializableListState("counter_partitions");
+			} else {
+				this.counterPartitions =
+						context.getOperatorStateStore().getSerializableListState("counter_partitions");
+			}
+
 			if (context.isRestored()) {
 				for (int v : counterPartitions.get()) {
 					counter += v;