You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2019/01/25 21:41:40 UTC

[flink] branch master updated (4c4ce45 -> bced96a)

This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git.


    from 4c4ce45  [FLINK-11008][state] Parallelize file upload for RocksDB incremental snapshots.
     new c95e9f6  [FLINK-11390][tests] Port testTaskManagerFailure to new codebase.
     new bced96a  [FLINK-11171] Avoid concurrent usage of StateSnapshotTransformer

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../client/program/rest/RestClusterClient.java     |  21 +-
 .../org/apache/flink/runtime/rest/RestClient.java  |   9 +
 .../job/metrics/JobMetricsMessageParameters.java   |   2 +-
 .../RegisteredKeyValueStateBackendMetaInfo.java    |  27 +-
 .../runtime/state/StateSnapshotTransformer.java    |  90 +-----
 ...sformer.java => StateSnapshotTransformers.java} | 111 +++----
 .../state/heap/CopyOnWriteStateTableSnapshot.java  |   7 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  |  31 +-
 .../runtime/state/heap/NestedMapsStateTable.java   |   9 +-
 .../state/ttl/TtlStateSnapshotTransformer.java     |   2 +-
 .../flink/runtime/state/StateBackendTestBase.java  |  16 +
 .../state/StateSnapshotTransformerTest.java        | 305 ++++++++++++++++++
 .../state/ttl/mock/MockKeyedStateBackend.java      |   7 +-
 .../runtime/state/ttl/mock/MockStateBackend.java   |  25 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  |  58 +---
 .../RocksDBSnapshotTransformFactoryAdaptor.java    | 105 +++++++
 .../state/snapshot/RocksFullSnapshotStrategy.java  |  49 ++-
 .../flink/yarn/YARNHighAvailabilityITCase.java     | 232 +++++++++++---
 .../yarn/YARNSessionCapacitySchedulerITCase.java   | 345 +++++++++++----------
 .../java/org/apache/flink/yarn/YarnTestBase.java   |  20 +-
 .../org/apache/flink/yarn/testjob/YarnTestJob.java |  56 +++-
 21 files changed, 1053 insertions(+), 474 deletions(-)
 copy flink-runtime/src/main/java/org/apache/flink/runtime/state/{StateSnapshotTransformer.java => StateSnapshotTransformers.java} (56%)
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotTransformerTest.java
 create mode 100644 flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBSnapshotTransformFactoryAdaptor.java


[flink] 01/02: [FLINK-11390][tests] Port testTaskManagerFailure to new codebase.

Posted by tr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit c95e9f642288bb2816cf84868709ea2543a90ae5
Author: Gary Yao <ga...@data-artisans.com>
AuthorDate: Fri Jan 18 22:32:22 2019 +0100

    [FLINK-11390][tests] Port testTaskManagerFailure to new codebase.
    
    Port YARNSessionCapacitySchedulerITCase#testTaskManagerFailure to flip6 codebase:
    * Remove assertions that rely on log messages only
    * Move part where TMs are killed to YARNHighAvailabilityITCase
    * Rename test to a proper name that describes what it does
    * Add Javadoc explaning what this test does
    
    [FLINK-11390][tests] Move comment to right position
    
    [FLINK-11390][tests] Reuse YarnClient from super class
    
    Move waitUntilCondition to YarnTestBase
    
    [FLINK-11390][tests] Extract method parse hostname
    
    Extract method getOnlyApplicationReport
    
    Extract method submitJob
    
    Extract method getNumberOfSlotsPerTaskManager
    
    Extract method getFlinkConfigFromRestApi
    
    Delete useless comment
    
    Rename: runner -> yarnSessionClusterRunner
    
    Delete useless sleep & refactor
    
    Reorder methods and add static keyword where possible
    
    This closes #7546.
---
 .../client/program/rest/RestClusterClient.java     |  21 +-
 .../org/apache/flink/runtime/rest/RestClient.java  |   9 +
 .../job/metrics/JobMetricsMessageParameters.java   |   2 +-
 .../flink/yarn/YARNHighAvailabilityITCase.java     | 232 +++++++++++---
 .../yarn/YARNSessionCapacitySchedulerITCase.java   | 345 +++++++++++----------
 .../java/org/apache/flink/yarn/YarnTestBase.java   |  20 +-
 .../org/apache/flink/yarn/testjob/YarnTestJob.java |  56 +++-
 7 files changed, 467 insertions(+), 218 deletions(-)

diff --git a/flink-clients/src/main/java/org/apache/flink/client/program/rest/RestClusterClient.java b/flink-clients/src/main/java/org/apache/flink/client/program/rest/RestClusterClient.java
index c6dc37e..eea83c6 100644
--- a/flink-clients/src/main/java/org/apache/flink/client/program/rest/RestClusterClient.java
+++ b/flink-clients/src/main/java/org/apache/flink/client/program/rest/RestClusterClient.java
@@ -272,17 +272,25 @@ public class RestClusterClient<T> extends ClusterClient<T> implements NewCluster
 		}
 	}
 
-	@Override
-	public CompletableFuture<JobStatus> getJobStatus(JobID jobId) {
-		JobDetailsHeaders detailsHeaders = JobDetailsHeaders.getInstance();
+	/**
+	 * Requests the job details.
+	 *
+	 * @param jobId The job id
+	 * @return Job details
+	 */
+	public CompletableFuture<JobDetailsInfo> getJobDetails(JobID jobId) {
+		final JobDetailsHeaders detailsHeaders = JobDetailsHeaders.getInstance();
 		final JobMessageParameters  params = new JobMessageParameters();
 		params.jobPathParameter.resolve(jobId);
 
-		CompletableFuture<JobDetailsInfo> responseFuture = sendRequest(
+		return sendRequest(
 			detailsHeaders,
 			params);
+	}
 
-		return responseFuture.thenApply(JobDetailsInfo::getJobStatus);
+	@Override
+	public CompletableFuture<JobStatus> getJobStatus(JobID jobId) {
+		return getJobDetails(jobId).thenApply(JobDetailsInfo::getJobStatus);
 	}
 
 	/**
@@ -694,7 +702,8 @@ public class RestClusterClient<T> extends ClusterClient<T> implements NewCluster
 		return sendRequest(messageHeaders, EmptyMessageParameters.getInstance(), EmptyRequestBody.getInstance());
 	}
 
-	private <M extends MessageHeaders<R, P, U>, U extends MessageParameters, R extends RequestBody, P extends ResponseBody> CompletableFuture<P>
+	@VisibleForTesting
+	public <M extends MessageHeaders<R, P, U>, U extends MessageParameters, R extends RequestBody, P extends ResponseBody> CompletableFuture<P>
 			sendRequest(M messageHeaders, U messageParameters, R request) {
 		return sendRetriableRequest(
 			messageHeaders, messageParameters, request, isConnectionProblemOrServiceUnavailable());
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java
index a25b13c..c478cf5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/RestClient.java
@@ -23,6 +23,8 @@ import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.RestOptions;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.io.network.netty.SSLHandlerFactory;
+import org.apache.flink.runtime.rest.messages.EmptyMessageParameters;
+import org.apache.flink.runtime.rest.messages.EmptyRequestBody;
 import org.apache.flink.runtime.rest.messages.ErrorResponseBody;
 import org.apache.flink.runtime.rest.messages.MessageHeaders;
 import org.apache.flink.runtime.rest.messages.MessageParameters;
@@ -185,6 +187,13 @@ public class RestClient implements AutoCloseableAsync {
 		return terminationFuture;
 	}
 
+	public <M extends MessageHeaders<EmptyRequestBody, P, EmptyMessageParameters>, P extends ResponseBody> CompletableFuture<P> sendRequest(
+			String targetAddress,
+			int targetPort,
+			M messageHeaders) throws IOException {
+		return sendRequest(targetAddress, targetPort, messageHeaders, EmptyMessageParameters.getInstance(), EmptyRequestBody.getInstance());
+	}
+
 	public <M extends MessageHeaders<R, P, U>, U extends MessageParameters, R extends RequestBody, P extends ResponseBody> CompletableFuture<P> sendRequest(
 			String targetAddress,
 			int targetPort,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/job/metrics/JobMetricsMessageParameters.java b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/job/metrics/JobMetricsMessageParameters.java
index f8bab83..3f6f8af 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/job/metrics/JobMetricsMessageParameters.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/rest/messages/job/metrics/JobMetricsMessageParameters.java
@@ -31,7 +31,7 @@ import java.util.Collections;
  */
 public class JobMetricsMessageParameters extends JobMessageParameters {
 
-	private final MetricsFilterParameter metricsFilterParameter = new MetricsFilterParameter();
+	public final MetricsFilterParameter metricsFilterParameter = new MetricsFilterParameter();
 
 	@Override
 	public Collection<MessageQueryParameter<?>> getQueryParameters() {
diff --git a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNHighAvailabilityITCase.java b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNHighAvailabilityITCase.java
index 1a2eb92..de7e02a 100644
--- a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNHighAvailabilityITCase.java
+++ b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNHighAvailabilityITCase.java
@@ -25,28 +25,43 @@ import org.apache.flink.client.deployment.ClusterDeploymentException;
 import org.apache.flink.client.deployment.ClusterSpecification;
 import org.apache.flink.client.program.ClusterClient;
 import org.apache.flink.client.program.rest.RestClusterClient;
+import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.HighAvailabilityOptions;
 import org.apache.flink.configuration.ResourceManagerOptions;
-import org.apache.flink.runtime.client.JobStatusMessage;
+import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.jobgraph.JobGraph;
-import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.jobmaster.JobResult;
+import org.apache.flink.runtime.rest.messages.EmptyRequestBody;
+import org.apache.flink.runtime.rest.messages.job.JobDetailsInfo;
+import org.apache.flink.runtime.rest.messages.job.metrics.JobMetricsHeaders;
+import org.apache.flink.runtime.rest.messages.job.metrics.JobMetricsMessageParameters;
+import org.apache.flink.runtime.rest.messages.job.metrics.Metric;
 import org.apache.flink.util.OperatingSystem;
-import org.apache.flink.util.function.SupplierWithException;
 import org.apache.flink.yarn.configuration.YarnConfigOptions;
 import org.apache.flink.yarn.entrypoint.YarnSessionClusterEntrypoint;
 import org.apache.flink.yarn.testjob.YarnTestJob;
 import org.apache.flink.yarn.util.YarnTestUtils;
 
+import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
+
+import org.apache.commons.lang3.StringUtils;
 import org.apache.curator.test.TestingServer;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.yarn.api.protocolrecords.StopContainersRequest;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.ApplicationReport;
+import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.YarnApplicationState;
 import org.apache.hadoop.yarn.client.api.YarnClient;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.security.NMTokenIdentifier;
+import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
+import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
 import org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceScheduler;
 import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.CapacityScheduler;
 import org.junit.AfterClass;
-import org.junit.Assert;
+import org.junit.Before;
 import org.junit.BeforeClass;
 import org.junit.ClassRule;
 import org.junit.Test;
@@ -58,13 +73,21 @@ import java.io.File;
 import java.io.IOException;
 import java.time.Duration;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.EnumSet;
+import java.util.Map;
+import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.TimeoutException;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+import java.util.function.Predicate;
 
+import static org.apache.flink.util.Preconditions.checkState;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.notNullValue;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assume.assumeTrue;
 
@@ -78,11 +101,13 @@ public class YARNHighAvailabilityITCase extends YarnTestBase {
 
 	private static final String LOG_DIR = "flink-yarn-tests-ha";
 	private static final Duration TIMEOUT = Duration.ofSeconds(200L);
-	private static final long RETRY_TIMEOUT = 100L;
 
 	private static TestingServer zkServer;
 	private static String storageDir;
 
+	private YarnTestJob.StopJobSignal stopJobSignal;
+	private JobGraph job;
+
 	@BeforeClass
 	public static void setup() throws Exception {
 		zkServer = new TestingServer();
@@ -104,6 +129,22 @@ public class YARNHighAvailabilityITCase extends YarnTestBase {
 		}
 	}
 
+	@Before
+	public void setUp() throws Exception {
+		initJobGraph();
+	}
+
+	private void initJobGraph() throws IOException {
+		stopJobSignal = YarnTestJob.StopJobSignal.usingMarkerFile(FOLDER.newFile().toPath());
+		job = YarnTestJob.stoppableJob(stopJobSignal);
+		final File testingJar =
+			YarnTestBase.findFile("..", new YarnTestUtils.TestJarFinder("flink-yarn-tests"));
+
+		assertThat(testingJar, notNullValue());
+
+		job.addJar(new org.apache.flink.core.fs.Path(testingJar.toURI()));
+	}
+
 	/**
 	 * Tests that Yarn will restart a killed {@link YarnSessionClusterEntrypoint} which will then resume
 	 * a persisted {@link JobGraph}.
@@ -115,33 +156,107 @@ public class YARNHighAvailabilityITCase extends YarnTestBase {
 			OperatingSystem.isLinux() || OperatingSystem.isMac() || OperatingSystem.isFreeBSD() || OperatingSystem.isSolaris());
 
 		final YarnClusterDescriptor yarnClusterDescriptor = setupYarnClusterDescriptor();
-
 		final RestClusterClient<ApplicationId> restClusterClient = deploySessionCluster(yarnClusterDescriptor);
 
-		final JobGraph job = createJobGraph();
+		try {
+			final JobID jobId = submitJob(restClusterClient);
+			final ApplicationId id = restClusterClient.getClusterId();
+
+			waitUntilJobIsRunning(restClusterClient, jobId);
 
-		final JobID jobId = submitJob(restClusterClient, job);
+			killApplicationMaster(yarnClusterDescriptor.getYarnSessionClusterEntrypoint());
+			waitForApplicationAttempt(id, 2);
 
-		final ApplicationId id = restClusterClient.getClusterId();
+			waitForJobTermination(restClusterClient, jobId);
 
-		waitUntilJobIsRunning(restClusterClient, jobId, RETRY_TIMEOUT);
+			killApplicationAndWait(id);
+		} finally {
+			restClusterClient.shutdown();
+		}
+	}
+
+	@Test
+	public void testJobRecoversAfterKillingTaskManager() throws Exception {
+		final YarnClusterDescriptor yarnClusterDescriptor = setupYarnClusterDescriptor();
+		final RestClusterClient<ApplicationId> restClusterClient = deploySessionCluster(yarnClusterDescriptor);
+		try {
+			final JobID jobId = submitJob(restClusterClient);
+			waitUntilJobIsRunning(restClusterClient, jobId);
+
+			stopTaskManagerContainer();
+			waitUntilJobIsRestarted(restClusterClient, jobId, 1);
 
-		killApplicationMaster(yarnClusterDescriptor.getYarnSessionClusterEntrypoint());
+			waitForJobTermination(restClusterClient, jobId);
 
+			killApplicationAndWait(restClusterClient.getClusterId());
+		} finally {
+			restClusterClient.shutdown();
+		}
+	}
+
+	private void waitForApplicationAttempt(final ApplicationId applicationId, final int attemptId) throws Exception {
 		final YarnClient yarnClient = getYarnClient();
-		Assert.assertNotNull(yarnClient);
+		checkState(yarnClient != null, "yarnClient must be initialized");
 
-		while (yarnClient.getApplicationReport(id).getCurrentApplicationAttemptId().getAttemptId() < 2) {
-			Thread.sleep(RETRY_TIMEOUT);
+		waitUntilCondition(() -> {
+			final ApplicationReport applicationReport = yarnClient.getApplicationReport(applicationId);
+			return applicationReport.getCurrentApplicationAttemptId().getAttemptId() >= attemptId;
+		}, Deadline.fromNow(TIMEOUT));
+	}
+
+	/**
+	 * Stops a container running {@link YarnTaskExecutorRunner}.
+	 */
+	private void stopTaskManagerContainer() throws Exception {
+		// find container id of taskManager:
+		ContainerId taskManagerContainer = null;
+		NodeManager nodeManager = null;
+		NMTokenIdentifier nmIdent = null;
+		UserGroupInformation remoteUgi = UserGroupInformation.getCurrentUser();
+
+		for (int nmId = 0; nmId < NUM_NODEMANAGERS; nmId++) {
+			NodeManager nm = yarnCluster.getNodeManager(nmId);
+			ConcurrentMap<ContainerId, Container> containers = nm.getNMContext().getContainers();
+			for (Map.Entry<ContainerId, Container> entry : containers.entrySet()) {
+				String command = StringUtils.join(entry.getValue().getLaunchContext().getCommands(), " ");
+				if (command.contains(YarnTaskExecutorRunner.class.getSimpleName())) {
+					taskManagerContainer = entry.getKey();
+					nodeManager = nm;
+					nmIdent = new NMTokenIdentifier(taskManagerContainer.getApplicationAttemptId(), null, "", 0);
+					// allow myself to do stuff with the container
+					// remoteUgi.addCredentials(entry.getValue().getCredentials());
+					remoteUgi.addTokenIdentifier(nmIdent);
+				}
+			}
 		}
 
-		waitUntilJobIsRunning(restClusterClient, jobId, RETRY_TIMEOUT);
+		assertNotNull("Unable to find container with TaskManager", taskManagerContainer);
+		assertNotNull("Illegal state", nodeManager);
+
+		StopContainersRequest scr = StopContainersRequest.newInstance(Collections.singletonList(taskManagerContainer));
+
+		nodeManager.getNMContext().getContainerManager().stopContainers(scr);
+
+		// cleanup auth for the subsequent tests.
+		remoteUgi.getTokenIdentifiers().remove(nmIdent);
+	}
+
+	private void killApplicationAndWait(final ApplicationId id) throws Exception {
+		final YarnClient yarnClient = getYarnClient();
+		checkState(yarnClient != null, "yarnClient must be initialized");
 
 		yarnClient.killApplication(id);
 
-		while (yarnClient.getApplications(EnumSet.of(YarnApplicationState.KILLED, YarnApplicationState.FINISHED)).isEmpty()) {
-			Thread.sleep(RETRY_TIMEOUT);
-		}
+		waitUntilCondition(() -> !yarnClient.getApplications(EnumSet.of(YarnApplicationState.KILLED, YarnApplicationState.FINISHED)).isEmpty(),
+			Deadline.fromNow(TIMEOUT));
+	}
+
+	private void waitForJobTermination(
+			final RestClusterClient<ApplicationId> restClusterClient,
+			final JobID jobId) throws Exception {
+		stopJobSignal.signal();
+		final CompletableFuture<JobResult> jobResult = restClusterClient.requestJobResult(jobId);
+		jobResult.get(TIMEOUT.toMillis(), TimeUnit.MILLISECONDS);
 	}
 
 	@Nonnull
@@ -153,6 +268,9 @@ public class YARNHighAvailabilityITCase extends YarnTestBase {
 		flinkConfiguration.setString(HighAvailabilityOptions.HA_ZOOKEEPER_QUORUM, zkServer.getConnectString());
 		flinkConfiguration.setInteger(HighAvailabilityOptions.ZOOKEEPER_SESSION_TIMEOUT, 1000);
 
+		flinkConfiguration.setString(ConfigConstants.RESTART_STRATEGY, "fixed-delay");
+		flinkConfiguration.setInteger(ConfigConstants.RESTART_STRATEGY_FIXED_DELAY_ATTEMPTS, Integer.MAX_VALUE);
+
 		final int minMemory = 100;
 		flinkConfiguration.setInteger(ResourceManagerOptions.CONTAINERIZED_HEAP_CUTOFF_MIN, minMemory);
 
@@ -172,8 +290,9 @@ public class YARNHighAvailabilityITCase extends YarnTestBase {
 		return (RestClusterClient<ApplicationId>) yarnClusterClient;
 	}
 
-	private JobID submitJob(RestClusterClient<ApplicationId> restClusterClient, JobGraph job) throws InterruptedException, java.util.concurrent.ExecutionException {
-		final CompletableFuture<JobSubmissionResult> jobSubmissionResultCompletableFuture = restClusterClient.submitJob(job);
+	private JobID submitJob(RestClusterClient<ApplicationId> restClusterClient) throws InterruptedException, java.util.concurrent.ExecutionException {
+		final CompletableFuture<JobSubmissionResult> jobSubmissionResultCompletableFuture =
+			restClusterClient.submitJob(job);
 
 		final JobSubmissionResult jobSubmissionResult = jobSubmissionResultCompletableFuture.get();
 		return jobSubmissionResult.getJobID();
@@ -184,38 +303,61 @@ public class YARNHighAvailabilityITCase extends YarnTestBase {
 		assertThat(exec.waitFor(), is(0));
 	}
 
-	@Nonnull
-	private JobGraph createJobGraph() {
-		final JobGraph job = YarnTestJob.createJob();
-		final File testingJar =
-			YarnTestBase.findFile("..", new YarnTestUtils.TestJarFinder("flink-yarn-tests"));
+	private static void waitUntilJobIsRunning(RestClusterClient<ApplicationId> restClusterClient, JobID jobId) throws Exception {
+		waitUntilCondition(
+			() -> {
+				final JobDetailsInfo jobDetails = restClusterClient.getJobDetails(jobId).get();
+				return jobDetails.getJobVertexInfos()
+					.stream()
+					.map(toExecutionState())
+					.allMatch(isRunning());
+			},
+			Deadline.fromNow(TIMEOUT));
+	}
 
-		assertThat(testingJar, notNullValue());
+	private static Function<JobDetailsInfo.JobVertexDetailsInfo, ExecutionState> toExecutionState() {
+		return JobDetailsInfo.JobVertexDetailsInfo::getExecutionState;
+	}
 
-		job.addJar(new org.apache.flink.core.fs.Path(testingJar.toURI()));
-		return job;
+	private static Predicate<ExecutionState> isRunning() {
+		return executionState -> executionState == ExecutionState.RUNNING;
 	}
 
-	private void waitUntilJobIsRunning(RestClusterClient<ApplicationId> restClusterClient, JobID jobId, long retryTimeout) throws Exception {
+	private static void waitUntilJobIsRestarted(
+		final RestClusterClient<ApplicationId> restClusterClient,
+		final JobID jobId,
+		final int expectedFullRestarts) throws Exception {
 		waitUntilCondition(
-			() -> {
-				final Collection<JobStatusMessage> jobStatusMessages = restClusterClient.listJobs().get();
+			() -> getJobFullRestarts(restClusterClient, jobId) >= expectedFullRestarts,
+			Deadline.fromNow(TIMEOUT));
+	}
 
-				return jobStatusMessages.stream()
-					.filter(jobStatusMessage -> jobStatusMessage.getJobId().equals(jobId))
-					.anyMatch(jobStatusMessage -> jobStatusMessage.getJobState() == JobStatus.RUNNING);
-			},
-			Deadline.fromNow(TIMEOUT),
-			retryTimeout);
+	private static int getJobFullRestarts(
+		final RestClusterClient<ApplicationId> restClusterClient,
+		final JobID jobId) throws Exception {
+
+		return getJobMetric(restClusterClient, jobId, "fullRestarts")
+			.map(Metric::getValue)
+			.map(Integer::parseInt)
+			.orElse(0);
 	}
 
-	private void waitUntilCondition(SupplierWithException<Boolean, Exception> condition, Deadline timeout, long retryTimeout) throws Exception {
-		while (timeout.hasTimeLeft() && !condition.get()) {
-			Thread.sleep(Math.min(retryTimeout, timeout.timeLeft().toMillis()));
-		}
+	private static Optional<Metric> getJobMetric(
+		final RestClusterClient<ApplicationId> restClusterClient,
+		final JobID jobId,
+		final String metricName) throws Exception {
 
-		if (!timeout.hasTimeLeft()) {
-			throw new TimeoutException("Condition was not met in given timeout.");
-		}
+		final JobMetricsMessageParameters messageParameters = new JobMetricsMessageParameters();
+		messageParameters.jobPathParameter.resolve(jobId);
+		messageParameters.metricsFilterParameter.resolveFromString(metricName);
+
+		final Collection<Metric> metrics = restClusterClient.sendRequest(
+			JobMetricsHeaders.getInstance(),
+			messageParameters,
+			EmptyRequestBody.getInstance()).get().getMetrics();
+
+		final Metric metric = Iterables.getOnlyElement(metrics, null);
+		checkState(metric == null || metric.getId().equals(metricName));
+		return Optional.ofNullable(metric);
 	}
 }
diff --git a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNSessionCapacitySchedulerITCase.java b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNSessionCapacitySchedulerITCase.java
index 83abd8a..93b4c9d 100644
--- a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNSessionCapacitySchedulerITCase.java
+++ b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YARNSessionCapacitySchedulerITCase.java
@@ -18,40 +18,43 @@
 
 package org.apache.flink.yarn;
 
+import org.apache.flink.api.common.time.Deadline;
+import org.apache.flink.api.common.time.Time;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.GlobalConfiguration;
 import org.apache.flink.configuration.JobManagerOptions;
 import org.apache.flink.configuration.ResourceManagerOptions;
 import org.apache.flink.runtime.client.JobClient;
+import org.apache.flink.runtime.rest.RestClient;
+import org.apache.flink.runtime.rest.RestClientConfiguration;
+import org.apache.flink.runtime.rest.handler.legacy.messages.ClusterOverviewWithVersion;
+import org.apache.flink.runtime.rest.messages.ClusterConfigurationInfo;
+import org.apache.flink.runtime.rest.messages.ClusterConfigurationInfoEntry;
+import org.apache.flink.runtime.rest.messages.ClusterConfigurationInfoHeaders;
+import org.apache.flink.runtime.rest.messages.ClusterOverviewHeaders;
+import org.apache.flink.runtime.rest.messages.taskmanager.TaskManagerInfo;
+import org.apache.flink.runtime.rest.messages.taskmanager.TaskManagersHeaders;
+import org.apache.flink.runtime.rest.messages.taskmanager.TaskManagersInfo;
 import org.apache.flink.runtime.taskexecutor.TaskManagerServices;
-import org.apache.flink.runtime.webmonitor.WebMonitorUtils;
 import org.apache.flink.test.testdata.WordCountData;
-import org.apache.flink.test.util.TestBaseUtils;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.yarn.cli.FlinkYarnSessionCli;
 import org.apache.flink.yarn.configuration.YarnConfigOptions;
 
-import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode;
-import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
-import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.ArrayNode;
+import org.apache.flink.shaded.guava18.com.google.common.net.HostAndPort;
 
 import org.apache.commons.io.FileUtils;
-import org.apache.commons.lang3.StringUtils;
-import org.apache.hadoop.security.UserGroupInformation;
-import org.apache.hadoop.yarn.api.protocolrecords.StopContainersRequest;
 import org.apache.hadoop.yarn.api.records.ApplicationId;
 import org.apache.hadoop.yarn.api.records.ApplicationReport;
-import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.api.records.YarnApplicationState;
 import org.apache.hadoop.yarn.client.api.YarnClient;
 import org.apache.hadoop.yarn.conf.YarnConfiguration;
-import org.apache.hadoop.yarn.security.NMTokenIdentifier;
-import org.apache.hadoop.yarn.server.nodemanager.NodeManager;
-import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
+import org.apache.hadoop.yarn.exceptions.YarnException;
 import org.apache.hadoop.yarn.server.resourcemanager.scheduler.ResourceScheduler;
 import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.CapacityScheduler;
 import org.apache.log4j.Level;
 import org.junit.After;
+import org.junit.AfterClass;
 import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -63,23 +66,31 @@ import java.io.FilenameFilter;
 import java.io.IOException;
 import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
+import java.time.Duration;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.EnumSet;
-import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
+import java.util.stream.Collectors;
 
 import static junit.framework.TestCase.assertTrue;
+import static org.apache.flink.util.Preconditions.checkState;
 import static org.apache.flink.yarn.UtilsTest.addTestAppender;
 import static org.apache.flink.yarn.UtilsTest.checkForLogString;
 import static org.apache.flink.yarn.util.YarnTestUtils.getTestJarPath;
-import static org.junit.Assume.assumeTrue;
+import static org.hamcrest.Matchers.hasEntry;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
 
 /**
  * This test starts a MiniYARNCluster with a CapacityScheduler.
@@ -88,14 +99,46 @@ import static org.junit.Assume.assumeTrue;
 public class YARNSessionCapacitySchedulerITCase extends YarnTestBase {
 	private static final Logger LOG = LoggerFactory.getLogger(YARNSessionCapacitySchedulerITCase.class);
 
+	/**
+	 * RestClient to query Flink cluster.
+	 */
+	private static RestClient restClient;
+
+	/**
+	 * ExecutorService for {@link RestClient}.
+	 * @see #restClient
+	 */
+	private static ExecutorService restClientExecutor;
+
+	/** Toggles checking for prohibited strings in logs after the test has run. */
+	private boolean checkForProhibitedLogContents = true;
+
 	@BeforeClass
-	public static void setup() {
+	public static void setup() throws Exception {
 		YARN_CONFIGURATION.setClass(YarnConfiguration.RM_SCHEDULER, CapacityScheduler.class, ResourceScheduler.class);
 		YARN_CONFIGURATION.set("yarn.scheduler.capacity.root.queues", "default,qa-team");
 		YARN_CONFIGURATION.setInt("yarn.scheduler.capacity.root.default.capacity", 40);
 		YARN_CONFIGURATION.setInt("yarn.scheduler.capacity.root.qa-team.capacity", 60);
 		YARN_CONFIGURATION.set(YarnTestBase.TEST_CLUSTER_NAME_KEY, "flink-yarn-tests-capacityscheduler");
 		startYARNWithConfig(YARN_CONFIGURATION);
+
+		restClientExecutor = Executors.newSingleThreadExecutor();
+		restClient = new RestClient(RestClientConfiguration.fromConfiguration(new Configuration()), restClientExecutor);
+	}
+
+	@AfterClass
+	public static void tearDown() throws Exception {
+		try {
+			YarnTestBase.teardown();
+		} finally {
+			if (restClient != null) {
+				restClient.shutdown(Time.seconds(5));
+			}
+
+			if (restClientExecutor != null) {
+				restClientExecutor.shutdownNow();
+			}
+		}
 	}
 
 	/**
@@ -184,14 +227,23 @@ public class YARNSessionCapacitySchedulerITCase extends YarnTestBase {
 	}
 
 	/**
-	 * Test TaskManager failure and also if the vcores are set correctly (see issue FLINK-2213).
+	 * Starts a session cluster on YARN, and submits a streaming job.
+	 *
+	 * <p>Tests
+	 * <ul>
+	 * <li>if a custom YARN application name can be set from the command line,
+	 * <li>if the number of TaskManager slots can be set from the command line,
+	 * <li>if dynamic properties from the command line are set,
+	 * <li>if the vcores are set correctly (FLINK-2213),
+	 * <li>if jobmanager hostname/port are shown in web interface (FLINK-1902)
+	 * </ul>
+	 *
+	 * <p><b>Hint: </b> If you think it is a good idea to add more assertions to this test, think again!
 	 */
-	@Test(timeout = 100000) // timeout after 100 seconds
-	public void testTaskManagerFailure() throws Exception {
-		assumeTrue("The new mode does not start TMs upfront.", !isNewMode);
-		LOG.info("Starting testTaskManagerFailure()");
-		Runner runner = startWithArgs(new String[]{"-j", flinkUberjar.getAbsolutePath(), "-t", flinkLibFolder.getAbsolutePath(),
-				"-n", "1",
+	@Test(timeout = 100_000)
+	public void testVCoresAreSetCorrectlyAndJobManagerHostnameAreShownInWebInterfaceAndDynamicPropertiesAndYarnApplicationNameAndTaskManagerSlots() throws Exception {
+		checkForProhibitedLogContents = false;
+		final Runner yarnSessionClusterRunner = startWithArgs(new String[]{"-j", flinkUberjar.getAbsolutePath(), "-t", flinkLibFolder.getAbsolutePath(),
 				"-jm", "768m",
 				"-tm", "1024m",
 				"-s", "3", // set the slots 3 to check if the vCores are set properly!
@@ -199,163 +251,134 @@ public class YARNSessionCapacitySchedulerITCase extends YarnTestBase {
 				"-Dfancy-configuration-value=veryFancy",
 				"-Dyarn.maximum-failed-containers=3",
 				"-D" + YarnConfigOptions.VCORES.key() + "=2"},
-			"Number of connected TaskManagers changed to 1. Slots available: 3",
+			"Flink JobManager is now running on ",
 			RunTypes.YARN_SESSION);
 
-		Assert.assertEquals(2, getRunningContainers());
-
-		// ------------------------ Test if JobManager web interface is accessible -------
-
-		final YarnClient yc = YarnClient.createYarnClient();
-		yc.init(YARN_CONFIGURATION);
-		yc.start();
-
-		List<ApplicationReport> apps = yc.getApplications(EnumSet.of(YarnApplicationState.RUNNING));
-		Assert.assertEquals(1, apps.size()); // Only one running
-		ApplicationReport app = apps.get(0);
-		Assert.assertEquals("customName", app.getName());
-		String url = app.getTrackingUrl();
-		if (!url.endsWith("/")) {
-			url += "/";
-		}
-		if (!url.startsWith("http://")) {
-			url = "http://" + url;
-		}
-		LOG.info("Got application URL from YARN {}", url);
-
-		String response = TestBaseUtils.getFromHTTP(url + "taskmanagers/");
-
-		JsonNode parsedTMs = new ObjectMapper().readTree(response);
-		ArrayNode taskManagers = (ArrayNode) parsedTMs.get("taskmanagers");
-		Assert.assertNotNull(taskManagers);
-		Assert.assertEquals(1, taskManagers.size());
-		Assert.assertEquals(3, taskManagers.get(0).get("slotsNumber").asInt());
-
-		// get the configuration from webinterface & check if the dynamic properties from YARN show up there.
-		String jsonConfig = TestBaseUtils.getFromHTTP(url + "jobmanager/config");
-		Map<String, String> parsedConfig = WebMonitorUtils.fromKeyValueJsonArray(jsonConfig);
-
-		Assert.assertEquals("veryFancy", parsedConfig.get("fancy-configuration-value"));
-		Assert.assertEquals("3", parsedConfig.get("yarn.maximum-failed-containers"));
-		Assert.assertEquals("2", parsedConfig.get(YarnConfigOptions.VCORES.key()));
+		final String logs = outContent.toString();
+		final HostAndPort hostAndPort = parseJobManagerHostname(logs);
+		final String host = hostAndPort.getHostText();
+		final int port = hostAndPort.getPort();
+		LOG.info("Extracted hostname:port: {}", host, port);
+
+		submitJob("WindowJoin.jar");
+
+		//
+		// Assert that custom YARN application name "customName" is set
+		//
+		final ApplicationReport applicationReport = getOnlyApplicationReport();
+		assertEquals("customName", applicationReport.getName());
+
+		//
+		// Assert the number of TaskManager slots are set
+		//
+		waitForTaskManagerRegistration(host, port, Duration.ofMillis(30_000));
+		assertNumberOfSlotsPerTask(host, port, 3);
+
+		final Map<String, String> flinkConfig = getFlinkConfig(host, port);
+
+		//
+		// Assert dynamic properties
+		//
+		assertThat(flinkConfig, hasEntry("fancy-configuration-value", "veryFancy"));
+		assertThat(flinkConfig, hasEntry("yarn.maximum-failed-containers", "3"));
+
+		//
+		// FLINK-2213: assert that vcores are set
+		//
+		assertThat(flinkConfig, hasEntry(YarnConfigOptions.VCORES.key(), "2"));
+
+		//
+		// FLINK-1902: check if jobmanager hostname is shown in web interface
+		//
+		assertThat(flinkConfig, hasEntry(JobManagerOptions.ADDRESS.key(), host));
+
+		yarnSessionClusterRunner.sendStop();
+		yarnSessionClusterRunner.join();
+	}
 
-		// -------------- FLINK-1902: check if jobmanager hostname/port are shown in web interface
-		// first, get the hostname/port
-		String oC = outContent.toString();
-		Pattern p = Pattern.compile("Flink JobManager is now running on ([a-zA-Z0-9.-]+):([0-9]+)");
-		Matcher matches = p.matcher(oC);
+	private static HostAndPort parseJobManagerHostname(final String logs) {
+		final Pattern p = Pattern.compile("Flink JobManager is now running on ([a-zA-Z0-9.-]+):([0-9]+)");
+		final Matcher matches = p.matcher(logs);
 		String hostname = null;
 		String port = null;
+
 		while (matches.find()) {
 			hostname = matches.group(1).toLowerCase();
 			port = matches.group(2);
 		}
-		LOG.info("Extracted hostname:port: {} {}", hostname, port);
-
-		Assert.assertEquals("unable to find hostname in " + jsonConfig, hostname,
-			parsedConfig.get(JobManagerOptions.ADDRESS.key()));
-		Assert.assertEquals("unable to find port in " + jsonConfig, port,
-			parsedConfig.get(JobManagerOptions.PORT.key()));
-
-		// test logfile access
-		String logs = TestBaseUtils.getFromHTTP(url + "jobmanager/log");
-		Assert.assertTrue(logs.contains("Starting YARN ApplicationMaster"));
-		Assert.assertTrue(logs.contains("Starting JobManager"));
-		Assert.assertTrue(logs.contains("Starting JobManager Web Frontend"));
-
-		// ------------------------ Kill container with TaskManager and check if vcores are set correctly -------
-
-		// find container id of taskManager:
-		ContainerId taskManagerContainer = null;
-		NodeManager nodeManager = null;
-		UserGroupInformation remoteUgi = null;
-		NMTokenIdentifier nmIdent = null;
-		try {
-			remoteUgi = UserGroupInformation.getCurrentUser();
-		} catch (IOException e) {
-			LOG.warn("Unable to get curr user", e);
-			Assert.fail();
-		}
-		for (int nmId = 0; nmId < NUM_NODEMANAGERS; nmId++) {
-			NodeManager nm = yarnCluster.getNodeManager(nmId);
-			ConcurrentMap<ContainerId, Container> containers = nm.getNMContext().getContainers();
-			for (Map.Entry<ContainerId, Container> entry : containers.entrySet()) {
-				String command = StringUtils.join(entry.getValue().getLaunchContext().getCommands(), " ");
-				if (command.contains(YarnTaskManager.class.getSimpleName())) {
-					taskManagerContainer = entry.getKey();
-					nodeManager = nm;
-					nmIdent = new NMTokenIdentifier(taskManagerContainer.getApplicationAttemptId(), null, "", 0);
-					// allow myself to do stuff with the container
-					// remoteUgi.addCredentials(entry.getValue().getCredentials());
-					remoteUgi.addTokenIdentifier(nmIdent);
-				}
-			}
-			sleep(500);
-		}
 
-		Assert.assertNotNull("Unable to find container with TaskManager", taskManagerContainer);
-		Assert.assertNotNull("Illegal state", nodeManager);
+		checkState(hostname != null, "hostname not found in log");
+		checkState(port != null, "port not found in log");
 
-		yc.stop();
+		return HostAndPort.fromParts(hostname, Integer.parseInt(port));
+	}
 
-		List<ContainerId> toStop = new LinkedList<ContainerId>();
-		toStop.add(taskManagerContainer);
-		StopContainersRequest scr = StopContainersRequest.newInstance(toStop);
+	private ApplicationReport getOnlyApplicationReport() throws IOException, YarnException {
+		final YarnClient yarnClient = getYarnClient();
+		checkState(yarnClient != null);
 
-		try {
-			nodeManager.getNMContext().getContainerManager().stopContainers(scr);
-		} catch (Throwable e) {
-			LOG.warn("Error stopping container", e);
-			Assert.fail("Error stopping container: " + e.getMessage());
-		}
+		final List<ApplicationReport> apps = yarnClient.getApplications(EnumSet.of(YarnApplicationState.RUNNING));
+		assertEquals(1, apps.size()); // Only one running
+		return apps.get(0);
+	}
 
-		// stateful termination check:
-		// wait until we saw a container being killed and AFTERWARDS a new one launched
-		boolean ok = false;
-		do {
-			LOG.debug("Waiting for correct order of events. Output: {}", errContent.toString());
-
-			String o = errContent.toString();
-			int killedOff = o.indexOf("Container killed by the ApplicationMaster");
-			if (killedOff != -1) {
-				o = o.substring(killedOff);
-				ok = o.indexOf("Launching TaskManager") > 0;
-			}
-			sleep(1000);
-		} while(!ok);
+	private void submitJob(final String jobFileName) throws IOException, InterruptedException {
+		Runner jobRunner = startWithArgs(new String[]{"run",
+				"--detached", getTestJarPath(jobFileName).getAbsolutePath()},
+			"Job has been submitted with JobID", RunTypes.CLI_FRONTEND);
+		jobRunner.join();
+	}
+
+	private static void waitForTaskManagerRegistration(
+			final String host,
+			final int port,
+			final Duration waitDuration) throws Exception {
+		waitUntilCondition(() -> getNumberOfTaskManagers(host, port) > 0, Deadline.fromNow(waitDuration));
+	}
 
-		// send "stop" command to command line interface
-		runner.sendStop();
-		// wait for the thread to stop
+	private static void assertNumberOfSlotsPerTask(
+			final String host,
+			final int port,
+			final int slotsNumber) throws Exception {
 		try {
-			runner.join();
-		} catch (InterruptedException e) {
-			LOG.warn("Interrupted while stopping runner", e);
+			waitUntilCondition(() -> getNumberOfSlotsPerTaskManager(host, port) == slotsNumber, Deadline.fromNow(Duration.ofSeconds(30)));
+		} catch (final TimeoutException e) {
+			final int currentNumberOfSlots = getNumberOfSlotsPerTaskManager(host, port);
+			fail(String.format("Expected slots per TM to be %d, was: %d", slotsNumber, currentNumberOfSlots));
 		}
-		LOG.warn("stopped");
-
-		// ----------- Send output to logger
-		System.setOut(ORIGINAL_STDOUT);
-		System.setErr(ORIGINAL_STDERR);
-		oC = outContent.toString();
-		String eC = errContent.toString();
-		LOG.info("Sending stdout content through logger: \n\n{}\n\n", oC);
-		LOG.info("Sending stderr content through logger: \n\n{}\n\n", eC);
+	}
 
-		// ------ Check if everything happened correctly
-		Assert.assertTrue("Expect to see failed container",
-			eC.contains("New messages from the YARN cluster"));
+	private static int getNumberOfTaskManagers(final String host, final int port) throws Exception {
+		final ClusterOverviewWithVersion clusterOverviewWithVersion = restClient.sendRequest(
+			host,
+			port,
+			ClusterOverviewHeaders.getInstance()).get(30_000, TimeUnit.MILLISECONDS);
 
-		Assert.assertTrue("Expect to see failed container",
-			eC.contains("Container killed by the ApplicationMaster"));
+		return clusterOverviewWithVersion.getNumTaskManagersConnected();
+	}
 
-		Assert.assertTrue("Expect to see new container started",
-			eC.contains("Launching TaskManager") && eC.contains("on host"));
+	private static int getNumberOfSlotsPerTaskManager(final String host, final int port) throws Exception {
+		final TaskManagersInfo taskManagersInfo = restClient.sendRequest(
+			host,
+			port,
+			TaskManagersHeaders.getInstance()).get();
+
+		return taskManagersInfo.getTaskManagerInfos()
+			.stream()
+			.map(TaskManagerInfo::getNumberSlots)
+			.findFirst()
+			.orElse(0);
+	}
 
-		// cleanup auth for the subsequent tests.
-		remoteUgi.getTokenIdentifiers().remove(nmIdent);
+	private static Map<String, String> getFlinkConfig(final String host, final int port) throws Exception {
+		final ClusterConfigurationInfo clusterConfigurationInfoEntries = restClient.sendRequest(
+			host,
+			port,
+			ClusterConfigurationInfoHeaders.getInstance()).get();
 
-		LOG.info("Finished testTaskManagerFailure()");
+		return clusterConfigurationInfoEntries.stream().collect(Collectors.toMap(
+			ClusterConfigurationInfoEntry::getKey,
+			ClusterConfigurationInfoEntry::getValue));
 	}
 
 	/**
@@ -620,11 +643,13 @@ public class YARNSessionCapacitySchedulerITCase extends YarnTestBase {
 		@SuppressWarnings("unchecked")
 		Set<String> applicationTags = (Set<String>) applicationTagsMethod.invoke(report);
 
-		Assert.assertEquals(Collections.singleton("test-tag"), applicationTags);
+		assertEquals(Collections.singleton("test-tag"), applicationTags);
 	}
 
 	@After
 	public void checkForProhibitedLogContents() {
-		ensureNoProhibitedStringInLogFiles(PROHIBITED_STRINGS, WHITELISTED_STRINGS);
+		if (checkForProhibitedLogContents) {
+			ensureNoProhibitedStringInLogFiles(PROHIBITED_STRINGS, WHITELISTED_STRINGS);
+		}
 	}
 }
diff --git a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YarnTestBase.java b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YarnTestBase.java
index b7b08ae..242588f 100644
--- a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YarnTestBase.java
+++ b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/YarnTestBase.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.clusterframework.BootstrapTools;
 import org.apache.flink.test.util.TestBaseUtils;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.TestLogger;
+import org.apache.flink.util.function.SupplierWithException;
 import org.apache.flink.yarn.cli.FlinkYarnSessionCli;
 
 import org.apache.commons.io.FileUtils;
@@ -81,6 +82,7 @@ import java.util.Scanner;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.TimeoutException;
 import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
@@ -104,6 +106,8 @@ public abstract class YarnTestBase extends TestLogger {
 
 	protected static final int NUM_NODEMANAGERS = 2;
 
+	private static final long RETRY_TIMEOUT = 100L;
+
 	/** The tests are scanning for these strings in the final output. */
 	protected static final String[] PROHIBITED_STRINGS = {
 			"Exception", // we don't want any exceptions to happen
@@ -122,7 +126,11 @@ public abstract class YarnTestBase extends TestLogger {
 		"java.io.IOException: Connection reset by peer",
 
 		// this can happen in Akka 2.4 on shutdown.
-		"java.util.concurrent.RejectedExecutionException: Worker has already been shutdown"
+		"java.util.concurrent.RejectedExecutionException: Worker has already been shutdown",
+
+		"org.apache.flink.util.FlinkException: Stopping JobMaster",
+		"org.apache.flink.util.FlinkException: JobManager is shutting down.",
+		"lost the leadership."
 	};
 
 	// Temp directory which is deleted after the unit test.
@@ -271,6 +279,16 @@ public abstract class YarnTestBase extends TestLogger {
 		return null;
 	}
 
+	protected static void waitUntilCondition(SupplierWithException<Boolean, Exception> condition, Deadline timeout) throws Exception {
+		while (timeout.hasTimeLeft() && !condition.get()) {
+			Thread.sleep(Math.min(RETRY_TIMEOUT, timeout.timeLeft().toMillis()));
+		}
+
+		if (!timeout.hasTimeLeft()) {
+			throw new TimeoutException("Condition was not met in given timeout.");
+		}
+	}
+
 	@Nonnull
 	YarnClusterDescriptor createYarnClusterDescriptor(org.apache.flink.configuration.Configuration flinkConfiguration) {
 		final YarnClusterDescriptor yarnClusterDescriptor = new YarnClusterDescriptor(
diff --git a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/testjob/YarnTestJob.java b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/testjob/YarnTestJob.java
index f86d6e4..59bad7e 100644
--- a/flink-yarn-tests/src/test/java/org/apache/flink/yarn/testjob/YarnTestJob.java
+++ b/flink-yarn-tests/src/test/java/org/apache/flink/yarn/testjob/YarnTestJob.java
@@ -23,6 +23,12 @@ import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 
+import java.io.IOException;
+import java.io.Serializable;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+
 /**
  * Testing job for {@link org.apache.flink.runtime.jobmaster.JobMaster} failover.
  * Covering stream case that have a infinite source and a sink, scheduling by
@@ -30,10 +36,10 @@ import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunctio
  */
 public class YarnTestJob {
 
-	public static JobGraph createJob() {
+	public static JobGraph stoppableJob(final StopJobSignal stopJobSignal) {
 		final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 
-		env.addSource(new InfiniteSourceFunction())
+		env.addSource(new InfiniteSourceFunction(stopJobSignal))
 			.setParallelism(2)
 			.shuffle()
 			.addSink(new DiscardingSink<>())
@@ -42,21 +48,61 @@ public class YarnTestJob {
 		return env.getStreamGraph().getJobGraph();
 	}
 
+	/**
+	 * Helper class to signal between multiple processes that a job should stop.
+	 */
+	public static class StopJobSignal implements Serializable {
+
+		private final String stopJobMarkerFile;
+
+		public static StopJobSignal usingMarkerFile(final Path stopJobMarkerFile) {
+			return new StopJobSignal(stopJobMarkerFile.toString());
+		}
+
+		private StopJobSignal(final String stopJobMarkerFile) {
+			this.stopJobMarkerFile = stopJobMarkerFile;
+		}
+
+		/**
+		 * Signals that the job should stop.
+		 */
+		public void signal() {
+			try {
+				Files.delete(Paths.get(stopJobMarkerFile));
+			} catch (final IOException e) {
+				throw new RuntimeException(e);
+			}
+		}
+
+		/**
+		 * True if job should stop.
+		 */
+		public boolean isSignaled() {
+			return !Files.exists(Paths.get(stopJobMarkerFile));
+		}
+
+	}
+
 	// *************************************************************************
 	//     USER FUNCTIONS
 	// *************************************************************************
 
 	private static final class InfiniteSourceFunction extends RichParallelSourceFunction<Integer> {
+
 		private static final long serialVersionUID = -8758033916372648233L;
+
 		private boolean running;
 
-		InfiniteSourceFunction() {
-			running = true;
+		private final StopJobSignal stopJobSignal;
+
+		InfiniteSourceFunction(final StopJobSignal stopJobSignal) {
+			this.running = true;
+			this.stopJobSignal = stopJobSignal;
 		}
 
 		@Override
 		public void run(SourceContext<Integer> ctx) throws Exception {
-			while (running) {
+			while (running && !stopJobSignal.isSignaled()) {
 				synchronized (ctx.getCheckpointLock()) {
 					ctx.collect(0);
 				}


[flink] 02/02: [FLINK-11171] Avoid concurrent usage of StateSnapshotTransformer

Posted by tr...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit bced96a5a0b8f7b7848add316c12071e0398404a
Author: Andrey Zagrebin <az...@gmail.com>
AuthorDate: Mon Dec 17 17:09:47 2018 +0100

    [FLINK-11171] Avoid concurrent usage of StateSnapshotTransformer
    
    Test non concurrent access of StateSnapshotTransformer
    
    Refactor out testNonConcurrentSnapshotTransformerAccess to separte StateSnapshotTransformerTest
    
    use element serializer from new meta info, duplicate it in rocksdb transformer factory, test concurrent access for element serializer
    
    This closes #7320.
---
 .../RegisteredKeyValueStateBackendMetaInfo.java    |  27 +-
 .../runtime/state/StateSnapshotTransformer.java    |  90 +-----
 ...sformer.java => StateSnapshotTransformers.java} | 111 +++-----
 .../state/heap/CopyOnWriteStateTableSnapshot.java  |   7 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  |  31 +--
 .../runtime/state/heap/NestedMapsStateTable.java   |   9 +-
 .../state/ttl/TtlStateSnapshotTransformer.java     |   2 +-
 .../flink/runtime/state/StateBackendTestBase.java  |  16 ++
 .../state/StateSnapshotTransformerTest.java        | 305 +++++++++++++++++++++
 .../state/ttl/mock/MockKeyedStateBackend.java      |   7 +-
 .../runtime/state/ttl/mock/MockStateBackend.java   |  25 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  |  58 ++--
 .../RocksDBSnapshotTransformFactoryAdaptor.java    | 105 +++++++
 .../state/snapshot/RocksFullSnapshotStrategy.java  |  49 +++-
 14 files changed, 586 insertions(+), 256 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
index b2d1cdc..1ce728d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
 import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.Preconditions;
 
@@ -48,8 +49,8 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 	private final StateSerializerProvider<N> namespaceSerializerProvider;
 	@Nonnull
 	private final StateSerializerProvider<S> stateSerializerProvider;
-	@Nullable
-	private StateSnapshotTransformer<S> snapshotTransformer;
+	@Nonnull
+	private StateSnapshotTransformFactory<S> stateSnapshotTransformFactory;
 
 	public RegisteredKeyValueStateBackendMetaInfo(
 		@Nonnull StateDescriptor.Type stateType,
@@ -62,7 +63,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 			name,
 			StateSerializerProvider.fromNewRegisteredSerializer(namespaceSerializer),
 			StateSerializerProvider.fromNewRegisteredSerializer(stateSerializer),
-			null);
+			StateSnapshotTransformFactory.noTransform());
 	}
 
 	public RegisteredKeyValueStateBackendMetaInfo(
@@ -70,14 +71,14 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 		@Nonnull String name,
 		@Nonnull TypeSerializer<N> namespaceSerializer,
 		@Nonnull TypeSerializer<S> stateSerializer,
-		@Nullable StateSnapshotTransformer<S> snapshotTransformer) {
+		@Nonnull StateSnapshotTransformFactory<S> stateSnapshotTransformFactory) {
 
 		this(
 			stateType,
 			name,
 			StateSerializerProvider.fromNewRegisteredSerializer(namespaceSerializer),
 			StateSerializerProvider.fromNewRegisteredSerializer(stateSerializer),
-			snapshotTransformer);
+			stateSnapshotTransformFactory);
 	}
 
 	@SuppressWarnings("unchecked")
@@ -91,7 +92,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 			StateSerializerProvider.fromPreviousSerializerSnapshot(
 				(TypeSerializerSnapshot<S>) Preconditions.checkNotNull(
 					snapshot.getTypeSerializerSnapshot(StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER))),
-			null);
+			StateSnapshotTransformFactory.noTransform());
 
 		Preconditions.checkState(StateMetaInfoSnapshot.BackendStateType.KEY_VALUE == snapshot.getBackendStateType());
 	}
@@ -101,13 +102,13 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 		@Nonnull String name,
 		@Nonnull StateSerializerProvider<N> namespaceSerializerProvider,
 		@Nonnull StateSerializerProvider<S> stateSerializerProvider,
-		@Nullable StateSnapshotTransformer<S> snapshotTransformer) {
+		@Nonnull StateSnapshotTransformFactory<S> stateSnapshotTransformFactory) {
 
 		super(name);
 		this.stateType = stateType;
 		this.namespaceSerializerProvider = namespaceSerializerProvider;
 		this.stateSerializerProvider = stateSerializerProvider;
-		this.snapshotTransformer = snapshotTransformer;
+		this.stateSnapshotTransformFactory = stateSnapshotTransformFactory;
 	}
 
 	@Nonnull
@@ -145,13 +146,13 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 		return stateSerializerProvider.previousSchemaSerializer();
 	}
 
-	@Nullable
-	public StateSnapshotTransformer<S> getSnapshotTransformer() {
-		return snapshotTransformer;
+	@Nonnull
+	public StateSnapshotTransformFactory<S> getStateSnapshotTransformFactory() {
+		return stateSnapshotTransformFactory;
 	}
 
-	public void updateSnapshotTransformer(StateSnapshotTransformer<S> snapshotTransformer) {
-		this.snapshotTransformer = snapshotTransformer;
+	public void updateSnapshotTransformFactory(StateSnapshotTransformFactory<S> stateSnapshotTransformFactory) {
+		this.stateSnapshotTransformFactory = stateSnapshotTransformFactory;
 	}
 
 	@Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
index cd2c7bf..2eb4c3f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
@@ -18,19 +18,11 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy;
-
 import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
 
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
 import java.util.Optional;
 
-import static org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy.STOP_ON_FIRST_INCLUDED;
-
 /**
  * Transformer of state values which are included or skipped in the snapshot.
  *
@@ -44,6 +36,7 @@ import static org.apache.flink.runtime.state.StateSnapshotTransformer.Collection
  * @param <T> type of state
  */
 @FunctionalInterface
+@NotThreadSafe
 public interface StateSnapshotTransformer<T> {
 	/**
 	 * Transform or filter out state values which are included or skipped in the snapshot.
@@ -75,84 +68,6 @@ public interface StateSnapshotTransformer<T> {
 	}
 
 	/**
-	 * General implementation of list state transformer.
-	 *
-	 * <p>This transformer wraps a transformer per-entry
-	 * and transforms the whole list state.
-	 * If the wrapped per entry transformer is {@link CollectionStateSnapshotTransformer},
-	 * it respects its {@link TransformStrategy}.
-	 */
-	class ListStateSnapshotTransformer<T> implements StateSnapshotTransformer<List<T>> {
-		private final StateSnapshotTransformer<T> entryValueTransformer;
-		private final TransformStrategy transformStrategy;
-
-		public ListStateSnapshotTransformer(StateSnapshotTransformer<T> entryValueTransformer) {
-			this.entryValueTransformer = entryValueTransformer;
-			this.transformStrategy = entryValueTransformer instanceof CollectionStateSnapshotTransformer ?
-				((CollectionStateSnapshotTransformer) entryValueTransformer).getFilterStrategy() :
-				TransformStrategy.TRANSFORM_ALL;
-		}
-
-		@Override
-		@Nullable
-		public List<T> filterOrTransform(@Nullable List<T> list) {
-			if (list == null) {
-				return null;
-			}
-			List<T> transformedList = new ArrayList<>();
-			boolean anyChange = false;
-			for (int i = 0; i < list.size(); i++) {
-				T entry = list.get(i);
-				T transformedEntry = entryValueTransformer.filterOrTransform(entry);
-				if (transformedEntry != null) {
-					if (transformStrategy == STOP_ON_FIRST_INCLUDED) {
-						transformedList = list.subList(i, list.size());
-						anyChange = i > 0;
-						break;
-					} else {
-						transformedList.add(transformedEntry);
-					}
-				}
-				anyChange |= transformedEntry == null || !Objects.equals(entry, transformedEntry);
-			}
-			transformedList = anyChange ? transformedList : list;
-			return transformedList.isEmpty() ? null : transformedList;
-		}
-	}
-
-	/**
-	 * General implementation of map state transformer.
-	 *
-	 * <p>This transformer wraps a transformer per-entry
-	 * and transforms the whole map state.
-	 */
-	class MapStateSnapshotTransformer<K, V> implements StateSnapshotTransformer<Map<K, V>> {
-		private final StateSnapshotTransformer<V> entryValueTransformer;
-
-		public MapStateSnapshotTransformer(StateSnapshotTransformer<V> entryValueTransformer) {
-			this.entryValueTransformer = entryValueTransformer;
-		}
-
-		@Nullable
-		@Override
-		public Map<K, V> filterOrTransform(@Nullable Map<K, V> map) {
-			if (map == null) {
-				return null;
-			}
-			Map<K, V> transformedMap = new HashMap<>();
-			boolean anyChange = false;
-			for (Map.Entry<K, V> entry : map.entrySet()) {
-				V transformedValue = entryValueTransformer.filterOrTransform(entry.getValue());
-				if (transformedValue != null) {
-					transformedMap.put(entry.getKey(), transformedValue);
-				}
-				anyChange |= transformedValue == null || !Objects.equals(entry.getValue(), transformedValue);
-			}
-			return anyChange ? (transformedMap.isEmpty() ? null : transformedMap) : map;
-		}
-	}
-
-	/**
 	 * This factory creates state transformers depending on the form of values to transform.
 	 *
 	 * <p>If there is no transforming needed, the factory methods return {@code Optional.empty()}.
@@ -183,4 +98,5 @@ public interface StateSnapshotTransformer<T> {
 
 		Optional<StateSnapshotTransformer<byte[]>> createForSerializedState();
 	}
+
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformers.java
similarity index 56%
copy from flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
copy to flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformers.java
index cd2c7bf..0b9306a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformers.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 
 import javax.annotation.Nullable;
 
@@ -31,66 +31,25 @@ import java.util.Optional;
 
 import static org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy.STOP_ON_FIRST_INCLUDED;
 
-/**
- * Transformer of state values which are included or skipped in the snapshot.
- *
- * <p>This transformer can be applied to state values
- * to decide which entries should be included into the snapshot.
- * The included entries can be optionally modified before.
- *
- * <p>Unless specified differently, the transformer should be applied per entry
- * for collection types of state, like list or map.
- *
- * @param <T> type of state
- */
-@FunctionalInterface
-public interface StateSnapshotTransformer<T> {
-	/**
-	 * Transform or filter out state values which are included or skipped in the snapshot.
-	 *
-	 * @param value non-serialized form of value
-	 * @return value to snapshot or null which means the entry is not included
-	 */
-	@Nullable
-	T filterOrTransform(@Nullable T value);
-
-	/** Collection state specific transformer which says how to transform entries of the collection. */
-	interface CollectionStateSnapshotTransformer<T> extends StateSnapshotTransformer<T> {
-		enum TransformStrategy {
-			/** Transform all entries. */
-			TRANSFORM_ALL,
-
-			/**
-			 * Skip first null entries.
-			 *
-			 * <p>While traversing collection entries, as optimisation, stops transforming
-			 * if encounters first non-null included entry and returns it plus the rest untouched.
-			 */
-			STOP_ON_FIRST_INCLUDED
-		}
-
-		default TransformStrategy getFilterStrategy() {
-			return TransformStrategy.TRANSFORM_ALL;
-		}
-	}
-
+/** Collection of common state snapshot transformers and their factories. */
+public class StateSnapshotTransformers {
 	/**
 	 * General implementation of list state transformer.
 	 *
 	 * <p>This transformer wraps a transformer per-entry
 	 * and transforms the whole list state.
 	 * If the wrapped per entry transformer is {@link CollectionStateSnapshotTransformer},
-	 * it respects its {@link TransformStrategy}.
+	 * it respects its {@link CollectionStateSnapshotTransformer.TransformStrategy}.
 	 */
-	class ListStateSnapshotTransformer<T> implements StateSnapshotTransformer<List<T>> {
+	public static class ListStateSnapshotTransformer<T> implements StateSnapshotTransformer<List<T>> {
 		private final StateSnapshotTransformer<T> entryValueTransformer;
-		private final TransformStrategy transformStrategy;
+		private final CollectionStateSnapshotTransformer.TransformStrategy transformStrategy;
 
 		public ListStateSnapshotTransformer(StateSnapshotTransformer<T> entryValueTransformer) {
 			this.entryValueTransformer = entryValueTransformer;
 			this.transformStrategy = entryValueTransformer instanceof CollectionStateSnapshotTransformer ?
 				((CollectionStateSnapshotTransformer) entryValueTransformer).getFilterStrategy() :
-				TransformStrategy.TRANSFORM_ALL;
+				CollectionStateSnapshotTransformer.TransformStrategy.TRANSFORM_ALL;
 		}
 
 		@Override
@@ -120,13 +79,24 @@ public interface StateSnapshotTransformer<T> {
 		}
 	}
 
+	public static class ListStateSnapshotTransformFactory<T> extends StateSnapshotTransformFactoryWrapAdaptor<T, List<T>> {
+		public ListStateSnapshotTransformFactory(StateSnapshotTransformFactory<T> originalSnapshotTransformFactory) {
+			super(originalSnapshotTransformFactory);
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<List<T>>> createForDeserializedState() {
+			return originalSnapshotTransformFactory.createForDeserializedState().map(ListStateSnapshotTransformer::new);
+		}
+	}
+
 	/**
 	 * General implementation of map state transformer.
 	 *
 	 * <p>This transformer wraps a transformer per-entry
 	 * and transforms the whole map state.
 	 */
-	class MapStateSnapshotTransformer<K, V> implements StateSnapshotTransformer<Map<K, V>> {
+	public static class MapStateSnapshotTransformer<K, V> implements StateSnapshotTransformer<Map<K, V>> {
 		private final StateSnapshotTransformer<V> entryValueTransformer;
 
 		public MapStateSnapshotTransformer(StateSnapshotTransformer<V> entryValueTransformer) {
@@ -152,35 +122,32 @@ public interface StateSnapshotTransformer<T> {
 		}
 	}
 
-	/**
-	 * This factory creates state transformers depending on the form of values to transform.
-	 *
-	 * <p>If there is no transforming needed, the factory methods return {@code Optional.empty()}.
-	 */
-	interface StateSnapshotTransformFactory<T> {
-		StateSnapshotTransformFactory<?> NO_TRANSFORM = createNoTransform();
+	public static class MapStateSnapshotTransformFactory<K, V> extends StateSnapshotTransformFactoryWrapAdaptor<V, Map<K, V>> {
+		public MapStateSnapshotTransformFactory(StateSnapshotTransformFactory<V> originalSnapshotTransformFactory) {
+			super(originalSnapshotTransformFactory);
+		}
 
-		@SuppressWarnings("unchecked")
-		static <T> StateSnapshotTransformFactory<T> noTransform() {
-			return (StateSnapshotTransformFactory<T>) NO_TRANSFORM;
+		@Override
+		public Optional<StateSnapshotTransformer<Map<K, V>>> createForDeserializedState() {
+			return originalSnapshotTransformFactory.createForDeserializedState().map(MapStateSnapshotTransformer::new);
 		}
+	}
 
-		static <T> StateSnapshotTransformFactory<T> createNoTransform() {
-			return new StateSnapshotTransformFactory<T>() {
-				@Override
-				public Optional<StateSnapshotTransformer<T>> createForDeserializedState() {
-					return Optional.empty();
-				}
+	public abstract static class StateSnapshotTransformFactoryWrapAdaptor<S, T> implements StateSnapshotTransformFactory<T> {
+		final StateSnapshotTransformFactory<S> originalSnapshotTransformFactory;
 
-				@Override
-				public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
-					return Optional.empty();
-				}
-			};
+		StateSnapshotTransformFactoryWrapAdaptor(StateSnapshotTransformFactory<S> originalSnapshotTransformFactory) {
+			this.originalSnapshotTransformFactory = originalSnapshotTransformFactory;
 		}
 
-		Optional<StateSnapshotTransformer<T>> createForDeserializedState();
+		@Override
+		public Optional<StateSnapshotTransformer<T>> createForDeserializedState() {
+			throw new UnsupportedOperationException();
+		}
 
-		Optional<StateSnapshotTransformer<byte[]>> createForSerializedState();
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			throw new UnsupportedOperationException();
+		}
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
index 21abf8d..12afcbc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
@@ -88,6 +88,9 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 	@Nonnull
 	private final TypeSerializer<S> localStateSerializer;
 
+	@Nullable
+	private final StateSnapshotTransformer<S> stateSnapshotTransformer;
+
 	/**
 	 * Result of partitioning the snapshot by key-group. This is lazily created in the process of writing this snapshot
 	 * to an output as part of checkpointing.
@@ -114,6 +117,9 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 		this.localStateSerializer = owningStateTable.metaInfo.getStateSerializer().duplicate();
 
 		this.partitionedStateTableSnapshot = null;
+
+		this.stateSnapshotTransformer = owningStateTable.metaInfo.
+			getStateSnapshotTransformFactory().createForDeserializedState().orElse(null);
 	}
 
 	/**
@@ -147,7 +153,6 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 					localKeySerializer.serialize(element.key, dov);
 					localStateSerializer.serialize(element.state, dov);
 				};
-			StateSnapshotTransformer<S> stateSnapshotTransformer = owningStateTable.metaInfo.getSnapshotTransformer();
 			StateTableKeyGroupPartitioner<K, N, S> stateTableKeyGroupPartitioner = stateSnapshotTransformer != null ?
 				new TransformingStateTableKeyGroupPartitioner<>(
 					snapshotData,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 55d5a6f..56374fe 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -61,8 +61,8 @@ import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
 import org.apache.flink.runtime.state.StateSnapshotRestore;
-import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
+import org.apache.flink.runtime.state.StateSnapshotTransformers;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
@@ -78,7 +78,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.io.InputStream;
@@ -89,7 +88,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.Optional;
 import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 import java.util.stream.Collectors;
@@ -229,7 +227,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	private <N, V> StateTable<K, N, V> tryRegisterStateTable(
 			TypeSerializer<N> namespaceSerializer,
 			StateDescriptor<?, V> stateDesc,
-			@Nullable StateSnapshotTransformer<V> snapshotTransformer) throws StateMigrationException {
+			@Nonnull StateSnapshotTransformFactory<V> snapshotTransformFactory) throws StateMigrationException {
 
 		@SuppressWarnings("unchecked")
 		StateTable<K, N, V> stateTable = (StateTable<K, N, V>) registeredKVStates.get(stateDesc.getName());
@@ -239,7 +237,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		if (stateTable != null) {
 			RegisteredKeyValueStateBackendMetaInfo<N, V> restoredKvMetaInfo = stateTable.getMetaInfo();
 
-			restoredKvMetaInfo.updateSnapshotTransformer(snapshotTransformer);
+			restoredKvMetaInfo.updateSnapshotTransformFactory(snapshotTransformFactory);
 
 			TypeSerializerSchemaCompatibility<N> namespaceCompatibility =
 				restoredKvMetaInfo.updateNamespaceSerializer(namespaceSerializer);
@@ -263,7 +261,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				stateDesc.getName(),
 				namespaceSerializer,
 				newStateSerializer,
-				snapshotTransformer);
+				snapshotTransformFactory);
 
 			stateTable = snapshotStrategy.newStateTable(newMetaInfo);
 			registeredKVStates.put(stateDesc.getName(), stateTable);
@@ -301,27 +299,20 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			throw new FlinkRuntimeException(message);
 		}
 		StateTable<K, N, SV> stateTable = tryRegisterStateTable(
-			namespaceSerializer, stateDesc, getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
+			namespaceSerializer, stateDesc, getStateSnapshotTransformFactory(stateDesc, snapshotTransformFactory));
 		return stateFactory.createState(stateDesc, stateTable, getKeySerializer());
 	}
 
 	@SuppressWarnings("unchecked")
-	private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
+	private <SV, SEV> StateSnapshotTransformFactory<SV> getStateSnapshotTransformFactory(
 		StateDescriptor<?, SV> stateDesc,
 		StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
-		Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
-		if (original.isPresent()) {
-			if (stateDesc instanceof ListStateDescriptor) {
-				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
-					.ListStateSnapshotTransformer<>(original.get());
-			} else if (stateDesc instanceof MapStateDescriptor) {
-				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
-					.MapStateSnapshotTransformer<>(original.get());
-			} else {
-				return (StateSnapshotTransformer<SV>) original.get();
-			}
+		if (stateDesc instanceof ListStateDescriptor) {
+			return (StateSnapshotTransformFactory<SV>) new StateSnapshotTransformers.ListStateSnapshotTransformFactory<>(snapshotTransformFactory);
+		} else if (stateDesc instanceof MapStateDescriptor) {
+			return (StateSnapshotTransformFactory<SV>) new StateSnapshotTransformers.MapStateSnapshotTransformFactory<>(snapshotTransformFactory);
 		} else {
-			return null;
+			return (StateSnapshotTransformFactory<SV>) snapshotTransformFactory;
 		}
 	}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
index f982370..167d90f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.Preconditions;
@@ -319,7 +320,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	@Nonnull
 	@Override
 	public NestedMapsStateTableSnapshot<K, N, S> stateSnapshot() {
-		return new NestedMapsStateTableSnapshot<>(this, metaInfo.getSnapshotTransformer());
+		return new NestedMapsStateTableSnapshot<>(this, metaInfo.getStateSnapshotTransformFactory());
 	}
 
 	/**
@@ -337,9 +338,11 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 		private final TypeSerializer<S> stateSerializer;
 		private final StateSnapshotTransformer<S> snapshotFilter;
 
-		NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> owningTable, StateSnapshotTransformer<S> snapshotFilter) {
+		NestedMapsStateTableSnapshot(
+			NestedMapsStateTable<K, N, S> owningTable, StateSnapshotTransformFactory<S> snapshotTransformFactory) {
+
 			super(owningTable);
-			this.snapshotFilter = snapshotFilter;
+			this.snapshotFilter = snapshotTransformFactory.createForDeserializedState().orElse(null);
 			this.keySerializer = owningStateTable.keyContext.getKeySerializer();
 			this.namespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer();
 			this.stateSerializer = owningStateTable.metaInfo.getStateSerializer();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
index e3706ec..fd29271 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
@@ -91,7 +91,7 @@ abstract class TtlStateSnapshotTransformer<T> implements CollectionStateSnapshot
 			try {
 				ts = deserializeTs(value);
 			} catch (IOException e) {
-				throw new FlinkRuntimeException("Unexpected timestamp deserialization failure");
+				throw new FlinkRuntimeException("Unexpected timestamp deserialization failure", e);
 			}
 			return expired(ts) ? null : value;
 		}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index f1269fe..a4306fd 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -3614,6 +3614,22 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	}
 
 	@Test
+	public void testNonConcurrentSnapshotTransformerAccess() throws Exception {
+		BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024);
+		AbstractKeyedStateBackend<Integer> backend = null;
+		try {
+			backend = createKeyedBackend(IntSerializer.INSTANCE);
+			new StateSnapshotTransformerTest(backend, streamFactory)
+				.testNonConcurrentSnapshotTransformerAccess();
+		} finally {
+			if (backend != null) {
+				IOUtils.closeQuietly(backend);
+				backend.dispose();
+			}
+		}
+	}
+
+	@Test
 	public void testAsyncSnapshot() throws Exception {
 		OneShotLatch waiter = new OneShotLatch();
 		BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotTransformerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotTransformerTest.java
new file mode 100644
index 0000000..42bda6e
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotTransformerTest.java
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
+import org.apache.flink.runtime.state.internal.InternalListState;
+import org.apache.flink.runtime.state.internal.InternalMapState;
+import org.apache.flink.runtime.state.internal.InternalValueState;
+import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory;
+import org.apache.flink.util.StringUtils;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.RunnableFuture;
+
+import static org.junit.Assert.assertEquals;
+
+class StateSnapshotTransformerTest {
+	private final AbstractKeyedStateBackend<Integer> backend;
+	private final BlockerCheckpointStreamFactory streamFactory;
+	private final StateSnapshotTransformFactory<?> snapshotTransformFactory;
+
+	StateSnapshotTransformerTest(
+		AbstractKeyedStateBackend<Integer> backend,
+		BlockerCheckpointStreamFactory streamFactory) {
+
+		this.backend = backend;
+		this.streamFactory = streamFactory;
+		this.snapshotTransformFactory = SingleThreadAccessCheckingSnapshotTransformFactory.create();
+	}
+
+	void testNonConcurrentSnapshotTransformerAccess() throws Exception {
+		List<TestState> testStates = Arrays.asList(
+			new TestValueState(),
+			new TestListState(),
+			new TestMapState()
+		);
+
+		for (TestState state : testStates) {
+			for (int i = 0; i < 100; i++) {
+				backend.setCurrentKey(i);
+				state.setToRandomValue();
+			}
+
+			CheckpointOptions checkpointOptions = CheckpointOptions.forCheckpointWithDefaultLocation();
+
+			RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot1 =
+				backend.snapshot(1L, 0L, streamFactory, checkpointOptions);
+
+			RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot2 =
+				backend.snapshot(2L, 0L, streamFactory, checkpointOptions);
+
+			Thread runner1 = new Thread(snapshot1, "snapshot1");
+			runner1.start();
+			Thread runner2 = new Thread(snapshot2, "snapshot2");
+			runner2.start();
+
+			runner1.join();
+			runner2.join();
+
+			snapshot1.get();
+			snapshot2.get();
+		}
+	}
+
+	private abstract class TestState {
+		final Random rnd;
+
+		private TestState() {
+			this.rnd = new Random();
+		}
+
+		abstract void setToRandomValue() throws Exception;
+
+		String getRandomString() {
+			return StringUtils.getRandomString(rnd, 5, 10);
+		}
+	}
+
+	private class TestValueState extends TestState {
+		private final InternalValueState<Integer, VoidNamespace, String> state;
+
+		private TestValueState() throws Exception {
+			this.state = backend.createInternalState(
+				VoidNamespaceSerializer.INSTANCE,
+				new ValueStateDescriptor<>("TestValueState", StringSerializer.INSTANCE),
+				snapshotTransformFactory);
+			state.setCurrentNamespace(VoidNamespace.INSTANCE);
+		}
+
+		@Override
+		void setToRandomValue() throws Exception {
+			state.update(getRandomString());
+		}
+	}
+
+	private class TestListState extends TestState {
+		private final InternalListState<Integer, VoidNamespace, String> state;
+
+		private TestListState() throws Exception {
+			this.state = backend.createInternalState(
+				VoidNamespaceSerializer.INSTANCE,
+				new ListStateDescriptor<>("TestListState", new SingleThreadAccessCheckingTypeSerializer()),
+				snapshotTransformFactory);
+			state.setCurrentNamespace(VoidNamespace.INSTANCE);
+		}
+
+		@Override
+		void setToRandomValue() throws Exception {
+			int length = rnd.nextInt(10);
+			for (int i = 0; i < length; i++) {
+				state.add(getRandomString());
+			}
+		}
+	}
+
+	private class TestMapState extends TestState {
+		private final InternalMapState<Integer, VoidNamespace, String, String> state;
+
+		private TestMapState() throws Exception {
+			this.state = backend.createInternalState(
+				VoidNamespaceSerializer.INSTANCE,
+				new MapStateDescriptor<>("TestMapState", StringSerializer.INSTANCE, StringSerializer.INSTANCE),
+				snapshotTransformFactory);
+			state.setCurrentNamespace(VoidNamespace.INSTANCE);
+		}
+
+		@Override
+		void setToRandomValue() throws Exception {
+			int length = rnd.nextInt(10);
+			for (int i = 0; i < length; i++) {
+				state.put(getRandomString(), getRandomString());
+			}
+		}
+	}
+
+	private static class SingleThreadAccessCheckingSnapshotTransformFactory<T>
+		implements StateSnapshotTransformFactory<T> {
+
+		private final SingleThreadAccessChecker singleThreadAccessChecker = new SingleThreadAccessChecker();
+
+		static <T> StateSnapshotTransformFactory<T> create() {
+			return new SingleThreadAccessCheckingSnapshotTransformFactory<>();
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<T>> createForDeserializedState() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return createStateSnapshotTransformer();
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return createStateSnapshotTransformer();
+		}
+
+		private <T1> Optional<StateSnapshotTransformer<T1>> createStateSnapshotTransformer() {
+			return Optional.of(new StateSnapshotTransformer<T1>() {
+				private final SingleThreadAccessChecker singleThreadAccessChecker = new SingleThreadAccessChecker();
+
+				@Nullable
+				@Override
+				public T1 filterOrTransform(@Nullable T1 value) {
+					singleThreadAccessChecker.checkSingleThreadAccess();
+					return value;
+				}
+			});
+		}
+	}
+
+	private static class SingleThreadAccessCheckingTypeSerializer extends TypeSerializer<String> {
+		private final SingleThreadAccessChecker singleThreadAccessChecker = new SingleThreadAccessChecker();
+
+		@Override
+		public boolean isImmutableType() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.isImmutableType();
+		}
+
+		@Override
+		public TypeSerializer<String> duplicate() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return new SingleThreadAccessCheckingTypeSerializer();
+		}
+
+		@Override
+		public String createInstance() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.createInstance();
+		}
+
+		@Override
+		public String copy(String from) {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.copy(from);
+		}
+
+		@Override
+		public String copy(String from, String reuse) {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.copy(from, reuse);
+		}
+
+		@Override
+		public int getLength() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.getLength();
+		}
+
+		@Override
+		public void serialize(String record, DataOutputView target) throws IOException {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			StringSerializer.INSTANCE.serialize(record, target);
+		}
+
+		@Override
+		public String deserialize(DataInputView source) throws IOException {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.deserialize(source);
+		}
+
+		@Override
+		public String deserialize(String reuse, DataInputView source) throws IOException {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.deserialize(reuse, source);
+		}
+
+		@Override
+		public void copy(DataInputView source, DataOutputView target) throws IOException {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			StringSerializer.INSTANCE.copy(source, target);
+		}
+
+		@Override
+		public boolean equals(Object obj) {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return obj == this ||
+				(obj != null && obj.getClass() == getClass() &&
+					StringSerializer.INSTANCE.equals(obj));
+		}
+
+		@Override
+		public boolean canEqual(Object obj) {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return (obj != null && obj.getClass() == getClass() &&
+				StringSerializer.INSTANCE.canEqual(obj));
+		}
+
+		@Override
+		public int hashCode() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.hashCode();
+		}
+
+		@Override
+		public TypeSerializerSnapshot<String> snapshotConfiguration() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.snapshotConfiguration();
+		}
+	}
+
+	private static class SingleThreadAccessChecker {
+		private Thread currentThread = null;
+
+		void checkSingleThreadAccess() {
+			if (currentThread == null) {
+				currentThread = Thread.currentThread();
+			} else {
+				assertEquals("Concurrent access from another thread",
+					currentThread, Thread.currentThread());
+			}
+		}
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index f88e6d7..2725051 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -45,6 +45,7 @@ import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
+import org.apache.flink.runtime.state.StateSnapshotTransformers;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet;
 import org.apache.flink.runtime.state.ttl.TtlStateFactory;
@@ -131,11 +132,9 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
 		if (original.isPresent()) {
 			if (stateDesc instanceof ListStateDescriptor) {
-				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
-					.ListStateSnapshotTransformer<>(original.get());
+				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformers.ListStateSnapshotTransformer<>(original.get());
 			} else if (stateDesc instanceof MapStateDescriptor) {
-				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
-					.MapStateSnapshotTransformer<>(original.get());
+				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformers.MapStateSnapshotTransformer<>(original.get());
 			} else {
 				return (StateSnapshotTransformer<SV>) original.get();
 			}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
index 8ed84c0..9a899f4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
@@ -27,10 +27,12 @@ import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
+import org.apache.flink.runtime.state.CheckpointMetadataOutputStream;
 import org.apache.flink.runtime.state.CheckpointStorage;
 import org.apache.flink.runtime.state.CheckpointStorageLocation;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.CheckpointedStateScope;
 import org.apache.flink.runtime.state.CompletedCheckpointStorageLocation;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.OperatorStateBackend;
@@ -65,7 +67,28 @@ public class MockStateBackend extends AbstractStateBackend {
 
 			@Override
 			public CheckpointStorageLocation initializeLocationForCheckpoint(long checkpointId) {
-				return null;
+				return new CheckpointStorageLocation() {
+
+					@Override
+					public CheckpointStateOutputStream createCheckpointStateOutputStream(CheckpointedStateScope scope) {
+						return null;
+					}
+
+					@Override
+					public CheckpointMetadataOutputStream createMetadataOutputStream() {
+						return null;
+					}
+
+					@Override
+					public void disposeOnFailure() {
+
+					}
+
+					@Override
+					public CheckpointStorageLocationReference getLocationReference() {
+						return null;
+					}
+				};
 			}
 
 			@Override
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index e994682..7a585db 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -70,7 +70,6 @@ import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.StateHandleID;
-import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -114,7 +113,6 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.Optional;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.Spliterator;
@@ -126,6 +124,7 @@ import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import java.util.stream.StreamSupport;
 
+import static org.apache.flink.contrib.streaming.state.RocksDBSnapshotTransformFactoryAdaptor.wrapStateSnapshotTransformFactory;
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.END_OF_KEY_GROUP_MARK;
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.SST_FILE_SUFFIX;
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.clearMetaDataFollowsFlag;
@@ -368,7 +367,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@VisibleForTesting
-	public ColumnFamilyHandle getColumnFamilyHandle(String state) {
+	ColumnFamilyHandle getColumnFamilyHandle(String state) {
 		Tuple2<ColumnFamilyHandle, ?> columnInfo = kvStateInformation.get(state);
 		return columnInfo != null ? columnInfo.f0 : null;
 	}
@@ -688,7 +687,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		 *
 		 * @param rocksDBKeyedStateBackend the state backend into which we restore
 		 */
-		public RocksDBFullRestoreOperation(RocksDBKeyedStateBackend<K> rocksDBKeyedStateBackend) {
+		RocksDBFullRestoreOperation(RocksDBKeyedStateBackend<K> rocksDBKeyedStateBackend) {
 			this.rocksDBKeyedStateBackend = Preconditions.checkNotNull(rocksDBKeyedStateBackend);
 		}
 
@@ -697,7 +696,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		 *
 		 * @param keyedStateHandles List of all key groups state handles that shall be restored.
 		 */
-		public void doRestore(Collection<KeyedStateHandle> keyedStateHandles)
+		void doRestore(Collection<KeyedStateHandle> keyedStateHandles)
 			throws IOException, StateMigrationException, RocksDBException {
 
 			rocksDBKeyedStateBackend.createDB();
@@ -1344,10 +1343,10 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * already have a registered entry for that and return it (after some necessary state compatibility checks)
 	 * or create a new one if it does not exist.
 	 */
-	private <N, S extends State, SV> Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> tryRegisterKvStateInformation(
+	private <N, S extends State, SV, SEV> Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> tryRegisterKvStateInformation(
 			StateDescriptor<S, SV> stateDesc,
 			TypeSerializer<N> namespaceSerializer,
-			@Nullable StateSnapshotTransformer<SV> snapshotTransformer) throws Exception {
+			@Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception {
 
 		Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> oldStateInfo =
 			kvStateInformation.get(stateDesc.getName());
@@ -1364,8 +1363,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				Tuple2.of(oldStateInfo.f0, castedMetaInfo),
 				stateDesc,
 				namespaceSerializer,
-				stateSerializer,
-				snapshotTransformer);
+				stateSerializer);
 
 			oldStateInfo.f1 = newMetaInfo;
 			newColumnFamily = oldStateInfo.f0;
@@ -1375,12 +1373,16 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				stateDesc.getName(),
 				namespaceSerializer,
 				stateSerializer,
-				snapshotTransformer);
+				StateSnapshotTransformFactory.noTransform());
 
 			newColumnFamily = createColumnFamily(stateDesc.getName());
 			registerKvStateInformation(stateDesc.getName(), Tuple2.of(newColumnFamily, newMetaInfo));
 		}
 
+		StateSnapshotTransformFactory<SV> wrappedSnapshotTransformFactory = wrapStateSnapshotTransformFactory(
+			stateDesc, snapshotTransformFactory, newMetaInfo.getStateSerializer());
+		newMetaInfo.updateSnapshotTransformFactory(wrappedSnapshotTransformFactory);
+
 		return Tuple2.of(newColumnFamily, newMetaInfo);
 	}
 
@@ -1388,14 +1390,11 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> oldStateInfo,
 			StateDescriptor<S, SV> stateDesc,
 			TypeSerializer<N> namespaceSerializer,
-			TypeSerializer<SV> stateSerializer,
-			@Nullable StateSnapshotTransformer<SV> snapshotTransformer) throws Exception {
+			TypeSerializer<SV> stateSerializer) throws Exception {
 
 		@SuppressWarnings("unchecked")
 		RegisteredKeyValueStateBackendMetaInfo<N, SV> restoredKvStateMetaInfo = oldStateInfo.f1;
 
-		restoredKvStateMetaInfo.updateSnapshotTransformer(snapshotTransformer);
-
 		TypeSerializerSchemaCompatibility<N> s = restoredKvStateMetaInfo.updateNamespaceSerializer(namespaceSerializer);
 		if (s.isCompatibleAfterMigration() || s.isIncompatible()) {
 			throw new StateMigrationException("The new namespace serializer must be compatible.");
@@ -1512,39 +1511,14 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			throw new FlinkRuntimeException(message);
 		}
 		Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> registerResult = tryRegisterKvStateInformation(
-			stateDesc, namespaceSerializer, getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
+			stateDesc, namespaceSerializer, snapshotTransformFactory);
 		return stateFactory.createState(stateDesc, registerResult, RocksDBKeyedStateBackend.this);
 	}
 
-	@SuppressWarnings("unchecked")
-	private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
-		StateDescriptor<?, SV> stateDesc,
-		StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
-		if (stateDesc instanceof ListStateDescriptor) {
-			Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
-			return original.map(est -> createRocksDBListStateTransformer(stateDesc, est)).orElse(null);
-		} else if (stateDesc instanceof MapStateDescriptor) {
-			Optional<StateSnapshotTransformer<byte[]>> original = snapshotTransformFactory.createForSerializedState();
-			return (StateSnapshotTransformer<SV>) original
-				.map(RocksDBMapState.StateSnapshotTransformerWrapper::new).orElse(null);
-		} else {
-			Optional<StateSnapshotTransformer<byte[]>> original = snapshotTransformFactory.createForSerializedState();
-			return (StateSnapshotTransformer<SV>) original.orElse(null);
-		}
-	}
-
-	@SuppressWarnings("unchecked")
-	private <SV, SEV> StateSnapshotTransformer<SV> createRocksDBListStateTransformer(
-		StateDescriptor<?, SV> stateDesc,
-		StateSnapshotTransformer<SEV> elementTransformer) {
-		return (StateSnapshotTransformer<SV>) new RocksDBListState.StateSnapshotTransformerWrapper<>(
-			elementTransformer, ((ListStateDescriptor<SEV>) stateDesc).getElementSerializer());
-	}
-
 	/**
 	 * Only visible for testing, DO NOT USE.
 	 */
-	public File getInstanceBasePath() {
+	File getInstanceBasePath() {
 		return instanceBasePath;
 	}
 
@@ -1578,7 +1552,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		return new RocksIteratorWrapper(db.newIterator());
 	}
 
-	public static RocksIteratorWrapper getRocksIterator(
+	static RocksIteratorWrapper getRocksIterator(
 		RocksDB db,
 		ColumnFamilyHandle columnFamilyHandle) {
 		return new RocksIteratorWrapper(db.newIterator(columnFamilyHandle));
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBSnapshotTransformFactoryAdaptor.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBSnapshotTransformFactoryAdaptor.java
new file mode 100644
index 0000000..5b018c8
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBSnapshotTransformFactoryAdaptor.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.ListSerializer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
+
+import java.util.Optional;
+
+abstract class RocksDBSnapshotTransformFactoryAdaptor<SV, SEV> implements StateSnapshotTransformFactory<SV> {
+	final StateSnapshotTransformFactory<SEV> snapshotTransformFactory;
+
+	RocksDBSnapshotTransformFactoryAdaptor(StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+		this.snapshotTransformFactory = snapshotTransformFactory;
+	}
+
+	@Override
+	public Optional<StateSnapshotTransformer<SV>> createForDeserializedState() {
+		throw new UnsupportedOperationException("Only serialized state filtering is supported in RocksDB backend");
+	}
+
+	@SuppressWarnings("unchecked")
+	static <SV, SEV> StateSnapshotTransformFactory<SV> wrapStateSnapshotTransformFactory(
+		StateDescriptor<?, SV> stateDesc,
+		StateSnapshotTransformFactory<SEV> snapshotTransformFactory,
+		TypeSerializer<SV> stateSerializer) {
+		if (stateDesc instanceof ListStateDescriptor) {
+			TypeSerializer<SEV> elementSerializer = ((ListSerializer<SEV>) stateSerializer).getElementSerializer();
+			return new RocksDBListStateSnapshotTransformFactory<>(snapshotTransformFactory, elementSerializer);
+		} else if (stateDesc instanceof MapStateDescriptor) {
+			return new RocksDBMapStateSnapshotTransformFactory<>(snapshotTransformFactory);
+		} else {
+			return new RocksDBValueStateSnapshotTransformFactory<>(snapshotTransformFactory);
+		}
+	}
+
+	private static class RocksDBValueStateSnapshotTransformFactory<SV, SEV>
+		extends RocksDBSnapshotTransformFactoryAdaptor<SV, SEV> {
+
+		private RocksDBValueStateSnapshotTransformFactory(StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+			super(snapshotTransformFactory);
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			return snapshotTransformFactory.createForSerializedState();
+		}
+	}
+
+	private static class RocksDBMapStateSnapshotTransformFactory<SV, SEV>
+		extends RocksDBSnapshotTransformFactoryAdaptor<SV, SEV> {
+
+		private RocksDBMapStateSnapshotTransformFactory(StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+			super(snapshotTransformFactory);
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			return snapshotTransformFactory.createForSerializedState()
+				.map(RocksDBMapState.StateSnapshotTransformerWrapper::new);
+		}
+	}
+
+	private static class RocksDBListStateSnapshotTransformFactory<SV, SEV>
+		extends RocksDBSnapshotTransformFactoryAdaptor<SV, SEV> {
+
+		private final TypeSerializer<SEV> elementSerializer;
+
+		@SuppressWarnings("unchecked")
+		private RocksDBListStateSnapshotTransformFactory(
+			StateSnapshotTransformFactory<SEV> snapshotTransformFactory,
+			TypeSerializer<SEV> elementSerializer) {
+
+			super(snapshotTransformFactory);
+			this.elementSerializer = elementSerializer;
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			return snapshotTransformFactory.createForDeserializedState()
+				.map(est -> new RocksDBListState.StateSnapshotTransformerWrapper<>(est, elementSerializer.duplicate()));
+		}
+	}
+}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
index 817f684..f556e12 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
@@ -192,7 +192,7 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 		private List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
 
 		@Nonnull
-		private List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy;
+		private List<MetaData> metaData;
 
 		@Nonnull
 		private final String logPathString;
@@ -209,7 +209,7 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 			this.dbLease = dbLease;
 			this.snapshot = snapshot;
 			this.stateMetaInfoSnapshots = stateMetaInfoSnapshots;
-			this.metaDataCopy = metaDataCopy;
+			this.metaData = fillMetaData(metaDataCopy);
 			this.logPathString = logPathString;
 		}
 
@@ -248,7 +248,7 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 			@Nonnull KeyGroupRangeOffsets keyGroupRangeOffsets) throws IOException, InterruptedException {
 
 			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators =
-				new ArrayList<>(metaDataCopy.size());
+				new ArrayList<>(metaData.size());
 			final DataOutputView outputView =
 				new DataOutputViewStreamWrapper(checkpointStreamWithResultProvider.getCheckpointOutputStream());
 			final ReadOptions readOptions = new ReadOptions();
@@ -273,10 +273,10 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 
 			int kvStateId = 0;
 
-			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : metaDataCopy) {
+			for (MetaData metaDataEntry : metaData) {
 
-				RocksIteratorWrapper rocksIteratorWrapper =
-					getRocksIterator(db, tuple2.f0, tuple2.f1, readOptions);
+				RocksIteratorWrapper rocksIteratorWrapper = getRocksIterator(
+					db, metaDataEntry.columnFamilyHandle, metaDataEntry.stateSnapshotTransformer, readOptions);
 
 				kvStateIterators.add(Tuple2.of(rocksIteratorWrapper, kvStateId));
 				++kvStateId;
@@ -402,20 +402,45 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 		}
 	}
 
+	private static List<MetaData> fillMetaData(
+		List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy) {
+		List<MetaData> metaData = new ArrayList<>(metaDataCopy.size());
+		for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> metaInfo : metaDataCopy) {
+			StateSnapshotTransformer<byte[]> stateSnapshotTransformer = null;
+			if (metaInfo.f1 instanceof RegisteredKeyValueStateBackendMetaInfo) {
+				stateSnapshotTransformer = ((RegisteredKeyValueStateBackendMetaInfo<?, ?>) metaInfo.f1).
+					getStateSnapshotTransformFactory().createForSerializedState().orElse(null);
+			}
+			metaData.add(new MetaData(metaInfo.f0, metaInfo.f1, stateSnapshotTransformer));
+		}
+		return metaData;
+	}
+
 	@SuppressWarnings("unchecked")
 	private static RocksIteratorWrapper getRocksIterator(
 		RocksDB db,
 		ColumnFamilyHandle columnFamilyHandle,
-		RegisteredStateMetaInfoBase metaInfo,
+		StateSnapshotTransformer<byte[]> stateSnapshotTransformer,
 		ReadOptions readOptions) {
-		StateSnapshotTransformer<byte[]> stateSnapshotTransformer = null;
-		if (metaInfo instanceof RegisteredKeyValueStateBackendMetaInfo) {
-			stateSnapshotTransformer = (StateSnapshotTransformer<byte[]>)
-				((RegisteredKeyValueStateBackendMetaInfo<?, ?>) metaInfo).getSnapshotTransformer();
-		}
 		RocksIterator rocksIterator = db.newIterator(columnFamilyHandle, readOptions);
 		return stateSnapshotTransformer == null ?
 			new RocksIteratorWrapper(rocksIterator) :
 			new RocksTransformingIteratorWrapper(rocksIterator, stateSnapshotTransformer);
 	}
+
+	private static class MetaData {
+		final ColumnFamilyHandle columnFamilyHandle;
+		final RegisteredStateMetaInfoBase registeredStateMetaInfoBase;
+		final StateSnapshotTransformer<byte[]> stateSnapshotTransformer;
+
+		private MetaData(
+			ColumnFamilyHandle columnFamilyHandle,
+			RegisteredStateMetaInfoBase registeredStateMetaInfoBase,
+			StateSnapshotTransformer<byte[]> stateSnapshotTransformer) {
+
+			this.columnFamilyHandle = columnFamilyHandle;
+			this.registeredStateMetaInfoBase = registeredStateMetaInfoBase;
+			this.stateSnapshotTransformer = stateSnapshotTransformer;
+		}
+	}
 }