You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:46 UTC

[flink-ml] 08/08: [FLINK-24655][iteration] Add ITCase for the checkpoint and failover

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit b9ee412b0951d13ff5cf3d610a818a35e503a949
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Mon Nov 1 23:56:09 2021 +0800

    [FLINK-24655][iteration] Add ITCase for the checkpoint and failover
    
    This closes #17.
---
 .../iteration/BoundedAllRoundCheckpointTest.java   | 196 +++++++++++++++++++++
 .../iteration/UnboundedStreamIterationITCase.java  |   5 +-
 .../flink/test/iteration/operators/FailingMap.java |  45 +++++
 .../operators/ReduceAllRoundProcessFunction.java   |  55 +++++-
 .../test/iteration/operators/SequenceSource.java   |  40 ++++-
 .../TwoInputReduceAllRoundProcessFunction.java     |  16 +-
 6 files changed, 347 insertions(+), 10 deletions(-)

diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java
new file mode 100644
index 0000000..d53e334
--- /dev/null
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.iteration;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.streaming.api.CheckpointingMode;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.test.iteration.operators.EpochRecord;
+import org.apache.flink.test.iteration.operators.FailingMap;
+import org.apache.flink.test.iteration.operators.IncrementEpochMap;
+import org.apache.flink.test.iteration.operators.OutputRecord;
+import org.apache.flink.test.iteration.operators.SequenceSource;
+import org.apache.flink.test.iteration.operators.TwoInputReduceAllRoundProcessFunction;
+import org.apache.flink.testutils.junit.SharedObjects;
+import org.apache.flink.testutils.junit.SharedReference;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.test.iteration.UnboundedStreamIterationITCase.createMiniClusterConfiguration;
+import static org.junit.Assert.assertEquals;
+
+/** Tests checkpoints. */
+@RunWith(Parameterized.class)
+public class BoundedAllRoundCheckpointTest extends TestLogger {
+
+    @Rule public final SharedObjects sharedObjects = SharedObjects.create();
+
+    private SharedReference<List<OutputRecord<Integer>>> result;
+
+    @Parameterized.Parameter(0)
+    public int failoverCount;
+
+    @Parameterized.Parameter(1)
+    public boolean sync;
+
+    @Parameterized.Parameters(name = "failoverCount = {0}, sync = {1}")
+    public static Collection<Object[]> params() {
+        int[] failoverCounts = {1000, 4000, 8000, 15900};
+        boolean[] syncs = {true, false};
+
+        List<Object[]> result = new ArrayList<>();
+        for (int failoverCount : failoverCounts) {
+            for (boolean sync : syncs) {
+                result.add(new Object[] {failoverCount, sync});
+            }
+        }
+
+        return result;
+    }
+
+    @Before
+    public void setup() {
+        result = sharedObjects.add(new ArrayList<>());
+    }
+
+    @Test
+    public void testFailoverAndRestore() throws Exception {
+        try (MiniCluster miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2))) {
+            miniCluster.start();
+
+            // Create the test job
+            JobGraph jobGraph =
+                    createVariableAndConstantJobGraph(
+                            4, 1000, false, 0, sync, 4, failoverCount, new CollectSink(result));
+            miniCluster.executeJobBlocking(jobGraph);
+
+            Map<Integer, Tuple2<Integer, Integer>> roundsStat = new HashMap<>();
+            for (OutputRecord<Integer> output : result.get()) {
+                Tuple2<Integer, Integer> state =
+                        roundsStat.computeIfAbsent(
+                                output.getRound(), ignored -> new Tuple2<>(0, 0));
+                state.f0++;
+                state.f1 = output.getValue();
+            }
+
+            // 0 ~ 4 round and termination information
+            assertEquals(6, roundsStat.size());
+            for (int i = 0; i <= 4; ++i) {
+                // In this case we could only check the final result, the number of records is not
+                // deterministic.
+                assertEquals(4 * (0 + 999) * 1000 / 2, (int) roundsStat.get(i).f1);
+            }
+        }
+    }
+
+    static JobGraph createVariableAndConstantJobGraph(
+            int numSources,
+            int numRecordsPerSource,
+            boolean holdSource,
+            int period,
+            boolean sync,
+            int maxRound,
+            int failoverCount,
+            SinkFunction<OutputRecord<Integer>> sinkFunction) {
+        StreamExecutionEnvironment env =
+                StreamExecutionEnvironment.getExecutionEnvironment(
+                        new Configuration() {
+                            {
+                                this.set(
+                                        ExecutionCheckpointingOptions
+                                                .ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
+                                        true);
+                            }
+                        });
+        env.enableCheckpointing(500, CheckpointingMode.EXACTLY_ONCE);
+        env.setParallelism(1);
+        DataStream<EpochRecord> variableSource =
+                env.addSource(new DraftExecutionEnvironment.EmptySource<EpochRecord>() {})
+                        .setParallelism(numSources)
+                        .name("Variable");
+        DataStream<EpochRecord> constSource =
+                env.addSource(new SequenceSource(numRecordsPerSource, holdSource, period))
+                        .setParallelism(numSources)
+                        .name("Constant");
+        DataStreamList outputs =
+                Iterations.iterateUnboundedStreams(
+                        DataStreamList.of(variableSource),
+                        DataStreamList.of(constSource),
+                        (variableStreams, dataStreams) -> {
+                            SingleOutputStreamOperator<EpochRecord> reducer =
+                                    variableStreams
+                                            .<EpochRecord>get(0)
+                                            .connect(dataStreams.<EpochRecord>get(0))
+                                            .process(
+                                                    new TwoInputReduceAllRoundProcessFunction(
+                                                            sync, maxRound));
+                            DataStream<EpochRecord> failedMap =
+                                    reducer.map(new FailingMap(failoverCount));
+                            return new IterationBodyResult(
+                                    DataStreamList.of(
+                                            failedMap
+                                                    .map(new IncrementEpochMap())
+                                                    .setParallelism(numSources)),
+                                    DataStreamList.of(
+                                            reducer.getSideOutput(
+                                                    new OutputTag<OutputRecord<Integer>>(
+                                                            "output") {})));
+                        });
+        outputs.<OutputRecord<Integer>>get(0).addSink(sinkFunction);
+
+        return env.getStreamGraph().getJobGraph();
+    }
+
+    private static class CollectSink implements SinkFunction<OutputRecord<Integer>> {
+
+        private final SharedReference<List<OutputRecord<Integer>>> result;
+
+        private CollectSink(SharedReference<List<OutputRecord<Integer>>> result) {
+            this.result = result;
+        }
+
+        @Override
+        public void invoke(OutputRecord<Integer> value, Context context) throws Exception {
+            result.get().add(value);
+        }
+    }
+}
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
index df084eb..6d80f23 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.minicluster.MiniCluster;
 import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.test.iteration.operators.CollectSink;
 import org.apache.flink.test.iteration.operators.EpochRecord;
@@ -150,9 +151,11 @@ public class UnboundedStreamIterationITCase extends TestLogger {
         assertEquals(OutputRecord.Event.TERMINATED, result.get().take().getEvent());
     }
 
-    static MiniClusterConfiguration createMiniClusterConfiguration(int numTm, int numSlot) {
+    public static MiniClusterConfiguration createMiniClusterConfiguration(int numTm, int numSlot) {
         Configuration configuration = new Configuration();
         configuration.set(RestOptions.BIND_PORT, "18081-19091");
+        configuration.set(
+                ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
         return new MiniClusterConfiguration.Builder()
                 .setConfiguration(configuration)
                 .setNumTaskManagers(numTm)
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/FailingMap.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/FailingMap.java
new file mode 100644
index 0000000..6c38d32
--- /dev/null
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/FailingMap.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.iteration.operators;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+
+/** Map Function triggers failover at the first task and first round. */
+public class FailingMap extends RichMapFunction<EpochRecord, EpochRecord> {
+
+    private final int failingCount;
+
+    private int count;
+
+    public FailingMap(int failingCount) {
+        this.failingCount = failingCount;
+    }
+
+    @Override
+    public EpochRecord map(EpochRecord value) throws Exception {
+        count++;
+        if (getRuntimeContext().getIndexOfThisSubtask() == 0
+                && getRuntimeContext().getAttemptNumber() == 0
+                && count >= failingCount) {
+            throw new RuntimeException("Artificial Exception");
+        }
+
+        return value;
+    }
+}
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java
index dfce6e6..18d04d5 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java
@@ -18,16 +18,26 @@
 
 package org.apache.flink.test.iteration.operators;
 
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.MapTypeInfo;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.ProcessFunction;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.OutputTag;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.function.BiConsumer;
 
 /**
@@ -35,7 +45,7 @@ import java.util.function.BiConsumer;
  * the received numbers to the next operator.
  */
 public class ReduceAllRoundProcessFunction extends ProcessFunction<EpochRecord, EpochRecord>
-        implements IterationListener<EpochRecord> {
+        implements IterationListener<EpochRecord>, CheckpointedFunction {
 
     private final boolean sync;
 
@@ -47,17 +57,54 @@ public class ReduceAllRoundProcessFunction extends ProcessFunction<EpochRecord,
 
     private transient OutputTag<OutputRecord<Integer>> outputTag;
 
+    private transient ListState<Map<Integer, Integer>> sumByEpochsState;
+
+    private transient ListState<EpochRecord> cachedRecordsState;
+
     public ReduceAllRoundProcessFunction(boolean sync, int maxRound) {
         this.sync = sync;
         this.maxRound = maxRound;
     }
 
     @Override
-    public void open(Configuration parameters) throws Exception {
-        super.open(parameters);
+    public void initializeState(FunctionInitializationContext functionInitializationContext)
+            throws Exception {
         sumByEpochs = new HashMap<>();
         cachedRecords = new ArrayList<>();
-        outputTag = new OutputTag<OutputRecord<Integer>>("output") {};
+
+        sumByEpochsState =
+                functionInitializationContext
+                        .getOperatorStateStore()
+                        .getListState(
+                                new ListStateDescriptor<>(
+                                        "test",
+                                        new MapTypeInfo<>(
+                                                BasicTypeInfo.INT_TYPE_INFO,
+                                                BasicTypeInfo.INT_TYPE_INFO)));
+        Optional<Map<Integer, Integer>> old =
+                OperatorStateUtils.getUniqueElement(sumByEpochsState, "test");
+        old.ifPresent(v -> sumByEpochs.putAll(v));
+
+        cachedRecordsState =
+                functionInitializationContext
+                        .getOperatorStateStore()
+                        .getListState(new ListStateDescriptor<>("cache", EpochRecord.class));
+        cachedRecordsState.get().forEach(v -> cachedRecords.add(v));
+    }
+
+    @Override
+    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+        sumByEpochsState.clear();
+        sumByEpochsState.update(Collections.singletonList(new HashMap<>(sumByEpochs)));
+
+        cachedRecordsState.clear();
+        cachedRecordsState.addAll(cachedRecords);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        this.outputTag = new OutputTag<OutputRecord<Integer>>("output") {};
     }
 
     @Override
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java
index 4054cf6..ed564b0 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java
@@ -18,10 +18,19 @@
 
 package org.apache.flink.test.iteration.operators;
 
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 
-/** A source emitting the continuous int sequences. */
-public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
+import java.util.Collections;
+
+/** Sources emitting the continuous int sequences. */
+public class SequenceSource extends RichParallelSourceFunction<EpochRecord>
+        implements CheckpointedFunction {
 
     private final int maxValue;
 
@@ -31,6 +40,10 @@ public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
 
     private volatile boolean canceled;
 
+    private int next;
+
+    private ListState<Integer> nextState;
+
     public SequenceSource(int maxValue, boolean holdAfterMaxValue, int period) {
         this.maxValue = maxValue;
         this.holdAfterMaxValue = holdAfterMaxValue;
@@ -38,9 +51,22 @@ public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
     }
 
     @Override
+    public void initializeState(FunctionInitializationContext functionInitializationContext)
+            throws Exception {
+        nextState =
+                functionInitializationContext
+                        .getOperatorStateStore()
+                        .getListState(new ListStateDescriptor<>("next", Integer.class));
+        next = OperatorStateUtils.getUniqueElement(nextState, "next").orElse(0);
+    }
+
+    @Override
     public void run(SourceContext<EpochRecord> ctx) throws Exception {
-        for (int i = 0; i < maxValue && !canceled; ++i) {
-            ctx.collect(new EpochRecord(0, i));
+        while (next < maxValue && !canceled) {
+            synchronized (ctx.getCheckpointLock()) {
+                ctx.collect(new EpochRecord(0, next++));
+            }
+
             if (period > 0) {
                 Thread.sleep(period);
             }
@@ -57,4 +83,10 @@ public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
     public void cancel() {
         canceled = true;
     }
+
+    @Override
+    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+        nextState.clear();
+        nextState.update(Collections.singletonList(next));
+    }
 }
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java
index 35e6876..1fea9a5 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java
@@ -20,6 +20,9 @@ package org.apache.flink.test.iteration.operators;
 
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
 import org.apache.flink.util.Collector;
 
@@ -28,7 +31,7 @@ import org.apache.flink.util.Collector;
  */
 public class TwoInputReduceAllRoundProcessFunction
         extends CoProcessFunction<EpochRecord, EpochRecord, EpochRecord>
-        implements IterationListener<EpochRecord> {
+        implements IterationListener<EpochRecord>, CheckpointedFunction {
 
     private final ReduceAllRoundProcessFunction reduceAllRoundProcessFunction;
 
@@ -78,4 +81,15 @@ public class TwoInputReduceAllRoundProcessFunction
         // Processing the first round of messages.
         reduceAllRoundProcessFunction.processRecord(record, ctx::output, out);
     }
+
+    @Override
+    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+        reduceAllRoundProcessFunction.snapshotState(functionSnapshotContext);
+    }
+
+    @Override
+    public void initializeState(FunctionInitializationContext functionInitializationContext)
+            throws Exception {
+        reduceAllRoundProcessFunction.initializeState(functionInitializationContext);
+    }
 }