You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ta...@apache.org on 2022/03/21 03:11:08 UTC

[flink] 02/02: [FLINK-21321][Runtime/StateBackends] Add ITCases for rescaling from checkpoint

This is an automated email from the ASF dual-hosted git repository.

tangyun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 1bf45b25791cc3fad8b7d0d863caa9b0eef9a87b
Author: fredia <fr...@gmail.com>
AuthorDate: Thu Mar 10 11:33:59 2022 +0800

    [FLINK-21321][Runtime/StateBackends] Add ITCases for rescaling from checkpoint
---
 .../state/RocksDBIncrementalCheckpointUtils.java   |   2 +-
 ...ncrementalCheckpointRescalingBenchmarkTest.java | 240 -----------------
 .../RescaleCheckpointManuallyITCase.java           | 286 +++++++++++++++++++++
 .../flink/test/checkpointing/RescalingITCase.java  | 136 +---------
 .../ResumeCheckpointManuallyITCase.java            |  65 +----
 .../checkpointing/utils/RescalingTestUtils.java    | 162 ++++++++++++
 .../java/org/apache/flink/test/util/TestUtils.java |  22 ++
 7 files changed, 480 insertions(+), 433 deletions(-)

diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
index 467156c..23c7867 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
@@ -149,7 +149,7 @@ public class RocksDBIncrementalCheckpointUtils {
             // Using RocksDB's deleteRange will take advantage of delete
             // tombstones, which mark the range as deleted.
             //
-            // https://github.com/facebook/rocksdb/blob/bcd32560dd5898956b9d24553c2bb3c1b1d2319f/include/rocksdb/db.h#L357-L371
+            // https://github.com/ververica/frocksdb/blob/FRocksDB-6.20.3/include/rocksdb/db.h#L363-L377
             db.deleteRange(columnFamilyHandle, beginKeyBytes, endKeyBytes);
         }
     }
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java
deleted file mode 100644
index a4e267b..0000000
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java
+++ /dev/null
@@ -1,240 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.contrib.streaming.state;
-
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
-import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
-import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
-import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
-import org.apache.flink.testutils.junit.RetryOnFailure;
-import org.apache.flink.util.Collector;
-import org.apache.flink.util.TestLogger;
-
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
-
-import java.util.Collection;
-import java.util.List;
-
-/** Test runs the benchmark for incremental checkpoint rescaling. */
-public class RocksIncrementalCheckpointRescalingBenchmarkTest extends TestLogger {
-
-    @Rule public TemporaryFolder rootFolder = new TemporaryFolder();
-
-    private static final int maxParallelism = 10;
-
-    private static final int recordCount = 1_000;
-
-    /** partitionParallelism is the parallelism to use for creating the partitionedSnapshot. */
-    private static final int partitionParallelism = 2;
-
-    /**
-     * repartitionParallelism is the parallelism to use during the test for the repartition step.
-     *
-     * <p>NOTE: To trigger {@link
-     * org.apache.flink.contrib.streaming.state.restore.RocksDBIncrementalRestoreOperation#restoreWithRescaling(Collection)},
-     * where the improvement code is exercised, the target parallelism must not be divisible by
-     * {@link partitionParallelism}. If this parallelism was instead 4, then there is no rescale.
-     */
-    private static final int repartitionParallelism = 3;
-
-    /** partitionedSnapshot is a partitioned incremental RocksDB snapshot. */
-    private OperatorSubtaskState partitionedSnapshot;
-
-    private KeySelector<Integer, Integer> keySelector = new TestKeySelector();
-
-    /**
-     * The benchmark's preparation will:
-     *
-     * <ol>
-     *   <li>Create a stateful operator and process records to persist state.
-     *   <li>Snapshot the state and re-partition it so the test operates on a partitioned state.
-     * </ol>
-     *
-     * @throws Exception
-     */
-    @Before
-    public void before() throws Exception {
-        OperatorSubtaskState snapshot;
-        // Initialize the test harness with a a task parallelism of 1.
-        try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> harness =
-                getHarnessTest(keySelector, maxParallelism, 1, 0)) {
-            // Set the state backend of the harness to RocksDB.
-            harness.setStateBackend(getStateBackend());
-            // Initialize the harness.
-            harness.open();
-            // Push the test records into the operator to trigger state updates.
-            Integer[] records = new Integer[recordCount];
-            for (int i = 0; i < recordCount; i++) {
-                harness.processElement(new StreamRecord<>(i, 1));
-            }
-            // Grab a snapshot of the state.
-            snapshot = harness.snapshot(0, 0);
-        }
-
-        // Now, re-partition to create a partitioned state.
-        KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer>[] partitionedTestHarness =
-                new KeyedOneInputStreamOperatorTestHarness[partitionParallelism];
-        List<KeyGroupRange> keyGroupPartitions =
-                StateAssignmentOperation.createKeyGroupPartitions(
-                        maxParallelism, partitionParallelism);
-        try {
-            for (int i = 0; i < partitionParallelism; i++) {
-                // Initialize, open, and then re-snapshot the two subtasks to create a partitioned
-                // incremental RocksDB snapshot.
-                OperatorSubtaskState subtaskState =
-                        AbstractStreamOperatorTestHarness.repartitionOperatorState(
-                                snapshot, maxParallelism, 1, partitionParallelism, i);
-                KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i);
-
-                partitionedTestHarness[i] =
-                        getHarnessTest(keySelector, maxParallelism, partitionParallelism, i);
-                partitionedTestHarness[i].setStateBackend(getStateBackend());
-                partitionedTestHarness[i].setup();
-                partitionedTestHarness[i].initializeState(subtaskState);
-                partitionedTestHarness[i].open();
-            }
-
-            partitionedSnapshot =
-                    AbstractStreamOperatorTestHarness.repackageState(
-                            partitionedTestHarness[0].snapshot(1, 2),
-                            partitionedTestHarness[1].snapshot(1, 2));
-
-        } finally {
-            closeHarness(partitionedTestHarness);
-        }
-    }
-
-    @Test(timeout = 1000)
-    @RetryOnFailure(times = 3)
-    public void benchmarkScalingUp() throws Exception {
-        long benchmarkTime = 0;
-
-        // Trigger the incremental re-scaling via restoreWithRescaling by repartitioning it from
-        // parallelism of >1 to a higher parallelism. Time spent during this step includes the cost
-        // of incremental rescaling.
-        List<KeyGroupRange> keyGroupPartitions =
-                StateAssignmentOperation.createKeyGroupPartitions(
-                        maxParallelism, repartitionParallelism);
-
-        long fullStateSize = partitionedSnapshot.getStateSize();
-
-        for (int i = 0; i < repartitionParallelism; i++) {
-            OperatorSubtaskState subtaskState =
-                    AbstractStreamOperatorTestHarness.repartitionOperatorState(
-                            partitionedSnapshot,
-                            maxParallelism,
-                            partitionParallelism,
-                            repartitionParallelism,
-                            i);
-            KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i);
-
-            try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> subtaskHarness =
-                    getHarnessTest(keySelector, maxParallelism, repartitionParallelism, i)) {
-                RocksDBStateBackend backend = getStateBackend();
-                subtaskHarness.setStateBackend(backend);
-                subtaskHarness.setup();
-
-                // Precisely measure the call-site that triggers the restore operation.
-                long startingTime = System.nanoTime();
-                subtaskHarness.initializeState(subtaskState);
-                benchmarkTime += System.nanoTime() - startingTime;
-            }
-        }
-
-        log.error(
-                "--------------> performance for incremental checkpoint re-scaling <--------------");
-        log.error(
-                "rescale from {} to {} with {} records took: {} nanoseconds",
-                partitionParallelism,
-                repartitionParallelism,
-                recordCount,
-                benchmarkTime);
-    }
-
-    private KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> getHarnessTest(
-            KeySelector<Integer, Integer> keySelector,
-            int maxParallelism,
-            int taskParallelism,
-            int subtaskIdx)
-            throws Exception {
-        return new KeyedOneInputStreamOperatorTestHarness<>(
-                new KeyedProcessOperator<>(new TestKeyedFunction()),
-                keySelector,
-                BasicTypeInfo.INT_TYPE_INFO,
-                maxParallelism,
-                taskParallelism,
-                subtaskIdx);
-    }
-
-    private void closeHarness(KeyedOneInputStreamOperatorTestHarness<?, ?, ?>[] harnessArr)
-            throws Exception {
-        for (KeyedOneInputStreamOperatorTestHarness<?, ?, ?> harness : harnessArr) {
-            if (harness != null) {
-                harness.close();
-            }
-        }
-    }
-
-    private RocksDBStateBackend getStateBackend() throws Exception {
-        return new RocksDBStateBackend("file://" + rootFolder.newFolder().getAbsolutePath(), true);
-    }
-
-    /** A simple keyed function for tests. */
-    private class TestKeyedFunction extends KeyedProcessFunction<Integer, Integer, Integer> {
-
-        public ValueStateDescriptor<Integer> stateDescriptor;
-        private ValueState<Integer> counterState;
-
-        @Override
-        public void open(Configuration parameters) throws Exception {
-            super.open(parameters);
-            stateDescriptor = new ValueStateDescriptor<Integer>("counter", Integer.class);
-            counterState = this.getRuntimeContext().getState(stateDescriptor);
-        }
-
-        @Override
-        public void processElement(Integer incomingValue, Context ctx, Collector<Integer> out)
-                throws Exception {
-            Integer oldValue = counterState.value();
-            Integer newValue = oldValue != null ? oldValue + incomingValue : incomingValue;
-            counterState.update(newValue);
-            out.collect(newValue);
-        }
-    }
-
-    /** A simple key selector for tests. */
-    private class TestKeySelector implements KeySelector<Integer, Integer> {
-        @Override
-        public Integer getKey(Integer value) throws Exception {
-            return value;
-        }
-    }
-}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java
new file mode 100644
index 0000000..8c05afa
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.client.program.ClusterClient;
+import org.apache.flink.configuration.CheckpointingOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.StateBackendOptions;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.CheckpointConfig;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils;
+import org.apache.flink.test.util.MiniClusterWithClientResource;
+import org.apache.flink.test.util.TestUtils;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.File;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+
+import static org.apache.flink.test.util.TestUtils.submitJobAndWaitForResult;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/** Test checkpoint rescaling for incremental rocksdb. */
+public class RescaleCheckpointManuallyITCase extends TestLogger {
+
+    private static final int NUM_TASK_MANAGERS = 2;
+    private static final int SLOTS_PER_TASK_MANAGER = 2;
+
+    private static MiniClusterWithClientResource cluster;
+    private File checkpointDir;
+
+    @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+    @Before
+    public void setup() throws Exception {
+        Configuration config = new Configuration();
+
+        checkpointDir = temporaryFolder.newFolder();
+
+        config.setString(StateBackendOptions.STATE_BACKEND, "rocksdb");
+        config.setString(
+                CheckpointingOptions.CHECKPOINTS_DIRECTORY, checkpointDir.toURI().toString());
+        config.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true);
+
+        cluster =
+                new MiniClusterWithClientResource(
+                        new MiniClusterResourceConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumberTaskManagers(NUM_TASK_MANAGERS)
+                                .setNumberSlotsPerTaskManager(SLOTS_PER_TASK_MANAGER)
+                                .build());
+        cluster.before();
+    }
+
+    @Test
+    public void testCheckpointRescalingInKeyedState() throws Exception {
+        testCheckpointRescalingKeyedState(false);
+    }
+
+    @Test
+    public void testCheckpointRescalingOutKeyedState() throws Exception {
+        testCheckpointRescalingKeyedState(true);
+    }
+
+    /**
+     * Tests that a job with purely keyed state can be restarted from a checkpoint with a different
+     * parallelism.
+     */
+    public void testCheckpointRescalingKeyedState(boolean scaleOut) throws Exception {
+        final int numberKeys = 42;
+        final int numberElements = 1000;
+        final int numberElements2 = 500;
+        final int parallelism = scaleOut ? 3 : 4;
+        final int parallelism2 = scaleOut ? 4 : 3;
+        final int maxParallelism = 13;
+
+        cluster.before();
+
+        ClusterClient<?> client = cluster.getClusterClient();
+        String checkpointPath =
+                runJobAndGetCheckpoint(
+                        numberKeys,
+                        numberElements,
+                        parallelism,
+                        maxParallelism,
+                        client,
+                        checkpointDir);
+
+        assertNotNull(checkpointPath);
+
+        restoreAndAssert(
+                parallelism2,
+                maxParallelism,
+                maxParallelism,
+                numberKeys,
+                numberElements2,
+                numberElements + numberElements2,
+                client,
+                checkpointPath);
+    }
+
+    private static String runJobAndGetCheckpoint(
+            int numberKeys,
+            int numberElements,
+            int parallelism,
+            int maxParallelism,
+            ClusterClient<?> client,
+            File checkpointDir)
+            throws Exception {
+        try {
+            JobGraph jobGraph =
+                    createJobGraphWithKeyedState(
+                            parallelism, maxParallelism, numberKeys, numberElements, false, 100);
+            NotifyingDefiniteKeySource.sourceLatch = new CountDownLatch(parallelism);
+            client.submitJob(jobGraph).get();
+            NotifyingDefiniteKeySource.sourceLatch.await();
+
+            RescalingTestUtils.SubtaskIndexFlatMapper.workCompletedLatch.await();
+
+            // verify the current state
+            Set<Tuple2<Integer, Integer>> actualResult =
+                    RescalingTestUtils.CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+                expectedResult.add(
+                        Tuple2.of(
+                                KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelism, parallelism, keyGroupIndex),
+                                numberElements * key));
+            }
+
+            assertEquals(expectedResult, actualResult);
+            NotifyingDefiniteKeySource.sourceLatch.await();
+
+            TestUtils.waitUntilExternalizedCheckpointCreated(checkpointDir);
+            client.cancel(jobGraph.getJobID()).get();
+            TestUtils.waitUntilJobCanceled(jobGraph.getJobID(), client);
+            return TestUtils.getMostRecentCompletedCheckpoint(checkpointDir).getAbsolutePath();
+        } finally {
+            RescalingTestUtils.CollectionSink.clearElementsSet();
+        }
+    }
+
+    private void restoreAndAssert(
+            int restoreParallelism,
+            int restoreMaxParallelism,
+            int maxParallelismBefore,
+            int numberKeys,
+            int numberElements,
+            int numberElementsExpect,
+            ClusterClient<?> client,
+            String restorePath)
+            throws Exception {
+        try {
+
+            JobGraph scaledJobGraph =
+                    createJobGraphWithKeyedState(
+                            restoreParallelism,
+                            restoreMaxParallelism,
+                            numberKeys,
+                            numberElements,
+                            true,
+                            100);
+
+            scaledJobGraph.setSavepointRestoreSettings(
+                    SavepointRestoreSettings.forPath(restorePath));
+
+            submitJobAndWaitForResult(client, scaledJobGraph, getClass().getClassLoader());
+
+            Set<Tuple2<Integer, Integer>> actualResult2 =
+                    RescalingTestUtils.CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex =
+                        KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelismBefore);
+                expectedResult2.add(
+                        Tuple2.of(
+                                KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelismBefore, restoreParallelism, keyGroupIndex),
+                                key * numberElementsExpect));
+            }
+            assertEquals(expectedResult2, actualResult2);
+        } finally {
+            RescalingTestUtils.CollectionSink.clearElementsSet();
+        }
+    }
+
+    private static JobGraph createJobGraphWithKeyedState(
+            int parallelism,
+            int maxParallelism,
+            int numberKeys,
+            int numberElements,
+            boolean terminateAfterEmission,
+            int checkpointingInterval) {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(parallelism);
+        if (0 < maxParallelism) {
+            env.getConfig().setMaxParallelism(maxParallelism);
+        }
+        env.enableCheckpointing(checkpointingInterval);
+        env.getCheckpointConfig()
+                .setExternalizedCheckpointCleanup(
+                        CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        env.getConfig().setUseSnapshotCompression(true);
+
+        DataStream<Integer> input =
+                env.addSource(
+                                new NotifyingDefiniteKeySource(
+                                        numberKeys, numberElements, terminateAfterEmission))
+                        .keyBy(
+                                new KeySelector<Integer, Integer>() {
+                                    private static final long serialVersionUID =
+                                            -7952298871120320940L;
+
+                                    @Override
+                                    public Integer getKey(Integer value) throws Exception {
+                                        return value;
+                                    }
+                                });
+        RescalingTestUtils.SubtaskIndexFlatMapper.workCompletedLatch =
+                new CountDownLatch(numberKeys);
+
+        DataStream<Tuple2<Integer, Integer>> result =
+                input.flatMap(new RescalingTestUtils.SubtaskIndexFlatMapper(numberElements));
+
+        result.addSink(new RescalingTestUtils.CollectionSink<>());
+
+        return env.getStreamGraph().getJobGraph();
+    }
+
+    private static class NotifyingDefiniteKeySource extends RescalingTestUtils.DefiniteKeySource {
+        private static final long serialVersionUID = 8120981235081181746L;
+
+        private static CountDownLatch sourceLatch;
+
+        public NotifyingDefiniteKeySource(
+                int numberKeys, int numberElements, boolean terminateAfterEmission) {
+            super(numberKeys, numberElements, terminateAfterEmission);
+        }
+
+        @Override
+        public void run(SourceContext<Integer> ctx) throws Exception {
+            if (sourceLatch != null) {
+                sourceLatch.countDown();
+            }
+            super.run(ctx);
+        }
+    }
+}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index ad43eae..d27d6f4 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -19,12 +19,9 @@
 package org.apache.flink.test.checkpointing;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.time.Deadline;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
@@ -50,12 +47,13 @@ import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
-import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.CollectionSink;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.DefiniteKeySource;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.SubtaskIndexFlatMapper;
 import org.apache.flink.test.util.MiniClusterWithClientResource;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 import org.apache.flink.util.concurrent.FutureUtils;
 
@@ -77,7 +75,6 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
@@ -672,7 +669,7 @@ public class RescalingITCase extends TestLogger {
 
         DataStream<Integer> input =
                 env.addSource(
-                                new SubtaskIndexSource(
+                                new DefiniteKeySource(
                                         numberKeys, numberElements, terminateAfterEmission))
                         .keyBy(
                                 new KeySelector<Integer, Integer>() {
@@ -736,60 +733,7 @@ public class RescalingITCase extends TestLogger {
         return env.getStreamGraph().getJobGraph();
     }
 
-    private static class SubtaskIndexSource extends RichParallelSourceFunction<Integer> {
-
-        private static final long serialVersionUID = -400066323594122516L;
-
-        private final int numberKeys;
-        private final int numberElements;
-        private final boolean terminateAfterEmission;
-
-        protected int counter = 0;
-
-        private boolean running = true;
-
-        SubtaskIndexSource(int numberKeys, int numberElements, boolean terminateAfterEmission) {
-
-            this.numberKeys = numberKeys;
-            this.numberElements = numberElements;
-            this.terminateAfterEmission = terminateAfterEmission;
-        }
-
-        @Override
-        public void run(SourceContext<Integer> ctx) throws Exception {
-            final Object lock = ctx.getCheckpointLock();
-            final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
-
-            while (running) {
-
-                if (counter < numberElements) {
-                    synchronized (lock) {
-                        for (int value = subtaskIndex;
-                                value < numberKeys;
-                                value += getRuntimeContext().getNumberOfParallelSubtasks()) {
-
-                            ctx.collect(value);
-                        }
-
-                        counter++;
-                    }
-                } else {
-                    if (terminateAfterEmission) {
-                        running = false;
-                    } else {
-                        Thread.sleep(100);
-                    }
-                }
-            }
-        }
-
-        @Override
-        public void cancel() {
-            running = false;
-        }
-    }
-
-    private static class SubtaskIndexNonPartitionedStateSource extends SubtaskIndexSource
+    private static class SubtaskIndexNonPartitionedStateSource extends DefiniteKeySource
             implements ListCheckpointed<Integer> {
 
         private static final long serialVersionUID = 8388073059042040203L;
@@ -814,76 +758,6 @@ public class RescalingITCase extends TestLogger {
         }
     }
 
-    private static class SubtaskIndexFlatMapper
-            extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
-            implements CheckpointedFunction {
-
-        private static final long serialVersionUID = 5273172591283191348L;
-
-        private static CountDownLatch workCompletedLatch = new CountDownLatch(1);
-
-        private transient ValueState<Integer> counter;
-        private transient ValueState<Integer> sum;
-
-        private final int numberElements;
-
-        SubtaskIndexFlatMapper(int numberElements) {
-            this.numberElements = numberElements;
-        }
-
-        @Override
-        public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out)
-                throws Exception {
-
-            int count = counter.value() + 1;
-            counter.update(count);
-
-            int s = sum.value() + value;
-            sum.update(s);
-
-            if (count % numberElements == 0) {
-                out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
-                workCompletedLatch.countDown();
-            }
-        }
-
-        @Override
-        public void snapshotState(FunctionSnapshotContext context) throws Exception {
-            // all managed, nothing to do.
-        }
-
-        @Override
-        public void initializeState(FunctionInitializationContext context) throws Exception {
-            counter =
-                    context.getKeyedStateStore()
-                            .getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
-            sum =
-                    context.getKeyedStateStore()
-                            .getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
-        }
-    }
-
-    private static class CollectionSink<IN> implements SinkFunction<IN> {
-
-        private static Set<Object> elements =
-                Collections.newSetFromMap(new ConcurrentHashMap<Object, Boolean>());
-
-        private static final long serialVersionUID = -1652452958040267745L;
-
-        public static <IN> Set<IN> getElementsSet() {
-            return (Set<IN>) elements;
-        }
-
-        public static void clearElementsSet() {
-            elements.clear();
-        }
-
-        @Override
-        public void invoke(IN value) throws Exception {
-            elements.add(value);
-        }
-    }
-
     private static class StateSourceBase extends RichParallelSourceFunction<Integer> {
 
         private static final long serialVersionUID = 7512206069681177940L;
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java
index eeda1b0..ab5d7f9 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.test.checkpointing;
 
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.api.common.eventtime.AscendingTimestampsWatermarks;
 import org.apache.flink.api.common.eventtime.TimestampAssigner;
 import org.apache.flink.api.common.eventtime.TimestampAssignerSupplier;
@@ -45,6 +43,7 @@ import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindo
 import org.apache.flink.streaming.api.windowing.time.Time;
 import org.apache.flink.test.state.ManualWindowSpeedITCase;
 import org.apache.flink.test.util.MiniClusterWithClientResource;
+import org.apache.flink.test.util.TestUtils;
 import org.apache.flink.util.TestLogger;
 
 import org.apache.curator.test.TestingServer;
@@ -56,12 +55,7 @@ import javax.annotation.Nullable;
 
 import java.io.File;
 import java.io.IOException;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.util.Optional;
 import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutionException;
-import java.util.stream.Stream;
 
 import static org.junit.Assert.assertNotNull;
 
@@ -306,62 +300,11 @@ public class ResumeCheckpointManuallyITCase extends TestLogger {
         // wait until all sources have been started
         NotifyingInfiniteTupleSource.countDownLatch.await();
 
-        waitUntilExternalizedCheckpointCreated(checkpointDir, initialJobGraph.getJobID());
+        TestUtils.waitUntilExternalizedCheckpointCreated(checkpointDir);
         client.cancel(initialJobGraph.getJobID()).get();
-        waitUntilCanceled(initialJobGraph.getJobID(), client);
+        TestUtils.waitUntilJobCanceled(initialJobGraph.getJobID(), client);
 
-        return getExternalizedCheckpointCheckpointPath(checkpointDir, initialJobGraph.getJobID());
-    }
-
-    private static String getExternalizedCheckpointCheckpointPath(File checkpointDir, JobID jobId)
-            throws IOException {
-        Optional<Path> checkpoint = findExternalizedCheckpoint(checkpointDir, jobId);
-        if (!checkpoint.isPresent()) {
-            throw new AssertionError("No complete checkpoint could be found.");
-        } else {
-            return checkpoint.get().toString();
-        }
-    }
-
-    private static void waitUntilExternalizedCheckpointCreated(File checkpointDir, JobID jobId)
-            throws InterruptedException, IOException {
-        while (true) {
-            Thread.sleep(50);
-            Optional<Path> externalizedCheckpoint =
-                    findExternalizedCheckpoint(checkpointDir, jobId);
-            if (externalizedCheckpoint.isPresent()) {
-                break;
-            }
-        }
-    }
-
-    private static Optional<Path> findExternalizedCheckpoint(File checkpointDir, JobID jobId)
-            throws IOException {
-        try (Stream<Path> checkpoints =
-                Files.list(checkpointDir.toPath().resolve(jobId.toString()))) {
-            return checkpoints
-                    .filter(path -> path.getFileName().toString().startsWith("chk-"))
-                    .filter(
-                            path -> {
-                                try (Stream<Path> checkpointFiles = Files.list(path)) {
-                                    return checkpointFiles.anyMatch(
-                                            child ->
-                                                    child.getFileName()
-                                                            .toString()
-                                                            .contains("meta"));
-                                } catch (IOException ignored) {
-                                    return false;
-                                }
-                            })
-                    .findAny();
-        }
-    }
-
-    private static void waitUntilCanceled(JobID jobId, ClusterClient<?> client)
-            throws ExecutionException, InterruptedException {
-        while (client.getJobStatus(jobId).get() != JobStatus.CANCELED) {
-            Thread.sleep(50);
-        }
+        return TestUtils.getMostRecentCompletedCheckpoint(checkpointDir).getAbsolutePath();
     }
 
     private static JobGraph getJobGraph(StateBackend backend, @Nullable String externalCheckpoint) {
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java
new file mode 100644
index 0000000..87214b7
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing.utils;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.util.Collector;
+
+import java.util.Collections;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+
+/** Test utilities for rescaling. */
+public class RescalingTestUtils {
+
+    /** A parallel source with definite keys. */
+    public static class DefiniteKeySource extends RichParallelSourceFunction<Integer> {
+
+        private static final long serialVersionUID = -400066323594122516L;
+
+        private final int numberKeys;
+        private final int numberElements;
+        private final boolean terminateAfterEmission;
+
+        protected int counter = 0;
+
+        private boolean running = true;
+
+        public DefiniteKeySource(
+                int numberKeys, int numberElements, boolean terminateAfterEmission) {
+            this.numberKeys = numberKeys;
+            this.numberElements = numberElements;
+            this.terminateAfterEmission = terminateAfterEmission;
+        }
+
+        @Override
+        public void run(SourceContext<Integer> ctx) throws Exception {
+            final Object lock = ctx.getCheckpointLock();
+            final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+            while (running) {
+
+                if (counter < numberElements) {
+                    synchronized (lock) {
+                        for (int value = subtaskIndex;
+                                value < numberKeys;
+                                value += getRuntimeContext().getNumberOfParallelSubtasks()) {
+                            ctx.collect(value);
+                        }
+                        counter++;
+                    }
+                } else {
+                    if (terminateAfterEmission) {
+                        running = false;
+                    } else {
+                        Thread.sleep(100);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public void cancel() {
+            running = false;
+        }
+    }
+
+    /** A flatMapper with the index of subtask. */
+    public static class SubtaskIndexFlatMapper
+            extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
+            implements CheckpointedFunction {
+
+        private static final long serialVersionUID = 5273172591283191348L;
+
+        public static CountDownLatch workCompletedLatch = new CountDownLatch(1);
+
+        private transient ValueState<Integer> counter;
+        private transient ValueState<Integer> sum;
+
+        private final int numberElements;
+
+        public SubtaskIndexFlatMapper(int numberElements) {
+            this.numberElements = numberElements;
+        }
+
+        @Override
+        public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out)
+                throws Exception {
+
+            int count = counter.value() + 1;
+            counter.update(count);
+
+            int s = sum.value() + value;
+            sum.update(s);
+
+            if (count % numberElements == 0) {
+                out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
+                workCompletedLatch.countDown();
+            }
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext context) throws Exception {
+            // all managed, nothing to do.
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) throws Exception {
+            counter =
+                    context.getKeyedStateStore()
+                            .getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
+            sum =
+                    context.getKeyedStateStore()
+                            .getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
+        }
+    }
+
+    /** Sink for collecting results into a collection. */
+    public static class CollectionSink<IN> implements SinkFunction<IN> {
+
+        private static final Set<Object> elements =
+                Collections.newSetFromMap(new ConcurrentHashMap<>());
+
+        private static final long serialVersionUID = -1652452958040267745L;
+
+        public static <IN> Set<IN> getElementsSet() {
+            return (Set<IN>) elements;
+        }
+
+        public static void clearElementsSet() {
+            elements.clear();
+        }
+
+        @Override
+        public void invoke(IN value) throws Exception {
+            elements.add(value);
+        }
+    }
+}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java b/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java
index 1856960..bd69b85 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java
@@ -18,6 +18,8 @@
 
 package org.apache.flink.test.util;
 
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.client.program.ClusterClient;
 import org.apache.flink.core.execution.JobClient;
 import org.apache.flink.runtime.checkpoint.Checkpoints;
@@ -39,6 +41,7 @@ import java.nio.file.Path;
 import java.nio.file.attribute.BasicFileAttributes;
 import java.util.Comparator;
 import java.util.Optional;
+import java.util.concurrent.ExecutionException;
 
 import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.CHECKPOINT_DIR_PREFIX;
 import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.METADATA_FILE_NAME;
@@ -139,4 +142,23 @@ public class TestUtils {
             return false; // should never happen
         }
     }
+
+    public static void waitUntilExternalizedCheckpointCreated(File checkpointDir)
+            throws InterruptedException, IOException {
+        while (true) {
+            Thread.sleep(50);
+            Optional<File> externalizedCheckpoint =
+                    getMostRecentCompletedCheckpointMaybe(checkpointDir);
+            if (externalizedCheckpoint.isPresent()) {
+                break;
+            }
+        }
+    }
+
+    public static void waitUntilJobCanceled(JobID jobId, ClusterClient<?> client)
+            throws ExecutionException, InterruptedException {
+        while (client.getJobStatus(jobId).get() != JobStatus.CANCELED) {
+            Thread.sleep(50);
+        }
+    }
 }