You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:44 UTC
[flink-ml] 06/08: [FLINK-24655][iteration] Skip the repeat round
for all-round operator
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 31ffe6c80339163b4bad05a14a8703d20177edb4
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Oct 7 01:28:27 2021 +0800
[FLINK-24655][iteration] Skip the repeat round for all-round operator
---
.../allround/AbstractAllRoundWrapperOperator.java | 151 ++++++++++++++++++---
.../OneInputAllRoundWrapperOperatorTest.java | 70 ++++++++++
2 files changed, 204 insertions(+), 17 deletions(-)
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index 2855e38..d3461a1 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -18,28 +18,45 @@
package org.apache.flink.iteration.operator.allround;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.fs.CloseableRegistry;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.AbstractWrapperOperator;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.metrics.groups.OperatorMetricGroup;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.KeyContext;
import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
import org.apache.flink.streaming.api.operators.TimestampedCollector;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
import org.apache.flink.util.OutputTag;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
import java.io.IOException;
+import java.util.Collections;
import static org.apache.flink.iteration.operator.OperatorUtils.processOperatorOrUdfIfSatisfy;
+import static org.apache.flink.util.Preconditions.checkState;
/** The base class for the all-round wrapper operators. */
public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperator<T>>
@@ -49,6 +66,13 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
private final IterationContext iterationContext;
+ // --------------- state ---------------------------
+ private int latestEpochWatermark = -1;
+
+ private ListState<Integer> parallelismState;
+
+ private ListState<Integer> latestEpochWatermarkState;
+
@SuppressWarnings({"unchecked", "rawtypes"})
public AbstractAllRoundWrapperOperator(
StreamOperatorParameters<IterationRecord<T>> parameters,
@@ -75,6 +99,11 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
@Override
public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
+ if (epochWatermark <= latestEpochWatermark) {
+ return;
+ }
+ latestEpochWatermark = epochWatermark;
+
setIterationContextRound(epochWatermark);
processOperatorOrUdfIfSatisfy(
wrappedOperator,
@@ -100,23 +129,42 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
}
@Override
- public void open() throws Exception {
- wrappedOperator.open();
- }
+ public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
+ throws Exception {
+ RecordingStreamTaskStateInitializer recordingStreamTaskStateInitializer =
+ new RecordingStreamTaskStateInitializer(streamTaskStateManager);
+ wrappedOperator.initializeState(recordingStreamTaskStateInitializer);
+ checkState(recordingStreamTaskStateInitializer.lastCreated != null);
- @Override
- public void finish() throws Exception {
- wrappedOperator.finish();
- }
+ OperatorStateStore operatorStateStore =
+ recordingStreamTaskStateInitializer.lastCreated.operatorStateBackend();
- @Override
- public void close() throws Exception {
- wrappedOperator.close();
- }
+ parallelismState =
+ operatorStateStore.getUnionListState(
+ new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
+ OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
+ .ifPresent(
+ oldParallelism ->
+ checkState(
+ oldParallelism
+ == containingTask
+ .getEnvironment()
+ .getTaskInfo()
+ .getNumberOfParallelSubtasks(),
+ "The all-round wrapper operator is recovered with parallelism changed from "
+ + oldParallelism
+ + " to "
+ + containingTask
+ .getEnvironment()
+ .getTaskInfo()
+ .getNumberOfParallelSubtasks()));
- @Override
- public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
- wrappedOperator.prepareSnapshotPreBarrier(checkpointId);
+ latestEpochWatermarkState =
+ operatorStateStore.getListState(
+ new ListStateDescriptor<>("latestEpoch", IntSerializer.INSTANCE));
+ OperatorStateUtils.getUniqueElement(latestEpochWatermarkState, "latestEpoch")
+ .ifPresent(
+ oldLatestEpochWatermark -> latestEpochWatermark = oldLatestEpochWatermark);
}
@Override
@@ -126,14 +174,41 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
CheckpointOptions checkpointOptions,
CheckpointStreamFactory storageLocation)
throws Exception {
+
+ // Always clear the union list state before set value.
+ parallelismState.clear();
+ if (containingTask.getEnvironment().getTaskInfo().getIndexOfThisSubtask() == 0) {
+ parallelismState.update(
+ Collections.singletonList(
+ containingTask
+ .getEnvironment()
+ .getTaskInfo()
+ .getNumberOfParallelSubtasks()));
+ }
+ latestEpochWatermarkState.update(Collections.singletonList(latestEpochWatermark));
+
return wrappedOperator.snapshotState(
checkpointId, timestamp, checkpointOptions, storageLocation);
}
@Override
- public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
- throws Exception {
- wrappedOperator.initializeState(streamTaskStateManager);
+ public void open() throws Exception {
+ wrappedOperator.open();
+ }
+
+ @Override
+ public void finish() throws Exception {
+ wrappedOperator.finish();
+ }
+
+ @Override
+ public void close() throws Exception {
+ wrappedOperator.close();
+ }
+
+ @Override
+ public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+ wrappedOperator.prepareSnapshotPreBarrier(checkpointId);
}
@Override
@@ -176,6 +251,11 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
return wrappedOperator.getCurrentKey();
}
+ @VisibleForTesting
+ int getLatestEpochWatermark() {
+ return latestEpochWatermark;
+ }
+
private class IterationContext implements IterationListener.Context {
@Override
@@ -183,4 +263,41 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
proxyOutput.collect(outputTag, new StreamRecord<>(value));
}
}
+
+ private static class RecordingStreamTaskStateInitializer implements StreamTaskStateInitializer {
+
+ private final StreamTaskStateInitializer wrapped;
+
+ StreamOperatorStateContext lastCreated;
+
+ public RecordingStreamTaskStateInitializer(StreamTaskStateInitializer wrapped) {
+ this.wrapped = wrapped;
+ }
+
+ @Override
+ public StreamOperatorStateContext streamOperatorStateContext(
+ @Nonnull OperatorID operatorID,
+ @Nonnull String s,
+ @Nonnull ProcessingTimeService processingTimeService,
+ @Nonnull KeyContext keyContext,
+ @Nullable TypeSerializer<?> typeSerializer,
+ @Nonnull CloseableRegistry closeableRegistry,
+ @Nonnull MetricGroup metricGroup,
+ double v,
+ boolean b)
+ throws Exception {
+ lastCreated =
+ wrapped.streamOperatorStateContext(
+ operatorID,
+ s,
+ processingTimeService,
+ keyContext,
+ typeSerializer,
+ closeableRegistry,
+ metricGroup,
+ v,
+ b);
+ return lastCreated;
+ }
+ }
}
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
index f628b65..9ebb975 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
@@ -21,12 +21,14 @@ package org.apache.flink.iteration.operator.allround;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.iteration.operator.OperatorWrapper;
import org.apache.flink.iteration.operator.WrapperOperatorFactory;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.io.network.api.EndOfData;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
@@ -38,7 +40,9 @@ import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
@@ -53,6 +57,7 @@ import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
/** Tests the {@link OneInputAllRoundWrapperOperator}. */
public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
@@ -129,6 +134,71 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
}
}
+ @Test
+ public void testSnapshotAndRestore() throws Exception {
+ StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory =
+ new RecordingOperatorFactory<>(
+ SimpleOperatorFactory.of(new LifeCycleTrackingOneInputStreamOperator()),
+ new AllRoundOperatorWrapper<>());
+ OperatorID operatorId = new OperatorID();
+
+ TaskStateSnapshot taskStateSnapshot = null;
+ try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
+ new StreamTaskMailboxTestHarnessBuilder<>(
+ OneInputStreamTask::new,
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
+ .build()) {
+ harness.getTaskStateManager().getWaitForReportLatch().reset();
+ harness.processElement(
+ new StreamRecord<>(IterationRecord.newEpochWatermark(5, "fake")));
+ harness.getStreamTask()
+ .triggerCheckpointAsync(
+ new CheckpointMetaData(2, 1000),
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference.getDefault()));
+ harness.processAll();
+
+ harness.getTaskStateManager().getWaitForReportLatch().await();
+ taskStateSnapshot = harness.getTaskStateManager().getLastJobManagerTaskStateSnapshot();
+ }
+
+ assertNotNull(taskStateSnapshot);
+ try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
+ new StreamTaskMailboxTestHarnessBuilder<>(
+ OneInputStreamTask::new,
+ new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .setTaskStateSnapshot(2, taskStateSnapshot)
+ .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+ .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
+ .build()) {
+ assertEquals(
+ 5,
+ ((AbstractAllRoundWrapperOperator) RecordingOperatorFactory.latest)
+ .getLatestEpochWatermark());
+ }
+ }
+
+ private static class RecordingOperatorFactory<OUT> extends WrapperOperatorFactory<OUT> {
+
+ static StreamOperator<?> latest = null;
+
+ public RecordingOperatorFactory(
+ StreamOperatorFactory<OUT> operatorFactory,
+ OperatorWrapper<OUT, IterationRecord<OUT>> wrapper) {
+ super(operatorFactory, wrapper);
+ }
+
+ @Override
+ public <T extends StreamOperator<IterationRecord<OUT>>> T createStreamOperator(
+ StreamOperatorParameters<IterationRecord<OUT>> parameters) {
+ latest = super.createStreamOperator(parameters);
+ return (T) latest;
+ }
+ }
+
private static class LifeCycleTrackingOneInputStreamOperator
extends AbstractStreamOperator<Integer>
implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput {