You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2018/08/03 17:40:59 UTC

[flink] branch master updated: [FLINK-9938][state] Clean up full snapshot from expired state with TTL

This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new ce96c40  [FLINK-9938][state] Clean up full snapshot from expired state with TTL
ce96c40 is described below

commit ce96c409148d1a9bc40f581e13900818b5f11f6a
Author: Andrey Zagrebin <az...@gmail.com>
AuthorDate: Fri Aug 3 18:18:07 2018 +0200

    [FLINK-9938][state] Clean up full snapshot from expired state with TTL
    
    This closes #6460.
---
 .../flink/api/common/state/StateDescriptor.java    |  12 +-
 .../flink/api/common/state/StateTtlConfig.java     | 267 +++++++++++++++++++++
 .../api/common/state/StateTtlConfiguration.java    | 205 ----------------
 .../tests/DataStreamStateTTLTestProgram.java       |   4 +-
 .../streaming/tests/TtlVerifyUpdateFunction.java   |   6 +-
 .../tests/verify/AbstractTtlStateVerifier.java     |   4 +-
 .../streaming/tests/verify/TtlStateVerifier.java   |   4 +-
 .../flink/runtime/state/KeyGroupPartitioner.java   |  18 +-
 .../flink/runtime/state/KeyedStateFactory.java     |  32 ++-
 .../RegisteredKeyValueStateBackendMetaInfo.java    |  34 ++-
 .../runtime/state/StateSnapshotTransformer.java    | 186 ++++++++++++++
 .../state/heap/CopyOnWriteStateTableSnapshot.java  | 106 ++++++--
 .../runtime/state/heap/HeapKeyedStateBackend.java  |  44 +++-
 .../runtime/state/heap/NestedMapsStateTable.java   |  64 +++--
 .../flink/runtime/state/heap/StateTable.java       |   2 +-
 .../runtime/state/ttl/AbstractTtlDecorator.java    |  26 +-
 .../flink/runtime/state/ttl/AbstractTtlState.java  |   4 +-
 .../runtime/state/ttl/TtlAggregateFunction.java    |   4 +-
 .../runtime/state/ttl/TtlAggregatingState.java     |   4 +-
 .../flink/runtime/state/ttl/TtlFoldFunction.java   |   4 +-
 .../flink/runtime/state/ttl/TtlFoldingState.java   |   4 +-
 .../flink/runtime/state/ttl/TtlListState.java      |   4 +-
 .../flink/runtime/state/ttl/TtlMapState.java       |   4 +-
 .../flink/runtime/state/ttl/TtlReduceFunction.java |   4 +-
 .../flink/runtime/state/ttl/TtlReducingState.java  |   4 +-
 .../flink/runtime/state/ttl/TtlStateFactory.java   |  96 ++++----
 .../state/ttl/TtlStateSnapshotTransformer.java     | 121 ++++++++++
 .../apache/flink/runtime/state/ttl/TtlUtils.java   |  39 +++
 .../flink/runtime/state/ttl/TtlValueState.java     |   4 +-
 .../flink/runtime/state/ttl/TtlStateTestBase.java  |  63 +++--
 .../state/ttl/mock/MockKeyedStateBackend.java      |  94 ++++++--
 .../streaming/state/RocksDBKeyedStateBackend.java  | 162 ++++++++++---
 .../contrib/streaming/state/RocksDBListState.java  | 100 ++++++--
 33 files changed, 1264 insertions(+), 465 deletions(-)

diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
index 191eb6f..422d77f 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
@@ -96,7 +96,7 @@ public abstract class StateDescriptor<S extends State, T> implements Serializabl
 
 	/** Name for queries against state created from this StateDescriptor. */
 	@Nonnull
-	private StateTtlConfiguration ttlConfig = StateTtlConfiguration.DISABLED;
+	private StateTtlConfig ttlConfig = StateTtlConfig.DISABLED;
 
 	/** The default value returned by the state when no other value is bound to a key. */
 	@Nullable
@@ -210,7 +210,7 @@ public abstract class StateDescriptor<S extends State, T> implements Serializabl
 	 */
 	public void setQueryable(String queryableStateName) {
 		Preconditions.checkArgument(
-			ttlConfig.getTtlUpdateType() == StateTtlConfiguration.TtlUpdateType.Disabled,
+			ttlConfig.getUpdateType() == StateTtlConfig.UpdateType.Disabled,
 			"Queryable state is currently not supported with TTL");
 		if (this.queryableStateName == null) {
 			this.queryableStateName = Preconditions.checkNotNull(queryableStateName, "Registration name");
@@ -243,14 +243,14 @@ public abstract class StateDescriptor<S extends State, T> implements Serializabl
 	 * Configures optional activation of state time-to-live (TTL).
 	 *
 	 * <p>State user value will expire, become unavailable and be cleaned up in storage
-	 * depending on configured {@link StateTtlConfiguration}.
+	 * depending on configured {@link StateTtlConfig}.
 	 *
 	 * @param ttlConfig configuration of state TTL
 	 */
-	public void enableTimeToLive(StateTtlConfiguration ttlConfig) {
+	public void enableTimeToLive(StateTtlConfig ttlConfig) {
 		Preconditions.checkNotNull(ttlConfig);
 		Preconditions.checkArgument(
-			ttlConfig.getTtlUpdateType() != StateTtlConfiguration.TtlUpdateType.Disabled &&
+			ttlConfig.getUpdateType() != StateTtlConfig.UpdateType.Disabled &&
 				queryableStateName == null,
 			"Queryable state is currently not supported with TTL");
 		this.ttlConfig = ttlConfig;
@@ -258,7 +258,7 @@ public abstract class StateDescriptor<S extends State, T> implements Serializabl
 
 	@Nonnull
 	@Internal
-	public StateTtlConfiguration getTtlConfig() {
+	public StateTtlConfig getTtlConfig() {
 		return ttlConfig;
 	}
 
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfig.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfig.java
new file mode 100644
index 0000000..f4ed929
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfig.java
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.state;
+
+import org.apache.flink.api.common.time.Time;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.EnumMap;
+
+import static org.apache.flink.api.common.state.StateTtlConfig.StateVisibility.NeverReturnExpired;
+import static org.apache.flink.api.common.state.StateTtlConfig.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.api.common.state.StateTtlConfig.UpdateType.OnCreateAndWrite;
+
+/**
+ * Configuration of state TTL logic.
+ */
+public class StateTtlConfig implements Serializable {
+
+	private static final long serialVersionUID = -7592693245044289793L;
+
+	public static final StateTtlConfig DISABLED =
+		newBuilder(Time.milliseconds(Long.MAX_VALUE)).setUpdateType(UpdateType.Disabled).build();
+
+	/**
+	 * This option value configures when to update last access timestamp which prolongs state TTL.
+	 */
+	public enum UpdateType {
+		/** TTL is disabled. State does not expire. */
+		Disabled,
+		/** Last access timestamp is initialised when state is created and updated on every write operation. */
+		OnCreateAndWrite,
+		/** The same as <code>OnCreateAndWrite</code> but also updated on read. */
+		OnReadAndWrite
+	}
+
+	/**
+	 * This option configures whether expired user value can be returned or not.
+	 */
+	public enum StateVisibility {
+		/** Return expired user value if it is not cleaned up yet. */
+		ReturnExpiredIfNotCleanedUp,
+		/** Never return expired user value. */
+		NeverReturnExpired
+	}
+
+	/**
+	 * This option configures time scale to use for ttl.
+	 */
+	public enum TimeCharacteristic {
+		/** Processing time, see also <code>TimeCharacteristic.ProcessingTime</code>. */
+		ProcessingTime
+	}
+
+	private final UpdateType updateType;
+	private final StateVisibility stateVisibility;
+	private final TimeCharacteristic timeCharacteristic;
+	private final Time ttl;
+	private final CleanupStrategies cleanupStrategies;
+
+	private StateTtlConfig(
+		UpdateType updateType,
+		StateVisibility stateVisibility,
+		TimeCharacteristic timeCharacteristic,
+		Time ttl,
+		CleanupStrategies cleanupStrategies) {
+		this.updateType = Preconditions.checkNotNull(updateType);
+		this.stateVisibility = Preconditions.checkNotNull(stateVisibility);
+		this.timeCharacteristic = Preconditions.checkNotNull(timeCharacteristic);
+		this.ttl = Preconditions.checkNotNull(ttl);
+		this.cleanupStrategies = cleanupStrategies;
+		Preconditions.checkArgument(ttl.toMilliseconds() > 0,
+			"TTL is expected to be positive");
+	}
+
+	@Nonnull
+	public UpdateType getUpdateType() {
+		return updateType;
+	}
+
+	@Nonnull
+	public StateVisibility getStateVisibility() {
+		return stateVisibility;
+	}
+
+	@Nonnull
+	public Time getTtl() {
+		return ttl;
+	}
+
+	@Nonnull
+	public TimeCharacteristic getTimeCharacteristic() {
+		return timeCharacteristic;
+	}
+
+	public boolean isEnabled() {
+		return updateType != UpdateType.Disabled;
+	}
+
+	@Nonnull
+	public CleanupStrategies getCleanupStrategies() {
+		return cleanupStrategies;
+	}
+
+	@Override
+	public String toString() {
+		return "StateTtlConfig{" +
+			"updateType=" + updateType +
+			", stateVisibility=" + stateVisibility +
+			", timeCharacteristic=" + timeCharacteristic +
+			", ttl=" + ttl +
+			'}';
+	}
+
+	@Nonnull
+	public static Builder newBuilder(@Nonnull Time ttl) {
+		return new Builder(ttl);
+	}
+
+	/**
+	 * Builder for the {@link StateTtlConfig}.
+	 */
+	public static class Builder {
+
+		private UpdateType updateType = OnCreateAndWrite;
+		private StateVisibility stateVisibility = NeverReturnExpired;
+		private TimeCharacteristic timeCharacteristic = ProcessingTime;
+		private Time ttl;
+		private CleanupStrategies cleanupStrategies = new CleanupStrategies();
+
+		public Builder(@Nonnull Time ttl) {
+			this.ttl = ttl;
+		}
+
+		/**
+		 * Sets the ttl update type.
+		 *
+		 * @param updateType The ttl update type configures when to update last access timestamp which prolongs state TTL.
+		 */
+		@Nonnull
+		public Builder setUpdateType(UpdateType updateType) {
+			this.updateType = updateType;
+			return this;
+		}
+
+		@Nonnull
+		public Builder updateTtlOnCreateAndWrite() {
+			return setUpdateType(UpdateType.OnCreateAndWrite);
+		}
+
+		@Nonnull
+		public Builder updateTtlOnReadAndWrite() {
+			return setUpdateType(UpdateType.OnReadAndWrite);
+		}
+
+		/**
+		 * Sets the state visibility.
+		 *
+		 * @param stateVisibility The state visibility configures whether expired user value can be returned or not.
+		 */
+		@Nonnull
+		public Builder setStateVisibility(@Nonnull StateVisibility stateVisibility) {
+			this.stateVisibility = stateVisibility;
+			return this;
+		}
+
+		@Nonnull
+		public Builder returnExpiredIfNotCleanedUp() {
+			return setStateVisibility(StateVisibility.ReturnExpiredIfNotCleanedUp);
+		}
+
+		@Nonnull
+		public Builder neverReturnExpired() {
+			return setStateVisibility(StateVisibility.NeverReturnExpired);
+		}
+
+		/**
+		 * Sets the time characteristic.
+		 *
+		 * @param timeCharacteristic The time characteristic configures time scale to use for ttl.
+		 */
+		@Nonnull
+		public Builder setTimeCharacteristic(@Nonnull TimeCharacteristic timeCharacteristic) {
+			this.timeCharacteristic = timeCharacteristic;
+			return this;
+		}
+
+		@Nonnull
+		public Builder useProcessingTime() {
+			return setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+		}
+
+		/** Cleanup expired state in full snapshot on checkpoint. */
+		@Nonnull
+		public Builder cleanupFullSnapshot() {
+			cleanupStrategies.strategies.put(
+				CleanupStrategies.Strategies.FULL_STATE_SCAN_SNAPSHOT,
+				new CleanupStrategies.CleanupStrategy() {  });
+			return this;
+		}
+
+		/**
+		 * Sets the ttl time.
+		 * @param ttl The ttl time.
+		 */
+		@Nonnull
+		public Builder setTtl(@Nonnull Time ttl) {
+			this.ttl = ttl;
+			return this;
+		}
+
+		@Nonnull
+		public StateTtlConfig build() {
+			return new StateTtlConfig(
+				updateType,
+				stateVisibility,
+				timeCharacteristic,
+				ttl,
+				cleanupStrategies);
+		}
+	}
+
+	/**
+	 * TTL cleanup strategies.
+	 *
+	 * <p>This class configures when to cleanup expired state with TTL.
+	 * By default, state is always cleaned up on explicit read access if found expired.
+	 * Currently cleanup of state full snapshot can be additionally activated.
+	 */
+	public static class CleanupStrategies implements Serializable {
+		private static final long serialVersionUID = -1617740467277313524L;
+
+		/** Fixed strategies ordinals in {@code strategies} config field. */
+		enum Strategies {
+			FULL_STATE_SCAN_SNAPSHOT
+		}
+
+		/** Base interface for cleanup strategies configurations. */
+		interface CleanupStrategy extends Serializable {
+
+		}
+
+		final EnumMap<Strategies, CleanupStrategy> strategies = new EnumMap<>(Strategies.class);
+
+		public boolean inFullSnapshot() {
+			return strategies.containsKey(Strategies.FULL_STATE_SCAN_SNAPSHOT);
+		}
+	}
+}
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfiguration.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfiguration.java
deleted file mode 100644
index 55ec29c..0000000
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfiguration.java
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.api.common.state;
-
-import org.apache.flink.api.common.time.Time;
-import org.apache.flink.util.Preconditions;
-
-import java.io.Serializable;
-
-import static org.apache.flink.api.common.state.StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired;
-import static org.apache.flink.api.common.state.StateTtlConfiguration.TtlTimeCharacteristic.ProcessingTime;
-import static org.apache.flink.api.common.state.StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite;
-
-/**
- * Configuration of state TTL logic.
- */
-public class StateTtlConfiguration implements Serializable {
-
-	private static final long serialVersionUID = -7592693245044289793L;
-
-	public static final StateTtlConfiguration DISABLED =
-		newBuilder(Time.milliseconds(Long.MAX_VALUE)).setTtlUpdateType(TtlUpdateType.Disabled).build();
-
-	/**
-	 * This option value configures when to update last access timestamp which prolongs state TTL.
-	 */
-	public enum TtlUpdateType {
-		/** TTL is disabled. State does not expire. */
-		Disabled,
-		/** Last access timestamp is initialised when state is created and updated on every write operation. */
-		OnCreateAndWrite,
-		/** The same as <code>OnCreateAndWrite</code> but also updated on read. */
-		OnReadAndWrite
-	}
-
-	/**
-	 * This option configures whether expired user value can be returned or not.
-	 */
-	public enum TtlStateVisibility {
-		/** Return expired user value if it is not cleaned up yet. */
-		ReturnExpiredIfNotCleanedUp,
-		/** Never return expired user value. */
-		NeverReturnExpired
-	}
-
-	/**
-	 * This option configures time scale to use for ttl.
-	 */
-	public enum TtlTimeCharacteristic {
-		/** Processing time, see also <code>TimeCharacteristic.ProcessingTime</code>. */
-		ProcessingTime
-	}
-
-	private final TtlUpdateType ttlUpdateType;
-	private final TtlStateVisibility stateVisibility;
-	private final TtlTimeCharacteristic timeCharacteristic;
-	private final Time ttl;
-
-	private StateTtlConfiguration(
-		TtlUpdateType ttlUpdateType,
-		TtlStateVisibility stateVisibility,
-		TtlTimeCharacteristic timeCharacteristic,
-		Time ttl) {
-		this.ttlUpdateType = Preconditions.checkNotNull(ttlUpdateType);
-		this.stateVisibility = Preconditions.checkNotNull(stateVisibility);
-		this.timeCharacteristic = Preconditions.checkNotNull(timeCharacteristic);
-		this.ttl = Preconditions.checkNotNull(ttl);
-		Preconditions.checkArgument(ttl.toMilliseconds() > 0,
-			"TTL is expected to be positive");
-	}
-
-	public TtlUpdateType getTtlUpdateType() {
-		return ttlUpdateType;
-	}
-
-	public TtlStateVisibility getStateVisibility() {
-		return stateVisibility;
-	}
-
-	public Time getTtl() {
-		return ttl;
-	}
-
-	public TtlTimeCharacteristic getTimeCharacteristic() {
-		return timeCharacteristic;
-	}
-
-	public boolean isEnabled() {
-		return ttlUpdateType != TtlUpdateType.Disabled;
-	}
-
-	@Override
-	public String toString() {
-		return "StateTtlConfiguration{" +
-			"ttlUpdateType=" + ttlUpdateType +
-			", stateVisibility=" + stateVisibility +
-			", timeCharacteristic=" + timeCharacteristic +
-			", ttl=" + ttl +
-			'}';
-	}
-
-	public static Builder newBuilder(Time ttl) {
-		return new Builder(ttl);
-	}
-
-	/**
-	 * Builder for the {@link StateTtlConfiguration}.
-	 */
-	public static class Builder {
-
-		private TtlUpdateType ttlUpdateType = OnCreateAndWrite;
-		private TtlStateVisibility stateVisibility = NeverReturnExpired;
-		private TtlTimeCharacteristic timeCharacteristic = ProcessingTime;
-		private Time ttl;
-
-		public Builder(Time ttl) {
-			this.ttl = ttl;
-		}
-
-		/**
-		 * Sets the ttl update type.
-		 *
-		 * @param ttlUpdateType The ttl update type configures when to update last access timestamp which prolongs state TTL.
-		 */
-		public Builder setTtlUpdateType(TtlUpdateType ttlUpdateType) {
-			this.ttlUpdateType = ttlUpdateType;
-			return this;
-		}
-
-		public Builder updateTtlOnCreateAndWrite() {
-			return setTtlUpdateType(TtlUpdateType.OnCreateAndWrite);
-		}
-
-		public Builder updateTtlOnReadAndWrite() {
-			return setTtlUpdateType(TtlUpdateType.OnReadAndWrite);
-		}
-
-		/**
-		 * Sets the state visibility.
-		 *
-		 * @param stateVisibility The state visibility configures whether expired user value can be returned or not.
-		 */
-		public Builder setStateVisibility(TtlStateVisibility stateVisibility) {
-			this.stateVisibility = stateVisibility;
-			return this;
-		}
-
-		public Builder returnExpiredIfNotCleanedUp() {
-			return setStateVisibility(TtlStateVisibility.ReturnExpiredIfNotCleanedUp);
-		}
-
-		public Builder neverReturnExpired() {
-			return setStateVisibility(TtlStateVisibility.NeverReturnExpired);
-		}
-
-		/**
-		 * Sets the time characteristic.
-		 *
-		 * @param timeCharacteristic The time characteristic configures time scale to use for ttl.
-		 */
-		public Builder setTimeCharacteristic(TtlTimeCharacteristic timeCharacteristic) {
-			this.timeCharacteristic = timeCharacteristic;
-			return this;
-		}
-
-		public Builder useProcessingTime() {
-			return setTimeCharacteristic(TtlTimeCharacteristic.ProcessingTime);
-		}
-
-		/**
-		 * Sets the ttl time.
-		 * @param ttl The ttl time.
-		 */
-		public Builder setTtl(Time ttl) {
-			this.ttl = ttl;
-			return this;
-		}
-
-		public StateTtlConfiguration build() {
-			return new StateTtlConfiguration(
-				ttlUpdateType,
-				stateVisibility,
-				timeCharacteristic,
-				ttl
-			);
-		}
-
-	}
-}
diff --git a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/DataStreamStateTTLTestProgram.java b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/DataStreamStateTTLTestProgram.java
index f4c9619..3b2e474 100644
--- a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/DataStreamStateTTLTestProgram.java
+++ b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/DataStreamStateTTLTestProgram.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.streaming.tests;
 
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.java.utils.ParameterTool;
 import org.apache.flink.configuration.ConfigOption;
@@ -84,7 +84,7 @@ public class DataStreamStateTTLTestProgram {
 		long reportStatAfterUpdatesNum = pt.getLong(REPORT_STAT_AFTER_UPDATES_NUM.key(),
 			REPORT_STAT_AFTER_UPDATES_NUM.defaultValue());
 
-		StateTtlConfiguration ttlConfig = StateTtlConfiguration.newBuilder(ttl).build();
+		StateTtlConfig ttlConfig = StateTtlConfig.newBuilder(ttl).build();
 
 		env
 			.addSource(new TtlStateUpdateSource(keySpace, sleepAfterElements, sleepTime))
diff --git a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/TtlVerifyUpdateFunction.java b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/TtlVerifyUpdateFunction.java
index a99a45f..3cfb0e2 100644
--- a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/TtlVerifyUpdateFunction.java
+++ b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/TtlVerifyUpdateFunction.java
@@ -23,7 +23,7 @@ import org.apache.flink.api.common.state.KeyedStateStore;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.State;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
@@ -67,14 +67,14 @@ class TtlVerifyUpdateFunction
 	private static final Logger LOG = LoggerFactory.getLogger(TtlVerifyUpdateFunction.class);
 
 	@Nonnull
-	private final StateTtlConfiguration ttlConfig;
+	private final StateTtlConfig ttlConfig;
 	private final long ttl;
 	private final UpdateStat stat;
 
 	private transient Map<String, State> states;
 	private transient Map<String, ListState<ValueWithTs<?>>> prevUpdatesByVerifierId;
 
-	TtlVerifyUpdateFunction(@Nonnull StateTtlConfiguration ttlConfig, long reportStatAfterUpdatesNum) {
+	TtlVerifyUpdateFunction(@Nonnull StateTtlConfig ttlConfig, long reportStatAfterUpdatesNum) {
 		this.ttlConfig = ttlConfig;
 		this.ttl = ttlConfig.getTtl().toMilliseconds();
 		this.stat = new UpdateStat(reportStatAfterUpdatesNum);
diff --git a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/AbstractTtlStateVerifier.java b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/AbstractTtlStateVerifier.java
index c56ff19..7b6def2 100644
--- a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/AbstractTtlStateVerifier.java
+++ b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/AbstractTtlStateVerifier.java
@@ -20,7 +20,7 @@ package org.apache.flink.streaming.tests.verify;
 
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.util.StringUtils;
@@ -52,7 +52,7 @@ abstract class AbstractTtlStateVerifier<D extends StateDescriptor<S, SV>, S exte
 	@SuppressWarnings("unchecked")
 	@Override
 	@Nonnull
-	public State createState(@Nonnull FunctionInitializationContext context, @Nonnull StateTtlConfiguration ttlConfig) {
+	public State createState(@Nonnull FunctionInitializationContext context, @Nonnull StateTtlConfig ttlConfig) {
 		stateDesc.enableTimeToLive(ttlConfig);
 		return createState(context);
 	}
diff --git a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/TtlStateVerifier.java b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/TtlStateVerifier.java
index e1c2e07..ec5d8b0 100644
--- a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/TtlStateVerifier.java
+++ b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/TtlStateVerifier.java
@@ -19,7 +19,7 @@
 package org.apache.flink.streaming.tests.verify;
 
 import org.apache.flink.api.common.state.State;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 
@@ -45,7 +45,7 @@ public interface TtlStateVerifier<UV, GV> {
 	}
 
 	@Nonnull
-	State createState(@Nonnull FunctionInitializationContext context, @Nonnull StateTtlConfiguration ttlConfig);
+	State createState(@Nonnull FunctionInitializationContext context, @Nonnull StateTtlConfig ttlConfig);
 
 	@Nonnull
 	TypeSerializer<UV> getUpdateSerializer();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java
index 27d411c..95f8369 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java
@@ -135,8 +135,8 @@ public class KeyGroupPartitioner<T> {
 	public StateSnapshot.StateKeyGroupWriter partitionByKeyGroup() {
 		if (computedResult == null) {
 			reportAllElementKeyGroups();
-			buildHistogramByAccumulatingCounts();
-			executePartitioning();
+			int outputNumberOfElements = buildHistogramByAccumulatingCounts();
+			executePartitioning(outputNumberOfElements);
 		}
 		return computedResult;
 	}
@@ -167,22 +167,20 @@ public class KeyGroupPartitioner<T> {
 	/**
 	 * This method creates a histogram from the counts per key-group in {@link #counterHistogram}.
 	 */
-	private void buildHistogramByAccumulatingCounts() {
+	private int buildHistogramByAccumulatingCounts() {
 		int sum = 0;
 		for (int i = 0; i < counterHistogram.length; ++i) {
 			int currentSlotValue = counterHistogram[i];
 			counterHistogram[i] = sum;
 			sum += currentSlotValue;
 		}
-
-		// sanity check that the sum matches the expected number of elements.
-		Preconditions.checkState(sum == numberOfElements);
+		return sum;
 	}
 
-	private void executePartitioning() {
+	private void executePartitioning(int outputNumberOfElements) {
 
 		// We repartition the entries by their pre-computed key-groups, using the histogram values as write indexes
-		for (int inIdx = 0; inIdx < numberOfElements; ++inIdx) {
+		for (int inIdx = 0; inIdx < outputNumberOfElements; ++inIdx) {
 			int effectiveKgIdx = elementKeyGroups[inIdx];
 			int outIdx = counterHistogram[effectiveKgIdx]++;
 			partitioningDestination[outIdx] = partitioningSource[inIdx];
@@ -272,7 +270,7 @@ public class KeyGroupPartitioner<T> {
 	}
 
 	/**
-	 * General algorithm to read key-grouped state that was written from a {@link PartitioningResult}
+	 * General algorithm to read key-grouped state that was written from a {@link PartitioningResult}.
 	 *
 	 * @param <T> type of the elements to read.
 	 */
@@ -339,8 +337,6 @@ public class KeyGroupPartitioner<T> {
 	 */
 	@FunctionalInterface
 	public interface KeyGroupElementsConsumer<T> {
-
-
 		void consume(@Nonnull T element, @Nonnegative int keyGroupId) throws IOException;
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java
index de35979..ca16662 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java
@@ -21,22 +21,48 @@ package org.apache.flink.runtime.state;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.internal.InternalKvState;
 
+import javax.annotation.Nonnull;
+
 /** This factory produces concrete internal state objects. */
 public interface KeyedStateFactory {
+
+	/**
+	 * Creates and returns a new {@link InternalKvState}.
+	 *
+	 * @param namespaceSerializer TypeSerializer for the state namespace.
+	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+	 *
+	 * @param <N> The type of the namespace.
+	 * @param <SV> The type of the stored state value.
+	 * @param <S> The type of the public API state.
+	 * @param <IS> The type of internal state.
+	 */
+	@Nonnull
+	default <N, SV, S extends State, IS extends S> IS createInternalState(
+		@Nonnull TypeSerializer<N> namespaceSerializer,
+		@Nonnull StateDescriptor<S, SV> stateDesc) throws Exception {
+		return createInternalState(namespaceSerializer, stateDesc, StateSnapshotTransformFactory.noTransform());
+	}
+
 	/**
 	 * Creates and returns a new {@link InternalKvState}.
 	 *
 	 * @param namespaceSerializer TypeSerializer for the state namespace.
 	 * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+	 * @param snapshotTransformFactory factory of state snapshot transformer.
 	 *
 	 * @param <N> The type of the namespace.
 	 * @param <SV> The type of the stored state value.
+	 * @param <SEV> The type of the stored state value or entry for collection types (list or map).
 	 * @param <S> The type of the public API state.
 	 * @param <IS> The type of internal state.
 	 */
-	<N, SV, S extends State, IS extends S> IS createInternalState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception;
+	@Nonnull
+	<N, SV, SEV, S extends State, IS extends S> IS createInternalState(
+		@Nonnull TypeSerializer<N> namespaceSerializer,
+		@Nonnull StateDescriptor<S, SV> stateDesc,
+		@Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception;
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
index b0248fc..789551d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
@@ -29,6 +29,7 @@ import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StateMigrationException;
 
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 import java.util.Collections;
 import java.util.HashMap;
@@ -50,17 +51,29 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 	private final TypeSerializer<N> namespaceSerializer;
 	@Nonnull
 	private final TypeSerializer<S> stateSerializer;
+	@Nullable
+	private final StateSnapshotTransformer<S> snapshotTransformer;
 
 	public RegisteredKeyValueStateBackendMetaInfo(
-			@Nonnull StateDescriptor.Type stateType,
-			@Nonnull String name,
-			@Nonnull TypeSerializer<N> namespaceSerializer,
-			@Nonnull TypeSerializer<S> stateSerializer) {
+		@Nonnull StateDescriptor.Type stateType,
+		@Nonnull String name,
+		@Nonnull TypeSerializer<N> namespaceSerializer,
+		@Nonnull TypeSerializer<S> stateSerializer) {
+		this(stateType, name, namespaceSerializer, stateSerializer, null);
+	}
+
+	public RegisteredKeyValueStateBackendMetaInfo(
+		@Nonnull StateDescriptor.Type stateType,
+		@Nonnull String name,
+		@Nonnull TypeSerializer<N> namespaceSerializer,
+		@Nonnull TypeSerializer<S> stateSerializer,
+		@Nullable StateSnapshotTransformer<S> snapshotTransformer) {
 
 		super(name);
 		this.stateType = stateType;
 		this.namespaceSerializer = namespaceSerializer;
 		this.stateSerializer = stateSerializer;
+		this.snapshotTransformer = snapshotTransformer;
 	}
 
 	@SuppressWarnings("unchecked")
@@ -71,7 +84,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 			(TypeSerializer<N>) Preconditions.checkNotNull(
 				snapshot.getTypeSerializer(StateMetaInfoSnapshot.CommonSerializerKeys.NAMESPACE_SERIALIZER)),
 			(TypeSerializer<S>) Preconditions.checkNotNull(
-				snapshot.getTypeSerializer(StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER)));
+				snapshot.getTypeSerializer(StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER)), null);
 		Preconditions.checkState(StateMetaInfoSnapshot.BackendStateType.KEY_VALUE == snapshot.getBackendStateType());
 	}
 
@@ -90,6 +103,11 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 		return stateSerializer;
 	}
 
+	@Nullable
+	public StateSnapshotTransformer<S> getSnapshotTransformer() {
+		return snapshotTransformer;
+	}
+
 	@Override
 	public boolean equals(Object o) {
 		if (this == o) {
@@ -142,7 +160,8 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 	public static <N, S> RegisteredKeyValueStateBackendMetaInfo<N, S> resolveKvStateCompatibility(
 		StateMetaInfoSnapshot restoredStateMetaInfoSnapshot,
 		TypeSerializer<N> newNamespaceSerializer,
-		StateDescriptor<?, S> newStateDescriptor) throws StateMigrationException {
+		StateDescriptor<?, S> newStateDescriptor,
+		@Nullable StateSnapshotTransformer<S> snapshotTransformer) throws StateMigrationException {
 
 		Preconditions.checkState(restoredStateMetaInfoSnapshot.getBackendStateType()
 				== StateMetaInfoSnapshot.BackendStateType.KEY_VALUE,
@@ -196,7 +215,8 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
 				newStateDescriptor.getType(),
 				newStateDescriptor.getName(),
 				newNamespaceSerializer,
-				newStateSerializer);
+				newStateSerializer,
+				snapshotTransformer);
 		}
 	}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
new file mode 100644
index 0000000..cd2c7bf
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy;
+
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+import static org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy.STOP_ON_FIRST_INCLUDED;
+
+/**
+ * Transformer of state values which are included or skipped in the snapshot.
+ *
+ * <p>This transformer can be applied to state values
+ * to decide which entries should be included into the snapshot.
+ * The included entries can be optionally modified before.
+ *
+ * <p>Unless specified differently, the transformer should be applied per entry
+ * for collection types of state, like list or map.
+ *
+ * @param <T> type of state
+ */
+@FunctionalInterface
+public interface StateSnapshotTransformer<T> {
+	/**
+	 * Transform or filter out state values which are included or skipped in the snapshot.
+	 *
+	 * @param value non-serialized form of value
+	 * @return value to snapshot or null which means the entry is not included
+	 */
+	@Nullable
+	T filterOrTransform(@Nullable T value);
+
+	/** Collection state specific transformer which says how to transform entries of the collection. */
+	interface CollectionStateSnapshotTransformer<T> extends StateSnapshotTransformer<T> {
+		enum TransformStrategy {
+			/** Transform all entries. */
+			TRANSFORM_ALL,
+
+			/**
+			 * Skip first null entries.
+			 *
+			 * <p>While traversing collection entries, as optimisation, stops transforming
+			 * if encounters first non-null included entry and returns it plus the rest untouched.
+			 */
+			STOP_ON_FIRST_INCLUDED
+		}
+
+		default TransformStrategy getFilterStrategy() {
+			return TransformStrategy.TRANSFORM_ALL;
+		}
+	}
+
+	/**
+	 * General implementation of list state transformer.
+	 *
+	 * <p>This transformer wraps a transformer per-entry
+	 * and transforms the whole list state.
+	 * If the wrapped per entry transformer is {@link CollectionStateSnapshotTransformer},
+	 * it respects its {@link TransformStrategy}.
+	 */
+	class ListStateSnapshotTransformer<T> implements StateSnapshotTransformer<List<T>> {
+		private final StateSnapshotTransformer<T> entryValueTransformer;
+		private final TransformStrategy transformStrategy;
+
+		public ListStateSnapshotTransformer(StateSnapshotTransformer<T> entryValueTransformer) {
+			this.entryValueTransformer = entryValueTransformer;
+			this.transformStrategy = entryValueTransformer instanceof CollectionStateSnapshotTransformer ?
+				((CollectionStateSnapshotTransformer) entryValueTransformer).getFilterStrategy() :
+				TransformStrategy.TRANSFORM_ALL;
+		}
+
+		@Override
+		@Nullable
+		public List<T> filterOrTransform(@Nullable List<T> list) {
+			if (list == null) {
+				return null;
+			}
+			List<T> transformedList = new ArrayList<>();
+			boolean anyChange = false;
+			for (int i = 0; i < list.size(); i++) {
+				T entry = list.get(i);
+				T transformedEntry = entryValueTransformer.filterOrTransform(entry);
+				if (transformedEntry != null) {
+					if (transformStrategy == STOP_ON_FIRST_INCLUDED) {
+						transformedList = list.subList(i, list.size());
+						anyChange = i > 0;
+						break;
+					} else {
+						transformedList.add(transformedEntry);
+					}
+				}
+				anyChange |= transformedEntry == null || !Objects.equals(entry, transformedEntry);
+			}
+			transformedList = anyChange ? transformedList : list;
+			return transformedList.isEmpty() ? null : transformedList;
+		}
+	}
+
+	/**
+	 * General implementation of map state transformer.
+	 *
+	 * <p>This transformer wraps a transformer per-entry
+	 * and transforms the whole map state.
+	 */
+	class MapStateSnapshotTransformer<K, V> implements StateSnapshotTransformer<Map<K, V>> {
+		private final StateSnapshotTransformer<V> entryValueTransformer;
+
+		public MapStateSnapshotTransformer(StateSnapshotTransformer<V> entryValueTransformer) {
+			this.entryValueTransformer = entryValueTransformer;
+		}
+
+		@Nullable
+		@Override
+		public Map<K, V> filterOrTransform(@Nullable Map<K, V> map) {
+			if (map == null) {
+				return null;
+			}
+			Map<K, V> transformedMap = new HashMap<>();
+			boolean anyChange = false;
+			for (Map.Entry<K, V> entry : map.entrySet()) {
+				V transformedValue = entryValueTransformer.filterOrTransform(entry.getValue());
+				if (transformedValue != null) {
+					transformedMap.put(entry.getKey(), transformedValue);
+				}
+				anyChange |= transformedValue == null || !Objects.equals(entry.getValue(), transformedValue);
+			}
+			return anyChange ? (transformedMap.isEmpty() ? null : transformedMap) : map;
+		}
+	}
+
+	/**
+	 * This factory creates state transformers depending on the form of values to transform.
+	 *
+	 * <p>If there is no transforming needed, the factory methods return {@code Optional.empty()}.
+	 */
+	interface StateSnapshotTransformFactory<T> {
+		StateSnapshotTransformFactory<?> NO_TRANSFORM = createNoTransform();
+
+		@SuppressWarnings("unchecked")
+		static <T> StateSnapshotTransformFactory<T> noTransform() {
+			return (StateSnapshotTransformFactory<T>) NO_TRANSFORM;
+		}
+
+		static <T> StateSnapshotTransformFactory<T> createNoTransform() {
+			return new StateSnapshotTransformFactory<T>() {
+				@Override
+				public Optional<StateSnapshotTransformer<T>> createForDeserializedState() {
+					return Optional.empty();
+				}
+
+				@Override
+				public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+					return Optional.empty();
+				}
+			};
+		}
+
+		Optional<StateSnapshotTransformer<T>> createForDeserializedState();
+
+		Optional<StateSnapshotTransformer<byte[]>> createForSerializedState();
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
index f3f21dd..11a23bf 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
@@ -22,8 +22,10 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.KeyGroupPartitioner;
+import org.apache.flink.runtime.state.KeyGroupPartitioner.ElementWriterFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 
 import javax.annotation.Nonnegative;
@@ -34,8 +36,8 @@ import javax.annotation.Nullable;
  * This class represents the snapshot of a {@link CopyOnWriteStateTable} and has a role in operator state checkpointing. Besides
  * holding the {@link CopyOnWriteStateTable}s internal entries at the time of the snapshot, this class is also responsible for
  * preparing and writing the state in the process of checkpointing.
- * <p>
- * IMPORTANT: Please notice that snapshot integrity of entries in this class rely on proper copy-on-write semantics
+ *
+ * <p>IMPORTANT: Please notice that snapshot integrity of entries in this class rely on proper copy-on-write semantics
  * through the {@link CopyOnWriteStateTable} that created the snapshot object, but all objects in this snapshot must be considered
  * as READ-ONLY!. The reason is that the objects held by this class may or may not be deep copies of original objects
  * that may still used in the {@link CopyOnWriteStateTable}. This depends for each entry on whether or not it was subject to
@@ -105,7 +107,6 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 		this.snapshotVersion = owningStateTable.getStateTableVersion();
 		this.numberOfEntriesInSnapshotData = owningStateTable.size();
 
-
 		// We create duplicates of the serializers for the async snapshot, because TypeSerializer
 		// might be stateful and shared with the event processing thread.
 		this.localKeySerializer = owningStateTable.keyContext.getKeySerializer().duplicate();
@@ -128,35 +129,41 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 	 * into key-groups. Then, the histogram is accumulated to obtain the boundaries of each key-group in an array.
 	 * Last, we use the accumulated counts as write position pointers for the key-group's bins when reordering the
 	 * entries by key-group. This operation is lazily performed before the first writing of a key-group.
-	 * <p>
-	 * As a possible future optimization, we could perform the repartitioning in-place, using a scheme similar to the
+	 *
+	 * <p>As a possible future optimization, we could perform the repartitioning in-place, using a scheme similar to the
 	 * cuckoo cycles in cuckoo hashing. This can trade some performance for a smaller memory footprint.
 	 */
 	@Nonnull
 	@SuppressWarnings("unchecked")
 	@Override
 	public StateKeyGroupWriter getKeyGroupWriter() {
-
 		if (partitionedStateTableSnapshot == null) {
-
 			final InternalKeyContext<K> keyContext = owningStateTable.keyContext;
-			final KeyGroupRange keyGroupRange = keyContext.getKeyGroupRange();
 			final int numberOfKeyGroups = keyContext.getNumberOfKeyGroups();
-
-			final StateTableKeyGroupPartitioner<K, N, S> keyGroupPartitioner = new StateTableKeyGroupPartitioner<>(
-				snapshotData,
-				numberOfEntriesInSnapshotData,
-				keyGroupRange,
-				numberOfKeyGroups,
+			final KeyGroupRange keyGroupRange = keyContext.getKeyGroupRange();
+			ElementWriterFunction<CopyOnWriteStateTable.StateTableEntry<K, N, S>> elementWriterFunction =
 				(element, dov) -> {
 					localNamespaceSerializer.serialize(element.namespace, dov);
 					localKeySerializer.serialize(element.key, dov);
 					localStateSerializer.serialize(element.state, dov);
-				});
-
-			partitionedStateTableSnapshot = keyGroupPartitioner.partitionByKeyGroup();
+				};
+			StateSnapshotTransformer<S> stateSnapshotTransformer = owningStateTable.metaInfo.getSnapshotTransformer();
+			StateTableKeyGroupPartitioner<K, N, S> stateTableKeyGroupPartitioner = stateSnapshotTransformer != null ?
+				new TransformingStateTableKeyGroupPartitioner<>(
+					snapshotData,
+					numberOfEntriesInSnapshotData,
+					keyGroupRange,
+					numberOfKeyGroups,
+					elementWriterFunction,
+					stateSnapshotTransformer) :
+				new StateTableKeyGroupPartitioner<>(
+					snapshotData,
+					numberOfEntriesInSnapshotData,
+					keyGroupRange,
+					numberOfKeyGroups,
+					elementWriterFunction);
+			partitionedStateTableSnapshot = stateTableKeyGroupPartitioner.partitionByKeyGroup();
 		}
-
 		return partitionedStateTableSnapshot;
 	}
 
@@ -188,7 +195,7 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 	 * @param <S> type of state value.
 	 */
 	@VisibleForTesting
-	protected static final class StateTableKeyGroupPartitioner<K, N, S>
+	protected static class StateTableKeyGroupPartitioner<K, N, S>
 		extends KeyGroupPartitioner<CopyOnWriteStateTable.StateTableEntry<K, N, S>> {
 
 		@SuppressWarnings("unchecked")
@@ -217,12 +224,67 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
 			int flattenIndex = 0;
 			for (CopyOnWriteStateTable.StateTableEntry<K, N, S> entry : partitioningDestination) {
 				while (null != entry) {
-					final int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(entry.key, totalKeyGroups);
-					reportKeyGroupOfElementAtIndex(flattenIndex, keyGroup);
-					partitioningSource[flattenIndex++] = entry;
+					flattenIndex = tryAddToSource(flattenIndex, entry);
 					entry = entry.next;
 				}
 			}
 		}
+
+		/** Tries to append next entry to {@code partitioningSource} array snapshot and returns next index.*/
+		int tryAddToSource(int currentIndex, CopyOnWriteStateTable.StateTableEntry<K, N, S> entry) {
+			final int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(entry.key, totalKeyGroups);
+			reportKeyGroupOfElementAtIndex(currentIndex, keyGroup);
+			partitioningSource[currentIndex] = entry;
+			return currentIndex + 1;
+		}
+	}
+
+	/**
+	 * Extended state snapshot transforming {@link StateTableKeyGroupPartitioner}.
+	 *
+	 * <p>This partitioner can additionally transform state before including or not into the snapshot.
+	 */
+	protected static final class TransformingStateTableKeyGroupPartitioner<K, N, S>
+		extends StateTableKeyGroupPartitioner<K, N, S> {
+		private final StateSnapshotTransformer<S> stateSnapshotTransformer;
+
+		TransformingStateTableKeyGroupPartitioner(
+			@Nonnull CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshotData,
+			int stateTableSize,
+			@Nonnull KeyGroupRange keyGroupRange,
+			int totalKeyGroups,
+			@Nonnull ElementWriterFunction<CopyOnWriteStateTable.StateTableEntry<K, N, S>> elementWriterFunction,
+			@Nonnull StateSnapshotTransformer<S> stateSnapshotTransformer) {
+			super(
+				snapshotData,
+				stateTableSize,
+				keyGroupRange,
+				totalKeyGroups,
+				elementWriterFunction);
+			this.stateSnapshotTransformer = stateSnapshotTransformer;
+		}
+
+		@Override
+		int tryAddToSource(int currentIndex, CopyOnWriteStateTable.StateTableEntry<K, N, S> entry) {
+			CopyOnWriteStateTable.StateTableEntry<K, N, S> filteredEntry = filterEntry(entry);
+			if (filteredEntry != null) {
+				return tryAddToSource(currentIndex, filteredEntry);
+			}
+			return currentIndex;
+		}
+
+		private CopyOnWriteStateTable.StateTableEntry<K, N, S> filterEntry(
+			CopyOnWriteStateTable.StateTableEntry<K, N, S> entry) {
+			S transformedValue = stateSnapshotTransformer.filterOrTransform(entry.state);
+			if (transformedValue != null) {
+				CopyOnWriteStateTable.StateTableEntry<K, N, S> filteredEntry = entry;
+				if (transformedValue != entry.state) {
+					filteredEntry = new CopyOnWriteStateTable.StateTableEntry<>(entry, entry.entryVersion);
+					filteredEntry.state = transformedValue;
+				}
+				return filteredEntry;
+			}
+			return null;
+		}
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 34c9698..6d2bfef 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -62,8 +62,10 @@ import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
 import org.apache.flink.runtime.state.StateSnapshotRestore;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
@@ -89,6 +91,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.concurrent.RunnableFuture;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
@@ -247,7 +250,9 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	private <N, V> StateTable<K, N, V> tryRegisterStateTable(
-			TypeSerializer<N> namespaceSerializer, StateDescriptor<?, V> stateDesc) throws StateMigrationException {
+			TypeSerializer<N> namespaceSerializer,
+			StateDescriptor<?, V> stateDesc,
+			StateSnapshotTransformer<V> snapshotTransformer) throws StateMigrationException {
 
 		@SuppressWarnings("unchecked")
 		StateTable<K, N, V> stateTable = (StateTable<K, N, V>) registeredKVStates.get(stateDesc.getName());
@@ -267,7 +272,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			newMetaInfo = RegisteredKeyValueStateBackendMetaInfo.resolveKvStateCompatibility(
 				restoredMetaInfoSnapshot,
 				namespaceSerializer,
-				stateDesc);
+				stateDesc,
+				snapshotTransformer);
 
 			stateTable.setMetaInfo(newMetaInfo);
 		} else {
@@ -275,7 +281,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				stateDesc.getType(),
 				stateDesc.getName(),
 				namespaceSerializer,
-				stateDesc.getSerializer());
+				stateDesc.getSerializer(),
+				snapshotTransformer);
 
 			stateTable = snapshotStrategy.newStateTable(newMetaInfo);
 			registeredKVStates.put(stateDesc.getName(), stateTable);
@@ -301,19 +308,42 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@Override
-	public <N, SV, S extends State, IS extends S> IS createInternalState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
+	@Nonnull
+	public <N, SV, SEV, S extends State, IS extends S> IS createInternalState(
+		@Nonnull TypeSerializer<N> namespaceSerializer,
+		@Nonnull StateDescriptor<S, SV> stateDesc,
+		@Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception {
 		StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
 		if (stateFactory == null) {
 			String message = String.format("State %s is not supported by %s",
 				stateDesc.getClass(), this.getClass());
 			throw new FlinkRuntimeException(message);
 		}
-		StateTable<K, N, SV> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc);
+		StateTable<K, N, SV> stateTable = tryRegisterStateTable(
+			namespaceSerializer, stateDesc, getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
 		return stateFactory.createState(stateDesc, stateTable, keySerializer);
 	}
 
+	@SuppressWarnings("unchecked")
+	private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
+		StateDescriptor<?, SV> stateDesc,
+		StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+		Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
+		if (original.isPresent()) {
+			if (stateDesc instanceof ListStateDescriptor) {
+				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
+					.ListStateSnapshotTransformer<>(original.get());
+			} else if (stateDesc instanceof MapStateDescriptor) {
+				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
+					.MapStateSnapshotTransformer<>(original.get());
+			} else {
+				return (StateSnapshotTransformer<SV>) original.get();
+			}
+		} else {
+			return null;
+		}
+	}
+
 	@Override
 	@SuppressWarnings("unchecked")
 	public  RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
index efed1cc..f982370 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
@@ -15,6 +15,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+
 package org.apache.flink.runtime.state.heap;
 
 import org.apache.flink.annotation.Internal;
@@ -24,6 +25,7 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.StateTransformationFunction;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 import org.apache.flink.util.Preconditions;
@@ -41,8 +43,8 @@ import java.util.stream.Stream;
 /**
  * This implementation of {@link StateTable} uses nested {@link HashMap} objects. It is also maintaining a partitioning
  * by key-group.
- * <p>
- * In contrast to {@link CopyOnWriteStateTable}, this implementation does not support asynchronous snapshots. However,
+ *
+ * <p>In contrast to {@link CopyOnWriteStateTable}, this implementation does not support asynchronous snapshots. However,
  * it might have a better memory footprint for some use-cases, e.g. it is naturally de-duplicating namespace objects.
  *
  * @param <K> type of key.
@@ -59,7 +61,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	private final Map<N, Map<K, S>>[] state;
 
 	/**
-	 * The offset to the contiguous key groups
+	 * The offset to the contiguous key groups.
 	 */
 	private final int keyGroupOffset;
 
@@ -317,7 +319,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	@Nonnull
 	@Override
 	public NestedMapsStateTableSnapshot<K, N, S> stateSnapshot() {
-		return new NestedMapsStateTableSnapshot<>(this);
+		return new NestedMapsStateTableSnapshot<>(this, metaInfo.getSnapshotTransformer());
 	}
 
 	/**
@@ -330,9 +332,17 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 	static class NestedMapsStateTableSnapshot<K, N, S>
 			extends AbstractStateTableSnapshot<K, N, S, NestedMapsStateTable<K, N, S>>
 			implements StateSnapshot.StateKeyGroupWriter {
+		private final TypeSerializer<K> keySerializer;
+		private final TypeSerializer<N> namespaceSerializer;
+		private final TypeSerializer<S> stateSerializer;
+		private final StateSnapshotTransformer<S> snapshotFilter;
 
-		NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> owningTable) {
+		NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> owningTable, StateSnapshotTransformer<S> snapshotFilter) {
 			super(owningTable);
+			this.snapshotFilter = snapshotFilter;
+			this.keySerializer = owningStateTable.keyContext.getKeySerializer();
+			this.namespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer();
+			this.stateSerializer = owningStateTable.metaInfo.getStateSerializer();
 		}
 
 		@Nonnull
@@ -350,8 +360,8 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 		/**
 		 * Implementation note: we currently chose the same format between {@link NestedMapsStateTable} and
 		 * {@link CopyOnWriteStateTable}.
-		 * <p>
-		 * {@link NestedMapsStateTable} could naturally support a kind of
+		 *
+		 * <p>{@link NestedMapsStateTable} could naturally support a kind of
 		 * prefix-compressed format (grouping by namespace, writing the namespace only once per group instead for each
 		 * mapping). We might implement support for different formats later (tailored towards different state table
 		 * implementations).
@@ -360,23 +370,45 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
 		public void writeStateInKeyGroup(@Nonnull DataOutputView dov, int keyGroupId) throws IOException {
 			final Map<N, Map<K, S>> keyGroupMap = owningStateTable.getMapForKeyGroup(keyGroupId);
 			if (null != keyGroupMap) {
-				TypeSerializer<K> keySerializer = owningStateTable.keyContext.getKeySerializer();
-				TypeSerializer<N> namespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer();
-				TypeSerializer<S> stateSerializer = owningStateTable.metaInfo.getStateSerializer();
-				dov.writeInt(countMappingsInKeyGroup(keyGroupMap));
-				for (Map.Entry<N, Map<K, S>> namespaceEntry : keyGroupMap.entrySet()) {
+				Map<N, Map<K, S>> filteredMappings = filterMappingsInKeyGroupIfNeeded(keyGroupMap);
+				dov.writeInt(countMappingsInKeyGroup(filteredMappings));
+				for (Map.Entry<N, Map<K, S>> namespaceEntry : filteredMappings.entrySet()) {
 					final N namespace = namespaceEntry.getKey();
 					final Map<K, S> namespaceMap = namespaceEntry.getValue();
-
 					for (Map.Entry<K, S> keyEntry : namespaceMap.entrySet()) {
-						namespaceSerializer.serialize(namespace, dov);
-						keySerializer.serialize(keyEntry.getKey(), dov);
-						stateSerializer.serialize(keyEntry.getValue(), dov);
+						writeElement(namespace, keyEntry, dov);
 					}
 				}
 			} else {
 				dov.writeInt(0);
 			}
 		}
+
+		private void writeElement(N namespace, Map.Entry<K, S> keyEntry, DataOutputView dov) throws IOException {
+			namespaceSerializer.serialize(namespace, dov);
+			keySerializer.serialize(keyEntry.getKey(), dov);
+			stateSerializer.serialize(keyEntry.getValue(), dov);
+		}
+
+		private Map<N, Map<K, S>> filterMappingsInKeyGroupIfNeeded(final Map<N, Map<K, S>> keyGroupMap) {
+			return snapshotFilter == null ?
+				keyGroupMap : filterMappingsInKeyGroup(keyGroupMap);
+		}
+
+		private Map<N, Map<K, S>> filterMappingsInKeyGroup(final Map<N, Map<K, S>> keyGroupMap) {
+			Map<N, Map<K, S>> filtered = new HashMap<>();
+			for (Map.Entry<N, Map<K, S>> namespaceEntry : keyGroupMap.entrySet()) {
+				N namespace = namespaceEntry.getKey();
+				Map<K, S> filteredNamespaceMap = filtered.computeIfAbsent(namespace, n -> new HashMap<>());
+				for (Map.Entry<K, S> keyEntry : namespaceEntry.getValue().entrySet()) {
+					K key = keyEntry.getKey();
+					S transformedvalue = snapshotFilter.filterOrTransform(keyEntry.getValue());
+					if (transformedvalue != null) {
+						filteredNamespaceMap.put(key, transformedvalue);
+					}
+				}
+			}
+			return filtered;
+		}
 	}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
index 58a2f97..30f96f3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
@@ -46,7 +46,7 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
 	protected final InternalKeyContext<K> keyContext;
 
 	/**
-	 * Combined meta information such as name and serializers for this state
+	 * Combined meta information such as name and serializers for this state.
 	 */
 	protected RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo;
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlDecorator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlDecorator.java
index 1b72c54..268f84a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlDecorator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlDecorator.java
@@ -18,14 +18,12 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.SupplierWithException;
 import org.apache.flink.util.function.ThrowingConsumer;
 import org.apache.flink.util.function.ThrowingRunnable;
 
-import javax.annotation.Nonnull;
-
 /**
  * Base class for TTL logic wrappers.
  *
@@ -35,7 +33,7 @@ abstract class AbstractTtlDecorator<T> {
 	/** Wrapped original state handler. */
 	final T original;
 
-	final StateTtlConfiguration config;
+	final StateTtlConfig config;
 
 	final TtlTimeProvider timeProvider;
 
@@ -50,7 +48,7 @@ abstract class AbstractTtlDecorator<T> {
 
 	AbstractTtlDecorator(
 		T original,
-		StateTtlConfiguration config,
+		StateTtlConfig config,
 		TtlTimeProvider timeProvider) {
 		Preconditions.checkNotNull(original);
 		Preconditions.checkNotNull(config);
@@ -58,8 +56,8 @@ abstract class AbstractTtlDecorator<T> {
 		this.original = original;
 		this.config = config;
 		this.timeProvider = timeProvider;
-		this.updateTsOnRead = config.getTtlUpdateType() == StateTtlConfiguration.TtlUpdateType.OnReadAndWrite;
-		this.returnExpired = config.getStateVisibility() == StateTtlConfiguration.TtlStateVisibility.ReturnExpiredIfNotCleanedUp;
+		this.updateTsOnRead = config.getUpdateType() == StateTtlConfig.UpdateType.OnReadAndWrite;
+		this.returnExpired = config.getStateVisibility() == StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp;
 		this.ttl = config.getTtl().toMilliseconds();
 	}
 
@@ -68,21 +66,11 @@ abstract class AbstractTtlDecorator<T> {
 	}
 
 	<V> boolean expired(TtlValue<V> ttlValue) {
-		return ttlValue != null && getExpirationTimestamp(ttlValue) <= timeProvider.currentTimestamp();
-	}
-
-	private long getExpirationTimestamp(@Nonnull TtlValue<?> ttlValue) {
-		long ts = ttlValue.getLastAccessTimestamp();
-		long ttlWithoutOverflow = ts > 0 ? Math.min(Long.MAX_VALUE - ts, ttl) : ttl;
-		return ts + ttlWithoutOverflow;
+		return TtlUtils.expired(ttlValue, ttl, timeProvider);
 	}
 
 	<V> TtlValue<V> wrapWithTs(V value) {
-		return wrapWithTs(value, timeProvider.currentTimestamp());
-	}
-
-	static <V> TtlValue<V> wrapWithTs(V value, long ts) {
-		return value == null ? null : new TtlValue<>(value, ts);
+		return TtlUtils.wrapWithTs(value, timeProvider.currentTimestamp());
 	}
 
 	<V> TtlValue<V> rewrapWithNewTs(TtlValue<V> ttlValue) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlState.java
index 7903796..5d1af8a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlState.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.internal.InternalKvState;
 import org.apache.flink.util.FlinkRuntimeException;
@@ -39,7 +39,7 @@ abstract class AbstractTtlState<K, N, SV, TTLSV, S extends InternalKvState<K, N,
 	implements InternalKvState<K, N, SV> {
 	private final TypeSerializer<SV> valueSerializer;
 
-	AbstractTtlState(S original, StateTtlConfiguration config, TtlTimeProvider timeProvider, TypeSerializer<SV> valueSerializer) {
+	AbstractTtlState(S original, StateTtlConfig config, TtlTimeProvider timeProvider, TypeSerializer<SV> valueSerializer) {
 		super(original, config, timeProvider);
 		this.valueSerializer = valueSerializer;
 	}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregateFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregateFunction.java
index 5448ba1..07f538a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregateFunction.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregateFunction.java
@@ -19,7 +19,7 @@
 package org.apache.flink.runtime.state.ttl;
 
 import org.apache.flink.api.common.functions.AggregateFunction;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.ThrowingConsumer;
@@ -38,7 +38,7 @@ class TtlAggregateFunction<IN, ACC, OUT>
 	ThrowingRunnable<Exception> stateClear;
 	ThrowingConsumer<TtlValue<ACC>, Exception> updater;
 
-	TtlAggregateFunction(AggregateFunction<IN, ACC, OUT> aggFunction, StateTtlConfiguration config, TtlTimeProvider timeProvider) {
+	TtlAggregateFunction(AggregateFunction<IN, ACC, OUT> aggFunction, StateTtlConfig config, TtlTimeProvider timeProvider) {
 		super(aggFunction, config, timeProvider);
 	}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregatingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregatingState.java
index a90698e..94f489d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregatingState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregatingState.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.internal.InternalAggregatingState;
 
@@ -40,7 +40,7 @@ class TtlAggregatingState<K, N, IN, ACC, OUT>
 
 	TtlAggregatingState(
 		InternalAggregatingState<K, N, IN, TtlValue<ACC>, OUT> originalState,
-		StateTtlConfiguration config,
+		StateTtlConfig config,
 		TtlTimeProvider timeProvider,
 		TypeSerializer<ACC> valueSerializer,
 		TtlAggregateFunction<IN, ACC, OUT> aggregateFunction) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldFunction.java
index c7305bd..a8f4e6a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldFunction.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldFunction.java
@@ -19,7 +19,7 @@
 package org.apache.flink.runtime.state.ttl;
 
 import org.apache.flink.api.common.functions.FoldFunction;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 
 /**
  * This class wraps folding function with TTL logic.
@@ -36,7 +36,7 @@ class TtlFoldFunction<T, ACC>
 	private final ACC defaultAccumulator;
 
 	TtlFoldFunction(
-		FoldFunction<T, ACC> original, StateTtlConfiguration config, TtlTimeProvider timeProvider, ACC defaultAccumulator) {
+		FoldFunction<T, ACC> original, StateTtlConfig config, TtlTimeProvider timeProvider, ACC defaultAccumulator) {
 		super(original, config, timeProvider);
 		this.defaultAccumulator = defaultAccumulator;
 	}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldingState.java
index c3a75e4..4c64ae3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldingState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldingState.java
@@ -19,7 +19,7 @@
 package org.apache.flink.runtime.state.ttl;
 
 import org.apache.flink.api.common.state.AggregatingState;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.internal.InternalFoldingState;
 
@@ -37,7 +37,7 @@ class TtlFoldingState<K, N, T, ACC>
 	implements InternalFoldingState<K, N, T, ACC> {
 	TtlFoldingState(
 		InternalFoldingState<K, N, T, TtlValue<ACC>> originalState,
-		StateTtlConfiguration config,
+		StateTtlConfig config,
 		TtlTimeProvider timeProvider,
 		TypeSerializer<ACC> valueSerializer) {
 		super(originalState, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlListState.java
index 77e92f6..cb64df7 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlListState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlListState.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.internal.InternalListState;
 import org.apache.flink.util.Preconditions;
@@ -43,7 +43,7 @@ class TtlListState<K, N, T> extends
 	implements InternalListState<K, N, T> {
 	TtlListState(
 		InternalListState<K, N, TtlValue<T>> originalState,
-		StateTtlConfiguration config,
+		StateTtlConfig config,
 		TtlTimeProvider timeProvider,
 		TypeSerializer<List<T>> valueSerializer) {
 		super(originalState, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlMapState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlMapState.java
index 21145e5..f6f81ff 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlMapState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlMapState.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.internal.InternalMapState;
 import org.apache.flink.util.FlinkRuntimeException;
@@ -46,7 +46,7 @@ class TtlMapState<K, N, UK, UV>
 	implements InternalMapState<K, N, UK, UV> {
 	TtlMapState(
 		InternalMapState<K, N, UK, TtlValue<UV>> original,
-		StateTtlConfiguration config,
+		StateTtlConfig config,
 		TtlTimeProvider timeProvider,
 		TypeSerializer<Map<UK, UV>> valueSerializer) {
 		super(original, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReduceFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReduceFunction.java
index fa7c8bf..823c55d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReduceFunction.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReduceFunction.java
@@ -19,7 +19,7 @@
 package org.apache.flink.runtime.state.ttl;
 
 import org.apache.flink.api.common.functions.ReduceFunction;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 
 /**
  * This class wraps reducing function with TTL logic.
@@ -32,7 +32,7 @@ class TtlReduceFunction<T>
 
 	TtlReduceFunction(
 		ReduceFunction<T> originalReduceFunction,
-		StateTtlConfiguration config,
+		StateTtlConfig config,
 		TtlTimeProvider timeProvider) {
 		super(originalReduceFunction, config, timeProvider);
 	}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReducingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReducingState.java
index c0aa465..0710808 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReducingState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReducingState.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.internal.InternalReducingState;
 
@@ -36,7 +36,7 @@ class TtlReducingState<K, N, T>
 	implements InternalReducingState<K, N, T> {
 	TtlReducingState(
 		InternalReducingState<K, N, TtlValue<T>> originalState,
-		StateTtlConfiguration config,
+		StateTtlConfig config,
 		TtlTimeProvider timeProvider,
 		TypeSerializer<T> valueSerializer) {
 		super(originalState, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
index e12ba90..303285a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
@@ -25,15 +25,17 @@ import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.api.common.state.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.typeutils.CompositeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.state.KeyedStateFactory;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SupplierWithException;
 
 import javax.annotation.Nonnull;
 
@@ -44,7 +46,7 @@ import java.util.stream.Stream;
 /**
  * This state factory wraps state objects, produced by backends, with TTL logic.
  */
-public class TtlStateFactory {
+public class TtlStateFactory<N, SV, S extends State, IS extends S> {
 	public static <N, SV, S extends State, IS extends S> IS createStateAndWrapWithTtlIfEnabled(
 		TypeSerializer<N> namespaceSerializer,
 		StateDescriptor<S, SV> stateDesc,
@@ -55,103 +57,103 @@ public class TtlStateFactory {
 		Preconditions.checkNotNull(originalStateFactory);
 		Preconditions.checkNotNull(timeProvider);
 		return  stateDesc.getTtlConfig().isEnabled() ?
-			new TtlStateFactory(originalStateFactory, stateDesc.getTtlConfig(), timeProvider)
-				.createState(namespaceSerializer, stateDesc) :
+			new TtlStateFactory<N, SV, S, IS>(
+				namespaceSerializer, stateDesc, originalStateFactory, timeProvider)
+				.createState() :
 			originalStateFactory.createInternalState(namespaceSerializer, stateDesc);
 	}
 
-	private final Map<Class<? extends StateDescriptor>, KeyedStateFactory> stateFactories;
+	private final Map<Class<? extends StateDescriptor>, SupplierWithException<IS, Exception>> stateFactories;
 
+	private final TypeSerializer<N> namespaceSerializer;
+	private final StateDescriptor<S, SV> stateDesc;
 	private final KeyedStateFactory originalStateFactory;
-	private final StateTtlConfiguration ttlConfig;
+	private final StateTtlConfig ttlConfig;
 	private final TtlTimeProvider timeProvider;
+	private final long ttl;
 
-	private TtlStateFactory(KeyedStateFactory originalStateFactory, StateTtlConfiguration ttlConfig, TtlTimeProvider timeProvider) {
+	private TtlStateFactory(
+		TypeSerializer<N> namespaceSerializer,
+		StateDescriptor<S, SV> stateDesc,
+		KeyedStateFactory originalStateFactory,
+		TtlTimeProvider timeProvider) {
+		this.namespaceSerializer = namespaceSerializer;
+		this.stateDesc = stateDesc;
 		this.originalStateFactory = originalStateFactory;
-		this.ttlConfig = ttlConfig;
+		this.ttlConfig = stateDesc.getTtlConfig();
 		this.timeProvider = timeProvider;
+		this.ttl = ttlConfig.getTtl().toMilliseconds();
 		this.stateFactories = createStateFactories();
 	}
 
 	@SuppressWarnings("deprecation")
-	private Map<Class<? extends StateDescriptor>, KeyedStateFactory> createStateFactories() {
+	private Map<Class<? extends StateDescriptor>, SupplierWithException<IS, Exception>> createStateFactories() {
 		return Stream.of(
-			Tuple2.of(ValueStateDescriptor.class, (KeyedStateFactory) this::createValueState),
-			Tuple2.of(ListStateDescriptor.class, (KeyedStateFactory) this::createListState),
-			Tuple2.of(MapStateDescriptor.class, (KeyedStateFactory) this::createMapState),
-			Tuple2.of(ReducingStateDescriptor.class, (KeyedStateFactory) this::createReducingState),
-			Tuple2.of(AggregatingStateDescriptor.class, (KeyedStateFactory) this::createAggregatingState),
-			Tuple2.of(FoldingStateDescriptor.class, (KeyedStateFactory) this::createFoldingState)
+			Tuple2.of(ValueStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createValueState),
+			Tuple2.of(ListStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createListState),
+			Tuple2.of(MapStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createMapState),
+			Tuple2.of(ReducingStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createReducingState),
+			Tuple2.of(AggregatingStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createAggregatingState),
+			Tuple2.of(FoldingStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createFoldingState)
 		).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 	}
 
-	private <N, SV, S extends State, IS extends S> IS createState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
-		KeyedStateFactory stateFactory = stateFactories.get(stateDesc.getClass());
+	private IS createState() throws Exception {
+		SupplierWithException<IS, Exception> stateFactory = stateFactories.get(stateDesc.getClass());
 		if (stateFactory == null) {
 			String message = String.format("State %s is not supported by %s",
 				stateDesc.getClass(), TtlStateFactory.class);
 			throw new FlinkRuntimeException(message);
 		}
-		return stateFactory.createInternalState(namespaceSerializer, stateDesc);
+		return stateFactory.get();
 	}
 
 	@SuppressWarnings("unchecked")
-	private <N, SV, S extends State, IS extends S> IS createValueState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
+	private IS createValueState() throws Exception {
 		ValueStateDescriptor<TtlValue<SV>> ttlDescriptor = new ValueStateDescriptor<>(
 			stateDesc.getName(), new TtlSerializer<>(stateDesc.getSerializer()));
 		return (IS) new TtlValueState<>(
-			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
 			ttlConfig, timeProvider, stateDesc.getSerializer());
 	}
 
 	@SuppressWarnings("unchecked")
-	private <T, N, SV, S extends State, IS extends S> IS createListState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
+	private <T> IS createListState() throws Exception {
 		ListStateDescriptor<T> listStateDesc = (ListStateDescriptor<T>) stateDesc;
 		ListStateDescriptor<TtlValue<T>> ttlDescriptor = new ListStateDescriptor<>(
 			stateDesc.getName(), new TtlSerializer<>(listStateDesc.getElementSerializer()));
 		return (IS) new TtlListState<>(
-			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+			originalStateFactory.createInternalState(
+				namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
 			ttlConfig, timeProvider, listStateDesc.getSerializer());
 	}
 
 	@SuppressWarnings("unchecked")
-	private <UK, UV, N, SV, S extends State, IS extends S> IS createMapState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
+	private <UK, UV> IS createMapState() throws Exception {
 		MapStateDescriptor<UK, UV> mapStateDesc = (MapStateDescriptor<UK, UV>) stateDesc;
 		MapStateDescriptor<UK, TtlValue<UV>> ttlDescriptor = new MapStateDescriptor<>(
 			stateDesc.getName(),
 			mapStateDesc.getKeySerializer(),
 			new TtlSerializer<>(mapStateDesc.getValueSerializer()));
 		return (IS) new TtlMapState<>(
-			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
 			ttlConfig, timeProvider, mapStateDesc.getSerializer());
 	}
 
 	@SuppressWarnings("unchecked")
-	private <N, SV, S extends State, IS extends S> IS createReducingState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
+	private IS createReducingState() throws Exception {
 		ReducingStateDescriptor<SV> reducingStateDesc = (ReducingStateDescriptor<SV>) stateDesc;
 		ReducingStateDescriptor<TtlValue<SV>> ttlDescriptor = new ReducingStateDescriptor<>(
 			stateDesc.getName(),
 			new TtlReduceFunction<>(reducingStateDesc.getReduceFunction(), ttlConfig, timeProvider),
 			new TtlSerializer<>(stateDesc.getSerializer()));
 		return (IS) new TtlReducingState<>(
-			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
 			ttlConfig, timeProvider, stateDesc.getSerializer());
 	}
 
 	@SuppressWarnings("unchecked")
-	private <IN, OUT, N, SV, S extends State, IS extends S> IS createAggregatingState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
+	private <IN, OUT> IS createAggregatingState() throws Exception {
 		AggregatingStateDescriptor<IN, SV, OUT> aggregatingStateDescriptor =
 			(AggregatingStateDescriptor<IN, SV, OUT>) stateDesc;
 		TtlAggregateFunction<IN, SV, OUT> ttlAggregateFunction = new TtlAggregateFunction<>(
@@ -159,14 +161,12 @@ public class TtlStateFactory {
 		AggregatingStateDescriptor<IN, TtlValue<SV>, OUT> ttlDescriptor = new AggregatingStateDescriptor<>(
 			stateDesc.getName(), ttlAggregateFunction, new TtlSerializer<>(stateDesc.getSerializer()));
 		return (IS) new TtlAggregatingState<>(
-			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
 			ttlConfig, timeProvider, stateDesc.getSerializer(), ttlAggregateFunction);
 	}
 
 	@SuppressWarnings({"deprecation", "unchecked"})
-	private <T, N, SV, S extends State, IS extends S> IS createFoldingState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
+	private <T> IS createFoldingState() throws Exception {
 		FoldingStateDescriptor<T, SV> foldingStateDescriptor = (FoldingStateDescriptor<T, SV>) stateDesc;
 		SV initAcc = stateDesc.getDefaultValue();
 		TtlValue<SV> ttlInitAcc = initAcc == null ? null : new TtlValue<>(initAcc, Long.MAX_VALUE);
@@ -176,10 +176,18 @@ public class TtlStateFactory {
 			new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider, initAcc),
 			new TtlSerializer<>(stateDesc.getSerializer()));
 		return (IS) new TtlFoldingState<>(
-			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+			originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
 			ttlConfig, timeProvider, stateDesc.getSerializer());
 	}
 
+	private StateSnapshotTransformFactory<?> getSnapshotTransformFactory() {
+		if (!ttlConfig.getCleanupStrategies().inFullSnapshot()) {
+			return StateSnapshotTransformFactory.noTransform();
+		} else {
+			return new TtlStateSnapshotTransformer.Factory<>(timeProvider, ttl);
+		}
+	}
+
 	/** Serializer for user state value with TTL. */
 	private static class TtlSerializer<T> extends CompositeSerializer<TtlValue<T>> {
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
new file mode 100644
index 0000000..228d045
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.ttl;
+
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+/** State snapshot filter of expired values with TTL. */
+abstract class TtlStateSnapshotTransformer<T> implements CollectionStateSnapshotTransformer<T> {
+	private final TtlTimeProvider ttlTimeProvider;
+	final long ttl;
+
+	TtlStateSnapshotTransformer(@Nonnull TtlTimeProvider ttlTimeProvider, long ttl) {
+		this.ttlTimeProvider = ttlTimeProvider;
+		this.ttl = ttl;
+	}
+
+	<V> TtlValue<V> filterTtlValue(TtlValue<V> value) {
+		return expired(value) ? null : value;
+	}
+
+	private boolean expired(TtlValue<?> ttlValue) {
+		return expired(ttlValue.getLastAccessTimestamp());
+	}
+
+	boolean expired(long ts) {
+		return TtlUtils.expired(ts, ttl, ttlTimeProvider);
+	}
+
+	private static long deserializeTs(
+		byte[] value, int offset) throws IOException {
+		return LongSerializer.INSTANCE.deserialize(
+			new DataInputViewStreamWrapper(new ByteArrayInputStream(value, offset, Long.BYTES)));
+	}
+
+	@Override
+	public TransformStrategy getFilterStrategy() {
+		return TransformStrategy.STOP_ON_FIRST_INCLUDED;
+	}
+
+	static class TtlDeserializedValueStateSnapshotTransformer<T> extends TtlStateSnapshotTransformer<TtlValue<T>> {
+		TtlDeserializedValueStateSnapshotTransformer(TtlTimeProvider ttlTimeProvider, long ttl) {
+			super(ttlTimeProvider, ttl);
+		}
+
+		@Override
+		@Nullable
+		public TtlValue<T> filterOrTransform(@Nullable TtlValue<T> value) {
+			return filterTtlValue(value);
+		}
+	}
+
+	static class TtlSerializedValueStateSnapshotTransformer extends TtlStateSnapshotTransformer<byte[]> {
+		TtlSerializedValueStateSnapshotTransformer(TtlTimeProvider ttlTimeProvider, long ttl) {
+			super(ttlTimeProvider, ttl);
+		}
+
+		@Override
+		@Nullable
+		public byte[] filterOrTransform(@Nullable byte[] value) {
+			if (value == null) {
+				return null;
+			}
+			Preconditions.checkArgument(value.length >= Long.BYTES);
+			long ts;
+			try {
+				ts = deserializeTs(value, value.length - Long.BYTES);
+			} catch (IOException e) {
+				throw new FlinkRuntimeException("Unexpected timestamp deserialization failure");
+			}
+			return expired(ts) ? null : value;
+		}
+	}
+
+	static class Factory<T> implements StateSnapshotTransformFactory<TtlValue<T>> {
+		private final TtlTimeProvider ttlTimeProvider;
+		private final long ttl;
+
+		Factory(@Nonnull TtlTimeProvider ttlTimeProvider, long ttl) {
+			this.ttlTimeProvider = ttlTimeProvider;
+			this.ttl = ttl;
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<TtlValue<T>>> createForDeserializedState() {
+			return Optional.of(new TtlDeserializedValueStateSnapshotTransformer<>(ttlTimeProvider, ttl));
+		}
+
+		@Override
+		public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+			return Optional.of(new TtlSerializedValueStateSnapshotTransformer(ttlTimeProvider, ttl));
+		}
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlUtils.java
new file mode 100644
index 0000000..9d9e5e1
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlUtils.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.ttl;
+
+/** Common functions related to State TTL. */
+class TtlUtils {
+	static <V> boolean expired(TtlValue<V> ttlValue, long ttl, TtlTimeProvider timeProvider) {
+		return ttlValue != null && expired(ttlValue.getLastAccessTimestamp(), ttl, timeProvider);
+	}
+
+	static boolean expired(long ts, long ttl, TtlTimeProvider timeProvider) {
+		return getExpirationTimestamp(ts, ttl) <= timeProvider.currentTimestamp();
+	}
+
+	private static long getExpirationTimestamp(long ts, long ttl) {
+		long ttlWithoutOverflow = ts > 0 ? Math.min(Long.MAX_VALUE - ts, ttl) : ttl;
+		return ts + ttlWithoutOverflow;
+	}
+
+	static <V> TtlValue<V> wrapWithTs(V value, long ts) {
+		return value == null ? null : new TtlValue<>(value, ts);
+	}
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlValueState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlValueState.java
index c14a583..7b19341 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlValueState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlValueState.java
@@ -18,7 +18,7 @@
 
 package org.apache.flink.runtime.state.ttl;
 
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.state.internal.InternalValueState;
 
@@ -36,7 +36,7 @@ class TtlValueState<K, N, T>
 	implements InternalValueState<K, N, T> {
 	TtlValueState(
 		InternalValueState<K, N, TtlValue<T>> originalState,
-		StateTtlConfiguration config,
+		StateTtlConfig config,
 		TtlTimeProvider timeProvider,
 		TypeSerializer<T> valueSerializer) {
 		super(originalState, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
index 5820f13..6b3a15f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
@@ -20,7 +20,7 @@ package org.apache.flink.runtime.state.ttl;
 
 import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.state.KeyedStateHandle;
@@ -49,7 +49,7 @@ public abstract class TtlStateTestBase {
 
 	private MockTtlTimeProvider timeProvider;
 	private StateBackendTestContext sbetc;
-	private StateTtlConfiguration ttlConfig;
+	private StateTtlConfig ttlConfig;
 
 	@Before
 	public void setup() {
@@ -85,24 +85,31 @@ public abstract class TtlStateTestBase {
 	}
 
 	private void initTest() throws Exception {
-		initTest(StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite, StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired);
+		initTest(StateTtlConfig.UpdateType.OnCreateAndWrite, StateTtlConfig.StateVisibility.NeverReturnExpired);
 	}
 
 	private void initTest(
-				StateTtlConfiguration.TtlUpdateType updateType,
-				StateTtlConfiguration.TtlStateVisibility visibility) throws Exception {
+				StateTtlConfig.UpdateType updateType,
+				StateTtlConfig.StateVisibility visibility) throws Exception {
 		initTest(updateType, visibility, TTL);
 	}
 
 	private void initTest(
-		StateTtlConfiguration.TtlUpdateType updateType,
-		StateTtlConfiguration.TtlStateVisibility visibility,
+		StateTtlConfig.UpdateType updateType,
+		StateTtlConfig.StateVisibility visibility,
 		long ttl) throws Exception {
-		ttlConfig = StateTtlConfiguration
-			.newBuilder(Time.milliseconds(ttl))
-			.setTtlUpdateType(updateType)
+		initTest(getConfBuilder(ttl)
+			.setUpdateType(updateType)
 			.setStateVisibility(visibility)
-			.build();
+			.build());
+	}
+
+	private static StateTtlConfig.Builder getConfBuilder(long ttl) {
+		return StateTtlConfig.newBuilder(Time.milliseconds(ttl));
+	}
+
+	private void initTest(StateTtlConfig ttlConfig) throws Exception {
+		this.ttlConfig = ttlConfig;
 		sbetc.createAndRestoreKeyedStateBackend();
 		sbetc.restoreSnapshot(null);
 		createState();
@@ -132,7 +139,7 @@ public abstract class TtlStateTestBase {
 
 	@Test
 	public void testExactExpirationOnWrite() throws Exception {
-		initTest(StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite, StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired);
+		initTest(StateTtlConfig.UpdateType.OnCreateAndWrite, StateTtlConfig.StateVisibility.NeverReturnExpired);
 
 		takeAndRestoreSnapshot();
 
@@ -173,7 +180,7 @@ public abstract class TtlStateTestBase {
 
 	@Test
 	public void testRelaxedExpirationOnWrite() throws Exception {
-		initTest(StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite, StateTtlConfiguration.TtlStateVisibility.ReturnExpiredIfNotCleanedUp);
+		initTest(StateTtlConfig.UpdateType.OnCreateAndWrite, StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp);
 
 		timeProvider.time = 0;
 		ctx().update(ctx().updateEmpty);
@@ -188,7 +195,7 @@ public abstract class TtlStateTestBase {
 
 	@Test
 	public void testExactExpirationOnRead() throws Exception {
-		initTest(StateTtlConfiguration.TtlUpdateType.OnReadAndWrite, StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired);
+		initTest(StateTtlConfig.UpdateType.OnReadAndWrite, StateTtlConfig.StateVisibility.NeverReturnExpired);
 
 		timeProvider.time = 0;
 		ctx().update(ctx().updateEmpty);
@@ -212,7 +219,7 @@ public abstract class TtlStateTestBase {
 
 	@Test
 	public void testRelaxedExpirationOnRead() throws Exception {
-		initTest(StateTtlConfiguration.TtlUpdateType.OnReadAndWrite, StateTtlConfiguration.TtlStateVisibility.ReturnExpiredIfNotCleanedUp);
+		initTest(StateTtlConfig.UpdateType.OnReadAndWrite, StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp);
 
 		timeProvider.time = 0;
 		ctx().update(ctx().updateEmpty);
@@ -231,7 +238,7 @@ public abstract class TtlStateTestBase {
 
 	@Test
 	public void testExpirationTimestampOverflow() throws Exception {
-		initTest(StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite, StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired, Long.MAX_VALUE);
+		initTest(StateTtlConfig.UpdateType.OnCreateAndWrite, StateTtlConfig.StateVisibility.NeverReturnExpired, Long.MAX_VALUE);
 
 		timeProvider.time = 10;
 		ctx().update(ctx().updateEmpty);
@@ -275,16 +282,26 @@ public abstract class TtlStateTestBase {
 
 	@Test
 	public void testMultipleKeys() throws Exception {
-		testMultipleStateIds(id -> sbetc.setCurrentKey(id));
+		testMultipleStateIdsWithSnapshotCleanup(id -> sbetc.setCurrentKey(id));
 	}
 
 	@Test
 	public void testMultipleNamespaces() throws Exception {
-		testMultipleStateIds(id -> ctx().ttlState.setCurrentNamespace(id));
+		testMultipleStateIdsWithSnapshotCleanup(id -> ctx().ttlState.setCurrentNamespace(id));
 	}
 
-	private void testMultipleStateIds(Consumer<String> idChanger) throws Exception {
+	private void testMultipleStateIdsWithSnapshotCleanup(Consumer<String> idChanger) throws Exception {
 		initTest();
+		testMultipleStateIds(idChanger, false);
+
+		initTest(getConfBuilder(TTL).cleanupFullSnapshot().build());
+		// set time back after restore to see entry unexpired if it was not cleaned up in snapshot properly
+		testMultipleStateIds(idChanger, true);
+	}
+
+	private void testMultipleStateIds(Consumer<String> idChanger, boolean timeBackAfterRestore) throws Exception {
+		// test empty storage snapshot/restore
+		takeAndRestoreSnapshot();
 
 		timeProvider.time = 0;
 		idChanger.accept("id2");
@@ -298,9 +315,9 @@ public abstract class TtlStateTestBase {
 		idChanger.accept("id2");
 		ctx().update(ctx().updateUnexpired);
 
+		timeProvider.time = 120;
 		takeAndRestoreSnapshot();
 
-		timeProvider.time = 120;
 		idChanger.accept("id1");
 		assertEquals("Unexpired state should be available", ctx().getUpdateEmpty, ctx().get());
 		idChanger.accept("id2");
@@ -312,17 +329,19 @@ public abstract class TtlStateTestBase {
 		idChanger.accept("id2");
 		ctx().update(ctx().updateExpired);
 
+		timeProvider.time = 230;
 		takeAndRestoreSnapshot();
 
-		timeProvider.time = 230;
+		timeProvider.time = timeBackAfterRestore ? 170 : timeProvider.time;
 		idChanger.accept("id1");
 		assertEquals("Expired state should be unavailable", ctx().emptyValue, ctx().get());
 		idChanger.accept("id2");
 		assertEquals("Unexpired state should be available after update", ctx().getUpdateExpired, ctx().get());
 
+		timeProvider.time = 300;
 		takeAndRestoreSnapshot();
 
-		timeProvider.time = 300;
+		timeProvider.time = timeBackAfterRestore ? 230 : timeProvider.time;
 		idChanger.accept("id1");
 		assertEquals("Expired state should be unavailable", ctx().emptyValue, ctx().get());
 		idChanger.accept("id2");
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index 805ae1c..0b5931c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -37,12 +37,13 @@ import org.apache.flink.runtime.state.KeyExtractorFunction;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
 import org.apache.flink.runtime.state.Keyed;
-import org.apache.flink.runtime.state.KeyedStateFactory;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.PriorityComparable;
 import org.apache.flink.runtime.state.PriorityComparator;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet;
 import org.apache.flink.runtime.state.ttl.TtlStateFactory;
@@ -57,6 +58,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.concurrent.FutureTask;
 import java.util.concurrent.RunnableFuture;
 import java.util.stream.Collectors;
@@ -65,19 +67,27 @@ import java.util.stream.Stream;
 /** State backend which produces in memory mock state objects. */
 public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
+	private interface StateFactory {
+		<N, SV, S extends State, IS extends S> IS createInternalState(
+			TypeSerializer<N> namespaceSerializer,
+			StateDescriptor<S, SV> stateDesc) throws Exception;
+	}
+
 	@SuppressWarnings("deprecation")
-	private static final Map<Class<? extends StateDescriptor>, KeyedStateFactory> STATE_FACTORIES =
+	private static final Map<Class<? extends StateDescriptor>, StateFactory> STATE_FACTORIES =
 		Stream.of(
-			Tuple2.of(ValueStateDescriptor.class, (KeyedStateFactory) MockInternalValueState::createState),
-			Tuple2.of(ListStateDescriptor.class, (KeyedStateFactory) MockInternalListState::createState),
-			Tuple2.of(MapStateDescriptor.class, (KeyedStateFactory) MockInternalMapState::createState),
-			Tuple2.of(ReducingStateDescriptor.class, (KeyedStateFactory) MockInternalReducingState::createState),
-			Tuple2.of(AggregatingStateDescriptor.class, (KeyedStateFactory) MockInternalAggregatingState::createState),
-			Tuple2.of(FoldingStateDescriptor.class, (KeyedStateFactory) MockInternalFoldingState::createState)
+			Tuple2.of(ValueStateDescriptor.class, (StateFactory) MockInternalValueState::createState),
+			Tuple2.of(ListStateDescriptor.class, (StateFactory) MockInternalListState::createState),
+			Tuple2.of(MapStateDescriptor.class, (StateFactory) MockInternalMapState::createState),
+			Tuple2.of(ReducingStateDescriptor.class, (StateFactory) MockInternalReducingState::createState),
+			Tuple2.of(AggregatingStateDescriptor.class, (StateFactory) MockInternalAggregatingState::createState),
+			Tuple2.of(FoldingStateDescriptor.class, (StateFactory) MockInternalFoldingState::createState)
 		).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
 
 	private final Map<String, Map<K, Map<Object, Object>>> stateValues = new HashMap<>();
 
+	private final Map<String, StateSnapshotTransformer<Object>> stateSnapshotFilters = new HashMap<>();
+
 	MockKeyedStateBackend(
 		TaskKvStateRegistry kvStateRegistry,
 		TypeSerializer<K> keySerializer,
@@ -92,22 +102,46 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 	@Override
 	@SuppressWarnings("unchecked")
-	public <N, SV, S extends State, IS extends S> IS createInternalState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
-		KeyedStateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
+	@Nonnull
+	public <N, SV, SEV, S extends State, IS extends S> IS createInternalState(
+		@Nonnull TypeSerializer<N> namespaceSerializer,
+		@Nonnull StateDescriptor<S, SV> stateDesc,
+		@Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception {
+		StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
 		if (stateFactory == null) {
 			String message = String.format("State %s is not supported by %s",
 				stateDesc.getClass(), TtlStateFactory.class);
 			throw new FlinkRuntimeException(message);
 		}
 		IS state = stateFactory.createInternalState(namespaceSerializer, stateDesc);
+		stateSnapshotFilters.put(stateDesc.getName(),
+			(StateSnapshotTransformer<Object>) getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
 		((MockInternalKvState<K, N, SV>) state).values = () -> stateValues
 			.computeIfAbsent(stateDesc.getName(), n -> new HashMap<>())
 			.computeIfAbsent(getCurrentKey(), k -> new HashMap<>());
 		return state;
 	}
 
+	@SuppressWarnings("unchecked")
+	private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
+		StateDescriptor<?, SV> stateDesc,
+		StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+		Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
+		if (original.isPresent()) {
+			if (stateDesc instanceof ListStateDescriptor) {
+				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
+					.ListStateSnapshotTransformer<>(original.get());
+			} else if (stateDesc instanceof MapStateDescriptor) {
+				return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
+					.MapStateSnapshotTransformer<>(original.get());
+			} else {
+				return (StateSnapshotTransformer<SV>) original.get();
+			}
+		} else {
+			return null;
+		}
+	}
+
 	@Override
 	public int numKeyValueStateEntries() {
 		int count = 0;
@@ -142,7 +176,8 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		long timestamp,
 		CheckpointStreamFactory streamFactory,
 		CheckpointOptions checkpointOptions) {
-		return new FutureTask<>(() -> SnapshotResult.of(new MockKeyedStateHandle<>(copy(stateValues))));
+		return new FutureTask<>(() ->
+			SnapshotResult.of(new MockKeyedStateHandle<>(copy(stateValues, stateSnapshotFilters))));
 	}
 
 	@SuppressWarnings("unchecked")
@@ -153,32 +188,51 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		state.forEach(ksh -> stateValues.putAll(copy(((MockKeyedStateHandle<K>) ksh).snapshotStates)));
 	}
 
-	@SuppressWarnings("unchecked")
 	private static <K> Map<String, Map<K, Map<Object, Object>>> copy(
 		Map<String, Map<K, Map<Object, Object>>> stateValues) {
+		return copy(stateValues, Collections.emptyMap());
+	}
+
+	private static <K> Map<String, Map<K, Map<Object, Object>>> copy(
+		Map<String, Map<K, Map<Object, Object>>> stateValues, Map<String, StateSnapshotTransformer<Object>> stateSnapshotFilters) {
 		Map<String, Map<K, Map<Object, Object>>> snapshotStates = new HashMap<>();
 		for (String stateName : stateValues.keySet()) {
+			StateSnapshotTransformer<Object> stateSnapshotTransformer = stateSnapshotFilters.getOrDefault(stateName, null);
 			Map<K, Map<Object, Object>> keyedValues = snapshotStates.computeIfAbsent(stateName, s -> new HashMap<>());
 			for (K key : stateValues.get(stateName).keySet()) {
-				Map<Object, Object> values = keyedValues.computeIfAbsent(key, s -> new HashMap<>());
+				Map<Object, Object> snapshotedValues = keyedValues.computeIfAbsent(key, s -> new HashMap<>());
 				for (Object namespace : stateValues.get(stateName).get(key).keySet()) {
-					Object value = stateValues.get(stateName).get(key).get(namespace);
-					value = value instanceof List ? new ArrayList<>((List) value) : value;
-					value = value instanceof Map ? new HashMap<>((Map) value) : value;
-					values.put(namespace, value);
+					copyEntry(stateValues, snapshotedValues, stateName, key, namespace, stateSnapshotTransformer);
 				}
 			}
 		}
 		return snapshotStates;
 	}
 
+	@SuppressWarnings("unchecked")
+	private static <K> void copyEntry(
+		Map<String, Map<K, Map<Object, Object>>> stateValues,
+		Map<Object, Object> snapshotedValues,
+		String stateName,
+		K key,
+		Object namespace,
+		StateSnapshotTransformer<Object> stateSnapshotTransformer) {
+		Object value = stateValues.get(stateName).get(key).get(namespace);
+		value = value instanceof List ? new ArrayList<>((List) value) : value;
+		value = value instanceof Map ? new HashMap<>((Map) value) : value;
+		Object filteredValue = stateSnapshotTransformer == null ? value : stateSnapshotTransformer.filterOrTransform(value);
+		if (filteredValue != null) {
+			snapshotedValues.put(namespace, filteredValue);
+		}
+	}
+
 	@Nonnull
 	@Override
 	public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T>
 	create(
 		@Nonnull String stateName,
 		@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
-		return new HeapPriorityQueueSet<T>(
+		return new HeapPriorityQueueSet<>(
 			PriorityComparator.forPriorityComparableObjects(),
 			KeyExtractorFunction.forKeyedObjects(),
 			0,
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 7ead620..f7af354 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -84,6 +84,8 @@ import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.SnapshotStrategy;
 import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.StateUtil;
 import org.apache.flink.runtime.state.StreamCompressionDecorator;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -110,12 +112,14 @@ import org.rocksdb.DBOptions;
 import org.rocksdb.ReadOptions;
 import org.rocksdb.RocksDB;
 import org.rocksdb.RocksDBException;
+import org.rocksdb.RocksIterator;
 import org.rocksdb.Snapshot;
 import org.rocksdb.WriteOptions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 import java.io.File;
 import java.io.IOException;
@@ -135,6 +139,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.NoSuchElementException;
 import java.util.Objects;
+import java.util.Optional;
 import java.util.PriorityQueue;
 import java.util.Set;
 import java.util.SortedMap;
@@ -1312,7 +1317,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 */
 	private <N, S> Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, S>> tryRegisterKvStateInformation(
 			StateDescriptor<?, S> stateDesc,
-			TypeSerializer<N> namespaceSerializer) throws StateMigrationException {
+			TypeSerializer<N> namespaceSerializer,
+			@Nullable StateSnapshotTransformer<S> snapshotTransformer) throws StateMigrationException {
 
 		Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> stateInfo =
 			kvStateInformation.get(stateDesc.getName());
@@ -1330,7 +1336,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			newMetaInfo = RegisteredKeyValueStateBackendMetaInfo.resolveKvStateCompatibility(
 				restoredMetaInfoSnapshot,
 				namespaceSerializer,
-				stateDesc);
+				stateDesc,
+				snapshotTransformer);
 
 			stateInfo.f1 = newMetaInfo;
 		} else {
@@ -1340,7 +1347,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 				stateDesc.getType(),
 				stateName,
 				namespaceSerializer,
-				stateDesc.getSerializer());
+				stateDesc.getSerializer(),
+				snapshotTransformer);
 
 			ColumnFamilyHandle columnFamily = createColumnFamily(stateName);
 
@@ -1369,20 +1377,43 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	}
 
 	@Override
-	public <N, SV, S extends State, IS extends S> IS createInternalState(
-		TypeSerializer<N> namespaceSerializer,
-		StateDescriptor<S, SV> stateDesc) throws Exception {
+	@Nonnull
+	public <N, SV, SEV, S extends State, IS extends S> IS createInternalState(
+		@Nonnull TypeSerializer<N> namespaceSerializer,
+		@Nonnull StateDescriptor<S, SV> stateDesc,
+		@Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception {
 		StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
 		if (stateFactory == null) {
 			String message = String.format("State %s is not supported by %s",
 				stateDesc.getClass(), this.getClass());
 			throw new FlinkRuntimeException(message);
 		}
-		Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> registerResult =
-			tryRegisterKvStateInformation(stateDesc, namespaceSerializer);
+		Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> registerResult = tryRegisterKvStateInformation(
+			stateDesc, namespaceSerializer, getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
 		return stateFactory.createState(stateDesc, registerResult, RocksDBKeyedStateBackend.this);
 	}
 
+	@SuppressWarnings("unchecked")
+	private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
+		StateDescriptor<?, SV> stateDesc,
+		StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+		if (stateDesc instanceof ListStateDescriptor) {
+			Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
+			return original.map(est -> createRocksDBListStateTransformer(stateDesc, est)).orElse(null);
+		} else {
+			Optional<StateSnapshotTransformer<byte[]>> original = snapshotTransformFactory.createForSerializedState();
+			return (StateSnapshotTransformer<SV>) original.orElse(null);
+		}
+	}
+
+	@SuppressWarnings("unchecked")
+	private <SV, SEV> StateSnapshotTransformer<SV> createRocksDBListStateTransformer(
+		StateDescriptor<?, SV> stateDesc,
+		StateSnapshotTransformer<SEV> elementTransformer) {
+		return (StateSnapshotTransformer<SV>) new RocksDBListState.StateSnapshotTransformerWrapper<>(
+			elementTransformer, ((ListStateDescriptor<SEV>) stateDesc).getElementSerializer());
+	}
+
 	/**
 	 * Only visible for testing, DO NOT USE.
 	 */
@@ -1402,7 +1433,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		int count = 0;
 
 		for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> column : kvStateInformation.values()) {
-			//TODO maybe filter only for k/v states
+			//TODO maybe filterOrTransform only for k/v states
 			try (RocksIteratorWrapper rocksIterator = getRocksIterator(db, column.f0)) {
 				rocksIterator.seekToFirst();
 
@@ -1416,14 +1447,12 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		return count;
 	}
 
-
-
 	/**
 	 * Iterator that merges multiple RocksDB iterators to partition all states into contiguous key-groups.
 	 * The resulting iteration sequence is ordered by (key-group, kv-state).
 	 */
 	@VisibleForTesting
-	static final class RocksDBMergeIterator implements AutoCloseable {
+	static class RocksDBMergeIterator implements AutoCloseable {
 
 		private final PriorityQueue<RocksDBKeyedStateBackend.MergeIterator> heap;
 		private final int keyGroupPrefixByteCount;
@@ -1431,7 +1460,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		private boolean newKVState;
 		private boolean valid;
 
-		private MergeIterator currentSubIterator;
+		MergeIterator currentSubIterator;
 
 		private static final List<Comparator<MergeIterator>> COMPARATORS;
 
@@ -1440,18 +1469,17 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			COMPARATORS = new ArrayList<>(maxBytes);
 			for (int i = 0; i < maxBytes; ++i) {
 				final int currentBytes = i + 1;
-				COMPARATORS.add(new Comparator<MergeIterator>() {
-					@Override
-					public int compare(MergeIterator o1, MergeIterator o2) {
-						int arrayCmpRes = compareKeyGroupsForByteArrays(
-							o1.currentKey, o2.currentKey, currentBytes);
-						return arrayCmpRes == 0 ? o1.getKvStateId() - o2.getKvStateId() : arrayCmpRes;
-					}
+				COMPARATORS.add((o1, o2) -> {
+					int arrayCmpRes = compareKeyGroupsForByteArrays(
+						o1.currentKey, o2.currentKey, currentBytes);
+					return arrayCmpRes == 0 ? o1.getKvStateId() - o2.getKvStateId() : arrayCmpRes;
 				});
 			}
 		}
 
-		RocksDBMergeIterator(List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators, final int keyGroupPrefixByteCount) throws RocksDBException {
+		RocksDBMergeIterator(
+			List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
+			final int keyGroupPrefixByteCount) {
 			Preconditions.checkNotNull(kvStateIterators);
 			Preconditions.checkArgument(keyGroupPrefixByteCount >= 1);
 
@@ -1492,7 +1520,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		 * Advance the iterator. Should only be called if {@link #isValid()} returned true. Valid can only chance after
 		 * calls to {@link #next()}.
 		 */
-		public void next() throws RocksDBException {
+		public void next() {
 			newKeyGroup = false;
 			newKVState = false;
 
@@ -1612,7 +1640,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 	 * Wraps a RocksDB iterator to cache it's current key and assigns an id for the key/value state to the iterator.
 	 * Used by #MergeIterator.
 	 */
-	private static final class MergeIterator implements AutoCloseable {
+	@VisibleForTesting
+	protected static final class MergeIterator implements AutoCloseable {
 
 		/**
 		 * @param iterator  The #RocksIterator to wrap .
@@ -1650,6 +1679,57 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		}
 	}
 
+	private static final class TransformingRocksIteratorWrapper extends RocksIteratorWrapper {
+		@Nonnull
+		private final StateSnapshotTransformer<byte[]> stateSnapshotTransformer;
+		private byte[] current;
+
+		public TransformingRocksIteratorWrapper(
+			@Nonnull RocksIterator iterator,
+			@Nonnull StateSnapshotTransformer<byte[]> stateSnapshotTransformer) {
+			super(iterator);
+			this.stateSnapshotTransformer = stateSnapshotTransformer;
+		}
+
+		@Override
+		public void seekToFirst() {
+			super.seekToFirst();
+			filterOrTransform(super::next);
+		}
+
+		@Override
+		public void seekToLast() {
+			super.seekToLast();
+			filterOrTransform(super::prev);
+		}
+
+		@Override
+		public void next() {
+			super.next();
+			filterOrTransform(super::next);
+		}
+
+		@Override
+		public void prev() {
+			super.prev();
+			filterOrTransform(super::prev);
+		}
+
+		private void filterOrTransform(Runnable advance) {
+			while (isValid() && (current = stateSnapshotTransformer.filterOrTransform(super.value())) == null) {
+				advance.run();
+			}
+		}
+
+		@Override
+		public byte[] value() {
+			if (!isValid()) {
+				throw new IllegalStateException("value() method cannot be called if isValid() is false");
+			}
+			return current;
+		}
+	}
+
 	/**
 	 * Adapter class to bridge between {@link RocksIteratorWrapper} and {@link Iterator} to iterate over the keys. This class
 	 * is not thread safe.
@@ -1972,7 +2052,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		/**
 		 * The copied column handle.
 		 */
-		private List<ColumnFamilyHandle> copiedColumnFamilyHandles;
+		private List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> copiedMeta;
 
 		private List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators;
 
@@ -2000,15 +2080,13 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 			this.stateMetaInfoSnapshots = new ArrayList<>(stateBackend.kvStateInformation.size());
 
-			this.copiedColumnFamilyHandles = new ArrayList<>(stateBackend.kvStateInformation.size());
+			this.copiedMeta = new ArrayList<>(stateBackend.kvStateInformation.size());
 
 			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 :
 				stateBackend.kvStateInformation.values()) {
 				// snapshot meta info
 				this.stateMetaInfoSnapshots.add(tuple2.f1.snapshot());
-
-				// copy column family handle
-				this.copiedColumnFamilyHandles.add(tuple2.f0);
+				this.copiedMeta.add(tuple2);
 			}
 			this.snapshot = stateBackend.db.getSnapshot();
 		}
@@ -2095,7 +2173,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
 		private void writeKVStateMetaData() throws IOException {
 
-			this.kvStateIterators = new ArrayList<>(copiedColumnFamilyHandles.size());
+			this.kvStateIterators = new ArrayList<>(copiedMeta.size());
 
 			int kvStateId = 0;
 
@@ -2103,11 +2181,10 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			readOptions = new ReadOptions();
 			readOptions.setSnapshot(snapshot);
 
-			for (ColumnFamilyHandle columnFamilyHandle : copiedColumnFamilyHandles) {
-
-				kvStateIterators.add(
-					new Tuple2<>(getRocksIterator(stateBackend.db, columnFamilyHandle, readOptions), kvStateId));
-
+			for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : copiedMeta) {
+				RocksIteratorWrapper rocksIteratorWrapper =
+					getRocksIterator(stateBackend.db, tuple2.f0, tuple2.f1, readOptions);
+				kvStateIterators.add(new Tuple2<>(rocksIteratorWrapper, kvStateId));
 				++kvStateId;
 			}
 
@@ -2124,8 +2201,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 			serializationProxy.write(outputView);
 		}
 
-		private void writeKVStateData() throws IOException, InterruptedException, RocksDBException {
-
+		private void writeKVStateData() throws IOException, InterruptedException {
 			byte[] previousKey = null;
 			byte[] previousValue = null;
 			DataOutputView kgOutView = null;
@@ -2635,11 +2711,21 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 		return new RocksIteratorWrapper(db.newIterator(columnFamilyHandle));
 	}
 
-	public static RocksIteratorWrapper getRocksIterator(
+	@SuppressWarnings("unchecked")
+	private static RocksIteratorWrapper getRocksIterator(
 		RocksDB db,
 		ColumnFamilyHandle columnFamilyHandle,
+		RegisteredStateMetaInfoBase metaInfo,
 		ReadOptions readOptions) {
-		return new RocksIteratorWrapper(db.newIterator(columnFamilyHandle, readOptions));
+		StateSnapshotTransformer<byte[]> stateSnapshotTransformer = null;
+		if (metaInfo instanceof RegisteredKeyValueStateBackendMetaInfo) {
+			stateSnapshotTransformer = (StateSnapshotTransformer<byte[]>)
+				((RegisteredKeyValueStateBackendMetaInfo<?, ?>) metaInfo).getSnapshotTransformer();
+		}
+		RocksIterator rocksIterator = db.newIterator(columnFamilyHandle, readOptions);
+		return stateSnapshotTransformer == null ?
+			new RocksIteratorWrapper(rocksIterator) :
+			new TransformingRocksIteratorWrapper(rocksIterator, stateSnapshotTransformer);
 	}
 
 	/**
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java
index aa5e93a..176f48c 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java
@@ -24,9 +24,12 @@ import org.apache.flink.api.common.state.State;
 import org.apache.flink.api.common.state.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.ByteArrayDataInputView;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.internal.InternalListState;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
@@ -34,12 +37,16 @@ import org.apache.flink.util.Preconditions;
 import org.rocksdb.ColumnFamilyHandle;
 import org.rocksdb.RocksDBException;
 
-import java.io.ByteArrayInputStream;
+import javax.annotation.Nullable;
+
 import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
+import static org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy.STOP_ON_FIRST_INCLUDED;
+
 /**
  * {@link ListState} implementation that stores state in RocksDB.
  *
@@ -111,25 +118,42 @@ class RocksDBListState<K, N, V>
 			writeCurrentKeyWithGroupAndNamespace();
 			byte[] key = keySerializationStream.toByteArray();
 			byte[] valueBytes = backend.db.get(columnFamily, key);
+			return deserializeList(valueBytes, elementSerializer);
+		} catch (IOException | RocksDBException e) {
+			throw new FlinkRuntimeException("Error while retrieving data from RocksDB", e);
+		}
+	}
 
-			if (valueBytes == null) {
-				return null;
-			}
+	private static <V> List<V> deserializeList(
+		byte[] valueBytes, TypeSerializer<V> elementSerializer) {
+		if (valueBytes == null) {
+			return null;
+		}
 
-			ByteArrayInputStream bais = new ByteArrayInputStream(valueBytes);
-			DataInputViewStreamWrapper in = new DataInputViewStreamWrapper(bais);
+		DataInputViewStreamWrapper in = new ByteArrayDataInputView(valueBytes);
 
-			List<V> result = new ArrayList<>();
-			while (in.available() > 0) {
-				result.add(elementSerializer.deserialize(in));
+		List<V> result = new ArrayList<>();
+		V next;
+		while ((next = deserializeNextElement(in, elementSerializer)) != null) {
+			result.add(next);
+		}
+		return result;
+	}
+
+	private static <V> V deserializeNextElement(
+		DataInputViewStreamWrapper in, TypeSerializer<V> elementSerializer) {
+		try {
+			if (in.available() > 0) {
+				V element = elementSerializer.deserialize(in);
 				if (in.available() > 0) {
 					in.readByte();
 				}
+				return element;
 			}
-			return result;
-		} catch (IOException | RocksDBException e) {
-			throw new FlinkRuntimeException("Error while retrieving data from RocksDB", e);
+		} catch (IOException e) {
+			throw new FlinkRuntimeException("Unexpected list element deserialization failure");
 		}
+		return null;
 	}
 
 	@Override
@@ -203,7 +227,7 @@ class RocksDBListState<K, N, V>
 				writeCurrentKeyWithGroupAndNamespace();
 				byte[] key = keySerializationStream.toByteArray();
 
-				byte[] premerge = getPreMergedValue(values);
+				byte[] premerge = getPreMergedValue(values, elementSerializer, keySerializationStream);
 				if (premerge != null) {
 					backend.db.put(columnFamily, writeOptions, key, premerge);
 				} else {
@@ -224,7 +248,7 @@ class RocksDBListState<K, N, V>
 				writeCurrentKeyWithGroupAndNamespace();
 				byte[] key = keySerializationStream.toByteArray();
 
-				byte[] premerge = getPreMergedValue(values);
+				byte[] premerge = getPreMergedValue(values, elementSerializer, keySerializationStream);
 				if (premerge != null) {
 					backend.db.merge(columnFamily, writeOptions, key, premerge);
 				} else {
@@ -236,7 +260,10 @@ class RocksDBListState<K, N, V>
 		}
 	}
 
-	private byte[] getPreMergedValue(List<V> values) throws IOException {
+	private static <V> byte[] getPreMergedValue(
+		List<V> values,
+		TypeSerializer<V> elementSerializer,
+		ByteArrayOutputStreamWithPos keySerializationStream) throws IOException {
 		DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(keySerializationStream);
 
 		keySerializationStream.reset();
@@ -267,4 +294,47 @@ class RocksDBListState<K, N, V>
 			((ListStateDescriptor<E>) stateDesc).getElementSerializer(),
 			backend);
 	}
+
+	static class StateSnapshotTransformerWrapper<T> implements StateSnapshotTransformer<byte[]> {
+		private final StateSnapshotTransformer<T> elementTransformer;
+		private final TypeSerializer<T> elementSerializer;
+		private final ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(128);
+		private final CollectionStateSnapshotTransformer.TransformStrategy transformStrategy;
+
+		StateSnapshotTransformerWrapper(StateSnapshotTransformer<T> elementTransformer, TypeSerializer<T> elementSerializer) {
+			this.elementTransformer = elementTransformer;
+			this.elementSerializer = elementSerializer;
+			this.transformStrategy = elementTransformer instanceof CollectionStateSnapshotTransformer ?
+				((CollectionStateSnapshotTransformer) elementTransformer).getFilterStrategy() :
+				CollectionStateSnapshotTransformer.TransformStrategy.TRANSFORM_ALL;
+		}
+
+		@Override
+		@Nullable
+		public byte[] filterOrTransform(@Nullable byte[] value) {
+			if (value == null) {
+				return null;
+			}
+			List<T> result = new ArrayList<>();
+			ByteArrayDataInputView in = new ByteArrayDataInputView(value);
+			T next;
+			int prevPosition = 0;
+			while ((next = deserializeNextElement(in, elementSerializer)) != null) {
+				T transformedElement = elementTransformer.filterOrTransform(next);
+				if (transformedElement != null) {
+					if (transformStrategy == STOP_ON_FIRST_INCLUDED) {
+						return Arrays.copyOfRange(value, prevPosition, value.length);
+					} else {
+						result.add(transformedElement);
+					}
+				}
+				prevPosition = in.getPosition();
+			}
+			try {
+				return result.isEmpty() ? null : getPreMergedValue(result, elementSerializer, out);
+			} catch (IOException e) {
+				throw new FlinkRuntimeException("Failed to serialize transformed list", e);
+			}
+		}
+	}
 }