You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2017/04/22 13:26:19 UTC

[1/3] flink git commit: [FLINK-6014] [checkpoint] Additional review changes

Repository: flink
Updated Branches:
  refs/heads/master db31ca3f8 -> aa21f853a


[FLINK-6014] [checkpoint] Additional review changes


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/aa21f853
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/aa21f853
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/aa21f853

Branch: refs/heads/master
Commit: aa21f853ab0380ec1f68ae1d0b7c8d9268da4533
Parents: 218bed8
Author: Stefan Richter <s....@data-artisans.com>
Authored: Sat Apr 22 01:23:09 2017 +0200
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Sat Apr 22 15:25:56 2017 +0200

----------------------------------------------------------------------
 .../AbstractCompletedCheckpointStore.java       |  37 ++++
 .../checkpoint/CheckpointCoordinator.java       |  41 ++---
 .../runtime/checkpoint/CompletedCheckpoint.java | 169 +++++++++++--------
 .../checkpoint/CompletedCheckpointStore.java    |  10 +-
 .../runtime/checkpoint/PendingCheckpoint.java   |  12 +-
 .../StandaloneCompletedCheckpointStore.java     |   9 +-
 .../flink/runtime/checkpoint/SubtaskState.java  |   5 -
 .../flink/runtime/checkpoint/TaskState.java     |   7 -
 .../ZooKeeperCompletedCheckpointStore.java      |   8 +-
 .../runtime/state/CompositeStateHandle.java     |  22 +--
 .../flink/runtime/state/SharedStateHandle.java  |   2 +-
 .../runtime/state/SharedStateRegistry.java      |  84 ++++-----
 .../apache/flink/runtime/state/StateObject.java |   6 +-
 .../CheckpointCoordinatorFailureTest.java       |  10 +-
 .../checkpoint/CheckpointCoordinatorTest.java   |  23 +--
 .../checkpoint/CheckpointStateRestoreTest.java  |   5 +-
 .../CompletedCheckpointStoreTest.java           |  28 ++-
 ...ExecutionGraphCheckpointCoordinatorTest.java |   7 +-
 .../checkpoint/PendingCheckpointTest.java       |   4 -
 .../StandaloneCompletedCheckpointStoreTest.java |  26 ++-
 ...ZooKeeperCompletedCheckpointStoreITCase.java |  41 ++---
 .../ZooKeeperCompletedCheckpointStoreTest.java  |  10 +-
 .../runtime/state/SharedStateRegistryTest.java  |  42 +----
 .../RecoverableCompletedCheckpointStore.java    |  13 +-
 24 files changed, 292 insertions(+), 329 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/AbstractCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/AbstractCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/AbstractCompletedCheckpointStore.java
new file mode 100644
index 0000000..f42fd06
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/AbstractCompletedCheckpointStore.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.SharedStateRegistry;
+
+/**
+ * This is the base class that provides implementation of some aspects common for all
+ * {@link CompletedCheckpointStore}s.
+ */
+public abstract class AbstractCompletedCheckpointStore implements CompletedCheckpointStore {
+
+	/**
+	 * Registry for shared states.
+	 */
+	protected final SharedStateRegistry sharedStateRegistry;
+
+	public AbstractCompletedCheckpointStore() {
+		this.sharedStateRegistry = new SharedStateRegistry();
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 5309dd4..256321e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -37,9 +37,7 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.TaskStateHandles;
-
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
@@ -110,9 +108,6 @@ public class CheckpointCoordinator {
 	/** Completed checkpoints. Implementations can be blocking. Make sure calls to methods
 	 * accessing this don't block the job manager actor and run asynchronously. */
 	private final CompletedCheckpointStore completedCheckpointStore;
-	
-	/** Registry for shared states */
-	private final SharedStateRegistry sharedStateRegistry;
 
 	/** Default directory for persistent checkpoints; <code>null</code> if none configured.
 	 * THIS WILL BE REPLACED BY PROPER STATE-BACKEND METADATA WRITING */
@@ -223,7 +218,6 @@ public class CheckpointCoordinator {
 		this.completedCheckpointStore = checkNotNull(completedCheckpointStore);
 		this.checkpointDirectory = checkpointDirectory;
 		this.executor = checkNotNull(executor);
-		this.sharedStateRegistry = new SharedStateRegistry();
 
 		this.recentPendingCheckpoints = new ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS);
 
@@ -288,7 +282,7 @@ public class CheckpointCoordinator {
 				}
 				pendingCheckpoints.clear();
 
-				completedCheckpointStore.shutdown(jobStatus, sharedStateRegistry);
+				completedCheckpointStore.shutdown(jobStatus);
 				checkpointIdCounter.shutdown(jobStatus);
 			}
 		}
@@ -732,7 +726,7 @@ public class CheckpointCoordinator {
 								"the state handle to avoid lingering state.", message.getCheckpointId(),
 							message.getTaskExecutionId(), message.getJob());
 
-						discardState(message.getJob(), message.getTaskExecutionId(), message.getCheckpointId(), message.getSubtaskState());
+						discardSubtaskState(message.getJob(), message.getTaskExecutionId(), message.getCheckpointId(), message.getSubtaskState());
 
 						break;
 					case DISCARDED:
@@ -741,7 +735,7 @@ public class CheckpointCoordinator {
 								"state handle tp avoid lingering state.",
 							message.getCheckpointId(), message.getTaskExecutionId(), message.getJob());
 
-						discardState(message.getJob(), message.getTaskExecutionId(), message.getCheckpointId(), message.getSubtaskState());
+						discardSubtaskState(message.getJob(), message.getTaskExecutionId(), message.getCheckpointId(), message.getSubtaskState());
 				}
 
 				return true;
@@ -767,7 +761,7 @@ public class CheckpointCoordinator {
 				}
 
 				// try to discard the state so that we don't have lingering state lying around
-				discardState(message.getJob(), message.getTaskExecutionId(), message.getCheckpointId(), message.getSubtaskState());
+				discardSubtaskState(message.getJob(), message.getTaskExecutionId(), message.getCheckpointId(), message.getSubtaskState());
 
 				return wasPendingCheckpoint;
 			}
@@ -805,16 +799,16 @@ public class CheckpointCoordinator {
 	
 			// the pending checkpoint must be discarded after the finalization
 			Preconditions.checkState(pendingCheckpoint.isDiscarded() && completedCheckpoint != null);
-	
+
 			try {
-				completedCheckpointStore.addCheckpoint(completedCheckpoint, sharedStateRegistry);
+				completedCheckpointStore.addCheckpoint(completedCheckpoint);
 			} catch (Exception exception) {
 				// we failed to store the completed checkpoint. Let's clean up
 				executor.execute(new Runnable() {
 					@Override
 					public void run() {
 						try {
-							completedCheckpoint.discardOnFail();
+							completedCheckpoint.discardOnFailedStoring();
 						} catch (Throwable t) {
 							LOG.warn("Could not properly discard completed checkpoint {}.", completedCheckpoint.getCheckpointID(), t);
 						}
@@ -953,7 +947,7 @@ public class CheckpointCoordinator {
 			}
 
 			// Recover the checkpoints
-			completedCheckpointStore.recover(sharedStateRegistry);
+			completedCheckpointStore.recover();
 
 			// restore from the latest checkpoint
 			CompletedCheckpoint latest = completedCheckpointStore.getLatestCheckpoint();
@@ -1017,7 +1011,7 @@ public class CheckpointCoordinator {
 		CompletedCheckpoint savepoint = SavepointLoader.loadAndValidateSavepoint(
 				job, tasks, savepointPath, userClassLoader, allowNonRestored);
 
-		completedCheckpointStore.addCheckpoint(savepoint, sharedStateRegistry);
+		completedCheckpointStore.addCheckpoint(savepoint);
 		
 		// Reset the checkpoint ID counter
 		long nextCheckpointId = savepoint.getCheckpointID() + 1;
@@ -1057,10 +1051,11 @@ public class CheckpointCoordinator {
 	public CompletedCheckpointStore getCheckpointStore() {
 		return completedCheckpointStore;
 	}
-	
-	public SharedStateRegistry getSharedStateRegistry() {
-		return sharedStateRegistry;
-	}
+
+//	@VisibleForTesting
+//	SharedStateRegistry getSharedStateRegistry() {
+//		return sharedStateRegistry;
+//	}
 
 	public CheckpointIDCounter getCheckpointIdCounter() {
 		return checkpointIdCounter;
@@ -1151,7 +1146,7 @@ public class CheckpointCoordinator {
 	 * @param checkpointId of the state object
 	 * @param subtaskState to discard asynchronously
 	 */
-	private void discardState(
+	private void discardSubtaskState(
 			final JobID jobId,
 			final ExecutionAttemptID executionAttemptID,
 			final long checkpointId,
@@ -1161,12 +1156,6 @@ public class CheckpointCoordinator {
 			executor.execute(new Runnable() {
 				@Override
 				public void run() {
-					try {
-						subtaskState.discardSharedStatesOnFail();
-					} catch (Throwable t1) {
-						LOG.warn("Could not properly discard shared states of checkpoint {} " +
-							"belonging to task {} of job {}.", checkpointId, executionAttemptID, jobId, t1);
-					}
 
 					try {
 						subtaskState.discardState();

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
index 58e91e1..79fc31f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
@@ -20,14 +20,12 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.runtime.checkpoint.CompletedCheckpointStats.DiscardCallback;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.ExceptionUtils;
-
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -107,7 +105,7 @@ public class CompletedCheckpoint implements Serializable {
 
 	/** Optional stats tracker callback for discard. */
 	@Nullable
-	private transient volatile DiscardCallback discardCallback;
+	private transient volatile CompletedCheckpointStats.DiscardCallback discardCallback;
 
 	// ------------------------------------------------------------------------
 
@@ -151,7 +149,7 @@ public class CompletedCheckpoint implements Serializable {
 		checkArgument((externalPointer == null) == (externalizedMetadata == null),
 				"external pointer without externalized metadata must be both null or both non-null");
 
-		checkArgument(!props.externalizeCheckpoint() || externalPointer != null, 
+		checkArgument(!props.externalizeCheckpoint() || externalPointer != null,
 			"Checkpoint properties require externalized checkpoint, but checkpoint is not externalized");
 
 		this.job = checkNotNull(job);
@@ -186,15 +184,14 @@ public class CompletedCheckpoint implements Serializable {
 		return props;
 	}
 
-	public void discardOnFail() throws Exception {
-		discard(null, true);
+	public void discardOnFailedStoring() throws Exception {
+		new UnstoredDiscardStategy().discard();
 	}
 
 	public boolean discardOnSubsume(SharedStateRegistry sharedStateRegistry) throws Exception {
-		Preconditions.checkNotNull(sharedStateRegistry, "The registry cannot be null.");
 
 		if (props.discardOnSubsumed()) {
-			discard(sharedStateRegistry, false);
+			new StoredDiscardStrategy(sharedStateRegistry).discard();
 			return true;
 		}
 
@@ -202,14 +199,13 @@ public class CompletedCheckpoint implements Serializable {
 	}
 
 	public boolean discardOnShutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
-		Preconditions.checkNotNull(sharedStateRegistry, "The registry cannot be null.");
 
 		if (jobStatus == JobStatus.FINISHED && props.discardOnJobFinished() ||
 				jobStatus == JobStatus.CANCELED && props.discardOnJobCancelled() ||
 				jobStatus == JobStatus.FAILED && props.discardOnJobFailed() ||
 				jobStatus == JobStatus.SUSPENDED && props.discardOnJobSuspended()) {
 
-			discard(sharedStateRegistry, false);
+			new StoredDiscardStrategy(sharedStateRegistry).discard();
 			return true;
 		} else {
 			if (externalPointer != null) {
@@ -221,53 +217,6 @@ public class CompletedCheckpoint implements Serializable {
 		}
 	}
 
-	private void discard(SharedStateRegistry sharedStateRegistry, boolean failed) throws Exception {
-		Preconditions.checkState(failed || (sharedStateRegistry != null),
-			"The registry must not be null if the complete checkpoint does not fail.");
-
-		try {
-			// collect exceptions and continue cleanup
-			Exception exception = null;
-
-			// drop the metadata, if we have some
-			if (externalizedMetadata != null) {
-				try {
-					externalizedMetadata.discardState();
-				} catch (Exception e) {
-					exception = e;
-				}
-			}
-
-			// In the cases where the completed checkpoint fails, the shared
-			// states have not been registered to the registry. It's the state
-			// handles' responsibility to discard their shared states.
-			if (!failed) {
-				unregisterSharedStates(sharedStateRegistry);
-			} else {
-				discardSharedStatesOnFail();
-			}
-
-			// discard private state objects
-			try {
-				StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
-			} catch (Exception e) {
-				exception = ExceptionUtils.firstOrSuppressed(e, exception);
-			}
-
-			if (exception != null) {
-				throw exception;
-			}
-		} finally {
-			taskStates.clear();
-
-			// to be null-pointer safe, copy reference to stack
-			DiscardCallback discardCallback = this.discardCallback;
-			if (discardCallback != null) {
-				discardCallback.notifyDiscardedCheckpoint();
-			}
-		}
-	}
-
 	public long getStateSize() {
 		long result = 0L;
 
@@ -319,30 +268,108 @@ public class CompletedCheckpoint implements Serializable {
 		sharedStateRegistry.registerAll(taskStates.values());
 	}
 
+	// --------------------------------------------------------------------------------------------
+
+	@Override
+	public String toString() {
+		return String.format("Checkpoint %d @ %d for %s", checkpointID, timestamp, job);
+	}
+
 	/**
-	 * Unregister all shared states from the given registry. This is method is
-	 * called when the completed checkpoint is subsumed or the job terminates.
-	 *
-	 * @param sharedStateRegistry The registry where shared states are registered
+	 * Base class for the discarding strategies of {@link CompletedCheckpoint}.
 	 */
-	private void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) {
-		sharedStateRegistry.unregisterAll(taskStates.values());
+	private abstract class DiscardStrategy {
+
+		protected Exception storedException;
+
+		public DiscardStrategy() {
+			this.storedException = null;
+		}
+
+		public void discard() throws Exception {
+
+			try {
+				// collect exceptions and continue cleanup
+				storedException = null;
+
+				doDiscardExternalizedMetaData();
+				doDiscardSharedState();
+				doDiscardPrivateState();
+				doReportStoredExceptions();
+			} finally {
+				clearTaskStatesAndNotifyDiscardCompleted();
+			}
+		}
+
+		protected void doDiscardExternalizedMetaData() {
+			// drop the metadata, if we have some
+			if (externalizedMetadata != null) {
+				try {
+					externalizedMetadata.discardState();
+				} catch (Exception e) {
+					storedException = e;
+				}
+			}
+		}
+
+		protected void doDiscardPrivateState() {
+			// discard private state objects
+			try {
+				StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
+			} catch (Exception e) {
+				storedException = ExceptionUtils.firstOrSuppressed(e, storedException);
+			}
+		}
+
+		protected abstract void doDiscardSharedState();
+
+		protected void doReportStoredExceptions() throws Exception {
+			if (storedException != null) {
+				throw storedException;
+			}
+		}
+
+		protected void clearTaskStatesAndNotifyDiscardCompleted() {
+			taskStates.clear();
+			// to be null-pointer safe, copy reference to stack
+			CompletedCheckpointStats.DiscardCallback discardCallback =
+				CompletedCheckpoint.this.discardCallback;
+
+			if (discardCallback != null) {
+				discardCallback.notifyDiscardedCheckpoint();
+			}
+		}
 	}
 
 	/**
-	 * Discard all shared states created in the checkpoint. This method is called
+	 * Discard all shared states created in the checkpoint. This strategy is applied
 	 * when the completed checkpoint fails to be added into the store.
 	 */
-	private void discardSharedStatesOnFail() throws Exception {
-		for (TaskState taskState : taskStates.values()) {
-			taskState.discardSharedStatesOnFail();
+	private class UnstoredDiscardStategy extends CompletedCheckpoint.DiscardStrategy {
+
+		@Override
+		protected void doDiscardSharedState() {
+			// nothing to do because we did not register any shared state yet. unregistered, new
+			// shared state is then still considered private state and deleted as part of
+			// doDiscardPrivateState().
 		}
 	}
 
-	// --------------------------------------------------------------------------------------------
+	/**
+	 * Unregister all shared states from the given registry. This is strategy is
+	 * applied when the completed checkpoint is subsumed or the job terminates.
+	 */
+	private class StoredDiscardStrategy extends CompletedCheckpoint.DiscardStrategy {
 
-	@Override
-	public String toString() {
-		return String.format("Checkpoint %d @ %d for %s", checkpointID, timestamp, job);
+		SharedStateRegistry sharedStateRegistry;
+
+		public StoredDiscardStrategy(SharedStateRegistry sharedStateRegistry) {
+			this.sharedStateRegistry = Preconditions.checkNotNull(sharedStateRegistry);
+		}
+
+		@Override
+		protected void doDiscardSharedState() {
+			sharedStateRegistry.unregisterAll(taskStates.values());
+		}
 	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
index 0ade25c..82193b5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
@@ -19,7 +19,6 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 
 import java.util.List;
 
@@ -34,16 +33,15 @@ public interface CompletedCheckpointStore {
 	 * <p>After a call to this method, {@link #getLatestCheckpoint()} returns the latest
 	 * available checkpoint.
 	 */
-	void recover(SharedStateRegistry sharedStateRegistry) throws Exception;
+	void recover() throws Exception;
 
 	/**
 	 * Adds a {@link CompletedCheckpoint} instance to the list of completed checkpoints.
 	 *
 	 * <p>Only a bounded number of checkpoints is kept. When exceeding the maximum number of
-	 * retained checkpoints, the oldest one will be discarded via {@link
-	 * CompletedCheckpoint#discardOnSubsume(SharedStateRegistry)} )}.
+	 * retained checkpoints, the oldest one will be discarded.
 	 */
-	void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception;
+	void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception;
 
 	/**
 	 * Returns the latest {@link CompletedCheckpoint} instance or <code>null</code> if none was
@@ -59,7 +57,7 @@ public interface CompletedCheckpointStore {
 	 *
 	 * @param jobStatus Job state on shut down
 	 */
-	void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception;
+	void shutdown(JobStatus jobStatus) throws Exception;
 
 	/**
 	 * Returns all {@link CompletedCheckpoint} instances.

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index 6805dea..900331b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -506,16 +506,8 @@ public class PendingCheckpoint {
 						@Override
 						public void run() {
 
-							// discard the shared states that are created in the checkpoint
-							for (TaskState taskState : taskStates.values()) {
-								try {
-									taskState.discardSharedStatesOnFail();
-								} catch (Throwable t) {
-									LOG.warn("Could not properly dispose unreferenced shared states.");
-								}
-							}
-
-							// discard the private states
+							// discard the private states.
+							// unregistered shared states are still considered private at this point.
 							try {
 								StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
 							} catch (Throwable t) {

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
index 9f833c3..f5e1db3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
@@ -20,7 +20,6 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobmanager.HighAvailabilityMode;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -33,7 +32,7 @@ import static org.apache.flink.util.Preconditions.checkArgument;
 /**
  * {@link CompletedCheckpointStore} for JobManagers running in {@link HighAvailabilityMode#NONE}.
  */
-public class StandaloneCompletedCheckpointStore implements CompletedCheckpointStore {
+public class StandaloneCompletedCheckpointStore extends AbstractCompletedCheckpointStore {
 
 	private static final Logger LOG = LoggerFactory.getLogger(StandaloneCompletedCheckpointStore.class);
 
@@ -57,12 +56,12 @@ public class StandaloneCompletedCheckpointStore implements CompletedCheckpointSt
 	}
 
 	@Override
-	public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void recover() throws Exception {
 		// Nothing to do
 	}
 
 	@Override
-	public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception {
 		
 		checkpoints.addLast(checkpoint);
 
@@ -99,7 +98,7 @@ public class StandaloneCompletedCheckpointStore implements CompletedCheckpointSt
 	}
 
 	@Override
-	public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void shutdown(JobStatus jobStatus) throws Exception {
 		try {
 			LOG.info("Shutting down");
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index e968643..121ac57 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -161,11 +161,6 @@ public class SubtaskState implements CompositeStateHandle {
 	}
 
 	@Override
-	public void discardSharedStatesOnFail() {
-		// No shared states
-	}
-
-	@Override
 	public long getStateSize() {
 		return stateSize;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
index 19fe962..4f5f536 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
@@ -144,13 +144,6 @@ public class TaskState implements CompositeStateHandle {
 	}
 
 	@Override
-	public void discardSharedStatesOnFail() {
-		for (SubtaskState subtaskState : subtaskStates.values()) {
-			subtaskState.discardSharedStatesOnFail();
-		}
-	}
-
-	@Override
 	public long getStateSize() {
 		long result = 0L;
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
index 07546ea..52a4eea 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
@@ -68,7 +68,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * checkpoints is consistent. Currently, after recovery we start out with only a single
  * checkpoint to circumvent those situations.
  */
-public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointStore {
+public class ZooKeeperCompletedCheckpointStore extends AbstractCompletedCheckpointStore {
 
 	private static final Logger LOG = LoggerFactory.getLogger(ZooKeeperCompletedCheckpointStore.class);
 
@@ -141,7 +141,7 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 	 * that the history of checkpoints is consistent.
 	 */
 	@Override
-	public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void recover() throws Exception {
 		LOG.info("Recovering checkpoints from ZooKeeper.");
 
 		// Clear local handles in order to prevent duplicates on
@@ -192,7 +192,7 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 	 * @param checkpoint Completed checkpoint to add.
 	 */
 	@Override
-	public void addCheckpoint(final CompletedCheckpoint checkpoint, final SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void addCheckpoint(final CompletedCheckpoint checkpoint) throws Exception {
 		checkNotNull(checkpoint, "Checkpoint");
 		
 		final String path = checkpointIdToPath(checkpoint.getCheckpointID());
@@ -281,7 +281,7 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 	}
 
 	@Override
-	public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void shutdown(JobStatus jobStatus) throws Exception {
 		if (jobStatus.isGloballyTerminalState()) {
 			LOG.info("Shutting down");
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
index 2ea5bc9..002b7c3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
@@ -28,16 +28,24 @@ package org.apache.flink.runtime.state;
  * received by the {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator}
  * and will be discarded when the checkpoint is discarded.
  * 
- * <p>The {@link SharedStateRegistry} is responsible for the discarding of the
- * shared states. The composite state handle should only delete those private
- * states in the {@link StateObject#discardState()} method.
+ * <p>The {@link SharedStateRegistry} is responsible for the discarding of registered
+ * shared states. Before their first registration through
+ * {@link #registerSharedStates(SharedStateRegistry)}, newly created shared state is still owned by
+ * this handle and considered as private state until it is registered for the first time. Registration
+ * transfers ownership to the {@link SharedStateRegistry}.
+ * The composite state handle should only delete all private states in the
+ * {@link StateObject#discardState()} method.
  */
 public interface CompositeStateHandle extends StateObject {
 
 	/**
-	 * Register both created and referenced shared states in the given
+	 * Register both newly created and already referenced shared states in the given
 	 * {@link SharedStateRegistry}. This method is called when the checkpoint
 	 * successfully completes or is recovered from failures.
+	 * <p>
+	 * After this is completed, newly created shared state is considered as published is no longer
+	 * owned by this handle. This means that it should no longer be deleted as part of calls to
+	 * {@link #discardState()}.
 	 *
 	 * @param stateRegistry The registry where shared states are registered.
 	 */
@@ -51,10 +59,4 @@ public interface CompositeStateHandle extends StateObject {
 	 * @param stateRegistry The registry where shared states are registered.
 	 */
 	void unregisterSharedStates(SharedStateRegistry stateRegistry);
-
-	/**
-	 * Discard all shared states created in this checkpoint. This method is
-	 * called when the checkpoint fails to complete.
-	 */
-	void discardSharedStatesOnFail() throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
index f856052..c8c4046 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
@@ -35,5 +35,5 @@ public interface SharedStateHandle extends StateObject {
 	/**
 	 * Return the identifier of the shared state.
 	 */
-	String getKey();
+	String getRegistrationKey();
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
index b5048d0..2cb43ac 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
@@ -18,13 +18,10 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.Serializable;
-import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -33,73 +30,78 @@ import java.util.Map;
  * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} to 
  * maintain the reference count of {@link SharedStateHandle}s which are shared
  * among different checkpoints.
+ *
  */
-public class SharedStateRegistry implements Serializable {
+public class SharedStateRegistry {
 
-	private static Logger LOG = LoggerFactory.getLogger(SharedStateRegistry.class);
+	private static final Logger LOG = LoggerFactory.getLogger(SharedStateRegistry.class);
 
-	private static final long serialVersionUID = -8357254413007773970L;
+	/** All registered state objects by an artificial key */
+	private final Map<String, SharedStateRegistry.SharedStateEntry> registeredStates;
 
-	/** All registered state objects */
-	private final Map<String, SharedStateEntry> registeredStates = new HashMap<>();
+	public SharedStateRegistry() {
+		this.registeredStates = new HashMap<>();
+	}
 
 	/**
-	 * Register the state in the registry
+	 * Register a reference to the given shared state in the registry. This increases the reference
+	 * count for the this shared state by one. Returns the reference count after the update.
 	 *
-	 * @param state The state to register
-	 * @param isNew True if the shared state is newly created
+	 * @param state the shared state for which we register a reference.
+	 * @return the updated reference count for the given shared state.
 	 */
-	public void register(SharedStateHandle state, boolean isNew) {
+	public int register(SharedStateHandle state) {
 		if (state == null) {
-			return;
+			return 0;
 		}
 
 		synchronized (registeredStates) {
-			SharedStateEntry entry = registeredStates.get(state.getKey());
-
-			if (isNew) {
-				Preconditions.checkState(entry == null,
-					"The state cannot be created more than once.");
+			SharedStateRegistry.SharedStateEntry entry =
+				registeredStates.get(state.getRegistrationKey());
 
-				registeredStates.put(state.getKey(), new SharedStateEntry(state));
+			if (entry == null) {
+				SharedStateRegistry.SharedStateEntry stateEntry =
+					new SharedStateRegistry.SharedStateEntry(state);
+				registeredStates.put(state.getRegistrationKey(), stateEntry);
+				return 1;
 			} else {
-				Preconditions.checkState(entry != null,
-					"The state cannot be referenced if it has not been created yet.");
-
 				entry.increaseReferenceCount();
+				return entry.getReferenceCount();
 			}
 		}
 	}
 
 	/**
-	 * Unregister the state in the registry
+	 * Unregister one reference to the given shared state in the registry. This decreases the
+	 * reference count by one. Once the count reaches zero, the shared state is deleted.
 	 *
-	 * @param state The state to unregister
+	 * @param state the shared state for which we unregister a reference.
+	 * @return the reference count for the shared state after the update.
 	 */
-	public void unregister(SharedStateHandle state) {
+	public int unregister(SharedStateHandle state) {
 		if (state == null) {
-			return;
+			return 0;
 		}
 
 		synchronized (registeredStates) {
-			SharedStateEntry entry = registeredStates.get(state.getKey());
+			SharedStateRegistry.SharedStateEntry entry = registeredStates.get(state.getRegistrationKey());
 
-			if (entry == null) {
-				throw new IllegalStateException("Cannot unregister an unexisted state.");
-			}
+			Preconditions.checkState(entry != null, "Cannot unregister a state that is not registered.");
 
 			entry.decreaseReferenceCount();
 
-			// Remove the state from the registry when it's not referenced any more.
-			if (entry.getReferenceCount() == 0) {
-				registeredStates.remove(state.getKey());
+			final int newReferenceCount = entry.getReferenceCount();
 
+			// Remove the state from the registry when it's not referenced any more.
+			if (newReferenceCount <= 0) {
+				registeredStates.remove(state.getRegistrationKey());
 				try {
 					entry.getState().discardState();
 				} catch (Exception e) {
-					LOG.warn("Cannot properly discard the state " + entry.getState() + ".", e);
+					LOG.warn("Cannot properly discard the state {}.", entry.getState(), e);
 				}
 			}
+			return newReferenceCount;
 		}
 	}
 
@@ -108,7 +110,7 @@ public class SharedStateRegistry implements Serializable {
 	 *
 	 * @param stateHandles The shared states to register.
 	 */
-	public void registerAll(Collection<? extends CompositeStateHandle> stateHandles) {
+	public void registerAll(Iterable<? extends CompositeStateHandle> stateHandles) {
 		if (stateHandles == null) {
 			return;
 		}
@@ -127,7 +129,7 @@ public class SharedStateRegistry implements Serializable {
 	 *
 	 * @param stateHandles The shared states to unregister.
 	 */
-	public void unregisterAll(Collection<? extends CompositeStateHandle> stateHandles) {
+	public void unregisterAll(Iterable<? extends CompositeStateHandle> stateHandles) {
 		if (stateHandles == null) {
 			return;
 		}
@@ -140,6 +142,7 @@ public class SharedStateRegistry implements Serializable {
 	}
 
 	private static class SharedStateEntry {
+
 		/** The shared object */
 		private final SharedStateHandle state;
 
@@ -168,10 +171,13 @@ public class SharedStateRegistry implements Serializable {
 		}
 	}
 
-
-	@VisibleForTesting
 	public int getReferenceCount(SharedStateHandle state) {
-		SharedStateEntry entry = registeredStates.get(state.getKey());
+		if (state == null) {
+			return 0;
+		}
+
+		SharedStateRegistry.SharedStateEntry entry =
+			registeredStates.get(state.getRegistrationKey());
 
 		return entry == null ? 0 : entry.getReferenceCount();
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
index 7f1dd18..3b49df7 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateObject.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.runtime.state;
 
+import java.io.Serializable;
+
 /**
  * Base of all handles that represent checkpointed state in some form. The object may hold
  * the (small) state directly, or contain a file path (state is in the file), or contain the
@@ -33,10 +35,10 @@ package org.apache.flink.runtime.state;
  * compatibility, they are not stored via {@link java.io.Serializable Java Serialization},
  * but through custom serializers.
  */
-public interface StateObject extends java.io.Serializable {
+public interface StateObject extends Serializable {
 
 	/**
-	 * Discards the state referred to by this handle, to free up resources in
+	 * Discards the state referred to and solemnly owned by this handle, to free up resources in
 	 * the persistent storage. This method is called when the state represented by this
 	 * object will not be used any more.
 	 */

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index 632f2c0..90b7fe7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -25,7 +25,6 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.util.TestLogger;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -39,9 +38,7 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
 @RunWith(PowerMockRunner.class)
@@ -105,19 +102,18 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		assertTrue(pendingCheckpoint.isDiscarded());
 
 		// make sure that the subtask state has been discarded after we could not complete it.
-		verify(subtaskState, times(1)).discardSharedStatesOnFail();
 		verify(subtaskState).discardState();
 	}
 
 	private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore {
 
 		@Override
-		public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
+		public void recover() throws Exception {
 			throw new UnsupportedOperationException("Not implemented.");
 		}
 
 		@Override
-		public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception {
+		public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception {
 			throw new Exception("The failing completed checkpoint store failed again... :-(");
 		}
 
@@ -127,7 +123,7 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		}
 
 		@Override
-		public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
+		public void shutdown(JobStatus jobStatus) throws Exception {
 			throw new UnsupportedOperationException("Not implemented.");
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index fabf3fc..24169f2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -88,7 +88,6 @@ import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
@@ -880,8 +879,6 @@ public class CheckpointCoordinatorTest {
 			assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// validate that all received subtask states in the first checkpoint have been discarded
-			verify(subtaskState1_1, times(1)).discardSharedStatesOnFail();
-			verify(subtaskState1_2, times(1)).discardSharedStatesOnFail();
 			verify(subtaskState1_1, times(1)).discardState();
 			verify(subtaskState1_2, times(1)).discardState();
 
@@ -907,7 +904,6 @@ public class CheckpointCoordinatorTest {
 			// send the last remaining ack for the first checkpoint. This should not do anything
 			SubtaskState subtaskState1_3 = mock(SubtaskState.class);
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3));
-			verify(subtaskState1_3, times(1)).discardSharedStatesOnFail();
 			verify(subtaskState1_3, times(1)).discardState();
 
 			coord.shutdown(JobStatus.FINISHED);
@@ -993,7 +989,6 @@ public class CheckpointCoordinatorTest {
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// validate that the received states have been discarded
-			verify(subtaskState, times(1)).discardSharedStatesOnFail();
 			verify(subtaskState, times(1)).discardState();
 
 			// no confirm message must have been sent
@@ -1117,7 +1112,6 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState));
 
 		// verify that the subtask state has registered its shared states at the registry
-		verify(triggerSubtaskState, never()).discardSharedStatesOnFail();
 		verify(triggerSubtaskState, never()).discardState();
 
 		SubtaskState unknownSubtaskState = mock(SubtaskState.class);
@@ -1126,7 +1120,6 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState));
 
 		// we should discard acknowledge messages from an unknown vertex belonging to our job
-		verify(unknownSubtaskState, times(1)).discardSharedStatesOnFail();
 		verify(unknownSubtaskState, times(1)).discardState();
 
 		SubtaskState differentJobSubtaskState = mock(SubtaskState.class);
@@ -1135,7 +1128,6 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState));
 
 		// we should not interfere with different jobs
-		verify(differentJobSubtaskState, never()).discardSharedStatesOnFail();
 		verify(differentJobSubtaskState, never()).discardState();
 
 		// duplicate acknowledge message for the trigger vertex
@@ -1143,7 +1135,6 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState));
 
 		// duplicate acknowledge messages for a known vertex should not trigger discarding the state
-		verify(triggerSubtaskState, never()).discardSharedStatesOnFail();
 		verify(triggerSubtaskState, never()).discardState();
 
 		// let the checkpoint fail at the first ack vertex
@@ -1153,7 +1144,6 @@ public class CheckpointCoordinatorTest {
 		assertTrue(pendingCheckpoint.isDiscarded());
 
 		// check that we've cleaned up the already acknowledged state
-		verify(triggerSubtaskState, times(1)).discardSharedStatesOnFail();
 		verify(triggerSubtaskState, times(1)).discardState();
 
 		SubtaskState ackSubtaskState = mock(SubtaskState.class);
@@ -1162,7 +1152,6 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState));
 
 		// check that we also cleaned up this state
-		verify(ackSubtaskState, times(1)).discardSharedStatesOnFail();
 		verify(ackSubtaskState, times(1)).discardState();
 
 		// receive an acknowledge message from an unknown job
@@ -1170,7 +1159,6 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState));
 
 		// we should not interfere with different jobs
-		verify(differentJobSubtaskState, never()).discardSharedStatesOnFail();
 		verify(differentJobSubtaskState, never()).discardState();
 
 		SubtaskState unknownSubtaskState2 = mock(SubtaskState.class);
@@ -1179,7 +1167,6 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2));
 
 		// we should discard acknowledge messages from an unknown vertex belonging to our job
-		verify(unknownSubtaskState2, times(1)).discardSharedStatesOnFail();
 		verify(unknownSubtaskState2, times(1)).discardState();
 	}
 
@@ -2013,14 +2000,13 @@ public class CheckpointCoordinatorTest {
 		assertEquals(1, completedCheckpoints.size());
 
 		// shutdown the store
-		SharedStateRegistry sharedStateRegistry = coord.getSharedStateRegistry();
-		store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
+		store.shutdown(JobStatus.SUSPENDED);
 
 		// All shared states should be unregistered once the store is shut down
 		for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) {
 			for (TaskState taskState : completedCheckpoint.getTaskStates().values()) {
 				for (SubtaskState subtaskState : taskState.getStates()) {
-					verify(subtaskState, times(1)).unregisterSharedStates(sharedStateRegistry);
+					verify(subtaskState, times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
 				}
 			}
 		}
@@ -2037,7 +2023,7 @@ public class CheckpointCoordinatorTest {
 		for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) {
 			for (TaskState taskState : completedCheckpoint.getTaskStates().values()) {
 				for (SubtaskState subtaskState : taskState.getStates()) {
-					verify(subtaskState, times(2)).registerSharedStates(sharedStateRegistry);
+					verify(subtaskState, times(2)).registerSharedStates(any(SharedStateRegistry.class));
 				}
 			}
 		}
@@ -3150,8 +3136,7 @@ public class CheckpointCoordinatorTest {
 			Executors.directExecutor());
 
 		store.addCheckpoint(
-			new CompletedCheckpoint(new JobID(), 0, 0, 0, Collections.<JobVertexID, TaskState>emptyMap()),
-			coord.getSharedStateRegistry());
+			new CompletedCheckpoint(new JobID(), 0, 0, 0, Collections.<JobVertexID, TaskState>emptyMap()));
 
 		CheckpointStatsTracker tracker = mock(CheckpointStatsTracker.class);
 		coord.setCheckpointStatsTracker(tracker);

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 9e372e1..2fc1de5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -30,7 +30,6 @@ import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -255,7 +254,7 @@ public class CheckpointStateRestoreTest {
 		}
 		CompletedCheckpoint checkpoint = new CompletedCheckpoint(new JobID(), 0, 1, 2, new HashMap<>(checkpointTaskStates));
 
-		coord.getCheckpointStore().addCheckpoint(checkpoint, coord.getSharedStateRegistry());
+		coord.getCheckpointStore().addCheckpoint(checkpoint);
 
 		coord.restoreLatestCheckpointedState(tasks, true, false);
 		coord.restoreLatestCheckpointedState(tasks, true, true);
@@ -273,7 +272,7 @@ public class CheckpointStateRestoreTest {
 
 		checkpoint = new CompletedCheckpoint(new JobID(), 1, 2, 3, new HashMap<>(checkpointTaskStates));
 
-		coord.getCheckpointStore().addCheckpoint(checkpoint, coord.getSharedStateRegistry());
+		coord.getCheckpointStore().addCheckpoint(checkpoint);
 
 		// (i) Allow non restored state (should succeed)
 		coord.restoreLatestCheckpointedState(tasks, true, true);

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index aa1726b..4a36dd2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -37,6 +37,7 @@ import java.util.concurrent.CountDownLatch;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -68,7 +69,6 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 	@Test
 	public void testAddAndGetLatestCheckpoint() throws Exception {
 		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4);
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		
 		// Empty state
 		assertEquals(0, checkpoints.getNumberOfRetainedCheckpoints());
@@ -78,11 +78,11 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 				createCheckpoint(0), createCheckpoint(1) };
 
 		// Add and get latest
-		checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
+		checkpoints.addCheckpoint(expected[0]);
 		assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints());
 		verifyCheckpoint(expected[0], checkpoints.getLatestCheckpoint());
 
-		checkpoints.addCheckpoint(expected[1], sharedStateRegistry);
+		checkpoints.addCheckpoint(expected[1]);
 		assertEquals(2, checkpoints.getNumberOfRetainedCheckpoints());
 		verifyCheckpoint(expected[1], checkpoints.getLatestCheckpoint());
 	}
@@ -93,8 +93,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 	 */
 	@Test
 	public void testAddCheckpointMoreThanMaxRetained() throws Exception {
-		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(1);   
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(1);
 
 		TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] {
 				createCheckpoint(0), createCheckpoint(1),
@@ -102,13 +101,13 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		};
 
 		// Add checkpoints
-		checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
+		checkpoints.addCheckpoint(expected[0]);
 		assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints());
 
 		for (int i = 1; i < expected.length; i++) {
 			Collection<TaskState> taskStates = expected[i - 1].getTaskStates().values();
 
-			checkpoints.addCheckpoint(expected[i], sharedStateRegistry);
+			checkpoints.addCheckpoint(expected[i]);
 
 			// The ZooKeeper implementation discards asynchronously
 			expected[i - 1].awaitDiscard();
@@ -117,7 +116,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 
 			for (TaskState taskState : taskStates) {
 				for (SubtaskState subtaskState : taskState.getStates()) {
-					verify(subtaskState, times(1)).unregisterSharedStates(sharedStateRegistry);
+					verify(subtaskState, times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
 				}
 			}
 		}
@@ -146,7 +145,6 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 	@Test
 	public void testGetAllCheckpoints() throws Exception {
 		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4);
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 
 		TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] {
 				createCheckpoint(0), createCheckpoint(1),
@@ -154,7 +152,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		};
 
 		for (TestCompletedCheckpoint checkpoint : expected) {
-			checkpoints.addCheckpoint(checkpoint, sharedStateRegistry);
+			checkpoints.addCheckpoint(checkpoint);
 		}
 
 		List<CompletedCheckpoint> actual = checkpoints.getAllCheckpoints();
@@ -172,7 +170,6 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 	@Test
 	public void testDiscardAllCheckpoints() throws Exception {
 		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4);
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 
 		TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] {
 				createCheckpoint(0), createCheckpoint(1),
@@ -180,10 +177,10 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		};
 
 		for (TestCompletedCheckpoint checkpoint : expected) {
-			checkpoints.addCheckpoint(checkpoint, sharedStateRegistry);
+			checkpoints.addCheckpoint(checkpoint);
 		}
 
-		checkpoints.shutdown(JobStatus.FINISHED, sharedStateRegistry);
+		checkpoints.shutdown(JobStatus.FINISHED);
 
 		// Empty state
 		assertNull(checkpoints.getLatestCheckpoint());
@@ -235,10 +232,10 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		}
 	}
 
-	protected void verifyCheckpointRegistered(Collection<TaskState> taskStates, SharedStateRegistry sharedStateRegistry) {
+	protected void verifyCheckpointRegistered(Collection<TaskState> taskStates, SharedStateRegistry registry) {
 		for (TaskState taskState : taskStates) {
 			for (SubtaskState subtaskState : taskState.getStates()) {
-				verify(subtaskState, times(1)).registerSharedStates(eq(sharedStateRegistry));
+				verify(subtaskState, times(1)).registerSharedStates(eq(registry));
 			}
 		}
 	}
@@ -246,7 +243,6 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 	protected void verifyCheckpointDiscarded(Collection<TaskState> taskStates) {
 		for (TaskState taskState : taskStates) {
 			for (SubtaskState subtaskState : taskState.getStates()) {
-				verify(subtaskState, times(1)).discardSharedStatesOnFail();
 				verify(subtaskState, times(1)).discardState();
 			}
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
index e7c1c3b..5fce62e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
@@ -31,17 +31,14 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.jobmanager.scheduler.Scheduler;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.util.SerializedValue;
 
 import org.junit.Test;
-import org.mockito.Matchers;
 
 import java.net.URL;
 import java.util.Collections;
 
-import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
@@ -62,7 +59,7 @@ public class ExecutionGraphCheckpointCoordinatorTest {
 		graph.fail(new Exception("Test Exception"));
 
 		verify(counter, times(1)).shutdown(JobStatus.FAILED);
-		verify(store, times(1)).shutdown(eq(JobStatus.FAILED), any(SharedStateRegistry.class));
+		verify(store, times(1)).shutdown(eq(JobStatus.FAILED));
 	}
 
 	/**
@@ -79,7 +76,7 @@ public class ExecutionGraphCheckpointCoordinatorTest {
 
 		// No shutdown
 		verify(counter, times(1)).shutdown(eq(JobStatus.SUSPENDED));
-		verify(store, times(1)).shutdown(eq(JobStatus.SUSPENDED), any(SharedStateRegistry.class));
+		verify(store, times(1)).shutdown(eq(JobStatus.SUSPENDED));
 	}
 
 	private ExecutionGraph createExecutionGraphAndEnableCheckpointing(

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index d77fac1..2dd1803 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -207,7 +207,6 @@ public class PendingCheckpointTest {
 		// execute asynchronous discard operation
 		executor.runQueuedCommands();
 		verify(state, times(1)).discardState();
-		verify(state, times(1)).discardSharedStatesOnFail();
 
 		// Abort error
 		Mockito.reset(state);
@@ -219,7 +218,6 @@ public class PendingCheckpointTest {
 		// execute asynchronous discard operation
 		executor.runQueuedCommands();
 		verify(state, times(1)).discardState();
-		verify(state, times(1)).discardSharedStatesOnFail();
 
 		// Abort expired
 		Mockito.reset(state);
@@ -231,7 +229,6 @@ public class PendingCheckpointTest {
 		// execute asynchronous discard operation
 		executor.runQueuedCommands();
 		verify(state, times(1)).discardState();
-		verify(state, times(1)).discardSharedStatesOnFail();
 
 		// Abort subsumed
 		Mockito.reset(state);
@@ -243,7 +240,6 @@ public class PendingCheckpointTest {
 		// execute asynchronous discard operation
 		executor.runQueuedCommands();
 		verify(state, times(1)).discardState();
-		verify(state, times(1)).discardSharedStatesOnFail();
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
index 7a85897..64aeeba 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
@@ -30,6 +30,7 @@ import java.util.List;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
 import static org.powermock.api.mockito.PowerMockito.doReturn;
 import static org.powermock.api.mockito.PowerMockito.doThrow;
 import static org.powermock.api.mockito.PowerMockito.mock;
@@ -40,7 +41,7 @@ import static org.powermock.api.mockito.PowerMockito.mock;
 public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointStoreTest {
 
 	@Override
-	protected CompletedCheckpointStore createCompletedCheckpoints(
+	protected AbstractCompletedCheckpointStore createCompletedCheckpoints(
 			int maxNumberOfCheckpointsToRetain) throws Exception {
 
 		return new StandaloneCompletedCheckpointStore(maxNumberOfCheckpointsToRetain);
@@ -51,16 +52,15 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS
 	 */
 	@Test
 	public void testShutdownDiscardsCheckpoints() throws Exception {
-		CompletedCheckpointStore store = createCompletedCheckpoints(1);
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+		AbstractCompletedCheckpointStore store = createCompletedCheckpoints(1);
 		TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 		Collection<TaskState> taskStates = checkpoint.getTaskStates().values();
 
-		store.addCheckpoint(checkpoint, sharedStateRegistry);
+		store.addCheckpoint(checkpoint);
 		assertEquals(1, store.getNumberOfRetainedCheckpoints());
-		verifyCheckpointRegistered(taskStates, sharedStateRegistry);
+		verifyCheckpointRegistered(taskStates, store.sharedStateRegistry);
 
-		store.shutdown(JobStatus.FINISHED, sharedStateRegistry);
+		store.shutdown(JobStatus.FINISHED);
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 		assertTrue(checkpoint.isDiscarded());
 		verifyCheckpointDiscarded(taskStates);
@@ -72,16 +72,15 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS
 	 */
 	@Test
 	public void testSuspendDiscardsCheckpoints() throws Exception {
-		CompletedCheckpointStore store = createCompletedCheckpoints(1);
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+		AbstractCompletedCheckpointStore store = createCompletedCheckpoints(1);
 		TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 		Collection<TaskState> taskStates = checkpoint.getTaskStates().values();
 
-		store.addCheckpoint(checkpoint, sharedStateRegistry);
+		store.addCheckpoint(checkpoint);
 		assertEquals(1, store.getNumberOfRetainedCheckpoints());
-		verifyCheckpointRegistered(taskStates, sharedStateRegistry);
+		verifyCheckpointRegistered(taskStates, store.sharedStateRegistry);
 
-		store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
+		store.shutdown(JobStatus.SUSPENDED);
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 		assertTrue(checkpoint.isDiscarded());
 		verifyCheckpointDiscarded(taskStates);
@@ -96,16 +95,15 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS
 		
 		final int numCheckpointsToRetain = 1;
 		CompletedCheckpointStore store = createCompletedCheckpoints(numCheckpointsToRetain);
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		
 		for (long i = 0; i <= numCheckpointsToRetain; ++i) {
 			CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class);
 			doReturn(i).when(checkpointToAdd).getCheckpointID();
 			doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates();
-			doThrow(new IOException()).when(checkpointToAdd).discardOnSubsume(sharedStateRegistry);
+			doThrow(new IOException()).when(checkpointToAdd).discardOnSubsume(any(SharedStateRegistry.class));
 			
 			try {
-				store.addCheckpoint(checkpointToAdd, sharedStateRegistry);
+				store.addCheckpoint(checkpointToAdd);
 				
 				// The checkpoint should be in the store if we successfully add it into the store.
 				List<CompletedCheckpoint> addedCheckpoints = store.getAllCheckpoints();

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
index 607e773..73fcf78 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
@@ -22,7 +22,6 @@ import org.apache.curator.framework.CuratorFramework;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment;
 import org.junit.AfterClass;
@@ -59,7 +58,7 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 	}
 
 	@Override
-	protected CompletedCheckpointStore createCompletedCheckpoints(
+	protected AbstractCompletedCheckpointStore createCompletedCheckpoints(
 			int maxNumberOfCheckpointsToRetain) throws Exception {
 
 		return new ZooKeeperCompletedCheckpointStore(maxNumberOfCheckpointsToRetain,
@@ -80,21 +79,20 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 	 */
 	@Test
 	public void testRecover() throws Exception {
-		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(3);
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+		AbstractCompletedCheckpointStore checkpoints = createCompletedCheckpoints(3);
 
 		TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] {
 				createCheckpoint(0), createCheckpoint(1), createCheckpoint(2)
 		};
 
 		// Add multiple checkpoints
-		checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
-		checkpoints.addCheckpoint(expected[1], sharedStateRegistry);
-		checkpoints.addCheckpoint(expected[2], sharedStateRegistry);
+		checkpoints.addCheckpoint(expected[0]);
+		checkpoints.addCheckpoint(expected[1]);
+		checkpoints.addCheckpoint(expected[2]);
 
-		verifyCheckpointRegistered(expected[0].getTaskStates().values(), sharedStateRegistry);
-		verifyCheckpointRegistered(expected[1].getTaskStates().values(), sharedStateRegistry);
-		verifyCheckpointRegistered(expected[2].getTaskStates().values(), sharedStateRegistry);
+		verifyCheckpointRegistered(expected[0].getTaskStates().values(), checkpoints.sharedStateRegistry);
+		verifyCheckpointRegistered(expected[1].getTaskStates().values(), checkpoints.sharedStateRegistry);
+		verifyCheckpointRegistered(expected[2].getTaskStates().values(), checkpoints.sharedStateRegistry);
 
 		// All three should be in ZK
 		assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size());
@@ -104,9 +102,8 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		resetCheckpoint(expected[1].getTaskStates().values());
 		resetCheckpoint(expected[2].getTaskStates().values());
 
-		// Recover
-		SharedStateRegistry newSharedStateRegistry = new SharedStateRegistry();
-		checkpoints.recover(newSharedStateRegistry);
+		// Recover TODO!!! clear registry!
+		checkpoints.recover();
 
 		assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size());
 		assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints());
@@ -117,14 +114,14 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		expectedCheckpoints.add(expected[2]);
 		expectedCheckpoints.add(createCheckpoint(3));
 
-		checkpoints.addCheckpoint(expectedCheckpoints.get(2), newSharedStateRegistry);
+		checkpoints.addCheckpoint(expectedCheckpoints.get(2));
 
 		List<CompletedCheckpoint> actualCheckpoints = checkpoints.getAllCheckpoints();
 
 		assertEquals(expectedCheckpoints, actualCheckpoints);
 
 		for (CompletedCheckpoint actualCheckpoint : actualCheckpoints) {
-			verifyCheckpointRegistered(actualCheckpoint.getTaskStates().values(), newSharedStateRegistry);
+			verifyCheckpointRegistered(actualCheckpoint.getTaskStates().values(), checkpoints.sharedStateRegistry);
 		}
 	}
 
@@ -136,18 +133,17 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		CuratorFramework client = ZooKeeper.getClient();
 
 		CompletedCheckpointStore store = createCompletedCheckpoints(1);
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 
-		store.addCheckpoint(checkpoint, sharedStateRegistry);
+		store.addCheckpoint(checkpoint);
 		assertEquals(1, store.getNumberOfRetainedCheckpoints());
 		assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID()));
 
-		store.shutdown(JobStatus.FINISHED, sharedStateRegistry);
+		store.shutdown(JobStatus.FINISHED);
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 		assertNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID()));
 
-		store.recover(sharedStateRegistry);
+		store.recover();
 
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 	}
@@ -161,20 +157,19 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		CuratorFramework client = ZooKeeper.getClient();
 
 		CompletedCheckpointStore store = createCompletedCheckpoints(1);
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 
-		store.addCheckpoint(checkpoint, sharedStateRegistry);
+		store.addCheckpoint(checkpoint);
 		assertEquals(1, store.getNumberOfRetainedCheckpoints());
 		assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID()));
 
-		store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
+		store.shutdown(JobStatus.SUSPENDED);
 
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 		assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID()));
 
 		// Recover again
-		store.recover(sharedStateRegistry);
+		store.recover();
 
 		CompletedCheckpoint recovered = store.getLatestCheckpoint();
 		assertEquals(checkpoint, recovered);

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
index 1f5731d..66ef232 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
@@ -27,7 +27,6 @@ import org.apache.curator.utils.EnsurePath;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
-import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore;
 import org.apache.flink.util.TestLogger;
@@ -160,9 +159,7 @@ public class ZooKeeperCompletedCheckpointStoreTest extends TestLogger {
 			stateSotrage,
 			Executors.directExecutor());
 
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
-
-		zooKeeperCompletedCheckpointStore.recover(sharedStateRegistry);
+		zooKeeperCompletedCheckpointStore.recover();
 
 		CompletedCheckpoint latestCompletedCheckpoint = zooKeeperCompletedCheckpointStore.getLatestCheckpoint();
 
@@ -227,16 +224,13 @@ public class ZooKeeperCompletedCheckpointStoreTest extends TestLogger {
 			stateSotrage,
 			Executors.directExecutor());
 
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
-		
-		
 		for (long i = 0; i <= numCheckpointsToRetain; ++i) {
 			CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class);
 			doReturn(i).when(checkpointToAdd).getCheckpointID();
 			doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates();
 			
 			try {
-				zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd, sharedStateRegistry);
+				zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd);
 				
 				// The checkpoint should be in the store if we successfully add it into the store.
 				List<CompletedCheckpoint> addedCheckpoints = zooKeeperCompletedCheckpointStore.getAllCheckpoints();

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
index cb14ff0..821bb69 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
@@ -34,50 +34,20 @@ public class SharedStateRegistryTest {
 
 		// register one state
 		TestSharedState firstState = new TestSharedState("first");
-		sharedStateRegistry.register(firstState, true);
-		assertEquals(1, sharedStateRegistry.getReferenceCount(firstState));
+		assertEquals(1, sharedStateRegistry.register(firstState));
 
 		// register another state
 		TestSharedState secondState = new TestSharedState("second");
-		sharedStateRegistry.register(secondState, true);
-		assertEquals(1, sharedStateRegistry.getReferenceCount(secondState));
+		assertEquals(1, sharedStateRegistry.register(secondState));
 
 		// register the first state again
-		sharedStateRegistry.register(firstState, false);
-		assertEquals(2, sharedStateRegistry.getReferenceCount(firstState));
+		assertEquals(2, sharedStateRegistry.register(firstState));
 
 		// unregister the second state
-		sharedStateRegistry.unregister(secondState);
-		assertEquals(0, sharedStateRegistry.getReferenceCount(secondState));
+		assertEquals(0, sharedStateRegistry.unregister(secondState));
 
 		// unregister the first state
-		sharedStateRegistry.unregister(firstState);
-		assertEquals(1, sharedStateRegistry.getReferenceCount(firstState));
-	}
-
-	/**
-	 * Validate that registering a handle referencing uncreated state will throw exception
-	 */
-	@Test(expected = IllegalStateException.class)
-	public void testRegisterWithUncreatedReference() {
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
-
-		// register one state
-		TestSharedState state = new TestSharedState("state");
-		sharedStateRegistry.register(state, false);
-	}
-
-	/**
-	 * Validate that registering duplicate creation of the same state will throw exception
-	 */
-	@Test(expected = IllegalStateException.class)
-	public void testRegisterWithDuplicateState() {
-		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
-
-		// register one state
-		TestSharedState state = new TestSharedState("state");
-		sharedStateRegistry.register(state, true);
-		sharedStateRegistry.register(state, true);
+		assertEquals(1, sharedStateRegistry.unregister(firstState));
 	}
 
 	/**
@@ -100,7 +70,7 @@ public class SharedStateRegistryTest {
 		}
 
 		@Override
-		public String getKey() {
+		public String getRegistrationKey() {
 			return key;
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/aa21f853/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
index 75b0f6f..a932c18 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
@@ -18,12 +18,9 @@
 
 package org.apache.flink.runtime.testutils;
 
+import org.apache.flink.runtime.checkpoint.AbstractCompletedCheckpointStore;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
-import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
 import org.apache.flink.runtime.jobgraph.JobStatus;
-import org.apache.flink.runtime.state.SharedStateRegistry;
-import org.apache.flink.runtime.state.StateObject;
-import org.apache.flink.runtime.state.StateUtil;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -35,7 +32,7 @@ import java.util.List;
  * A checkpoint store, which supports shutdown and suspend. You can use this to test HA
  * as long as the factory always returns the same store instance.
  */
-public class RecoverableCompletedCheckpointStore implements CompletedCheckpointStore {
+public class RecoverableCompletedCheckpointStore extends AbstractCompletedCheckpointStore {
 
 	private static final Logger LOG = LoggerFactory.getLogger(RecoverableCompletedCheckpointStore.class);
 
@@ -44,7 +41,7 @@ public class RecoverableCompletedCheckpointStore implements CompletedCheckpointS
 	private final ArrayDeque<CompletedCheckpoint> suspended = new ArrayDeque<>(2);
 
 	@Override
-	public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void recover() throws Exception {
 		checkpoints.addAll(suspended);
 		suspended.clear();
 
@@ -54,7 +51,7 @@ public class RecoverableCompletedCheckpointStore implements CompletedCheckpointS
 	}
 
 	@Override
-	public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception {
 		checkpoints.addLast(checkpoint);
 
 		checkpoint.registerSharedStates(sharedStateRegistry);
@@ -71,7 +68,7 @@ public class RecoverableCompletedCheckpointStore implements CompletedCheckpointS
 	}
 
 	@Override
-	public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
+	public void shutdown(JobStatus jobStatus) throws Exception {
 		if (jobStatus.isGloballyTerminalState()) {
 			checkpoints.clear();
 			suspended.clear();


[2/3] flink git commit: [FLINK-6014] [checkpoint] Allow the registration of state objects in checkpoints

Posted by sr...@apache.org.
http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index f77c755..aa1726b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -21,13 +21,14 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.messages.CheckpointMessagesTest;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.util.TestLogger;
 import org.junit.Test;
+import org.mockito.Mockito;
 
 import java.io.IOException;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -36,6 +37,9 @@ import java.util.concurrent.CountDownLatch;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 
 /**
  * Test for basic {@link CompletedCheckpointStore} contract.
@@ -64,7 +68,8 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 	@Test
 	public void testAddAndGetLatestCheckpoint() throws Exception {
 		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4);
-
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+		
 		// Empty state
 		assertEquals(0, checkpoints.getNumberOfRetainedCheckpoints());
 		assertEquals(0, checkpoints.getAllCheckpoints().size());
@@ -73,11 +78,11 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 				createCheckpoint(0), createCheckpoint(1) };
 
 		// Add and get latest
-		checkpoints.addCheckpoint(expected[0]);
+		checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
 		assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints());
 		verifyCheckpoint(expected[0], checkpoints.getLatestCheckpoint());
 
-		checkpoints.addCheckpoint(expected[1]);
+		checkpoints.addCheckpoint(expected[1], sharedStateRegistry);
 		assertEquals(2, checkpoints.getNumberOfRetainedCheckpoints());
 		verifyCheckpoint(expected[1], checkpoints.getLatestCheckpoint());
 	}
@@ -88,7 +93,8 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 	 */
 	@Test
 	public void testAddCheckpointMoreThanMaxRetained() throws Exception {
-		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(1);
+		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(1);   
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 
 		TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] {
 				createCheckpoint(0), createCheckpoint(1),
@@ -96,16 +102,24 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		};
 
 		// Add checkpoints
-		checkpoints.addCheckpoint(expected[0]);
+		checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
 		assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints());
 
 		for (int i = 1; i < expected.length; i++) {
-			checkpoints.addCheckpoint(expected[i]);
+			Collection<TaskState> taskStates = expected[i - 1].getTaskStates().values();
+
+			checkpoints.addCheckpoint(expected[i], sharedStateRegistry);
 
 			// The ZooKeeper implementation discards asynchronously
 			expected[i - 1].awaitDiscard();
 			assertTrue(expected[i - 1].isDiscarded());
 			assertEquals(1, checkpoints.getNumberOfRetainedCheckpoints());
+
+			for (TaskState taskState : taskStates) {
+				for (SubtaskState subtaskState : taskState.getStates()) {
+					verify(subtaskState, times(1)).unregisterSharedStates(sharedStateRegistry);
+				}
+			}
 		}
 	}
 
@@ -132,6 +146,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 	@Test
 	public void testGetAllCheckpoints() throws Exception {
 		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4);
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 
 		TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] {
 				createCheckpoint(0), createCheckpoint(1),
@@ -139,7 +154,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		};
 
 		for (TestCompletedCheckpoint checkpoint : expected) {
-			checkpoints.addCheckpoint(checkpoint);
+			checkpoints.addCheckpoint(checkpoint, sharedStateRegistry);
 		}
 
 		List<CompletedCheckpoint> actual = checkpoints.getAllCheckpoints();
@@ -157,6 +172,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 	@Test
 	public void testDiscardAllCheckpoints() throws Exception {
 		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(4);
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 
 		TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] {
 				createCheckpoint(0), createCheckpoint(1),
@@ -164,10 +180,10 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		};
 
 		for (TestCompletedCheckpoint checkpoint : expected) {
-			checkpoints.addCheckpoint(checkpoint);
+			checkpoints.addCheckpoint(checkpoint, sharedStateRegistry);
 		}
 
-		checkpoints.shutdown(JobStatus.FINISHED);
+		checkpoints.shutdown(JobStatus.FINISHED, sharedStateRegistry);
 
 		// Empty state
 		assertNull(checkpoints.getLatestCheckpoint());
@@ -203,15 +219,39 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		taskGroupStates.put(jvid, taskState);
 
 		for (int i = 0; i < numberOfStates; i++) {
-			ChainedStateHandle<StreamStateHandle> stateHandle = CheckpointCoordinatorTest.generateChainedStateHandle(
-					new CheckpointMessagesTest.MyHandle());
+			SubtaskState subtaskState = CheckpointCoordinatorTest.mockSubtaskState(jvid, i, new KeyGroupRange(i, i));
 
-			taskState.putState(i, new SubtaskState(stateHandle, null, null, null, null));
+			taskState.putState(i, subtaskState);
 		}
 
 		return new TestCompletedCheckpoint(new JobID(), id, 0, taskGroupStates, props);
 	}
 
+	protected void resetCheckpoint(Collection<TaskState> taskStates) {
+		for (TaskState taskState : taskStates) {
+			for (SubtaskState subtaskState : taskState.getStates()) {
+				Mockito.reset(subtaskState);
+			}
+		}
+	}
+
+	protected void verifyCheckpointRegistered(Collection<TaskState> taskStates, SharedStateRegistry sharedStateRegistry) {
+		for (TaskState taskState : taskStates) {
+			for (SubtaskState subtaskState : taskState.getStates()) {
+				verify(subtaskState, times(1)).registerSharedStates(eq(sharedStateRegistry));
+			}
+		}
+	}
+
+	protected void verifyCheckpointDiscarded(Collection<TaskState> taskStates) {
+		for (TaskState taskState : taskStates) {
+			for (SubtaskState subtaskState : taskState.getStates()) {
+				verify(subtaskState, times(1)).discardSharedStatesOnFail();
+				verify(subtaskState, times(1)).discardState();
+			}
+		}
+	}
+
 	private void verifyCheckpoint(CompletedCheckpoint expected, CompletedCheckpoint actual) {
 		assertEquals(expected, actual);
 	}
@@ -241,8 +281,8 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		}
 
 		@Override
-		public boolean subsume() throws Exception {
-			if (super.subsume()) {
+		public boolean discardOnSubsume(SharedStateRegistry sharedStateRegistry) throws Exception {
+			if (super.discardOnSubsume(sharedStateRegistry)) {
 				discard();
 				return true;
 			} else {
@@ -251,8 +291,8 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		}
 
 		@Override
-		public boolean discard(JobStatus jobStatus) throws Exception {
-			if (super.discard(jobStatus)) {
+		public boolean discardOnShutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
+			if (super.discardOnShutdown(jobStatus, sharedStateRegistry)) {
 				discard();
 				return true;
 			} else {

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java
index b34e9a6..0b759d4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointTest.java
@@ -23,6 +23,8 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.SharedStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.junit.Rule;
 import org.junit.Test;
@@ -30,10 +32,12 @@ import org.junit.rules.TemporaryFolder;
 import org.mockito.Mockito;
 
 import java.io.File;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
 import static org.junit.Assert.assertEquals;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -61,7 +65,7 @@ public class CompletedCheckpointTest {
 				new FileStateHandle(new Path(file.toURI()), file.length()),
 				file.getAbsolutePath());
 
-		checkpoint.discard(JobStatus.FAILED);
+		checkpoint.discardOnShutdown(JobStatus.FAILED, new SharedStateRegistry());
 
 		assertEquals(false, file.exists());
 	}
@@ -80,10 +84,15 @@ public class CompletedCheckpointTest {
 		CompletedCheckpoint checkpoint = new CompletedCheckpoint(
 				new JobID(), 0, 0, 1, taskStates, props);
 
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+		checkpoint.registerSharedStates(sharedStateRegistry);
+		verify(state, times(1)).registerSharedStates(sharedStateRegistry);
+
 		// Subsume
-		checkpoint.subsume();
+		checkpoint.discardOnSubsume(sharedStateRegistry);
 
 		verify(state, times(1)).discardState();
+		verify(state, times(1)).unregisterSharedStates(sharedStateRegistry);
 	}
 
 	/**
@@ -112,17 +121,22 @@ public class CompletedCheckpointTest {
 					new FileStateHandle(new Path(file.toURI()), file.length()),
 					externalPath);
 
-			checkpoint.discard(status);
+			SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+			checkpoint.registerSharedStates(sharedStateRegistry);
+
+			checkpoint.discardOnShutdown(status, sharedStateRegistry);
 			verify(state, times(0)).discardState();
 			assertEquals(true, file.exists());
+			verify(state, times(0)).unregisterSharedStates(sharedStateRegistry);
 
 			// Discard
 			props = new CheckpointProperties(false, false, true, true, true, true, true);
 			checkpoint = new CompletedCheckpoint(
 					new JobID(), 0, 0, 1, new HashMap<>(taskStates), props);
 
-			checkpoint.discard(status);
+			checkpoint.discardOnShutdown(status, sharedStateRegistry);
 			verify(state, times(1)).discardState();
+			verify(state, times(1)).unregisterSharedStates(sharedStateRegistry);
 		}
 	}
 
@@ -146,7 +160,7 @@ public class CompletedCheckpointTest {
 		CompletedCheckpointStats.DiscardCallback callback = mock(CompletedCheckpointStats.DiscardCallback.class);
 		completed.setDiscardCallback(callback);
 
-		completed.discard(JobStatus.FINISHED);
+		completed.discardOnShutdown(JobStatus.FINISHED, new SharedStateRegistry());
 		verify(callback, times(1)).notifyDiscardedCheckpoint();
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
index 0ab031e..e7c1c3b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ExecutionGraphCheckpointCoordinatorTest.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.jobmanager.scheduler.Scheduler;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.util.SerializedValue;
 
@@ -40,6 +41,8 @@ import org.mockito.Matchers;
 import java.net.URL;
 import java.util.Collections;
 
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -59,7 +62,7 @@ public class ExecutionGraphCheckpointCoordinatorTest {
 		graph.fail(new Exception("Test Exception"));
 
 		verify(counter, times(1)).shutdown(JobStatus.FAILED);
-		verify(store, times(1)).shutdown(JobStatus.FAILED);
+		verify(store, times(1)).shutdown(eq(JobStatus.FAILED), any(SharedStateRegistry.class));
 	}
 
 	/**
@@ -75,8 +78,8 @@ public class ExecutionGraphCheckpointCoordinatorTest {
 		graph.suspend(new Exception("Test Exception"));
 
 		// No shutdown
-		verify(counter, times(1)).shutdown(Matchers.eq(JobStatus.SUSPENDED));
-		verify(store, times(1)).shutdown(Matchers.eq(JobStatus.SUSPENDED));
+		verify(counter, times(1)).shutdown(eq(JobStatus.SUSPENDED));
+		verify(store, times(1)).shutdown(eq(JobStatus.SUSPENDED), any(SharedStateRegistry.class));
 	}
 
 	private ExecutionGraph createExecutionGraphAndEnableCheckpointing(

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index a15684c..d77fac1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -24,15 +24,19 @@ import org.apache.flink.runtime.concurrent.Future;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.SharedStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
+import org.mockito.Mock;
 import org.mockito.Mockito;
 
 import java.io.File;
 import java.lang.reflect.Field;
 import java.util.ArrayDeque;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Queue;
@@ -45,7 +49,10 @@ import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.powermock.api.mockito.PowerMockito.when;
@@ -184,9 +191,12 @@ public class PendingCheckpointTest {
 	@SuppressWarnings("unchecked")
 	public void testAbortDiscardsState() throws Exception {
 		CheckpointProperties props = new CheckpointProperties(false, true, false, false, false, false, false);
-		TaskState state = mock(TaskState.class);
 		QueueExecutor executor = new QueueExecutor();
 
+		TaskState state = mock(TaskState.class);
+		doNothing().when(state).registerSharedStates(any(SharedStateRegistry.class));
+		doNothing().when(state).unregisterSharedStates(any(SharedStateRegistry.class));
+
 		String targetDir = tmpFolder.newFolder().getAbsolutePath();
 
 		// Abort declined
@@ -197,6 +207,7 @@ public class PendingCheckpointTest {
 		// execute asynchronous discard operation
 		executor.runQueuedCommands();
 		verify(state, times(1)).discardState();
+		verify(state, times(1)).discardSharedStatesOnFail();
 
 		// Abort error
 		Mockito.reset(state);
@@ -208,6 +219,7 @@ public class PendingCheckpointTest {
 		// execute asynchronous discard operation
 		executor.runQueuedCommands();
 		verify(state, times(1)).discardState();
+		verify(state, times(1)).discardSharedStatesOnFail();
 
 		// Abort expired
 		Mockito.reset(state);
@@ -219,6 +231,7 @@ public class PendingCheckpointTest {
 		// execute asynchronous discard operation
 		executor.runQueuedCommands();
 		verify(state, times(1)).discardState();
+		verify(state, times(1)).discardSharedStatesOnFail();
 
 		// Abort subsumed
 		Mockito.reset(state);
@@ -230,6 +243,7 @@ public class PendingCheckpointTest {
 		// execute asynchronous discard operation
 		executor.runQueuedCommands();
 		verify(state, times(1)).discardState();
+		verify(state, times(1)).discardSharedStatesOnFail();
 	}
 
 	/**
@@ -340,7 +354,11 @@ public class PendingCheckpointTest {
 		return createPendingCheckpoint(props, targetDirectory, Executors.directExecutor());
 	}
 
-	private static PendingCheckpoint createPendingCheckpoint(CheckpointProperties props, String targetDirectory, Executor executor) {
+	private static PendingCheckpoint createPendingCheckpoint(
+			CheckpointProperties props,
+			String targetDirectory,
+			Executor executor) {
+
 		Map<ExecutionAttemptID, ExecutionVertex> ackTasks = new HashMap<>(ACK_TASKS);
 		return new PendingCheckpoint(
 			new JobID(),

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
index cc7b2d0..7a85897 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStoreTest.java
@@ -19,9 +19,12 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.junit.Test;
 
 import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
@@ -49,15 +52,18 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS
 	@Test
 	public void testShutdownDiscardsCheckpoints() throws Exception {
 		CompletedCheckpointStore store = createCompletedCheckpoints(1);
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		TestCompletedCheckpoint checkpoint = createCheckpoint(0);
+		Collection<TaskState> taskStates = checkpoint.getTaskStates().values();
 
-		store.addCheckpoint(checkpoint);
+		store.addCheckpoint(checkpoint, sharedStateRegistry);
 		assertEquals(1, store.getNumberOfRetainedCheckpoints());
+		verifyCheckpointRegistered(taskStates, sharedStateRegistry);
 
-		store.shutdown(JobStatus.FINISHED);
-
+		store.shutdown(JobStatus.FINISHED, sharedStateRegistry);
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 		assertTrue(checkpoint.isDiscarded());
+		verifyCheckpointDiscarded(taskStates);
 	}
 
 	/**
@@ -67,15 +73,18 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS
 	@Test
 	public void testSuspendDiscardsCheckpoints() throws Exception {
 		CompletedCheckpointStore store = createCompletedCheckpoints(1);
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		TestCompletedCheckpoint checkpoint = createCheckpoint(0);
+		Collection<TaskState> taskStates = checkpoint.getTaskStates().values();
 
-		store.addCheckpoint(checkpoint);
+		store.addCheckpoint(checkpoint, sharedStateRegistry);
 		assertEquals(1, store.getNumberOfRetainedCheckpoints());
+		verifyCheckpointRegistered(taskStates, sharedStateRegistry);
 
-		store.shutdown(JobStatus.SUSPENDED);
-
+		store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 		assertTrue(checkpoint.isDiscarded());
+		verifyCheckpointDiscarded(taskStates);
 	}
 	
 	/**
@@ -87,14 +96,16 @@ public class StandaloneCompletedCheckpointStoreTest extends CompletedCheckpointS
 		
 		final int numCheckpointsToRetain = 1;
 		CompletedCheckpointStore store = createCompletedCheckpoints(numCheckpointsToRetain);
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		
 		for (long i = 0; i <= numCheckpointsToRetain; ++i) {
 			CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class);
 			doReturn(i).when(checkpointToAdd).getCheckpointID();
-			doThrow(new IOException()).when(checkpointToAdd).subsume();
+			doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates();
+			doThrow(new IOException()).when(checkpointToAdd).discardOnSubsume(sharedStateRegistry);
 			
 			try {
-				store.addCheckpoint(checkpointToAdd);
+				store.addCheckpoint(checkpointToAdd, sharedStateRegistry);
 				
 				// The checkpoint should be in the store if we successfully add it into the store.
 				List<CompletedCheckpoint> addedCheckpoints = store.getAllCheckpoints();

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
index 625999a..607e773 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreITCase.java
@@ -22,6 +22,7 @@ import org.apache.curator.framework.CuratorFramework;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment;
 import org.junit.AfterClass;
@@ -80,22 +81,32 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 	@Test
 	public void testRecover() throws Exception {
 		CompletedCheckpointStore checkpoints = createCompletedCheckpoints(3);
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 
 		TestCompletedCheckpoint[] expected = new TestCompletedCheckpoint[] {
 				createCheckpoint(0), createCheckpoint(1), createCheckpoint(2)
 		};
 
 		// Add multiple checkpoints
-		checkpoints.addCheckpoint(expected[0]);
-		checkpoints.addCheckpoint(expected[1]);
-		checkpoints.addCheckpoint(expected[2]);
+		checkpoints.addCheckpoint(expected[0], sharedStateRegistry);
+		checkpoints.addCheckpoint(expected[1], sharedStateRegistry);
+		checkpoints.addCheckpoint(expected[2], sharedStateRegistry);
+
+		verifyCheckpointRegistered(expected[0].getTaskStates().values(), sharedStateRegistry);
+		verifyCheckpointRegistered(expected[1].getTaskStates().values(), sharedStateRegistry);
+		verifyCheckpointRegistered(expected[2].getTaskStates().values(), sharedStateRegistry);
 
 		// All three should be in ZK
 		assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size());
 		assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints());
 
+		resetCheckpoint(expected[0].getTaskStates().values());
+		resetCheckpoint(expected[1].getTaskStates().values());
+		resetCheckpoint(expected[2].getTaskStates().values());
+
 		// Recover
-		checkpoints.recover();
+		SharedStateRegistry newSharedStateRegistry = new SharedStateRegistry();
+		checkpoints.recover(newSharedStateRegistry);
 
 		assertEquals(3, ZooKeeper.getClient().getChildren().forPath(CheckpointsPath).size());
 		assertEquals(3, checkpoints.getNumberOfRetainedCheckpoints());
@@ -106,11 +117,15 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		expectedCheckpoints.add(expected[2]);
 		expectedCheckpoints.add(createCheckpoint(3));
 
-		checkpoints.addCheckpoint(expectedCheckpoints.get(2));
+		checkpoints.addCheckpoint(expectedCheckpoints.get(2), newSharedStateRegistry);
 
 		List<CompletedCheckpoint> actualCheckpoints = checkpoints.getAllCheckpoints();
 
 		assertEquals(expectedCheckpoints, actualCheckpoints);
+
+		for (CompletedCheckpoint actualCheckpoint : actualCheckpoints) {
+			verifyCheckpointRegistered(actualCheckpoint.getTaskStates().values(), newSharedStateRegistry);
+		}
 	}
 
 	/**
@@ -121,18 +136,18 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		CuratorFramework client = ZooKeeper.getClient();
 
 		CompletedCheckpointStore store = createCompletedCheckpoints(1);
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 
-		store.addCheckpoint(checkpoint);
+		store.addCheckpoint(checkpoint, sharedStateRegistry);
 		assertEquals(1, store.getNumberOfRetainedCheckpoints());
 		assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID()));
 
-		store.shutdown(JobStatus.FINISHED);
-
+		store.shutdown(JobStatus.FINISHED, sharedStateRegistry);
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 		assertNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID()));
 
-		store.recover();
+		store.recover(sharedStateRegistry);
 
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 	}
@@ -146,19 +161,20 @@ public class ZooKeeperCompletedCheckpointStoreITCase extends CompletedCheckpoint
 		CuratorFramework client = ZooKeeper.getClient();
 
 		CompletedCheckpointStore store = createCompletedCheckpoints(1);
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		TestCompletedCheckpoint checkpoint = createCheckpoint(0);
 
-		store.addCheckpoint(checkpoint);
+		store.addCheckpoint(checkpoint, sharedStateRegistry);
 		assertEquals(1, store.getNumberOfRetainedCheckpoints());
 		assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID()));
 
-		store.shutdown(JobStatus.SUSPENDED);
+		store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
 
 		assertEquals(0, store.getNumberOfRetainedCheckpoints());
 		assertNotNull(client.checkExists().forPath(CheckpointsPath + "/" + checkpoint.getCheckpointID()));
 
 		// Recover again
-		store.recover();
+		store.recover(sharedStateRegistry);
 
 		CompletedCheckpoint recovered = store.getLatestCheckpoint();
 		assertEquals(checkpoint, recovered);

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
index aa2ec85..1f5731d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStoreTest.java
@@ -27,6 +27,7 @@ import org.apache.curator.utils.EnsurePath;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore;
 import org.apache.flink.util.TestLogger;
@@ -40,6 +41,7 @@ import org.powermock.modules.junit4.PowerMockRunner;
 
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.concurrent.Executor;
@@ -158,7 +160,9 @@ public class ZooKeeperCompletedCheckpointStoreTest extends TestLogger {
 			stateSotrage,
 			Executors.directExecutor());
 
-		zooKeeperCompletedCheckpointStore.recover();
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+
+		zooKeeperCompletedCheckpointStore.recover(sharedStateRegistry);
 
 		CompletedCheckpoint latestCompletedCheckpoint = zooKeeperCompletedCheckpointStore.getLatestCheckpoint();
 
@@ -222,14 +226,17 @@ public class ZooKeeperCompletedCheckpointStoreTest extends TestLogger {
 			checkpointsPath,
 			stateSotrage,
 			Executors.directExecutor());
+
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
 		
 		
 		for (long i = 0; i <= numCheckpointsToRetain; ++i) {
 			CompletedCheckpoint checkpointToAdd = mock(CompletedCheckpoint.class);
 			doReturn(i).when(checkpointToAdd).getCheckpointID();
+			doReturn(Collections.emptyMap()).when(checkpointToAdd).getTaskStates();
 			
 			try {
-				zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd);
+				zooKeeperCompletedCheckpointStore.addCheckpoint(checkpointToAdd, sharedStateRegistry);
 				
 				// The checkpoint should be in the store if we successfully add it into the store.
 				List<CompletedCheckpoint> addedCheckpoints = zooKeeperCompletedCheckpointStore.getAllCheckpoints();

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index 6eacaac..77eb566 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -71,6 +71,7 @@ import org.apache.flink.runtime.leaderretrieval.LeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.metrics.MetricRegistry;
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
@@ -81,6 +82,7 @@ import org.apache.flink.runtime.testingUtils.TestingMessages;
 import org.apache.flink.runtime.testingUtils.TestingTaskManager;
 import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
+import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
 
@@ -164,7 +166,7 @@ public class JobManagerHARecoveryTest {
 			Scheduler scheduler = new Scheduler(TestingUtils.defaultExecutionContext());
 
 			MySubmittedJobGraphStore mySubmittedJobGraphStore = new MySubmittedJobGraphStore();
-			MyCheckpointStore checkpointStore = new MyCheckpointStore();
+			CompletedCheckpointStore checkpointStore = new RecoverableCompletedCheckpointStore();
 			CheckpointIDCounter checkpointCounter = new StandaloneCheckpointIDCounter();
 			CheckpointRecoveryFactory checkpointStateFactory = new MyCheckpointRecoveryFactory(checkpointStore, checkpointCounter);
 			TestingLeaderElectionService myLeaderElectionService = new TestingLeaderElectionService();
@@ -438,67 +440,6 @@ public class JobManagerHARecoveryTest {
 		}
 	}
 
-	/**
-	 * A checkpoint store, which supports shutdown and suspend. You can use this to test HA
-	 * as long as the factory always returns the same store instance.
-	 */
-	static class MyCheckpointStore implements CompletedCheckpointStore {
-
-		private final ArrayDeque<CompletedCheckpoint> checkpoints = new ArrayDeque<>(2);
-
-		private final ArrayDeque<CompletedCheckpoint> suspended = new ArrayDeque<>(2);
-
-		@Override
-		public void recover() throws Exception {
-			checkpoints.addAll(suspended);
-			suspended.clear();
-		}
-
-		@Override
-		public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception {
-			checkpoints.addLast(checkpoint);
-			if (checkpoints.size() > 1) {
-				checkpoints.removeFirst().subsume();
-			}
-		}
-
-		@Override
-		public CompletedCheckpoint getLatestCheckpoint() throws Exception {
-			return checkpoints.isEmpty() ? null : checkpoints.getLast();
-		}
-
-		@Override
-		public void shutdown(JobStatus jobStatus) throws Exception {
-			if (jobStatus.isGloballyTerminalState()) {
-				checkpoints.clear();
-				suspended.clear();
-			} else {
-				suspended.addAll(checkpoints);
-				checkpoints.clear();
-			}
-		}
-
-		@Override
-		public List<CompletedCheckpoint> getAllCheckpoints() throws Exception {
-			return new ArrayList<>(checkpoints);
-		}
-
-		@Override
-		public int getNumberOfRetainedCheckpoints() {
-			return checkpoints.size();
-		}
-
-		@Override
-		public int getMaxNumberOfRetainedCheckpoints() {
-			return 1;
-		}
-
-		@Override
-		public boolean requiresExternalizedCheckpoints() {
-			return false;
-		}
-	}
-
 	static class MyCheckpointRecoveryFactory implements CheckpointRecoveryFactory {
 
 		private final CompletedCheckpointStore store;

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
new file mode 100644
index 0000000..cb14ff0
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/SharedStateRegistryTest.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.flink.runtime.state;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class SharedStateRegistryTest {
+
+	/**
+	 * Validate that all states can be correctly registered at the registry.
+	 */
+	@Test
+	public void testRegistryNormal() {
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+
+		// register one state
+		TestSharedState firstState = new TestSharedState("first");
+		sharedStateRegistry.register(firstState, true);
+		assertEquals(1, sharedStateRegistry.getReferenceCount(firstState));
+
+		// register another state
+		TestSharedState secondState = new TestSharedState("second");
+		sharedStateRegistry.register(secondState, true);
+		assertEquals(1, sharedStateRegistry.getReferenceCount(secondState));
+
+		// register the first state again
+		sharedStateRegistry.register(firstState, false);
+		assertEquals(2, sharedStateRegistry.getReferenceCount(firstState));
+
+		// unregister the second state
+		sharedStateRegistry.unregister(secondState);
+		assertEquals(0, sharedStateRegistry.getReferenceCount(secondState));
+
+		// unregister the first state
+		sharedStateRegistry.unregister(firstState);
+		assertEquals(1, sharedStateRegistry.getReferenceCount(firstState));
+	}
+
+	/**
+	 * Validate that registering a handle referencing uncreated state will throw exception
+	 */
+	@Test(expected = IllegalStateException.class)
+	public void testRegisterWithUncreatedReference() {
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+
+		// register one state
+		TestSharedState state = new TestSharedState("state");
+		sharedStateRegistry.register(state, false);
+	}
+
+	/**
+	 * Validate that registering duplicate creation of the same state will throw exception
+	 */
+	@Test(expected = IllegalStateException.class)
+	public void testRegisterWithDuplicateState() {
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+
+		// register one state
+		TestSharedState state = new TestSharedState("state");
+		sharedStateRegistry.register(state, true);
+		sharedStateRegistry.register(state, true);
+	}
+
+	/**
+	 * Validate that unregister an unexisted key will throw exception
+	 */
+	@Test(expected = IllegalStateException.class)
+	public void testUnregisterWithUnexistedKey() {
+		SharedStateRegistry sharedStateRegistry = new SharedStateRegistry();
+
+		sharedStateRegistry.unregister(new TestSharedState("unexisted"));
+	}
+
+	private static class TestSharedState implements SharedStateHandle {
+		private static final long serialVersionUID = 4468635881465159780L;
+
+		private String key;
+
+		TestSharedState(String key) {
+			this.key = key;
+		}
+
+		@Override
+		public String getKey() {
+			return key;
+		}
+
+		@Override
+		public void discardState() throws Exception {
+			// nothing to do
+		}
+
+		@Override
+		public long getStateSize() {
+			return key.length();
+		}
+
+		@Override
+		public boolean equals(Object o) {
+			if (this == o) {
+				return true;
+			}
+			if (o == null || getClass() != o.getClass()) {
+				return false;
+			}
+
+			TestSharedState testState = (TestSharedState) o;
+
+			return key.equals(testState.key);
+		}
+
+		@Override
+		public int hashCode() {
+			return key.hashCode();
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
new file mode 100644
index 0000000..75b0f6f
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/RecoverableCompletedCheckpointStore.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.testutils;
+
+import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
+import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
+import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * A checkpoint store, which supports shutdown and suspend. You can use this to test HA
+ * as long as the factory always returns the same store instance.
+ */
+public class RecoverableCompletedCheckpointStore implements CompletedCheckpointStore {
+
+	private static final Logger LOG = LoggerFactory.getLogger(RecoverableCompletedCheckpointStore.class);
+
+	private final ArrayDeque<CompletedCheckpoint> checkpoints = new ArrayDeque<>(2);
+
+	private final ArrayDeque<CompletedCheckpoint> suspended = new ArrayDeque<>(2);
+
+	@Override
+	public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
+		checkpoints.addAll(suspended);
+		suspended.clear();
+
+		for (CompletedCheckpoint checkpoint : checkpoints) {
+			checkpoint.registerSharedStates(sharedStateRegistry);
+		}
+	}
+
+	@Override
+	public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception {
+		checkpoints.addLast(checkpoint);
+
+		checkpoint.registerSharedStates(sharedStateRegistry);
+
+		if (checkpoints.size() > 1) {
+			CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst();
+			checkpointToSubsume.discardOnSubsume(sharedStateRegistry);
+		}
+	}
+
+	@Override
+	public CompletedCheckpoint getLatestCheckpoint() throws Exception {
+		return checkpoints.isEmpty() ? null : checkpoints.getLast();
+	}
+
+	@Override
+	public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
+		if (jobStatus.isGloballyTerminalState()) {
+			checkpoints.clear();
+			suspended.clear();
+		} else {
+			suspended.clear();
+
+			for (CompletedCheckpoint checkpoint : checkpoints) {
+				sharedStateRegistry.unregisterAll(checkpoint.getTaskStates().values());
+				suspended.add(checkpoint);
+			}
+
+			checkpoints.clear();
+		}
+	}
+
+	@Override
+	public List<CompletedCheckpoint> getAllCheckpoints() throws Exception {
+		return new ArrayList<>(checkpoints);
+	}
+
+	@Override
+	public int getNumberOfRetainedCheckpoints() {
+		return checkpoints.size();
+	}
+
+	@Override
+	public int getMaxNumberOfRetainedCheckpoints() {
+		return 1;
+	}
+
+	@Override
+	public boolean requiresExternalizedCheckpoints() {
+		return false;
+	}
+}


[3/3] flink git commit: [FLINK-6014] [checkpoint] Allow the registration of state objects in checkpoints

Posted by sr...@apache.org.
[FLINK-6014] [checkpoint] Allow the registration of state objects in checkpoints


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/218bed8b
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/218bed8b
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/218bed8b

Branch: refs/heads/master
Commit: 218bed8b8e49b0e4c61c61f696a8f010eafea1b7
Parents: db31ca3
Author: xiaogang.sxg <xi...@alibaba-inc.com>
Authored: Mon Mar 13 19:23:47 2017 +0800
Committer: Stefan Richter <s....@data-artisans.com>
Committed: Sat Apr 22 15:25:56 2017 +0200

----------------------------------------------------------------------
 .../checkpoint/CheckpointCoordinator.java       | 136 ++++++++----
 .../runtime/checkpoint/CompletedCheckpoint.java |  73 +++++--
 .../checkpoint/CompletedCheckpointStore.java    |   9 +-
 .../runtime/checkpoint/PendingCheckpoint.java   |  16 +-
 .../StandaloneCompletedCheckpointStore.java     |  18 +-
 .../flink/runtime/checkpoint/SubtaskState.java  |  49 +++--
 .../flink/runtime/checkpoint/TaskState.java     |  30 ++-
 .../ZooKeeperCompletedCheckpointStore.java      |  67 ++++--
 .../runtime/state/CompositeStateHandle.java     |  60 ++++++
 .../flink/runtime/state/SharedStateHandle.java  |  39 ++++
 .../runtime/state/SharedStateRegistry.java      | 178 ++++++++++++++++
 .../flink/runtime/jobmanager/JobManager.scala   |  27 +--
 .../CheckpointCoordinatorFailureTest.java       |  26 ++-
 .../checkpoint/CheckpointCoordinatorTest.java   | 211 ++++++++++++++-----
 .../checkpoint/CheckpointStateRestoreTest.java  |   4 +-
 .../CompletedCheckpointStoreTest.java           |  78 +++++--
 .../checkpoint/CompletedCheckpointTest.java     |  24 ++-
 ...ExecutionGraphCheckpointCoordinatorTest.java |   9 +-
 .../checkpoint/PendingCheckpointTest.java       |  22 +-
 .../StandaloneCompletedCheckpointStoreTest.java |  27 ++-
 ...ZooKeeperCompletedCheckpointStoreITCase.java |  40 ++--
 .../ZooKeeperCompletedCheckpointStoreTest.java  |  11 +-
 .../jobmanager/JobManagerHARecoveryTest.java    |  65 +-----
 .../runtime/state/SharedStateRegistryTest.java  | 136 ++++++++++++
 .../RecoverableCompletedCheckpointStore.java    | 109 ++++++++++
 25 files changed, 1176 insertions(+), 288 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 7087540..5309dd4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.runtime.checkpoint.savepoint.SavepointLoader;
 import org.apache.flink.runtime.checkpoint.savepoint.SavepointStore;
 import org.apache.flink.runtime.concurrent.ApplyFunction;
 import org.apache.flink.runtime.concurrent.Future;
@@ -36,10 +37,11 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
-import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.TaskStateHandles;
 
 import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
+import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -108,6 +110,9 @@ public class CheckpointCoordinator {
 	/** Completed checkpoints. Implementations can be blocking. Make sure calls to methods
 	 * accessing this don't block the job manager actor and run asynchronously. */
 	private final CompletedCheckpointStore completedCheckpointStore;
+	
+	/** Registry for shared states */
+	private final SharedStateRegistry sharedStateRegistry;
 
 	/** Default directory for persistent checkpoints; <code>null</code> if none configured.
 	 * THIS WILL BE REPLACED BY PROPER STATE-BACKEND METADATA WRITING */
@@ -218,6 +223,7 @@ public class CheckpointCoordinator {
 		this.completedCheckpointStore = checkNotNull(completedCheckpointStore);
 		this.checkpointDirectory = checkpointDirectory;
 		this.executor = checkNotNull(executor);
+		this.sharedStateRegistry = new SharedStateRegistry();
 
 		this.recentPendingCheckpoints = new ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS);
 
@@ -282,7 +288,7 @@ public class CheckpointCoordinator {
 				}
 				pendingCheckpoints.clear();
 
-				completedCheckpointStore.shutdown(jobStatus);
+				completedCheckpointStore.shutdown(jobStatus, sharedStateRegistry);
 				checkpointIdCounter.shutdown(jobStatus);
 			}
 		}
@@ -615,7 +621,7 @@ public class CheckpointCoordinator {
 			throw new IllegalArgumentException("Received DeclineCheckpoint message for job " +
 				message.getJob() + " while this coordinator handles job " + job);
 		}
-
+		
 		final long checkpointId = message.getCheckpointId();
 		final String reason = (message.getReason() != null ? message.getReason().getMessage() : "");
 
@@ -695,7 +701,7 @@ public class CheckpointCoordinator {
 		}
 
 		final long checkpointId = message.getCheckpointId();
-
+		
 		synchronized (lock) {
 			// we need to check inside the lock for being shutdown as well, otherwise we
 			// get races and invalid error log messages
@@ -778,49 +784,55 @@ public class CheckpointCoordinator {
 	 */
 	private void completePendingCheckpoint(PendingCheckpoint pendingCheckpoint) throws CheckpointException {
 		final long checkpointId = pendingCheckpoint.getCheckpointId();
-		CompletedCheckpoint completedCheckpoint = null;
+		final CompletedCheckpoint completedCheckpoint;
 
 		try {
-			// externalize the checkpoint if required
-			if (pendingCheckpoint.getProps().externalizeCheckpoint()) {
-				completedCheckpoint = pendingCheckpoint.finalizeCheckpointExternalized();
-			} else {
-				completedCheckpoint = pendingCheckpoint.finalizeCheckpointNonExternalized();
-			}
-
-			completedCheckpointStore.addCheckpoint(completedCheckpoint);
-
-			rememberRecentCheckpointId(checkpointId);
-			dropSubsumedCheckpoints(checkpointId);
-		}
-		catch (Exception exception) {
-			// abort the current pending checkpoint if it has not been discarded yet
-			if (!pendingCheckpoint.isDiscarded()) {
-				pendingCheckpoint.abortError(exception);
+			try {
+				// externalize the checkpoint if required
+				if (pendingCheckpoint.getProps().externalizeCheckpoint()) {
+					completedCheckpoint = pendingCheckpoint.finalizeCheckpointExternalized();
+				} else {
+					completedCheckpoint = pendingCheckpoint.finalizeCheckpointNonExternalized();
+				}
+			} catch (Exception e1) {
+				// abort the current pending checkpoint if we fails to finalize the pending checkpoint.
+				if (!pendingCheckpoint.isDiscarded()) {
+					pendingCheckpoint.abortError(e1);
+				}
+	
+				throw new CheckpointException("Could not finalize the pending checkpoint " + checkpointId + '.', e1);
 			}
-
-			if (completedCheckpoint != null) {
+	
+			// the pending checkpoint must be discarded after the finalization
+			Preconditions.checkState(pendingCheckpoint.isDiscarded() && completedCheckpoint != null);
+	
+			try {
+				completedCheckpointStore.addCheckpoint(completedCheckpoint, sharedStateRegistry);
+			} catch (Exception exception) {
 				// we failed to store the completed checkpoint. Let's clean up
-				final CompletedCheckpoint cc = completedCheckpoint;
-
 				executor.execute(new Runnable() {
 					@Override
 					public void run() {
 						try {
-							cc.discard();
+							completedCheckpoint.discardOnFail();
 						} catch (Throwable t) {
-							LOG.warn("Could not properly discard completed checkpoint {}.", cc.getCheckpointID(), t);
+							LOG.warn("Could not properly discard completed checkpoint {}.", completedCheckpoint.getCheckpointID(), t);
 						}
 					}
 				});
+				
+				throw new CheckpointException("Could not complete the pending checkpoint " + checkpointId + '.', exception);
 			}
-
-			throw new CheckpointException("Could not complete the pending checkpoint " + checkpointId + '.', exception);
 		} finally {
 			pendingCheckpoints.remove(checkpointId);
 
 			triggerQueuedRequests();
 		}
+		
+		rememberRecentCheckpointId(checkpointId);
+		
+		// drop those pending checkpoints that are at prior to the completed one
+		dropSubsumedCheckpoints(checkpointId);
 
 		// record the time when this was completed, to calculate
 		// the 'min delay between checkpoints'
@@ -941,7 +953,7 @@ public class CheckpointCoordinator {
 			}
 
 			// Recover the checkpoints
-			completedCheckpointStore.recover();
+			completedCheckpointStore.recover(sharedStateRegistry);
 
 			// restore from the latest checkpoint
 			CompletedCheckpoint latest = completedCheckpointStore.getLatestCheckpoint();
@@ -978,6 +990,44 @@ public class CheckpointCoordinator {
 		}
 	}
 
+	/**
+	 * Restore the state with given savepoint
+	 * 
+	 * @param savepointPath    Location of the savepoint
+	 * @param allowNonRestored True if allowing checkpoint state that cannot be 
+	 *                         mapped to any job vertex in tasks.
+	 * @param tasks            Map of job vertices to restore. State for these 
+	 *                         vertices is restored via 
+	 *                         {@link Execution#setInitialState(TaskStateHandles)}.
+	 * @param userClassLoader  The class loader to resolve serialized classes in 
+	 *                         legacy savepoint versions. 
+	 */
+	public boolean restoreSavepoint(
+			String savepointPath, 
+			boolean allowNonRestored,
+			Map<JobVertexID, ExecutionJobVertex> tasks,
+			ClassLoader userClassLoader) throws Exception {
+		
+		Preconditions.checkNotNull(savepointPath, "The savepoint path cannot be null.");
+		
+		LOG.info("Starting job from savepoint {} ({})", 
+				savepointPath, (allowNonRestored ? "allowing non restored state" : ""));
+
+		// Load the savepoint as a checkpoint into the system
+		CompletedCheckpoint savepoint = SavepointLoader.loadAndValidateSavepoint(
+				job, tasks, savepointPath, userClassLoader, allowNonRestored);
+
+		completedCheckpointStore.addCheckpoint(savepoint, sharedStateRegistry);
+		
+		// Reset the checkpoint ID counter
+		long nextCheckpointId = savepoint.getCheckpointID() + 1;
+		checkpointIdCounter.setCount(nextCheckpointId);
+		
+		LOG.info("Reset the checkpoint ID to {}.", nextCheckpointId);
+		
+		return restoreLatestCheckpointedState(tasks, true, allowNonRestored);
+	}
+
 	// --------------------------------------------------------------------------------------------
 	//  Accessors
 	// --------------------------------------------------------------------------------------------
@@ -1007,6 +1057,10 @@ public class CheckpointCoordinator {
 	public CompletedCheckpointStore getCheckpointStore() {
 		return completedCheckpointStore;
 	}
+	
+	public SharedStateRegistry getSharedStateRegistry() {
+		return sharedStateRegistry;
+	}
 
 	public CheckpointIDCounter getCheckpointIdCounter() {
 		return checkpointIdCounter;
@@ -1095,24 +1149,30 @@ public class CheckpointCoordinator {
 	 * @param jobId identifying the job to which the state object belongs
 	 * @param executionAttemptID identifying the task to which the state object belongs
 	 * @param checkpointId of the state object
-	 * @param stateObject to discard asynchronously
+	 * @param subtaskState to discard asynchronously
 	 */
 	private void discardState(
 			final JobID jobId,
 			final ExecutionAttemptID executionAttemptID,
 			final long checkpointId,
-			final StateObject stateObject) {
+			final SubtaskState subtaskState) {
 
-		if (stateObject != null) {
+		if (subtaskState != null) {
 			executor.execute(new Runnable() {
 				@Override
 				public void run() {
 					try {
-						stateObject.discardState();
-					} catch (Throwable throwable) {
-					LOG.warn("Could not properly discard state object of checkpoint {} " +
-						"belonging to task {} of job {}.", checkpointId, executionAttemptID, jobId,
-						throwable);
+						subtaskState.discardSharedStatesOnFail();
+					} catch (Throwable t1) {
+						LOG.warn("Could not properly discard shared states of checkpoint {} " +
+							"belonging to task {} of job {}.", checkpointId, executionAttemptID, jobId, t1);
+					}
+
+					try {
+						subtaskState.discardState();
+					} catch (Throwable t2) {
+						LOG.warn("Could not properly discard state object of checkpoint {} " +
+							"belonging to task {} of job {}.", checkpointId, executionAttemptID, jobId, t2);
 					}
 				}
 			});

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
index 17ce4d5..58e91e1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpoint.java
@@ -23,10 +23,12 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStats.DiscardCallback;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.ExceptionUtils;
 
+import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -184,22 +186,30 @@ public class CompletedCheckpoint implements Serializable {
 		return props;
 	}
 
-	public boolean subsume() throws Exception {
+	public void discardOnFail() throws Exception {
+		discard(null, true);
+	}
+
+	public boolean discardOnSubsume(SharedStateRegistry sharedStateRegistry) throws Exception {
+		Preconditions.checkNotNull(sharedStateRegistry, "The registry cannot be null.");
+
 		if (props.discardOnSubsumed()) {
-			discard();
+			discard(sharedStateRegistry, false);
 			return true;
 		}
 
 		return false;
 	}
 
-	public boolean discard(JobStatus jobStatus) throws Exception {
+	public boolean discardOnShutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
+		Preconditions.checkNotNull(sharedStateRegistry, "The registry cannot be null.");
+
 		if (jobStatus == JobStatus.FINISHED && props.discardOnJobFinished() ||
 				jobStatus == JobStatus.CANCELED && props.discardOnJobCancelled() ||
 				jobStatus == JobStatus.FAILED && props.discardOnJobFailed() ||
 				jobStatus == JobStatus.SUSPENDED && props.discardOnJobSuspended()) {
 
-			discard();
+			discard(sharedStateRegistry, false);
 			return true;
 		} else {
 			if (externalPointer != null) {
@@ -211,7 +221,10 @@ public class CompletedCheckpoint implements Serializable {
 		}
 	}
 
-	void discard() throws Exception {
+	private void discard(SharedStateRegistry sharedStateRegistry, boolean failed) throws Exception {
+		Preconditions.checkState(failed || (sharedStateRegistry != null),
+			"The registry must not be null if the complete checkpoint does not fail.");
+
 		try {
 			// collect exceptions and continue cleanup
 			Exception exception = null;
@@ -220,25 +233,31 @@ public class CompletedCheckpoint implements Serializable {
 			if (externalizedMetadata != null) {
 				try {
 					externalizedMetadata.discardState();
-				}
-				catch (Exception e) {
+				} catch (Exception e) {
 					exception = e;
 				}
 			}
 
-			// drop the actual state
+			// In the cases where the completed checkpoint fails, the shared
+			// states have not been registered to the registry. It's the state
+			// handles' responsibility to discard their shared states.
+			if (!failed) {
+				unregisterSharedStates(sharedStateRegistry);
+			} else {
+				discardSharedStatesOnFail();
+			}
+
+			// discard private state objects
 			try {
 				StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
-			}
-			catch (Exception e) {
+			} catch (Exception e) {
 				exception = ExceptionUtils.firstOrSuppressed(e, exception);
 			}
 
 			if (exception != null) {
 				throw exception;
 			}
-		}
-		finally {
+		} finally {
 			taskStates.clear();
 
 			// to be null-pointer safe, copy reference to stack
@@ -290,6 +309,36 @@ public class CompletedCheckpoint implements Serializable {
 		this.discardCallback = discardCallback;
 	}
 
+	/**
+	 * Register all shared states in the given registry. This is method is called
+	 * when the completed checkpoint has been successfully added into the store.
+	 *
+	 * @param sharedStateRegistry The registry where shared states are registered
+	 */
+	public void registerSharedStates(SharedStateRegistry sharedStateRegistry) {
+		sharedStateRegistry.registerAll(taskStates.values());
+	}
+
+	/**
+	 * Unregister all shared states from the given registry. This is method is
+	 * called when the completed checkpoint is subsumed or the job terminates.
+	 *
+	 * @param sharedStateRegistry The registry where shared states are registered
+	 */
+	private void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) {
+		sharedStateRegistry.unregisterAll(taskStates.values());
+	}
+
+	/**
+	 * Discard all shared states created in the checkpoint. This method is called
+	 * when the completed checkpoint fails to be added into the store.
+	 */
+	private void discardSharedStatesOnFail() throws Exception {
+		for (TaskState taskState : taskStates.values()) {
+			taskState.discardSharedStatesOnFail();
+		}
+	}
+
 	// --------------------------------------------------------------------------------------------
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
index 9c2b199..0ade25c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStore.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 
 import java.util.List;
 
@@ -33,16 +34,16 @@ public interface CompletedCheckpointStore {
 	 * <p>After a call to this method, {@link #getLatestCheckpoint()} returns the latest
 	 * available checkpoint.
 	 */
-	void recover() throws Exception;
+	void recover(SharedStateRegistry sharedStateRegistry) throws Exception;
 
 	/**
 	 * Adds a {@link CompletedCheckpoint} instance to the list of completed checkpoints.
 	 *
 	 * <p>Only a bounded number of checkpoints is kept. When exceeding the maximum number of
 	 * retained checkpoints, the oldest one will be discarded via {@link
-	 * CompletedCheckpoint#discard()}.
+	 * CompletedCheckpoint#discardOnSubsume(SharedStateRegistry)} )}.
 	 */
-	void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception;
+	void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception;
 
 	/**
 	 * Returns the latest {@link CompletedCheckpoint} instance or <code>null</code> if none was
@@ -58,7 +59,7 @@ public interface CompletedCheckpointStore {
 	 *
 	 * @param jobStatus Job state on shut down
 	 */
-	void shutdown(JobStatus jobStatus) throws Exception;
+	void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception;
 
 	/**
 	 * Returns all {@link CompletedCheckpoint} instances.

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index e1182ae..6805dea 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -497,6 +497,7 @@ public class PendingCheckpoint {
 	}
 
 	private void dispose(boolean releaseState) {
+
 		synchronized (lock) {
 			try {
 				numAcknowledgedTasks = -1;
@@ -504,11 +505,22 @@ public class PendingCheckpoint {
 					executor.execute(new Runnable() {
 						@Override
 						public void run() {
+
+							// discard the shared states that are created in the checkpoint
+							for (TaskState taskState : taskStates.values()) {
+								try {
+									taskState.discardSharedStatesOnFail();
+								} catch (Throwable t) {
+									LOG.warn("Could not properly dispose unreferenced shared states.");
+								}
+							}
+
+							// discard the private states
 							try {
 								StateUtil.bestEffortDiscardAllStateObjects(taskStates.values());
 							} catch (Throwable t) {
-								LOG.warn("Could not properly dispose the pending checkpoint {} of job {}.", 
-										checkpointId, jobId, t);
+								LOG.warn("Could not properly dispose the private states in the pending checkpoint {} of job {}.",
+									checkpointId, jobId, t);
 							} finally {
 								taskStates.clear();
 							}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
index 6eb5242..9f833c3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/StandaloneCompletedCheckpointStore.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobmanager.HighAvailabilityMode;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -56,16 +57,21 @@ public class StandaloneCompletedCheckpointStore implements CompletedCheckpointSt
 	}
 
 	@Override
-	public void recover() throws Exception {
+	public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
 		// Nothing to do
 	}
 
 	@Override
-	public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception {
-		checkpoints.add(checkpoint);
+	public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception {
+		
+		checkpoints.addLast(checkpoint);
+
+		checkpoint.registerSharedStates(sharedStateRegistry);
+
 		if (checkpoints.size() > maxNumberOfCheckpointsToRetain) {
 			try {
-				checkpoints.remove().subsume();
+				CompletedCheckpoint checkpointToSubsume = checkpoints.removeFirst();
+				checkpointToSubsume.discardOnSubsume(sharedStateRegistry);
 			} catch (Exception e) {
 				LOG.warn("Fail to subsume the old checkpoint.", e);
 			}
@@ -93,12 +99,12 @@ public class StandaloneCompletedCheckpointStore implements CompletedCheckpointSt
 	}
 
 	@Override
-	public void shutdown(JobStatus jobStatus) throws Exception {
+	public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
 		try {
 			LOG.info("Shutting down");
 
 			for (CompletedCheckpoint checkpoint : checkpoints) {
-				checkpoint.discard(jobStatus);
+				checkpoint.discardOnShutdown(jobStatus, sharedStateRegistry);
 			}
 		} finally {
 			checkpoints.clear();

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
index 97b08fc..e968643 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskState.java
@@ -19,11 +19,15 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.CompositeStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateObject;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.Arrays;
 
@@ -33,7 +37,9 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * Container for the chained state of one parallel subtask of an operator/task. This is part of the
  * {@link TaskState}.
  */
-public class SubtaskState implements StateObject {
+public class SubtaskState implements CompositeStateHandle {
+
+	private static final Logger LOG = LoggerFactory.getLogger(SubtaskState.class);
 
 	private static final long serialVersionUID = -2394696997971923995L;
 
@@ -130,19 +136,38 @@ public class SubtaskState implements StateObject {
 	}
 
 	@Override
-	public long getStateSize() {
-		return stateSize;
+	public void discardState() {
+		try {
+			StateUtil.bestEffortDiscardAllStateObjects(
+				Arrays.asList(
+					legacyOperatorState,
+					managedOperatorState,
+					rawOperatorState,
+					managedKeyedState,
+					rawKeyedState));
+		} catch (Exception e) {
+			LOG.warn("Error while discarding operator states.", e);
+		}
 	}
 
 	@Override
-	public void discardState() throws Exception {
-		StateUtil.bestEffortDiscardAllStateObjects(
-				Arrays.asList(
-						legacyOperatorState,
-						managedOperatorState,
-						rawOperatorState,
-						managedKeyedState,
-						rawKeyedState));
+	public void registerSharedStates(SharedStateRegistry sharedStateRegistry) {
+		// No shared states
+	}
+
+	@Override
+	public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) {
+		// No shared states
+	}
+
+	@Override
+	public void discardSharedStatesOnFail() {
+		// No shared states
+	}
+
+	@Override
+	public long getStateSize() {
+		return stateSize;
 	}
 
 	// --------------------------------------------------------------------------------------------
@@ -206,7 +231,7 @@ public class SubtaskState implements StateObject {
 				", operatorStateFromBackend=" + managedOperatorState +
 				", operatorStateFromStream=" + rawOperatorState +
 				", keyedStateFromBackend=" + managedKeyedState +
-				", keyedStateHandleFromStream=" + rawKeyedState +
+				", keyedStateFromStream=" + rawKeyedState +
 				", stateSize=" + stateSize +
 				'}';
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
index 76f1c51..19fe962 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/TaskState.java
@@ -19,8 +19,8 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.state.StateObject;
-import org.apache.flink.runtime.state.StateUtil;
+import org.apache.flink.runtime.state.CompositeStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.util.Preconditions;
 
 import java.util.Collection;
@@ -35,7 +35,7 @@ import java.util.Objects;
  *
  * This class basically groups all non-partitioned state and key-group state belonging to the same job vertex together.
  */
-public class TaskState implements StateObject {
+public class TaskState implements CompositeStateHandle {
 
 	private static final long serialVersionUID = -4845578005863201810L;
 
@@ -124,9 +124,31 @@ public class TaskState implements StateObject {
 
 	@Override
 	public void discardState() throws Exception {
-		StateUtil.bestEffortDiscardAllStateObjects(subtaskStates.values());
+		for (SubtaskState subtaskState : subtaskStates.values()) {
+			subtaskState.discardState();
+		}
+	}
+
+	@Override
+	public void registerSharedStates(SharedStateRegistry sharedStateRegistry) {
+		for (SubtaskState subtaskState : subtaskStates.values()) {
+			subtaskState.registerSharedStates(sharedStateRegistry);
+		}
+	}
+
+	@Override
+	public void unregisterSharedStates(SharedStateRegistry sharedStateRegistry) {
+		for (SubtaskState subtaskState : subtaskStates.values()) {
+			subtaskState.unregisterSharedStates(sharedStateRegistry);
+		}
 	}
 
+	@Override
+	public void discardSharedStatesOnFail() {
+		for (SubtaskState subtaskState : subtaskStates.values()) {
+			subtaskState.discardSharedStatesOnFail();
+		}
+	}
 
 	@Override
 	public long getStateSize() {

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
index af7bcc4..07546ea 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ZooKeeperCompletedCheckpointStore.java
@@ -27,6 +27,7 @@ import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobmanager.HighAvailabilityMode;
 import org.apache.flink.runtime.state.RetrievableStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperStateHandleStore;
 import org.apache.flink.util.FlinkException;
@@ -123,7 +124,7 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 		this.checkpointsInZooKeeper = new ZooKeeperStateHandleStore<>(this.client, stateStorage, executor);
 
 		this.checkpointStateHandles = new ArrayDeque<>(maxNumberOfCheckpointsToRetain + 1);
-
+		
 		LOG.info("Initialized in '{}'.", checkpointsPath);
 	}
 
@@ -140,7 +141,7 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 	 * that the history of checkpoints is consistent.
 	 */
 	@Override
-	public void recover() throws Exception {
+	public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
 		LOG.info("Recovering checkpoints from ZooKeeper.");
 
 		// Clear local handles in order to prevent duplicates on
@@ -164,8 +165,24 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 
 		LOG.info("Found {} checkpoints in ZooKeeper.", numberOfInitialCheckpoints);
 
-		for (Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String> checkpoint : initialCheckpoints) {
-			checkpointStateHandles.add(checkpoint);
+		for (Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String> checkpointStateHandle : initialCheckpoints) {
+
+			CompletedCheckpoint completedCheckpoint = null;
+
+			try {
+				completedCheckpoint = retrieveCompletedCheckpoint(checkpointStateHandle);
+			} catch (Exception e) {
+				LOG.warn("Could not retrieve checkpoint. Removing it from the completed " +
+					"checkpoint store.", e);
+
+				// remove the checkpoint with broken state handle
+				removeBrokenStateHandle(checkpointStateHandle);
+			}
+
+			if (completedCheckpoint != null) {
+				completedCheckpoint.registerSharedStates(sharedStateRegistry);
+				checkpointStateHandles.add(checkpointStateHandle);
+			}
 		}
 	}
 
@@ -175,21 +192,24 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 	 * @param checkpoint Completed checkpoint to add.
 	 */
 	@Override
-	public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception {
+	public void addCheckpoint(final CompletedCheckpoint checkpoint, final SharedStateRegistry sharedStateRegistry) throws Exception {
 		checkNotNull(checkpoint, "Checkpoint");
+		
+		final String path = checkpointIdToPath(checkpoint.getCheckpointID());
+		final RetrievableStateHandle<CompletedCheckpoint> stateHandle;
 
 		// First add the new one. If it fails, we don't want to loose existing data.
-		String path = checkpointIdToPath(checkpoint.getCheckpointID());
-
-		final RetrievableStateHandle<CompletedCheckpoint> stateHandle =
-				checkpointsInZooKeeper.add(path, checkpoint);
+		stateHandle = checkpointsInZooKeeper.add(path, checkpoint);
 
 		checkpointStateHandles.addLast(new Tuple2<>(stateHandle, path));
 
+		// Register all shared states in the checkpoint
+		checkpoint.registerSharedStates(sharedStateRegistry);
+
 		// Everything worked, let's remove a previous checkpoint if necessary.
 		while (checkpointStateHandles.size() > maxNumberOfCheckpointsToRetain) {
 			try {
-				removeSubsumed(checkpointStateHandles.removeFirst());
+				removeSubsumed(checkpointStateHandles.removeFirst(), sharedStateRegistry);
 			} catch (Exception e) {
 				LOG.warn("Failed to subsume the old checkpoint", e);
 			}
@@ -261,13 +281,13 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 	}
 
 	@Override
-	public void shutdown(JobStatus jobStatus) throws Exception {
+	public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
 		if (jobStatus.isGloballyTerminalState()) {
 			LOG.info("Shutting down");
 
 			for (Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String> checkpoint : checkpointStateHandles) {
 				try {
-					removeShutdown(checkpoint, jobStatus);
+					removeShutdown(checkpoint, jobStatus, sharedStateRegistry);
 				} catch (Exception e) {
 					LOG.error("Failed to discard checkpoint.", e);
 				}
@@ -289,11 +309,19 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 
 	// ------------------------------------------------------------------------
 
-	private void removeSubsumed(final Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String> stateHandleAndPath) throws Exception {
+	private void removeSubsumed(
+		final Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String> stateHandleAndPath,
+		final SharedStateRegistry sharedStateRegistry) throws Exception {
+		
 		Callable<Void> action = new Callable<Void>() {
 			@Override
 			public Void call() throws Exception {
-				stateHandleAndPath.f0.retrieveState().subsume();
+				CompletedCheckpoint completedCheckpoint = retrieveCompletedCheckpoint(stateHandleAndPath);
+				
+				if (completedCheckpoint != null) {
+					completedCheckpoint.discardOnSubsume(sharedStateRegistry);
+				}
+
 				return null;
 			}
 		};
@@ -303,13 +331,18 @@ public class ZooKeeperCompletedCheckpointStore implements CompletedCheckpointSto
 
 	private void removeShutdown(
 			final Tuple2<RetrievableStateHandle<CompletedCheckpoint>, String> stateHandleAndPath,
-			final JobStatus jobStatus) throws Exception {
+			final JobStatus jobStatus,
+			final SharedStateRegistry sharedStateRegistry) throws Exception {
 
 		Callable<Void> action = new Callable<Void>() {
 			@Override
 			public Void call() throws Exception {
-				CompletedCheckpoint checkpoint = stateHandleAndPath.f0.retrieveState();
-				checkpoint.discard(jobStatus);
+				CompletedCheckpoint completedCheckpoint = retrieveCompletedCheckpoint(stateHandleAndPath);
+				
+				if (completedCheckpoint != null) {
+					completedCheckpoint.discardOnShutdown(jobStatus, sharedStateRegistry);
+				}
+
 				return null;
 			}
 		};

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
new file mode 100644
index 0000000..2ea5bc9
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CompositeStateHandle.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * Base of all snapshots that are taken by {@link StateBackend}s and some other
+ * components in tasks.
+ *
+ * <p>Each snapshot is composed of a collection of {@link StateObject}s some of 
+ * which may be referenced by other checkpoints. The shared states will be 
+ * registered at the given {@link SharedStateRegistry} when the handle is
+ * received by the {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator}
+ * and will be discarded when the checkpoint is discarded.
+ * 
+ * <p>The {@link SharedStateRegistry} is responsible for the discarding of the
+ * shared states. The composite state handle should only delete those private
+ * states in the {@link StateObject#discardState()} method.
+ */
+public interface CompositeStateHandle extends StateObject {
+
+	/**
+	 * Register both created and referenced shared states in the given
+	 * {@link SharedStateRegistry}. This method is called when the checkpoint
+	 * successfully completes or is recovered from failures.
+	 *
+	 * @param stateRegistry The registry where shared states are registered.
+	 */
+	void registerSharedStates(SharedStateRegistry stateRegistry);
+
+	/**
+	 * Unregister both created and referenced shared states in the given
+	 * {@link SharedStateRegistry}. This method is called when the checkpoint is
+	 * subsumed or the job is shut down.
+	 *
+	 * @param stateRegistry The registry where shared states are registered.
+	 */
+	void unregisterSharedStates(SharedStateRegistry stateRegistry);
+
+	/**
+	 * Discard all shared states created in this checkpoint. This method is
+	 * called when the checkpoint fails to complete.
+	 */
+	void discardSharedStatesOnFail() throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
new file mode 100644
index 0000000..f856052
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateHandle.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+/**
+ * A handle to those states that are referenced by different checkpoints.
+ *
+ * <p> Each shared state handle is identified by a unique key. Two shared states
+ * are considered equal if their keys are identical.
+ *
+ * <p> All shared states are registered at the {@link SharedStateRegistry} once
+ * they are received by the {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator}
+ * and will be unregistered when the checkpoints are discarded. A shared state
+ * will be discarded once it is not referenced by any checkpoint. A shared state
+ * should not be referenced any more if it has been discarded.
+ */
+public interface SharedStateHandle extends StateObject {
+
+	/**
+	 * Return the identifier of the shared state.
+	 */
+	String getKey();
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
new file mode 100644
index 0000000..b5048d0
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/SharedStateRegistry.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.util.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A {@code SharedStateRegistry} will be deployed in the 
+ * {@link org.apache.flink.runtime.checkpoint.CheckpointCoordinator} to 
+ * maintain the reference count of {@link SharedStateHandle}s which are shared
+ * among different checkpoints.
+ */
+public class SharedStateRegistry implements Serializable {
+
+	private static Logger LOG = LoggerFactory.getLogger(SharedStateRegistry.class);
+
+	private static final long serialVersionUID = -8357254413007773970L;
+
+	/** All registered state objects */
+	private final Map<String, SharedStateEntry> registeredStates = new HashMap<>();
+
+	/**
+	 * Register the state in the registry
+	 *
+	 * @param state The state to register
+	 * @param isNew True if the shared state is newly created
+	 */
+	public void register(SharedStateHandle state, boolean isNew) {
+		if (state == null) {
+			return;
+		}
+
+		synchronized (registeredStates) {
+			SharedStateEntry entry = registeredStates.get(state.getKey());
+
+			if (isNew) {
+				Preconditions.checkState(entry == null,
+					"The state cannot be created more than once.");
+
+				registeredStates.put(state.getKey(), new SharedStateEntry(state));
+			} else {
+				Preconditions.checkState(entry != null,
+					"The state cannot be referenced if it has not been created yet.");
+
+				entry.increaseReferenceCount();
+			}
+		}
+	}
+
+	/**
+	 * Unregister the state in the registry
+	 *
+	 * @param state The state to unregister
+	 */
+	public void unregister(SharedStateHandle state) {
+		if (state == null) {
+			return;
+		}
+
+		synchronized (registeredStates) {
+			SharedStateEntry entry = registeredStates.get(state.getKey());
+
+			if (entry == null) {
+				throw new IllegalStateException("Cannot unregister an unexisted state.");
+			}
+
+			entry.decreaseReferenceCount();
+
+			// Remove the state from the registry when it's not referenced any more.
+			if (entry.getReferenceCount() == 0) {
+				registeredStates.remove(state.getKey());
+
+				try {
+					entry.getState().discardState();
+				} catch (Exception e) {
+					LOG.warn("Cannot properly discard the state " + entry.getState() + ".", e);
+				}
+			}
+		}
+	}
+
+	/**
+	 * Register given shared states in the registry.
+	 *
+	 * @param stateHandles The shared states to register.
+	 */
+	public void registerAll(Collection<? extends CompositeStateHandle> stateHandles) {
+		if (stateHandles == null) {
+			return;
+		}
+
+		synchronized (registeredStates) {
+			for (CompositeStateHandle stateHandle : stateHandles) {
+				stateHandle.registerSharedStates(this);
+			}
+		}
+	}
+
+
+
+	/**
+	 * Unregister all the shared states referenced by the given.
+	 *
+	 * @param stateHandles The shared states to unregister.
+	 */
+	public void unregisterAll(Collection<? extends CompositeStateHandle> stateHandles) {
+		if (stateHandles == null) {
+			return;
+		}
+
+		synchronized (registeredStates) {
+			for (CompositeStateHandle stateHandle : stateHandles) {
+				stateHandle.unregisterSharedStates(this);
+			}
+		}
+	}
+
+	private static class SharedStateEntry {
+		/** The shared object */
+		private final SharedStateHandle state;
+
+		/** The reference count of the object */
+		private int referenceCount;
+
+		SharedStateEntry(SharedStateHandle value) {
+			this.state = value;
+			this.referenceCount = 1;
+		}
+
+		SharedStateHandle getState() {
+			return state;
+		}
+
+		int getReferenceCount() {
+			return referenceCount;
+		}
+
+		void increaseReferenceCount() {
+			++referenceCount;
+		}
+
+		void decreaseReferenceCount() {
+			--referenceCount;
+		}
+	}
+
+
+	@VisibleForTesting
+	public int getReferenceCount(SharedStateHandle state) {
+		SharedStateEntry entry = registeredStates.get(state.getKey());
+
+		return entry == null ? 0 : entry.getReferenceCount();
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
index f2ecde5..40e2c2a 100644
--- a/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
+++ b/flink-runtime/src/main/scala/org/apache/flink/runtime/jobmanager/JobManager.scala
@@ -1365,27 +1365,12 @@ class JobManager(
                 val savepointPath = savepointSettings.getRestorePath()
                 val allowNonRestored = savepointSettings.allowNonRestoredState()
 
-                log.info(s"Starting job from savepoint '$savepointPath'" +
-                  (if (allowNonRestored) " (allowing non restored state)" else "") + ".")
-
-                  // load the savepoint as a checkpoint into the system
-                  val savepoint: CompletedCheckpoint = SavepointLoader.loadAndValidateSavepoint(
-                    jobId,
-                    executionGraph.getAllVertices,
-                    savepointPath,
-                    executionGraph.getUserClassLoader,
-                    allowNonRestored)
-
-                executionGraph.getCheckpointCoordinator.getCheckpointStore
-                  .addCheckpoint(savepoint)
-
-                // Reset the checkpoint ID counter
-                val nextCheckpointId: Long = savepoint.getCheckpointID + 1
-                log.info(s"Reset the checkpoint ID to $nextCheckpointId")
-                executionGraph.getCheckpointCoordinator.getCheckpointIdCounter
-                  .setCount(nextCheckpointId)
-
-                executionGraph.restoreLatestCheckpointedState(true, allowNonRestored)
+                executionGraph.getCheckpointCoordinator.restoreSavepoint(
+                  savepointPath, 
+                  allowNonRestored,
+                  executionGraph.getAllVertices,
+                  executionGraph.getUserClassLoader
+                )
               } catch {
                 case e: Exception =>
                   jobInfo.notifyClients(

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index 340e2a7..632f2c0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.util.TestLogger;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -38,7 +39,9 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
 @RunWith(PowerMockRunner.class)
@@ -83,12 +86,13 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		assertFalse(pendingCheckpoint.isDiscarded());
 
 		final long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next();
-
-		AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId);
-
-		CompletedCheckpoint completedCheckpoint = mock(CompletedCheckpoint.class);
-		PowerMockito.whenNew(CompletedCheckpoint.class).withAnyArguments().thenReturn(completedCheckpoint);
-
+		
+		SubtaskState subtaskState = mock(SubtaskState.class);
+		PowerMockito.when(subtaskState.getLegacyOperatorState()).thenReturn(null);
+		PowerMockito.when(subtaskState.getManagedOperatorState()).thenReturn(null);
+		
+		AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new CheckpointMetrics(), subtaskState);
+		
 		try {
 			coord.receiveAcknowledgeMessage(acknowledgeMessage);
 			fail("Expected a checkpoint exception because the completed checkpoint store could not " +
@@ -100,18 +104,20 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		// make sure that the pending checkpoint has been discarded after we could not complete it
 		assertTrue(pendingCheckpoint.isDiscarded());
 
-		verify(completedCheckpoint).discard();
+		// make sure that the subtask state has been discarded after we could not complete it.
+		verify(subtaskState, times(1)).discardSharedStatesOnFail();
+		verify(subtaskState).discardState();
 	}
 
 	private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore {
 
 		@Override
-		public void recover() throws Exception {
+		public void recover(SharedStateRegistry sharedStateRegistry) throws Exception {
 			throw new UnsupportedOperationException("Not implemented.");
 		}
 
 		@Override
-		public void addCheckpoint(CompletedCheckpoint checkpoint) throws Exception {
+		public void addCheckpoint(CompletedCheckpoint checkpoint, SharedStateRegistry sharedStateRegistry) throws Exception {
 			throw new Exception("The failing completed checkpoint store failed again... :-(");
 		}
 
@@ -121,7 +127,7 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		}
 
 		@Override
-		public void shutdown(JobStatus jobStatus) throws Exception {
+		public void shutdown(JobStatus jobStatus, SharedStateRegistry sharedStateRegistry) throws Exception {
 			throw new UnsupportedOperationException("Not implemented.");
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 117c70d..fabf3fc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -43,11 +43,13 @@ import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
@@ -86,13 +88,17 @@ import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doNothing;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.withSettings;
 
 /**
  * Tests for the checkpoint coordinator.
@@ -545,19 +551,24 @@ public class CheckpointCoordinatorTest {
 			}
 
 			// acknowledge from one of the tasks
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId));
+			SubtaskState subtaskState2 = mock(SubtaskState.class);
+			AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), subtaskState2);
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
 			assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks());
 			assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks());
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
+			verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class));
 
 			// acknowledge the same task again (should not matter)
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId));
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
+			verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class));
 
 			// acknowledge the other task.
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId));
+			SubtaskState subtaskState1 = mock(SubtaskState.class);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), subtaskState1));
 
 			// the checkpoint is internally converted to a successful checkpoint and the
 			// pending checkpoint object is disposed
@@ -570,6 +581,12 @@ public class CheckpointCoordinatorTest {
 			// the canceler should be removed now
 			assertEquals(0, coord.getNumScheduledTasks());
 
+			// validate that the subtasks states have registered their shared states.
+			{
+				verify(subtaskState1, times(1)).registerSharedStates(any(SharedStateRegistry.class));
+				verify(subtaskState2, times(1)).registerSharedStates(any(SharedStateRegistry.class));
+			}
+
 			// validate that the relevant tasks got a confirmation message
 			{
 				verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class));
@@ -580,7 +597,7 @@ public class CheckpointCoordinatorTest {
 			assertEquals(jid, success.getJobId());
 			assertEquals(timestamp, success.getTimestamp());
 			assertEquals(checkpoint.getCheckpointId(), success.getCheckpointID());
-			assertTrue(success.getTaskStates().isEmpty());
+			assertEquals(2, success.getTaskStates().size());
 
 			// ---------------
 			// trigger another checkpoint and see that this one replaces the other checkpoint
@@ -602,6 +619,12 @@ public class CheckpointCoordinatorTest {
 			assertEquals(checkpointIdNew, successNew.getCheckpointID());
 			assertTrue(successNew.getTaskStates().isEmpty());
 
+			// validate that the subtask states in old savepoint have unregister their shared states
+			{
+				verify(subtaskState1, times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+				verify(subtaskState2, times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+			}
+
 			// validate that the relevant tasks got a confirmation message
 			{
 				verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class));
@@ -678,8 +701,6 @@ public class CheckpointCoordinatorTest {
 			verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class));
 			verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class));
 
-			CheckpointMetaData checkpointMetaData1 = new CheckpointMetaData(checkpointId1, 0L);
-
 			// acknowledge one of the three tasks
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1));
 
@@ -699,8 +720,6 @@ public class CheckpointCoordinatorTest {
 			}
 			long checkpointId2 = pending2.getCheckpointId();
 
-			CheckpointMetaData checkpointMetaData2 = new CheckpointMetaData(checkpointId2, 0L);
-
 			// trigger messages should have been sent
 			verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class));
 			verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class));
@@ -812,10 +831,9 @@ public class CheckpointCoordinatorTest {
 			verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class));
 			verify(triggerVertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId1), eq(timestamp1), any(CheckpointOptions.class));
 
-			CheckpointMetaData checkpointMetaData1 = new CheckpointMetaData(checkpointId1, 0L);
-
 			// acknowledge one of the three tasks
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1));
+			SubtaskState subtaskState1_2 = mock(SubtaskState.class);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), subtaskState1_2));
 
 			// start the second checkpoint
 			// trigger the first checkpoint. this should succeed
@@ -839,12 +857,18 @@ public class CheckpointCoordinatorTest {
 
 			// we acknowledge one more task from the first checkpoint and the second
 			// checkpoint completely. The second checkpoint should then subsume the first checkpoint
-			CheckpointMetaData checkpointMetaData2= new CheckpointMetaData(checkpointId2, 0L);
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2));
+			SubtaskState subtaskState2_3 = mock(SubtaskState.class);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), subtaskState2_3));
+
+			SubtaskState subtaskState2_1 = mock(SubtaskState.class);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), subtaskState2_1));
+
+			SubtaskState subtaskState1_1 = mock(SubtaskState.class);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), subtaskState1_1));
+
+			SubtaskState subtaskState2_2 = mock(SubtaskState.class);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), subtaskState2_2));
 
 			// now, the second checkpoint should be confirmed, and the first discarded
 			// actually both pending checkpoints are discarded, and the second has been transformed
@@ -855,21 +879,47 @@ public class CheckpointCoordinatorTest {
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
+			// validate that all received subtask states in the first checkpoint have been discarded
+			verify(subtaskState1_1, times(1)).discardSharedStatesOnFail();
+			verify(subtaskState1_2, times(1)).discardSharedStatesOnFail();
+			verify(subtaskState1_1, times(1)).discardState();
+			verify(subtaskState1_2, times(1)).discardState();
+
+			// validate that all subtask states in the second checkpoint are not discarded
+			verify(subtaskState2_1, never()).unregisterSharedStates(any(SharedStateRegistry.class));
+			verify(subtaskState2_2, never()).unregisterSharedStates(any(SharedStateRegistry.class));
+			verify(subtaskState2_3, never()).unregisterSharedStates(any(SharedStateRegistry.class));
+			verify(subtaskState2_1, never()).discardState();
+			verify(subtaskState2_2, never()).discardState();
+			verify(subtaskState2_3, never()).discardState();
+
 			// validate the committed checkpoints
 			List<CompletedCheckpoint> scs = coord.getSuccessfulCheckpoints();
 			CompletedCheckpoint success = scs.get(0);
 			assertEquals(checkpointId2, success.getCheckpointID());
 			assertEquals(timestamp2, success.getTimestamp());
 			assertEquals(jid, success.getJobId());
-			assertTrue(success.getTaskStates().isEmpty());
+			assertEquals(3, success.getTaskStates().size());
 
 			// the first confirm message should be out
 			verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2));
 
 			// send the last remaining ack for the first checkpoint. This should not do anything
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1));
+			SubtaskState subtaskState1_3 = mock(SubtaskState.class);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3));
+			verify(subtaskState1_3, times(1)).discardSharedStatesOnFail();
+			verify(subtaskState1_3, times(1)).discardState();
 
 			coord.shutdown(JobStatus.FINISHED);
+
+			// validate that the states in the second checkpoint have been discarded
+			verify(subtaskState2_1, times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+			verify(subtaskState2_2, times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+			verify(subtaskState2_3, times(1)).unregisterSharedStates(any(SharedStateRegistry.class));
+			verify(subtaskState2_1, times(1)).discardState();
+			verify(subtaskState2_2, times(1)).discardState();
+			verify(subtaskState2_3, times(1)).discardState();
+
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -924,7 +974,8 @@ public class CheckpointCoordinatorTest {
 			PendingCheckpoint checkpoint = coord.getPendingCheckpoints().values().iterator().next();
 			assertFalse(checkpoint.isDiscarded());
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId()));
+			SubtaskState subtaskState = mock(SubtaskState.class);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), subtaskState));
 
 			// wait until the checkpoint must have expired.
 			// we check every 250 msecs conservatively for 5 seconds
@@ -941,6 +992,10 @@ public class CheckpointCoordinatorTest {
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
+			// validate that the received states have been discarded
+			verify(subtaskState, times(1)).discardSharedStatesOnFail();
+			verify(subtaskState, times(1)).discardState();
+
 			// no confirm message must have been sent
 			verify(commitVertex.getCurrentExecutionAttempt(), times(0)).notifyCheckpointComplete(anyLong(), anyLong());
 
@@ -993,8 +1048,6 @@ public class CheckpointCoordinatorTest {
 			// of the vertices that need to be acknowledged.
 			// non of the messages should throw an exception
 
-			CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
-
 			// wrong job id
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), ackAttemptID1, checkpointId));
 
@@ -1058,19 +1111,22 @@ public class CheckpointCoordinatorTest {
 
 		long checkpointId = pendingCheckpoint.getCheckpointId();
 
-		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
-
 		SubtaskState triggerSubtaskState = mock(SubtaskState.class);
 
 		// acknowledge the first trigger vertex
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState));
 
+		// verify that the subtask state has registered its shared states at the registry
+		verify(triggerSubtaskState, never()).discardSharedStatesOnFail();
+		verify(triggerSubtaskState, never()).discardState();
+
 		SubtaskState unknownSubtaskState = mock(SubtaskState.class);
 
 		// receive an acknowledge message for an unknown vertex
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState));
 
 		// we should discard acknowledge messages from an unknown vertex belonging to our job
+		verify(unknownSubtaskState, times(1)).discardSharedStatesOnFail();
 		verify(unknownSubtaskState, times(1)).discardState();
 
 		SubtaskState differentJobSubtaskState = mock(SubtaskState.class);
@@ -1079,20 +1135,25 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState));
 
 		// we should not interfere with different jobs
+		verify(differentJobSubtaskState, never()).discardSharedStatesOnFail();
 		verify(differentJobSubtaskState, never()).discardState();
 
 		// duplicate acknowledge message for the trigger vertex
+		reset(triggerSubtaskState);
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState));
 
 		// duplicate acknowledge messages for a known vertex should not trigger discarding the state
+		verify(triggerSubtaskState, never()).discardSharedStatesOnFail();
 		verify(triggerSubtaskState, never()).discardState();
 
 		// let the checkpoint fail at the first ack vertex
+		reset(triggerSubtaskState);
 		coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, ackAttemptId1, checkpointId));
 
 		assertTrue(pendingCheckpoint.isDiscarded());
 
 		// check that we've cleaned up the already acknowledged state
+		verify(triggerSubtaskState, times(1)).discardSharedStatesOnFail();
 		verify(triggerSubtaskState, times(1)).discardState();
 
 		SubtaskState ackSubtaskState = mock(SubtaskState.class);
@@ -1101,12 +1162,15 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState));
 
 		// check that we also cleaned up this state
+		verify(ackSubtaskState, times(1)).discardSharedStatesOnFail();
 		verify(ackSubtaskState, times(1)).discardState();
 
 		// receive an acknowledge message from an unknown job
+		reset(differentJobSubtaskState);
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState));
 
 		// we should not interfere with different jobs
+		verify(differentJobSubtaskState, never()).discardSharedStatesOnFail();
 		verify(differentJobSubtaskState, never()).discardState();
 
 		SubtaskState unknownSubtaskState2 = mock(SubtaskState.class);
@@ -1115,6 +1179,7 @@ public class CheckpointCoordinatorTest {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2));
 
 		// we should discard acknowledge messages from an unknown vertex belonging to our job
+		verify(unknownSubtaskState2, times(1)).discardSharedStatesOnFail();
 		verify(unknownSubtaskState2, times(1)).discardState();
 	}
 
@@ -1363,12 +1428,11 @@ public class CheckpointCoordinatorTest {
 		assertFalse(pending.isDiscarded());
 		assertFalse(pending.isFullyAcknowledged());
 		assertFalse(pending.canBeSubsumed());
-		assertTrue(pending instanceof PendingCheckpoint);
-
-		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
 		// acknowledge from one of the tasks
-		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId));
+		SubtaskState subtaskState2 = mock(SubtaskState.class);
+		AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), subtaskState2);
+		coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2);
 		assertEquals(1, pending.getNumberOfAcknowledgedTasks());
 		assertEquals(1, pending.getNumberOfNonAcknowledgedTasks());
 		assertFalse(pending.isDiscarded());
@@ -1376,13 +1440,14 @@ public class CheckpointCoordinatorTest {
 		assertFalse(savepointFuture.isDone());
 
 		// acknowledge the same task again (should not matter)
-		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId));
+		coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2);
 		assertFalse(pending.isDiscarded());
 		assertFalse(pending.isFullyAcknowledged());
 		assertFalse(savepointFuture.isDone());
 
 		// acknowledge the other task.
-		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId));
+		SubtaskState subtaskState1 = mock(SubtaskState.class);
+		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), subtaskState1));
 
 		// the checkpoint is internally converted to a successful checkpoint and the
 		// pending checkpoint object is disposed
@@ -1399,11 +1464,17 @@ public class CheckpointCoordinatorTest {
 			verify(vertex2.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId), eq(timestamp));
 		}
 
+		// validate that the shared states are registered
+		{
+			verify(subtaskState1, times(1)).registerSharedStates(any(SharedStateRegistry.class));
+			verify(subtaskState2, times(1)).registerSharedStates(any(SharedStateRegistry.class));
+		}
+
 		CompletedCheckpoint success = coord.getSuccessfulCheckpoints().get(0);
 		assertEquals(jid, success.getJobId());
 		assertEquals(timestamp, success.getTimestamp());
 		assertEquals(pending.getCheckpointId(), success.getCheckpointID());
-		assertTrue(success.getTaskStates().isEmpty());
+		assertEquals(2, success.getTaskStates().size());
 
 		// ---------------
 		// trigger another checkpoint and see that this one replaces the other checkpoint
@@ -1426,6 +1497,14 @@ public class CheckpointCoordinatorTest {
 		assertTrue(successNew.getTaskStates().isEmpty());
 		assertTrue(savepointFuture.isDone());
 
+		// validate that the first savepoint does not discard its private states.
+		verify(subtaskState1, never()).discardState();
+		verify(subtaskState2, never()).discardState();
+
+		// Savepoints are not supposed to have any shared state.
+		verify(subtaskState1, never()).unregisterSharedStates(any(SharedStateRegistry.class));
+		verify(subtaskState2, never()).unregisterSharedStates(any(SharedStateRegistry.class));
+
 		// validate that the relevant tasks got a confirmation message
 		{
 			verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointIdNew), eq(timestampNew), any(CheckpointOptions.class));
@@ -1478,7 +1557,6 @@ public class CheckpointCoordinatorTest {
 		// Trigger savepoint and checkpoint
 		Future<CompletedCheckpoint> savepointFuture1 = coord.triggerSavepoint(timestamp, savepointDir);
 		long savepointId1 = counter.getLast();
-		CheckpointMetaData checkpointMetaDataS1 = new CheckpointMetaData(savepointId1, 0L);
 		assertEquals(1, coord.getNumberOfPendingCheckpoints());
 
 		assertTrue(coord.triggerCheckpoint(timestamp + 1, false));
@@ -1488,8 +1566,6 @@ public class CheckpointCoordinatorTest {
 		long checkpointId2 = counter.getLast();
 		assertEquals(3, coord.getNumberOfPendingCheckpoints());
 
-		CheckpointMetaData checkpointMetaData2 = new CheckpointMetaData(checkpointId2, 0L);
-
 		// 2nd checkpoint should subsume the 1st checkpoint, but not the savepoint
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId2));
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointId2));
@@ -1505,7 +1581,6 @@ public class CheckpointCoordinatorTest {
 
 		Future<CompletedCheckpoint> savepointFuture2 = coord.triggerSavepoint(timestamp + 4, savepointDir);
 		long savepointId2 = counter.getLast();
-		CheckpointMetaData checkpointMetaDataS2 = new CheckpointMetaData(savepointId2, 0L);
 		assertEquals(3, coord.getNumberOfPendingCheckpoints());
 
 		// 2nd savepoint should subsume the last checkpoint, but not the 1st savepoint
@@ -1880,6 +1955,8 @@ public class CheckpointCoordinatorTest {
 		ExecutionVertex[] arrayExecutionVertices =
 				allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
 
+		CompletedCheckpointStore store = new RecoverableCompletedCheckpointStore();
+
 		// set up the coordinator and validate the initial state
 		CheckpointCoordinator coord = new CheckpointCoordinator(
 			jid,
@@ -1892,7 +1969,7 @@ public class CheckpointCoordinatorTest {
 			arrayExecutionVertices,
 			arrayExecutionVertices,
 			new StandaloneCheckpointIDCounter(),
-			new StandaloneCompletedCheckpointStore(1),
+			store,
 			null,
 			Executors.directExecutor());
 
@@ -1901,38 +1978,32 @@ public class CheckpointCoordinatorTest {
 
 		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
-		CheckpointMetaData checkpointMetaData = new CheckpointMetaData(checkpointId, 0L);
 
 		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
 		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID1, index);
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false);
-			KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
+			SubtaskState subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
 
-			SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					subtaskState);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID2, index);
-			ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false);
-			KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
-			SubtaskState checkpointStateHandles = new SubtaskState(nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null);
+			SubtaskState subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
+
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					subtaskState);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -1941,6 +2012,20 @@ public class CheckpointCoordinatorTest {
 
 		assertEquals(1, completedCheckpoints.size());
 
+		// shutdown the store
+		SharedStateRegistry sharedStateRegistry = coord.getSharedStateRegistry();
+		store.shutdown(JobStatus.SUSPENDED, sharedStateRegistry);
+
+		// All shared states should be unregistered once the store is shut down
+		for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) {
+			for (TaskState taskState : completedCheckpoint.getTaskStates().values()) {
+				for (SubtaskState subtaskState : taskState.getStates()) {
+					verify(subtaskState, times(1)).unregisterSharedStates(sharedStateRegistry);
+				}
+			}
+		}
+
+		// restore the store
 		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
 
 		tasks.put(jobVertexID1, jobVertex1);
@@ -1948,6 +2033,15 @@ public class CheckpointCoordinatorTest {
 
 		coord.restoreLatestCheckpointedState(tasks, true, false);
 
+		// validate that all shared states are registered again after the recovery.
+		for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) {
+			for (TaskState taskState : completedCheckpoint.getTaskStates().values()) {
+				for (SubtaskState subtaskState : taskState.getStates()) {
+					verify(subtaskState, times(2)).registerSharedStates(sharedStateRegistry);
+				}
+			}
+		}
+
 		// verify the restored state
 		verifyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1);
 		verifyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2);
@@ -2666,6 +2760,26 @@ public class CheckpointCoordinatorTest {
 		return vertex;
 	}
 
+	static SubtaskState mockSubtaskState(
+		JobVertexID jobVertexID,
+		int index,
+		KeyGroupRange keyGroupRange) throws IOException {
+
+		ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID, index);
+		ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false);
+		KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false);
+
+		SubtaskState subtaskState = mock(SubtaskState.class, withSettings().serializable());
+
+		doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState();
+		doReturn(partitionableState).when(subtaskState).getManagedOperatorState();
+		doReturn(null).when(subtaskState).getRawOperatorState();
+		doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState();
+		doReturn(null).when(subtaskState).getRawKeyedState();
+
+		return subtaskState;
+	}
+
 	public static void verifyStateRestore(
 			JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
 			List<KeyGroupRange> keyGroupPartitions) throws Exception {
@@ -3018,7 +3132,6 @@ public class CheckpointCoordinatorTest {
 		ExecutionVertex vertex1 = mockExecutionVertex(new ExecutionAttemptID());
 
 		StandaloneCompletedCheckpointStore store = new StandaloneCompletedCheckpointStore(1);
-		store.addCheckpoint(new CompletedCheckpoint(new JobID(), 0, 0, 0, Collections.<JobVertexID, TaskState>emptyMap()));
 
 		// set up the coordinator and validate the initial state
 		CheckpointCoordinator coord = new CheckpointCoordinator(
@@ -3036,6 +3149,10 @@ public class CheckpointCoordinatorTest {
 			null,
 			Executors.directExecutor());
 
+		store.addCheckpoint(
+			new CompletedCheckpoint(new JobID(), 0, 0, 0, Collections.<JobVertexID, TaskState>emptyMap()),
+			coord.getSharedStateRegistry());
+
 		CheckpointStatsTracker tracker = mock(CheckpointStatsTracker.class);
 		coord.setCheckpointStatsTracker(tracker);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/218bed8b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 7e0a7c1..9e372e1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -255,7 +255,7 @@ public class CheckpointStateRestoreTest {
 		}
 		CompletedCheckpoint checkpoint = new CompletedCheckpoint(new JobID(), 0, 1, 2, new HashMap<>(checkpointTaskStates));
 
-		coord.getCheckpointStore().addCheckpoint(checkpoint);
+		coord.getCheckpointStore().addCheckpoint(checkpoint, coord.getSharedStateRegistry());
 
 		coord.restoreLatestCheckpointedState(tasks, true, false);
 		coord.restoreLatestCheckpointedState(tasks, true, true);
@@ -273,7 +273,7 @@ public class CheckpointStateRestoreTest {
 
 		checkpoint = new CompletedCheckpoint(new JobID(), 1, 2, 3, new HashMap<>(checkpointTaskStates));
 
-		coord.getCheckpointStore().addCheckpoint(checkpoint);
+		coord.getCheckpointStore().addCheckpoint(checkpoint, coord.getSharedStateRegistry());
 
 		// (i) Allow non restored state (should succeed)
 		coord.restoreLatestCheckpointedState(tasks, true, true);