You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ro...@apache.org on 2022/01/25 07:02:37 UTC

[flink] branch master updated: [FLINK-25524] Fix ChangelogStateBackend.notifyCheckpointComplete

This is an automated email from the ASF dual-hosted git repository.

roman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new d0927dd  [FLINK-25524] Fix ChangelogStateBackend.notifyCheckpointComplete
d0927dd is described below

commit d0927dd41e2f0441e4e5825ff423dd0e903713f3
Author: Roman Khachatryan <kh...@gmail.com>
AuthorDate: Mon Jan 17 17:01:11 2022 +0100

    [FLINK-25524] Fix ChangelogStateBackend.notifyCheckpointComplete
    
    When triggering materialization, Changelog backend uses fake
    checkpoint ID to obtain a (materialized) snapshot.
    That same ID must be used when proxying checkpoint
    completion/abortion notifications to the nested backend.
    
    On recovery, nested backend might read lastCompletedCheckpointID
    from its snapshot; in particular, when enabling changelog, which
    may cause inconsistency.
    
    This change:
    - adds a mapping from checkpoint to materializationID
    - stores materializationID in checkpoint metadata
    - selects max materializationID on recovery (to handle upscaling)
---
 .../metadata/MetadataV2V3SerializerBase.java       |   7 +-
 .../state/CheckpointBoundKeyedStateHandle.java     |  24 ++++
 .../runtime/state/IncrementalKeyedStateHandle.java |   6 +-
 .../changelog/ChangelogStateBackendHandle.java     |  15 ++-
 .../state/ttl/mock/MockKeyedStateBackend.java      |   8 +-
 .../changelog/ChangelogKeyedStateBackend.java      | 116 ++++++++++++++++---
 .../state/changelog/ChangelogStateBackend.java     |  12 +-
 .../changelog/PeriodicMaterializationManager.java  |  12 +-
 .../changelog/ChangelogKeyedStateBackendTest.java  | 124 +++++++++++++++++++++
 .../changelog/ChangelogStateBackendTestUtils.java  |   6 +-
 10 files changed, 301 insertions(+), 29 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/metadata/MetadataV2V3SerializerBase.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/metadata/MetadataV2V3SerializerBase.java
index 7062598..25086de 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/metadata/MetadataV2V3SerializerBase.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/metadata/MetadataV2V3SerializerBase.java
@@ -336,6 +336,8 @@ public abstract class MetadataV2V3SerializerBase {
                 serializeKeyedStateHandle(k, dos);
             }
 
+            dos.writeLong(handle.getMaterializationID());
+
         } else if (stateHandle instanceof InMemoryChangelogStateHandle) {
             InMemoryChangelogStateHandle handle = (InMemoryChangelogStateHandle) stateHandle;
             dos.writeByte(CHANGELOG_BYTE_INCREMENT_HANDLE);
@@ -441,8 +443,11 @@ public abstract class MetadataV2V3SerializerBase {
             for (int i = 0; i < deltaSize; i++) {
                 delta.add((ChangelogStateHandle) deserializeKeyedStateHandle(dis, context));
             }
+
+            long materializationID = dis.readLong();
+
             return new ChangelogStateBackendHandle.ChangelogStateBackendHandleImpl(
-                    base, delta, keyGroupRange);
+                    base, delta, keyGroupRange, materializationID);
 
         } else if (CHANGELOG_BYTE_INCREMENT_HANDLE == type) {
             int start = dis.readInt();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointBoundKeyedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointBoundKeyedStateHandle.java
new file mode 100644
index 0000000..6688f02
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/CheckpointBoundKeyedStateHandle.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.runtime.state;
+
+/** {@link KeyedStateHandle} that is bound to a specific checkpoint. */
+public interface CheckpointBoundKeyedStateHandle {
+
+    /** Returns the ID of the checkpoint for which the handle was created or used. */
+    long getCheckpointId();
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java
index 6d323e0..b701465 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/IncrementalKeyedStateHandle.java
@@ -24,10 +24,8 @@ import java.util.Set;
 import java.util.UUID;
 
 /** Common interface to all incremental {@link KeyedStateHandle}. */
-public interface IncrementalKeyedStateHandle extends KeyedStateHandle {
-
-    /** Returns the ID of the checkpoint for which the handle was created. */
-    long getCheckpointId();
+public interface IncrementalKeyedStateHandle
+        extends KeyedStateHandle, CheckpointBoundKeyedStateHandle {
 
     /** Returns the identifier of the state backend from which this handle was created. */
     @Nonnull
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/ChangelogStateBackendHandle.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/ChangelogStateBackendHandle.java
index 8b4203d..1e502d0 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/ChangelogStateBackendHandle.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/changelog/ChangelogStateBackendHandle.java
@@ -48,20 +48,25 @@ public interface ChangelogStateBackendHandle extends KeyedStateHandle {
 
     List<ChangelogStateHandle> getNonMaterializedStateHandles();
 
+    long getMaterializationID();
+
     class ChangelogStateBackendHandleImpl implements ChangelogStateBackendHandle {
         private static final long serialVersionUID = 1L;
         private final List<KeyedStateHandle> materialized;
         private final List<ChangelogStateHandle> nonMaterialized;
         private final KeyGroupRange keyGroupRange;
+        private final long materializationID;
 
         public ChangelogStateBackendHandleImpl(
                 List<KeyedStateHandle> materialized,
                 List<ChangelogStateHandle> nonMaterialized,
-                KeyGroupRange keyGroupRange) {
+                KeyGroupRange keyGroupRange,
+                long materializationID) {
             this.materialized = unmodifiableList(materialized);
             this.nonMaterialized = unmodifiableList(nonMaterialized);
             this.keyGroupRange = keyGroupRange;
             checkArgument(keyGroupRange.getNumberOfKeyGroups() > 0);
+            this.materializationID = materializationID;
         }
 
         @Override
@@ -104,7 +109,8 @@ public interface ChangelogStateBackendHandle extends KeyedStateHandle {
                                                     handle.getIntersection(keyGroupRange))
                             .filter(Objects::nonNull)
                             .collect(Collectors.toList());
-            return new ChangelogStateBackendHandleImpl(basePart, deltaPart, intersection);
+            return new ChangelogStateBackendHandleImpl(
+                    basePart, deltaPart, intersection, materializationID);
         }
 
         @Override
@@ -124,6 +130,11 @@ public interface ChangelogStateBackendHandle extends KeyedStateHandle {
         }
 
         @Override
+        public long getMaterializationID() {
+            return materializationID;
+        }
+
+        @Override
         public String toString() {
             return String.format(
                     "keyGroupRange=%s, basePartSize=%d, deltaPartSize=%d",
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
index c7f6ea9..34d6265 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ttl/mock/MockKeyedStateBackend.java
@@ -69,6 +69,8 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
     /** Whether to create empty snapshot ({@link MockKeyedStateHandle} isn't recognized by JM). */
     private final boolean emptySnapshot;
 
+    private long lastCompletedCheckpointID;
+
     private interface StateFactory {
         <N, SV, S extends State, IS extends S> IS createInternalState(
                 TypeSerializer<N> namespaceSerializer, StateDescriptor<S, SV> stateDesc)
@@ -188,7 +190,7 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
 
     @Override
     public void notifyCheckpointComplete(long checkpointId) {
-        // noop
+        lastCompletedCheckpointID = checkpointId;
     }
 
     @Override
@@ -300,6 +302,10 @@ public class MockKeyedStateBackend<K> extends AbstractKeyedStateBackend<K> {
                 0);
     }
 
+    public long getLastCompletedCheckpointID() {
+        return lastCompletedCheckpointID;
+    }
+
     static class MockKeyedStateHandle<K> implements KeyedStateHandle {
         private static final long serialVersionUID = 1L;
 
diff --git a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
index 9fc4f56..54fd7d1 100644
--- a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
+++ b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackend.java
@@ -78,10 +78,14 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.NavigableMap;
 import java.util.NoSuchElementException;
 import java.util.Optional;
+import java.util.Set;
+import java.util.TreeMap;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.RunnableFuture;
@@ -156,7 +160,7 @@ public class ChangelogKeyedStateBackend<K>
 
     private final TtlTimeProvider ttlTimeProvider;
 
-    private final StateChangelogWriter<ChangelogStateHandle> stateChangelogWriter;
+    private final StateChangelogWriter<? extends ChangelogStateHandle> stateChangelogWriter;
 
     private final Closer closer = Closer.create();
 
@@ -203,12 +207,30 @@ public class ChangelogKeyedStateBackend<K>
      */
     private short lastCreatedStateId = -1;
 
+    /** Checkpoint ID mapped to Materialization ID - used to notify nested backend of completion. */
+    private final NavigableMap<Long, Long> materializationIdByCheckpointId = new TreeMap<>();
+    /**
+     * Materialization ID mapped to Checkpoint IDs - used to notify nested backend of abortion.
+     * Entry is removed when:
+     *
+     * <ol>
+     *   <li>some checkpoint of a Set completes (in which case {@link #keyedStateBackend} is {@link
+     *       CheckpointListener#notifyCheckpointComplete(long) notified of completion}.
+     *   <li>a newer checkpoint completes
+     *   <li>all checkpoints of a Set are aborted (in which case {@link #keyedStateBackend} is
+     *       {@link CheckpointListener#notifyCheckpointAborted(long) notified of abortion}.
+     * </ol>
+     */
+    private final Map<Long, Set<Long>> pendingMaterializationConfirmations = new HashMap<>();
+
+    private long lastConfirmedMaterializationId = -1L;
+
     public ChangelogKeyedStateBackend(
             AbstractKeyedStateBackend<K> keyedStateBackend,
             String subtaskName,
             ExecutionConfig executionConfig,
             TtlTimeProvider ttlTimeProvider,
-            StateChangelogWriter<ChangelogStateHandle> stateChangelogWriter,
+            StateChangelogWriter<? extends ChangelogStateHandle> stateChangelogWriter,
             Collection<ChangelogStateBackendHandle> initialState,
             CheckpointStorageWorkerView checkpointStorageWorkerView) {
         this.keyedStateBackend = keyedStateBackend;
@@ -378,6 +400,16 @@ public class ChangelogKeyedStateBackend<K>
 
         ChangelogSnapshotState changelogStateBackendStateCopy = changelogSnapshotState;
 
+        if (changelogStateBackendStateCopy.materializationID > lastConfirmedMaterializationId) {
+            materializationIdByCheckpointId.put(
+                    checkpointId, changelogStateBackendStateCopy.materializationID);
+            pendingMaterializationConfirmations
+                    .computeIfAbsent(
+                            changelogStateBackendStateCopy.materializationID,
+                            ign -> new HashSet<>())
+                    .add(checkpointId);
+        }
+
         return toRunnableFuture(
                 stateChangelogWriter
                         .persist(lastUploadedFrom)
@@ -405,7 +437,8 @@ public class ChangelogKeyedStateBackend<K>
                     new ChangelogStateBackendHandleImpl(
                             changelogStateBackendStateCopy.getMaterializedSnapshot(),
                             prevDeltaCopy,
-                            getKeyGroupRange()));
+                            getKeyGroupRange(),
+                            changelogStateBackendStateCopy.materializationID));
         }
     }
 
@@ -466,7 +499,20 @@ public class ChangelogKeyedStateBackend<K>
             // This might change if the log ownership changes (the method won't likely be needed).
             stateChangelogWriter.confirm(lastUploadedFrom, lastUploadedTo);
         }
-        keyedStateBackend.notifyCheckpointComplete(checkpointId);
+        Long materializationID = materializationIdByCheckpointId.remove(checkpointId);
+        if (materializationID != null) {
+            if (materializationID > lastConfirmedMaterializationId) {
+                keyedStateBackend.notifyCheckpointComplete(materializationID);
+                lastConfirmedMaterializationId = materializationID;
+            }
+            pendingMaterializationConfirmations.remove(materializationID);
+        }
+        // there is a chance that nested backend will miss the abort notification
+        // but there is no other way to cleanup this map
+        Map<Long, Long> olderCheckpoints =
+                materializationIdByCheckpointId.headMap(checkpointId, true);
+        olderCheckpoints.values().forEach(pendingMaterializationConfirmations::remove);
+        olderCheckpoints.clear();
     }
 
     @Override
@@ -478,7 +524,22 @@ public class ChangelogKeyedStateBackend<K>
             // This might change if the log ownership changes (the method won't likely be needed).
             stateChangelogWriter.reset(lastUploadedFrom, lastUploadedTo);
         }
-        keyedStateBackend.notifyCheckpointAborted(checkpointId);
+        Long materializationID = materializationIdByCheckpointId.remove(checkpointId);
+        if (materializationID != null) {
+            Set<Long> checkpoints = pendingMaterializationConfirmations.get(materializationID);
+            checkpoints.remove(checkpointId);
+            if (checkpoints.isEmpty()) {
+                if (materializationID < changelogSnapshotState.materializationID) {
+                    // Notification is not strictly required and will arrive only after the nested
+                    // snapshot has completed. It's also unlikely to be sent because of the
+                    // difference in checkpoint/materialization intervals. But it can still be
+                    // useful
+                    // for some backends.
+                    keyedStateBackend.notifyCheckpointAborted(materializationID);
+                }
+                pendingMaterializationConfirmations.remove(materializationID);
+            }
+        }
     }
 
     // -------- Methods not simply delegating to wrapped state backend ---------
@@ -572,6 +633,7 @@ public class ChangelogKeyedStateBackend<K>
 
     private ChangelogSnapshotState completeRestore(
             Collection<ChangelogStateBackendHandle> stateHandles) {
+        long materializationId = 0L;
 
         List<KeyedStateHandle> materialized = new ArrayList<>();
         List<ChangelogStateHandle> restoredNonMaterialized = new ArrayList<>();
@@ -580,13 +642,17 @@ public class ChangelogKeyedStateBackend<K>
             if (h != null) {
                 materialized.addAll(h.getMaterializedStateHandles());
                 restoredNonMaterialized.addAll(h.getNonMaterializedStateHandles());
+                // choose max materializationID to handle rescaling
+                materializationId = Math.max(materializationId, h.getMaterializationID());
             }
         }
+        this.materializedId = materializationId + 1;
 
         return new ChangelogSnapshotState(
                 materialized,
                 restoredNonMaterialized,
-                stateChangelogWriter.initialSequenceNumber());
+                stateChangelogWriter.initialSequenceNumber(),
+                materializationId);
     }
 
     /**
@@ -611,19 +677,20 @@ public class ChangelogKeyedStateBackend<K>
 
             LOG.info("Starting materialization from {} : {}", lastMaterializedTo, upTo);
 
+            // This ID is not needed for materialization; But since we are re-using the
+            // streamFactory that is designed for state backend snapshot, which requires unique
+            // checkpoint ID. A faked materialized Id is provided here.
+            long materializationID = materializedId++;
+
             MaterializationRunnable materializationRunnable =
                     new MaterializationRunnable(
                             keyedStateBackend.snapshot(
-                                    // This ID is not needed for materialization;
-                                    // But since we are re-using the streamFactory
-                                    // that is designed for state backend snapshot,
-                                    // which requires unique checkpoint ID.
-                                    // A faked materialized Id is provided here.
-                                    // TODO: implement its own streamFactory.
-                                    materializedId++,
+                                    materializationID,
                                     System.currentTimeMillis(),
+                                    // TODO: implement its own streamFactory.
                                     streamFactory,
                                     CHECKPOINT_OPTIONS),
+                            materializationID,
                             upTo);
 
             // log metadata after materialization is triggered
@@ -660,7 +727,10 @@ public class ChangelogKeyedStateBackend<K>
      * mailbox executor.
      */
     public void updateChangelogSnapshotState(
-            SnapshotResult<KeyedStateHandle> materializedSnapshot, SequenceNumber upTo) {
+            SnapshotResult<KeyedStateHandle> materializedSnapshot,
+            long materializationID,
+            SequenceNumber upTo)
+            throws Exception {
 
         LOG.info(
                 "Task {} finishes materialization, updates the snapshotState upTo {} : {}",
@@ -669,7 +739,10 @@ public class ChangelogKeyedStateBackend<K>
                 materializedSnapshot);
         changelogSnapshotState =
                 new ChangelogSnapshotState(
-                        getMaterializedResult(materializedSnapshot), Collections.emptyList(), upTo);
+                        getMaterializedResult(materializedSnapshot),
+                        Collections.emptyList(),
+                        upTo,
+                        materializationID);
 
         stateChangelogWriter.truncate(upTo);
     }
@@ -788,13 +861,18 @@ public class ChangelogKeyedStateBackend<K>
          */
         private final List<ChangelogStateHandle> restoredNonMaterialized;
 
+        /** ID of this materialization corresponding to the nested backend checkpoint ID. */
+        private final long materializationID;
+
         public ChangelogSnapshotState(
                 List<KeyedStateHandle> materializedSnapshot,
                 List<ChangelogStateHandle> restoredNonMaterialized,
-                SequenceNumber materializedTo) {
+                SequenceNumber materializedTo,
+                long materializationID) {
             this.materializedSnapshot = unmodifiableList((materializedSnapshot));
             this.restoredNonMaterialized = unmodifiableList(restoredNonMaterialized);
             this.materializedTo = materializedTo;
+            this.materializationID = materializationID;
         }
 
         public List<KeyedStateHandle> getMaterializedSnapshot() {
@@ -808,10 +886,14 @@ public class ChangelogKeyedStateBackend<K>
         public List<ChangelogStateHandle> getRestoredNonMaterialized() {
             return restoredNonMaterialized;
         }
+
+        public long getMaterializationID() {
+            return materializationID;
+        }
     }
 
     @VisibleForTesting
-    StateChangelogWriter<ChangelogStateHandle> getChangelogWriter() {
+    StateChangelogWriter<? extends ChangelogStateHandle> getChangelogWriter() {
         return stateChangelogWriter;
     }
 }
diff --git a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogStateBackend.java b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogStateBackend.java
index 4a87952..893ac48 100644
--- a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogStateBackend.java
+++ b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/ChangelogStateBackend.java
@@ -29,6 +29,7 @@ import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
+import org.apache.flink.runtime.state.CheckpointBoundKeyedStateHandle;
 import org.apache.flink.runtime.state.CheckpointableKeyedStateBackend;
 import org.apache.flink.runtime.state.ConfigurableStateBackend;
 import org.apache.flink.runtime.state.KeyGroupRange;
@@ -272,7 +273,16 @@ public class ChangelogStateBackend implements DelegatingStateBackend, Configurab
                                         : new ChangelogStateBackendHandleImpl(
                                                 singletonList(keyedStateHandle),
                                                 emptyList(),
-                                                keyedStateHandle.getKeyGroupRange()))
+                                                keyedStateHandle.getKeyGroupRange(),
+                                                getMaterializationID(keyedStateHandle)))
                 .collect(Collectors.toList());
     }
+
+    private long getMaterializationID(KeyedStateHandle keyedStateHandle) {
+        if (keyedStateHandle instanceof CheckpointBoundKeyedStateHandle) {
+            return ((CheckpointBoundKeyedStateHandle) keyedStateHandle).getCheckpointId();
+        } else {
+            return 0L;
+        }
+    }
 }
diff --git a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/PeriodicMaterializationManager.java b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/PeriodicMaterializationManager.java
index 1cbdfd2..9f13112 100644
--- a/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/PeriodicMaterializationManager.java
+++ b/flink-state-backends/flink-statebackend-changelog/src/main/java/org/apache/flink/state/changelog/PeriodicMaterializationManager.java
@@ -131,6 +131,7 @@ class PeriodicMaterializationManager implements Closeable {
                                 () ->
                                         asyncMaterializationPhase(
                                                 runnable.getMaterializationRunnable(),
+                                                runnable.getMaterializationID(),
                                                 runnable.getMaterializedTo()));
                     } else {
                         scheduleNextMaterialization();
@@ -147,6 +148,7 @@ class PeriodicMaterializationManager implements Closeable {
 
     private void asyncMaterializationPhase(
             RunnableFuture<SnapshotResult<KeyedStateHandle>> materializedRunnableFuture,
+            long materializationID,
             SequenceNumber upTo) {
 
         uploadSnapshot(materializedRunnableFuture)
@@ -159,7 +161,7 @@ class PeriodicMaterializationManager implements Closeable {
                                 mailboxExecutor.execute(
                                         () ->
                                                 keyedStateBackend.updateChangelogSnapshotState(
-                                                        snapshotResult, upTo),
+                                                        snapshotResult, materializationID, upTo),
                                         "Task {} update materializedSnapshot up to changelog sequence number: {}",
                                         subtaskName,
                                         upTo);
@@ -255,6 +257,8 @@ class PeriodicMaterializationManager implements Closeable {
     static class MaterializationRunnable {
         private final RunnableFuture<SnapshotResult<KeyedStateHandle>> materializationRunnable;
 
+        private final long materializationID;
+
         /**
          * The {@link SequenceNumber} up to which the state is materialized, exclusive. This
          * indicates the non-materialized part of the current changelog.
@@ -263,9 +267,11 @@ class PeriodicMaterializationManager implements Closeable {
 
         public MaterializationRunnable(
                 RunnableFuture<SnapshotResult<KeyedStateHandle>> materializationRunnable,
+                long materializationID,
                 SequenceNumber materializedTo) {
             this.materializationRunnable = materializationRunnable;
             this.materializedTo = materializedTo;
+            this.materializationID = materializationID;
         }
 
         RunnableFuture<SnapshotResult<KeyedStateHandle>> getMaterializationRunnable() {
@@ -275,5 +281,9 @@ class PeriodicMaterializationManager implements Closeable {
         SequenceNumber getMaterializedTo() {
             return materializedTo;
         }
+
+        public long getMaterializationID() {
+            return materializationID;
+        }
     }
 }
diff --git a/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackendTest.java b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackendTest.java
new file mode 100644
index 0000000..c0abc6e
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogKeyedStateBackendTest.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.changelog;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.fs.CloseableRegistry;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.query.KvStateRegistry;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.UncompressedStreamCompressionDecorator;
+import org.apache.flink.runtime.state.changelog.SequenceNumber;
+import org.apache.flink.runtime.state.changelog.inmemory.InMemoryStateChangelogStorage;
+import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
+import org.apache.flink.runtime.state.metrics.LatencyTrackingStateConfig;
+import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
+import org.apache.flink.runtime.state.ttl.mock.MockKeyedStateBackend;
+import org.apache.flink.runtime.state.ttl.mock.MockKeyedStateBackendBuilder;
+import org.apache.flink.state.changelog.ChangelogStateBackendTestUtils.DummyCheckpointingStorageAccess;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameter;
+
+import java.util.concurrent.RunnableFuture;
+
+import static java.util.Collections.emptyList;
+import static org.junit.Assert.assertEquals;
+
+/** {@link ChangelogKeyedStateBackend} test. */
+@RunWith(Parameterized.class)
+public class ChangelogKeyedStateBackendTest {
+
+    @Parameterized.Parameters(name = "checkpointID={0}, materializationId={1}")
+    public static Object[][] parameters() {
+        return new Object[][] {
+            {0L, 200L},
+            {200L, 0L},
+        };
+    }
+
+    @Parameter(0)
+    public long checkpointId;
+
+    @Parameter(1)
+    public long materializationId;
+
+    @Test
+    public void testCheckpointConfirmation() throws Exception {
+        MockKeyedStateBackend<Integer> mock = createMock();
+        ChangelogKeyedStateBackend<Integer> changelog = createChangelog(mock);
+        try {
+            changelog.updateChangelogSnapshotState(
+                    SnapshotResult.empty(), materializationId, SequenceNumber.of(Long.MAX_VALUE));
+            checkpoint(changelog, checkpointId).get().discardState();
+
+            changelog.notifyCheckpointComplete(checkpointId);
+            assertEquals(materializationId, mock.getLastCompletedCheckpointID());
+
+        } finally {
+            changelog.close();
+            changelog.dispose();
+        }
+    }
+
+    private MockKeyedStateBackend<Integer> createMock() {
+        return new MockKeyedStateBackendBuilder<>(
+                        new KvStateRegistry().createTaskRegistry(new JobID(), new JobVertexID()),
+                        IntSerializer.INSTANCE,
+                        getClass().getClassLoader(),
+                        1,
+                        KeyGroupRange.EMPTY_KEY_GROUP_RANGE,
+                        new ExecutionConfig(),
+                        TtlTimeProvider.DEFAULT,
+                        LatencyTrackingStateConfig.disabled(),
+                        emptyList(),
+                        UncompressedStreamCompressionDecorator.INSTANCE,
+                        new CloseableRegistry(),
+                        true)
+                .build();
+    }
+
+    private ChangelogKeyedStateBackend<Integer> createChangelog(
+            MockKeyedStateBackend<Integer> mock) {
+        return new ChangelogKeyedStateBackend<>(
+                mock,
+                "test",
+                new ExecutionConfig(),
+                TtlTimeProvider.DEFAULT,
+                new InMemoryStateChangelogStorage()
+                        .createWriter("test", KeyGroupRange.EMPTY_KEY_GROUP_RANGE),
+                emptyList(),
+                new DummyCheckpointingStorageAccess());
+    }
+
+    private RunnableFuture<SnapshotResult<KeyedStateHandle>> checkpoint(
+            ChangelogKeyedStateBackend<Integer> backend, long checkpointId) throws Exception {
+        return backend.snapshot(
+                checkpointId,
+                0L,
+                new MemCheckpointStreamFactory(1000),
+                CheckpointOptions.forCheckpointWithDefaultLocation());
+    }
+}
diff --git a/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateBackendTestUtils.java b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateBackendTestUtils.java
index bed4773..c3710dd 100644
--- a/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateBackendTestUtils.java
+++ b/flink-state-backends/flink-statebackend-changelog/src/test/java/org/apache/flink/state/changelog/ChangelogStateBackendTestUtils.java
@@ -248,7 +248,8 @@ public class ChangelogStateBackendTestUtils {
     private static void materialize(
             ChangelogKeyedStateBackend<Integer> keyedBackend,
             PeriodicMaterializationManager periodicMaterializationManager) {
-        StateChangelogWriter<ChangelogStateHandle> writer = keyedBackend.getChangelogWriter();
+        StateChangelogWriter<? extends ChangelogStateHandle> writer =
+                keyedBackend.getChangelogWriter();
         SequenceNumber sqnBefore = writer.lastAppendedSequenceNumber();
         periodicMaterializationManager.triggerMaterialization();
         assertTrue(
@@ -356,7 +357,8 @@ public class ChangelogStateBackendTestUtils {
                 1);
     }
 
-    static class DummyCheckpointingStorageAccess implements CheckpointStorageAccess {
+    /** Dummy {@link CheckpointStorageAccess}. */
+    public static class DummyCheckpointingStorageAccess implements CheckpointStorageAccess {
 
         DummyCheckpointingStorageAccess() {}