You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:41 UTC
[flink-ml] 03/08: [FLINK-24655][iteration] Make head operator
aligned with coordinator for each checkpoint
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 5896237262fc1327ab465dbbaeceaa8402c23c09
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Mon Oct 4 21:20:49 2021 +0800
[FLINK-24655][iteration] Make head operator aligned with coordinator for each checkpoint
---
.../flink/iteration/operator/HeadOperator.java | 63 +++++++++--
.../operator/HeadOperatorCheckpointAligner.java | 104 +++++++++++++++++
.../coordinator/HeadOperatorCoordinator.java | 36 +++---
.../coordinator/SharedProgressAligner.java | 107 ++++++++++++++----
.../coordinator/SharedProgressAlignerListener.java | 30 +++++
.../operator/event/CoordinatorCheckpointEvent.java | 59 ++++++++++
.../flink/iteration/operator/HeadOperatorTest.java | 117 +++++++++++++++++++-
.../coordinator/SharedProgressAlignerTest.java | 123 ++++++++++++++-------
8 files changed, 554 insertions(+), 85 deletions(-)
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index d7a9a54..6e5a7b4 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -23,9 +23,11 @@ import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.operators.MailboxExecutor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.broadcast.BroadcastOutput;
import org.apache.flink.iteration.broadcast.BroadcastOutputFactory;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.iteration.operator.headprocessor.HeadOperatorRecordProcessor;
@@ -36,6 +38,7 @@ import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
@@ -49,6 +52,7 @@ import org.apache.flink.streaming.api.operators.Output;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.util.Collector;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.OutputTag;
@@ -59,8 +63,23 @@ import java.util.concurrent.Executor;
import static org.apache.flink.util.Preconditions.checkState;
/**
- * The head operator unions the initialized variable stream and the feedback stream, and synchronize
- * the epoch watermark (round).
+ * The head operator unions the initialized variable stream and the feedback stream, synchronize the
+ * epoch watermark (round) and taking care of the checkpoints.
+ *
+ * <p>Specially for checkpoint, the head operator would like to
+ *
+ * <ul>
+ * <li>Ensures the exactly-once for processing elements.
+ * <li>Ensures the exactly-once for {@link IterationListener#onEpochWatermarkIncremented(int,
+ * IterationListener.Context, Collector)}.
+ * </ul>
+ *
+ * <p>To implement the first target, the head operator also need to include the records between
+ * alignment and received barrier from the feed-back edge into the snapshot. To implement the second
+ * target, the head operator would also wait for the notification from the OperatorCoordinator in
+ * additional to the task inputs. This ensures the {@link GloballyAlignedEvent} would not interleave
+ * with the epoch watermarks and all the tasks inside the iteration would be notified with the same
+ * epochs, which facility the rescaling in the future.
*/
public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
implements OneInputStreamOperator<IterationRecord<?>, IterationRecord<?>>,
@@ -91,6 +110,8 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
private HeadOperatorRecordProcessor recordProcessor;
+ private HeadOperatorCheckpointAligner checkpointAligner;
+
public HeadOperator(
IterationID iterationId,
int feedbackIndex,
@@ -129,6 +150,8 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
status = HeadOperatorStatus.RUNNING;
recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
+ checkpointAligner = new HeadOperatorCheckpointAligner();
+
// Here we register a mail
registerFeedbackConsumer(
(Runnable runnable) -> {
@@ -139,6 +162,21 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
}
@Override
+ public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+ super.prepareSnapshotPreBarrier(checkpointId);
+
+ checkpointAligner.waitTillCoordinatorNotified(checkpointId, mailboxExecutor::yield);
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ checkpointAligner
+ .onStateSnapshot(context.getCheckpointId())
+ .forEach(this::processGloballyAlignedEvent);
+ }
+
+ @Override
public void processElement(StreamRecord<IterationRecord<?>> element) throws Exception {
recordProcessor.processElement(element);
}
@@ -155,12 +193,21 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
@Override
public void handleOperatorEvent(OperatorEvent operatorEvent) {
if (operatorEvent instanceof GloballyAlignedEvent) {
- boolean shouldTerminate =
- recordProcessor.onGloballyAligned((GloballyAlignedEvent) operatorEvent);
- if (shouldTerminate) {
- status = HeadOperatorStatus.TERMINATING;
- recordProcessor = new TerminatingHeadOperatorRecordProcessor();
- }
+ checkpointAligner
+ .checkHoldingGloballyAlignedEvent((GloballyAlignedEvent) operatorEvent)
+ .ifPresent(this::processGloballyAlignedEvent);
+ } else if (operatorEvent instanceof CoordinatorCheckpointEvent) {
+ checkpointAligner.coordinatorNotify((CoordinatorCheckpointEvent) operatorEvent);
+ } else {
+ throw new FlinkRuntimeException("Unsupported operator event: " + operatorEvent);
+ }
+ }
+
+ private void processGloballyAlignedEvent(GloballyAlignedEvent globallyAlignedEvent) {
+ boolean shouldTerminate = recordProcessor.onGloballyAligned(globallyAlignedEvent);
+ if (shouldTerminate) {
+ status = HeadOperatorStatus.TERMINATING;
+ recordProcessor = new TerminatingHeadOperatorRecordProcessor();
}
}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
new file mode 100644
index 0000000..b4f4215
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator;
+
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
+import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.util.function.RunnableWithException;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.TreeMap;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Aligns the checkpoint barrier from the task inputs and the checkpoint event from the coordinator.
+ * Besides, it needs to hold the other operator events after the checkpoint event till the state is
+ * snapshot.
+ */
+class HeadOperatorCheckpointAligner {
+
+ private final TreeMap<Long, CheckpointAlignment> checkpointAlignmments;
+
+ private long latestCheckpointFromCoordinator;
+
+ HeadOperatorCheckpointAligner() {
+ this.checkpointAlignmments = new TreeMap<>();
+ }
+
+ void waitTillCoordinatorNotified(long checkpointId, RunnableWithException defaultAction)
+ throws Exception {
+ CheckpointAlignment checkpointAlignment =
+ checkpointAlignmments.computeIfAbsent(
+ checkpointId, ignored -> new CheckpointAlignment(true, false));
+ while (!checkpointAlignment.notifiedFromCoordinator) {
+ defaultAction.run();
+ }
+ checkpointAlignment.notifiedFromChannels = true;
+ }
+
+ void coordinatorNotify(CoordinatorCheckpointEvent checkpointEvent) {
+ checkState(checkpointEvent.getCheckpointId() > latestCheckpointFromCoordinator);
+ latestCheckpointFromCoordinator = checkpointEvent.getCheckpointId();
+ CheckpointAlignment checkpointAlignment =
+ checkpointAlignmments.computeIfAbsent(
+ checkpointEvent.getCheckpointId(),
+ ignored -> new CheckpointAlignment(false, true));
+ checkpointAlignment.notifiedFromCoordinator = true;
+ }
+
+ Optional<GloballyAlignedEvent> checkHoldingGloballyAlignedEvent(
+ GloballyAlignedEvent globallyAlignedEvent) {
+ CheckpointAlignment checkpointAlignment =
+ checkpointAlignmments.get(latestCheckpointFromCoordinator);
+ if (checkpointAlignment != null && !checkpointAlignment.notifiedFromChannels) {
+ checkpointAlignment.pendingGlobalEvents.add(globallyAlignedEvent);
+ return Optional.empty();
+ }
+
+ return Optional.of(globallyAlignedEvent);
+ }
+
+ List<GloballyAlignedEvent> onStateSnapshot(long checkpointId) {
+ CheckpointAlignment checkpointAlignment = checkpointAlignmments.remove(checkpointId);
+ checkState(
+ checkpointAlignment.notifiedFromCoordinator
+ && checkpointAlignment.notifiedFromChannels,
+ "Checkpoint " + checkpointId + " is not fully aligned");
+ return checkpointAlignment.pendingGlobalEvents;
+ }
+
+ private static class CheckpointAlignment {
+
+ final List<GloballyAlignedEvent> pendingGlobalEvents;
+
+ boolean notifiedFromChannels;
+
+ boolean notifiedFromCoordinator;
+
+ public CheckpointAlignment(boolean notifiedFromChannels, boolean notifiedFromCoordinator) {
+ this.pendingGlobalEvents = new ArrayList<>();
+
+ this.notifiedFromChannels = notifiedFromChannels;
+ this.notifiedFromCoordinator = notifiedFromCoordinator;
+ }
+ }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
index bde35a2..2c01c33 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration.operator.coordinator;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.operator.HeadOperator;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -37,7 +38,7 @@ import java.util.concurrent.Executors;
* SharedProgressAligner} when received aligned event from the operator, and emit the globally
* aligned event back after one round is globally aligned.
*/
-public class HeadOperatorCoordinator implements OperatorCoordinator {
+public class HeadOperatorCoordinator implements OperatorCoordinator, SharedProgressAlignerListener {
private final Context context;
@@ -50,18 +51,24 @@ public class HeadOperatorCoordinator implements OperatorCoordinator {
this.sharedProgressAligner = Objects.requireNonNull(sharedProgressAligner);
this.subtaskGateways = new SubtaskGateway[context.currentParallelism()];
- sharedProgressAligner.registerAlignedConsumer(context.getOperatorId(), this::onAligned);
+ sharedProgressAligner.registerAlignedListener(context.getOperatorId(), this);
}
@Override
public void start() {}
@Override
- public void subtaskReady(int i, SubtaskGateway subtaskGateway) {
- this.subtaskGateways[i] = subtaskGateway;
+ public void subtaskReady(int subtaskIndex, SubtaskGateway subtaskGateway) {
+ this.subtaskGateways[subtaskIndex] = subtaskGateway;
}
@Override
+ public void resetToCheckpoint(long checkpointId, @Nullable byte[] bytes) {}
+
+ @Override
+ public void subtaskFailed(int subtaskIndex, @Nullable Throwable throwable) {}
+
+ @Override
public void handleEventFromOperator(int subtaskIndex, OperatorEvent operatorEvent) {
if (operatorEvent instanceof SubtaskAlignedEvent) {
sharedProgressAligner.reportSubtaskProgress(
@@ -71,6 +78,11 @@ public class HeadOperatorCoordinator implements OperatorCoordinator {
}
}
+ @Override
+ public void checkpointCoordinator(long l, CompletableFuture<byte[]> completableFuture) {
+ sharedProgressAligner.requestCheckpoint(l, context.currentParallelism(), completableFuture);
+ }
+
public void onAligned(GloballyAlignedEvent globallyAlignedEvent) {
for (int i = 0; i < context.currentParallelism(); ++i) {
subtaskGateways[i].sendEvent(globallyAlignedEvent);
@@ -78,25 +90,21 @@ public class HeadOperatorCoordinator implements OperatorCoordinator {
}
@Override
- public void close() {
- sharedProgressAligner.unregisterConsumer(context.getOperatorId());
+ public void onCheckpointAligned(CoordinatorCheckpointEvent coordinatorCheckpointEvent) {
+ for (int i = 0; i < context.currentParallelism(); ++i) {
+ subtaskGateways[i].sendEvent(coordinatorCheckpointEvent);
+ }
}
@Override
- public void checkpointCoordinator(long l, CompletableFuture<byte[]> completableFuture) {
- completableFuture.complete(new byte[0]);
+ public void close() {
+ sharedProgressAligner.unregisterListener(context.getOperatorId());
}
@Override
public void notifyCheckpointComplete(long l) {}
@Override
- public void resetToCheckpoint(long l, @Nullable byte[] bytes) {}
-
- @Override
- public void subtaskFailed(int i, @Nullable Throwable throwable) {}
-
- @Override
public void subtaskReset(int i, long l) {}
/** The factory of {@link HeadOperatorCoordinator}. */
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
index 6d03a5e..9c2fd43 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration.operator.coordinator;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -31,12 +32,14 @@ import org.apache.flink.util.function.ThrowingRunnable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
-import java.util.function.Consumer;
import java.util.function.Supplier;
import static org.apache.flink.util.Preconditions.checkState;
@@ -44,7 +47,7 @@ import static org.apache.flink.util.Preconditions.checkState;
/**
* The progress aligner shared between multiple {@link HeadOperatorCoordinator}. It maintains the
* information for each round, once one round is aligned, it would notify all the register
- * consumers.
+ * listeners.
*/
public class SharedProgressAligner {
@@ -63,7 +66,9 @@ public class SharedProgressAligner {
private final Map<Integer, EpochStatus> statusByEpoch;
- private final Map<OperatorID, Consumer<GloballyAlignedEvent>> alignedConsumers;
+ private final Map<OperatorID, SharedProgressAlignerListener> listeners;
+
+ private final Map<Long, CheckpointStatus> checkpointStatuses;
public static SharedProgressAligner getOrCreate(
IterationID iterationId,
@@ -93,29 +98,28 @@ public class SharedProgressAligner {
this.executor = Objects.requireNonNull(executor);
this.statusByEpoch = new HashMap<>();
- this.alignedConsumers = new HashMap<>();
+ this.listeners = new HashMap<>();
+ this.checkpointStatuses = new HashMap<>();
}
- public void registerAlignedConsumer(
- OperatorID operatorID, Consumer<GloballyAlignedEvent> alignedConsumer) {
+ public void registerAlignedListener(
+ OperatorID operatorID, SharedProgressAlignerListener alignedConsumer) {
runInEventLoop(
- () -> this.alignedConsumers.put(operatorID, alignedConsumer),
- "Register consumer %s",
+ () -> this.listeners.put(operatorID, alignedConsumer),
+ "Register listeners %s",
operatorID.toHexString());
}
- public void unregisterConsumer(OperatorID operatorID) {
- synchronized (this) {
- runInEventLoop(
- () -> {
- this.alignedConsumers.remove(operatorID);
- if (alignedConsumers.isEmpty()) {
- instances.remove(iterationId);
- }
- },
- "Unregister consumer %s",
- operatorID.toHexString());
- }
+ public void unregisterListener(OperatorID operatorID) {
+ runInEventLoop(
+ () -> {
+ this.listeners.remove(operatorID);
+ if (listeners.isEmpty()) {
+ instances.remove(iterationId);
+ }
+ },
+ "Unregister listeners %s",
+ operatorID.toHexString());
}
public void reportSubtaskProgress(
@@ -137,8 +141,8 @@ public class SharedProgressAligner {
GloballyAlignedEvent globallyAlignedEvent =
new GloballyAlignedEvent(
subtaskAlignedEvent.getEpoch(), roundStatus.isTerminated());
- for (Consumer<GloballyAlignedEvent> consumer : alignedConsumers.values()) {
- consumer.accept(globallyAlignedEvent);
+ for (SharedProgressAlignerListener listeners : listeners.values()) {
+ listeners.onAligned(globallyAlignedEvent);
}
}
},
@@ -147,6 +151,37 @@ public class SharedProgressAligner {
subtaskIndex);
}
+ public void requestCheckpoint(
+ long checkpointId,
+ int operatorParallelism,
+ CompletableFuture<byte[]> snapshotStateFuture) {
+ runInEventLoop(
+ () -> {
+ CheckpointStatus checkpointStatus =
+ checkpointStatuses.computeIfAbsent(
+ checkpointId,
+ ignored -> new CheckpointStatus(totalHeadParallelism));
+ boolean aligned =
+ checkpointStatus.notify(operatorParallelism, snapshotStateFuture);
+ if (aligned) {
+ CoordinatorCheckpointEvent checkpointEvent =
+ new CoordinatorCheckpointEvent(checkpointId);
+ for (SharedProgressAlignerListener listener : listeners.values()) {
+ listener.onCheckpointAligned(checkpointEvent);
+ }
+
+ for (CompletableFuture<byte[]> stateFuture :
+ checkpointStatus.getStateFutures()) {
+ stateFuture.complete(new byte[0]);
+ }
+
+ checkpointStatuses.remove(checkpointId);
+ }
+ },
+ "Coordinator report checkpoint %d",
+ checkpointId);
+ }
+
private void runInEventLoop(
ThrowingRunnable<Throwable> action,
String actionName,
@@ -170,8 +205,8 @@ public class SharedProgressAligner {
}
@VisibleForTesting
- int getNumberConsumers() {
- return alignedConsumers.size();
+ int getNumberListeners() {
+ return listeners.size();
}
private static class EpochStatus {
@@ -224,4 +259,28 @@ public class SharedProgressAligner {
return totalRecord == 0 || (hasCriteriaStream && totalCriteriaRecord == 0);
}
}
+
+ private static class CheckpointStatus {
+
+ private final long totalHeadParallelism;
+
+ private final List<CompletableFuture<byte[]>> stateFutures = new ArrayList<>();
+
+ private int notifiedCoordinatorParallelism;
+
+ private CheckpointStatus(long totalHeadParallelism) {
+ this.totalHeadParallelism = totalHeadParallelism;
+ }
+
+ public boolean notify(int parallelism, CompletableFuture<byte[]> stateFuture) {
+ stateFutures.add(stateFuture);
+ notifiedCoordinatorParallelism += parallelism;
+
+ return notifiedCoordinatorParallelism == totalHeadParallelism;
+ }
+
+ public List<CompletableFuture<byte[]>> getStateFutures() {
+ return stateFutures;
+ }
+ }
}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerListener.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerListener.java
new file mode 100644
index 0000000..6b687dd
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerListener.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.coordinator;
+
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
+import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+
+/** The listener of the {@link SharedProgressAligner}. */
+public interface SharedProgressAlignerListener {
+
+ void onAligned(GloballyAlignedEvent globallyAlignedEvent);
+
+ void onCheckpointAligned(CoordinatorCheckpointEvent coordinatorCheckpointEvent);
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/CoordinatorCheckpointEvent.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/CoordinatorCheckpointEvent.java
new file mode 100644
index 0000000..9c97581
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/CoordinatorCheckpointEvent.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.event;
+
+import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+
+import java.util.Objects;
+
+/** Coordinator received the request of checkpoints. */
+public class CoordinatorCheckpointEvent implements OperatorEvent {
+
+ private final long checkpointId;
+
+ public CoordinatorCheckpointEvent(long checkpointId) {
+ this.checkpointId = checkpointId;
+ }
+
+ public long getCheckpointId() {
+ return checkpointId;
+ }
+
+ @Override
+ public String toString() {
+ return "CoordinatorCheckpointEvent{" + "checkpointId=" + checkpointId + '}';
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof CoordinatorCheckpointEvent)) {
+ return false;
+ }
+ CoordinatorCheckpointEvent that = (CoordinatorCheckpointEvent) o;
+ return checkpointId == that.checkpointId;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(checkpointId);
+ }
+}
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index 3ded30a..e6d1e26 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -21,15 +21,20 @@ package org.apache.flink.iteration.operator;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
import org.apache.flink.runtime.io.network.api.EndOfData;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
+import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -152,7 +157,7 @@ public class HeadOperatorTest extends TestLogger {
while (RecordingHeadOperatorFactory.latestHeadOperator
.getStatus()
- == HeadOperator.HeadOperatorStatus.RUNNING) ;
+ == HeadOperator.HeadOperatorStatus.RUNNING) {}
putFeedbackRecords(
iterationId,
IterationRecord.newEpochWatermark(
@@ -197,6 +202,116 @@ public class HeadOperatorTest extends TestLogger {
});
}
+ @Test(timeout = 60000)
+ public void testHoldCheckpointTillCoordinatorNotified() throws Exception {
+ IterationID iterationId = new IterationID();
+ OperatorID operatorId = new OperatorID();
+
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ null,
+ harness -> {
+ CompletableFuture<Void> coordinatorResult =
+ CompletableFuture.supplyAsync(
+ () -> {
+ try {
+ // Slight postpone the notification
+ Thread.sleep(2000);
+
+ harness.getStreamTask()
+ .dispatchOperatorEvent(
+ operatorId,
+ new SerializedValue<>(
+ new GloballyAlignedEvent(
+ 5, false)));
+ harness.getStreamTask()
+ .dispatchOperatorEvent(
+ operatorId,
+ new SerializedValue<>(
+ new CoordinatorCheckpointEvent(
+ 5)));
+ return null;
+ } catch (Throwable e) {
+ RecordingHeadOperatorFactory.latestHeadOperator
+ .getMailboxExecutor()
+ .execute(
+ () -> {
+ throw e;
+ },
+ "poison mail");
+ throw new CompletionException(e);
+ }
+ });
+
+ CheckpointBarrier barrier =
+ new CheckpointBarrier(
+ 5,
+ 5000,
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference.getDefault()));
+ harness.processEvent(barrier);
+
+ // There should be no exception
+ coordinatorResult.get();
+
+ // If the task do not hold, it would be likely snapshot state before received
+ // the globally aligned event.
+ assertEquals(
+ Arrays.asList(
+ new StreamRecord<>(
+ IterationRecord.newEpochWatermark(
+ 5,
+ OperatorUtils.getUniqueSenderId(operatorId, 0)),
+ 0),
+ barrier),
+ new ArrayList<>(harness.getOutput()));
+ return null;
+ });
+ }
+
+ @Test(timeout = 60000)
+ public void testPostponeGloballyAlignedEventsAfterSnapshot() throws Exception {
+ IterationID iterationId = new IterationID();
+ OperatorID operatorId = new OperatorID();
+
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ null,
+ harness -> {
+ harness.getStreamTask()
+ .dispatchOperatorEvent(
+ operatorId,
+ new SerializedValue<>(new CoordinatorCheckpointEvent(5)));
+ harness.getStreamTask()
+ .dispatchOperatorEvent(
+ operatorId,
+ new SerializedValue<>(new GloballyAlignedEvent(5, false)));
+ CheckpointBarrier barrier =
+ new CheckpointBarrier(
+ 5,
+ 5000,
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference.getDefault()));
+ harness.processEvent(barrier);
+ harness.processAll();
+
+ assertEquals(
+ Arrays.asList(
+ barrier,
+ new StreamRecord<>(
+ IterationRecord.newEpochWatermark(
+ 5,
+ OperatorUtils.getUniqueSenderId(operatorId, 0)),
+ 0)),
+ new ArrayList<>(harness.getOutput()));
+ return null;
+ });
+ }
+
private <T> T createHarnessAndRun(
IterationID iterationId,
OperatorID operatorId,
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java
index 89ecba6..350f7bf 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java
@@ -19,6 +19,7 @@
package org.apache.flink.iteration.operator.coordinator;
import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -32,7 +33,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
-import java.util.function.Consumer;
+import java.util.concurrent.CompletableFuture;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@@ -64,22 +65,22 @@ public class SharedProgressAlignerTest extends TestLogger {
}
@Test
- public void testRegisterAndUnregisterConsumers() {
+ public void testRegisterAndUnregisterListeners() {
IterationID iterationId = new IterationID();
List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
- List<Consumer<GloballyAlignedEvent>> consumers =
- Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+ List<SharedProgressAlignerListener> listeners =
+ Arrays.asList(new RecordingListener(), new RecordingListener());
SharedProgressAligner aligner =
- initializeAligner(iterationId, operatorIds, Arrays.asList(2, 3), consumers);
+ initializeAligner(iterationId, operatorIds, Arrays.asList(2, 3), listeners);
- assertEquals(2, aligner.getNumberConsumers());
+ assertEquals(2, aligner.getNumberListeners());
- aligner.unregisterConsumer(operatorIds.get(0));
- assertEquals(1, aligner.getNumberConsumers());
+ aligner.unregisterListener(operatorIds.get(0));
+ assertEquals(1, aligner.getNumberListeners());
assertTrue(SharedProgressAligner.getInstances().containsKey(iterationId));
- aligner.unregisterConsumer(operatorIds.get(1));
- assertEquals(0, aligner.getNumberConsumers());
+ aligner.unregisterListener(operatorIds.get(1));
+ assertEquals(0, aligner.getNumberListeners());
assertFalse(SharedProgressAligner.getInstances().containsKey(iterationId));
}
@@ -88,10 +89,10 @@ public class SharedProgressAlignerTest extends TestLogger {
IterationID iterationId = new IterationID();
List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
List<Integer> parallelisms = Arrays.asList(2, 3);
- List<RecordingConsumer> consumers =
- Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+ List<RecordingListener> listeners =
+ Arrays.asList(new RecordingListener(), new RecordingListener());
SharedProgressAligner aligner =
- initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+ initializeAligner(iterationId, operatorIds, parallelisms, listeners);
for (int i = 0; i < operatorIds.size(); ++i) {
for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -100,8 +101,8 @@ public class SharedProgressAlignerTest extends TestLogger {
}
}
- checkRecordingConsumers(
- Collections.singletonList(new GloballyAlignedEvent(2, false)), consumers);
+ this.checkGloballyAlignedEvents(
+ Collections.singletonList(new GloballyAlignedEvent(2, false)), listeners);
}
@Test
@@ -109,10 +110,10 @@ public class SharedProgressAlignerTest extends TestLogger {
IterationID iterationId = new IterationID();
List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
List<Integer> parallelisms = Arrays.asList(2, 3);
- List<RecordingConsumer> consumers =
- Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+ List<RecordingListener> listeners =
+ Arrays.asList(new RecordingListener(), new RecordingListener());
SharedProgressAligner aligner =
- initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+ initializeAligner(iterationId, operatorIds, parallelisms, listeners);
for (int i = 0; i < operatorIds.size(); ++i) {
for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -121,8 +122,8 @@ public class SharedProgressAlignerTest extends TestLogger {
}
}
- checkRecordingConsumers(
- Collections.singletonList(new GloballyAlignedEvent(2, true)), consumers);
+ this.checkGloballyAlignedEvents(
+ Collections.singletonList(new GloballyAlignedEvent(2, true)), listeners);
}
@Test
@@ -130,10 +131,10 @@ public class SharedProgressAlignerTest extends TestLogger {
IterationID iterationId = new IterationID();
List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
List<Integer> parallelisms = Arrays.asList(2, 3);
- List<RecordingConsumer> consumers =
- Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+ List<RecordingListener> listeners =
+ Arrays.asList(new RecordingListener(), new RecordingListener());
SharedProgressAligner aligner =
- initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+ initializeAligner(iterationId, operatorIds, parallelisms, listeners);
for (int i = 0; i < operatorIds.size(); ++i) {
for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -142,8 +143,8 @@ public class SharedProgressAlignerTest extends TestLogger {
}
}
- checkRecordingConsumers(
- Collections.singletonList(new GloballyAlignedEvent(0, false)), consumers);
+ this.checkGloballyAlignedEvents(
+ Collections.singletonList(new GloballyAlignedEvent(0, false)), listeners);
}
@Test
@@ -151,10 +152,10 @@ public class SharedProgressAlignerTest extends TestLogger {
IterationID iterationId = new IterationID();
List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
List<Integer> parallelisms = Arrays.asList(2, 3);
- List<RecordingConsumer> consumers =
- Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+ List<RecordingListener> listeners =
+ Arrays.asList(new RecordingListener(), new RecordingListener());
SharedProgressAligner aligner =
- initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+ initializeAligner(iterationId, operatorIds, parallelisms, listeners);
for (int i = 0; i < operatorIds.size(); ++i) {
for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -164,15 +165,46 @@ public class SharedProgressAlignerTest extends TestLogger {
}
}
- checkRecordingConsumers(
- Collections.singletonList(new GloballyAlignedEvent(2, true)), consumers);
+ this.checkGloballyAlignedEvents(
+ Collections.singletonList(new GloballyAlignedEvent(2, true)), listeners);
+ }
+
+ @Test
+ public void testSendEventsBeforeCompleteCheckpoint() {
+ IterationID iterationId = new IterationID();
+ List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
+ List<Integer> parallelisms = Arrays.asList(2, 3);
+ List<RecordingListener> listeners =
+ Arrays.asList(new RecordingListener(), new RecordingListener());
+ SharedProgressAligner aligner =
+ initializeAligner(iterationId, operatorIds, parallelisms, listeners);
+
+ List<CompletableFuture<byte[]>> firstCheckpointStateFutures =
+ Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>());
+ for (int i = 0; i < operatorIds.size(); ++i) {
+ // Operator 0 is the criteria stream
+ aligner.requestCheckpoint(1, parallelisms.get(i), firstCheckpointStateFutures.get(i));
+ }
+
+ List<CompletableFuture<byte[]>> secondCheckpointStateFutures =
+ Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>());
+ for (int i = 0; i < operatorIds.size(); ++i) {
+ // Operator 0 is the criteria stream
+ aligner.requestCheckpoint(2, parallelisms.get(i), secondCheckpointStateFutures.get(i));
+ }
+
+ firstCheckpointStateFutures.forEach(future -> assertTrue(future.isDone()));
+ secondCheckpointStateFutures.forEach(future -> assertTrue(future.isDone()));
+ checkCoordinatorCheckpointEvents(
+ Arrays.asList(new CoordinatorCheckpointEvent(1), new CoordinatorCheckpointEvent(2)),
+ listeners);
}
private SharedProgressAligner initializeAligner(
IterationID iterationId,
List<OperatorID> operatorIds,
List<Integer> parallelisms,
- List<? extends Consumer<GloballyAlignedEvent>> consumers) {
+ List<? extends SharedProgressAlignerListener> listeners) {
SharedProgressAligner aligner =
SharedProgressAligner.getOrCreate(
@@ -181,28 +213,43 @@ public class SharedProgressAlignerTest extends TestLogger {
new MockOperatorCoordinatorContext(operatorIds.get(0), parallelisms.get(0)),
DirectScheduledExecutorService::new);
- for (int i = 0; i < consumers.size(); ++i) {
- aligner.registerAlignedConsumer(operatorIds.get(i), consumers.get(i));
+ for (int i = 0; i < listeners.size(); ++i) {
+ aligner.registerAlignedListener(operatorIds.get(i), listeners.get(i));
}
return aligner;
}
- private void checkRecordingConsumers(
+ private void checkGloballyAlignedEvents(
List<GloballyAlignedEvent> expectedGloballyAlignedEvents,
- List<RecordingConsumer> consumers) {
- for (RecordingConsumer consumer : consumers) {
+ List<RecordingListener> listeners) {
+ for (RecordingListener consumer : listeners) {
assertEquals(expectedGloballyAlignedEvents, consumer.globallyAlignedEvents);
}
}
- private static class RecordingConsumer implements Consumer<GloballyAlignedEvent> {
+ private void checkCoordinatorCheckpointEvents(
+ List<CoordinatorCheckpointEvent> expectedGloballyAlignedEvents,
+ List<RecordingListener> listeners) {
+ for (RecordingListener consumer : listeners) {
+ assertEquals(expectedGloballyAlignedEvents, consumer.checkpointEvents);
+ }
+ }
+
+ private static class RecordingListener implements SharedProgressAlignerListener {
final List<GloballyAlignedEvent> globallyAlignedEvents = new ArrayList<>();
+ final List<CoordinatorCheckpointEvent> checkpointEvents = new ArrayList<>();
+
@Override
- public void accept(GloballyAlignedEvent globallyAlignedEvent) {
+ public void onAligned(GloballyAlignedEvent globallyAlignedEvent) {
globallyAlignedEvents.add(globallyAlignedEvent);
}
+
+ @Override
+ public void onCheckpointAligned(CoordinatorCheckpointEvent coordinatorCheckpointEvent) {
+ checkpointEvents.add(coordinatorCheckpointEvent);
+ }
}
}