You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:46 UTC
[flink-ml] 08/08: [FLINK-24655][iteration] Add ITCase for the
checkpoint and failover
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit b9ee412b0951d13ff5cf3d610a818a35e503a949
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Mon Nov 1 23:56:09 2021 +0800
[FLINK-24655][iteration] Add ITCase for the checkpoint and failover
This closes #17.
---
.../iteration/BoundedAllRoundCheckpointTest.java | 196 +++++++++++++++++++++
.../iteration/UnboundedStreamIterationITCase.java | 5 +-
.../flink/test/iteration/operators/FailingMap.java | 45 +++++
.../operators/ReduceAllRoundProcessFunction.java | 55 +++++-
.../test/iteration/operators/SequenceSource.java | 40 ++++-
.../TwoInputReduceAllRoundProcessFunction.java | 16 +-
6 files changed, 347 insertions(+), 10 deletions(-)
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java
new file mode 100644
index 0000000..d53e334
--- /dev/null
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.iteration;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.streaming.api.CheckpointingMode;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.test.iteration.operators.EpochRecord;
+import org.apache.flink.test.iteration.operators.FailingMap;
+import org.apache.flink.test.iteration.operators.IncrementEpochMap;
+import org.apache.flink.test.iteration.operators.OutputRecord;
+import org.apache.flink.test.iteration.operators.SequenceSource;
+import org.apache.flink.test.iteration.operators.TwoInputReduceAllRoundProcessFunction;
+import org.apache.flink.testutils.junit.SharedObjects;
+import org.apache.flink.testutils.junit.SharedReference;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.test.iteration.UnboundedStreamIterationITCase.createMiniClusterConfiguration;
+import static org.junit.Assert.assertEquals;
+
+/** Tests checkpoints. */
+@RunWith(Parameterized.class)
+public class BoundedAllRoundCheckpointTest extends TestLogger {
+
+ @Rule public final SharedObjects sharedObjects = SharedObjects.create();
+
+ private SharedReference<List<OutputRecord<Integer>>> result;
+
+ @Parameterized.Parameter(0)
+ public int failoverCount;
+
+ @Parameterized.Parameter(1)
+ public boolean sync;
+
+ @Parameterized.Parameters(name = "failoverCount = {0}, sync = {1}")
+ public static Collection<Object[]> params() {
+ int[] failoverCounts = {1000, 4000, 8000, 15900};
+ boolean[] syncs = {true, false};
+
+ List<Object[]> result = new ArrayList<>();
+ for (int failoverCount : failoverCounts) {
+ for (boolean sync : syncs) {
+ result.add(new Object[] {failoverCount, sync});
+ }
+ }
+
+ return result;
+ }
+
+ @Before
+ public void setup() {
+ result = sharedObjects.add(new ArrayList<>());
+ }
+
+ @Test
+ public void testFailoverAndRestore() throws Exception {
+ try (MiniCluster miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2))) {
+ miniCluster.start();
+
+ // Create the test job
+ JobGraph jobGraph =
+ createVariableAndConstantJobGraph(
+ 4, 1000, false, 0, sync, 4, failoverCount, new CollectSink(result));
+ miniCluster.executeJobBlocking(jobGraph);
+
+ Map<Integer, Tuple2<Integer, Integer>> roundsStat = new HashMap<>();
+ for (OutputRecord<Integer> output : result.get()) {
+ Tuple2<Integer, Integer> state =
+ roundsStat.computeIfAbsent(
+ output.getRound(), ignored -> new Tuple2<>(0, 0));
+ state.f0++;
+ state.f1 = output.getValue();
+ }
+
+ // 0 ~ 4 round and termination information
+ assertEquals(6, roundsStat.size());
+ for (int i = 0; i <= 4; ++i) {
+ // In this case we could only check the final result, the number of records is not
+ // deterministic.
+ assertEquals(4 * (0 + 999) * 1000 / 2, (int) roundsStat.get(i).f1);
+ }
+ }
+ }
+
+ static JobGraph createVariableAndConstantJobGraph(
+ int numSources,
+ int numRecordsPerSource,
+ boolean holdSource,
+ int period,
+ boolean sync,
+ int maxRound,
+ int failoverCount,
+ SinkFunction<OutputRecord<Integer>> sinkFunction) {
+ StreamExecutionEnvironment env =
+ StreamExecutionEnvironment.getExecutionEnvironment(
+ new Configuration() {
+ {
+ this.set(
+ ExecutionCheckpointingOptions
+ .ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
+ true);
+ }
+ });
+ env.enableCheckpointing(500, CheckpointingMode.EXACTLY_ONCE);
+ env.setParallelism(1);
+ DataStream<EpochRecord> variableSource =
+ env.addSource(new DraftExecutionEnvironment.EmptySource<EpochRecord>() {})
+ .setParallelism(numSources)
+ .name("Variable");
+ DataStream<EpochRecord> constSource =
+ env.addSource(new SequenceSource(numRecordsPerSource, holdSource, period))
+ .setParallelism(numSources)
+ .name("Constant");
+ DataStreamList outputs =
+ Iterations.iterateUnboundedStreams(
+ DataStreamList.of(variableSource),
+ DataStreamList.of(constSource),
+ (variableStreams, dataStreams) -> {
+ SingleOutputStreamOperator<EpochRecord> reducer =
+ variableStreams
+ .<EpochRecord>get(0)
+ .connect(dataStreams.<EpochRecord>get(0))
+ .process(
+ new TwoInputReduceAllRoundProcessFunction(
+ sync, maxRound));
+ DataStream<EpochRecord> failedMap =
+ reducer.map(new FailingMap(failoverCount));
+ return new IterationBodyResult(
+ DataStreamList.of(
+ failedMap
+ .map(new IncrementEpochMap())
+ .setParallelism(numSources)),
+ DataStreamList.of(
+ reducer.getSideOutput(
+ new OutputTag<OutputRecord<Integer>>(
+ "output") {})));
+ });
+ outputs.<OutputRecord<Integer>>get(0).addSink(sinkFunction);
+
+ return env.getStreamGraph().getJobGraph();
+ }
+
+ private static class CollectSink implements SinkFunction<OutputRecord<Integer>> {
+
+ private final SharedReference<List<OutputRecord<Integer>>> result;
+
+ private CollectSink(SharedReference<List<OutputRecord<Integer>>> result) {
+ this.result = result;
+ }
+
+ @Override
+ public void invoke(OutputRecord<Integer> value, Context context) throws Exception {
+ result.get().add(value);
+ }
+ }
+}
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
index df084eb..6d80f23 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.minicluster.MiniCluster;
import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.test.iteration.operators.CollectSink;
import org.apache.flink.test.iteration.operators.EpochRecord;
@@ -150,9 +151,11 @@ public class UnboundedStreamIterationITCase extends TestLogger {
assertEquals(OutputRecord.Event.TERMINATED, result.get().take().getEvent());
}
- static MiniClusterConfiguration createMiniClusterConfiguration(int numTm, int numSlot) {
+ public static MiniClusterConfiguration createMiniClusterConfiguration(int numTm, int numSlot) {
Configuration configuration = new Configuration();
configuration.set(RestOptions.BIND_PORT, "18081-19091");
+ configuration.set(
+ ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
return new MiniClusterConfiguration.Builder()
.setConfiguration(configuration)
.setNumTaskManagers(numTm)
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/FailingMap.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/FailingMap.java
new file mode 100644
index 0000000..6c38d32
--- /dev/null
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/FailingMap.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.iteration.operators;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+
+/** Map Function triggers failover at the first task and first round. */
+public class FailingMap extends RichMapFunction<EpochRecord, EpochRecord> {
+
+ private final int failingCount;
+
+ private int count;
+
+ public FailingMap(int failingCount) {
+ this.failingCount = failingCount;
+ }
+
+ @Override
+ public EpochRecord map(EpochRecord value) throws Exception {
+ count++;
+ if (getRuntimeContext().getIndexOfThisSubtask() == 0
+ && getRuntimeContext().getAttemptNumber() == 0
+ && count >= failingCount) {
+ throw new RuntimeException("Artificial Exception");
+ }
+
+ return value;
+ }
+}
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java
index dfce6e6..18d04d5 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java
@@ -18,16 +18,26 @@
package org.apache.flink.test.iteration.operators;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.MapTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.ProcessFunction;
import org.apache.flink.util.Collector;
import org.apache.flink.util.OutputTag;
import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.function.BiConsumer;
/**
@@ -35,7 +45,7 @@ import java.util.function.BiConsumer;
* the received numbers to the next operator.
*/
public class ReduceAllRoundProcessFunction extends ProcessFunction<EpochRecord, EpochRecord>
- implements IterationListener<EpochRecord> {
+ implements IterationListener<EpochRecord>, CheckpointedFunction {
private final boolean sync;
@@ -47,17 +57,54 @@ public class ReduceAllRoundProcessFunction extends ProcessFunction<EpochRecord,
private transient OutputTag<OutputRecord<Integer>> outputTag;
+ private transient ListState<Map<Integer, Integer>> sumByEpochsState;
+
+ private transient ListState<EpochRecord> cachedRecordsState;
+
public ReduceAllRoundProcessFunction(boolean sync, int maxRound) {
this.sync = sync;
this.maxRound = maxRound;
}
@Override
- public void open(Configuration parameters) throws Exception {
- super.open(parameters);
+ public void initializeState(FunctionInitializationContext functionInitializationContext)
+ throws Exception {
sumByEpochs = new HashMap<>();
cachedRecords = new ArrayList<>();
- outputTag = new OutputTag<OutputRecord<Integer>>("output") {};
+
+ sumByEpochsState =
+ functionInitializationContext
+ .getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "test",
+ new MapTypeInfo<>(
+ BasicTypeInfo.INT_TYPE_INFO,
+ BasicTypeInfo.INT_TYPE_INFO)));
+ Optional<Map<Integer, Integer>> old =
+ OperatorStateUtils.getUniqueElement(sumByEpochsState, "test");
+ old.ifPresent(v -> sumByEpochs.putAll(v));
+
+ cachedRecordsState =
+ functionInitializationContext
+ .getOperatorStateStore()
+ .getListState(new ListStateDescriptor<>("cache", EpochRecord.class));
+ cachedRecordsState.get().forEach(v -> cachedRecords.add(v));
+ }
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+ sumByEpochsState.clear();
+ sumByEpochsState.update(Collections.singletonList(new HashMap<>(sumByEpochs)));
+
+ cachedRecordsState.clear();
+ cachedRecordsState.addAll(cachedRecords);
+ }
+
+ @Override
+ public void open(Configuration parameters) throws Exception {
+ super.open(parameters);
+ this.outputTag = new OutputTag<OutputRecord<Integer>>("output") {};
}
@Override
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java
index 4054cf6..ed564b0 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java
@@ -18,10 +18,19 @@
package org.apache.flink.test.iteration.operators;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
-/** A source emitting the continuous int sequences. */
-public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
+import java.util.Collections;
+
+/** Sources emitting the continuous int sequences. */
+public class SequenceSource extends RichParallelSourceFunction<EpochRecord>
+ implements CheckpointedFunction {
private final int maxValue;
@@ -31,6 +40,10 @@ public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
private volatile boolean canceled;
+ private int next;
+
+ private ListState<Integer> nextState;
+
public SequenceSource(int maxValue, boolean holdAfterMaxValue, int period) {
this.maxValue = maxValue;
this.holdAfterMaxValue = holdAfterMaxValue;
@@ -38,9 +51,22 @@ public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
}
@Override
+ public void initializeState(FunctionInitializationContext functionInitializationContext)
+ throws Exception {
+ nextState =
+ functionInitializationContext
+ .getOperatorStateStore()
+ .getListState(new ListStateDescriptor<>("next", Integer.class));
+ next = OperatorStateUtils.getUniqueElement(nextState, "next").orElse(0);
+ }
+
+ @Override
public void run(SourceContext<EpochRecord> ctx) throws Exception {
- for (int i = 0; i < maxValue && !canceled; ++i) {
- ctx.collect(new EpochRecord(0, i));
+ while (next < maxValue && !canceled) {
+ synchronized (ctx.getCheckpointLock()) {
+ ctx.collect(new EpochRecord(0, next++));
+ }
+
if (period > 0) {
Thread.sleep(period);
}
@@ -57,4 +83,10 @@ public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
public void cancel() {
canceled = true;
}
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+ nextState.clear();
+ nextState.update(Collections.singletonList(next));
+ }
}
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java
index 35e6876..1fea9a5 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java
@@ -20,6 +20,9 @@ package org.apache.flink.test.iteration.operators;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
import org.apache.flink.util.Collector;
@@ -28,7 +31,7 @@ import org.apache.flink.util.Collector;
*/
public class TwoInputReduceAllRoundProcessFunction
extends CoProcessFunction<EpochRecord, EpochRecord, EpochRecord>
- implements IterationListener<EpochRecord> {
+ implements IterationListener<EpochRecord>, CheckpointedFunction {
private final ReduceAllRoundProcessFunction reduceAllRoundProcessFunction;
@@ -78,4 +81,15 @@ public class TwoInputReduceAllRoundProcessFunction
// Processing the first round of messages.
reduceAllRoundProcessFunction.processRecord(record, ctx::output, out);
}
+
+ @Override
+ public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+ reduceAllRoundProcessFunction.snapshotState(functionSnapshotContext);
+ }
+
+ @Override
+ public void initializeState(FunctionInitializationContext functionInitializationContext)
+ throws Exception {
+ reduceAllRoundProcessFunction.initializeState(functionInitializationContext);
+ }
}