You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by se...@apache.org on 2016/09/22 12:42:57 UTC

[1/2] flink git commit: [FLINK-4628] [core] Provide user class loader during input split assignment

Repository: flink
Updated Branches:
  refs/heads/master e6fbda906 -> 345b2529a


[FLINK-4628] [core] Provide user class loader during input split assignment

In analogy to the configure() method, this also sets a context class
loader during input split assignment.

This closes #2505


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/345b2529
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/345b2529
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/345b2529

Branch: refs/heads/master
Commit: 345b2529a8acdd59d67e89ea930ec69ad69a55d3
Parents: 3b8fe95
Author: Maximilian Michels <mx...@apache.org>
Authored: Fri Sep 16 12:21:54 2016 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Thu Sep 22 14:42:12 2016 +0200

----------------------------------------------------------------------
 .../runtime/executiongraph/ExecutionJobVertex.java   | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/345b2529/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
index 1ac9522..ead0852 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertex.java
@@ -165,10 +165,17 @@ public class ExecutionJobVertex {
 			InputSplitSource<InputSplit> splitSource = (InputSplitSource<InputSplit>) jobVertex.getInputSplitSource();
 			
 			if (splitSource != null) {
-				inputSplits = splitSource.createInputSplits(numTaskVertices);
-				
-				if (inputSplits != null) {
-					splitAssigner = splitSource.getInputSplitAssigner(inputSplits);
+				Thread currentThread = Thread.currentThread();
+				ClassLoader oldContextClassLoader = currentThread.getContextClassLoader();
+				currentThread.setContextClassLoader(graph.getUserClassLoader());
+				try {
+					inputSplits = splitSource.createInputSplits(numTaskVertices);
+
+					if (inputSplits != null) {
+						splitAssigner = splitSource.getInputSplitAssigner(inputSplits);
+					}
+				} finally {
+					currentThread.setContextClassLoader(oldContextClassLoader);
 				}
 			}
 			else {


[2/2] flink git commit: [FLINK-4603] [checkpoints] Fix user code classloading in KeyedStateBackend

Posted by se...@apache.org.
[FLINK-4603] [checkpoints] Fix user code classloading in KeyedStateBackend

This closes #2533


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/3b8fe95e
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/3b8fe95e
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/3b8fe95e

Branch: refs/heads/master
Commit: 3b8fe95ec728d59e3ffba2901450c56d7cca2b24
Parents: e6fbda9
Author: Stefan Richter <s....@data-artisans.com>
Authored: Wed Sep 21 14:55:58 2016 +0200
Committer: Stephan Ewen <se...@apache.org>
Committed: Thu Sep 22 14:42:12 2016 +0200

----------------------------------------------------------------------
 .../state/RocksDBKeyedStateBackend.java         |  19 +-
 .../streaming/state/RocksDBStateBackend.java    |   2 +
 .../apache/flink/util/InstantiationUtil.java    |   6 +-
 .../flink/runtime/state/KeyedStateBackend.java  |   4 +
 .../state/filesystem/FsStateBackend.java        |   2 +
 .../state/heap/HeapKeyedStateBackend.java       |  31 +--
 .../state/memory/MemoryStateBackend.java        |   5 +-
 .../streaming/runtime/tasks/StreamTask.java     |   8 +-
 flink-tests/pom.xml                             |  19 ++
 ...t-checkpointing-custom_kv_state-assembly.xml |  38 +++
 .../test/classloading/ClassLoaderITCase.java    |  25 +-
 .../jar/CheckpointingCustomKvStateProgram.java  | 233 +++++++++++++++++++
 12 files changed, 363 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 177c09f..d5a96af 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -47,6 +47,7 @@ import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.util.SerializableObject;
+import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 import org.rocksdb.ColumnFamilyDescriptor;
 import org.rocksdb.ColumnFamilyHandle;
@@ -63,8 +64,6 @@ import org.slf4j.LoggerFactory;
 import javax.annotation.concurrent.GuardedBy;
 import java.io.File;
 import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -125,6 +124,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 	public RocksDBKeyedStateBackend(
 			JobID jobId,
 			String operatorIdentifier,
+			ClassLoader userCodeClassLoader,
 			File instanceBasePath,
 			DBOptions dbOptions,
 			ColumnFamilyOptions columnFamilyOptions,
@@ -134,7 +134,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 			KeyGroupRange keyGroupRange
 	) throws Exception {
 
-		super(kvStateRegistry, keySerializer, numberOfKeyGroups, keyGroupRange);
+		super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange);
 
 		this.operatorIdentifier = operatorIdentifier;
 		this.jobId = jobId;
@@ -177,6 +177,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 	public RocksDBKeyedStateBackend(
 			JobID jobId,
 			String operatorIdentifier,
+			ClassLoader userCodeClassLoader,
 			File instanceBasePath,
 			DBOptions dbOptions,
 			ColumnFamilyOptions columnFamilyOptions,
@@ -189,6 +190,7 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 		this(
 			jobId,
 			operatorIdentifier,
+			userCodeClassLoader,
 			instanceBasePath,
 			dbOptions,
 			columnFamilyOptions,
@@ -455,8 +457,8 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 				checkInterrupted();
 
 				//write StateDescriptor for this k/v state
-				ObjectOutputStream ooOut = new ObjectOutputStream(outStream);
-				ooOut.writeObject(column.getValue().f1);
+				InstantiationUtil.serializeObject(outStream, column.getValue().f1);
+
 				//retrieve iterator for this k/v states
 				ReadOptions readOptions = new ReadOptions();
 				readOptions.setSnapshot(snapshot);
@@ -649,8 +651,11 @@ public class RocksDBKeyedStateBackend<K> extends KeyedStateBackend<K> {
 
 			//restore the empty columns for the k/v states through the metadata
 			for (int i = 0; i < numColumns; i++) {
-				ObjectInputStream ooIn = new ObjectInputStream(currentStateHandleInStream);
-				StateDescriptor stateDescriptor = (StateDescriptor) ooIn.readObject();
+
+				StateDescriptor stateDescriptor = InstantiationUtil.deserializeObject(
+						currentStateHandleInStream,
+						rocksDBKeyedStateBackend.userCodeClassLoader);
+
 				Tuple2<ColumnFamilyHandle, StateDescriptor> columnFamily = rocksDBKeyedStateBackend.
 						kvStateInformation.get(stateDescriptor.getName());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
index 0fdbd5f..b6ce224 100644
--- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
+++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateBackend.java
@@ -240,6 +240,7 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 		return new RocksDBKeyedStateBackend<>(
 				jobID,
 				operatorIdentifier,
+				env.getUserClassLoader(),
 				instanceBasePath,
 				getDbOptions(),
 				getColumnOptions(),
@@ -264,6 +265,7 @@ public class RocksDBStateBackend extends AbstractStateBackend {
 		return new RocksDBKeyedStateBackend<>(
 				jobID,
 				operatorIdentifier,
+				env.getUserClassLoader(),
 				instanceBasePath,
 				getDbOptions(),
 				getColumnOptions(),

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java
index b1dddae..de4cffb 100644
--- a/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java
+++ b/flink-core/src/main/java/org/apache/flink/util/InstantiationUtil.java
@@ -299,7 +299,10 @@ public final class InstantiationUtil {
 	@SuppressWarnings("unchecked")
 	public static <T> T deserializeObject(InputStream in, ClassLoader cl) throws IOException, ClassNotFoundException {
 		final ClassLoader old = Thread.currentThread().getContextClassLoader();
-		try (ObjectInputStream oois = new ClassLoaderObjectInputStream(in, cl)) {
+		ObjectInputStream oois;
+		// not using resource try to avoid AutoClosable's close() on the given stream
+		try {
+			oois = new ClassLoaderObjectInputStream(in, cl);
 			Thread.currentThread().setContextClassLoader(cl);
 			return (T) oois.readObject();
 		}
@@ -332,7 +335,6 @@ public final class InstantiationUtil {
 	public static void serializeObject(OutputStream out, Object o) throws IOException {
 		ObjectOutputStream oos = new ObjectOutputStream(out);
 		oos.writeObject(o);
-		oos.flush();
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
index bf9018e..8db63ee 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateBackend.java
@@ -77,14 +77,18 @@ public abstract class KeyedStateBackend<K> {
 	/** KvStateRegistry helper for this task */
 	protected final TaskKvStateRegistry kvStateRegistry;
 
+	protected final ClassLoader userCodeClassLoader;
+
 	public KeyedStateBackend(
 			TaskKvStateRegistry kvStateRegistry,
 			TypeSerializer<K> keySerializer,
+			ClassLoader userCodeClassLoader,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange) {
 
 		this.kvStateRegistry = Preconditions.checkNotNull(kvStateRegistry);
 		this.keySerializer = Preconditions.checkNotNull(keySerializer);
+		this.userCodeClassLoader = Preconditions.checkNotNull(userCodeClassLoader);
 		this.numberOfKeyGroups = Preconditions.checkNotNull(numberOfKeyGroups);
 		this.keyGroupRange = Preconditions.checkNotNull(keyGroupRange);
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
index 6d92a4d..99e3684 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/filesystem/FsStateBackend.java
@@ -186,6 +186,7 @@ public class FsStateBackend extends AbstractStateBackend {
 		return new HeapKeyedStateBackend<>(
 				kvStateRegistry,
 				keySerializer,
+				env.getUserClassLoader(),
 				numberOfKeyGroups,
 				keyGroupRange);
 	}
@@ -203,6 +204,7 @@ public class FsStateBackend extends AbstractStateBackend {
 		return new HeapKeyedStateBackend<>(
 				kvStateRegistry,
 				keySerializer,
+				env.getUserClassLoader(),
 				numberOfKeyGroups,
 				keyGroupRange,
 				restoredState);

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 8d13941..c13be70 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -39,12 +39,11 @@ import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
@@ -75,20 +74,23 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
 	public HeapKeyedStateBackend(
 			TaskKvStateRegistry kvStateRegistry,
 			TypeSerializer<K> keySerializer,
+			ClassLoader userCodeClassLoader,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange) {
 
-		super(kvStateRegistry, keySerializer, numberOfKeyGroups, keyGroupRange);
+		super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange);
 
 		LOG.info("Initializing heap keyed state backend with stream factory.");
 	}
 
-	public HeapKeyedStateBackend(TaskKvStateRegistry kvStateRegistry,
+	public HeapKeyedStateBackend(
+			TaskKvStateRegistry kvStateRegistry,
 			TypeSerializer<K> keySerializer,
+			ClassLoader userCodeClassLoader,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
 			List<KeyGroupsStateHandle> restoredState) throws Exception {
-		super(kvStateRegistry, keySerializer, numberOfKeyGroups, keyGroupRange);
+		super(kvStateRegistry, keySerializer, userCodeClassLoader, numberOfKeyGroups, keyGroupRange);
 
 		LOG.info("Initializing heap keyed state backend from snapshot.");
 
@@ -135,7 +137,6 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
 		@SuppressWarnings("unchecked,rawtypes")
 		StateTable<K, N, T> stateTable = (StateTable) stateTables.get(stateDesc.getName());
 
-
 		if (stateTable == null) {
 			stateTable = new StateTable<>(stateDesc.getSerializer(), namespaceSerializer, keyGroupRange);
 			stateTables.put(stateDesc.getName(), stateTable);
@@ -190,10 +191,8 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
 			TypeSerializer namespaceSerializer = kvState.getValue().getNamespaceSerializer();
 			TypeSerializer stateSerializer = kvState.getValue().getStateSerializer();
 
-			ObjectOutputStream oos = new ObjectOutputStream(outView);
-			oos.writeObject(namespaceSerializer);
-			oos.writeObject(stateSerializer);
-			oos.flush();
+			InstantiationUtil.serializeObject(stream, namespaceSerializer);
+			InstantiationUtil.serializeObject(stream, stateSerializer);
 
 			kVStateToId.put(kvState.getKey(), kVStateToId.size());
 		}
@@ -266,18 +265,20 @@ public class HeapKeyedStateBackend<K> extends KeyedStateBackend<K> {
 			for (int i = 0; i < numKvStates; ++i) {
 				String stateName = inView.readUTF();
 
-				ObjectInputStream ois = new ObjectInputStream(inView);
+				TypeSerializer namespaceSerializer =
+						InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
+				TypeSerializer stateSerializer =
+						InstantiationUtil.deserializeObject(fsDataInputStream, userCodeClassLoader);
 
-				TypeSerializer namespaceSerializer = (TypeSerializer) ois.readObject();
-				TypeSerializer stateSerializer = (TypeSerializer) ois.readObject();
-				StateTable<K, ?, ?> stateTable = new StateTable(stateSerializer,
+				StateTable<K, ?, ?> stateTable = new StateTable(
+						stateSerializer,
 						namespaceSerializer,
 						keyGroupRange);
 				stateTables.put(stateName, stateTable);
 				kvStatesById.put(i, stateName);
 			}
 
-			for (int keyGroupIndex = keyGroupRange.getStartKeyGroup(); keyGroupIndex <= keyGroupRange.getEndKeyGroup(); keyGroupIndex++) {
+			for (int keyGroupIndex = keyGroupRange.getStartKeyGroup();  keyGroupIndex <= keyGroupRange.getEndKeyGroup(); ++keyGroupIndex) {
 				long offset = keyGroupsHandle.getOffsetForKeyGroup(keyGroupIndex);
 				fsDataInputStream.seek(offset);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
index 179dfe7..cc145ff 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/memory/MemoryStateBackend.java
@@ -78,7 +78,8 @@ public class MemoryStateBackend extends AbstractStateBackend {
 	@Override
 	public <K> KeyedStateBackend<K> createKeyedStateBackend(
 			Environment env, JobID jobID,
-			String operatorIdentifier, TypeSerializer<K> keySerializer,
+			String operatorIdentifier,
+			TypeSerializer<K> keySerializer,
 			int numberOfKeyGroups,
 			KeyGroupRange keyGroupRange,
 			TaskKvStateRegistry kvStateRegistry) throws IOException {
@@ -86,6 +87,7 @@ public class MemoryStateBackend extends AbstractStateBackend {
 		return new HeapKeyedStateBackend<>(
 				kvStateRegistry,
 				keySerializer,
+				env.getUserClassLoader(),
 				numberOfKeyGroups,
 				keyGroupRange);
 	}
@@ -103,6 +105,7 @@ public class MemoryStateBackend extends AbstractStateBackend {
 		return new HeapKeyedStateBackend<>(
 				kvStateRegistry,
 				keySerializer,
+				env.getUserClassLoader(),
 				numberOfKeyGroups,
 				keyGroupRange,
 				restoredState);

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 9c26509..d4638a4 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.IllegalConfigurationException;
+import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
@@ -585,7 +586,12 @@ public abstract class StreamTask<OUT, Operator extends StreamOperator<OUT>>
 
 						if (operator != null) {
 							LOG.debug("Restore state of task {} in chain ({}).", i, getName());
-							operator.restoreState(state.openInputStream());
+							FSDataInputStream inputStream = state.openInputStream();
+							try {
+								operator.restoreState(inputStream);
+							} finally {
+								inputStream.close();
+							}
 						}
 					}
 				}

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/pom.xml
----------------------------------------------------------------------
diff --git a/flink-tests/pom.xml b/flink-tests/pom.xml
index b09db1f..efc95ab 100644
--- a/flink-tests/pom.xml
+++ b/flink-tests/pom.xml
@@ -485,6 +485,25 @@ under the License.
 							</descriptors>
 						</configuration>
 					</execution>
+					<execution>
+						<id>create-checkpointing_custom_kv_state-jar</id>
+						<phase>process-test-classes</phase>
+						<goals>
+							<goal>single</goal>
+						</goals>
+						<configuration>
+							<archive>
+								<manifest>
+									<mainClass>org.apache.flink.test.classloading.jar.CheckpointingCustomKvStateProgram</mainClass>
+								</manifest>
+							</archive>
+							<finalName>checkpointing_custom_kv_state</finalName>
+							<attach>false</attach>
+							<descriptors>
+								<descriptor>src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml</descriptor>
+							</descriptors>
+						</configuration>
+					</execution>
 				</executions>
 			</plugin>
 

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml b/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml
new file mode 100644
index 0000000..fdebfdd
--- /dev/null
+++ b/flink-tests/src/test/assembly/test-checkpointing-custom_kv_state-assembly.xml
@@ -0,0 +1,38 @@
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+
+-->
+
+<assembly>
+	<id>test-jar</id>
+	<formats>
+		<format>jar</format>
+	</formats>
+	<includeBaseDirectory>false</includeBaseDirectory>
+	<fileSets>
+		<fileSet>
+			<directory>${project.build.testOutputDirectory}</directory>
+			<outputDirectory>/</outputDirectory>
+			<!--modify/add include to match your package(s) -->
+			<includes>
+				<include>org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.class</include>
+				<include>org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram*.class</include>
+			</includes>
+		</fileSet>
+	</fileSets>
+</assembly>

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java b/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java
index 7afafe4..65da33f 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/classloading/ClassLoaderITCase.java
@@ -39,6 +39,7 @@ import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.testingUtils.TestingCluster;
 import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages.WaitForAllVerticesToBeRunning;
 import org.apache.flink.test.testdata.KMeansData;
+import org.apache.flink.test.util.SuccessException;
 import org.apache.flink.util.TestLogger;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
@@ -46,7 +47,6 @@ import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
-import scala.Option;
 import scala.concurrent.Await;
 import scala.concurrent.Future;
 import scala.concurrent.duration.Deadline;
@@ -79,6 +79,8 @@ public class ClassLoaderITCase extends TestLogger {
 
 	private static final String CUSTOM_KV_STATE_JAR_PATH = "custom_kv_state-test-jar.jar";
 
+	private static final String CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH = "checkpointing_custom_kv_state-test-jar.jar";
+
 	public static final TemporaryFolder FOLDER = new TemporaryFolder();
 
 	private static TestingCluster testCluster;
@@ -199,9 +201,26 @@ public class ClassLoaderITCase extends TestLogger {
 					});
 
 			userCodeTypeProg.invokeInteractiveModeForExecution();
+
+			File checkpointDir = FOLDER.newFolder();
+			File outputDir = FOLDER.newFolder();
+
+			final PackagedProgram program = new PackagedProgram(
+					new File(CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH),
+					new String[] {
+							CHECKPOINTING_CUSTOM_KV_STATE_JAR_PATH,
+							"localhost",
+							String.valueOf(port),
+							checkpointDir.toURI().toString(),
+							outputDir.toURI().toString()
+					});
+
+			program.invokeInteractiveModeForExecution();
+
 		} catch (Exception e) {
-			e.printStackTrace();
-			fail(e.getMessage());
+			if (!(e.getCause().getCause() instanceof SuccessException)) {
+				fail(e.getMessage());
+			}
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/3b8fe95e/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java
new file mode 100644
index 0000000..6796cb0
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/classloading/jar/CheckpointingCustomKvStateProgram.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.classloading.jar;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ReducingState;
+import org.apache.flink.api.common.state.ReducingStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.state.CheckpointListener;
+import org.apache.flink.runtime.state.filesystem.FsStateBackend;
+import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
+import org.apache.flink.test.util.SuccessException;
+import org.apache.flink.util.Collector;
+
+import java.io.IOException;
+import java.util.concurrent.ThreadLocalRandom;
+
+public class CheckpointingCustomKvStateProgram {
+
+	public static void main(String[] args) throws Exception {
+		final String jarFile = args[0];
+		final String host = args[1];
+		final int port = Integer.parseInt(args[2]);
+		final String checkpointPath = args[3];
+		final String outputPath = args[4];
+		final int parallelism = 1;
+
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.createRemoteEnvironment(host, port, jarFile);
+
+		env.setParallelism(parallelism);
+		env.getConfig().disableSysoutLogging();
+		env.enableCheckpointing(100);
+		env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 1000));
+		env.setStateBackend(new FsStateBackend(checkpointPath));
+
+		DataStream<Integer> source = env.addSource(new InfiniteIntegerSource());
+		source
+				.map(new MapFunction<Integer, Tuple2<Integer, Integer>>() {
+					private static final long serialVersionUID = 1L;
+
+					@Override
+					public Tuple2<Integer, Integer> map(Integer value) throws Exception {
+						return new Tuple2<>(ThreadLocalRandom.current().nextInt(parallelism), value);
+					}
+				})
+				.keyBy(new KeySelector<Tuple2<Integer,Integer>, Integer>() {
+					private static final long serialVersionUID = 1L;
+
+					@Override
+					public Integer getKey(Tuple2<Integer, Integer> value) throws Exception {
+						return value.f0;
+					}
+				}).flatMap(new ReducingStateFlatMap()).writeAsText(outputPath, FileSystem.WriteMode.OVERWRITE);
+
+		env.execute();
+	}
+
+	private static class InfiniteIntegerSource implements ParallelSourceFunction<Integer>, Checkpointed<Integer> {
+		private static final long serialVersionUID = -7517574288730066280L;
+		private volatile boolean running = true;
+
+		@Override
+		public void run(SourceContext<Integer> ctx) throws Exception {
+			int counter = 0;
+			while (running) {
+				synchronized (ctx.getCheckpointLock()) {
+					ctx.collect(counter++);
+				}
+			}
+		}
+
+		@Override
+		public void cancel() {
+			running = false;
+		}
+
+		@Override
+		public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
+			return 0;
+		}
+
+		@Override
+		public void restoreState(Integer state) throws Exception {
+
+		}
+	}
+
+	private static class ReducingStateFlatMap extends RichFlatMapFunction<Tuple2<Integer, Integer>, Integer> implements Checkpointed<ReducingStateFlatMap>, CheckpointListener {
+
+		private static final long serialVersionUID = -5939722892793950253L;
+		private transient ReducingState<Integer> kvState;
+
+		private boolean atLeastOneSnapshotComplete = false;
+		private boolean restored = false;
+
+		@Override
+		public void open(Configuration parameters) throws Exception {
+			ReducingStateDescriptor<Integer> stateDescriptor =
+					new ReducingStateDescriptor<>(
+							"reducing-state",
+							new ReduceSum(),
+							CustomIntSerializer.INSTANCE);
+
+			this.kvState = getRuntimeContext().getReducingState(stateDescriptor);
+		}
+
+
+		@Override
+		public void flatMap(Tuple2<Integer, Integer> value, Collector<Integer> out) throws Exception {
+			kvState.add(value.f1);
+
+			if(atLeastOneSnapshotComplete) {
+				if (restored) {
+					throw new SuccessException();
+				} else {
+					throw new RuntimeException("Intended failure, to trigger restore");
+				}
+			}
+		}
+
+		@Override
+		public ReducingStateFlatMap snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
+			return this;
+		}
+
+		@Override
+		public void restoreState(ReducingStateFlatMap state) throws Exception {
+			restored = true;
+			atLeastOneSnapshotComplete = true;
+		}
+
+		@Override
+		public void notifyCheckpointComplete(long checkpointId) throws Exception {
+			atLeastOneSnapshotComplete = true;
+		}
+
+		private static class ReduceSum implements ReduceFunction<Integer> {
+			private static final long serialVersionUID = 1L;
+
+			@Override
+			public Integer reduce(Integer value1, Integer value2) throws Exception {
+				return value1 + value2;
+			}
+		}
+	}
+
+	private static final class CustomIntSerializer extends TypeSerializerSingleton<Integer> {
+
+		private static final long serialVersionUID = 4572452915892737448L;
+
+		public static final TypeSerializer<Integer> INSTANCE = new CustomIntSerializer();
+
+		@Override
+		public boolean isImmutableType() {
+			return true;
+		}
+
+		@Override
+		public Integer createInstance() {
+			return 0;
+		}
+
+		@Override
+		public Integer copy(Integer from) {
+			return from;
+		}
+
+		@Override
+		public Integer copy(Integer from, Integer reuse) {
+			return from;
+		}
+
+		@Override
+		public int getLength() {
+			return 4;
+		}
+
+		@Override
+		public void serialize(Integer record, DataOutputView target) throws IOException {
+			target.writeInt(record.intValue());
+		}
+
+		@Override
+		public Integer deserialize(DataInputView source) throws IOException {
+			return Integer.valueOf(source.readInt());
+		}
+
+		@Override
+		public Integer deserialize(Integer reuse, DataInputView source) throws IOException {
+			return Integer.valueOf(source.readInt());
+		}
+
+		@Override
+		public void copy(DataInputView source, DataOutputView target) throws IOException {
+			target.writeInt(source.readInt());
+		}
+
+		@Override
+		public boolean canEqual(Object obj) {
+			return obj instanceof CustomIntSerializer;
+		}
+
+	}
+}