You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by tr...@apache.org on 2019/01/25 21:41:42 UTC

[flink] 02/02: [FLINK-11171] Avoid concurrent usage of StateSnapshotTransformer

This is an automated email from the ASF dual-hosted git repository.

trohrmann pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit bced96a5a0b8f7b7848add316c12071e0398404a
Author: Andrey Zagrebin <az...@gmail.com>
AuthorDate: Mon Dec 17 17:09:47 2018 +0100

    [FLINK-11171] Avoid concurrent usage of StateSnapshotTransformer
    
    Test non concurrent access of StateSnapshotTransformer
    
    Refactor out testNonConcurrentSnapshotTransformerAccess to separte StateSnapshotTransformerTest
    
    use element serializer from new meta info, duplicate it in rocksdb transformer factory, test concurrent access for element serializer
    
    This closes #7320.
---
 .../RegisteredKeyValueStateBackendMetaInfo.java    |  27 +-
 .../runtime/state/StateSnapshotTransformer.java    |  90 +-----
 ...sformer.java => StateSnapshotTransformers.java} | 111 +++-----
 .../state/heap/CopyOnWriteStateTableSnapshot.java  |   7 +-
 .../runtime/state/heap/HeapKeyedStateBackend.java  |  31 +--
 .../runtime/state/heap/NestedMapsStateTable.java   |   9 +-
 .../state/ttl/TtlStateSnapshotTransformer.java     |   2 +-
 .../flink/runtime/state/StateBackendTestBase.java  |  16 ++
 .../state/StateSnapshotTransformerTest.java        | 305 +++++++++++++++++++++
 .../state/ttl/mock/MockKeyedStateBackend.java      |   7 +-
 .../runtime/state/ttl/mock/MockStateBackend.java   |  25 +-
 .../streaming/state/RocksDBKeyedStateBackend.java  |  58 ++--
 .../RocksDBSnapshotTransformFactoryAdaptor.java    | 105 +++++++
 .../state/snapshot/RocksFullSnapshotStrategy.java  |  49 +++-
 14 files changed, 586 insertions(+), 256 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
index b2d1cdc..1ce728d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
 import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.Preconditions;
 
@@ -48,8 +49,8 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 	private final StateSerializerProvider<N> namespaceSerializerProvider;
 	@Nonnull
 	private final StateSerializerProvider<S> stateSerializerProvider;
-	@Nullable
-	private StateSnapshotTransformer<S> snapshotTransformer;
+	@Nonnull
+	private StateSnapshotTransformFactory<S> stateSnapshotTransformFactory;
 
 	public RegisteredKeyValueStateBackendMetaInfo(
 		@Nonnull StateDescriptor.Type stateType,
@@ -62,7 +63,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 			name,
 			StateSerializerProvider.fromNewRegisteredSerializer(namespaceSerializer),
 			StateSerializerProvider.fromNewRegisteredSerializer(stateSerializer),
-			null);
+			StateSnapshotTransformFactory.noTransform());
 	}
 
 	public RegisteredKeyValueStateBackendMetaInfo(
@@ -70,14 +71,14 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 		@Nonnull String name,
 		@Nonnull TypeSerializer<N> namespaceSerializer,
 		@Nonnull TypeSerializer<S> stateSerializer,
-		@Nullable StateSnapshotTransformer<S> snapshotTransformer) {
+		@Nonnull StateSnapshotTransformFactory<S> stateSnapshotTransformFactory) {
 
 		this(
 			stateType,
 			name,
 			StateSerializerProvider.fromNewRegisteredSerializer(namespaceSerializer),
 			StateSerializerProvider.fromNewRegisteredSerializer(stateSerializer),
-			snapshotTransformer);
+			stateSnapshotTransformFactory);
 	}
 
 	@SuppressWarnings("unchecked")
@@ -91,7 +92,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 			StateSerializerProvider.fromPreviousSerializerSnapshot(
 				(TypeSerializerSnapshot<S>) Preconditions.checkNotNull(
 					snapshot.getTypeSerializerSnapshot(StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER))),
-			null);
+			StateSnapshotTransformFactory.noTransform());
 
 		Preconditions.checkState(StateMetaInfoSnapshot.BackendStateType.KEY_VALUE == snapshot.getBackendStateType());
 	}
@@ -101,13 +102,13 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 		@Nonnull String name,
 		@Nonnull StateSerializerProvider<N> namespaceSerializerProvider,
 		@Nonnull StateSerializerProvider<S> stateSerializerProvider,
-		@Nullable StateSnapshotTransformer<S> snapshotTransformer) {
+		@Nonnull StateSnapshotTransformFactory<S> stateSnapshotTransformFactory) {
 
 		super(name);
 		this.stateType = stateType;
 		this.namespaceSerializerProvider = namespaceSerializerProvider;
 		this.stateSerializerProvider = stateSerializerProvider;
-		this.snapshotTransformer = snapshotTransformer;
+		this.stateSnapshotTransformFactory = stateSnapshotTransformFactory;
 	}
 
 	@Nonnull
@@ -145,13 +146,13 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 		return stateSerializerProvider.previousSchemaSerializer();
 	}
 
-	@Nullable
-	public StateSnapshotTransformer<S> getSnapshotTransformer() {
-		return snapshotTransformer;
+	@Nonnull
+	public StateSnapshotTransformFactory<S> getStateSnapshotTransformFactory() {
+		return stateSnapshotTransformFactory;
 	}
 
-	public void updateSnapshotTransformer(StateSnapshotTransformer<S> snapshotTransformer) {
-		this.snapshotTransformer = snapshotTransformer;
+	public void updateSnapshotTransformFactory(StateSnapshotTransformFactory<S> stateSnapshotTransformFactory) {
+		this.stateSnapshotTransformFactory = stateSnapshotTransformFactory;
 	}
 
 	@Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
index cd2c7bf..2eb4c3f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
@@ -18,19 +18,11 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy;
-
 import javax.annotation.Nullable;
+import javax.annotation.concurrent.NotThreadSafe;
 
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
 import java.util.Optional;
 
-import static org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy.STOP_ON_FIRST_INCLUDED;
-
 /**
  * Transformer of state values which are included or skipped in the snapshot.
  *
@@ -44,6 +36,7 @@ import static org.apache.flink.runtime.state.StateSnapshotTransformer.Collection
  * @param <T> type of state
  */
 @FunctionalInterface
+@NotThreadSafe
 public interface StateSnapshotTransformer<T> {
 	/**
 	 * Transform or filter out state values which are included or skipped in the snapshot.
@@ -75,84 +68,6 @@ public interface StateSnapshotTransformer<T> {
 	}
 
 	/**
-	 * General implementation of list state transformer.
-	 *
-	 * <p>This transformer wraps a transformer per-entry
-	 * and transforms the whole list state.
-	 * If the wrapped per entry transformer is {@link CollectionStateSnapshotTransformer},
-	 * it respects its {@link TransformStrategy}.
-	 */
-	class ListStateSnapshotTransformer<T> implements StateSnapshotTransformer<List<T>> {
-		private final StateSnapshotTransformer<T> entryValueTransformer;
-		private final TransformStrategy transformStrategy;
-
-		public ListStateSnapshotTransformer(StateSnapshotTransformer<T> entryValueTransformer) {
-			this.entryValueTransformer = entryValueTransformer;
-			this.transformStrategy = entryValueTransformer instanceof CollectionStateSnapshotTransformer ?
-				((CollectionStateSnapshotTransformer) entryValueTransformer).getFilterStrategy() :
-				TransformStrategy.TRANSFORM_ALL;
-		}
-
-		@Override
-		@Nullable
-		public List<T> filterOrTransform(@Nullable List<T> list) {
-			if (list == null) {
-				return null;
-			}
-			List<T> transformedList = new ArrayList<>();
-			boolean anyChange = false;
-			for (int i = 0; i < list.size(); i++) {
-				T entry = list.get(i);
-				T transformedEntry = entryValueTransformer.filterOrTransform(entry);
-				if (transformedEntry != null) {
-					if (transformStrategy == STOP_ON_FIRST_INCLUDED) {
-						transformedList = list.subList(i, list.size());
-						anyChange = i > 0;
-						break;
-					} else {
-						transformedList.add(transformedEntry);
-					}
-				}
-				anyChange |= transformedEntry == null || !Objects.equals(entry, transformedEntry);
-			}
-			transformedList = anyChange ? transformedList : list;
-			return transformedList.isEmpty() ? null : transformedList;
-		}
-	}
-
-	/**
-	 * General implementation of map state transformer.
-	 *
-	 * <p>This transformer wraps a transformer per-entry
-	 * and transforms the whole map state.
-	 */
-	class MapStateSnapshotTransformer<K, V> implements StateSnapshotTransformer<Map<K, V>> {
-		private final StateSnapshotTransformer<V> entryValueTransformer;
-
-		public MapStateSnapshotTransformer(StateSnapshotTransformer<V> entryValueTransformer) {
-			this.entryValueTransformer = entryValueTransformer;
-		}
-
-		@Nullable
-		@Override
-		public Map<K, V> filterOrTransform(@Nullable Map<K, V> map) {
-			if (map == null) {
-				return null;
-			}
-			Map<K, V> transformedMap = new HashMap<>();
-			boolean anyChange = false;
-			for (Map.Entry<K, V> entry : map.entrySet()) {
-				V transformedValue = entryValueTransformer.filterOrTransform(entry.getValue());
-				if (transformedValue != null) {
-					transformedMap.put(entry.getKey(), transformedValue);
-				}
-				anyChange |= transformedValue == null || !Objects.equals(entry.getValue(), transformedValue);
-			}
-			return anyChange ? (transformedMap.isEmpty() ? null : transformedMap) : map;
-		}
-	}
-
-	/**
 	 * This factory creates state transformers depending on the form of values to transform.
 	 *
 	 * <p>If there is no transforming needed, the factory methods return {@code Optional.empty()}.
@@ -183,4 +98,5 @@ public interface StateSnapshotTransformer<T> {
 
 		Optional<StateSnapshotTransformer<byte[]>> createForSerializedState();
 	}
+
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformers.java
similarity index 56%
copy from flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
copy to flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformers.java
index cd2c7bf..0b9306a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformers.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.runtime.state;
 
-import org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 
 import javax.annotation.Nullable;
 
@@ -31,66 +31,25 @@ import java.util.Optional;
 
 import static org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy.STOP_ON_FIRST_INCLUDED;
 
-/**
- * Transformer of state values which are included or skipped in the snapshot.
- *
- * <p>This transformer can be applied to state values
- * to decide which entries should be included into the snapshot.
- * The included entries can be optionally modified before.
- *
- * <p>Unless specified differently, the transformer should be applied per entry
- * for collection types of state, like list or map.
- *
- * @param <T> type of state
- */
-@FunctionalInterface
-public interface StateSnapshotTransformer<T> {
-	/**
-	 * Transform or filter out state values which are included or skipped in the snapshot.
-	 *
-	 * @param value non-serialized form of value
-	 * @return value to snapshot or null which means the entry is not included
-	 */
-	@Nullable
-	T filterOrTransform(@Nullable T value);
-
-	/** Collection state specific transformer which says how to transform entries of the collection. */
-	interface CollectionStateSnapshotTransformer<T> extends StateSnapshotTransformer<T> {
-		enum TransformStrategy {
-			/** Transform all entries. */
-			TRANSFORM_ALL,
-
-			/**
-			 * Skip first null entries.
-			 *
-			 * <p>While traversing collection entries, as optimisation, stops transforming
-			 * if encounters first non-null included entry and returns it plus the rest untouched.
-			 */
-			STOP_ON_FIRST_INCLUDED
-		}
-
-		default TransformStrategy getFilterStrategy() {
-			return TransformStrategy.TRANSFORM_ALL;
-		}
-	}
-
+/** Collection of common state snapshot transformers and their factories. */
+public class StateSnapshotTransformers {
 	/**
 	 * General implementation of list state transformer.
 	 *
 	 * <p>This transformer wraps a transformer per-entry
 	 * and transforms the whole list state.
 	 * If the wrapped per entry transformer is {@link CollectionStateSnapshotTransformer},
-	 * it respects its {@link TransformStrategy}.
+	 * it respects its {@link CollectionStateSnapshotTransformer.TransformStrategy}.
 	 */
-	class ListStateSnapshotTransformer<T> implements StateSnapshotTransformer<List<T>> {
+	public static class ListStateSnapshotTransformer<T> implements StateSnapshotTransformer<List<T>> {
 		private final StateSnapshotTransformer<T> entryValueTransformer;
-		private final TransformStrategy transformStrategy;
+		private final CollectionStateSnapshotTransformer.TransformStrategy transformStrategy;
 
 		public ListStateSnapshotTransformer(StateSnapshotTransformer<T> entryValueTransformer) {
 			this.entryValueTransformer = entryValueTransformer;
 			this.transformStrategy = entryValueTransformer instanceof CollectionStateSnapshotTransformer ?
 				((CollectionStateSnapshotTransformer) entryValueTransformer).getFilterStrategy() :
-				TransformStrategy.TRANSFORM_ALL;
+				CollectionStateSnapshotTransformer.TransformStrategy.TRANSFORM_ALL;
 		}
 
 		@Override
@@ -120,13 +79,24 @@ public interface StateSnapshotTransformer<T> {
 		}
 	}
 
+	public static class ListStateSnapshotTransformFactory<T> extends StateSnapshotTransformFactoryWrapAdaptor<T, List<T>> {
+		public ListStateSnapshotTransformFactory(StateSnapshotTransformFactory<T> originalSnapshotTransformFactory) {
+			super(originalSnapshotTransformFactory);
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<List<T>>> createForDeserializedState() {
+			return originalSnapshotTransformFactory.createForDeserializedState().map(ListStateSnapshotTransformer::new);
+		}
+	}
+
 	/**
 	 * General implementation of map state transformer.
 	 *
 	 * <p>This transformer wraps a transformer per-entry
 	 * and transforms the whole map state.
 	 */
-	class MapStateSnapshotTransformer<K, V> implements StateSnapshotTransformer<Map<K, V>> {
+	public static class MapStateSnapshotTransformer<K, V> implements StateSnapshotTransformer<Map<K, V>> {
 		private final StateSnapshotTransformer<V> entryValueTransformer;
 
 		public MapStateSnapshotTransformer(StateSnapshotTransformer<V> entryValueTransformer) {
@@ -152,35 +122,32 @@ public interface StateSnapshotTransformer<T> {
 		}
 	}
 
-	/**
-	 * This factory creates state transformers depending on the form of values to transform.
-	 *
-	 * <p>If there is no transforming needed, the factory methods return {@code Optional.empty()}.
-	 */
-	interface StateSnapshotTransformFactory<T> {
-		StateSnapshotTransformFactory<?> NO_TRANSFORM = createNoTransform();
+	public static class MapStateSnapshotTransformFactory<K, V> extends StateSnapshotTransformFactoryWrapAdaptor<V, Map<K, V>> {
+		public MapStateSnapshotTransformFactory(StateSnapshotTransformFactory<V> originalSnapshotTransformFactory) {
+			super(originalSnapshotTransformFactory);
+		}
 
-		@SuppressWarnings("unchecked")
-		static <T> StateSnapshotTransformFactory<T> noTransform() {
-			return (StateSnapshotTransformFactory<T>) NO_TRANSFORM;
+		@Override
+		public Optional<StateSnapshotTransformer<Map<K, V>>> createForDeserializedState() {
+			return originalSnapshotTransformFactory.createForDeserializedState().map(MapStateSnapshotTransformer::new);
 		}
+	}
 
-		static <T> StateSnapshotTransformFactory<T> createNoTransform() {
-			return new StateSnapshotTransformFactory<T>() {
-				@Override
-				public Optional<StateSnapshotTransformer<T>> createForDeserializedState() {
-					return Optional.empty();
-				}
+	public abstract static class StateSnapshotTransformFactoryWrapAdaptor<S, T> implements StateSnapshotTransformFactory<T> {
+		final StateSnapshotTransformFactory<S> originalSnapshotTransformFactory;
 
-				@Override
-				public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
-					return Optional.empty();
-				}
-			};
+		StateSnapshotTransformFactoryWrapAdaptor(StateSnapshotTransformFactory<S> originalSnapshotTransformFactory) {
+			this.originalSnapshotTransformFactory = originalSnapshotTransformFactory;
 		}
 
-		Optional<StateSnapshotTransformer<T>> createForDeserializedState();
+		@Override
+		public Optional<StateSnapshotTransformer<T>> createForDeserializedState() {
+			throw new UnsupportedOperationException();
+		}
 
-		Optional<StateSnapshotTransformer<byte[]>> createForSerializedState();
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			throw new UnsupportedOperationException();
+		}
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
index 21abf8d..12afcbc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
@@ -88,6 +88,9 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 	@Nonnull
 	private final TypeSerializer<S> localStateSerializer;
 
+	@Nullable
+	private final StateSnapshotTransformer<S> stateSnapshotTransformer;
+
 	/**
 	 * Result of partitioning the snapshot by key-group. This is lazily created in the process of writing this snapshot
 	 * to an output as part of checkpointing.
@@ -114,6 +117,9 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 		this.localStateSerializer = owningStateTable.metaInfo.getStateSerializer().duplicate();
 
 		this.partitionedStateTableSnapshot = null;
+
+		this.stateSnapshotTransformer = owningStateTable.metaInfo.
+			getStateSnapshotTransformFactory().createForDeserializedState().orElse(null);
 	}
 
 	/**
@@ -147,7 +153,6 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 					localKeySerializer.serialize(element.key, dov);
 					localStateSerializer.serialize(element.state, dov);
 				};
-			StateSnapshotTransformer<S> stateSnapshotTransformer = owningStateTable.metaInfo.getSnapshotTransformer();
 			StateTableKeyGroupPartitioner<K, N, S> stateTableKeyGroupPartitioner = stateSnapshotTransformer != null ?
 				new TransformingStateTableKeyGroupPartitioner<>(
 					snapshotData,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 55d5a6f..56374fe 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -61,8 +61,8 @@ import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
 import org.apache.flink.runtime.state.StateSnapshotRestore;
-import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
+import org.apache.flink.runtime.state.StateSnapshotTransformers;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
@@ -78,7 +78,6 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
-import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.io.InputStream;
@@ -89,7 +88,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.Optional;
 import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 import java.util.stream.Collectors;
@@ -229,7 +227,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	private <N, V> StateTable<K, N, V> tryRegisterStateTable(
 			TypeSerializer<N> namespaceSerializer,
 			StateDescriptor<?, V> stateDesc,
-			@Nullable StateSnapshotTransformer<V> snapshotTransformer) throws StateMigrationException {
+			@Nonnull StateSnapshotTransformFactory<V> snapshotTransformFactory) throws StateMigrationException {
 
 		@SuppressWarnings("unchecked")
 		StateTable<K, N, V> stateTable = (StateTable<K, N, V>) registeredKVStates.get(stateDesc.getName());
@@ -239,7 +237,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		if (stateTable != null) {
 			RegisteredKeyValueStateBackendMetaInfo<N, V> restoredKvMetaInfo = stateTable.getMetaInfo();
 
-			restoredKvMetaInfo.updateSnapshotTransformer(snapshotTransformer);
+			restoredKvMetaInfo.updateSnapshotTransformFactory(snapshotTransformFactory);
 
 			TypeSerializerSchemaCompatibility<N> namespaceCompatibility =
 				restoredKvMetaInfo.updateNamespaceSerializer(namespaceSerializer);
@@ -263,7 +261,7 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				stateDesc.getName(),
 				namespaceSerializer,
 				newStateSerializer,
-				snapshotTransformer);
+				snapshotTransformFactory);
 
 			stateTable = snapshotStrategy.newStateTable(newMetaInfo);
 			registeredKVStates.put(stateDesc.getName(), stateTable);
@@ -301,27 +299,20 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			throw new FlinkRuntimeException(message);
 		}
 		StateTable<K, N, SV> stateTable = tryRegisterStateTable(
-			namespaceSerializer, stateDesc, getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
+			namespaceSerializer, stateDesc, getStateSnapshotTransformFactory(stateDesc, snapshotTransformFactory));
 		return stateFactory.createState(stateDesc, stateTable, getKeySerializer());
 	}
 
 	@SuppressWarnings("unchecked")
-	private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
+	private <SV, SEV> StateSnapshotTransformFactory<SV> getStateSnapshotTransformFactory(
 		StateDescriptor<?, SV> stateDesc,
 		StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
-		Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
-		if (original.isPresent()) {
-			if (stateDesc instanceof ListStateDescriptor) {
-				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
-					.ListStateSnapshotTransformer<>(original.get());
-			} else if (stateDesc instanceof MapStateDescriptor) {
-				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
-					.MapStateSnapshotTransformer<>(original.get());
-			} else {
-				return (StateSnapshotTransformer<SV>) original.get();
-			}
+		if (stateDesc instanceof ListStateDescriptor) {
+			return (StateSnapshotTransformFactory<SV>) new StateSnapshotTransformers.ListStateSnapshotTransformFactory<>(snapshotTransformFactory);
+		} else if (stateDesc instanceof MapStateDescriptor) {
+			return (StateSnapshotTransformFactory<SV>) new StateSnapshotTransformers.MapStateSnapshotTransformFactory<>(snapshotTransformFactory);
 		} else {
-			return null;
+			return (StateSnapshotTransformFactory<SV>) snapshotTransformFactory;
 		}
 	}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
index f982370..167d90f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateSnapshot;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.Preconditions;
@@ -319,7 +320,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	@Nonnull
 	@Override
 	public NestedMapsStateTableSnapshot<K, N, S> stateSnapshot() {
-		return new NestedMapsStateTableSnapshot<>(this, metaInfo.getSnapshotTransformer());
+		return new NestedMapsStateTableSnapshot<>(this, metaInfo.getStateSnapshotTransformFactory());
 	}
 
 	/**
@@ -337,9 +338,11 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 		private final TypeSerializer<S> stateSerializer;
 		private final StateSnapshotTransformer<S> snapshotFilter;
 
-		NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> owningTable, StateSnapshotTransformer<S> snapshotFilter) {
+		NestedMapsStateTableSnapshot(
+			NestedMapsStateTable<K, N, S> owningTable, StateSnapshotTransformFactory<S> snapshotTransformFactory) {
+
 			super(owningTable);
-			this.snapshotFilter = snapshotFilter;
+			this.snapshotFilter = snapshotTransformFactory.createForDeserializedState().orElse(null);
 			this.keySerializer = owningStateTable.keyContext.getKeySerializer();
 			this.namespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer();
 			this.stateSerializer = owningStateTable.metaInfo.getStateSerializer();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
index e3706ec..fd29271 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
@@ -91,7 +91,7 @@ abstract class TtlStateSnapshotTransformer<T> implements CollectionStateSnapshot
 			try {
 				ts = deserializeTs(value);
 			} catch (IOException e) {
-				throw new FlinkRuntimeException("Unexpected timestamp deserialization failure");
+				throw new FlinkRuntimeException("Unexpected timestamp deserialization failure", e);
 			}
 			return expired(ts) ? null : value;
 		}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
index f1269fe..a4306fd 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java
@@ -3614,6 +3614,22 @@ public abstract class StateBackendTestBase<B extends AbstractStateBackend> exten
 	}
 
 	@Test
+	public void testNonConcurrentSnapshotTransformerAccess() throws Exception {
+		BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024);
+		AbstractKeyedStateBackend<Integer> backend = null;
+		try {
+			backend = createKeyedBackend(IntSerializer.INSTANCE);
+			new StateSnapshotTransformerTest(backend, streamFactory)
+				.testNonConcurrentSnapshotTransformerAccess();
+		} finally {
+			if (backend != null) {
+				IOUtils.closeQuietly(backend);
+				backend.dispose();
+			}
+		}
+	}
+
+	@Test
 	public void testAsyncSnapshot() throws Exception {
 		OneShotLatch waiter = new OneShotLatch();
 		BlockerCheckpointStreamFactory streamFactory = new BlockerCheckpointStreamFactory(1024 * 1024);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotTransformerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotTransformerTest.java
new file mode 100644
index 0000000..42bda6e
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateSnapshotTransformerTest.java
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
+import org.apache.flink.runtime.state.internal.InternalListState;
+import org.apache.flink.runtime.state.internal.InternalMapState;
+import org.apache.flink.runtime.state.internal.InternalValueState;
+import org.apache.flink.runtime.util.BlockerCheckpointStreamFactory;
+import org.apache.flink.util.StringUtils;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+import java.util.Random;
+import java.util.concurrent.RunnableFuture;
+
+import static org.junit.Assert.assertEquals;
+
+class StateSnapshotTransformerTest {
+	private final AbstractKeyedStateBackend<Integer> backend;
+	private final BlockerCheckpointStreamFactory streamFactory;
+	private final StateSnapshotTransformFactory<?> snapshotTransformFactory;
+
+	StateSnapshotTransformerTest(
+		AbstractKeyedStateBackend<Integer> backend,
+		BlockerCheckpointStreamFactory streamFactory) {
+
+		this.backend = backend;
+		this.streamFactory = streamFactory;
+		this.snapshotTransformFactory = SingleThreadAccessCheckingSnapshotTransformFactory.create();
+	}
+
+	void testNonConcurrentSnapshotTransformerAccess() throws Exception {
+		List<TestState> testStates = Arrays.asList(
+			new TestValueState(),
+			new TestListState(),
+			new TestMapState()
+		);
+
+		for (TestState state : testStates) {
+			for (int i = 0; i < 100; i++) {
+				backend.setCurrentKey(i);
+				state.setToRandomValue();
+			}
+
+			CheckpointOptions checkpointOptions = CheckpointOptions.forCheckpointWithDefaultLocation();
+
+			RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot1 =
+				backend.snapshot(1L, 0L, streamFactory, checkpointOptions);
+
+			RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot2 =
+				backend.snapshot(2L, 0L, streamFactory, checkpointOptions);
+
+			Thread runner1 = new Thread(snapshot1, "snapshot1");
+			runner1.start();
+			Thread runner2 = new Thread(snapshot2, "snapshot2");
+			runner2.start();
+
+			runner1.join();
+			runner2.join();
+
+			snapshot1.get();
+			snapshot2.get();
+		}
+	}
+
+	private abstract class TestState {
+		final Random rnd;
+
+		private TestState() {
+			this.rnd = new Random();
+		}
+
+		abstract void setToRandomValue() throws Exception;
+
+		String getRandomString() {
+			return StringUtils.getRandomString(rnd, 5, 10);
+		}
+	}
+
+	private class TestValueState extends TestState {
+		private final InternalValueState<Integer, VoidNamespace, String> state;
+
+		private TestValueState() throws Exception {
+			this.state = backend.createInternalState(
+				VoidNamespaceSerializer.INSTANCE,
+				new ValueStateDescriptor<>("TestValueState", StringSerializer.INSTANCE),
+				snapshotTransformFactory);
+			state.setCurrentNamespace(VoidNamespace.INSTANCE);
+		}
+
+		@Override
+		void setToRandomValue() throws Exception {
+			state.update(getRandomString());
+		}
+	}
+
+	private class TestListState extends TestState {
+		private final InternalListState<Integer, VoidNamespace, String> state;
+
+		private TestListState() throws Exception {
+			this.state = backend.createInternalState(
+				VoidNamespaceSerializer.INSTANCE,
+				new ListStateDescriptor<>("TestListState", new SingleThreadAccessCheckingTypeSerializer()),
+				snapshotTransformFactory);
+			state.setCurrentNamespace(VoidNamespace.INSTANCE);
+		}
+
+		@Override
+		void setToRandomValue() throws Exception {
+			int length = rnd.nextInt(10);
+			for (int i = 0; i < length; i++) {
+				state.add(getRandomString());
+			}
+		}
+	}
+
+	private class TestMapState extends TestState {
+		private final InternalMapState<Integer, VoidNamespace, String, String> state;
+
+		private TestMapState() throws Exception {
+			this.state = backend.createInternalState(
+				VoidNamespaceSerializer.INSTANCE,
+				new MapStateDescriptor<>("TestMapState", StringSerializer.INSTANCE, StringSerializer.INSTANCE),
+				snapshotTransformFactory);
+			state.setCurrentNamespace(VoidNamespace.INSTANCE);
+		}
+
+		@Override
+		void setToRandomValue() throws Exception {
+			int length = rnd.nextInt(10);
+			for (int i = 0; i < length; i++) {
+				state.put(getRandomString(), getRandomString());
+			}
+		}
+	}
+
+	private static class SingleThreadAccessCheckingSnapshotTransformFactory<T>
+		implements StateSnapshotTransformFactory<T> {
+
+		private final SingleThreadAccessChecker singleThreadAccessChecker = new SingleThreadAccessChecker();
+
+		static <T> StateSnapshotTransformFactory<T> create() {
+			return new SingleThreadAccessCheckingSnapshotTransformFactory<>();
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<T>> createForDeserializedState() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return createStateSnapshotTransformer();
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return createStateSnapshotTransformer();
+		}
+
+		private <T1> Optional<StateSnapshotTransformer<T1>> createStateSnapshotTransformer() {
+			return Optional.of(new StateSnapshotTransformer<T1>() {
+				private final SingleThreadAccessChecker singleThreadAccessChecker = new SingleThreadAccessChecker();
+
+				@Nullable
+				@Override
+				public T1 filterOrTransform(@Nullable T1 value) {
+					singleThreadAccessChecker.checkSingleThreadAccess();
+					return value;
+				}
+			});
+		}
+	}
+
+	private static class SingleThreadAccessCheckingTypeSerializer extends TypeSerializer<String> {
+		private final SingleThreadAccessChecker singleThreadAccessChecker = new SingleThreadAccessChecker();
+
+		@Override
+		public boolean isImmutableType() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.isImmutableType();
+		}
+
+		@Override
+		public TypeSerializer<String> duplicate() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return new SingleThreadAccessCheckingTypeSerializer();
+		}
+
+		@Override
+		public String createInstance() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.createInstance();
+		}
+
+		@Override
+		public String copy(String from) {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.copy(from);
+		}
+
+		@Override
+		public String copy(String from, String reuse) {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.copy(from, reuse);
+		}
+
+		@Override
+		public int getLength() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.getLength();
+		}
+
+		@Override
+		public void serialize(String record, DataOutputView target) throws IOException {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			StringSerializer.INSTANCE.serialize(record, target);
+		}
+
+		@Override
+		public String deserialize(DataInputView source) throws IOException {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.deserialize(source);
+		}
+
+		@Override
+		public String deserialize(String reuse, DataInputView source) throws IOException {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.deserialize(reuse, source);
+		}
+
+		@Override
+		public void copy(DataInputView source, DataOutputView target) throws IOException {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			StringSerializer.INSTANCE.copy(source, target);
+		}
+
+		@Override
+		public boolean equals(Object obj) {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return obj == this ||
+				(obj != null && obj.getClass() == getClass() &&
+					StringSerializer.INSTANCE.equals(obj));
+		}
+
+		@Override
+		public boolean canEqual(Object obj) {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return (obj != null && obj.getClass() == getClass() &&
+				StringSerializer.INSTANCE.canEqual(obj));
+		}
+
+		@Override
+		public int hashCode() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.hashCode();
+		}
+
+		@Override
+		public TypeSerializerSnapshot<String> snapshotConfiguration() {
+			singleThreadAccessChecker.checkSingleThreadAccess();
+			return StringSerializer.INSTANCE.snapshotConfiguration();
+		}
+	}
+
+	private static class SingleThreadAccessChecker {
+		private Thread currentThread = null;
+
+		void checkSingleThreadAccess() {
+			if (currentThread == null) {
+				currentThread = Thread.currentThread();
+			} else {
+				assertEquals("Concurrent access from another thread",
+					currentThread, Thread.currentThread());
+			}
+		}
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index f88e6d7..2725051 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -45,6 +45,7 @@ import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
+import org.apache.flink.runtime.state.StateSnapshotTransformers;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet;
 import org.apache.flink.runtime.state.ttl.TtlStateFactory;
@@ -131,11 +132,9 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
 		if (original.isPresent()) {
 			if (stateDesc instanceof ListStateDescriptor) {
-				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
-					.ListStateSnapshotTransformer<>(original.get());
+				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformers.ListStateSnapshotTransformer<>(original.get());
 			} else if (stateDesc instanceof MapStateDescriptor) {
-				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
-					.MapStateSnapshotTransformer<>(original.get());
+				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformers.MapStateSnapshotTransformer<>(original.get());
 			} else {
 				return (StateSnapshotTransformer<SV>) original.get();
 			}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
index 8ed84c0..9a899f4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockStateBackend.java
@@ -27,10 +27,12 @@ import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
 import org.apache.flink.runtime.state.AbstractStateBackend;
+import org.apache.flink.runtime.state.CheckpointMetadataOutputStream;
 import org.apache.flink.runtime.state.CheckpointStorage;
 import org.apache.flink.runtime.state.CheckpointStorageLocation;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.runtime.state.CheckpointedStateScope;
 import org.apache.flink.runtime.state.CompletedCheckpointStorageLocation;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.OperatorStateBackend;
@@ -65,7 +67,28 @@ public class MockStateBackend extends AbstractStateBackend {
 
 			@Override
 			public CheckpointStorageLocation initializeLocationForCheckpoint(long checkpointId) {
-				return null;
+				return new CheckpointStorageLocation() {
+
+					@Override
+					public CheckpointStateOutputStream createCheckpointStateOutputStream(CheckpointedStateScope scope) {
+						return null;
+					}
+
+					@Override
+					public CheckpointMetadataOutputStream createMetadataOutputStream() {
+						return null;
+					}
+
+					@Override
+					public void disposeOnFailure() {
+
+					}
+
+					@Override
+					public CheckpointStorageLocationReference getLocationReference() {
+						return null;
+					}
+				};
 			}
 
 			@Override
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index e994682..7a585db 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -70,7 +70,6 @@ import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
 import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.StateHandleID;
-import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -114,7 +113,6 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
-import java.util.Optional;
 import java.util.Set;
 import java.util.SortedMap;
 import java.util.Spliterator;
@@ -126,6 +124,7 @@ import java.util.stream.Collectors;
 import java.util.stream.Stream;
 import java.util.stream.StreamSupport;
 
+import static org.apache.flink.contrib.streaming.state.RocksDBSnapshotTransformFactoryAdaptor.wrapStateSnapshotTransformFactory;
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.END_OF_KEY_GROUP_MARK;
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.SST_FILE_SUFFIX;
 import static org.apache.flink.contrib.streaming.state.snapshot.RocksSnapshotUtil.clearMetaDataFollowsFlag;
@@ -368,7 +367,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@VisibleForTesting
-	public ColumnFamilyHandle getColumnFamilyHandle(String state) {
+	ColumnFamilyHandle getColumnFamilyHandle(String state) {
 		Tuple2<ColumnFamilyHandle, ?> columnInfo = kvStateInformation.get(state);
 		return columnInfo != null ? columnInfo.f0 : null;
 	}
@@ -688,7 +687,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		 *
 		 * @param rocksDBKeyedStateBackend the state backend into which we restore
 		 */
-		public RocksDBFullRestoreOperation(RocksDBKeyedStateBackend<K> rocksDBKeyedStateBackend) {
+		RocksDBFullRestoreOperation(RocksDBKeyedStateBackend<K> rocksDBKeyedStateBackend) {
 			this.rocksDBKeyedStateBackend = Preconditions.checkNotNull(rocksDBKeyedStateBackend);
 		}
 
@@ -697,7 +696,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		 *
 		 * @param keyedStateHandles List of all key groups state handles that shall be restored.
 		 */
-		public void doRestore(Collection<KeyedStateHandle> keyedStateHandles)
+		void doRestore(Collection<KeyedStateHandle> keyedStateHandles)
 			throws IOException, StateMigrationException, RocksDBException {
 
 			rocksDBKeyedStateBackend.createDB();
@@ -1344,10 +1343,10 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * already have a registered entry for that and return it (after some necessary state compatibility checks)
 	 * or create a new one if it does not exist.
 	 */
-	private <N, S extends State, SV> Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> tryRegisterKvStateInformation(
+	private <N, S extends State, SV, SEV> Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> tryRegisterKvStateInformation(
 			StateDescriptor<S, SV> stateDesc,
 			TypeSerializer<N> namespaceSerializer,
-			@Nullable StateSnapshotTransformer<SV> snapshotTransformer) throws Exception {
+			@Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception {
 
 		Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> oldStateInfo =
 			kvStateInformation.get(stateDesc.getName());
@@ -1364,8 +1363,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				Tuple2.of(oldStateInfo.f0, castedMetaInfo),
 				stateDesc,
 				namespaceSerializer,
-				stateSerializer,
-				snapshotTransformer);
+				stateSerializer);
 
 			oldStateInfo.f1 = newMetaInfo;
 			newColumnFamily = oldStateInfo.f0;
@@ -1375,12 +1373,16 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				stateDesc.getName(),
 				namespaceSerializer,
 				stateSerializer,
-				snapshotTransformer);
+				StateSnapshotTransformFactory.noTransform());
 
 			newColumnFamily = createColumnFamily(stateDesc.getName());
 			registerKvStateInformation(stateDesc.getName(), Tuple2.of(newColumnFamily, newMetaInfo));
 		}
 
+		StateSnapshotTransformFactory<SV> wrappedSnapshotTransformFactory = wrapStateSnapshotTransformFactory(
+			stateDesc, snapshotTransformFactory, newMetaInfo.getStateSerializer());
+		newMetaInfo.updateSnapshotTransformFactory(wrappedSnapshotTransformFactory);
+
 		return Tuple2.of(newColumnFamily, newMetaInfo);
 	}
 
@@ -1388,14 +1390,11 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> oldStateInfo,
 			StateDescriptor<S, SV> stateDesc,
 			TypeSerializer<N> namespaceSerializer,
-			TypeSerializer<SV> stateSerializer,
-			@Nullable StateSnapshotTransformer<SV> snapshotTransformer) throws Exception {
+			TypeSerializer<SV> stateSerializer) throws Exception {
 
 		@SuppressWarnings("unchecked")
 		RegisteredKeyValueStateBackendMetaInfo<N, SV> restoredKvStateMetaInfo = oldStateInfo.f1;
 
-		restoredKvStateMetaInfo.updateSnapshotTransformer(snapshotTransformer);
-
 		TypeSerializerSchemaCompatibility<N> s = restoredKvStateMetaInfo.updateNamespaceSerializer(namespaceSerializer);
 		if (s.isCompatibleAfterMigration() || s.isIncompatible()) {
 			throw new StateMigrationException("The new namespace serializer must be compatible.");
@@ -1512,39 +1511,14 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			throw new FlinkRuntimeException(message);
 		}
 		Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> registerResult = tryRegisterKvStateInformation(
-			stateDesc, namespaceSerializer, getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
+			stateDesc, namespaceSerializer, snapshotTransformFactory);
 		return stateFactory.createState(stateDesc, registerResult, RocksDBKeyedStateBackend.this);
 	}
 
-	@SuppressWarnings("unchecked")
-	private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
-		StateDescriptor<?, SV> stateDesc,
-		StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
-		if (stateDesc instanceof ListStateDescriptor) {
-			Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
-			return original.map(est -> createRocksDBListStateTransformer(stateDesc, est)).orElse(null);
-		} else if (stateDesc instanceof MapStateDescriptor) {
-			Optional<StateSnapshotTransformer<byte[]>> original = snapshotTransformFactory.createForSerializedState();
-			return (StateSnapshotTransformer<SV>) original
-				.map(RocksDBMapState.StateSnapshotTransformerWrapper::new).orElse(null);
-		} else {
-			Optional<StateSnapshotTransformer<byte[]>> original = snapshotTransformFactory.createForSerializedState();
-			return (StateSnapshotTransformer<SV>) original.orElse(null);
-		}
-	}
-
-	@SuppressWarnings("unchecked")
-	private <SV, SEV> StateSnapshotTransformer<SV> createRocksDBListStateTransformer(
-		StateDescriptor<?, SV> stateDesc,
-		StateSnapshotTransformer<SEV> elementTransformer) {
-		return (StateSnapshotTransformer<SV>) new RocksDBListState.StateSnapshotTransformerWrapper<>(
-			elementTransformer, ((ListStateDescriptor<SEV>) stateDesc).getElementSerializer());
-	}
-
 	/**
 	 * Only visible for testing, DO NOT USE.
 	 */
-	public File getInstanceBasePath() {
+	File getInstanceBasePath() {
 		return instanceBasePath;
 	}
 
@@ -1578,7 +1552,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		return new RocksIteratorWrapper(db.newIterator());
 	}
 
-	public static RocksIteratorWrapper getRocksIterator(
+	static RocksIteratorWrapper getRocksIterator(
 		RocksDB db,
 		ColumnFamilyHandle columnFamilyHandle) {
 		return new RocksIteratorWrapper(db.newIterator(columnFamilyHandle));
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBSnapshotTransformFactoryAdaptor.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBSnapshotTransformFactoryAdaptor.java
new file mode 100644
index 0000000..5b018c8
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBSnapshotTransformFactoryAdaptor.java
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.MapStateDescriptor;
+import org.apache.flink.api.common.state.StateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.ListSerializer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
+
+import java.util.Optional;
+
+abstract class RocksDBSnapshotTransformFactoryAdaptor<SV, SEV> implements StateSnapshotTransformFactory<SV> {
+	final StateSnapshotTransformFactory<SEV> snapshotTransformFactory;
+
+	RocksDBSnapshotTransformFactoryAdaptor(StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+		this.snapshotTransformFactory = snapshotTransformFactory;
+	}
+
+	@Override
+	public Optional<StateSnapshotTransformer<SV>> createForDeserializedState() {
+		throw new UnsupportedOperationException("Only serialized state filtering is supported in RocksDB backend");
+	}
+
+	@SuppressWarnings("unchecked")
+	static <SV, SEV> StateSnapshotTransformFactory<SV> wrapStateSnapshotTransformFactory(
+		StateDescriptor<?, SV> stateDesc,
+		StateSnapshotTransformFactory<SEV> snapshotTransformFactory,
+		TypeSerializer<SV> stateSerializer) {
+		if (stateDesc instanceof ListStateDescriptor) {
+			TypeSerializer<SEV> elementSerializer = ((ListSerializer<SEV>) stateSerializer).getElementSerializer();
+			return new RocksDBListStateSnapshotTransformFactory<>(snapshotTransformFactory, elementSerializer);
+		} else if (stateDesc instanceof MapStateDescriptor) {
+			return new RocksDBMapStateSnapshotTransformFactory<>(snapshotTransformFactory);
+		} else {
+			return new RocksDBValueStateSnapshotTransformFactory<>(snapshotTransformFactory);
+		}
+	}
+
+	private static class RocksDBValueStateSnapshotTransformFactory<SV, SEV>
+		extends RocksDBSnapshotTransformFactoryAdaptor<SV, SEV> {
+
+		private RocksDBValueStateSnapshotTransformFactory(StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+			super(snapshotTransformFactory);
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			return snapshotTransformFactory.createForSerializedState();
+		}
+	}
+
+	private static class RocksDBMapStateSnapshotTransformFactory<SV, SEV>
+		extends RocksDBSnapshotTransformFactoryAdaptor<SV, SEV> {
+
+		private RocksDBMapStateSnapshotTransformFactory(StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+			super(snapshotTransformFactory);
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			return snapshotTransformFactory.createForSerializedState()
+				.map(RocksDBMapState.StateSnapshotTransformerWrapper::new);
+		}
+	}
+
+	private static class RocksDBListStateSnapshotTransformFactory<SV, SEV>
+		extends RocksDBSnapshotTransformFactoryAdaptor<SV, SEV> {
+
+		private final TypeSerializer<SEV> elementSerializer;
+
+		@SuppressWarnings("unchecked")
+		private RocksDBListStateSnapshotTransformFactory(
+			StateSnapshotTransformFactory<SEV> snapshotTransformFactory,
+			TypeSerializer<SEV> elementSerializer) {
+
+			super(snapshotTransformFactory);
+			this.elementSerializer = elementSerializer;
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			return snapshotTransformFactory.createForDeserializedState()
+				.map(est -> new RocksDBListState.StateSnapshotTransformerWrapper<>(est, elementSerializer.duplicate()));
+		}
+	}
+}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
index 817f684..f556e12 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/snapshot/RocksFullSnapshotStrategy.java
@@ -192,7 +192,7 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 		private List<StateMetaInfoSnapshot> stateMetaInfoSnapshots;
 
 		@Nonnull
-		private List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy;
+		private List<MetaData> metaData;
 
 		@Nonnull
 		private final String logPathString;
@@ -209,7 +209,7 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 			this.dbLease = dbLease;
 			this.snapshot = snapshot;
 			this.stateMetaInfoSnapshots = stateMetaInfoSnapshots;
-			this.metaDataCopy = metaDataCopy;
+			this.metaData = fillMetaData(metaDataCopy);
 			this.logPathString = logPathString;
 		}
 
@@ -248,7 +248,7 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 			@Nonnull KeyGroupRangeOffsets keyGroupRangeOffsets) throws IOException, InterruptedException {
 
 			final List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators =
-				new ArrayList<>(metaDataCopy.size());
+				new ArrayList<>(metaData.size());
 			final DataOutputView outputView =
 				new DataOutputViewStreamWrapper(checkpointStreamWithResultProvider.getCheckpointOutputStream());
 			final ReadOptions readOptions = new ReadOptions();
@@ -273,10 +273,10 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 
 			int kvStateId = 0;
 
-			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : metaDataCopy) {
+			for (MetaData metaDataEntry : metaData) {
 
-				RocksIteratorWrapper rocksIteratorWrapper =
-					getRocksIterator(db, tuple2.f0, tuple2.f1, readOptions);
+				RocksIteratorWrapper rocksIteratorWrapper = getRocksIterator(
+					db, metaDataEntry.columnFamilyHandle, metaDataEntry.stateSnapshotTransformer, readOptions);
 
 				kvStateIterators.add(Tuple2.of(rocksIteratorWrapper, kvStateId));
 				++kvStateId;
@@ -402,20 +402,45 @@ public class RocksFullSnapshotStrategy<K> extends RocksDBSnapshotStrategyBase<K>
 		}
 	}
 
+	private static List<MetaData> fillMetaData(
+		List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> metaDataCopy) {
+		List<MetaData> metaData = new ArrayList<>(metaDataCopy.size());
+		for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> metaInfo : metaDataCopy) {
+			StateSnapshotTransformer<byte[]> stateSnapshotTransformer = null;
+			if (metaInfo.f1 instanceof RegisteredKeyValueStateBackendMetaInfo) {
+				stateSnapshotTransformer = ((RegisteredKeyValueStateBackendMetaInfo<?, ?>) metaInfo.f1).
+					getStateSnapshotTransformFactory().createForSerializedState().orElse(null);
+			}
+			metaData.add(new MetaData(metaInfo.f0, metaInfo.f1, stateSnapshotTransformer));
+		}
+		return metaData;
+	}
+
 	@SuppressWarnings("unchecked")
 	private static RocksIteratorWrapper getRocksIterator(
 		RocksDB db,
 		ColumnFamilyHandle columnFamilyHandle,
-		RegisteredStateMetaInfoBase metaInfo,
+		StateSnapshotTransformer<byte[]> stateSnapshotTransformer,
 		ReadOptions readOptions) {
-		StateSnapshotTransformer<byte[]> stateSnapshotTransformer = null;
-		if (metaInfo instanceof RegisteredKeyValueStateBackendMetaInfo) {
-			stateSnapshotTransformer = (StateSnapshotTransformer<byte[]>)
-				((RegisteredKeyValueStateBackendMetaInfo<?, ?>) metaInfo).getSnapshotTransformer();
-		}
 		RocksIterator rocksIterator = db.newIterator(columnFamilyHandle, readOptions);
 		return stateSnapshotTransformer == null ?
 			new RocksIteratorWrapper(rocksIterator) :
 			new RocksTransformingIteratorWrapper(rocksIterator, stateSnapshotTransformer);
 	}
+
+	private static class MetaData {
+		final ColumnFamilyHandle columnFamilyHandle;
+		final RegisteredStateMetaInfoBase registeredStateMetaInfoBase;
+		final StateSnapshotTransformer<byte[]> stateSnapshotTransformer;
+
+		private MetaData(
+			ColumnFamilyHandle columnFamilyHandle,
+			RegisteredStateMetaInfoBase registeredStateMetaInfoBase,
+			StateSnapshotTransformer<byte[]> stateSnapshotTransformer) {
+
+			this.columnFamilyHandle = columnFamilyHandle;
+			this.registeredStateMetaInfoBase = registeredStateMetaInfoBase;
+			this.stateSnapshotTransformer = stateSnapshotTransformer;
+		}
+	}
 }