You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ta...@apache.org on 2022/03/21 03:11:08 UTC
[flink] 02/02: [FLINK-21321][Runtime/StateBackends] Add ITCases for rescaling from checkpoint
This is an automated email from the ASF dual-hosted git repository.
tangyun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 1bf45b25791cc3fad8b7d0d863caa9b0eef9a87b
Author: fredia <fr...@gmail.com>
AuthorDate: Thu Mar 10 11:33:59 2022 +0800
[FLINK-21321][Runtime/StateBackends] Add ITCases for rescaling from checkpoint
---
.../state/RocksDBIncrementalCheckpointUtils.java | 2 +-
...ncrementalCheckpointRescalingBenchmarkTest.java | 240 -----------------
.../RescaleCheckpointManuallyITCase.java | 286 +++++++++++++++++++++
.../flink/test/checkpointing/RescalingITCase.java | 136 +---------
.../ResumeCheckpointManuallyITCase.java | 65 +----
.../checkpointing/utils/RescalingTestUtils.java | 162 ++++++++++++
.../java/org/apache/flink/test/util/TestUtils.java | 22 ++
7 files changed, 480 insertions(+), 433 deletions(-)
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
index 467156c..23c7867 100644
--- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
+++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java
@@ -149,7 +149,7 @@ public class RocksDBIncrementalCheckpointUtils {
// Using RocksDB's deleteRange will take advantage of delete
// tombstones, which mark the range as deleted.
//
- // https://github.com/facebook/rocksdb/blob/bcd32560dd5898956b9d24553c2bb3c1b1d2319f/include/rocksdb/db.h#L357-L371
+ // https://github.com/ververica/frocksdb/blob/FRocksDB-6.20.3/include/rocksdb/db.h#L363-L377
db.deleteRange(columnFamilyHandle, beginKeyBytes, endKeyBytes);
}
}
diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java
deleted file mode 100644
index a4e267b..0000000
--- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksIncrementalCheckpointRescalingBenchmarkTest.java
+++ /dev/null
@@ -1,240 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.contrib.streaming.state;
-
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
-import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
-import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.configuration.Configuration;
-import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
-import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
-import org.apache.flink.runtime.state.KeyGroupRange;
-import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
-import org.apache.flink.streaming.api.operators.KeyedProcessOperator;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
-import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
-import org.apache.flink.testutils.junit.RetryOnFailure;
-import org.apache.flink.util.Collector;
-import org.apache.flink.util.TestLogger;
-
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
-
-import java.util.Collection;
-import java.util.List;
-
-/** Test runs the benchmark for incremental checkpoint rescaling. */
-public class RocksIncrementalCheckpointRescalingBenchmarkTest extends TestLogger {
-
- @Rule public TemporaryFolder rootFolder = new TemporaryFolder();
-
- private static final int maxParallelism = 10;
-
- private static final int recordCount = 1_000;
-
- /** partitionParallelism is the parallelism to use for creating the partitionedSnapshot. */
- private static final int partitionParallelism = 2;
-
- /**
- * repartitionParallelism is the parallelism to use during the test for the repartition step.
- *
- * <p>NOTE: To trigger {@link
- * org.apache.flink.contrib.streaming.state.restore.RocksDBIncrementalRestoreOperation#restoreWithRescaling(Collection)},
- * where the improvement code is exercised, the target parallelism must not be divisible by
- * {@link partitionParallelism}. If this parallelism was instead 4, then there is no rescale.
- */
- private static final int repartitionParallelism = 3;
-
- /** partitionedSnapshot is a partitioned incremental RocksDB snapshot. */
- private OperatorSubtaskState partitionedSnapshot;
-
- private KeySelector<Integer, Integer> keySelector = new TestKeySelector();
-
- /**
- * The benchmark's preparation will:
- *
- * <ol>
- * <li>Create a stateful operator and process records to persist state.
- * <li>Snapshot the state and re-partition it so the test operates on a partitioned state.
- * </ol>
- *
- * @throws Exception
- */
- @Before
- public void before() throws Exception {
- OperatorSubtaskState snapshot;
- // Initialize the test harness with a a task parallelism of 1.
- try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> harness =
- getHarnessTest(keySelector, maxParallelism, 1, 0)) {
- // Set the state backend of the harness to RocksDB.
- harness.setStateBackend(getStateBackend());
- // Initialize the harness.
- harness.open();
- // Push the test records into the operator to trigger state updates.
- Integer[] records = new Integer[recordCount];
- for (int i = 0; i < recordCount; i++) {
- harness.processElement(new StreamRecord<>(i, 1));
- }
- // Grab a snapshot of the state.
- snapshot = harness.snapshot(0, 0);
- }
-
- // Now, re-partition to create a partitioned state.
- KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer>[] partitionedTestHarness =
- new KeyedOneInputStreamOperatorTestHarness[partitionParallelism];
- List<KeyGroupRange> keyGroupPartitions =
- StateAssignmentOperation.createKeyGroupPartitions(
- maxParallelism, partitionParallelism);
- try {
- for (int i = 0; i < partitionParallelism; i++) {
- // Initialize, open, and then re-snapshot the two subtasks to create a partitioned
- // incremental RocksDB snapshot.
- OperatorSubtaskState subtaskState =
- AbstractStreamOperatorTestHarness.repartitionOperatorState(
- snapshot, maxParallelism, 1, partitionParallelism, i);
- KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i);
-
- partitionedTestHarness[i] =
- getHarnessTest(keySelector, maxParallelism, partitionParallelism, i);
- partitionedTestHarness[i].setStateBackend(getStateBackend());
- partitionedTestHarness[i].setup();
- partitionedTestHarness[i].initializeState(subtaskState);
- partitionedTestHarness[i].open();
- }
-
- partitionedSnapshot =
- AbstractStreamOperatorTestHarness.repackageState(
- partitionedTestHarness[0].snapshot(1, 2),
- partitionedTestHarness[1].snapshot(1, 2));
-
- } finally {
- closeHarness(partitionedTestHarness);
- }
- }
-
- @Test(timeout = 1000)
- @RetryOnFailure(times = 3)
- public void benchmarkScalingUp() throws Exception {
- long benchmarkTime = 0;
-
- // Trigger the incremental re-scaling via restoreWithRescaling by repartitioning it from
- // parallelism of >1 to a higher parallelism. Time spent during this step includes the cost
- // of incremental rescaling.
- List<KeyGroupRange> keyGroupPartitions =
- StateAssignmentOperation.createKeyGroupPartitions(
- maxParallelism, repartitionParallelism);
-
- long fullStateSize = partitionedSnapshot.getStateSize();
-
- for (int i = 0; i < repartitionParallelism; i++) {
- OperatorSubtaskState subtaskState =
- AbstractStreamOperatorTestHarness.repartitionOperatorState(
- partitionedSnapshot,
- maxParallelism,
- partitionParallelism,
- repartitionParallelism,
- i);
- KeyGroupRange localKeyGroupRange20 = keyGroupPartitions.get(i);
-
- try (KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> subtaskHarness =
- getHarnessTest(keySelector, maxParallelism, repartitionParallelism, i)) {
- RocksDBStateBackend backend = getStateBackend();
- subtaskHarness.setStateBackend(backend);
- subtaskHarness.setup();
-
- // Precisely measure the call-site that triggers the restore operation.
- long startingTime = System.nanoTime();
- subtaskHarness.initializeState(subtaskState);
- benchmarkTime += System.nanoTime() - startingTime;
- }
- }
-
- log.error(
- "--------------> performance for incremental checkpoint re-scaling <--------------");
- log.error(
- "rescale from {} to {} with {} records took: {} nanoseconds",
- partitionParallelism,
- repartitionParallelism,
- recordCount,
- benchmarkTime);
- }
-
- private KeyedOneInputStreamOperatorTestHarness<Integer, Integer, Integer> getHarnessTest(
- KeySelector<Integer, Integer> keySelector,
- int maxParallelism,
- int taskParallelism,
- int subtaskIdx)
- throws Exception {
- return new KeyedOneInputStreamOperatorTestHarness<>(
- new KeyedProcessOperator<>(new TestKeyedFunction()),
- keySelector,
- BasicTypeInfo.INT_TYPE_INFO,
- maxParallelism,
- taskParallelism,
- subtaskIdx);
- }
-
- private void closeHarness(KeyedOneInputStreamOperatorTestHarness<?, ?, ?>[] harnessArr)
- throws Exception {
- for (KeyedOneInputStreamOperatorTestHarness<?, ?, ?> harness : harnessArr) {
- if (harness != null) {
- harness.close();
- }
- }
- }
-
- private RocksDBStateBackend getStateBackend() throws Exception {
- return new RocksDBStateBackend("file://" + rootFolder.newFolder().getAbsolutePath(), true);
- }
-
- /** A simple keyed function for tests. */
- private class TestKeyedFunction extends KeyedProcessFunction<Integer, Integer, Integer> {
-
- public ValueStateDescriptor<Integer> stateDescriptor;
- private ValueState<Integer> counterState;
-
- @Override
- public void open(Configuration parameters) throws Exception {
- super.open(parameters);
- stateDescriptor = new ValueStateDescriptor<Integer>("counter", Integer.class);
- counterState = this.getRuntimeContext().getState(stateDescriptor);
- }
-
- @Override
- public void processElement(Integer incomingValue, Context ctx, Collector<Integer> out)
- throws Exception {
- Integer oldValue = counterState.value();
- Integer newValue = oldValue != null ? oldValue + incomingValue : incomingValue;
- counterState.update(newValue);
- out.collect(newValue);
- }
- }
-
- /** A simple key selector for tests. */
- private class TestKeySelector implements KeySelector<Integer, Integer> {
- @Override
- public Integer getKey(Integer value) throws Exception {
- return value;
- }
- }
-}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java
new file mode 100644
index 0000000..8c05afa
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.java
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.client.program.ClusterClient;
+import org.apache.flink.configuration.CheckpointingOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.StateBackendOptions;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.CheckpointConfig;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils;
+import org.apache.flink.test.util.MiniClusterWithClientResource;
+import org.apache.flink.test.util.TestUtils;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.File;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+
+import static org.apache.flink.test.util.TestUtils.submitJobAndWaitForResult;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/** Test checkpoint rescaling for incremental rocksdb. */
+public class RescaleCheckpointManuallyITCase extends TestLogger {
+
+ private static final int NUM_TASK_MANAGERS = 2;
+ private static final int SLOTS_PER_TASK_MANAGER = 2;
+
+ private static MiniClusterWithClientResource cluster;
+ private File checkpointDir;
+
+ @ClassRule public static TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+ @Before
+ public void setup() throws Exception {
+ Configuration config = new Configuration();
+
+ checkpointDir = temporaryFolder.newFolder();
+
+ config.setString(StateBackendOptions.STATE_BACKEND, "rocksdb");
+ config.setString(
+ CheckpointingOptions.CHECKPOINTS_DIRECTORY, checkpointDir.toURI().toString());
+ config.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true);
+
+ cluster =
+ new MiniClusterWithClientResource(
+ new MiniClusterResourceConfiguration.Builder()
+ .setConfiguration(config)
+ .setNumberTaskManagers(NUM_TASK_MANAGERS)
+ .setNumberSlotsPerTaskManager(SLOTS_PER_TASK_MANAGER)
+ .build());
+ cluster.before();
+ }
+
+ @Test
+ public void testCheckpointRescalingInKeyedState() throws Exception {
+ testCheckpointRescalingKeyedState(false);
+ }
+
+ @Test
+ public void testCheckpointRescalingOutKeyedState() throws Exception {
+ testCheckpointRescalingKeyedState(true);
+ }
+
+ /**
+ * Tests that a job with purely keyed state can be restarted from a checkpoint with a different
+ * parallelism.
+ */
+ public void testCheckpointRescalingKeyedState(boolean scaleOut) throws Exception {
+ final int numberKeys = 42;
+ final int numberElements = 1000;
+ final int numberElements2 = 500;
+ final int parallelism = scaleOut ? 3 : 4;
+ final int parallelism2 = scaleOut ? 4 : 3;
+ final int maxParallelism = 13;
+
+ cluster.before();
+
+ ClusterClient<?> client = cluster.getClusterClient();
+ String checkpointPath =
+ runJobAndGetCheckpoint(
+ numberKeys,
+ numberElements,
+ parallelism,
+ maxParallelism,
+ client,
+ checkpointDir);
+
+ assertNotNull(checkpointPath);
+
+ restoreAndAssert(
+ parallelism2,
+ maxParallelism,
+ maxParallelism,
+ numberKeys,
+ numberElements2,
+ numberElements + numberElements2,
+ client,
+ checkpointPath);
+ }
+
+ private static String runJobAndGetCheckpoint(
+ int numberKeys,
+ int numberElements,
+ int parallelism,
+ int maxParallelism,
+ ClusterClient<?> client,
+ File checkpointDir)
+ throws Exception {
+ try {
+ JobGraph jobGraph =
+ createJobGraphWithKeyedState(
+ parallelism, maxParallelism, numberKeys, numberElements, false, 100);
+ NotifyingDefiniteKeySource.sourceLatch = new CountDownLatch(parallelism);
+ client.submitJob(jobGraph).get();
+ NotifyingDefiniteKeySource.sourceLatch.await();
+
+ RescalingTestUtils.SubtaskIndexFlatMapper.workCompletedLatch.await();
+
+ // verify the current state
+ Set<Tuple2<Integer, Integer>> actualResult =
+ RescalingTestUtils.CollectionSink.getElementsSet();
+
+ Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+ for (int key = 0; key < numberKeys; key++) {
+ int keyGroupIndex = KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+ expectedResult.add(
+ Tuple2.of(
+ KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+ maxParallelism, parallelism, keyGroupIndex),
+ numberElements * key));
+ }
+
+ assertEquals(expectedResult, actualResult);
+ NotifyingDefiniteKeySource.sourceLatch.await();
+
+ TestUtils.waitUntilExternalizedCheckpointCreated(checkpointDir);
+ client.cancel(jobGraph.getJobID()).get();
+ TestUtils.waitUntilJobCanceled(jobGraph.getJobID(), client);
+ return TestUtils.getMostRecentCompletedCheckpoint(checkpointDir).getAbsolutePath();
+ } finally {
+ RescalingTestUtils.CollectionSink.clearElementsSet();
+ }
+ }
+
+ private void restoreAndAssert(
+ int restoreParallelism,
+ int restoreMaxParallelism,
+ int maxParallelismBefore,
+ int numberKeys,
+ int numberElements,
+ int numberElementsExpect,
+ ClusterClient<?> client,
+ String restorePath)
+ throws Exception {
+ try {
+
+ JobGraph scaledJobGraph =
+ createJobGraphWithKeyedState(
+ restoreParallelism,
+ restoreMaxParallelism,
+ numberKeys,
+ numberElements,
+ true,
+ 100);
+
+ scaledJobGraph.setSavepointRestoreSettings(
+ SavepointRestoreSettings.forPath(restorePath));
+
+ submitJobAndWaitForResult(client, scaledJobGraph, getClass().getClassLoader());
+
+ Set<Tuple2<Integer, Integer>> actualResult2 =
+ RescalingTestUtils.CollectionSink.getElementsSet();
+
+ Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+ for (int key = 0; key < numberKeys; key++) {
+ int keyGroupIndex =
+ KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelismBefore);
+ expectedResult2.add(
+ Tuple2.of(
+ KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+ maxParallelismBefore, restoreParallelism, keyGroupIndex),
+ key * numberElementsExpect));
+ }
+ assertEquals(expectedResult2, actualResult2);
+ } finally {
+ RescalingTestUtils.CollectionSink.clearElementsSet();
+ }
+ }
+
+ private static JobGraph createJobGraphWithKeyedState(
+ int parallelism,
+ int maxParallelism,
+ int numberKeys,
+ int numberElements,
+ boolean terminateAfterEmission,
+ int checkpointingInterval) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(parallelism);
+ if (0 < maxParallelism) {
+ env.getConfig().setMaxParallelism(maxParallelism);
+ }
+ env.enableCheckpointing(checkpointingInterval);
+ env.getCheckpointConfig()
+ .setExternalizedCheckpointCleanup(
+ CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ env.getConfig().setUseSnapshotCompression(true);
+
+ DataStream<Integer> input =
+ env.addSource(
+ new NotifyingDefiniteKeySource(
+ numberKeys, numberElements, terminateAfterEmission))
+ .keyBy(
+ new KeySelector<Integer, Integer>() {
+ private static final long serialVersionUID =
+ -7952298871120320940L;
+
+ @Override
+ public Integer getKey(Integer value) throws Exception {
+ return value;
+ }
+ });
+ RescalingTestUtils.SubtaskIndexFlatMapper.workCompletedLatch =
+ new CountDownLatch(numberKeys);
+
+ DataStream<Tuple2<Integer, Integer>> result =
+ input.flatMap(new RescalingTestUtils.SubtaskIndexFlatMapper(numberElements));
+
+ result.addSink(new RescalingTestUtils.CollectionSink<>());
+
+ return env.getStreamGraph().getJobGraph();
+ }
+
+ private static class NotifyingDefiniteKeySource extends RescalingTestUtils.DefiniteKeySource {
+ private static final long serialVersionUID = 8120981235081181746L;
+
+ private static CountDownLatch sourceLatch;
+
+ public NotifyingDefiniteKeySource(
+ int numberKeys, int numberElements, boolean terminateAfterEmission) {
+ super(numberKeys, numberElements, terminateAfterEmission);
+ }
+
+ @Override
+ public void run(SourceContext<Integer> ctx) throws Exception {
+ if (sourceLatch != null) {
+ sourceLatch.countDown();
+ }
+ super.run(ctx);
+ }
+ }
+}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index ad43eae..d27d6f4 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -19,12 +19,9 @@
package org.apache.flink.test.checkpointing;
import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
-import org.apache.flink.api.common.state.ValueState;
-import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Deadline;
import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
@@ -50,12 +47,13 @@ import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
-import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.CollectionSink;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.DefiniteKeySource;
+import org.apache.flink.test.checkpointing.utils.RescalingTestUtils.SubtaskIndexFlatMapper;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.util.Collector;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.concurrent.FutureUtils;
@@ -77,7 +75,6 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
@@ -672,7 +669,7 @@ public class RescalingITCase extends TestLogger {
DataStream<Integer> input =
env.addSource(
- new SubtaskIndexSource(
+ new DefiniteKeySource(
numberKeys, numberElements, terminateAfterEmission))
.keyBy(
new KeySelector<Integer, Integer>() {
@@ -736,60 +733,7 @@ public class RescalingITCase extends TestLogger {
return env.getStreamGraph().getJobGraph();
}
- private static class SubtaskIndexSource extends RichParallelSourceFunction<Integer> {
-
- private static final long serialVersionUID = -400066323594122516L;
-
- private final int numberKeys;
- private final int numberElements;
- private final boolean terminateAfterEmission;
-
- protected int counter = 0;
-
- private boolean running = true;
-
- SubtaskIndexSource(int numberKeys, int numberElements, boolean terminateAfterEmission) {
-
- this.numberKeys = numberKeys;
- this.numberElements = numberElements;
- this.terminateAfterEmission = terminateAfterEmission;
- }
-
- @Override
- public void run(SourceContext<Integer> ctx) throws Exception {
- final Object lock = ctx.getCheckpointLock();
- final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
-
- while (running) {
-
- if (counter < numberElements) {
- synchronized (lock) {
- for (int value = subtaskIndex;
- value < numberKeys;
- value += getRuntimeContext().getNumberOfParallelSubtasks()) {
-
- ctx.collect(value);
- }
-
- counter++;
- }
- } else {
- if (terminateAfterEmission) {
- running = false;
- } else {
- Thread.sleep(100);
- }
- }
- }
- }
-
- @Override
- public void cancel() {
- running = false;
- }
- }
-
- private static class SubtaskIndexNonPartitionedStateSource extends SubtaskIndexSource
+ private static class SubtaskIndexNonPartitionedStateSource extends DefiniteKeySource
implements ListCheckpointed<Integer> {
private static final long serialVersionUID = 8388073059042040203L;
@@ -814,76 +758,6 @@ public class RescalingITCase extends TestLogger {
}
}
- private static class SubtaskIndexFlatMapper
- extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
- implements CheckpointedFunction {
-
- private static final long serialVersionUID = 5273172591283191348L;
-
- private static CountDownLatch workCompletedLatch = new CountDownLatch(1);
-
- private transient ValueState<Integer> counter;
- private transient ValueState<Integer> sum;
-
- private final int numberElements;
-
- SubtaskIndexFlatMapper(int numberElements) {
- this.numberElements = numberElements;
- }
-
- @Override
- public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out)
- throws Exception {
-
- int count = counter.value() + 1;
- counter.update(count);
-
- int s = sum.value() + value;
- sum.update(s);
-
- if (count % numberElements == 0) {
- out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
- workCompletedLatch.countDown();
- }
- }
-
- @Override
- public void snapshotState(FunctionSnapshotContext context) throws Exception {
- // all managed, nothing to do.
- }
-
- @Override
- public void initializeState(FunctionInitializationContext context) throws Exception {
- counter =
- context.getKeyedStateStore()
- .getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
- sum =
- context.getKeyedStateStore()
- .getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
- }
- }
-
- private static class CollectionSink<IN> implements SinkFunction<IN> {
-
- private static Set<Object> elements =
- Collections.newSetFromMap(new ConcurrentHashMap<Object, Boolean>());
-
- private static final long serialVersionUID = -1652452958040267745L;
-
- public static <IN> Set<IN> getElementsSet() {
- return (Set<IN>) elements;
- }
-
- public static void clearElementsSet() {
- elements.clear();
- }
-
- @Override
- public void invoke(IN value) throws Exception {
- elements.add(value);
- }
- }
-
private static class StateSourceBase extends RichParallelSourceFunction<Integer> {
private static final long serialVersionUID = 7512206069681177940L;
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java
index eeda1b0..ab5d7f9 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/ResumeCheckpointManuallyITCase.java
@@ -18,8 +18,6 @@
package org.apache.flink.test.checkpointing;
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.JobStatus;
import org.apache.flink.api.common.eventtime.AscendingTimestampsWatermarks;
import org.apache.flink.api.common.eventtime.TimestampAssigner;
import org.apache.flink.api.common.eventtime.TimestampAssignerSupplier;
@@ -45,6 +43,7 @@ import org.apache.flink.streaming.api.windowing.assigners.TumblingEventTimeWindo
import org.apache.flink.streaming.api.windowing.time.Time;
import org.apache.flink.test.state.ManualWindowSpeedITCase;
import org.apache.flink.test.util.MiniClusterWithClientResource;
+import org.apache.flink.test.util.TestUtils;
import org.apache.flink.util.TestLogger;
import org.apache.curator.test.TestingServer;
@@ -56,12 +55,7 @@ import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.util.Optional;
import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.ExecutionException;
-import java.util.stream.Stream;
import static org.junit.Assert.assertNotNull;
@@ -306,62 +300,11 @@ public class ResumeCheckpointManuallyITCase extends TestLogger {
// wait until all sources have been started
NotifyingInfiniteTupleSource.countDownLatch.await();
- waitUntilExternalizedCheckpointCreated(checkpointDir, initialJobGraph.getJobID());
+ TestUtils.waitUntilExternalizedCheckpointCreated(checkpointDir);
client.cancel(initialJobGraph.getJobID()).get();
- waitUntilCanceled(initialJobGraph.getJobID(), client);
+ TestUtils.waitUntilJobCanceled(initialJobGraph.getJobID(), client);
- return getExternalizedCheckpointCheckpointPath(checkpointDir, initialJobGraph.getJobID());
- }
-
- private static String getExternalizedCheckpointCheckpointPath(File checkpointDir, JobID jobId)
- throws IOException {
- Optional<Path> checkpoint = findExternalizedCheckpoint(checkpointDir, jobId);
- if (!checkpoint.isPresent()) {
- throw new AssertionError("No complete checkpoint could be found.");
- } else {
- return checkpoint.get().toString();
- }
- }
-
- private static void waitUntilExternalizedCheckpointCreated(File checkpointDir, JobID jobId)
- throws InterruptedException, IOException {
- while (true) {
- Thread.sleep(50);
- Optional<Path> externalizedCheckpoint =
- findExternalizedCheckpoint(checkpointDir, jobId);
- if (externalizedCheckpoint.isPresent()) {
- break;
- }
- }
- }
-
- private static Optional<Path> findExternalizedCheckpoint(File checkpointDir, JobID jobId)
- throws IOException {
- try (Stream<Path> checkpoints =
- Files.list(checkpointDir.toPath().resolve(jobId.toString()))) {
- return checkpoints
- .filter(path -> path.getFileName().toString().startsWith("chk-"))
- .filter(
- path -> {
- try (Stream<Path> checkpointFiles = Files.list(path)) {
- return checkpointFiles.anyMatch(
- child ->
- child.getFileName()
- .toString()
- .contains("meta"));
- } catch (IOException ignored) {
- return false;
- }
- })
- .findAny();
- }
- }
-
- private static void waitUntilCanceled(JobID jobId, ClusterClient<?> client)
- throws ExecutionException, InterruptedException {
- while (client.getJobStatus(jobId).get() != JobStatus.CANCELED) {
- Thread.sleep(50);
- }
+ return TestUtils.getMostRecentCompletedCheckpoint(checkpointDir).getAbsolutePath();
}
private static JobGraph getJobGraph(StateBackend backend, @Nullable String externalCheckpoint) {
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java
new file mode 100644
index 0000000..87214b7
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/utils/RescalingTestUtils.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing.utils;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.util.Collector;
+
+import java.util.Collections;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+
+/** Test utilities for rescaling. */
+public class RescalingTestUtils {
+
+ /** A parallel source with definite keys. */
+ public static class DefiniteKeySource extends RichParallelSourceFunction<Integer> {
+
+ private static final long serialVersionUID = -400066323594122516L;
+
+ private final int numberKeys;
+ private final int numberElements;
+ private final boolean terminateAfterEmission;
+
+ protected int counter = 0;
+
+ private boolean running = true;
+
+ public DefiniteKeySource(
+ int numberKeys, int numberElements, boolean terminateAfterEmission) {
+ this.numberKeys = numberKeys;
+ this.numberElements = numberElements;
+ this.terminateAfterEmission = terminateAfterEmission;
+ }
+
+ @Override
+ public void run(SourceContext<Integer> ctx) throws Exception {
+ final Object lock = ctx.getCheckpointLock();
+ final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+ while (running) {
+
+ if (counter < numberElements) {
+ synchronized (lock) {
+ for (int value = subtaskIndex;
+ value < numberKeys;
+ value += getRuntimeContext().getNumberOfParallelSubtasks()) {
+ ctx.collect(value);
+ }
+ counter++;
+ }
+ } else {
+ if (terminateAfterEmission) {
+ running = false;
+ } else {
+ Thread.sleep(100);
+ }
+ }
+ }
+ }
+
+ @Override
+ public void cancel() {
+ running = false;
+ }
+ }
+
+ /** A flatMapper with the index of subtask. */
+ public static class SubtaskIndexFlatMapper
+ extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
+ implements CheckpointedFunction {
+
+ private static final long serialVersionUID = 5273172591283191348L;
+
+ public static CountDownLatch workCompletedLatch = new CountDownLatch(1);
+
+ private transient ValueState<Integer> counter;
+ private transient ValueState<Integer> sum;
+
+ private final int numberElements;
+
+ public SubtaskIndexFlatMapper(int numberElements) {
+ this.numberElements = numberElements;
+ }
+
+ @Override
+ public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out)
+ throws Exception {
+
+ int count = counter.value() + 1;
+ counter.update(count);
+
+ int s = sum.value() + value;
+ sum.update(s);
+
+ if (count % numberElements == 0) {
+ out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
+ workCompletedLatch.countDown();
+ }
+ }
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext context) throws Exception {
+ // all managed, nothing to do.
+ }
+
+ @Override
+ public void initializeState(FunctionInitializationContext context) throws Exception {
+ counter =
+ context.getKeyedStateStore()
+ .getState(new ValueStateDescriptor<>("counter", Integer.class, 0));
+ sum =
+ context.getKeyedStateStore()
+ .getState(new ValueStateDescriptor<>("sum", Integer.class, 0));
+ }
+ }
+
+ /** Sink for collecting results into a collection. */
+ public static class CollectionSink<IN> implements SinkFunction<IN> {
+
+ private static final Set<Object> elements =
+ Collections.newSetFromMap(new ConcurrentHashMap<>());
+
+ private static final long serialVersionUID = -1652452958040267745L;
+
+ public static <IN> Set<IN> getElementsSet() {
+ return (Set<IN>) elements;
+ }
+
+ public static void clearElementsSet() {
+ elements.clear();
+ }
+
+ @Override
+ public void invoke(IN value) throws Exception {
+ elements.add(value);
+ }
+ }
+}
diff --git a/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java b/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java
index 1856960..bd69b85 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/util/TestUtils.java
@@ -18,6 +18,8 @@
package org.apache.flink.test.util;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.JobStatus;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.core.execution.JobClient;
import org.apache.flink.runtime.checkpoint.Checkpoints;
@@ -39,6 +41,7 @@ import java.nio.file.Path;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.Comparator;
import java.util.Optional;
+import java.util.concurrent.ExecutionException;
import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.CHECKPOINT_DIR_PREFIX;
import static org.apache.flink.runtime.state.filesystem.AbstractFsCheckpointStorageAccess.METADATA_FILE_NAME;
@@ -139,4 +142,23 @@ public class TestUtils {
return false; // should never happen
}
}
+
+ public static void waitUntilExternalizedCheckpointCreated(File checkpointDir)
+ throws InterruptedException, IOException {
+ while (true) {
+ Thread.sleep(50);
+ Optional<File> externalizedCheckpoint =
+ getMostRecentCompletedCheckpointMaybe(checkpointDir);
+ if (externalizedCheckpoint.isPresent()) {
+ break;
+ }
+ }
+ }
+
+ public static void waitUntilJobCanceled(JobID jobId, ClusterClient<?> client)
+ throws ExecutionException, InterruptedException {
+ while (client.getJobStatus(jobId).get() != JobStatus.CANCELED) {
+ Thread.sleep(50);
+ }
+ }
}