You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ta...@apache.org on 2022/03/21 03:11:06 UTC

[flink] branch master updated (1fa91ba -> 1bf45b2)

This is an automated email from the ASF dual-hosted git repository.

tangyun pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git.


    from 1fa91ba  [FLINK-26151]Avoid inprogressfileRecoverable not be clean up after restoring the bucket
     new b6822f2  [FLINK-21321]: change RocksDB rescale to use deleteRange
     new 1bf45b2  [FLINK-21321][Runtime/StateBackends] Add ITCases for rescaling from checkpoint

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../state/RocksDBIncrementalCheckpointUtils.java   |  37 +--
 .../RocksDBIncrementalRestoreOperation.java        |   3 +-
 .../RocksDBIncrementalCheckpointUtilsTest.java     |   3 +-
 .../RescaleCheckpointManuallyITCase.java           | 286 +++++++++++++++++++++
 .../flink/test/checkpointing/RescalingITCase.java  | 136 +---------
 .../ResumeCheckpointManuallyITCase.java            |  65 +----
 .../checkpointing/utils/RescalingTestUtils.java    | 162 ++++++++++++
 .../java/org/apache/flink/test/util/TestUtils.java |  22 ++
 8 files changed, 490 insertions(+), 224 deletions(-)
 create mode 100644 flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java
 create mode 100644 flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java

[flink] 01/02: [FLINK-21321]: change RocksDB rescale to use deleteRange

Posted by ta...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

tangyun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit b6822f293e24263752419e49f8b7910f4e0464a8
Author: Joey Pereira <jo...@pereira.io>
AuthorDate: Tue May 26 08:26:58 2020 -0400

    [FLINK-21321]: change RocksDB rescale to use deleteRange
    
    Previously, the Flink incremental checkpoint restore operation would
    scan and delete individual keys during recovery when rescaling. This is
    done to truncate the ranges of the checkpoints which are no longer part
    of the assigned key-range for a worker.
    
    Now, this operation is replaced and uses RocksDB's deleteRange
    operation. This operation is preferred because it can cheaply remove
    data, via tombstones.
    
    The RocksDB API for DeleteRange is here,
    https://github.com/facebook/rocksdb/blob/bcd32560dd5898956b9d24553c2bb3c1b1d2319f/include/rocksdb/db.h#L357-L373
    
    Tombstones are described in further detail here,
    https://rocksdb.org/blog/2018/11/21/delete-range.html
    
    Additionally, this adds a benchmark test based on
    RocksIncrementalCheckpointRescalingTest which triggers the modified
    re-scaling code.
---
 .../state/RocksDBIncrementalCheckpointUtils.java   |  37 +---
 .../RocksDBIncrementalRestoreOperation.java        |   3 +-
 .../RocksDBIncrementalCheckpointUtilsTest.java     |   3 +-
 ...ncrementalCheckpointRescalingBenchmarkTest.java | 240 +++++++++++++++++++++
 4 files changed, 251 insertions(+), 32 deletions(-)

diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
index 4f73ff8..467156c 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
@@ -22,7 +22,6 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 
 import org.rocksdb.ColumnFamilyHandle;
-import org.rocksdb.ReadOptions;
 import org.rocksdb.RocksDB;
 import org.rocksdb.RocksDBException;
 
@@ -108,8 +107,7 @@ public class RocksDBIncrementalCheckpointUtils {
             @Nonnull List<ColumnFamilyHandle> columnFamilyHandles,
             @Nonnull KeyGroupRange targetKeyGroupRange,
             @Nonnull KeyGroupRange currentKeyGroupRange,
-            @Nonnegative int keyGroupPrefixBytes,
-            @Nonnegative long writeBatchSize)
+            @Nonnegative int keyGroupPrefixBytes)
             throws RocksDBException {
 
         final byte[] beginKeyGroupBytes = new byte[keyGroupPrefixBytes];
@@ -120,8 +118,7 @@ public class RocksDBIncrementalCheckpointUtils {
                     currentKeyGroupRange.getStartKeyGroup(), beginKeyGroupBytes);
             CompositeKeySerializationUtils.serializeKeyGroup(
                     targetKeyGroupRange.getStartKeyGroup(), endKeyGroupBytes);
-            deleteRange(
-                    db, columnFamilyHandles, beginKeyGroupBytes, endKeyGroupBytes, writeBatchSize);
+            deleteRange(db, columnFamilyHandles, beginKeyGroupBytes, endKeyGroupBytes);
         }
 
         if (currentKeyGroupRange.getEndKeyGroup() > targetKeyGroupRange.getEndKeyGroup()) {
@@ -129,8 +126,7 @@ public class RocksDBIncrementalCheckpointUtils {
                     targetKeyGroupRange.getEndKeyGroup() + 1, beginKeyGroupBytes);
             CompositeKeySerializationUtils.serializeKeyGroup(
                     currentKeyGroupRange.getEndKeyGroup() + 1, endKeyGroupBytes);
-            deleteRange(
-                    db, columnFamilyHandles, beginKeyGroupBytes, endKeyGroupBytes, writeBatchSize);
+            deleteRange(db, columnFamilyHandles, beginKeyGroupBytes, endKeyGroupBytes);
         }
     }
 
@@ -146,30 +142,15 @@ public class RocksDBIncrementalCheckpointUtils {
             RocksDB db,
             List<ColumnFamilyHandle> columnFamilyHandles,
             byte[] beginKeyBytes,
-            byte[] endKeyBytes,
-            @Nonnegative long writeBatchSize)
+            byte[] endKeyBytes)
             throws RocksDBException {
 
         for (ColumnFamilyHandle columnFamilyHandle : columnFamilyHandles) {
-            try (ReadOptions readOptions = new ReadOptions();
-                    RocksIteratorWrapper iteratorWrapper =
-                            RocksDBOperationUtils.getRocksIterator(
-                                    db, columnFamilyHandle, readOptions);
-                    RocksDBWriteBatchWrapper writeBatchWrapper =
-                            new RocksDBWriteBatchWrapper(db, writeBatchSize)) {
-
-                iteratorWrapper.seek(beginKeyBytes);
-
-                while (iteratorWrapper.isValid()) {
-                    final byte[] currentKey = iteratorWrapper.key();
-                    if (beforeThePrefixBytes(currentKey, endKeyBytes)) {
-                        writeBatchWrapper.remove(columnFamilyHandle, currentKey);
-                    } else {
-                        break;
-                    }
-                    iteratorWrapper.next();
-                }
-            }
+            // Using RocksDB's deleteRange will take advantage of delete
+            // tombstones, which mark the range as deleted.
+            //
+            // https://github.com/facebook/rocksdb/blob/bcd32560dd5898956b9d24553c2bb3c1b1d2319f/include/rocksdb/db.h#L357-L371
+            db.deleteRange(columnFamilyHandle, beginKeyBytes, endKeyBytes);
         }
     }
 
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
index 657b6f2..d6ec9ae6 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java
@@ -390,8 +390,7 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper
                     this.rocksHandle.getColumnFamilyHandles(),
                     keyGroupRange,
                     initialHandle.getKeyGroupRange(),
-                    keyGroupPrefixBytes,
-                    writeBatchSize);
+                    keyGroupPrefixBytes);
         } catch (RocksDBException e) {
             String errMsg = "Failed to clip DB after initialization.";
             logger.error(errMsg, e);
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtilsTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtilsTest.java
index ff6d854..2de72a7 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtilsTest.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtilsTest.java
@@ -185,8 +185,7 @@ public class RocksDBIncrementalCheckpointUtilsTest extends TestLogger {
                     Collections.singletonList(columnFamilyHandle),
                     targetGroupRange,
                     currentGroupRange,
-                    keyGroupPrefixBytes,
-                    RocksDBConfigurableOptions.WRITE_BATCH_SIZE.defaultValue().getBytes());
+                    keyGroupPrefixBytes);
 
             for (int i = currentGroupRangeStart; i <= currentGroupRangeEnd; ++i) {
                 for (int j = 0; j < 100; ++j) {
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java
new file mode 100644
index 0000000..a4e267b
--- /dev/null
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.contrib.streaming.state;
+
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
+import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
+import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
+import org.apache.flink.testutils.junit.RetryOnFailure;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Collection;
+import java.util.List;
+
+/** Test runs the benchmark for incremental checkpoint rescaling. */
+public class RocksIncrementalCheckpointRescalingBenchmarkTest extends TestLogger {
+
+    @Rule public TemporaryFolder rootFolder = new TemporaryFolder();
+
+    private static final int maxParallelism = 10;
+
+    private static final int recordCount = 1_000;
+
+    /** partitionParallelism is the parallelism to use for creating the partitionedSnapshot. */
+    private static final int partitionParallelism = 2;
+
+    /**
+     * repartitionParallelism is the parallelism to use during the test for the repartition step.
+     *
+     * <p>NOTE: To trigger {@link
+     * org.apache.flink.contrib.streaming.state.restore.RocksDBIncrementalRestoreOperation#restoreWithRescaling(Collection)},
+     * where the improvement code is exercised, the target parallelism must not be divisible by
+     * {@link partitionParallelism}. If this parallelism was instead 4, then there is no rescale.
+     */
+    private static final int repartitionParallelism = 3;
+
+    /** partitionedSnapshot is a partitioned incremental RocksDB snapshot. */
+    private OperatorSubtaskState partitionedSnapshot;
+
+    private KeySelector<Integer, Integer> keySelector = new TestKeySelector();
+
+    /**
+     * The benchmark's preparation will:
+     *
+     * <ol>
+     *   <li>Create a stateful operator and process records to persist state.
+     *   <li>Snapshot the state and re-partition it so the test operates on a partitioned state.
+     * </ol>
+     *
+     * @throws Exception
+     */
+    @Before
+    public void before() throws Exception {
+        OperatorSubtaskState snapshot;
+        // Initialize the test harness with a a task parallelism of 1.
+        try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> harness =
+                getHarnessTest(keySelector, maxParallelism, 1, 0)) {
+            // Set the state backend of the harness to RocksDB.
+            harness.setStateBackend(getStateBackend());
+            // Initialize the harness.
+            harness.open();
+            // Push the test records into the operator to trigger state updates.
+            Integer[] records = new Integer[recordCount];
+            for (int i = 0; i < recordCount; i++) {
+                harness.processElement(new StreamRecord<>(i, 1));
+            }
+            // Grab a snapshot of the state.
+            snapshot = harness.snapshot(0, 0);
+        }
+
+        // Now, re-partition to create a partitioned state.
+        KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer>[] partitionedTestHarness =
+                new KeyedOneInputStreamOperatorTestHarness[partitionParallelism];
+        List<KeyGroupRange> keyGroupPartitions =
+                StateAssignmentOperation.createKeyGroupPartitions(
+                        maxParallelism, partitionParallelism);
+        try {
+            for (int i = 0; i < partitionParallelism; i++) {
+                // Initialize, open, and then re-snapshot the two subtasks to create a partitioned
+                // incremental RocksDB snapshot.
+                OperatorSubtaskState subtaskState =
+                        AbstractStreamOperatorTestHarness.repartitionOperatorState(
+                                snapshot, maxParallelism, 1, partitionParallelism, i);
+                KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i);
+
+                partitionedTestHarness[i] =
+                        getHarnessTest(keySelector, maxParallelism, partitionParallelism, i);
+                partitionedTestHarness[i].setStateBackend(getStateBackend());
+                partitionedTestHarness[i].setup();
+                partitionedTestHarness[i].initializeState(subtaskState);
+                partitionedTestHarness[i].open();
+            }
+
+            partitionedSnapshot =
+                    AbstractStreamOperatorTestHarness.repackageState(
+                            partitionedTestHarness[0].snapshot(1, 2),
+                            partitionedTestHarness[1].snapshot(1, 2));
+
+        } finally {
+            closeHarness(partitionedTestHarness);
+        }
+    }
+
+    @Test(timeout = 1000)
+    @RetryOnFailure(times = 3)
+    public void benchmarkScalingUp() throws Exception {
+        long benchmarkTime = 0;
+
+        // Trigger the incremental re-scaling via restoreWithRescaling by repartitioning it from
+        // parallelism of >1 to a higher parallelism. Time spent during this step includes the cost
+        // of incremental rescaling.
+        List<KeyGroupRange> keyGroupPartitions =
+                StateAssignmentOperation.createKeyGroupPartitions(
+                        maxParallelism, repartitionParallelism);
+
+        long fullStateSize = partitionedSnapshot.getStateSize();
+
+        for (int i = 0; i < repartitionParallelism; i++) {
+            OperatorSubtaskState subtaskState =
+                    AbstractStreamOperatorTestHarness.repartitionOperatorState(
+                            partitionedSnapshot,
+                            maxParallelism,
+                            partitionParallelism,
+                            repartitionParallelism,
+                            i);
+            KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i);
+
+            try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> subtaskHarness =
+                    getHarnessTest(keySelector, maxParallelism, repartitionParallelism, i)) {
+                RocksDBStateBackend backend = getStateBackend();
+                subtaskHarness.setStateBackend(backend);
+                subtaskHarness.setup();
+
+                // Precisely measure the call-site that triggers the restore operation.
+                long startingTime = System.nanoTime();
+                subtaskHarness.initializeState(subtaskState);
+                benchmarkTime += System.nanoTime() - startingTime;
+            }
+        }
+
+        log.error(
+                "--------------> performance for incremental checkpoint re-scaling <--------------");
+        log.error(
+                "rescale from {} to {} with {} records took: {} nanoseconds",
+                partitionParallelism,
+                repartitionParallelism,
+                recordCount,
+                benchmarkTime);
+    }
+
+    private KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> getHarnessTest(
+            KeySelector<Integer, Integer> keySelector,
+            int maxParallelism,
+            int taskParallelism,
+            int subtaskIdx)
+            throws Exception {
+        return new KeyedOneInputStreamOperatorTestHarness<>(
+                new KeyedProcessOperator<>(new TestKeyedFunction()),
+                keySelector,
+                BasicTypeInfo.INT_TYPE_INFO,
+                maxParallelism,
+                taskParallelism,
+                subtaskIdx);
+    }
+
+    private void closeHarness(KeyedOneInputStreamOperatorTestHarness<?, ?, ?>[] harnessArr)
+            throws Exception {
+        for (KeyedOneInputStreamOperatorTestHarness<?, ?, ?> harness : harnessArr) {
+            if (harness != null) {
+                harness.close();
+            }
+        }
+    }
+
+    private RocksDBStateBackend getStateBackend() throws Exception {
+        return new RocksDBStateBackend("file://" + rootFolder.newFolder().getAbsolutePath(), true);
+    }
+
+    /** A simple keyed function for tests. */
+    private class TestKeyedFunction extends KeyedProcessFunction<Integer, Integer, Integer> {
+
+        public ValueStateDescriptor<Integer> stateDescriptor;
+        private ValueState<Integer> counterState;
+
+        @Override
+        public void open(Configuration parameters) throws Exception {
+            super.open(parameters);
+            stateDescriptor = new ValueStateDescriptor<Integer>("counter", Integer.class);
+            counterState = this.getRuntimeContext().getState(stateDescriptor);
+        }
+
+        @Override
+        public void processElement(Integer incomingValue, Context ctx, Collector<Integer> out)
+                throws Exception {
+            Integer oldValue = counterState.value();
+            Integer newValue = oldValue != null ? oldValue + incomingValue : incomingValue;
+            counterState.update(newValue);
+            out.collect(newValue);
+        }
+    }
+
+    /** A simple key selector for tests. */
+    private class TestKeySelector implements KeySelector<Integer, Integer> {
+        @Override
+        public Integer getKey(Integer value) throws Exception {
+            return value;
+        }
+    }
+}

[flink] 02/02: [FLINK-21321][Runtime/StateBackends] Add ITCases for rescaling from checkpoint

Posted by ta...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

tangyun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1bf45b25791cc3fad8b7d0d863caa9b0eef9a87b
Author: fredia <fr...@gmail.com>
AuthorDate: Thu Mar 10 11:33:59 2022 +0800

    [FLINK-21321][Runtime/StateBackends] Add ITCases for rescaling from checkpoint
---
 .../state/RocksDBIncrementalCheckpointUtils.java   |   2 +-
 ...ncrementalCheckpointRescalingBenchmarkTest.java | 240 -----------------
 .../RescaleCheckpointManuallyITCase.java           | 286 +++++++++++++++++++++
 .../flink/test/checkpointing/RescalingITCase.java  | 136 +---------
 .../ResumeCheckpointManuallyITCase.java            |  65 +----
 .../checkpointing/utils/RescalingTestUtils.java    | 162 ++++++++++++
 .../java/org/apache/flink/test/util/TestUtils.java |  22 ++
 7 files changed, 480 insertions(+), 433 deletions(-)

diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
index 467156c..23c7867 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
@@ -149,7 +149,7 @@ public class RocksDBIncrementalCheckpointUtils {
             // Using RocksDB's deleteRange will take advantage of delete
             // tombstones, which mark the range as deleted.
             //
-            // https://github.com/facebook/rocksdb/blob/bcd32560dd5898956b9d24553c2bb3c1b1d2319f/include/rocksdb/db.h#L357-L371
+            // https://github.com/ververica/frocksdb/blob/FRocksDB-6.20.3/include/rocksdb/db.h#L363-L377
             db.deleteRange(columnFamilyHandle, beginKeyBytes, endKeyBytes);
         }
     }
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java
deleted file mode 100644
index a4e267b..0000000
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java
+++ /dev/null
@@ -1,240 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.contrib.streaming.state;
-
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
-import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
-import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
-import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
-import org.apache.flink.testutils.junit.RetryOnFailure;
-import org.apache.flink.util.Collector;
-import org.apache.flink.util.TestLogger;
-
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
-
-import java.util.Collection;
-import java.util.List;
-
-/** Test runs the benchmark for incremental checkpoint rescaling. */
-public class RocksIncrementalCheckpointRescalingBenchmarkTest extends TestLogger {
-
-    @Rule public TemporaryFolder rootFolder = new TemporaryFolder();
-
-    private static final int maxParallelism = 10;
-
-    private static final int recordCount = 1_000;
-
-    /** partitionParallelism is the parallelism to use for creating the partitionedSnapshot. */
-    private static final int partitionParallelism = 2;
-
-    /**
-     * repartitionParallelism is the parallelism to use during the test for the repartition step.
-     *
-     * <p>NOTE: To trigger {@link
-     * org.apache.flink.contrib.streaming.state.restore.RocksDBIncrementalRestoreOperation#restoreWithRescaling(Collection)},
-     * where the improvement code is exercised, the target parallelism must not be divisible by
-     * {@link partitionParallelism}. If this parallelism was instead 4, then there is no rescale.
-     */
-    private static final int repartitionParallelism = 3;
-
-    /** partitionedSnapshot is a partitioned incremental RocksDB snapshot. */
-    private OperatorSubtaskState partitionedSnapshot;
-
-    private KeySelector<Integer, Integer> keySelector = new TestKeySelector();
-
-    /**
-     * The benchmark's preparation will:
-     *
-     * <ol>
-     *   <li>Create a stateful operator and process records to persist state.
-     *   <li>Snapshot the state and re-partition it so the test operates on a partitioned state.
-     * </ol>
-     *
-     * @throws Exception
-     */
-    @Before
-    public void before() throws Exception {
-        OperatorSubtaskState snapshot;
-        // Initialize the test harness with a a task parallelism of 1.
-        try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> harness =
-                getHarnessTest(keySelector, maxParallelism, 1, 0)) {
-            // Set the state backend of the harness to RocksDB.
-            harness.setStateBackend(getStateBackend());
-            // Initialize the harness.
-            harness.open();
-            // Push the test records into the operator to trigger state updates.
-            Integer[] records = new Integer[recordCount];
-            for (int i = 0; i < recordCount; i++) {
-                harness.processElement(new StreamRecord<>(i, 1));
-            }
-            // Grab a snapshot of the state.
-            snapshot = harness.snapshot(0, 0);
-        }
-
-        // Now, re-partition to create a partitioned state.
-        KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer>[] partitionedTestHarness =
-                new KeyedOneInputStreamOperatorTestHarness[partitionParallelism];
-        List<KeyGroupRange> keyGroupPartitions =
-                StateAssignmentOperation.createKeyGroupPartitions(
-                        maxParallelism, partitionParallelism);
-        try {
-            for (int i = 0; i < partitionParallelism; i++) {
-                // Initialize, open, and then re-snapshot the two subtasks to create a partitioned
-                // incremental RocksDB snapshot.
-                OperatorSubtaskState subtaskState =
-                        AbstractStreamOperatorTestHarness.repartitionOperatorState(
-                                snapshot, maxParallelism, 1, partitionParallelism, i);
-                KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i);
-
-                partitionedTestHarness[i] =
-                        getHarnessTest(keySelector, maxParallelism, partitionParallelism, i);
-                partitionedTestHarness[i].setStateBackend(getStateBackend());
-                partitionedTestHarness[i].setup();
-                partitionedTestHarness[i].initializeState(subtaskState);
-                partitionedTestHarness[i].open();
-            }
-
-            partitionedSnapshot =
-                    AbstractStreamOperatorTestHarness.repackageState(
-                            partitionedTestHarness[0].snapshot(1, 2),
-                            partitionedTestHarness[1].snapshot(1, 2));
-
-        } finally {
-            closeHarness(partitionedTestHarness);
-        }
-    }
-
-    @Test(timeout = 1000)
-    @RetryOnFailure(times = 3)
-    public void benchmarkScalingUp() throws Exception {
-        long benchmarkTime = 0;
-
-        // Trigger the incremental re-scaling via restoreWithRescaling by repartitioning it from
-        // parallelism of >1 to a higher parallelism. Time spent during this step includes the cost
-        // of incremental rescaling.
-        List<KeyGroupRange> keyGroupPartitions =
-                StateAssignmentOperation.createKeyGroupPartitions(
-                        maxParallelism, repartitionParallelism);
-
-        long fullStateSize = partitionedSnapshot.getStateSize();
-
-        for (int i = 0; i < repartitionParallelism; i++) {
-            OperatorSubtaskState subtaskState =
-                    AbstractStreamOperatorTestHarness.repartitionOperatorState(
-                            partitionedSnapshot,
-                            maxParallelism,
-                            partitionParallelism,
-                            repartitionParallelism,
-                            i);
-            KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i);
-
-            try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> subtaskHarness =
-                    getHarnessTest(keySelector, maxParallelism, repartitionParallelism, i)) {
-                RocksDBStateBackend backend = getStateBackend();
-                subtaskHarness.setStateBackend(backend);
-                subtaskHarness.setup();
-
-                // Precisely measure the call-site that triggers the restore operation.
-                long startingTime = System.nanoTime();
-                subtaskHarness.initializeState(subtaskState);
-                benchmarkTime += System.nanoTime() - startingTime;
-            }
-        }
-
-        log.error(
-                "--------------> performance for incremental checkpoint re-scaling <--------------");
-        log.error(
-                "rescale from {} to {} with {} records took: {} nanoseconds",
-                partitionParallelism,
-                repartitionParallelism,
-                recordCount,
-                benchmarkTime);
-    }
-
-    private KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> getHarnessTest(
-            KeySelector<Integer, Integer> keySelector,
-            int maxParallelism,
-            int taskParallelism,
-            int subtaskIdx)
-            throws Exception {
-        return new KeyedOneInputStreamOperatorTestHarness<>(
-                new KeyedProcessOperator<>(new TestKeyedFunction()),
-                keySelector,
-                BasicTypeInfo.INT_TYPE_INFO,
-                maxParallelism,
-                taskParallelism,
-                subtaskIdx);
-    }
-
-    private void closeHarness(KeyedOneInputStreamOperatorTestHarness<?, ?, ?>[] harnessArr)
-            throws Exception {
-        for (KeyedOneInputStreamOperatorTestHarness<?, ?, ?> harness : harnessArr) {
-            if (harness != null) {
-                harness.close();
-            }
-        }
-    }
-
-    private RocksDBStateBackend getStateBackend() throws Exception {
-        return new RocksDBStateBackend("file://" + rootFolder.newFolder().getAbsolutePath(), true);
-    }
-
-    /** A simple keyed function for tests. */
-    private class TestKeyedFunction extends KeyedProcessFunction<Integer, Integer, Integer> {
-
-        public ValueStateDescriptor<Integer> stateDescriptor;
-        private ValueState<Integer> counterState;
-
-        @Override
-        public void open(Configuration parameters) throws Exception {
-            super.open(parameters);
-            stateDescriptor = new ValueStateDescriptor<Integer>("counter", Integer.class);
-            counterState = this.getRuntimeContext().getState(stateDescriptor);
-        }
-
-        @Override
-        public void processElement(Integer incomingValue, Context ctx, Collector<Integer> out)
-                throws Exception {
-            Integer oldValue = counterState.value();
-            Integer newValue = oldValue != null ? oldValue + incomingValue : incomingValue;
-            counterState.update(newValue);
-            out.collect(newValue);
-        }
-    }
-
-    /** A simple key selector for tests. */
-    private class TestKeySelector implements KeySelector<Integer, Integer> {
-        @Override
-        public Integer getKey(Integer value) throws Exception {
-            return value;
-        }
-    }
-}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java
new file mode 100644
index 0000000..8c05afa
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.client.program.ClusterClient;
+import org.apache.flink.configuration.CheckpointingOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.StateBackendOptions;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.CheckpointConfig;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils;
+import org.apache.flink.test.util.MiniClusterWithClientResource;
+import org.apache.flink.test.util.TestUtils;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.File;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+
+import static org.apache.flink.test.util.TestUtils.submitJobAndWaitForResult;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/** Test checkpoint rescaling for incremental rocksdb. */
+public class RescaleCheckpointManuallyITCase extends TestLogger {
+
+    private static final int NUM_TASK_MANAGERS = 2;
+    private static final int SLOTS_PER_TASK_MANAGER = 2;
+
+    private static MiniClusterWithClientResource cluster;
+    private File checkpointDir;
+
+    @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+    @Before
+    public void setup() throws Exception {
+        Configuration config = new Configuration();
+
+        checkpointDir = temporaryFolder.newFolder();
+
+        config.setString(StateBackendOptions.STATE_BACKEND, "rocksdb");
+        config.setString(
+                CheckpointingOptions.CHECKPOINTS_DIRECTORY, checkpointDir.toURI().toString());
+        config.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true);
+
+        cluster =
+                new MiniClusterWithClientResource(
+                        new MiniClusterResourceConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumberTaskManagers(NUM_TASK_MANAGERS)
+                                .setNumberSlotsPerTaskManager(SLOTS_PER_TASK_MANAGER)
+                                .build());
+        cluster.before();
+    }
+
+    @Test
+    public void testCheckpointRescalingInKeyedState() throws Exception {
+        testCheckpointRescalingKeyedState(false);
+    }
+
+    @Test
+    public void testCheckpointRescalingOutKeyedState() throws Exception {
+        testCheckpointRescalingKeyedState(true);
+    }
+
+    /**
+     * Tests that a job with purely keyed state can be restarted from a checkpoint with a different
+     * parallelism.
+     */
+    public void testCheckpointRescalingKeyedState(boolean scaleOut) throws Exception {
+        final int numberKeys = 42;
+        final int numberElements = 1000;
+        final int numberElements2 = 500;
+        final int parallelism = scaleOut ? 3 : 4;
+        final int parallelism2 = scaleOut ? 4 : 3;
+        final int maxParallelism = 13;
+
+        cluster.before();
+
+        ClusterClient<?> client = cluster.getClusterClient();
+        String checkpointPath =
+                runJobAndGetCheckpoint(
+                        numberKeys,
+                        numberElements,
+                        parallelism,
+                        maxParallelism,
+                        client,
+                        checkpointDir);
+
+        assertNotNull(checkpointPath);
+
+        restoreAndAssert(
+                parallelism2,
+                maxParallelism,
+                maxParallelism,
+                numberKeys,
+                numberElements2,
+                numberElements + numberElements2,
+                client,
+                checkpointPath);
+    }
+
+    private static String runJobAndGetCheckpoint(
+            int numberKeys,
+            int numberElements,
+            int parallelism,
+            int maxParallelism,
+            ClusterClient<?> client,
+            File checkpointDir)
+            throws Exception {
+        try {
+            JobGraph jobGraph =
+                    createJobGraphWithKeyedState(
+                            parallelism, maxParallelism, numberKeys, numberElements, false, 100);
+            NotifyingDefiniteKeySource.sourceLatch = new CountDownLatch(parallelism);
+            client.submitJob(jobGraph).get();
+            NotifyingDefiniteKeySource.sourceLatch.await();
+
+            RescalingTestUtils.SubtaskIndexFlatMapper.workCompletedLatch.await();
+
+            // verify the current state
+            Set<Tuple2<Integer, Integer>> actualResult =
+                    RescalingTestUtils.CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+                expectedResult.add(
+                        Tuple2.of(
+                                KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelism, parallelism, keyGroupIndex),
+                                numberElements * key));
+            }
+
+            assertEquals(expectedResult, actualResult);
+            NotifyingDefiniteKeySource.sourceLatch.await();
+
+            TestUtils.waitUntilExternalizedCheckpointCreated(checkpointDir);
+            client.cancel(jobGraph.getJobID()).get();
+            TestUtils.waitUntilJobCanceled(jobGraph.getJobID(), client);
+            return TestUtils.getMostRecentCompletedCheckpoint(checkpointDir).getAbsolutePath();
+        } finally {
+            RescalingTestUtils.CollectionSink.clearElementsSet();
+        }
+    }
+
+    private void restoreAndAssert(
+            int restoreParallelism,
+            int restoreMaxParallelism,
+            int maxParallelismBefore,
+            int numberKeys,
+            int numberElements,
+            int numberElementsExpect,
+            ClusterClient<?> client,
+            String restorePath)
+            throws Exception {
+        try {
+
+            JobGraph scaledJobGraph =
+                    createJobGraphWithKeyedState(
+                            restoreParallelism,
+                            restoreMaxParallelism,
+                            numberKeys,
+                            numberElements,
+                            true,
+                            100);
+
+            scaledJobGraph.setSavepointRestoreSettings(
+                    SavepointRestoreSettings.forPath(restorePath));
+
+            submitJobAndWaitForResult(client, scaledJobGraph, getClass().getClassLoader());
+
+            Set<Tuple2<Integer, Integer>> actualResult2 =
+                    RescalingTestUtils.CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex =
+                        KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelismBefore);
+                expectedResult2.add(
+                        Tuple2.of(
+                                KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelismBefore, restoreParallelism, keyGroupIndex),
+                                key * numberElementsExpect));
+            }
+            assertEquals(expectedResult2, actualResult2);
+        } finally {
+            RescalingTestUtils.CollectionSink.clearElementsSet();
+        }
+    }
+
+    private static JobGraph createJobGraphWithKeyedState(
+            int parallelism,
+            int maxParallelism,
+            int numberKeys,
+            int numberElements,
+            boolean terminateAfterEmission,
+            int checkpointingInterval) {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(parallelism);
+        if (0 < maxParallelism) {
+            env.getConfig().setMaxParallelism(maxParallelism);
+        }
+        env.enableCheckpointing(checkpointingInterval);
+        env.getCheckpointConfig()
+                .setExternalizedCheckpointCleanup(
+                        CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        env.getConfig().setUseSnapshotCompression(true);
+
+        DataStream<Integer> input =
+                env.addSource(
+                                new NotifyingDefiniteKeySource(
+                                        numberKeys, numberElements, terminateAfterEmission))
+                        .keyBy(
+                                new KeySelector<Integer, Integer>() {
+                                    private static final long serialVersionUID =
+                                            -7952298871120320940L;
+
+                                    @Override
+                                    public Integer getKey(Integer value) throws Exception {
+                                        return value;
+                                    }
+                                });
+        RescalingTestUtils.SubtaskIndexFlatMapper.workCompletedLatch =
+                new CountDownLatch(numberKeys);
+
+        DataStream<Tuple2<Integer, Integer>> result =
+                input.flatMap(new RescalingTestUtils.SubtaskIndexFlatMapper(numberElements));
+
+        result.addSink(new RescalingTestUtils.CollectionSink<>());
+
+        return env.getStreamGraph().getJobGraph();
+    }
+
+    private static class NotifyingDefiniteKeySource extends RescalingTestUtils.DefiniteKeySource {
+        private static final long serialVersionUID = 8120981235081181746L;
+
+        private static CountDownLatch sourceLatch;
+
+        public NotifyingDefiniteKeySource(
+                int numberKeys, int numberElements, boolean terminateAfterEmission) {
+            super(numberKeys, numberElements, terminateAfterEmission);
+        }
+
+        @Override
+        public void run(SourceContext<Integer> ctx) throws Exception {
+            if (sourceLatch != null) {
+                sourceLatch.countDown();
+            }
+            super.run(ctx);
+        }
+    }
+}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index ad43eae..d27d6f4 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -19,12 +19,9 @@
 package org.apache.flink.test.checkpointing;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.time.Deadline;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
@@ -50,12 +47,13 @@ import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
-import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.CollectionSink;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.DefiniteKeySource;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.SubtaskIndexFlatMapper;
 import org.apache.flink.test.util.MiniClusterWithClientResource;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 import org.apache.flink.util.concurrent.FutureUtils;
 
@@ -77,7 +75,6 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
@@ -672,7 +669,7 @@ public class RescalingITCase extends TestLogger {
 
         DataStream<Integer> input =
                 env.addSource(
-                                new SubtaskIndexSource(
+                                new DefiniteKeySource(
                                         numberKeys, numberElements, terminateAfterEmission))
                         .keyBy(
                                 new KeySelector<Integer, Integer>() {
@@ -736,60 +733,7 @@ public class RescalingITCase extends TestLogger {
         return env.getStreamGraph().getJobGraph();
     }
 
-    private static class SubtaskIndexSource extends RichParallelSourceFunction<Integer> {
-
-        private static final long serialVersionUID = -400066323594122516L;
-
-        private final int numberKeys;
-        private final int numberElements;
-        private final boolean terminateAfterEmission;
-
-        protected int counter = 0;
-
-        private boolean running = true;
-
-        SubtaskIndexSource(int numberKeys, int numberElements, boolean terminateAfterEmission) {
-
-            this.numberKeys = numberKeys;
-            this.numberElements = numberElements;
-            this.terminateAfterEmission = terminateAfterEmission;
-        }
-
-        @Override
-        public void run(SourceContext<Integer> ctx) throws Exception {
-            final Object lock = ctx.getCheckpointLock();
-            final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
-
-            while (running) {
-
-                if (counter < numberElements) {
-                    synchronized (lock) {
-                        for (int value = subtaskIndex;
-                                value < numberKeys;
-                                value += getRuntimeContext().getNumberOfParallelSubtasks()) {
-
-                            ctx.collect(value);
-                        }
-
-                        counter++;
-                    }
-                } else {
-                    if (terminateAfterEmission) {
-                        running = false;
-                    } else {
-                        Thread.sleep(100);
-                    }
-                }
-            }
-        }
-
-        @Override
-        public void cancel() {
-            running = false;
-        }
-    }
-
-    private static class SubtaskIndexNonPartitionedStateSource extends SubtaskIndexSource
+    private static class SubtaskIndexNonPartitionedStateSource extends DefiniteKeySource
             implements ListCheckpointed<Integer> {
 
         private static final long serialVersionUID = 8388073059042040203L;
@@ -814,76 +758,6 @@ public class RescalingITCase extends TestLogger {
         }
     }
 
-    private static class SubtaskIndexFlatMapper
-            extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
-            implements CheckpointedFunction {
-
-        private static final long serialVersionUID = 5273172591283191348L;
-
-        private static CountDownLatch workCompletedLatch = new CountDownLatch(1);
-
-        private transient ValueState<Integer> counter;
-        private transient ValueState<Integer> sum;
-
-        private final int numberElements;
-
-        SubtaskIndexFlatMapper(int numberElements) {
-            this.numberElements = numberElements;
-        }
-
-        @Override
-        public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out)
-                throws Exception {
-
-            int count = counter.value() + 1;
-            counter.update(count);
-
-            int s = sum.value() + value;
-            sum.update(s);
-
-            if (count % numberElements == 0) {
-                out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
-                workCompletedLatch.countDown();
-            }
-        }
-
-        @Override
-        public void snapshotState(FunctionSnapshotContext context) throws Exception {
-            // all managed, nothing to do.
-        }
-
-        @Override
-        public void initializeState(FunctionInitializationContext context) throws Exception {
-            counter =
-                    context.getKeyedStateStore()
-                            .getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
-            sum =
-                    context.getKeyedStateStore()
-                            .getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
-        }
-    }
-
-    private static class CollectionSink<IN> implements SinkFunction<IN> {
-
-        private static Set<Object> elements =
-                Collections.newSetFromMap(new ConcurrentHashMap<Object, Boolean>());
-
-        private static final long serialVersionUID = -1652452958040267745L;
-
-        public static <IN> Set<IN> getElementsSet() {
-            return (Set<IN>) elements;
-        }
-
-        public static void clearElementsSet() {
-            elements.clear();
-        }
-
-        @Override
-        public void invoke(IN value) throws Exception {
-            elements.add(value);
-        }
-    }
-
     private static class StateSourceBase extends RichParallelSourceFunction<Integer> {
 
         private static final long serialVersionUID = 7512206069681177940L;
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java
index eeda1b0..ab5d7f9 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.test.checkpointing;
 
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.api.common.eventtime.AscendingTimestampsWatermarks;
 import org.apache.flink.api.common.eventtime.TimestampAssigner;
 import org.apache.flink.api.common.eventtime.TimestampAssignerSupplier;
@@ -45,6 +43,7 @@ import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindo
 import org.apache.flink.streaming.api.windowing.time.Time;
 import org.apache.flink.test.state.ManualWindowSpeedITCase;
 import org.apache.flink.test.util.MiniClusterWithClientResource;
+import org.apache.flink.test.util.TestUtils;
 import org.apache.flink.util.TestLogger;
 
 import org.apache.curator.test.TestingServer;
@@ -56,12 +55,7 @@ import javax.annotation.Nullable;
 
 import java.io.File;
 import java.io.IOException;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.util.Optional;
 import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutionException;
-import java.util.stream.Stream;
 
 import static org.junit.Assert.assertNotNull;
 
@@ -306,62 +300,11 @@ public class ResumeCheckpointManuallyITCase extends TestLogger {
         // wait until all sources have been started
         NotifyingInfiniteTupleSource.countDownLatch.await();
 
-        waitUntilExternalizedCheckpointCreated(checkpointDir, initialJobGraph.getJobID());
+        TestUtils.waitUntilExternalizedCheckpointCreated(checkpointDir);
         client.cancel(initialJobGraph.getJobID()).get();
-        waitUntilCanceled(initialJobGraph.getJobID(), client);
+        TestUtils.waitUntilJobCanceled(initialJobGraph.getJobID(), client);
 
-        return getExternalizedCheckpointCheckpointPath(checkpointDir, initialJobGraph.getJobID());
-    }
-
-    private static String getExternalizedCheckpointCheckpointPath(File checkpointDir, JobID jobId)
-            throws IOException {
-        Optional<Path> checkpoint = findExternalizedCheckpoint(checkpointDir, jobId);
-        if (!checkpoint.isPresent()) {
-            throw new AssertionError("No complete checkpoint could be found.");
-        } else {
-            return checkpoint.get().toString();
-        }
-    }
-
-    private static void waitUntilExternalizedCheckpointCreated(File checkpointDir, JobID jobId)
-            throws InterruptedException, IOException {
-        while (true) {
-            Thread.sleep(50);
-            Optional<Path> externalizedCheckpoint =
-                    findExternalizedCheckpoint(checkpointDir, jobId);
-            if (externalizedCheckpoint.isPresent()) {
-                break;
-            }
-        }
-    }
-
-    private static Optional<Path> findExternalizedCheckpoint(File checkpointDir, JobID jobId)
-            throws IOException {
-        try (Stream<Path> checkpoints =
-                Files.list(checkpointDir.toPath().resolve(jobId.toString()))) {
-            return checkpoints
-                    .filter(path -> path.getFileName().toString().startsWith("chk-"))
-                    .filter(
-                            path -> {
-                                try (Stream<Path> checkpointFiles = Files.list(path)) {
-                                    return checkpointFiles.anyMatch(
-                                            child ->
-                                                    child.getFileName()
-                                                            .toString()
-                                                            .contains("meta"));
-                                } catch (IOException ignored) {
-                                    return false;
-                                }
-                            })
-                    .findAny();
-        }
-    }
-
-    private static void waitUntilCanceled(JobID jobId, ClusterClient<?> client)
-            throws ExecutionException, InterruptedException {
-        while (client.getJobStatus(jobId).get() != JobStatus.CANCELED) {
-            Thread.sleep(50);
-        }
+        return TestUtils.getMostRecentCompletedCheckpoint(checkpointDir).getAbsolutePath();
     }
 
     private static JobGraph getJobGraph(StateBackend backend, @Nullable String externalCheckpoint) {
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java
new file mode 100644
index 0000000..87214b7
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing.utils;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.util.Collector;
+
+import java.util.Collections;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+
+/** Test utilities for rescaling. */
+public class RescalingTestUtils {
+
+    /** A parallel source with definite keys. */
+    public static class DefiniteKeySource extends RichParallelSourceFunction<Integer> {
+
+        private static final long serialVersionUID = -400066323594122516L;
+
+        private final int numberKeys;
+        private final int numberElements;
+        private final boolean terminateAfterEmission;
+
+        protected int counter = 0;
+
+        private boolean running = true;
+
+        public DefiniteKeySource(
+                int numberKeys, int numberElements, boolean terminateAfterEmission) {
+            this.numberKeys = numberKeys;
+            this.numberElements = numberElements;
+            this.terminateAfterEmission = terminateAfterEmission;
+        }
+
+        @Override
+        public void run(SourceContext<Integer> ctx) throws Exception {
+            final Object lock = ctx.getCheckpointLock();
+            final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+            while (running) {
+
+                if (counter < numberElements) {
+                    synchronized (lock) {
+                        for (int value = subtaskIndex;
+                                value < numberKeys;
+                                value += getRuntimeContext().getNumberOfParallelSubtasks()) {
+                            ctx.collect(value);
+                        }
+                        counter++;
+                    }
+                } else {
+                    if (terminateAfterEmission) {
+                        running = false;
+                    } else {
+                        Thread.sleep(100);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public void cancel() {
+            running = false;
+        }
+    }
+
+    /** A flatMapper with the index of subtask. */
+    public static class SubtaskIndexFlatMapper
+            extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
+            implements CheckpointedFunction {
+
+        private static final long serialVersionUID = 5273172591283191348L;
+
+        public static CountDownLatch workCompletedLatch = new CountDownLatch(1);
+
+        private transient ValueState<Integer> counter;
+        private transient ValueState<Integer> sum;
+
+        private final int numberElements;
+
+        public SubtaskIndexFlatMapper(int numberElements) {
+            this.numberElements = numberElements;
+        }
+
+        @Override
+        public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out)
+                throws Exception {
+
+            int count = counter.value() + 1;
+            counter.update(count);
+
+            int s = sum.value() + value;
+            sum.update(s);
+
+            if (count % numberElements == 0) {
+                out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
+                workCompletedLatch.countDown();
+            }
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext context) throws Exception {
+            // all managed, nothing to do.
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) throws Exception {
+            counter =
+                    context.getKeyedStateStore()
+                            .getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
+            sum =
+                    context.getKeyedStateStore()
+                            .getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
+        }
+    }
+
+    /** Sink for collecting results into a collection. */
+    public static class CollectionSink<IN> implements SinkFunction<IN> {
+
+        private static final Set<Object> elements =
+                Collections.newSetFromMap(new ConcurrentHashMap<>());
+
+        private static final long serialVersionUID = -1652452958040267745L;
+
+        public static <IN> Set<IN> getElementsSet() {
+            return (Set<IN>) elements;
+        }
+
+        public static void clearElementsSet() {
+            elements.clear();
+        }
+
+        @Override
+        public void invoke(IN value) throws Exception {
+            elements.add(value);
+        }
+    }
+}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java b/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java
index 1856960..bd69b85 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.test.util;
 
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.client.program.ClusterClient;
 import org.apache.flink.core.execution.JobClient;
 import org.apache.flink.runtime.checkpoint.Checkpoints;
@@ -39,6 +41,7 @@ import java.nio.file.Path;
 import java.nio.file.attribute.BasicFileAttributes;
 import java.util.Comparator;
 import java.util.Optional;
+import java.util.concurrent.ExecutionException;
 
 import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.CHECKPOINT_DIR_PREFIX;
 import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.METADATA_FILE_NAME;
@@ -139,4 +142,23 @@ public class TestUtils {
             return false; // should never happen
         }
     }
+
+    public static void waitUntilExternalizedCheckpointCreated(File checkpointDir)
+            throws InterruptedException, IOException {
+        while (true) {
+            Thread.sleep(50);
+            Optional<File> externalizedCheckpoint =
+                    getMostRecentCompletedCheckpointMaybe(checkpointDir);
+            if (externalizedCheckpoint.isPresent()) {
+                break;
+            }
+        }
+    }
+
+    public static void waitUntilJobCanceled(JobID jobId, ClusterClient<?> client)
+            throws ExecutionException, InterruptedException {
+        while (client.getJobStatus(jobId).get() != JobStatus.CANCELED) {
+            Thread.sleep(50);
+        }
+    }
 }