You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:44 UTC

[flink-ml] 06/08: [FLINK-24655][iteration] Skip the repeat round for all-round operator

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 31ffe6c80339163b4bad05a14a8703d20177edb4
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Oct 7 01:28:27 2021 +0800

    [FLINK-24655][iteration] Skip the repeat round for all-round operator
---
 .../allround/AbstractAllRoundWrapperOperator.java  | 151 ++++++++++++++++++---
 .../OneInputAllRoundWrapperOperatorTest.java       |  70 ++++++++++
 2 files changed, 204 insertions(+), 17 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index 2855e38..d3461a1 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -18,28 +18,45 @@
 
 package org.apache.flink.iteration.operator.allround;
 
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.AbstractWrapperOperator;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.metrics.groups.OperatorMetricGroup;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.KeyContext;
 import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
 import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
 import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
 import org.apache.flink.streaming.api.operators.TimestampedCollector;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.util.OutputTag;
 
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
 import java.io.IOException;
+import java.util.Collections;
 
 import static org.apache.flink.iteration.operator.OperatorUtils.processOperatorOrUdfIfSatisfy;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** The base class for the all-round wrapper operators. */
 public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperator<T>>
@@ -49,6 +66,13 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
 
     private final IterationContext iterationContext;
 
+    // --------------- state ---------------------------
+    private int latestEpochWatermark = -1;
+
+    private ListState<Integer> parallelismState;
+
+    private ListState<Integer> latestEpochWatermarkState;
+
     @SuppressWarnings({"unchecked", "rawtypes"})
     public AbstractAllRoundWrapperOperator(
             StreamOperatorParameters<IterationRecord<T>> parameters,
@@ -75,6 +99,11 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
 
     @Override
     public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
+        if (epochWatermark <= latestEpochWatermark) {
+            return;
+        }
+        latestEpochWatermark = epochWatermark;
+
         setIterationContextRound(epochWatermark);
         processOperatorOrUdfIfSatisfy(
                 wrappedOperator,
@@ -100,23 +129,42 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
     }
 
     @Override
-    public void open() throws Exception {
-        wrappedOperator.open();
-    }
+    public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
+            throws Exception {
+        RecordingStreamTaskStateInitializer recordingStreamTaskStateInitializer =
+                new RecordingStreamTaskStateInitializer(streamTaskStateManager);
+        wrappedOperator.initializeState(recordingStreamTaskStateInitializer);
+        checkState(recordingStreamTaskStateInitializer.lastCreated != null);
 
-    @Override
-    public void finish() throws Exception {
-        wrappedOperator.finish();
-    }
+        OperatorStateStore operatorStateStore =
+                recordingStreamTaskStateInitializer.lastCreated.operatorStateBackend();
 
-    @Override
-    public void close() throws Exception {
-        wrappedOperator.close();
-    }
+        parallelismState =
+                operatorStateStore.getUnionListState(
+                        new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
+        OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
+                .ifPresent(
+                        oldParallelism ->
+                                checkState(
+                                        oldParallelism
+                                                == containingTask
+                                                        .getEnvironment()
+                                                        .getTaskInfo()
+                                                        .getNumberOfParallelSubtasks(),
+                                        "The all-round wrapper operator is recovered with parallelism changed from "
+                                                + oldParallelism
+                                                + " to "
+                                                + containingTask
+                                                        .getEnvironment()
+                                                        .getTaskInfo()
+                                                        .getNumberOfParallelSubtasks()));
 
-    @Override
-    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
-        wrappedOperator.prepareSnapshotPreBarrier(checkpointId);
+        latestEpochWatermarkState =
+                operatorStateStore.getListState(
+                        new ListStateDescriptor<>("latestEpoch", IntSerializer.INSTANCE));
+        OperatorStateUtils.getUniqueElement(latestEpochWatermarkState, "latestEpoch")
+                .ifPresent(
+                        oldLatestEpochWatermark -> latestEpochWatermark = oldLatestEpochWatermark);
     }
 
     @Override
@@ -126,14 +174,41 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
             CheckpointOptions checkpointOptions,
             CheckpointStreamFactory storageLocation)
             throws Exception {
+
+        // Always clear the union list state before set value.
+        parallelismState.clear();
+        if (containingTask.getEnvironment().getTaskInfo().getIndexOfThisSubtask() == 0) {
+            parallelismState.update(
+                    Collections.singletonList(
+                            containingTask
+                                    .getEnvironment()
+                                    .getTaskInfo()
+                                    .getNumberOfParallelSubtasks()));
+        }
+        latestEpochWatermarkState.update(Collections.singletonList(latestEpochWatermark));
+
         return wrappedOperator.snapshotState(
                 checkpointId, timestamp, checkpointOptions, storageLocation);
     }
 
     @Override
-    public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
-            throws Exception {
-        wrappedOperator.initializeState(streamTaskStateManager);
+    public void open() throws Exception {
+        wrappedOperator.open();
+    }
+
+    @Override
+    public void finish() throws Exception {
+        wrappedOperator.finish();
+    }
+
+    @Override
+    public void close() throws Exception {
+        wrappedOperator.close();
+    }
+
+    @Override
+    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+        wrappedOperator.prepareSnapshotPreBarrier(checkpointId);
     }
 
     @Override
@@ -176,6 +251,11 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
         return wrappedOperator.getCurrentKey();
     }
 
+    @VisibleForTesting
+    int getLatestEpochWatermark() {
+        return latestEpochWatermark;
+    }
+
     private class IterationContext implements IterationListener.Context {
 
         @Override
@@ -183,4 +263,41 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
             proxyOutput.collect(outputTag, new StreamRecord<>(value));
         }
     }
+
+    private static class RecordingStreamTaskStateInitializer implements StreamTaskStateInitializer {
+
+        private final StreamTaskStateInitializer wrapped;
+
+        StreamOperatorStateContext lastCreated;
+
+        public RecordingStreamTaskStateInitializer(StreamTaskStateInitializer wrapped) {
+            this.wrapped = wrapped;
+        }
+
+        @Override
+        public StreamOperatorStateContext streamOperatorStateContext(
+                @Nonnull OperatorID operatorID,
+                @Nonnull String s,
+                @Nonnull ProcessingTimeService processingTimeService,
+                @Nonnull KeyContext keyContext,
+                @Nullable TypeSerializer<?> typeSerializer,
+                @Nonnull CloseableRegistry closeableRegistry,
+                @Nonnull MetricGroup metricGroup,
+                double v,
+                boolean b)
+                throws Exception {
+            lastCreated =
+                    wrapped.streamOperatorStateContext(
+                            operatorID,
+                            s,
+                            processingTimeService,
+                            keyContext,
+                            typeSerializer,
+                            closeableRegistry,
+                            metricGroup,
+                            v,
+                            b);
+            return lastCreated;
+        }
+    }
 }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
index f628b65..9ebb975 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
@@ -21,12 +21,14 @@ package org.apache.flink.iteration.operator.allround;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.iteration.operator.OperatorWrapper;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
@@ -38,7 +40,9 @@ import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
@@ -53,6 +57,7 @@ import java.util.Arrays;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 
 /** Tests the {@link OneInputAllRoundWrapperOperator}. */
 public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
@@ -129,6 +134,71 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
         }
     }
 
+    @Test
+    public void testSnapshotAndRestore() throws Exception {
+        StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory =
+                new RecordingOperatorFactory<>(
+                        SimpleOperatorFactory.of(new LifeCycleTrackingOneInputStreamOperator()),
+                        new AllRoundOperatorWrapper<>());
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot = null;
+        try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                OneInputStreamTask::new,
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
+                        .build()) {
+            harness.getTaskStateManager().getWaitForReportLatch().reset();
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(5, "fake")));
+            harness.getStreamTask()
+                    .triggerCheckpointAsync(
+                            new CheckpointMetaData(2, 1000),
+                            CheckpointOptions.alignedNoTimeout(
+                                    CheckpointType.CHECKPOINT,
+                                    CheckpointStorageLocationReference.getDefault()));
+            harness.processAll();
+
+            harness.getTaskStateManager().getWaitForReportLatch().await();
+            taskStateSnapshot = harness.getTaskStateManager().getLastJobManagerTaskStateSnapshot();
+        }
+
+        assertNotNull(taskStateSnapshot);
+        try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                OneInputStreamTask::new,
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .setTaskStateSnapshot(2, taskStateSnapshot)
+                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
+                        .build()) {
+            assertEquals(
+                    5,
+                    ((AbstractAllRoundWrapperOperator) RecordingOperatorFactory.latest)
+                            .getLatestEpochWatermark());
+        }
+    }
+
+    private static class RecordingOperatorFactory<OUT> extends WrapperOperatorFactory<OUT> {
+
+        static StreamOperator<?> latest = null;
+
+        public RecordingOperatorFactory(
+                StreamOperatorFactory<OUT> operatorFactory,
+                OperatorWrapper<OUT, IterationRecord<OUT>> wrapper) {
+            super(operatorFactory, wrapper);
+        }
+
+        @Override
+        public <T extends StreamOperator<IterationRecord<OUT>>> T createStreamOperator(
+                StreamOperatorParameters<IterationRecord<OUT>> parameters) {
+            latest = super.createStreamOperator(parameters);
+            return (T) latest;
+        }
+    }
+
     private static class LifeCycleTrackingOneInputStreamOperator
             extends AbstractStreamOperator<Integer>
             implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput {