You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by sr...@apache.org on 2018/08/03 17:40:59 UTC
[flink] branch master updated: [FLINK-9938][state] Clean up full
snapshot from expired state with TTL
This is an automated email from the ASF dual-hosted git repository.
srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new ce96c40 [FLINK-9938][state] Clean up full snapshot from expired state with TTL
ce96c40 is described below
commit ce96c409148d1a9bc40f581e13900818b5f11f6a
Author: Andrey Zagrebin <az...@gmail.com>
AuthorDate: Fri Aug 3 18:18:07 2018 +0200
[FLINK-9938][state] Clean up full snapshot from expired state with TTL
This closes #6460.
---
.../flink/api/common/state/StateDescriptor.java | 12 +-
.../flink/api/common/state/StateTtlConfig.java | 267 +++++++++++++++++++++
.../api/common/state/StateTtlConfiguration.java | 205 ----------------
.../tests/DataStreamStateTTLTestProgram.java | 4 +-
.../streaming/tests/TtlVerifyUpdateFunction.java | 6 +-
.../tests/verify/AbstractTtlStateVerifier.java | 4 +-
.../streaming/tests/verify/TtlStateVerifier.java | 4 +-
.../flink/runtime/state/KeyGroupPartitioner.java | 18 +-
.../flink/runtime/state/KeyedStateFactory.java | 32 ++-
.../RegisteredKeyValueStateBackendMetaInfo.java | 34 ++-
.../runtime/state/StateSnapshotTransformer.java | 186 ++++++++++++++
.../state/heap/CopyOnWriteStateTableSnapshot.java | 106 ++++++--
.../runtime/state/heap/HeapKeyedStateBackend.java | 44 +++-
.../runtime/state/heap/NestedMapsStateTable.java | 64 +++--
.../flink/runtime/state/heap/StateTable.java | 2 +-
.../runtime/state/ttl/AbstractTtlDecorator.java | 26 +-
.../flink/runtime/state/ttl/AbstractTtlState.java | 4 +-
.../runtime/state/ttl/TtlAggregateFunction.java | 4 +-
.../runtime/state/ttl/TtlAggregatingState.java | 4 +-
.../flink/runtime/state/ttl/TtlFoldFunction.java | 4 +-
.../flink/runtime/state/ttl/TtlFoldingState.java | 4 +-
.../flink/runtime/state/ttl/TtlListState.java | 4 +-
.../flink/runtime/state/ttl/TtlMapState.java | 4 +-
.../flink/runtime/state/ttl/TtlReduceFunction.java | 4 +-
.../flink/runtime/state/ttl/TtlReducingState.java | 4 +-
.../flink/runtime/state/ttl/TtlStateFactory.java | 96 ++++----
.../state/ttl/TtlStateSnapshotTransformer.java | 121 ++++++++++
.../apache/flink/runtime/state/ttl/TtlUtils.java | 39 +++
.../flink/runtime/state/ttl/TtlValueState.java | 4 +-
.../flink/runtime/state/ttl/TtlStateTestBase.java | 63 +++--
.../state/ttl/mock/MockKeyedStateBackend.java | 94 ++++++--
.../streaming/state/RocksDBKeyedStateBackend.java | 162 ++++++++++---
.../contrib/streaming/state/RocksDBListState.java | 100 ++++++--
33 files changed, 1264 insertions(+), 465 deletions(-)
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
index 191eb6f..422d77f 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/StateDescriptor.java
@@ -96,7 +96,7 @@ public abstract class StateDescriptor<S extends State, T> implements Serializabl
/** Name for queries against state created from this StateDescriptor. */
@Nonnull
- private StateTtlConfiguration ttlConfig = StateTtlConfiguration.DISABLED;
+ private StateTtlConfig ttlConfig = StateTtlConfig.DISABLED;
/** The default value returned by the state when no other value is bound to a key. */
@Nullable
@@ -210,7 +210,7 @@ public abstract class StateDescriptor<S extends State, T> implements Serializabl
*/
public void setQueryable(String queryableStateName) {
Preconditions.checkArgument(
- ttlConfig.getTtlUpdateType() == StateTtlConfiguration.TtlUpdateType.Disabled,
+ ttlConfig.getUpdateType() == StateTtlConfig.UpdateType.Disabled,
"Queryable state is currently not supported with TTL");
if (this.queryableStateName == null) {
this.queryableStateName = Preconditions.checkNotNull(queryableStateName, "Registration name");
@@ -243,14 +243,14 @@ public abstract class StateDescriptor<S extends State, T> implements Serializabl
* Configures optional activation of state time-to-live (TTL).
*
* <p>State user value will expire, become unavailable and be cleaned up in storage
- * depending on configured {@link StateTtlConfiguration}.
+ * depending on configured {@link StateTtlConfig}.
*
* @param ttlConfig configuration of state TTL
*/
- public void enableTimeToLive(StateTtlConfiguration ttlConfig) {
+ public void enableTimeToLive(StateTtlConfig ttlConfig) {
Preconditions.checkNotNull(ttlConfig);
Preconditions.checkArgument(
- ttlConfig.getTtlUpdateType() != StateTtlConfiguration.TtlUpdateType.Disabled &&
+ ttlConfig.getUpdateType() != StateTtlConfig.UpdateType.Disabled &&
queryableStateName == null,
"Queryable state is currently not supported with TTL");
this.ttlConfig = ttlConfig;
@@ -258,7 +258,7 @@ public abstract class StateDescriptor<S extends State, T> implements Serializabl
@Nonnull
@Internal
- public StateTtlConfiguration getTtlConfig() {
+ public StateTtlConfig getTtlConfig() {
return ttlConfig;
}
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfig.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfig.java
new file mode 100644
index 0000000..f4ed929
--- /dev/null
+++ b/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfig.java
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common.state;
+
+import org.apache.flink.api.common.time.Time;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nonnull;
+
+import java.io.Serializable;
+import java.util.EnumMap;
+
+import static org.apache.flink.api.common.state.StateTtlConfig.StateVisibility.NeverReturnExpired;
+import static org.apache.flink.api.common.state.StateTtlConfig.TimeCharacteristic.ProcessingTime;
+import static org.apache.flink.api.common.state.StateTtlConfig.UpdateType.OnCreateAndWrite;
+
+/**
+ * Configuration of state TTL logic.
+ */
+public class StateTtlConfig implements Serializable {
+
+ private static final long serialVersionUID = -7592693245044289793L;
+
+ public static final StateTtlConfig DISABLED =
+ newBuilder(Time.milliseconds(Long.MAX_VALUE)).setUpdateType(UpdateType.Disabled).build();
+
+ /**
+ * This option value configures when to update last access timestamp which prolongs state TTL.
+ */
+ public enum UpdateType {
+ /** TTL is disabled. State does not expire. */
+ Disabled,
+ /** Last access timestamp is initialised when state is created and updated on every write operation. */
+ OnCreateAndWrite,
+ /** The same as <code>OnCreateAndWrite</code> but also updated on read. */
+ OnReadAndWrite
+ }
+
+ /**
+ * This option configures whether expired user value can be returned or not.
+ */
+ public enum StateVisibility {
+ /** Return expired user value if it is not cleaned up yet. */
+ ReturnExpiredIfNotCleanedUp,
+ /** Never return expired user value. */
+ NeverReturnExpired
+ }
+
+ /**
+ * This option configures time scale to use for ttl.
+ */
+ public enum TimeCharacteristic {
+ /** Processing time, see also <code>TimeCharacteristic.ProcessingTime</code>. */
+ ProcessingTime
+ }
+
+ private final UpdateType updateType;
+ private final StateVisibility stateVisibility;
+ private final TimeCharacteristic timeCharacteristic;
+ private final Time ttl;
+ private final CleanupStrategies cleanupStrategies;
+
+ private StateTtlConfig(
+ UpdateType updateType,
+ StateVisibility stateVisibility,
+ TimeCharacteristic timeCharacteristic,
+ Time ttl,
+ CleanupStrategies cleanupStrategies) {
+ this.updateType = Preconditions.checkNotNull(updateType);
+ this.stateVisibility = Preconditions.checkNotNull(stateVisibility);
+ this.timeCharacteristic = Preconditions.checkNotNull(timeCharacteristic);
+ this.ttl = Preconditions.checkNotNull(ttl);
+ this.cleanupStrategies = cleanupStrategies;
+ Preconditions.checkArgument(ttl.toMilliseconds() > 0,
+ "TTL is expected to be positive");
+ }
+
+ @Nonnull
+ public UpdateType getUpdateType() {
+ return updateType;
+ }
+
+ @Nonnull
+ public StateVisibility getStateVisibility() {
+ return stateVisibility;
+ }
+
+ @Nonnull
+ public Time getTtl() {
+ return ttl;
+ }
+
+ @Nonnull
+ public TimeCharacteristic getTimeCharacteristic() {
+ return timeCharacteristic;
+ }
+
+ public boolean isEnabled() {
+ return updateType != UpdateType.Disabled;
+ }
+
+ @Nonnull
+ public CleanupStrategies getCleanupStrategies() {
+ return cleanupStrategies;
+ }
+
+ @Override
+ public String toString() {
+ return "StateTtlConfig{" +
+ "updateType=" + updateType +
+ ", stateVisibility=" + stateVisibility +
+ ", timeCharacteristic=" + timeCharacteristic +
+ ", ttl=" + ttl +
+ '}';
+ }
+
+ @Nonnull
+ public static Builder newBuilder(@Nonnull Time ttl) {
+ return new Builder(ttl);
+ }
+
+ /**
+ * Builder for the {@link StateTtlConfig}.
+ */
+ public static class Builder {
+
+ private UpdateType updateType = OnCreateAndWrite;
+ private StateVisibility stateVisibility = NeverReturnExpired;
+ private TimeCharacteristic timeCharacteristic = ProcessingTime;
+ private Time ttl;
+ private CleanupStrategies cleanupStrategies = new CleanupStrategies();
+
+ public Builder(@Nonnull Time ttl) {
+ this.ttl = ttl;
+ }
+
+ /**
+ * Sets the ttl update type.
+ *
+ * @param updateType The ttl update type configures when to update last access timestamp which prolongs state TTL.
+ */
+ @Nonnull
+ public Builder setUpdateType(UpdateType updateType) {
+ this.updateType = updateType;
+ return this;
+ }
+
+ @Nonnull
+ public Builder updateTtlOnCreateAndWrite() {
+ return setUpdateType(UpdateType.OnCreateAndWrite);
+ }
+
+ @Nonnull
+ public Builder updateTtlOnReadAndWrite() {
+ return setUpdateType(UpdateType.OnReadAndWrite);
+ }
+
+ /**
+ * Sets the state visibility.
+ *
+ * @param stateVisibility The state visibility configures whether expired user value can be returned or not.
+ */
+ @Nonnull
+ public Builder setStateVisibility(@Nonnull StateVisibility stateVisibility) {
+ this.stateVisibility = stateVisibility;
+ return this;
+ }
+
+ @Nonnull
+ public Builder returnExpiredIfNotCleanedUp() {
+ return setStateVisibility(StateVisibility.ReturnExpiredIfNotCleanedUp);
+ }
+
+ @Nonnull
+ public Builder neverReturnExpired() {
+ return setStateVisibility(StateVisibility.NeverReturnExpired);
+ }
+
+ /**
+ * Sets the time characteristic.
+ *
+ * @param timeCharacteristic The time characteristic configures time scale to use for ttl.
+ */
+ @Nonnull
+ public Builder setTimeCharacteristic(@Nonnull TimeCharacteristic timeCharacteristic) {
+ this.timeCharacteristic = timeCharacteristic;
+ return this;
+ }
+
+ @Nonnull
+ public Builder useProcessingTime() {
+ return setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+ }
+
+ /** Cleanup expired state in full snapshot on checkpoint. */
+ @Nonnull
+ public Builder cleanupFullSnapshot() {
+ cleanupStrategies.strategies.put(
+ CleanupStrategies.Strategies.FULL_STATE_SCAN_SNAPSHOT,
+ new CleanupStrategies.CleanupStrategy() { });
+ return this;
+ }
+
+ /**
+ * Sets the ttl time.
+ * @param ttl The ttl time.
+ */
+ @Nonnull
+ public Builder setTtl(@Nonnull Time ttl) {
+ this.ttl = ttl;
+ return this;
+ }
+
+ @Nonnull
+ public StateTtlConfig build() {
+ return new StateTtlConfig(
+ updateType,
+ stateVisibility,
+ timeCharacteristic,
+ ttl,
+ cleanupStrategies);
+ }
+ }
+
+ /**
+ * TTL cleanup strategies.
+ *
+ * <p>This class configures when to cleanup expired state with TTL.
+ * By default, state is always cleaned up on explicit read access if found expired.
+ * Currently cleanup of state full snapshot can be additionally activated.
+ */
+ public static class CleanupStrategies implements Serializable {
+ private static final long serialVersionUID = -1617740467277313524L;
+
+ /** Fixed strategies ordinals in {@code strategies} config field. */
+ enum Strategies {
+ FULL_STATE_SCAN_SNAPSHOT
+ }
+
+ /** Base interface for cleanup strategies configurations. */
+ interface CleanupStrategy extends Serializable {
+
+ }
+
+ final EnumMap<Strategies, CleanupStrategy> strategies = new EnumMap<>(Strategies.class);
+
+ public boolean inFullSnapshot() {
+ return strategies.containsKey(Strategies.FULL_STATE_SCAN_SNAPSHOT);
+ }
+ }
+}
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfiguration.java b/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfiguration.java
deleted file mode 100644
index 55ec29c..0000000
--- a/flink-core/src/main/java/org/apache/flink/api/common/state/StateTtlConfiguration.java
+++ /dev/null
@@ -1,205 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.api.common.state;
-
-import org.apache.flink.api.common.time.Time;
-import org.apache.flink.util.Preconditions;
-
-import java.io.Serializable;
-
-import static org.apache.flink.api.common.state.StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired;
-import static org.apache.flink.api.common.state.StateTtlConfiguration.TtlTimeCharacteristic.ProcessingTime;
-import static org.apache.flink.api.common.state.StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite;
-
-/**
- * Configuration of state TTL logic.
- */
-public class StateTtlConfiguration implements Serializable {
-
- private static final long serialVersionUID = -7592693245044289793L;
-
- public static final StateTtlConfiguration DISABLED =
- newBuilder(Time.milliseconds(Long.MAX_VALUE)).setTtlUpdateType(TtlUpdateType.Disabled).build();
-
- /**
- * This option value configures when to update last access timestamp which prolongs state TTL.
- */
- public enum TtlUpdateType {
- /** TTL is disabled. State does not expire. */
- Disabled,
- /** Last access timestamp is initialised when state is created and updated on every write operation. */
- OnCreateAndWrite,
- /** The same as <code>OnCreateAndWrite</code> but also updated on read. */
- OnReadAndWrite
- }
-
- /**
- * This option configures whether expired user value can be returned or not.
- */
- public enum TtlStateVisibility {
- /** Return expired user value if it is not cleaned up yet. */
- ReturnExpiredIfNotCleanedUp,
- /** Never return expired user value. */
- NeverReturnExpired
- }
-
- /**
- * This option configures time scale to use for ttl.
- */
- public enum TtlTimeCharacteristic {
- /** Processing time, see also <code>TimeCharacteristic.ProcessingTime</code>. */
- ProcessingTime
- }
-
- private final TtlUpdateType ttlUpdateType;
- private final TtlStateVisibility stateVisibility;
- private final TtlTimeCharacteristic timeCharacteristic;
- private final Time ttl;
-
- private StateTtlConfiguration(
- TtlUpdateType ttlUpdateType,
- TtlStateVisibility stateVisibility,
- TtlTimeCharacteristic timeCharacteristic,
- Time ttl) {
- this.ttlUpdateType = Preconditions.checkNotNull(ttlUpdateType);
- this.stateVisibility = Preconditions.checkNotNull(stateVisibility);
- this.timeCharacteristic = Preconditions.checkNotNull(timeCharacteristic);
- this.ttl = Preconditions.checkNotNull(ttl);
- Preconditions.checkArgument(ttl.toMilliseconds() > 0,
- "TTL is expected to be positive");
- }
-
- public TtlUpdateType getTtlUpdateType() {
- return ttlUpdateType;
- }
-
- public TtlStateVisibility getStateVisibility() {
- return stateVisibility;
- }
-
- public Time getTtl() {
- return ttl;
- }
-
- public TtlTimeCharacteristic getTimeCharacteristic() {
- return timeCharacteristic;
- }
-
- public boolean isEnabled() {
- return ttlUpdateType != TtlUpdateType.Disabled;
- }
-
- @Override
- public String toString() {
- return "StateTtlConfiguration{" +
- "ttlUpdateType=" + ttlUpdateType +
- ", stateVisibility=" + stateVisibility +
- ", timeCharacteristic=" + timeCharacteristic +
- ", ttl=" + ttl +
- '}';
- }
-
- public static Builder newBuilder(Time ttl) {
- return new Builder(ttl);
- }
-
- /**
- * Builder for the {@link StateTtlConfiguration}.
- */
- public static class Builder {
-
- private TtlUpdateType ttlUpdateType = OnCreateAndWrite;
- private TtlStateVisibility stateVisibility = NeverReturnExpired;
- private TtlTimeCharacteristic timeCharacteristic = ProcessingTime;
- private Time ttl;
-
- public Builder(Time ttl) {
- this.ttl = ttl;
- }
-
- /**
- * Sets the ttl update type.
- *
- * @param ttlUpdateType The ttl update type configures when to update last access timestamp which prolongs state TTL.
- */
- public Builder setTtlUpdateType(TtlUpdateType ttlUpdateType) {
- this.ttlUpdateType = ttlUpdateType;
- return this;
- }
-
- public Builder updateTtlOnCreateAndWrite() {
- return setTtlUpdateType(TtlUpdateType.OnCreateAndWrite);
- }
-
- public Builder updateTtlOnReadAndWrite() {
- return setTtlUpdateType(TtlUpdateType.OnReadAndWrite);
- }
-
- /**
- * Sets the state visibility.
- *
- * @param stateVisibility The state visibility configures whether expired user value can be returned or not.
- */
- public Builder setStateVisibility(TtlStateVisibility stateVisibility) {
- this.stateVisibility = stateVisibility;
- return this;
- }
-
- public Builder returnExpiredIfNotCleanedUp() {
- return setStateVisibility(TtlStateVisibility.ReturnExpiredIfNotCleanedUp);
- }
-
- public Builder neverReturnExpired() {
- return setStateVisibility(TtlStateVisibility.NeverReturnExpired);
- }
-
- /**
- * Sets the time characteristic.
- *
- * @param timeCharacteristic The time characteristic configures time scale to use for ttl.
- */
- public Builder setTimeCharacteristic(TtlTimeCharacteristic timeCharacteristic) {
- this.timeCharacteristic = timeCharacteristic;
- return this;
- }
-
- public Builder useProcessingTime() {
- return setTimeCharacteristic(TtlTimeCharacteristic.ProcessingTime);
- }
-
- /**
- * Sets the ttl time.
- * @param ttl The ttl time.
- */
- public Builder setTtl(Time ttl) {
- this.ttl = ttl;
- return this;
- }
-
- public StateTtlConfiguration build() {
- return new StateTtlConfiguration(
- ttlUpdateType,
- stateVisibility,
- timeCharacteristic,
- ttl
- );
- }
-
- }
-}
diff --git a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/DataStreamStateTTLTestProgram.java b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/DataStreamStateTTLTestProgram.java
index f4c9619..3b2e474 100644
--- a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/DataStreamStateTTLTestProgram.java
+++ b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/DataStreamStateTTLTestProgram.java
@@ -18,7 +18,7 @@
package org.apache.flink.streaming.tests;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.utils.ParameterTool;
import org.apache.flink.configuration.ConfigOption;
@@ -84,7 +84,7 @@ public class DataStreamStateTTLTestProgram {
long reportStatAfterUpdatesNum = pt.getLong(REPORT_STAT_AFTER_UPDATES_NUM.key(),
REPORT_STAT_AFTER_UPDATES_NUM.defaultValue());
- StateTtlConfiguration ttlConfig = StateTtlConfiguration.newBuilder(ttl).build();
+ StateTtlConfig ttlConfig = StateTtlConfig.newBuilder(ttl).build();
env
.addSource(new TtlStateUpdateSource(keySpace, sleepAfterElements, sleepTime))
diff --git a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/TtlVerifyUpdateFunction.java b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/TtlVerifyUpdateFunction.java
index a99a45f..3cfb0e2 100644
--- a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/TtlVerifyUpdateFunction.java
+++ b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/TtlVerifyUpdateFunction.java
@@ -23,7 +23,7 @@ import org.apache.flink.api.common.state.KeyedStateStore;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.State;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
@@ -67,14 +67,14 @@ class TtlVerifyUpdateFunction
private static final Logger LOG = LoggerFactory.getLogger(TtlVerifyUpdateFunction.class);
@Nonnull
- private final StateTtlConfiguration ttlConfig;
+ private final StateTtlConfig ttlConfig;
private final long ttl;
private final UpdateStat stat;
private transient Map<String, State> states;
private transient Map<String, ListState<ValueWithTs<?>>> prevUpdatesByVerifierId;
- TtlVerifyUpdateFunction(@Nonnull StateTtlConfiguration ttlConfig, long reportStatAfterUpdatesNum) {
+ TtlVerifyUpdateFunction(@Nonnull StateTtlConfig ttlConfig, long reportStatAfterUpdatesNum) {
this.ttlConfig = ttlConfig;
this.ttl = ttlConfig.getTtl().toMilliseconds();
this.stat = new UpdateStat(reportStatAfterUpdatesNum);
diff --git a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/AbstractTtlStateVerifier.java b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/AbstractTtlStateVerifier.java
index c56ff19..7b6def2 100644
--- a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/AbstractTtlStateVerifier.java
+++ b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/AbstractTtlStateVerifier.java
@@ -20,7 +20,7 @@ package org.apache.flink.streaming.tests.verify;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.util.StringUtils;
@@ -52,7 +52,7 @@ abstract class AbstractTtlStateVerifier<D extends StateDescriptor<S, SV>, S exte
@SuppressWarnings("unchecked")
@Override
@Nonnull
- public State createState(@Nonnull FunctionInitializationContext context, @Nonnull StateTtlConfiguration ttlConfig) {
+ public State createState(@Nonnull FunctionInitializationContext context, @Nonnull StateTtlConfig ttlConfig) {
stateDesc.enableTimeToLive(ttlConfig);
return createState(context);
}
diff --git a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/TtlStateVerifier.java b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/TtlStateVerifier.java
index e1c2e07..ec5d8b0 100644
--- a/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/TtlStateVerifier.java
+++ b/flink-end-to-end-tests/flink-stream-state-ttl-test/src/main/java/org/apache/flink/streaming/tests/verify/TtlStateVerifier.java
@@ -19,7 +19,7 @@
package org.apache.flink.streaming.tests.verify;
import org.apache.flink.api.common.state.State;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.FunctionInitializationContext;
@@ -45,7 +45,7 @@ public interface TtlStateVerifier<UV, GV> {
}
@Nonnull
- State createState(@Nonnull FunctionInitializationContext context, @Nonnull StateTtlConfiguration ttlConfig);
+ State createState(@Nonnull FunctionInitializationContext context, @Nonnull StateTtlConfig ttlConfig);
@Nonnull
TypeSerializer<UV> getUpdateSerializer();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java
index 27d411c..95f8369 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyGroupPartitioner.java
@@ -135,8 +135,8 @@ public class KeyGroupPartitioner<T> {
public StateSnapshot.StateKeyGroupWriter partitionByKeyGroup() {
if (computedResult == null) {
reportAllElementKeyGroups();
- buildHistogramByAccumulatingCounts();
- executePartitioning();
+ int outputNumberOfElements = buildHistogramByAccumulatingCounts();
+ executePartitioning(outputNumberOfElements);
}
return computedResult;
}
@@ -167,22 +167,20 @@ public class KeyGroupPartitioner<T> {
/**
* This method creates a histogram from the counts per key-group in {@link #counterHistogram}.
*/
- private void buildHistogramByAccumulatingCounts() {
+ private int buildHistogramByAccumulatingCounts() {
int sum = 0;
for (int i = 0; i < counterHistogram.length; ++i) {
int currentSlotValue = counterHistogram[i];
counterHistogram[i] = sum;
sum += currentSlotValue;
}
-
- // sanity check that the sum matches the expected number of elements.
- Preconditions.checkState(sum == numberOfElements);
+ return sum;
}
- private void executePartitioning() {
+ private void executePartitioning(int outputNumberOfElements) {
// We repartition the entries by their pre-computed key-groups, using the histogram values as write indexes
- for (int inIdx = 0; inIdx < numberOfElements; ++inIdx) {
+ for (int inIdx = 0; inIdx < outputNumberOfElements; ++inIdx) {
int effectiveKgIdx = elementKeyGroups[inIdx];
int outIdx = counterHistogram[effectiveKgIdx]++;
partitioningDestination[outIdx] = partitioningSource[inIdx];
@@ -272,7 +270,7 @@ public class KeyGroupPartitioner<T> {
}
/**
- * General algorithm to read key-grouped state that was written from a {@link PartitioningResult}
+ * General algorithm to read key-grouped state that was written from a {@link PartitioningResult}.
*
* @param <T> type of the elements to read.
*/
@@ -339,8 +337,6 @@ public class KeyGroupPartitioner<T> {
*/
@FunctionalInterface
public interface KeyGroupElementsConsumer<T> {
-
-
void consume(@Nonnull T element, @Nonnegative int keyGroupId) throws IOException;
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java
index de35979..ca16662 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/KeyedStateFactory.java
@@ -21,22 +21,48 @@ package org.apache.flink.runtime.state;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
import org.apache.flink.runtime.state.internal.InternalKvState;
+import javax.annotation.Nonnull;
+
/** This factory produces concrete internal state objects. */
public interface KeyedStateFactory {
+
+ /**
+ * Creates and returns a new {@link InternalKvState}.
+ *
+ * @param namespaceSerializer TypeSerializer for the state namespace.
+ * @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+ *
+ * @param <N> The type of the namespace.
+ * @param <SV> The type of the stored state value.
+ * @param <S> The type of the public API state.
+ * @param <IS> The type of internal state.
+ */
+ @Nonnull
+ default <N, SV, S extends State, IS extends S> IS createInternalState(
+ @Nonnull TypeSerializer<N> namespaceSerializer,
+ @Nonnull StateDescriptor<S, SV> stateDesc) throws Exception {
+ return createInternalState(namespaceSerializer, stateDesc, StateSnapshotTransformFactory.noTransform());
+ }
+
/**
* Creates and returns a new {@link InternalKvState}.
*
* @param namespaceSerializer TypeSerializer for the state namespace.
* @param stateDesc The {@code StateDescriptor} that contains the name of the state.
+ * @param snapshotTransformFactory factory of state snapshot transformer.
*
* @param <N> The type of the namespace.
* @param <SV> The type of the stored state value.
+ * @param <SEV> The type of the stored state value or entry for collection types (list or map).
* @param <S> The type of the public API state.
* @param <IS> The type of internal state.
*/
- <N, SV, S extends State, IS extends S> IS createInternalState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception;
+ @Nonnull
+ <N, SV, SEV, S extends State, IS extends S> IS createInternalState(
+ @Nonnull TypeSerializer<N> namespaceSerializer,
+ @Nonnull StateDescriptor<S, SV> stateDesc,
+ @Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception;
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
index b0248fc..789551d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredKeyValueStateBackendMetaInfo.java
@@ -29,6 +29,7 @@ import org.apache.flink.util.Preconditions;
import org.apache.flink.util.StateMigrationException;
import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
import java.util.Collections;
import java.util.HashMap;
@@ -50,17 +51,29 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
private final TypeSerializer<N> namespaceSerializer;
@Nonnull
private final TypeSerializer<S> stateSerializer;
+ @Nullable
+ private final StateSnapshotTransformer<S> snapshotTransformer;
public RegisteredKeyValueStateBackendMetaInfo(
- @Nonnull StateDescriptor.Type stateType,
- @Nonnull String name,
- @Nonnull TypeSerializer<N> namespaceSerializer,
- @Nonnull TypeSerializer<S> stateSerializer) {
+ @Nonnull StateDescriptor.Type stateType,
+ @Nonnull String name,
+ @Nonnull TypeSerializer<N> namespaceSerializer,
+ @Nonnull TypeSerializer<S> stateSerializer) {
+ this(stateType, name, namespaceSerializer, stateSerializer, null);
+ }
+
+ public RegisteredKeyValueStateBackendMetaInfo(
+ @Nonnull StateDescriptor.Type stateType,
+ @Nonnull String name,
+ @Nonnull TypeSerializer<N> namespaceSerializer,
+ @Nonnull TypeSerializer<S> stateSerializer,
+ @Nullable StateSnapshotTransformer<S> snapshotTransformer) {
super(name);
this.stateType = stateType;
this.namespaceSerializer = namespaceSerializer;
this.stateSerializer = stateSerializer;
+ this.snapshotTransformer = snapshotTransformer;
}
@SuppressWarnings("unchecked")
@@ -71,7 +84,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
(TypeSerializer<N>) Preconditions.checkNotNull(
snapshot.getTypeSerializer(StateMetaInfoSnapshot.CommonSerializerKeys.NAMESPACE_SERIALIZER)),
(TypeSerializer<S>) Preconditions.checkNotNull(
- snapshot.getTypeSerializer(StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER)));
+ snapshot.getTypeSerializer(StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER)), null);
Preconditions.checkState(StateMetaInfoSnapshot.BackendStateType.KEY_VALUE == snapshot.getBackendStateType());
}
@@ -90,6 +103,11 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
return stateSerializer;
}
+ @Nullable
+ public StateSnapshotTransformer<S> getSnapshotTransformer() {
+ return snapshotTransformer;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) {
@@ -142,7 +160,8 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
public static <N, S> RegisteredKeyValueStateBackendMetaInfo<N, S> resolveKvStateCompatibility(
StateMetaInfoSnapshot restoredStateMetaInfoSnapshot,
TypeSerializer<N> newNamespaceSerializer,
- StateDescriptor<?, S> newStateDescriptor) throws StateMigrationException {
+ StateDescriptor<?, S> newStateDescriptor,
+ @Nullable StateSnapshotTransformer<S> snapshotTransformer) throws StateMigrationException {
Preconditions.checkState(restoredStateMetaInfoSnapshot.getBackendStateType()
== StateMetaInfoSnapshot.BackendStateType.KEY_VALUE,
@@ -196,7 +215,8 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends RegisteredStat
newStateDescriptor.getType(),
newStateDescriptor.getName(),
newNamespaceSerializer,
- newStateSerializer);
+ newStateSerializer,
+ snapshotTransformer);
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
new file mode 100644
index 0000000..cd2c7bf
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/StateSnapshotTransformer.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy;
+
+import javax.annotation.Nullable;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+import static org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy.STOP_ON_FIRST_INCLUDED;
+
+/**
+ * Transformer of state values which are included or skipped in the snapshot.
+ *
+ * <p>This transformer can be applied to state values
+ * to decide which entries should be included into the snapshot.
+ * The included entries can be optionally modified before.
+ *
+ * <p>Unless specified differently, the transformer should be applied per entry
+ * for collection types of state, like list or map.
+ *
+ * @param <T> type of state
+ */
+@FunctionalInterface
+public interface StateSnapshotTransformer<T> {
+ /**
+ * Transform or filter out state values which are included or skipped in the snapshot.
+ *
+ * @param value non-serialized form of value
+ * @return value to snapshot or null which means the entry is not included
+ */
+ @Nullable
+ T filterOrTransform(@Nullable T value);
+
+ /** Collection state specific transformer which says how to transform entries of the collection. */
+ interface CollectionStateSnapshotTransformer<T> extends StateSnapshotTransformer<T> {
+ enum TransformStrategy {
+ /** Transform all entries. */
+ TRANSFORM_ALL,
+
+ /**
+ * Skip first null entries.
+ *
+ * <p>While traversing collection entries, as optimisation, stops transforming
+ * if encounters first non-null included entry and returns it plus the rest untouched.
+ */
+ STOP_ON_FIRST_INCLUDED
+ }
+
+ default TransformStrategy getFilterStrategy() {
+ return TransformStrategy.TRANSFORM_ALL;
+ }
+ }
+
+ /**
+ * General implementation of list state transformer.
+ *
+ * <p>This transformer wraps a transformer per-entry
+ * and transforms the whole list state.
+ * If the wrapped per entry transformer is {@link CollectionStateSnapshotTransformer},
+ * it respects its {@link TransformStrategy}.
+ */
+ class ListStateSnapshotTransformer<T> implements StateSnapshotTransformer<List<T>> {
+ private final StateSnapshotTransformer<T> entryValueTransformer;
+ private final TransformStrategy transformStrategy;
+
+ public ListStateSnapshotTransformer(StateSnapshotTransformer<T> entryValueTransformer) {
+ this.entryValueTransformer = entryValueTransformer;
+ this.transformStrategy = entryValueTransformer instanceof CollectionStateSnapshotTransformer ?
+ ((CollectionStateSnapshotTransformer) entryValueTransformer).getFilterStrategy() :
+ TransformStrategy.TRANSFORM_ALL;
+ }
+
+ @Override
+ @Nullable
+ public List<T> filterOrTransform(@Nullable List<T> list) {
+ if (list == null) {
+ return null;
+ }
+ List<T> transformedList = new ArrayList<>();
+ boolean anyChange = false;
+ for (int i = 0; i < list.size(); i++) {
+ T entry = list.get(i);
+ T transformedEntry = entryValueTransformer.filterOrTransform(entry);
+ if (transformedEntry != null) {
+ if (transformStrategy == STOP_ON_FIRST_INCLUDED) {
+ transformedList = list.subList(i, list.size());
+ anyChange = i > 0;
+ break;
+ } else {
+ transformedList.add(transformedEntry);
+ }
+ }
+ anyChange |= transformedEntry == null || !Objects.equals(entry, transformedEntry);
+ }
+ transformedList = anyChange ? transformedList : list;
+ return transformedList.isEmpty() ? null : transformedList;
+ }
+ }
+
+ /**
+ * General implementation of map state transformer.
+ *
+ * <p>This transformer wraps a transformer per-entry
+ * and transforms the whole map state.
+ */
+ class MapStateSnapshotTransformer<K, V> implements StateSnapshotTransformer<Map<K, V>> {
+ private final StateSnapshotTransformer<V> entryValueTransformer;
+
+ public MapStateSnapshotTransformer(StateSnapshotTransformer<V> entryValueTransformer) {
+ this.entryValueTransformer = entryValueTransformer;
+ }
+
+ @Nullable
+ @Override
+ public Map<K, V> filterOrTransform(@Nullable Map<K, V> map) {
+ if (map == null) {
+ return null;
+ }
+ Map<K, V> transformedMap = new HashMap<>();
+ boolean anyChange = false;
+ for (Map.Entry<K, V> entry : map.entrySet()) {
+ V transformedValue = entryValueTransformer.filterOrTransform(entry.getValue());
+ if (transformedValue != null) {
+ transformedMap.put(entry.getKey(), transformedValue);
+ }
+ anyChange |= transformedValue == null || !Objects.equals(entry.getValue(), transformedValue);
+ }
+ return anyChange ? (transformedMap.isEmpty() ? null : transformedMap) : map;
+ }
+ }
+
+ /**
+ * This factory creates state transformers depending on the form of values to transform.
+ *
+ * <p>If there is no transforming needed, the factory methods return {@code Optional.empty()}.
+ */
+ interface StateSnapshotTransformFactory<T> {
+ StateSnapshotTransformFactory<?> NO_TRANSFORM = createNoTransform();
+
+ @SuppressWarnings("unchecked")
+ static <T> StateSnapshotTransformFactory<T> noTransform() {
+ return (StateSnapshotTransformFactory<T>) NO_TRANSFORM;
+ }
+
+ static <T> StateSnapshotTransformFactory<T> createNoTransform() {
+ return new StateSnapshotTransformFactory<T>() {
+ @Override
+ public Optional<StateSnapshotTransformer<T>> createForDeserializedState() {
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+ return Optional.empty();
+ }
+ };
+ }
+
+ Optional<StateSnapshotTransformer<T>> createForDeserializedState();
+
+ Optional<StateSnapshotTransformer<byte[]>> createForSerializedState();
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
index f3f21dd..11a23bf 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/CopyOnWriteStateTableSnapshot.java
@@ -22,8 +22,10 @@ import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.KeyGroupPartitioner;
+import org.apache.flink.runtime.state.KeyGroupPartitioner.ElementWriterFunction;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
import javax.annotation.Nonnegative;
@@ -34,8 +36,8 @@ import javax.annotation.Nullable;
* This class represents the snapshot of a {@link CopyOnWriteStateTable} and has a role in operator state checkpointing. Besides
* holding the {@link CopyOnWriteStateTable}s internal entries at the time of the snapshot, this class is also responsible for
* preparing and writing the state in the process of checkpointing.
- * <p>
- * IMPORTANT: Please notice that snapshot integrity of entries in this class rely on proper copy-on-write semantics
+ *
+ * <p>IMPORTANT: Please notice that snapshot integrity of entries in this class rely on proper copy-on-write semantics
* through the {@link CopyOnWriteStateTable} that created the snapshot object, but all objects in this snapshot must be considered
* as READ-ONLY!. The reason is that the objects held by this class may or may not be deep copies of original objects
* that may still used in the {@link CopyOnWriteStateTable}. This depends for each entry on whether or not it was subject to
@@ -105,7 +107,6 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
this.snapshotVersion = owningStateTable.getStateTableVersion();
this.numberOfEntriesInSnapshotData = owningStateTable.size();
-
// We create duplicates of the serializers for the async snapshot, because TypeSerializer
// might be stateful and shared with the event processing thread.
this.localKeySerializer = owningStateTable.keyContext.getKeySerializer().duplicate();
@@ -128,35 +129,41 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
* into key-groups. Then, the histogram is accumulated to obtain the boundaries of each key-group in an array.
* Last, we use the accumulated counts as write position pointers for the key-group's bins when reordering the
* entries by key-group. This operation is lazily performed before the first writing of a key-group.
- * <p>
- * As a possible future optimization, we could perform the repartitioning in-place, using a scheme similar to the
+ *
+ * <p>As a possible future optimization, we could perform the repartitioning in-place, using a scheme similar to the
* cuckoo cycles in cuckoo hashing. This can trade some performance for a smaller memory footprint.
*/
@Nonnull
@SuppressWarnings("unchecked")
@Override
public StateKeyGroupWriter getKeyGroupWriter() {
-
if (partitionedStateTableSnapshot == null) {
-
final InternalKeyContext<K> keyContext = owningStateTable.keyContext;
- final KeyGroupRange keyGroupRange = keyContext.getKeyGroupRange();
final int numberOfKeyGroups = keyContext.getNumberOfKeyGroups();
-
- final StateTableKeyGroupPartitioner<K, N, S> keyGroupPartitioner = new StateTableKeyGroupPartitioner<>(
- snapshotData,
- numberOfEntriesInSnapshotData,
- keyGroupRange,
- numberOfKeyGroups,
+ final KeyGroupRange keyGroupRange = keyContext.getKeyGroupRange();
+ ElementWriterFunction<CopyOnWriteStateTable.StateTableEntry<K, N, S>> elementWriterFunction =
(element, dov) -> {
localNamespaceSerializer.serialize(element.namespace, dov);
localKeySerializer.serialize(element.key, dov);
localStateSerializer.serialize(element.state, dov);
- });
-
- partitionedStateTableSnapshot = keyGroupPartitioner.partitionByKeyGroup();
+ };
+ StateSnapshotTransformer<S> stateSnapshotTransformer = owningStateTable.metaInfo.getSnapshotTransformer();
+ StateTableKeyGroupPartitioner<K, N, S> stateTableKeyGroupPartitioner = stateSnapshotTransformer != null ?
+ new TransformingStateTableKeyGroupPartitioner<>(
+ snapshotData,
+ numberOfEntriesInSnapshotData,
+ keyGroupRange,
+ numberOfKeyGroups,
+ elementWriterFunction,
+ stateSnapshotTransformer) :
+ new StateTableKeyGroupPartitioner<>(
+ snapshotData,
+ numberOfEntriesInSnapshotData,
+ keyGroupRange,
+ numberOfKeyGroups,
+ elementWriterFunction);
+ partitionedStateTableSnapshot = stateTableKeyGroupPartitioner.partitionByKeyGroup();
}
-
return partitionedStateTableSnapshot;
}
@@ -188,7 +195,7 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
* @param <S> type of state value.
*/
@VisibleForTesting
- protected static final class StateTableKeyGroupPartitioner<K, N, S>
+ protected static class StateTableKeyGroupPartitioner<K, N, S>
extends KeyGroupPartitioner<CopyOnWriteStateTable.StateTableEntry<K, N, S>> {
@SuppressWarnings("unchecked")
@@ -217,12 +224,67 @@ public class CopyOnWriteStateTableSnapshot<K, N, S>
int flattenIndex = 0;
for (CopyOnWriteStateTable.StateTableEntry<K, N, S> entry : partitioningDestination) {
while (null != entry) {
- final int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(entry.key, totalKeyGroups);
- reportKeyGroupOfElementAtIndex(flattenIndex, keyGroup);
- partitioningSource[flattenIndex++] = entry;
+ flattenIndex = tryAddToSource(flattenIndex, entry);
entry = entry.next;
}
}
}
+
+ /** Tries to append next entry to {@code partitioningSource} array snapshot and returns next index.*/
+ int tryAddToSource(int currentIndex, CopyOnWriteStateTable.StateTableEntry<K, N, S> entry) {
+ final int keyGroup = KeyGroupRangeAssignment.assignToKeyGroup(entry.key, totalKeyGroups);
+ reportKeyGroupOfElementAtIndex(currentIndex, keyGroup);
+ partitioningSource[currentIndex] = entry;
+ return currentIndex + 1;
+ }
+ }
+
+ /**
+ * Extended state snapshot transforming {@link StateTableKeyGroupPartitioner}.
+ *
+ * <p>This partitioner can additionally transform state before including or not into the snapshot.
+ */
+ protected static final class TransformingStateTableKeyGroupPartitioner<K, N, S>
+ extends StateTableKeyGroupPartitioner<K, N, S> {
+ private final StateSnapshotTransformer<S> stateSnapshotTransformer;
+
+ TransformingStateTableKeyGroupPartitioner(
+ @Nonnull CopyOnWriteStateTable.StateTableEntry<K, N, S>[] snapshotData,
+ int stateTableSize,
+ @Nonnull KeyGroupRange keyGroupRange,
+ int totalKeyGroups,
+ @Nonnull ElementWriterFunction<CopyOnWriteStateTable.StateTableEntry<K, N, S>> elementWriterFunction,
+ @Nonnull StateSnapshotTransformer<S> stateSnapshotTransformer) {
+ super(
+ snapshotData,
+ stateTableSize,
+ keyGroupRange,
+ totalKeyGroups,
+ elementWriterFunction);
+ this.stateSnapshotTransformer = stateSnapshotTransformer;
+ }
+
+ @Override
+ int tryAddToSource(int currentIndex, CopyOnWriteStateTable.StateTableEntry<K, N, S> entry) {
+ CopyOnWriteStateTable.StateTableEntry<K, N, S> filteredEntry = filterEntry(entry);
+ if (filteredEntry != null) {
+ return tryAddToSource(currentIndex, filteredEntry);
+ }
+ return currentIndex;
+ }
+
+ private CopyOnWriteStateTable.StateTableEntry<K, N, S> filterEntry(
+ CopyOnWriteStateTable.StateTableEntry<K, N, S> entry) {
+ S transformedValue = stateSnapshotTransformer.filterOrTransform(entry.state);
+ if (transformedValue != null) {
+ CopyOnWriteStateTable.StateTableEntry<K, N, S> filteredEntry = entry;
+ if (transformedValue != entry.state) {
+ filteredEntry = new CopyOnWriteStateTable.StateTableEntry<>(entry, entry.entryVersion);
+ filteredEntry.state = transformedValue;
+ }
+ return filteredEntry;
+ }
+ return null;
+ }
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
index 34c9698..6d2bfef 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapKeyedStateBackend.java
@@ -62,8 +62,10 @@ import org.apache.flink.runtime.state.SnappyStreamCompressionDecorator;
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.SnapshotStrategy;
import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.StateSnapshotKeyGroupReader;
import org.apache.flink.runtime.state.StateSnapshotRestore;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
import org.apache.flink.runtime.state.StreamCompressionDecorator;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
@@ -89,6 +91,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
import java.util.concurrent.RunnableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -247,7 +250,9 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
}
private <N, V> StateTable<K, N, V> tryRegisterStateTable(
- TypeSerializer<N> namespaceSerializer, StateDescriptor<?, V> stateDesc) throws StateMigrationException {
+ TypeSerializer<N> namespaceSerializer,
+ StateDescriptor<?, V> stateDesc,
+ StateSnapshotTransformer<V> snapshotTransformer) throws StateMigrationException {
@SuppressWarnings("unchecked")
StateTable<K, N, V> stateTable = (StateTable<K, N, V>) registeredKVStates.get(stateDesc.getName());
@@ -267,7 +272,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
newMetaInfo = RegisteredKeyValueStateBackendMetaInfo.resolveKvStateCompatibility(
restoredMetaInfoSnapshot,
namespaceSerializer,
- stateDesc);
+ stateDesc,
+ snapshotTransformer);
stateTable.setMetaInfo(newMetaInfo);
} else {
@@ -275,7 +281,8 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
stateDesc.getType(),
stateDesc.getName(),
namespaceSerializer,
- stateDesc.getSerializer());
+ stateDesc.getSerializer(),
+ snapshotTransformer);
stateTable = snapshotStrategy.newStateTable(newMetaInfo);
registeredKVStates.put(stateDesc.getName(), stateTable);
@@ -301,19 +308,42 @@ public class HeapKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
}
@Override
- public <N, SV, S extends State, IS extends S> IS createInternalState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
+ @Nonnull
+ public <N, SV, SEV, S extends State, IS extends S> IS createInternalState(
+ @Nonnull TypeSerializer<N> namespaceSerializer,
+ @Nonnull StateDescriptor<S, SV> stateDesc,
+ @Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception {
StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
if (stateFactory == null) {
String message = String.format("State %s is not supported by %s",
stateDesc.getClass(), this.getClass());
throw new FlinkRuntimeException(message);
}
- StateTable<K, N, SV> stateTable = tryRegisterStateTable(namespaceSerializer, stateDesc);
+ StateTable<K, N, SV> stateTable = tryRegisterStateTable(
+ namespaceSerializer, stateDesc, getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
return stateFactory.createState(stateDesc, stateTable, keySerializer);
}
+ @SuppressWarnings("unchecked")
+ private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
+ StateDescriptor<?, SV> stateDesc,
+ StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+ Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
+ if (original.isPresent()) {
+ if (stateDesc instanceof ListStateDescriptor) {
+ return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
+ .ListStateSnapshotTransformer<>(original.get());
+ } else if (stateDesc instanceof MapStateDescriptor) {
+ return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
+ .MapStateSnapshotTransformer<>(original.get());
+ } else {
+ return (StateSnapshotTransformer<SV>) original.get();
+ }
+ } else {
+ return null;
+ }
+ }
+
@Override
@SuppressWarnings("unchecked")
public RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
index efed1cc..f982370 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/NestedMapsStateTable.java
@@ -15,6 +15,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package org.apache.flink.runtime.state.heap;
import org.apache.flink.annotation.Internal;
@@ -24,6 +25,7 @@ import org.apache.flink.core.memory.DataOutputView;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
import org.apache.flink.runtime.state.StateSnapshot;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.StateTransformationFunction;
import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
import org.apache.flink.util.Preconditions;
@@ -41,8 +43,8 @@ import java.util.stream.Stream;
/**
* This implementation of {@link StateTable} uses nested {@link HashMap} objects. It is also maintaining a partitioning
* by key-group.
- * <p>
- * In contrast to {@link CopyOnWriteStateTable}, this implementation does not support asynchronous snapshots. However,
+ *
+ * <p>In contrast to {@link CopyOnWriteStateTable}, this implementation does not support asynchronous snapshots. However,
* it might have a better memory footprint for some use-cases, e.g. it is naturally de-duplicating namespace objects.
*
* @param <K> type of key.
@@ -59,7 +61,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
private final Map<N, Map<K, S>>[] state;
/**
- * The offset to the contiguous key groups
+ * The offset to the contiguous key groups.
*/
private final int keyGroupOffset;
@@ -317,7 +319,7 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
@Nonnull
@Override
public NestedMapsStateTableSnapshot<K, N, S> stateSnapshot() {
- return new NestedMapsStateTableSnapshot<>(this);
+ return new NestedMapsStateTableSnapshot<>(this, metaInfo.getSnapshotTransformer());
}
/**
@@ -330,9 +332,17 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
static class NestedMapsStateTableSnapshot<K, N, S>
extends AbstractStateTableSnapshot<K, N, S, NestedMapsStateTable<K, N, S>>
implements StateSnapshot.StateKeyGroupWriter {
+ private final TypeSerializer<K> keySerializer;
+ private final TypeSerializer<N> namespaceSerializer;
+ private final TypeSerializer<S> stateSerializer;
+ private final StateSnapshotTransformer<S> snapshotFilter;
- NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> owningTable) {
+ NestedMapsStateTableSnapshot(NestedMapsStateTable<K, N, S> owningTable, StateSnapshotTransformer<S> snapshotFilter) {
super(owningTable);
+ this.snapshotFilter = snapshotFilter;
+ this.keySerializer = owningStateTable.keyContext.getKeySerializer();
+ this.namespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer();
+ this.stateSerializer = owningStateTable.metaInfo.getStateSerializer();
}
@Nonnull
@@ -350,8 +360,8 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
/**
* Implementation note: we currently chose the same format between {@link NestedMapsStateTable} and
* {@link CopyOnWriteStateTable}.
- * <p>
- * {@link NestedMapsStateTable} could naturally support a kind of
+ *
+ * <p>{@link NestedMapsStateTable} could naturally support a kind of
* prefix-compressed format (grouping by namespace, writing the namespace only once per group instead for each
* mapping). We might implement support for different formats later (tailored towards different state table
* implementations).
@@ -360,23 +370,45 @@ public class NestedMapsStateTable<K, N, S> extends StateTable<K, N, S> {
public void writeStateInKeyGroup(@Nonnull DataOutputView dov, int keyGroupId) throws IOException {
final Map<N, Map<K, S>> keyGroupMap = owningStateTable.getMapForKeyGroup(keyGroupId);
if (null != keyGroupMap) {
- TypeSerializer<K> keySerializer = owningStateTable.keyContext.getKeySerializer();
- TypeSerializer<N> namespaceSerializer = owningStateTable.metaInfo.getNamespaceSerializer();
- TypeSerializer<S> stateSerializer = owningStateTable.metaInfo.getStateSerializer();
- dov.writeInt(countMappingsInKeyGroup(keyGroupMap));
- for (Map.Entry<N, Map<K, S>> namespaceEntry : keyGroupMap.entrySet()) {
+ Map<N, Map<K, S>> filteredMappings = filterMappingsInKeyGroupIfNeeded(keyGroupMap);
+ dov.writeInt(countMappingsInKeyGroup(filteredMappings));
+ for (Map.Entry<N, Map<K, S>> namespaceEntry : filteredMappings.entrySet()) {
final N namespace = namespaceEntry.getKey();
final Map<K, S> namespaceMap = namespaceEntry.getValue();
-
for (Map.Entry<K, S> keyEntry : namespaceMap.entrySet()) {
- namespaceSerializer.serialize(namespace, dov);
- keySerializer.serialize(keyEntry.getKey(), dov);
- stateSerializer.serialize(keyEntry.getValue(), dov);
+ writeElement(namespace, keyEntry, dov);
}
}
} else {
dov.writeInt(0);
}
}
+
+ private void writeElement(N namespace, Map.Entry<K, S> keyEntry, DataOutputView dov) throws IOException {
+ namespaceSerializer.serialize(namespace, dov);
+ keySerializer.serialize(keyEntry.getKey(), dov);
+ stateSerializer.serialize(keyEntry.getValue(), dov);
+ }
+
+ private Map<N, Map<K, S>> filterMappingsInKeyGroupIfNeeded(final Map<N, Map<K, S>> keyGroupMap) {
+ return snapshotFilter == null ?
+ keyGroupMap : filterMappingsInKeyGroup(keyGroupMap);
+ }
+
+ private Map<N, Map<K, S>> filterMappingsInKeyGroup(final Map<N, Map<K, S>> keyGroupMap) {
+ Map<N, Map<K, S>> filtered = new HashMap<>();
+ for (Map.Entry<N, Map<K, S>> namespaceEntry : keyGroupMap.entrySet()) {
+ N namespace = namespaceEntry.getKey();
+ Map<K, S> filteredNamespaceMap = filtered.computeIfAbsent(namespace, n -> new HashMap<>());
+ for (Map.Entry<K, S> keyEntry : namespaceEntry.getValue().entrySet()) {
+ K key = keyEntry.getKey();
+ S transformedvalue = snapshotFilter.filterOrTransform(keyEntry.getValue());
+ if (transformedvalue != null) {
+ filteredNamespaceMap.put(key, transformedvalue);
+ }
+ }
+ }
+ return filtered;
+ }
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
index 58a2f97..30f96f3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/StateTable.java
@@ -46,7 +46,7 @@ public abstract class StateTable<K, N, S> implements StateSnapshotRestore {
protected final InternalKeyContext<K> keyContext;
/**
- * Combined meta information such as name and serializers for this state
+ * Combined meta information such as name and serializers for this state.
*/
protected RegisteredKeyValueStateBackendMetaInfo<N, S> metaInfo;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlDecorator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlDecorator.java
index 1b72c54..268f84a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlDecorator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlDecorator.java
@@ -18,14 +18,12 @@
package org.apache.flink.runtime.state.ttl;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.SupplierWithException;
import org.apache.flink.util.function.ThrowingConsumer;
import org.apache.flink.util.function.ThrowingRunnable;
-import javax.annotation.Nonnull;
-
/**
* Base class for TTL logic wrappers.
*
@@ -35,7 +33,7 @@ abstract class AbstractTtlDecorator<T> {
/** Wrapped original state handler. */
final T original;
- final StateTtlConfiguration config;
+ final StateTtlConfig config;
final TtlTimeProvider timeProvider;
@@ -50,7 +48,7 @@ abstract class AbstractTtlDecorator<T> {
AbstractTtlDecorator(
T original,
- StateTtlConfiguration config,
+ StateTtlConfig config,
TtlTimeProvider timeProvider) {
Preconditions.checkNotNull(original);
Preconditions.checkNotNull(config);
@@ -58,8 +56,8 @@ abstract class AbstractTtlDecorator<T> {
this.original = original;
this.config = config;
this.timeProvider = timeProvider;
- this.updateTsOnRead = config.getTtlUpdateType() == StateTtlConfiguration.TtlUpdateType.OnReadAndWrite;
- this.returnExpired = config.getStateVisibility() == StateTtlConfiguration.TtlStateVisibility.ReturnExpiredIfNotCleanedUp;
+ this.updateTsOnRead = config.getUpdateType() == StateTtlConfig.UpdateType.OnReadAndWrite;
+ this.returnExpired = config.getStateVisibility() == StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp;
this.ttl = config.getTtl().toMilliseconds();
}
@@ -68,21 +66,11 @@ abstract class AbstractTtlDecorator<T> {
}
<V> boolean expired(TtlValue<V> ttlValue) {
- return ttlValue != null && getExpirationTimestamp(ttlValue) <= timeProvider.currentTimestamp();
- }
-
- private long getExpirationTimestamp(@Nonnull TtlValue<?> ttlValue) {
- long ts = ttlValue.getLastAccessTimestamp();
- long ttlWithoutOverflow = ts > 0 ? Math.min(Long.MAX_VALUE - ts, ttl) : ttl;
- return ts + ttlWithoutOverflow;
+ return TtlUtils.expired(ttlValue, ttl, timeProvider);
}
<V> TtlValue<V> wrapWithTs(V value) {
- return wrapWithTs(value, timeProvider.currentTimestamp());
- }
-
- static <V> TtlValue<V> wrapWithTs(V value, long ts) {
- return value == null ? null : new TtlValue<>(value, ts);
+ return TtlUtils.wrapWithTs(value, timeProvider.currentTimestamp());
}
<V> TtlValue<V> rewrapWithNewTs(TtlValue<V> ttlValue) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlState.java
index 7903796..5d1af8a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/AbstractTtlState.java
@@ -18,7 +18,7 @@
package org.apache.flink.runtime.state.ttl;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.internal.InternalKvState;
import org.apache.flink.util.FlinkRuntimeException;
@@ -39,7 +39,7 @@ abstract class AbstractTtlState<K, N, SV, TTLSV, S extends InternalKvState<K, N,
implements InternalKvState<K, N, SV> {
private final TypeSerializer<SV> valueSerializer;
- AbstractTtlState(S original, StateTtlConfiguration config, TtlTimeProvider timeProvider, TypeSerializer<SV> valueSerializer) {
+ AbstractTtlState(S original, StateTtlConfig config, TtlTimeProvider timeProvider, TypeSerializer<SV> valueSerializer) {
super(original, config, timeProvider);
this.valueSerializer = valueSerializer;
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregateFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregateFunction.java
index 5448ba1..07f538a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregateFunction.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregateFunction.java
@@ -19,7 +19,7 @@
package org.apache.flink.runtime.state.ttl;
import org.apache.flink.api.common.functions.AggregateFunction;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.ThrowingConsumer;
@@ -38,7 +38,7 @@ class TtlAggregateFunction<IN, ACC, OUT>
ThrowingRunnable<Exception> stateClear;
ThrowingConsumer<TtlValue<ACC>, Exception> updater;
- TtlAggregateFunction(AggregateFunction<IN, ACC, OUT> aggFunction, StateTtlConfiguration config, TtlTimeProvider timeProvider) {
+ TtlAggregateFunction(AggregateFunction<IN, ACC, OUT> aggFunction, StateTtlConfig config, TtlTimeProvider timeProvider) {
super(aggFunction, config, timeProvider);
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregatingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregatingState.java
index a90698e..94f489d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregatingState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlAggregatingState.java
@@ -18,7 +18,7 @@
package org.apache.flink.runtime.state.ttl;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.internal.InternalAggregatingState;
@@ -40,7 +40,7 @@ class TtlAggregatingState<K, N, IN, ACC, OUT>
TtlAggregatingState(
InternalAggregatingState<K, N, IN, TtlValue<ACC>, OUT> originalState,
- StateTtlConfiguration config,
+ StateTtlConfig config,
TtlTimeProvider timeProvider,
TypeSerializer<ACC> valueSerializer,
TtlAggregateFunction<IN, ACC, OUT> aggregateFunction) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldFunction.java
index c7305bd..a8f4e6a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldFunction.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldFunction.java
@@ -19,7 +19,7 @@
package org.apache.flink.runtime.state.ttl;
import org.apache.flink.api.common.functions.FoldFunction;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
/**
* This class wraps folding function with TTL logic.
@@ -36,7 +36,7 @@ class TtlFoldFunction<T, ACC>
private final ACC defaultAccumulator;
TtlFoldFunction(
- FoldFunction<T, ACC> original, StateTtlConfiguration config, TtlTimeProvider timeProvider, ACC defaultAccumulator) {
+ FoldFunction<T, ACC> original, StateTtlConfig config, TtlTimeProvider timeProvider, ACC defaultAccumulator) {
super(original, config, timeProvider);
this.defaultAccumulator = defaultAccumulator;
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldingState.java
index c3a75e4..4c64ae3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldingState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlFoldingState.java
@@ -19,7 +19,7 @@
package org.apache.flink.runtime.state.ttl;
import org.apache.flink.api.common.state.AggregatingState;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.internal.InternalFoldingState;
@@ -37,7 +37,7 @@ class TtlFoldingState<K, N, T, ACC>
implements InternalFoldingState<K, N, T, ACC> {
TtlFoldingState(
InternalFoldingState<K, N, T, TtlValue<ACC>> originalState,
- StateTtlConfiguration config,
+ StateTtlConfig config,
TtlTimeProvider timeProvider,
TypeSerializer<ACC> valueSerializer) {
super(originalState, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlListState.java
index 77e92f6..cb64df7 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlListState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlListState.java
@@ -18,7 +18,7 @@
package org.apache.flink.runtime.state.ttl;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.internal.InternalListState;
import org.apache.flink.util.Preconditions;
@@ -43,7 +43,7 @@ class TtlListState<K, N, T> extends
implements InternalListState<K, N, T> {
TtlListState(
InternalListState<K, N, TtlValue<T>> originalState,
- StateTtlConfiguration config,
+ StateTtlConfig config,
TtlTimeProvider timeProvider,
TypeSerializer<List<T>> valueSerializer) {
super(originalState, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlMapState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlMapState.java
index 21145e5..f6f81ff 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlMapState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlMapState.java
@@ -18,7 +18,7 @@
package org.apache.flink.runtime.state.ttl;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.internal.InternalMapState;
import org.apache.flink.util.FlinkRuntimeException;
@@ -46,7 +46,7 @@ class TtlMapState<K, N, UK, UV>
implements InternalMapState<K, N, UK, UV> {
TtlMapState(
InternalMapState<K, N, UK, TtlValue<UV>> original,
- StateTtlConfiguration config,
+ StateTtlConfig config,
TtlTimeProvider timeProvider,
TypeSerializer<Map<UK, UV>> valueSerializer) {
super(original, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReduceFunction.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReduceFunction.java
index fa7c8bf..823c55d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReduceFunction.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReduceFunction.java
@@ -19,7 +19,7 @@
package org.apache.flink.runtime.state.ttl;
import org.apache.flink.api.common.functions.ReduceFunction;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
/**
* This class wraps reducing function with TTL logic.
@@ -32,7 +32,7 @@ class TtlReduceFunction<T>
TtlReduceFunction(
ReduceFunction<T> originalReduceFunction,
- StateTtlConfiguration config,
+ StateTtlConfig config,
TtlTimeProvider timeProvider) {
super(originalReduceFunction, config, timeProvider);
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReducingState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReducingState.java
index c0aa465..0710808 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReducingState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlReducingState.java
@@ -18,7 +18,7 @@
package org.apache.flink.runtime.state.ttl;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.internal.InternalReducingState;
@@ -36,7 +36,7 @@ class TtlReducingState<K, N, T>
implements InternalReducingState<K, N, T> {
TtlReducingState(
InternalReducingState<K, N, TtlValue<T>> originalState,
- StateTtlConfiguration config,
+ StateTtlConfig config,
TtlTimeProvider timeProvider,
TypeSerializer<T> valueSerializer) {
super(originalState, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
index e12ba90..303285a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateFactory.java
@@ -25,15 +25,17 @@ import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.typeutils.CompositeSerializer;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.state.KeyedStateFactory;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.function.SupplierWithException;
import javax.annotation.Nonnull;
@@ -44,7 +46,7 @@ import java.util.stream.Stream;
/**
* This state factory wraps state objects, produced by backends, with TTL logic.
*/
-public class TtlStateFactory {
+public class TtlStateFactory<N, SV, S extends State, IS extends S> {
public static <N, SV, S extends State, IS extends S> IS createStateAndWrapWithTtlIfEnabled(
TypeSerializer<N> namespaceSerializer,
StateDescriptor<S, SV> stateDesc,
@@ -55,103 +57,103 @@ public class TtlStateFactory {
Preconditions.checkNotNull(originalStateFactory);
Preconditions.checkNotNull(timeProvider);
return stateDesc.getTtlConfig().isEnabled() ?
- new TtlStateFactory(originalStateFactory, stateDesc.getTtlConfig(), timeProvider)
- .createState(namespaceSerializer, stateDesc) :
+ new TtlStateFactory<N, SV, S, IS>(
+ namespaceSerializer, stateDesc, originalStateFactory, timeProvider)
+ .createState() :
originalStateFactory.createInternalState(namespaceSerializer, stateDesc);
}
- private final Map<Class<? extends StateDescriptor>, KeyedStateFactory> stateFactories;
+ private final Map<Class<? extends StateDescriptor>, SupplierWithException<IS, Exception>> stateFactories;
+ private final TypeSerializer<N> namespaceSerializer;
+ private final StateDescriptor<S, SV> stateDesc;
private final KeyedStateFactory originalStateFactory;
- private final StateTtlConfiguration ttlConfig;
+ private final StateTtlConfig ttlConfig;
private final TtlTimeProvider timeProvider;
+ private final long ttl;
- private TtlStateFactory(KeyedStateFactory originalStateFactory, StateTtlConfiguration ttlConfig, TtlTimeProvider timeProvider) {
+ private TtlStateFactory(
+ TypeSerializer<N> namespaceSerializer,
+ StateDescriptor<S, SV> stateDesc,
+ KeyedStateFactory originalStateFactory,
+ TtlTimeProvider timeProvider) {
+ this.namespaceSerializer = namespaceSerializer;
+ this.stateDesc = stateDesc;
this.originalStateFactory = originalStateFactory;
- this.ttlConfig = ttlConfig;
+ this.ttlConfig = stateDesc.getTtlConfig();
this.timeProvider = timeProvider;
+ this.ttl = ttlConfig.getTtl().toMilliseconds();
this.stateFactories = createStateFactories();
}
@SuppressWarnings("deprecation")
- private Map<Class<? extends StateDescriptor>, KeyedStateFactory> createStateFactories() {
+ private Map<Class<? extends StateDescriptor>, SupplierWithException<IS, Exception>> createStateFactories() {
return Stream.of(
- Tuple2.of(ValueStateDescriptor.class, (KeyedStateFactory) this::createValueState),
- Tuple2.of(ListStateDescriptor.class, (KeyedStateFactory) this::createListState),
- Tuple2.of(MapStateDescriptor.class, (KeyedStateFactory) this::createMapState),
- Tuple2.of(ReducingStateDescriptor.class, (KeyedStateFactory) this::createReducingState),
- Tuple2.of(AggregatingStateDescriptor.class, (KeyedStateFactory) this::createAggregatingState),
- Tuple2.of(FoldingStateDescriptor.class, (KeyedStateFactory) this::createFoldingState)
+ Tuple2.of(ValueStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createValueState),
+ Tuple2.of(ListStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createListState),
+ Tuple2.of(MapStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createMapState),
+ Tuple2.of(ReducingStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createReducingState),
+ Tuple2.of(AggregatingStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createAggregatingState),
+ Tuple2.of(FoldingStateDescriptor.class, (SupplierWithException<IS, Exception>) this::createFoldingState)
).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
}
- private <N, SV, S extends State, IS extends S> IS createState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
- KeyedStateFactory stateFactory = stateFactories.get(stateDesc.getClass());
+ private IS createState() throws Exception {
+ SupplierWithException<IS, Exception> stateFactory = stateFactories.get(stateDesc.getClass());
if (stateFactory == null) {
String message = String.format("State %s is not supported by %s",
stateDesc.getClass(), TtlStateFactory.class);
throw new FlinkRuntimeException(message);
}
- return stateFactory.createInternalState(namespaceSerializer, stateDesc);
+ return stateFactory.get();
}
@SuppressWarnings("unchecked")
- private <N, SV, S extends State, IS extends S> IS createValueState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
+ private IS createValueState() throws Exception {
ValueStateDescriptor<TtlValue<SV>> ttlDescriptor = new ValueStateDescriptor<>(
stateDesc.getName(), new TtlSerializer<>(stateDesc.getSerializer()));
return (IS) new TtlValueState<>(
- originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+ originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
ttlConfig, timeProvider, stateDesc.getSerializer());
}
@SuppressWarnings("unchecked")
- private <T, N, SV, S extends State, IS extends S> IS createListState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
+ private <T> IS createListState() throws Exception {
ListStateDescriptor<T> listStateDesc = (ListStateDescriptor<T>) stateDesc;
ListStateDescriptor<TtlValue<T>> ttlDescriptor = new ListStateDescriptor<>(
stateDesc.getName(), new TtlSerializer<>(listStateDesc.getElementSerializer()));
return (IS) new TtlListState<>(
- originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+ originalStateFactory.createInternalState(
+ namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
ttlConfig, timeProvider, listStateDesc.getSerializer());
}
@SuppressWarnings("unchecked")
- private <UK, UV, N, SV, S extends State, IS extends S> IS createMapState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
+ private <UK, UV> IS createMapState() throws Exception {
MapStateDescriptor<UK, UV> mapStateDesc = (MapStateDescriptor<UK, UV>) stateDesc;
MapStateDescriptor<UK, TtlValue<UV>> ttlDescriptor = new MapStateDescriptor<>(
stateDesc.getName(),
mapStateDesc.getKeySerializer(),
new TtlSerializer<>(mapStateDesc.getValueSerializer()));
return (IS) new TtlMapState<>(
- originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+ originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
ttlConfig, timeProvider, mapStateDesc.getSerializer());
}
@SuppressWarnings("unchecked")
- private <N, SV, S extends State, IS extends S> IS createReducingState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
+ private IS createReducingState() throws Exception {
ReducingStateDescriptor<SV> reducingStateDesc = (ReducingStateDescriptor<SV>) stateDesc;
ReducingStateDescriptor<TtlValue<SV>> ttlDescriptor = new ReducingStateDescriptor<>(
stateDesc.getName(),
new TtlReduceFunction<>(reducingStateDesc.getReduceFunction(), ttlConfig, timeProvider),
new TtlSerializer<>(stateDesc.getSerializer()));
return (IS) new TtlReducingState<>(
- originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+ originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
ttlConfig, timeProvider, stateDesc.getSerializer());
}
@SuppressWarnings("unchecked")
- private <IN, OUT, N, SV, S extends State, IS extends S> IS createAggregatingState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
+ private <IN, OUT> IS createAggregatingState() throws Exception {
AggregatingStateDescriptor<IN, SV, OUT> aggregatingStateDescriptor =
(AggregatingStateDescriptor<IN, SV, OUT>) stateDesc;
TtlAggregateFunction<IN, SV, OUT> ttlAggregateFunction = new TtlAggregateFunction<>(
@@ -159,14 +161,12 @@ public class TtlStateFactory {
AggregatingStateDescriptor<IN, TtlValue<SV>, OUT> ttlDescriptor = new AggregatingStateDescriptor<>(
stateDesc.getName(), ttlAggregateFunction, new TtlSerializer<>(stateDesc.getSerializer()));
return (IS) new TtlAggregatingState<>(
- originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+ originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
ttlConfig, timeProvider, stateDesc.getSerializer(), ttlAggregateFunction);
}
@SuppressWarnings({"deprecation", "unchecked"})
- private <T, N, SV, S extends State, IS extends S> IS createFoldingState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
+ private <T> IS createFoldingState() throws Exception {
FoldingStateDescriptor<T, SV> foldingStateDescriptor = (FoldingStateDescriptor<T, SV>) stateDesc;
SV initAcc = stateDesc.getDefaultValue();
TtlValue<SV> ttlInitAcc = initAcc == null ? null : new TtlValue<>(initAcc, Long.MAX_VALUE);
@@ -176,10 +176,18 @@ public class TtlStateFactory {
new TtlFoldFunction<>(foldingStateDescriptor.getFoldFunction(), ttlConfig, timeProvider, initAcc),
new TtlSerializer<>(stateDesc.getSerializer()));
return (IS) new TtlFoldingState<>(
- originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor),
+ originalStateFactory.createInternalState(namespaceSerializer, ttlDescriptor, getSnapshotTransformFactory()),
ttlConfig, timeProvider, stateDesc.getSerializer());
}
+ private StateSnapshotTransformFactory<?> getSnapshotTransformFactory() {
+ if (!ttlConfig.getCleanupStrategies().inFullSnapshot()) {
+ return StateSnapshotTransformFactory.noTransform();
+ } else {
+ return new TtlStateSnapshotTransformer.Factory<>(timeProvider, ttl);
+ }
+ }
+
/** Serializer for user state value with TTL. */
private static class TtlSerializer<T> extends CompositeSerializer<TtlValue<T>> {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
new file mode 100644
index 0000000..228d045
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlStateSnapshotTransformer.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.ttl;
+
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.apache.flink.util.Preconditions;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.util.Optional;
+
+/** State snapshot filter of expired values with TTL. */
+abstract class TtlStateSnapshotTransformer<T> implements CollectionStateSnapshotTransformer<T> {
+ private final TtlTimeProvider ttlTimeProvider;
+ final long ttl;
+
+ TtlStateSnapshotTransformer(@Nonnull TtlTimeProvider ttlTimeProvider, long ttl) {
+ this.ttlTimeProvider = ttlTimeProvider;
+ this.ttl = ttl;
+ }
+
+ <V> TtlValue<V> filterTtlValue(TtlValue<V> value) {
+ return expired(value) ? null : value;
+ }
+
+ private boolean expired(TtlValue<?> ttlValue) {
+ return expired(ttlValue.getLastAccessTimestamp());
+ }
+
+ boolean expired(long ts) {
+ return TtlUtils.expired(ts, ttl, ttlTimeProvider);
+ }
+
+ private static long deserializeTs(
+ byte[] value, int offset) throws IOException {
+ return LongSerializer.INSTANCE.deserialize(
+ new DataInputViewStreamWrapper(new ByteArrayInputStream(value, offset, Long.BYTES)));
+ }
+
+ @Override
+ public TransformStrategy getFilterStrategy() {
+ return TransformStrategy.STOP_ON_FIRST_INCLUDED;
+ }
+
+ static class TtlDeserializedValueStateSnapshotTransformer<T> extends TtlStateSnapshotTransformer<TtlValue<T>> {
+ TtlDeserializedValueStateSnapshotTransformer(TtlTimeProvider ttlTimeProvider, long ttl) {
+ super(ttlTimeProvider, ttl);
+ }
+
+ @Override
+ @Nullable
+ public TtlValue<T> filterOrTransform(@Nullable TtlValue<T> value) {
+ return filterTtlValue(value);
+ }
+ }
+
+ static class TtlSerializedValueStateSnapshotTransformer extends TtlStateSnapshotTransformer<byte[]> {
+ TtlSerializedValueStateSnapshotTransformer(TtlTimeProvider ttlTimeProvider, long ttl) {
+ super(ttlTimeProvider, ttl);
+ }
+
+ @Override
+ @Nullable
+ public byte[] filterOrTransform(@Nullable byte[] value) {
+ if (value == null) {
+ return null;
+ }
+ Preconditions.checkArgument(value.length >= Long.BYTES);
+ long ts;
+ try {
+ ts = deserializeTs(value, value.length - Long.BYTES);
+ } catch (IOException e) {
+ throw new FlinkRuntimeException("Unexpected timestamp deserialization failure");
+ }
+ return expired(ts) ? null : value;
+ }
+ }
+
+ static class Factory<T> implements StateSnapshotTransformFactory<TtlValue<T>> {
+ private final TtlTimeProvider ttlTimeProvider;
+ private final long ttl;
+
+ Factory(@Nonnull TtlTimeProvider ttlTimeProvider, long ttl) {
+ this.ttlTimeProvider = ttlTimeProvider;
+ this.ttl = ttl;
+ }
+
+ @Override
+ public Optional<StateSnapshotTransformer<TtlValue<T>>> createForDeserializedState() {
+ return Optional.of(new TtlDeserializedValueStateSnapshotTransformer<>(ttlTimeProvider, ttl));
+ }
+
+ @Override
+ public Optional<StateSnapshotTransformer<byte[]>> createForSerializedState() {
+ return Optional.of(new TtlSerializedValueStateSnapshotTransformer(ttlTimeProvider, ttl));
+ }
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlUtils.java
new file mode 100644
index 0000000..9d9e5e1
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlUtils.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.ttl;
+
+/** Common functions related to State TTL. */
+class TtlUtils {
+ static <V> boolean expired(TtlValue<V> ttlValue, long ttl, TtlTimeProvider timeProvider) {
+ return ttlValue != null && expired(ttlValue.getLastAccessTimestamp(), ttl, timeProvider);
+ }
+
+ static boolean expired(long ts, long ttl, TtlTimeProvider timeProvider) {
+ return getExpirationTimestamp(ts, ttl) <= timeProvider.currentTimestamp();
+ }
+
+ private static long getExpirationTimestamp(long ts, long ttl) {
+ long ttlWithoutOverflow = ts > 0 ? Math.min(Long.MAX_VALUE - ts, ttl) : ttl;
+ return ts + ttlWithoutOverflow;
+ }
+
+ static <V> TtlValue<V> wrapWithTs(V value, long ts) {
+ return value == null ? null : new TtlValue<>(value, ts);
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlValueState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlValueState.java
index c14a583..7b19341 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlValueState.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/ttl/TtlValueState.java
@@ -18,7 +18,7 @@
package org.apache.flink.runtime.state.ttl;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.runtime.state.internal.InternalValueState;
@@ -36,7 +36,7 @@ class TtlValueState<K, N, T>
implements InternalValueState<K, N, T> {
TtlValueState(
InternalValueState<K, N, TtlValue<T>> originalState,
- StateTtlConfiguration config,
+ StateTtlConfig config,
TtlTimeProvider timeProvider,
TypeSerializer<T> valueSerializer) {
super(originalState, config, timeProvider, valueSerializer);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
index 5820f13..6b3a15f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/TtlStateTestBase.java
@@ -20,7 +20,7 @@ package org.apache.flink.runtime.state.ttl;
import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
-import org.apache.flink.api.common.state.StateTtlConfiguration;
+import org.apache.flink.api.common.state.StateTtlConfig;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.state.KeyedStateHandle;
@@ -49,7 +49,7 @@ public abstract class TtlStateTestBase {
private MockTtlTimeProvider timeProvider;
private StateBackendTestContext sbetc;
- private StateTtlConfiguration ttlConfig;
+ private StateTtlConfig ttlConfig;
@Before
public void setup() {
@@ -85,24 +85,31 @@ public abstract class TtlStateTestBase {
}
private void initTest() throws Exception {
- initTest(StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite, StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired);
+ initTest(StateTtlConfig.UpdateType.OnCreateAndWrite, StateTtlConfig.StateVisibility.NeverReturnExpired);
}
private void initTest(
- StateTtlConfiguration.TtlUpdateType updateType,
- StateTtlConfiguration.TtlStateVisibility visibility) throws Exception {
+ StateTtlConfig.UpdateType updateType,
+ StateTtlConfig.StateVisibility visibility) throws Exception {
initTest(updateType, visibility, TTL);
}
private void initTest(
- StateTtlConfiguration.TtlUpdateType updateType,
- StateTtlConfiguration.TtlStateVisibility visibility,
+ StateTtlConfig.UpdateType updateType,
+ StateTtlConfig.StateVisibility visibility,
long ttl) throws Exception {
- ttlConfig = StateTtlConfiguration
- .newBuilder(Time.milliseconds(ttl))
- .setTtlUpdateType(updateType)
+ initTest(getConfBuilder(ttl)
+ .setUpdateType(updateType)
.setStateVisibility(visibility)
- .build();
+ .build());
+ }
+
+ private static StateTtlConfig.Builder getConfBuilder(long ttl) {
+ return StateTtlConfig.newBuilder(Time.milliseconds(ttl));
+ }
+
+ private void initTest(StateTtlConfig ttlConfig) throws Exception {
+ this.ttlConfig = ttlConfig;
sbetc.createAndRestoreKeyedStateBackend();
sbetc.restoreSnapshot(null);
createState();
@@ -132,7 +139,7 @@ public abstract class TtlStateTestBase {
@Test
public void testExactExpirationOnWrite() throws Exception {
- initTest(StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite, StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired);
+ initTest(StateTtlConfig.UpdateType.OnCreateAndWrite, StateTtlConfig.StateVisibility.NeverReturnExpired);
takeAndRestoreSnapshot();
@@ -173,7 +180,7 @@ public abstract class TtlStateTestBase {
@Test
public void testRelaxedExpirationOnWrite() throws Exception {
- initTest(StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite, StateTtlConfiguration.TtlStateVisibility.ReturnExpiredIfNotCleanedUp);
+ initTest(StateTtlConfig.UpdateType.OnCreateAndWrite, StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp);
timeProvider.time = 0;
ctx().update(ctx().updateEmpty);
@@ -188,7 +195,7 @@ public abstract class TtlStateTestBase {
@Test
public void testExactExpirationOnRead() throws Exception {
- initTest(StateTtlConfiguration.TtlUpdateType.OnReadAndWrite, StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired);
+ initTest(StateTtlConfig.UpdateType.OnReadAndWrite, StateTtlConfig.StateVisibility.NeverReturnExpired);
timeProvider.time = 0;
ctx().update(ctx().updateEmpty);
@@ -212,7 +219,7 @@ public abstract class TtlStateTestBase {
@Test
public void testRelaxedExpirationOnRead() throws Exception {
- initTest(StateTtlConfiguration.TtlUpdateType.OnReadAndWrite, StateTtlConfiguration.TtlStateVisibility.ReturnExpiredIfNotCleanedUp);
+ initTest(StateTtlConfig.UpdateType.OnReadAndWrite, StateTtlConfig.StateVisibility.ReturnExpiredIfNotCleanedUp);
timeProvider.time = 0;
ctx().update(ctx().updateEmpty);
@@ -231,7 +238,7 @@ public abstract class TtlStateTestBase {
@Test
public void testExpirationTimestampOverflow() throws Exception {
- initTest(StateTtlConfiguration.TtlUpdateType.OnCreateAndWrite, StateTtlConfiguration.TtlStateVisibility.NeverReturnExpired, Long.MAX_VALUE);
+ initTest(StateTtlConfig.UpdateType.OnCreateAndWrite, StateTtlConfig.StateVisibility.NeverReturnExpired, Long.MAX_VALUE);
timeProvider.time = 10;
ctx().update(ctx().updateEmpty);
@@ -275,16 +282,26 @@ public abstract class TtlStateTestBase {
@Test
public void testMultipleKeys() throws Exception {
- testMultipleStateIds(id -> sbetc.setCurrentKey(id));
+ testMultipleStateIdsWithSnapshotCleanup(id -> sbetc.setCurrentKey(id));
}
@Test
public void testMultipleNamespaces() throws Exception {
- testMultipleStateIds(id -> ctx().ttlState.setCurrentNamespace(id));
+ testMultipleStateIdsWithSnapshotCleanup(id -> ctx().ttlState.setCurrentNamespace(id));
}
- private void testMultipleStateIds(Consumer<String> idChanger) throws Exception {
+ private void testMultipleStateIdsWithSnapshotCleanup(Consumer<String> idChanger) throws Exception {
initTest();
+ testMultipleStateIds(idChanger, false);
+
+ initTest(getConfBuilder(TTL).cleanupFullSnapshot().build());
+ // set time back after restore to see entry unexpired if it was not cleaned up in snapshot properly
+ testMultipleStateIds(idChanger, true);
+ }
+
+ private void testMultipleStateIds(Consumer<String> idChanger, boolean timeBackAfterRestore) throws Exception {
+ // test empty storage snapshot/restore
+ takeAndRestoreSnapshot();
timeProvider.time = 0;
idChanger.accept("id2");
@@ -298,9 +315,9 @@ public abstract class TtlStateTestBase {
idChanger.accept("id2");
ctx().update(ctx().updateUnexpired);
+ timeProvider.time = 120;
takeAndRestoreSnapshot();
- timeProvider.time = 120;
idChanger.accept("id1");
assertEquals("Unexpired state should be available", ctx().getUpdateEmpty, ctx().get());
idChanger.accept("id2");
@@ -312,17 +329,19 @@ public abstract class TtlStateTestBase {
idChanger.accept("id2");
ctx().update(ctx().updateExpired);
+ timeProvider.time = 230;
takeAndRestoreSnapshot();
- timeProvider.time = 230;
+ timeProvider.time = timeBackAfterRestore ? 170 : timeProvider.time;
idChanger.accept("id1");
assertEquals("Expired state should be unavailable", ctx().emptyValue, ctx().get());
idChanger.accept("id2");
assertEquals("Unexpired state should be available after update", ctx().getUpdateExpired, ctx().get());
+ timeProvider.time = 300;
takeAndRestoreSnapshot();
- timeProvider.time = 300;
+ timeProvider.time = timeBackAfterRestore ? 230 : timeProvider.time;
idChanger.accept("id1");
assertEquals("Expired state should be unavailable", ctx().emptyValue, ctx().get());
idChanger.accept("id2");
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index 805ae1c..0b5931c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -37,12 +37,13 @@ import org.apache.flink.runtime.state.KeyExtractorFunction;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupedInternalPriorityQueue;
import org.apache.flink.runtime.state.Keyed;
-import org.apache.flink.runtime.state.KeyedStateFactory;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.PriorityComparable;
import org.apache.flink.runtime.state.PriorityComparator;
import org.apache.flink.runtime.state.SharedStateRegistry;
import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
import org.apache.flink.runtime.state.heap.HeapPriorityQueueSet;
import org.apache.flink.runtime.state.ttl.TtlStateFactory;
@@ -57,6 +58,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.concurrent.FutureTask;
import java.util.concurrent.RunnableFuture;
import java.util.stream.Collectors;
@@ -65,19 +67,27 @@ import java.util.stream.Stream;
/** State backend which produces in memory mock state objects. */
public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
+ private interface StateFactory {
+ <N, SV, S extends State, IS extends S> IS createInternalState(
+ TypeSerializer<N> namespaceSerializer,
+ StateDescriptor<S, SV> stateDesc) throws Exception;
+ }
+
@SuppressWarnings("deprecation")
- private static final Map<Class<? extends StateDescriptor>, KeyedStateFactory> STATE_FACTORIES =
+ private static final Map<Class<? extends StateDescriptor>, StateFactory> STATE_FACTORIES =
Stream.of(
- Tuple2.of(ValueStateDescriptor.class, (KeyedStateFactory) MockInternalValueState::createState),
- Tuple2.of(ListStateDescriptor.class, (KeyedStateFactory) MockInternalListState::createState),
- Tuple2.of(MapStateDescriptor.class, (KeyedStateFactory) MockInternalMapState::createState),
- Tuple2.of(ReducingStateDescriptor.class, (KeyedStateFactory) MockInternalReducingState::createState),
- Tuple2.of(AggregatingStateDescriptor.class, (KeyedStateFactory) MockInternalAggregatingState::createState),
- Tuple2.of(FoldingStateDescriptor.class, (KeyedStateFactory) MockInternalFoldingState::createState)
+ Tuple2.of(ValueStateDescriptor.class, (StateFactory) MockInternalValueState::createState),
+ Tuple2.of(ListStateDescriptor.class, (StateFactory) MockInternalListState::createState),
+ Tuple2.of(MapStateDescriptor.class, (StateFactory) MockInternalMapState::createState),
+ Tuple2.of(ReducingStateDescriptor.class, (StateFactory) MockInternalReducingState::createState),
+ Tuple2.of(AggregatingStateDescriptor.class, (StateFactory) MockInternalAggregatingState::createState),
+ Tuple2.of(FoldingStateDescriptor.class, (StateFactory) MockInternalFoldingState::createState)
).collect(Collectors.toMap(t -> t.f0, t -> t.f1));
private final Map<String, Map<K, Map<Object, Object>>> stateValues = new HashMap<>();
+ private final Map<String, StateSnapshotTransformer<Object>> stateSnapshotFilters = new HashMap<>();
+
MockKeyedStateBackend(
TaskKvStateRegistry kvStateRegistry,
TypeSerializer<K> keySerializer,
@@ -92,22 +102,46 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
@Override
@SuppressWarnings("unchecked")
- public <N, SV, S extends State, IS extends S> IS createInternalState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
- KeyedStateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
+ @Nonnull
+ public <N, SV, SEV, S extends State, IS extends S> IS createInternalState(
+ @Nonnull TypeSerializer<N> namespaceSerializer,
+ @Nonnull StateDescriptor<S, SV> stateDesc,
+ @Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception {
+ StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
if (stateFactory == null) {
String message = String.format("State %s is not supported by %s",
stateDesc.getClass(), TtlStateFactory.class);
throw new FlinkRuntimeException(message);
}
IS state = stateFactory.createInternalState(namespaceSerializer, stateDesc);
+ stateSnapshotFilters.put(stateDesc.getName(),
+ (StateSnapshotTransformer<Object>) getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
((MockInternalKvState<K, N, SV>) state).values = () -> stateValues
.computeIfAbsent(stateDesc.getName(), n -> new HashMap<>())
.computeIfAbsent(getCurrentKey(), k -> new HashMap<>());
return state;
}
+ @SuppressWarnings("unchecked")
+ private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
+ StateDescriptor<?, SV> stateDesc,
+ StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+ Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
+ if (original.isPresent()) {
+ if (stateDesc instanceof ListStateDescriptor) {
+ return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
+ .ListStateSnapshotTransformer<>(original.get());
+ } else if (stateDesc instanceof MapStateDescriptor) {
+ return (StateSnapshotTransformer<SV>) new StateSnapshotTransformer
+ .MapStateSnapshotTransformer<>(original.get());
+ } else {
+ return (StateSnapshotTransformer<SV>) original.get();
+ }
+ } else {
+ return null;
+ }
+ }
+
@Override
public int numKeyValueStateEntries() {
int count = 0;
@@ -142,7 +176,8 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
long timestamp,
CheckpointStreamFactory streamFactory,
CheckpointOptions checkpointOptions) {
- return new FutureTask<>(() -> SnapshotResult.of(new MockKeyedStateHandle<>(copy(stateValues))));
+ return new FutureTask<>(() ->
+ SnapshotResult.of(new MockKeyedStateHandle<>(copy(stateValues, stateSnapshotFilters))));
}
@SuppressWarnings("unchecked")
@@ -153,32 +188,51 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
state.forEach(ksh -> stateValues.putAll(copy(((MockKeyedStateHandle<K>) ksh).snapshotStates)));
}
- @SuppressWarnings("unchecked")
private static <K> Map<String, Map<K, Map<Object, Object>>> copy(
Map<String, Map<K, Map<Object, Object>>> stateValues) {
+ return copy(stateValues, Collections.emptyMap());
+ }
+
+ private static <K> Map<String, Map<K, Map<Object, Object>>> copy(
+ Map<String, Map<K, Map<Object, Object>>> stateValues, Map<String, StateSnapshotTransformer<Object>> stateSnapshotFilters) {
Map<String, Map<K, Map<Object, Object>>> snapshotStates = new HashMap<>();
for (String stateName : stateValues.keySet()) {
+ StateSnapshotTransformer<Object> stateSnapshotTransformer = stateSnapshotFilters.getOrDefault(stateName, null);
Map<K, Map<Object, Object>> keyedValues = snapshotStates.computeIfAbsent(stateName, s -> new HashMap<>());
for (K key : stateValues.get(stateName).keySet()) {
- Map<Object, Object> values = keyedValues.computeIfAbsent(key, s -> new HashMap<>());
+ Map<Object, Object> snapshotedValues = keyedValues.computeIfAbsent(key, s -> new HashMap<>());
for (Object namespace : stateValues.get(stateName).get(key).keySet()) {
- Object value = stateValues.get(stateName).get(key).get(namespace);
- value = value instanceof List ? new ArrayList<>((List) value) : value;
- value = value instanceof Map ? new HashMap<>((Map) value) : value;
- values.put(namespace, value);
+ copyEntry(stateValues, snapshotedValues, stateName, key, namespace, stateSnapshotTransformer);
}
}
}
return snapshotStates;
}
+ @SuppressWarnings("unchecked")
+ private static <K> void copyEntry(
+ Map<String, Map<K, Map<Object, Object>>> stateValues,
+ Map<Object, Object> snapshotedValues,
+ String stateName,
+ K key,
+ Object namespace,
+ StateSnapshotTransformer<Object> stateSnapshotTransformer) {
+ Object value = stateValues.get(stateName).get(key).get(namespace);
+ value = value instanceof List ? new ArrayList<>((List) value) : value;
+ value = value instanceof Map ? new HashMap<>((Map) value) : value;
+ Object filteredValue = stateSnapshotTransformer == null ? value : stateSnapshotTransformer.filterOrTransform(value);
+ if (filteredValue != null) {
+ snapshotedValues.put(namespace, filteredValue);
+ }
+ }
+
@Nonnull
@Override
public <T extends HeapPriorityQueueElement & PriorityComparable & Keyed> KeyGroupedInternalPriorityQueue<T>
create(
@Nonnull String stateName,
@Nonnull TypeSerializer<T> byteOrderedElementSerializer) {
- return new HeapPriorityQueueSet<T>(
+ return new HeapPriorityQueueSet<>(
PriorityComparator.forPriorityComparableObjects(),
KeyExtractorFunction.forKeyedObjects(),
0,
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
index 7ead620..f7af354 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBKeyedStateBackend.java
@@ -84,6 +84,8 @@ import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.SnapshotStrategy;
import org.apache.flink.runtime.state.StateHandleID;
import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
+import org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
import org.apache.flink.runtime.state.StateUtil;
import org.apache.flink.runtime.state.StreamCompressionDecorator;
import org.apache.flink.runtime.state.StreamStateHandle;
@@ -110,12 +112,14 @@ import org.rocksdb.DBOptions;
import org.rocksdb.ReadOptions;
import org.rocksdb.RocksDB;
import org.rocksdb.RocksDBException;
+import org.rocksdb.RocksIterator;
import org.rocksdb.Snapshot;
import org.rocksdb.WriteOptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
@@ -135,6 +139,7 @@ import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
+import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.SortedMap;
@@ -1312,7 +1317,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
*/
private <N, S> Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, S>> tryRegisterKvStateInformation(
StateDescriptor<?, S> stateDesc,
- TypeSerializer<N> namespaceSerializer) throws StateMigrationException {
+ TypeSerializer<N> namespaceSerializer,
+ @Nullable StateSnapshotTransformer<S> snapshotTransformer) throws StateMigrationException {
Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> stateInfo =
kvStateInformation.get(stateDesc.getName());
@@ -1330,7 +1336,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
newMetaInfo = RegisteredKeyValueStateBackendMetaInfo.resolveKvStateCompatibility(
restoredMetaInfoSnapshot,
namespaceSerializer,
- stateDesc);
+ stateDesc,
+ snapshotTransformer);
stateInfo.f1 = newMetaInfo;
} else {
@@ -1340,7 +1347,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
stateDesc.getType(),
stateName,
namespaceSerializer,
- stateDesc.getSerializer());
+ stateDesc.getSerializer(),
+ snapshotTransformer);
ColumnFamilyHandle columnFamily = createColumnFamily(stateName);
@@ -1369,20 +1377,43 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
}
@Override
- public <N, SV, S extends State, IS extends S> IS createInternalState(
- TypeSerializer<N> namespaceSerializer,
- StateDescriptor<S, SV> stateDesc) throws Exception {
+ @Nonnull
+ public <N, SV, SEV, S extends State, IS extends S> IS createInternalState(
+ @Nonnull TypeSerializer<N> namespaceSerializer,
+ @Nonnull StateDescriptor<S, SV> stateDesc,
+ @Nonnull StateSnapshotTransformFactory<SEV> snapshotTransformFactory) throws Exception {
StateFactory stateFactory = STATE_FACTORIES.get(stateDesc.getClass());
if (stateFactory == null) {
String message = String.format("State %s is not supported by %s",
stateDesc.getClass(), this.getClass());
throw new FlinkRuntimeException(message);
}
- Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> registerResult =
- tryRegisterKvStateInformation(stateDesc, namespaceSerializer);
+ Tuple2<ColumnFamilyHandle, RegisteredKeyValueStateBackendMetaInfo<N, SV>> registerResult = tryRegisterKvStateInformation(
+ stateDesc, namespaceSerializer, getStateSnapshotTransformer(stateDesc, snapshotTransformFactory));
return stateFactory.createState(stateDesc, registerResult, RocksDBKeyedStateBackend.this);
}
+ @SuppressWarnings("unchecked")
+ private <SV, SEV> StateSnapshotTransformer<SV> getStateSnapshotTransformer(
+ StateDescriptor<?, SV> stateDesc,
+ StateSnapshotTransformFactory<SEV> snapshotTransformFactory) {
+ if (stateDesc instanceof ListStateDescriptor) {
+ Optional<StateSnapshotTransformer<SEV>> original = snapshotTransformFactory.createForDeserializedState();
+ return original.map(est -> createRocksDBListStateTransformer(stateDesc, est)).orElse(null);
+ } else {
+ Optional<StateSnapshotTransformer<byte[]>> original = snapshotTransformFactory.createForSerializedState();
+ return (StateSnapshotTransformer<SV>) original.orElse(null);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private <SV, SEV> StateSnapshotTransformer<SV> createRocksDBListStateTransformer(
+ StateDescriptor<?, SV> stateDesc,
+ StateSnapshotTransformer<SEV> elementTransformer) {
+ return (StateSnapshotTransformer<SV>) new RocksDBListState.StateSnapshotTransformerWrapper<>(
+ elementTransformer, ((ListStateDescriptor<SEV>) stateDesc).getElementSerializer());
+ }
+
/**
* Only visible for testing, DO NOT USE.
*/
@@ -1402,7 +1433,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
int count = 0;
for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> column : kvStateInformation.values()) {
- //TODO maybe filter only for k/v states
+ //TODO maybe filterOrTransform only for k/v states
try (RocksIteratorWrapper rocksIterator = getRocksIterator(db, column.f0)) {
rocksIterator.seekToFirst();
@@ -1416,14 +1447,12 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
return count;
}
-
-
/**
* Iterator that merges multiple RocksDB iterators to partition all states into contiguous key-groups.
* The resulting iteration sequence is ordered by (key-group, kv-state).
*/
@VisibleForTesting
- static final class RocksDBMergeIterator implements AutoCloseable {
+ static class RocksDBMergeIterator implements AutoCloseable {
private final PriorityQueue<RocksDBKeyedStateBackend.MergeIterator> heap;
private final int keyGroupPrefixByteCount;
@@ -1431,7 +1460,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
private boolean newKVState;
private boolean valid;
- private MergeIterator currentSubIterator;
+ MergeIterator currentSubIterator;
private static final List<Comparator<MergeIterator>> COMPARATORS;
@@ -1440,18 +1469,17 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
COMPARATORS = new ArrayList<>(maxBytes);
for (int i = 0; i < maxBytes; ++i) {
final int currentBytes = i + 1;
- COMPARATORS.add(new Comparator<MergeIterator>() {
- @Override
- public int compare(MergeIterator o1, MergeIterator o2) {
- int arrayCmpRes = compareKeyGroupsForByteArrays(
- o1.currentKey, o2.currentKey, currentBytes);
- return arrayCmpRes == 0 ? o1.getKvStateId() - o2.getKvStateId() : arrayCmpRes;
- }
+ COMPARATORS.add((o1, o2) -> {
+ int arrayCmpRes = compareKeyGroupsForByteArrays(
+ o1.currentKey, o2.currentKey, currentBytes);
+ return arrayCmpRes == 0 ? o1.getKvStateId() - o2.getKvStateId() : arrayCmpRes;
});
}
}
- RocksDBMergeIterator(List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators, final int keyGroupPrefixByteCount) throws RocksDBException {
+ RocksDBMergeIterator(
+ List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators,
+ final int keyGroupPrefixByteCount) {
Preconditions.checkNotNull(kvStateIterators);
Preconditions.checkArgument(keyGroupPrefixByteCount >= 1);
@@ -1492,7 +1520,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
* Advance the iterator. Should only be called if {@link #isValid()} returned true. Valid can only chance after
* calls to {@link #next()}.
*/
- public void next() throws RocksDBException {
+ public void next() {
newKeyGroup = false;
newKVState = false;
@@ -1612,7 +1640,8 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
* Wraps a RocksDB iterator to cache it's current key and assigns an id for the key/value state to the iterator.
* Used by #MergeIterator.
*/
- private static final class MergeIterator implements AutoCloseable {
+ @VisibleForTesting
+ protected static final class MergeIterator implements AutoCloseable {
/**
* @param iterator The #RocksIterator to wrap .
@@ -1650,6 +1679,57 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
}
}
+ private static final class TransformingRocksIteratorWrapper extends RocksIteratorWrapper {
+ @Nonnull
+ private final StateSnapshotTransformer<byte[]> stateSnapshotTransformer;
+ private byte[] current;
+
+ public TransformingRocksIteratorWrapper(
+ @Nonnull RocksIterator iterator,
+ @Nonnull StateSnapshotTransformer<byte[]> stateSnapshotTransformer) {
+ super(iterator);
+ this.stateSnapshotTransformer = stateSnapshotTransformer;
+ }
+
+ @Override
+ public void seekToFirst() {
+ super.seekToFirst();
+ filterOrTransform(super::next);
+ }
+
+ @Override
+ public void seekToLast() {
+ super.seekToLast();
+ filterOrTransform(super::prev);
+ }
+
+ @Override
+ public void next() {
+ super.next();
+ filterOrTransform(super::next);
+ }
+
+ @Override
+ public void prev() {
+ super.prev();
+ filterOrTransform(super::prev);
+ }
+
+ private void filterOrTransform(Runnable advance) {
+ while (isValid() && (current = stateSnapshotTransformer.filterOrTransform(super.value())) == null) {
+ advance.run();
+ }
+ }
+
+ @Override
+ public byte[] value() {
+ if (!isValid()) {
+ throw new IllegalStateException("value() method cannot be called if isValid() is false");
+ }
+ return current;
+ }
+ }
+
/**
* Adapter class to bridge between {@link RocksIteratorWrapper} and {@link Iterator} to iterate over the keys. This class
* is not thread safe.
@@ -1972,7 +2052,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
/**
* The copied column handle.
*/
- private List<ColumnFamilyHandle> copiedColumnFamilyHandles;
+ private List<Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase>> copiedMeta;
private List<Tuple2<RocksIteratorWrapper, Integer>> kvStateIterators;
@@ -2000,15 +2080,13 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
this.stateMetaInfoSnapshots = new ArrayList<>(stateBackend.kvStateInformation.size());
- this.copiedColumnFamilyHandles = new ArrayList<>(stateBackend.kvStateInformation.size());
+ this.copiedMeta = new ArrayList<>(stateBackend.kvStateInformation.size());
for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 :
stateBackend.kvStateInformation.values()) {
// snapshot meta info
this.stateMetaInfoSnapshots.add(tuple2.f1.snapshot());
-
- // copy column family handle
- this.copiedColumnFamilyHandles.add(tuple2.f0);
+ this.copiedMeta.add(tuple2);
}
this.snapshot = stateBackend.db.getSnapshot();
}
@@ -2095,7 +2173,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
private void writeKVStateMetaData() throws IOException {
- this.kvStateIterators = new ArrayList<>(copiedColumnFamilyHandles.size());
+ this.kvStateIterators = new ArrayList<>(copiedMeta.size());
int kvStateId = 0;
@@ -2103,11 +2181,10 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
readOptions = new ReadOptions();
readOptions.setSnapshot(snapshot);
- for (ColumnFamilyHandle columnFamilyHandle : copiedColumnFamilyHandles) {
-
- kvStateIterators.add(
- new Tuple2<>(getRocksIterator(stateBackend.db, columnFamilyHandle, readOptions), kvStateId));
-
+ for (Tuple2<ColumnFamilyHandle, RegisteredStateMetaInfoBase> tuple2 : copiedMeta) {
+ RocksIteratorWrapper rocksIteratorWrapper =
+ getRocksIterator(stateBackend.db, tuple2.f0, tuple2.f1, readOptions);
+ kvStateIterators.add(new Tuple2<>(rocksIteratorWrapper, kvStateId));
++kvStateId;
}
@@ -2124,8 +2201,7 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
serializationProxy.write(outputView);
}
- private void writeKVStateData() throws IOException, InterruptedException, RocksDBException {
-
+ private void writeKVStateData() throws IOException, InterruptedException {
byte[] previousKey = null;
byte[] previousValue = null;
DataOutputView kgOutView = null;
@@ -2635,11 +2711,21 @@ public class RocksDBKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
return new RocksIteratorWrapper(db.newIterator(columnFamilyHandle));
}
- public static RocksIteratorWrapper getRocksIterator(
+ @SuppressWarnings("unchecked")
+ private static RocksIteratorWrapper getRocksIterator(
RocksDB db,
ColumnFamilyHandle columnFamilyHandle,
+ RegisteredStateMetaInfoBase metaInfo,
ReadOptions readOptions) {
- return new RocksIteratorWrapper(db.newIterator(columnFamilyHandle, readOptions));
+ StateSnapshotTransformer<byte[]> stateSnapshotTransformer = null;
+ if (metaInfo instanceof RegisteredKeyValueStateBackendMetaInfo) {
+ stateSnapshotTransformer = (StateSnapshotTransformer<byte[]>)
+ ((RegisteredKeyValueStateBackendMetaInfo<?, ?>) metaInfo).getSnapshotTransformer();
+ }
+ RocksIterator rocksIterator = db.newIterator(columnFamilyHandle, readOptions);
+ return stateSnapshotTransformer == null ?
+ new RocksIteratorWrapper(rocksIterator) :
+ new TransformingRocksIteratorWrapper(rocksIterator, stateSnapshotTransformer);
}
/**
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java
index aa5e93a..176f48c 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java
@@ -24,9 +24,12 @@ import org.apache.flink.api.common.state.State;
import org.apache.flink.api.common.state.StateDescriptor;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.ByteArrayDataInputView;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.state.RegisteredKeyValueStateBackendMetaInfo;
+import org.apache.flink.runtime.state.StateSnapshotTransformer;
import org.apache.flink.runtime.state.internal.InternalListState;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.Preconditions;
@@ -34,12 +37,16 @@ import org.apache.flink.util.Preconditions;
import org.rocksdb.ColumnFamilyHandle;
import org.rocksdb.RocksDBException;
-import java.io.ByteArrayInputStream;
+import javax.annotation.Nullable;
+
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import java.util.List;
+import static org.apache.flink.runtime.state.StateSnapshotTransformer.CollectionStateSnapshotTransformer.TransformStrategy.STOP_ON_FIRST_INCLUDED;
+
/**
* {@link ListState} implementation that stores state in RocksDB.
*
@@ -111,25 +118,42 @@ class RocksDBListState<K, N, V>
writeCurrentKeyWithGroupAndNamespace();
byte[] key = keySerializationStream.toByteArray();
byte[] valueBytes = backend.db.get(columnFamily, key);
+ return deserializeList(valueBytes, elementSerializer);
+ } catch (IOException | RocksDBException e) {
+ throw new FlinkRuntimeException("Error while retrieving data from RocksDB", e);
+ }
+ }
- if (valueBytes == null) {
- return null;
- }
+ private static <V> List<V> deserializeList(
+ byte[] valueBytes, TypeSerializer<V> elementSerializer) {
+ if (valueBytes == null) {
+ return null;
+ }
- ByteArrayInputStream bais = new ByteArrayInputStream(valueBytes);
- DataInputViewStreamWrapper in = new DataInputViewStreamWrapper(bais);
+ DataInputViewStreamWrapper in = new ByteArrayDataInputView(valueBytes);
- List<V> result = new ArrayList<>();
- while (in.available() > 0) {
- result.add(elementSerializer.deserialize(in));
+ List<V> result = new ArrayList<>();
+ V next;
+ while ((next = deserializeNextElement(in, elementSerializer)) != null) {
+ result.add(next);
+ }
+ return result;
+ }
+
+ private static <V> V deserializeNextElement(
+ DataInputViewStreamWrapper in, TypeSerializer<V> elementSerializer) {
+ try {
+ if (in.available() > 0) {
+ V element = elementSerializer.deserialize(in);
if (in.available() > 0) {
in.readByte();
}
+ return element;
}
- return result;
- } catch (IOException | RocksDBException e) {
- throw new FlinkRuntimeException("Error while retrieving data from RocksDB", e);
+ } catch (IOException e) {
+ throw new FlinkRuntimeException("Unexpected list element deserialization failure");
}
+ return null;
}
@Override
@@ -203,7 +227,7 @@ class RocksDBListState<K, N, V>
writeCurrentKeyWithGroupAndNamespace();
byte[] key = keySerializationStream.toByteArray();
- byte[] premerge = getPreMergedValue(values);
+ byte[] premerge = getPreMergedValue(values, elementSerializer, keySerializationStream);
if (premerge != null) {
backend.db.put(columnFamily, writeOptions, key, premerge);
} else {
@@ -224,7 +248,7 @@ class RocksDBListState<K, N, V>
writeCurrentKeyWithGroupAndNamespace();
byte[] key = keySerializationStream.toByteArray();
- byte[] premerge = getPreMergedValue(values);
+ byte[] premerge = getPreMergedValue(values, elementSerializer, keySerializationStream);
if (premerge != null) {
backend.db.merge(columnFamily, writeOptions, key, premerge);
} else {
@@ -236,7 +260,10 @@ class RocksDBListState<K, N, V>
}
}
- private byte[] getPreMergedValue(List<V> values) throws IOException {
+ private static <V> byte[] getPreMergedValue(
+ List<V> values,
+ TypeSerializer<V> elementSerializer,
+ ByteArrayOutputStreamWithPos keySerializationStream) throws IOException {
DataOutputViewStreamWrapper out = new DataOutputViewStreamWrapper(keySerializationStream);
keySerializationStream.reset();
@@ -267,4 +294,47 @@ class RocksDBListState<K, N, V>
((ListStateDescriptor<E>) stateDesc).getElementSerializer(),
backend);
}
+
+ static class StateSnapshotTransformerWrapper<T> implements StateSnapshotTransformer<byte[]> {
+ private final StateSnapshotTransformer<T> elementTransformer;
+ private final TypeSerializer<T> elementSerializer;
+ private final ByteArrayOutputStreamWithPos out = new ByteArrayOutputStreamWithPos(128);
+ private final CollectionStateSnapshotTransformer.TransformStrategy transformStrategy;
+
+ StateSnapshotTransformerWrapper(StateSnapshotTransformer<T> elementTransformer, TypeSerializer<T> elementSerializer) {
+ this.elementTransformer = elementTransformer;
+ this.elementSerializer = elementSerializer;
+ this.transformStrategy = elementTransformer instanceof CollectionStateSnapshotTransformer ?
+ ((CollectionStateSnapshotTransformer) elementTransformer).getFilterStrategy() :
+ CollectionStateSnapshotTransformer.TransformStrategy.TRANSFORM_ALL;
+ }
+
+ @Override
+ @Nullable
+ public byte[] filterOrTransform(@Nullable byte[] value) {
+ if (value == null) {
+ return null;
+ }
+ List<T> result = new ArrayList<>();
+ ByteArrayDataInputView in = new ByteArrayDataInputView(value);
+ T next;
+ int prevPosition = 0;
+ while ((next = deserializeNextElement(in, elementSerializer)) != null) {
+ T transformedElement = elementTransformer.filterOrTransform(next);
+ if (transformedElement != null) {
+ if (transformStrategy == STOP_ON_FIRST_INCLUDED) {
+ return Arrays.copyOfRange(value, prevPosition, value.length);
+ } else {
+ result.add(transformedElement);
+ }
+ }
+ prevPosition = in.getPosition();
+ }
+ try {
+ return result.isEmpty() ? null : getPreMergedValue(result, elementSerializer, out);
+ } catch (IOException e) {
+ throw new FlinkRuntimeException("Failed to serialize transformed list", e);
+ }
+ }
+ }
}