You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:43 UTC
[flink-ml] 05/08: [FLINK-24655][iteration] Support the checkpoints
for the iteration
This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 4b47e5df79b35336c4271f9dce1cbac9781451ef
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Oct 7 00:39:49 2021 +0800
[FLINK-24655][iteration] Support the checkpoints for the iteration
---
.../flink/iteration/checkpoint/Checkpoints.java | 77 +++-
.../iteration/checkpoint/CheckpointsBroker.java | 57 +++
.../flink/iteration/operator/HeadOperator.java | 249 ++++++++++-
.../operator/HeadOperatorCheckpointAligner.java | 26 ++
.../iteration/operator/HeadOperatorFactory.java | 8 +
...dOperatorState.java => OperatorStateUtils.java} | 26 +-
.../flink/iteration/operator/TailOperator.java | 27 ++
.../coordinator/HeadOperatorCoordinator.java | 10 +-
.../coordinator/SharedProgressAligner.java | 28 ++
.../headprocessor/HeadOperatorRecordProcessor.java | 7 +-
.../operator/headprocessor/HeadOperatorState.java | 32 +-
.../RegularHeadOperatorRecordProcessor.java | 110 ++++-
.../TerminatingHeadOperatorRecordProcessor.java | 7 +-
.../flink/iteration/IterationConstructionTest.java | 38 ++
.../flink/iteration/operator/HeadOperatorTest.java | 496 ++++++++++++++++++++-
15 files changed, 1154 insertions(+), 44 deletions(-)
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
index 03420f8..edbfeba 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
@@ -19,11 +19,13 @@
package org.apache.flink.iteration.checkpoint;
import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.core.fs.FileSystem;
import org.apache.flink.core.fs.Path;
import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
import org.apache.flink.iteration.datacache.nonkeyed.DataCacheWriter;
import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.ResourceGuard;
import org.apache.flink.util.function.SupplierWithException;
@@ -33,6 +35,9 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.SortedMap;
import java.util.TreeMap;
+import java.util.concurrent.ConcurrentHashMap;
+
+import static org.apache.flink.util.Preconditions.checkState;
/** Maintains the pending checkpoints. */
public class Checkpoints<T> implements AutoCloseable {
@@ -43,7 +48,10 @@ public class Checkpoints<T> implements AutoCloseable {
private final FileSystem fileSystem;
private final SupplierWithException<Path, IOException> pathSupplier;
- private final TreeMap<Long, PendingCheckpoint> uncompletedCheckpoints = new TreeMap<>();
+ private final ConcurrentHashMap<Long, Tuple2<PendingCheckpoint, Boolean>>
+ uncompletedCheckpoints = new ConcurrentHashMap<>();
+
+ private final TreeMap<Long, PendingCheckpoint> sortedUncompletedCheckpoints = new TreeMap<>();
public Checkpoints(
TypeSerializer<T> typeSerializer,
@@ -51,27 +59,72 @@ public class Checkpoints<T> implements AutoCloseable {
SupplierWithException<Path, IOException> pathSupplier) {
this.typeSerializer = typeSerializer;
this.fileSystem = fileSystem;
+ checkState(!fileSystem.isDistributedFS(), "Currently only local fs is supported");
this.pathSupplier = pathSupplier;
}
+ public TypeSerializer<T> getTypeSerializer() {
+ return typeSerializer;
+ }
+
+ public FileSystem getFileSystem() {
+ return fileSystem;
+ }
+
+ public SupplierWithException<Path, IOException> getPathSupplier() {
+ return pathSupplier;
+ }
+
public void startLogging(long checkpointId, OperatorStateCheckpointOutputStream outputStream)
throws IOException {
- DataCacheWriter<T> dataCacheWriter =
- new DataCacheWriter<>(typeSerializer, fileSystem, pathSupplier);
- ResourceGuard.Lease snapshotLease = outputStream.acquireLease();
- uncompletedCheckpoints.put(
- checkpointId, new PendingCheckpoint(dataCacheWriter, outputStream, snapshotLease));
+ Tuple2<PendingCheckpoint, Boolean> possibleCheckpoint =
+ uncompletedCheckpoints.computeIfAbsent(
+ checkpointId,
+ ignored -> {
+ try {
+ DataCacheWriter<T> dataCacheWriter =
+ new DataCacheWriter<>(
+ typeSerializer, fileSystem, pathSupplier);
+ ResourceGuard.Lease snapshotLease = outputStream.acquireLease();
+ return new Tuple2<>(
+ new PendingCheckpoint(
+ dataCacheWriter, outputStream, snapshotLease),
+ false);
+ } catch (IOException e) {
+ throw new FlinkRuntimeException(e);
+ }
+ });
+
+ // If canceled, return
+ if (possibleCheckpoint.f1) {
+ return;
+ }
+
+ sortedUncompletedCheckpoints.put(checkpointId, possibleCheckpoint.f0);
+ }
+
+ public void abort(long checkpointId) {
+ uncompletedCheckpoints.compute(
+ checkpointId,
+ (k, v) -> {
+ if (v == null) {
+ return new Tuple2<>(null, true);
+ } else {
+ v.f0.snapshotLease.close();
+ return new Tuple2<>(v.f0, true);
+ }
+ });
}
public void append(T element) throws IOException {
- for (PendingCheckpoint pendingCheckpoint : uncompletedCheckpoints.values()) {
+ for (PendingCheckpoint pendingCheckpoint : sortedUncompletedCheckpoints.values()) {
pendingCheckpoint.dataCacheWriter.addRecord(element);
}
}
public void commitCheckpointsUntil(long checkpointId) {
SortedMap<Long, PendingCheckpoint> completedCheckpoints =
- uncompletedCheckpoints.headMap(checkpointId, true);
+ sortedUncompletedCheckpoints.headMap(checkpointId, true);
completedCheckpoints
.values()
.forEach(
@@ -86,9 +139,13 @@ public class Checkpoints<T> implements AutoCloseable {
.getFinishSegments());
pendingCheckpoint.checkpointOutputStream.startNewPartition();
snapshot.writeTo(pendingCheckpoint.checkpointOutputStream);
+
+ // Directly cleanup all the files since we are using the local fs.
+ // TODO: support of the remote fs.
pendingCheckpoint.dataCacheWriter.cleanup();
} catch (Exception e) {
LOG.error("Failed to commit checkpoint until " + checkpointId, e);
+ throw new FlinkRuntimeException(e);
} finally {
pendingCheckpoint.snapshotLease.close();
}
@@ -99,7 +156,7 @@ public class Checkpoints<T> implements AutoCloseable {
@Override
public void close() {
- uncompletedCheckpoints.forEach(
+ sortedUncompletedCheckpoints.forEach(
(checkpointId, pendingCheckpoint) -> {
pendingCheckpoint.snapshotLease.close();
try {
@@ -108,10 +165,12 @@ public class Checkpoints<T> implements AutoCloseable {
LOG.error("Failed to cleanup " + checkpointId, e);
}
});
+ sortedUncompletedCheckpoints.clear();
uncompletedCheckpoints.clear();
}
private class PendingCheckpoint {
+
final DataCacheWriter<T> dataCacheWriter;
final OperatorStateCheckpointOutputStream checkpointOutputStream;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java
new file mode 100644
index 0000000..e969c10
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.checkpoint;
+
+import org.apache.flink.statefun.flink.core.feedback.SubtaskFeedbackKey;
+
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Hand offs the {@link Checkpoints} from the head operator to the tail operator so that the tail
+ * operator could decrease the reference count of the raw state when checkpoints are aborted. We
+ * could not count on the head operator since it would be blocked on closing the raw state when
+ * aborting the checkpoint. It also looks like a bug.
+ */
+public class CheckpointsBroker {
+
+ private static final CheckpointsBroker INSTANCE = new CheckpointsBroker();
+
+ private final ConcurrentHashMap<SubtaskFeedbackKey<?>, Checkpoints<?>> checkpointManagers =
+ new ConcurrentHashMap<>();
+
+ public static CheckpointsBroker get() {
+ return INSTANCE;
+ }
+
+ public <V> void setCheckpoints(SubtaskFeedbackKey<V> key, Checkpoints<V> checkpoints) {
+ checkpointManagers.put(key, checkpoints);
+ }
+
+ @SuppressWarnings({"unchecked"})
+ public <V> Checkpoints<V> getCheckpoints(SubtaskFeedbackKey<V> key) {
+ Objects.requireNonNull(key);
+ return (Checkpoints<V>) Objects.requireNonNull(checkpointManagers.get(key));
+ }
+
+ @SuppressWarnings("resource")
+ void removeChannel(SubtaskFeedbackKey<?> key) {
+ checkpointManagers.remove(key);
+ }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index 6e5a7b4..5306d2b 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -21,23 +21,46 @@ package org.apache.flink.iteration.operator;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.operators.MailboxExecutor;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.fs.Path;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.IterationListener;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.broadcast.BroadcastOutput;
import org.apache.flink.iteration.broadcast.BroadcastOutputFactory;
+import org.apache.flink.iteration.checkpoint.Checkpoints;
+import org.apache.flink.iteration.checkpoint.CheckpointsBroker;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.iteration.operator.headprocessor.HeadOperatorRecordProcessor;
+import org.apache.flink.iteration.operator.headprocessor.HeadOperatorState;
import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
import org.apache.flink.iteration.operator.headprocessor.TerminatingHeadOperatorRecordProcessor;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.iteration.utils.ReflectionUtils;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferConsumerWithPartialRecordLength;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartitionView;
+import org.apache.flink.runtime.io.network.partition.PrioritizedDeque;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
import org.apache.flink.runtime.operators.coordination.OperatorEvent;
import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
@@ -57,6 +80,9 @@ import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.OutputTag;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
import java.util.Objects;
import java.util.concurrent.Executor;
@@ -112,6 +138,16 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
private HeadOperatorCheckpointAligner checkpointAligner;
+ // ------------- states -------------------
+
+ private ListState<Integer> parallelismState;
+
+ private ListState<Integer> statusState;
+
+ private ListState<HeadOperatorState> processorState;
+
+ private Checkpoints<IterationRecord<?>> checkpoints;
+
public HeadOperator(
IterationID iterationId,
int feedbackIndex,
@@ -146,12 +182,87 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
public void initializeState(StateInitializationContext context) throws Exception {
super.initializeState(context);
+ parallelismState =
+ context.getOperatorStateStore()
+ .getUnionListState(
+ new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
+ OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
+ .ifPresent(
+ oldParallelism ->
+ checkState(
+ oldParallelism
+ == getRuntimeContext()
+ .getNumberOfParallelSubtasks(),
+ "The head operator is recovered with parallelism changed from "
+ + oldParallelism
+ + " to "
+ + getRuntimeContext()
+ .getNumberOfParallelSubtasks()));
+
+ // Initialize the status and the record processor.
processorContext = new ContextImpl();
- status = HeadOperatorStatus.RUNNING;
- recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
+ statusState =
+ context.getOperatorStateStore()
+ .getListState(new ListStateDescriptor<>("status", Integer.class));
+ status =
+ HeadOperatorStatus.values()[
+ OperatorStateUtils.getUniqueElement(statusState, "status").orElse(0)];
+ if (status == HeadOperatorStatus.RUNNING) {
+ recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
+ } else {
+ recordProcessor = new TerminatingHeadOperatorRecordProcessor();
+ }
+
+ // Recover the process state if exists.
+ processorState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "processorState", HeadOperatorState.class));
+ OperatorStateUtils.getUniqueElement(processorState, "processorState")
+ .ifPresent(
+ headOperatorState ->
+ recordProcessor.initializeState(
+ headOperatorState, context.getRawOperatorStateInputs()));
checkpointAligner = new HeadOperatorCheckpointAligner();
+ // Initialize the checkpoints
+ Path dataCachePath =
+ OperatorUtils.getDataCachePath(
+ getRuntimeContext().getTaskManagerRuntimeInfo().getConfiguration(),
+ getContainingTask()
+ .getEnvironment()
+ .getIOManager()
+ .getSpillingDirectoriesPaths());
+ this.checkpoints =
+ new Checkpoints<>(
+ config.getTypeSerializerOut(getClass().getClassLoader()),
+ dataCachePath.getFileSystem(),
+ OperatorUtils.createDataCacheFileGenerator(
+ dataCachePath, "header-cp", getOperatorConfig().getOperatorID()));
+ CheckpointsBroker.get()
+ .setCheckpoints(
+ OperatorUtils.<IterationRecord<?>>createFeedbackKey(
+ iterationId, feedbackIndex)
+ .withSubTaskIndex(
+ getRuntimeContext().getIndexOfThisSubtask(),
+ getRuntimeContext().getAttemptNumber()),
+ checkpoints);
+
+ try {
+ for (StatePartitionStreamProvider rawStateInput : context.getRawOperatorStateInputs()) {
+ DataCacheSnapshot.replay(
+ rawStateInput.getStream(),
+ checkpoints.getTypeSerializer(),
+ checkpoints.getFileSystem(),
+ (record) ->
+ recordProcessor.processFeedbackElement(new StreamRecord<>(record)));
+ }
+ } catch (Exception e) {
+ throw new FlinkRuntimeException("Failed to replay the records", e);
+ }
+
// Here we register a mail
registerFeedbackConsumer(
(Runnable runnable) -> {
@@ -171,18 +282,54 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
@Override
public void snapshotState(StateSnapshotContext context) throws Exception {
super.snapshotState(context);
+
+ // Always clear the union list state before set value.
+ parallelismState.clear();
+ if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
+ parallelismState.update(
+ Collections.singletonList(getRuntimeContext().getNumberOfParallelSubtasks()));
+ }
+ statusState.update(Collections.singletonList(status.ordinal()));
+
+ HeadOperatorState currentProcessorState = recordProcessor.snapshotState();
+ if (currentProcessorState != null) {
+ processorState.update(Collections.singletonList(currentProcessorState));
+ } else {
+ processorState.clear();
+ }
+
+ if (status == HeadOperatorStatus.RUNNING) {
+ checkpoints.startLogging(
+ context.getCheckpointId(), context.getRawOperatorStateOutput());
+ }
+
checkpointAligner
.onStateSnapshot(context.getCheckpointId())
.forEach(this::processGloballyAlignedEvent);
}
@Override
+ public void notifyCheckpointAborted(long checkpointId) throws Exception {
+ super.notifyCheckpointAborted(checkpointId);
+
+ checkpointAligner
+ .onCheckpointAborted(checkpointId)
+ .forEach(this::processGloballyAlignedEvent);
+ }
+
+ @Override
public void processElement(StreamRecord<IterationRecord<?>> element) throws Exception {
recordProcessor.processElement(element);
}
@Override
public void processFeedback(StreamRecord<IterationRecord<?>> iterationRecord) throws Exception {
+ if (iterationRecord.getValue().getType() == IterationRecord.Type.BARRIER) {
+ checkpoints.commitCheckpointsUntil(iterationRecord.getValue().getCheckpointId());
+ return;
+ }
+
+ checkpoints.append(iterationRecord.getValue());
boolean terminated = recordProcessor.processFeedbackElement(iterationRecord);
if (terminated) {
checkState(status == HeadOperatorStatus.TERMINATING);
@@ -213,13 +360,65 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
@Override
public void endInput() throws Exception {
- recordProcessor.processElement(
- new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+ if (status == HeadOperatorStatus.RUNNING) {
+ recordProcessor.processElement(
+ new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+ }
+
+ // Since we choose to block here, we could not continue to process the barriers received
+ // from the task inputs, which would block the precedent tasks from finishing since
+ // they need to complete their final checkpoint. This is a temporary solution to this issue
+ // that we will check the input channels, trigger all the checkpoints until we see
+ // the EndOfPartitionEvent.
+ checkState(getContainingTask().getEnvironment().getAllInputGates().length == 1);
+ checkState(
+ getContainingTask()
+ .getEnvironment()
+ .getAllInputGates()[0]
+ .getNumberOfInputChannels()
+ == 1);
+ InputChannel inputChannel =
+ getContainingTask().getEnvironment().getAllInputGates()[0].getChannel(0);
+
+ boolean endOfPartitionReceived = false;
+ long lastTriggerCheckpointId = 0;
+ while (!endOfPartitionReceived && status != HeadOperatorStatus.TERMINATED) {
+ mailboxExecutor.tryYield();
+ Thread.sleep(200);
+
+ List<AbstractEvent> events = parseInputChannelEvents(inputChannel);
+
+ for (AbstractEvent event : events) {
+ if (event instanceof CheckpointBarrier) {
+ CheckpointBarrier barrier = (CheckpointBarrier) event;
+ if (barrier.getId() > lastTriggerCheckpointId) {
+ getContainingTask()
+ .triggerCheckpointAsync(
+ new CheckpointMetaData(
+ barrier.getId(), barrier.getTimestamp()),
+ barrier.getCheckpointOptions());
+ lastTriggerCheckpointId = barrier.getId();
+ }
+
+ } else if (event instanceof EndOfPartitionEvent) {
+ endOfPartitionReceived = true;
+ }
+ }
+ }
+
+ // By here we could step into the normal loop.
while (status != HeadOperatorStatus.TERMINATED) {
mailboxExecutor.yield();
}
}
+ @Override
+ public void close() throws Exception {
+ if (checkpoints != null) {
+ checkpoints.close();
+ }
+ }
+
private void registerFeedbackConsumer(Executor mailboxExecutor) {
int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
int attemptNum = getRuntimeContext().getAttemptNumber();
@@ -232,6 +431,48 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
OperatorUtils.registerFeedbackConsumer(channel, this, mailboxExecutor);
}
+ private List<AbstractEvent> parseInputChannelEvents(InputChannel inputChannel)
+ throws Exception {
+ List<AbstractEvent> events = new ArrayList<>();
+ if (inputChannel instanceof RemoteInputChannel) {
+ Class<?> seqBufferClass =
+ Class.forName(
+ "org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel$SequenceBuffer");
+ PrioritizedDeque<?> queue =
+ ReflectionUtils.getFieldValue(
+ inputChannel, RemoteInputChannel.class, "receivedBuffers");
+ for (Object sequenceBuffer : queue) {
+ Buffer buffer =
+ ReflectionUtils.getFieldValue(sequenceBuffer, seqBufferClass, "buffer");
+ if (!buffer.isBuffer()) {
+ events.add(EventSerializer.fromBuffer(buffer, getClass().getClassLoader()));
+ }
+ }
+ } else if (inputChannel instanceof LocalInputChannel) {
+ PipelinedSubpartitionView subpartitionView =
+ ReflectionUtils.getFieldValue(
+ inputChannel, LocalInputChannel.class, "subpartitionView");
+ PipelinedSubpartition pipelinedSubpartition =
+ ReflectionUtils.getFieldValue(
+ subpartitionView, PipelinedSubpartitionView.class, "parent");
+ PrioritizedDeque<BufferConsumerWithPartialRecordLength> queue =
+ ReflectionUtils.getFieldValue(
+ pipelinedSubpartition, PipelinedSubpartition.class, "buffers");
+ for (BufferConsumerWithPartialRecordLength bufferConsumer : queue) {
+ if (!bufferConsumer.getBufferConsumer().isBuffer()) {
+ events.add(
+ EventSerializer.fromBuffer(
+ bufferConsumer.getBufferConsumer().copy().build(),
+ getClass().getClassLoader()));
+ }
+ }
+ } else {
+ LOG.warn("Unknown input channel type: " + inputChannel);
+ }
+
+ return events;
+ }
+
@VisibleForTesting
public OperatorEventGateway getOperatorEventGateway() {
return operatorEventGateway;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
index b4f4215..9b83a7d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
@@ -24,6 +24,7 @@ import org.apache.flink.util.function.RunnableWithException;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
import java.util.TreeMap;
@@ -40,6 +41,8 @@ class HeadOperatorCheckpointAligner {
private long latestCheckpointFromCoordinator;
+ private long latestAbortedCheckpoint;
+
HeadOperatorCheckpointAligner() {
this.checkpointAlignmments = new TreeMap<>();
}
@@ -58,6 +61,13 @@ class HeadOperatorCheckpointAligner {
void coordinatorNotify(CoordinatorCheckpointEvent checkpointEvent) {
checkState(checkpointEvent.getCheckpointId() > latestCheckpointFromCoordinator);
latestCheckpointFromCoordinator = checkpointEvent.getCheckpointId();
+
+ // Do nothing if later checkpoint is aborted. In this case there should not be
+ // the notification from the task side.
+ if (latestCheckpointFromCoordinator <= latestAbortedCheckpoint) {
+ return;
+ }
+
CheckpointAlignment checkpointAlignment =
checkpointAlignmments.computeIfAbsent(
checkpointEvent.getCheckpointId(),
@@ -86,6 +96,22 @@ class HeadOperatorCheckpointAligner {
return checkpointAlignment.pendingGlobalEvents;
}
+ List<GloballyAlignedEvent> onCheckpointAborted(long checkpointId) {
+ // Here we need to abort all the checkpoints <= notified checkpoint id.
+ checkState(checkpointId > latestAbortedCheckpoint);
+ latestAbortedCheckpoint = checkpointId;
+
+ Map<Long, CheckpointAlignment> abortedAlignments =
+ checkpointAlignmments.headMap(latestAbortedCheckpoint, true);
+ List<GloballyAlignedEvent> events = new ArrayList<>();
+ abortedAlignments
+ .values()
+ .forEach(alignment -> events.addAll(alignment.pendingGlobalEvents));
+ abortedAlignments.clear();
+
+ return events;
+ }
+
private static class CheckpointAlignment {
final List<GloballyAlignedEvent> pendingGlobalEvents;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java
index 2bd1e38..5040506 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
import org.apache.flink.streaming.api.operators.CoordinatedOperatorFactory;
import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory;
import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -132,4 +133,11 @@ public class HeadOperatorFactory extends AbstractStreamOperatorFactory<Iteration
// We need it to be yielding operator factory to disable chaining,
// but we cannot use the given mailbox here since it has bugs.
}
+
+ @Override
+ public ChainingStrategy getChainingStrategy() {
+ // We could not allow the head operator chaining with the previous operator since
+ // the special treatment in endInput.
+ return ChainingStrategy.HEAD;
+ }
}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java
similarity index 51%
copy from flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
copy to flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java
index 861f481..fc4cd07 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java
@@ -16,7 +16,27 @@
* limitations under the License.
*/
-package org.apache.flink.iteration.operator.headprocessor;
+package org.apache.flink.iteration.operator;
-/** The state entry for the head operator. */
-public class HeadOperatorState {}
+import org.apache.flink.api.common.state.ListState;
+
+import java.util.Iterator;
+import java.util.Optional;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Utility to deal with the states inside the operator. */
+public class OperatorStateUtils {
+
+ public static <T> Optional<T> getUniqueElement(ListState<T> listState, String stateName)
+ throws Exception {
+ Iterator<T> iterator = listState.get().iterator();
+ if (!iterator.hasNext()) {
+ return Optional.empty();
+ }
+
+ T result = iterator.next();
+ checkState(!iterator.hasNext(), "The state " + stateName + " has more that one elements");
+ return Optional.of(result);
+ }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
index a58155e..2e26142 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
@@ -20,6 +20,8 @@ package org.apache.flink.iteration.operator;
import org.apache.flink.iteration.IterationID;
import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.checkpoint.Checkpoints;
+import org.apache.flink.iteration.checkpoint.CheckpointsBroker;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
@@ -90,6 +92,31 @@ public class TailOperator extends AbstractStreamOperator<Void>
recordConsumer.accept(streamRecord);
}
+ @Override
+ public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+ super.prepareSnapshotPreBarrier(checkpointId);
+ channel.put(new StreamRecord<>(IterationRecord.newBarrier(checkpointId)));
+ }
+
+ @Override
+ public void notifyCheckpointAborted(long checkpointId) throws Exception {
+ super.notifyCheckpointAborted(checkpointId);
+
+ // TODO: Unfortunately, we have to rely on the tail operator to help
+ // abort the checkpoint since the task thread of the head operator
+ // might get blocked due to not be able to close the raw state files.
+ // We would try to fix it in the Flink side in the future.
+ SubtaskFeedbackKey<?> key =
+ OperatorUtils.createFeedbackKey(iterationId, feedbackIndex)
+ .withSubTaskIndex(
+ getRuntimeContext().getIndexOfThisSubtask(),
+ getRuntimeContext().getAttemptNumber());
+ Checkpoints<?> checkpoints = CheckpointsBroker.get().getCheckpoints(key);
+ if (checkpoints != null) {
+ checkpoints.abort(checkpointId);
+ }
+ }
+
private void processIfObjectReuseEnabled(StreamRecord<IterationRecord<?>> record) {
// Since the record would be reused, we have to clone a new one
IterationRecord<?> cloned = record.getValue().clone();
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
index 2c01c33..f4d6e2d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
@@ -63,10 +63,16 @@ public class HeadOperatorCoordinator implements OperatorCoordinator, SharedProgr
}
@Override
- public void resetToCheckpoint(long checkpointId, @Nullable byte[] bytes) {}
+ public void resetToCheckpoint(long checkpointId, @Nullable byte[] bytes) {
+ for (int i = 0; i < context.currentParallelism(); ++i) {
+ sharedProgressAligner.removeProgressInfo(context.getOperatorId());
+ }
+ }
@Override
- public void subtaskFailed(int subtaskIndex, @Nullable Throwable throwable) {}
+ public void subtaskFailed(int subtaskIndex, @Nullable Throwable throwable) {
+ sharedProgressAligner.removeProgressInfo(context.getOperatorId(), subtaskIndex);
+ }
@Override
public void handleEventFromOperator(int subtaskIndex, OperatorEvent operatorEvent) {
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
index 9c2fd43..4c2204d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
@@ -182,6 +182,24 @@ public class SharedProgressAligner {
checkpointId);
}
+ public void removeProgressInfo(OperatorID operatorId) {
+ runInEventLoop(
+ () -> statusByEpoch.values().forEach(status -> status.remove(operatorId)),
+ "remove the progress information for {}",
+ operatorId);
+ }
+
+ public void removeProgressInfo(OperatorID operatorId, int subtaskIndex) {
+ runInEventLoop(
+ () ->
+ statusByEpoch
+ .values()
+ .forEach(status -> status.remove(operatorId, subtaskIndex)),
+ "remove the progress information for {}-{}",
+ operatorId,
+ subtaskIndex);
+ }
+
private void runInEventLoop(
ThrowingRunnable<Throwable> action,
String actionName,
@@ -234,6 +252,16 @@ public class SharedProgressAligner {
return reportedSubtasks.size() == totalHeadParallelism;
}
+ public void remove(OperatorID operatorID) {
+ reportedSubtasks
+ .entrySet()
+ .removeIf(entry -> entry.getKey().getOperatorId().equals(operatorID));
+ }
+
+ public void remove(OperatorID operatorID, int subtaskIndex) {
+ reportedSubtasks.remove(new OperatorInstanceID(subtaskIndex, operatorID));
+ }
+
public boolean isTerminated() {
checkState(
reportedSubtasks.size() == totalHeadParallelism,
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
index 78132c3..26d397d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
@@ -22,14 +22,18 @@ import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.HeadOperator;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.OutputTag;
+import javax.annotation.Nullable;
+
/** The component to actually deal with the event received in the {@link HeadOperator}. */
public interface HeadOperatorRecordProcessor {
- void initializeState(HeadOperatorState headOperatorState) throws Exception;
+ void initializeState(
+ HeadOperatorState headOperatorState, Iterable<StatePartitionStreamProvider> rawStates);
void processElement(StreamRecord<IterationRecord<?>> record);
@@ -37,6 +41,7 @@ public interface HeadOperatorRecordProcessor {
boolean onGloballyAligned(GloballyAlignedEvent globallyAlignedEvent);
+ @Nullable
HeadOperatorState snapshotState();
/** The context for {@link HeadOperatorRecordProcessor}. */
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
index 861f481..996d929 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
@@ -18,5 +18,35 @@
package org.apache.flink.iteration.operator.headprocessor;
+import java.util.Map;
+
/** The state entry for the head operator. */
-public class HeadOperatorState {}
+public class HeadOperatorState {
+
+ private Map<Integer, Long> numFeedbackRecordsEachRound;
+
+ private int latestRoundAligned;
+
+ private int latestRoundGloballyAligned;
+
+ public HeadOperatorState(
+ Map<Integer, Long> numFeedbackRecordsEachRound,
+ int latestRoundAligned,
+ int latestRoundGloballyAligned) {
+ this.numFeedbackRecordsEachRound = numFeedbackRecordsEachRound;
+ this.latestRoundAligned = latestRoundAligned;
+ this.latestRoundGloballyAligned = latestRoundGloballyAligned;
+ }
+
+ public Map<Integer, Long> getNumFeedbackRecordsEachRound() {
+ return numFeedbackRecordsEachRound;
+ }
+
+ public int getLatestRoundAligned() {
+ return latestRoundAligned;
+ }
+
+ public int getLatestRoundGloballyAligned() {
+ return latestRoundGloballyAligned;
+ }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
index f1a6b0f..107a233 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.OperatorUtils;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.slf4j.Logger;
@@ -30,6 +31,9 @@ import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
+
/**
* Processes the event before we received the terminated global aligned event from the coordinator.
*/
@@ -40,26 +44,46 @@ public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordPro
private final Context headOperatorContext;
- private final StreamRecord<IterationRecord<?>> reusable;
-
private final Map<Integer, Long> numFeedbackRecordsPerEpoch;
private final String senderId;
+ private int latestRoundAligned;
+
+ private int latestRoundGloballyAligned;
+
public RegularHeadOperatorRecordProcessor(Context headOperatorContext) {
this.headOperatorContext = headOperatorContext;
- this.reusable = new StreamRecord<>(null);
this.numFeedbackRecordsPerEpoch = new HashMap<>();
this.senderId =
OperatorUtils.getUniqueSenderId(
headOperatorContext.getStreamConfig().getOperatorID(),
headOperatorContext.getTaskInfo().getIndexOfThisSubtask());
+
+ this.latestRoundAligned = -1;
+ this.latestRoundGloballyAligned = -1;
}
@Override
- public void initializeState(HeadOperatorState headOperatorState) throws Exception {}
+ public void initializeState(
+ HeadOperatorState headOperatorState, Iterable<StatePartitionStreamProvider> rawStates) {
+ checkArgument(headOperatorState != null, "The initialized state should not be null");
+
+ numFeedbackRecordsPerEpoch.putAll(headOperatorState.getNumFeedbackRecordsEachRound());
+ latestRoundAligned = headOperatorState.getLatestRoundAligned();
+ latestRoundGloballyAligned = headOperatorState.getLatestRoundGloballyAligned();
+
+ // If the only round not fully aligned is round 0, then wait till endOfInput in
+ // case the input is changed.
+ if (!(latestRoundAligned == 0 && latestRoundGloballyAligned == -1)) {
+ for (int i = latestRoundGloballyAligned + 1; i <= latestRoundAligned; ++i) {
+ headOperatorContext.updateEpochToCoordinator(
+ i, numFeedbackRecordsPerEpoch.getOrDefault(i, 0L));
+ }
+ }
+ }
@Override
public void processElement(StreamRecord<IterationRecord<?>> element) {
@@ -81,22 +105,34 @@ public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordPro
@Override
public boolean onGloballyAligned(GloballyAlignedEvent globallyAlignedEvent) {
LOG.info("Received global event {}", globallyAlignedEvent);
-
- reusable.replace(
- IterationRecord.newEpochWatermark(
- globallyAlignedEvent.isTerminated()
- ? Integer.MAX_VALUE
- : globallyAlignedEvent.getEpoch(),
- senderId),
- 0);
- headOperatorContext.broadcastOutput(reusable);
-
+ checkState(
+ (globallyAlignedEvent.getEpoch() == 0 && latestRoundGloballyAligned == 0)
+ || globallyAlignedEvent.getEpoch() > latestRoundGloballyAligned,
+ String.format(
+ "Receive unexpected global aligned event, latest = %d, this one = %d",
+ latestRoundGloballyAligned, globallyAlignedEvent.getEpoch()));
+
+ StreamRecord<IterationRecord<?>> record =
+ new StreamRecord<>(
+ IterationRecord.newEpochWatermark(
+ globallyAlignedEvent.isTerminated()
+ ? Integer.MAX_VALUE
+ : globallyAlignedEvent.getEpoch(),
+ senderId),
+ 0);
+ headOperatorContext.broadcastOutput(record);
+
+ latestRoundGloballyAligned =
+ Math.max(globallyAlignedEvent.getEpoch(), latestRoundGloballyAligned);
return globallyAlignedEvent.isTerminated();
}
@Override
public HeadOperatorState snapshotState() {
- return new HeadOperatorState();
+ return new HeadOperatorState(
+ new HashMap<>(numFeedbackRecordsPerEpoch),
+ latestRoundAligned,
+ latestRoundGloballyAligned);
}
@VisibleForTesting
@@ -104,18 +140,50 @@ public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordPro
return numFeedbackRecordsPerEpoch;
}
+ @VisibleForTesting
+ public int getLatestRoundAligned() {
+ return latestRoundAligned;
+ }
+
+ @VisibleForTesting
+ public int getLatestRoundGloballyAligned() {
+ return latestRoundGloballyAligned;
+ }
+
private void processRecord(StreamRecord<IterationRecord<?>> iterationRecord) {
switch (iterationRecord.getValue().getType()) {
case RECORD:
- reusable.replace(iterationRecord.getValue(), iterationRecord.getTimestamp());
- headOperatorContext.output(reusable);
+ headOperatorContext.output(iterationRecord);
break;
case EPOCH_WATERMARK:
LOG.info("Head Received epoch watermark {}", iterationRecord.getValue().getEpoch());
- headOperatorContext.updateEpochToCoordinator(
- iterationRecord.getValue().getEpoch(),
- numFeedbackRecordsPerEpoch.getOrDefault(
- iterationRecord.getValue().getEpoch(), 0L));
+
+ boolean needNotifyCoordinator = false;
+ if (iterationRecord.getValue().getEpoch() == 0) {
+ if (latestRoundAligned <= 0) {
+ needNotifyCoordinator = true;
+ }
+ } else {
+ checkState(
+ iterationRecord.getValue().getEpoch() > latestRoundAligned,
+ String.format(
+ "Unexpected epoch watermark: latest = %d, this one = %d",
+ latestRoundAligned, iterationRecord.getValue().getEpoch()));
+ headOperatorContext.updateEpochToCoordinator(
+ iterationRecord.getValue().getEpoch(),
+ numFeedbackRecordsPerEpoch.getOrDefault(
+ iterationRecord.getValue().getEpoch(), 0L));
+ }
+
+ if (needNotifyCoordinator) {
+ headOperatorContext.updateEpochToCoordinator(
+ iterationRecord.getValue().getEpoch(),
+ numFeedbackRecordsPerEpoch.getOrDefault(
+ iterationRecord.getValue().getEpoch(), 0L));
+ }
+
+ latestRoundAligned =
+ Math.max(iterationRecord.getValue().getEpoch(), latestRoundAligned);
break;
}
}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
index c5377e5..f18bfa3 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration.operator.headprocessor;
import org.apache.flink.iteration.IterationRecord;
import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.FlinkRuntimeException;
@@ -30,7 +31,9 @@ import org.apache.flink.util.FlinkRuntimeException;
public class TerminatingHeadOperatorRecordProcessor implements HeadOperatorRecordProcessor {
@Override
- public void initializeState(HeadOperatorState headOperatorState) throws Exception {}
+ public void initializeState(
+ HeadOperatorState headOperatorState,
+ Iterable<StatePartitionStreamProvider> rawStates) {}
@Override
public void processElement(StreamRecord<IterationRecord<?>> record) {
@@ -55,6 +58,6 @@ public class TerminatingHeadOperatorRecordProcessor implements HeadOperatorRecor
@Override
public HeadOperatorState snapshotState() {
- return new HeadOperatorState();
+ return null;
}
}
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
index 9fa0d28..3cd4bf2 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
@@ -79,6 +79,44 @@ public class IterationConstructionTest extends TestLogger {
}
@Test
+ public void testNotChainingHeadOperator() {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(4);
+ DataStream<Integer> variableSource =
+ env.addSource(new DraftExecutionEnvironment.EmptySource<Integer>() {})
+ .name("Variable")
+ .map(x -> x)
+ .name("map")
+ .setParallelism(2);
+ DataStreamList result =
+ Iterations.iterateUnboundedStreams(
+ DataStreamList.of(variableSource),
+ DataStreamList.of(),
+ ((variableStreams, dataStreams) ->
+ new IterationBodyResult(variableStreams, dataStreams)));
+
+ JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+
+ List<String> expectedVertexNames =
+ Arrays.asList(
+ /* 0 */ "Source: Variable",
+ /* 1 */ "map -> input-map",
+ /* 2 */ "head-map",
+ /* 3 */ "tail-head-map");
+ List<Integer> expectedParallelisms = Arrays.asList(4, 2, 2, 2);
+
+ List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+ assertEquals(
+ expectedVertexNames,
+ vertices.stream().map(JobVertex::getName).collect(Collectors.toList()));
+ assertEquals(
+ expectedParallelisms,
+ vertices.stream().map(JobVertex::getParallelism).collect(Collectors.toList()));
+ assertNotNull(vertices.get(2).getCoLocationGroup());
+ assertSame(vertices.get(2).getCoLocationGroup(), vertices.get(3).getCoLocationGroup());
+ }
+
+ @Test
public void testUnboundedIteration() {
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
DataStream<Integer> variableSource1 =
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index e6d1e26..5603cc9 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -26,6 +26,7 @@ import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.CheckpointType;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
@@ -43,6 +44,7 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
import org.apache.flink.util.FlinkException;
import org.apache.flink.util.SerializedValue;
import org.apache.flink.util.TestLogger;
@@ -56,7 +58,10 @@ import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
@@ -157,7 +162,10 @@ public class HeadOperatorTest extends TestLogger {
while (RecordingHeadOperatorFactory.latestHeadOperator
.getStatus()
- == HeadOperator.HeadOperatorStatus.RUNNING) {}
+ == HeadOperator.HeadOperatorStatus.RUNNING) {
+ Thread.sleep(500);
+ }
+
putFeedbackRecords(
iterationId,
IterationRecord.newEpochWatermark(
@@ -312,6 +320,452 @@ public class HeadOperatorTest extends TestLogger {
});
}
+ @Test
+ public void testSnapshotAndRestoreBeforeRoundZeroFinish() throws Exception {
+ IterationID iterationId = new IterationID();
+ OperatorID operatorId = new OperatorID();
+
+ TaskStateSnapshot taskStateSnapshot =
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ null,
+ harness -> {
+ harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+ harness.processElement(
+ new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+ dispatchOperatorEvent(
+ harness, operatorId, new CoordinatorCheckpointEvent(2));
+ harness.getStreamTask()
+ .triggerCheckpointAsync(
+ new CheckpointMetaData(2, 1000),
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference
+ .getDefault()));
+ putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+ harness.processAll();
+ harness.getTaskStateManager().getWaitForReportLatch().await();
+ return harness.getTaskStateManager()
+ .getLastJobManagerTaskStateSnapshot();
+ });
+ assertNotNull(taskStateSnapshot);
+ cleanupFeedbackChannel(iterationId);
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ taskStateSnapshot,
+ harness -> {
+ checkRestoredOperatorState(
+ harness,
+ HeadOperator.HeadOperatorStatus.RUNNING,
+ Collections.emptyList(),
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ -1,
+ -1);
+ return null;
+ });
+ }
+
+ @Test
+ public void testSnapshotAndRestoreAfterRoundZeroFinishAndRoundOneNotAligned() throws Exception {
+ IterationID iterationId = new IterationID();
+ OperatorID operatorId = new OperatorID();
+
+ TaskStateSnapshot taskStateSnapshot =
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ null,
+ harness -> {
+ harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+ harness.processElement(
+ new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+
+ // Simulates endOfInputs, but not block the main thread.
+ harness.processElement(
+ new StreamRecord<>(
+ IterationRecord.newEpochWatermark(0, "fake")));
+ harness.processAll();
+
+ dispatchOperatorEvent(
+ harness, operatorId, new CoordinatorCheckpointEvent(2));
+ harness.getStreamTask()
+ .triggerCheckpointAsync(
+ new CheckpointMetaData(2, 1000),
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference
+ .getDefault()));
+ putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+ harness.processAll();
+ harness.getTaskStateManager().getWaitForReportLatch().await();
+ return harness.getTaskStateManager()
+ .getLastJobManagerTaskStateSnapshot();
+ });
+ assertNotNull(taskStateSnapshot);
+ cleanupFeedbackChannel(iterationId);
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ taskStateSnapshot,
+ harness -> {
+ checkRestoredOperatorState(
+ harness,
+ HeadOperator.HeadOperatorStatus.RUNNING,
+ Collections.emptyList(),
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ 0,
+ -1);
+
+ // Simulates endOfInputs, but not block the main thread.
+ harness.processElement(
+ new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+ assertEquals(
+ Collections.singletonList(new SubtaskAlignedEvent(0, 0, false)),
+ new ArrayList<>(
+ ((RecordingOperatorEventGateway)
+ RecordingHeadOperatorFactory.latestHeadOperator
+ .getOperatorEventGateway())
+ .operatorEvents));
+ return null;
+ });
+ }
+
+ @Test
+ public void testSnapshotAndRestoreAfterRoundZeroFinishAndRoundOneAligned() throws Exception {
+ IterationID iterationId = new IterationID();
+ OperatorID operatorId = new OperatorID();
+
+ TaskStateSnapshot taskStateSnapshot =
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ null,
+ harness -> {
+ harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+ harness.processElement(
+ new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+
+ // Simulates endOfInputs, but not block the main thread.
+ harness.processElement(
+ new StreamRecord<>(
+ IterationRecord.newEpochWatermark(0, "fake")));
+ dispatchOperatorEvent(
+ harness, operatorId, new GloballyAlignedEvent(0, false));
+ putFeedbackRecords(
+ iterationId, IterationRecord.newRecord(100, 1), null);
+ putFeedbackRecords(
+ iterationId,
+ IterationRecord.newEpochWatermark(1, "tail"),
+ null);
+ harness.processAll();
+
+ dispatchOperatorEvent(
+ harness, operatorId, new CoordinatorCheckpointEvent(2));
+ harness.getStreamTask()
+ .triggerCheckpointAsync(
+ new CheckpointMetaData(2, 1000),
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference
+ .getDefault()));
+ putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+ harness.processAll();
+ harness.getTaskStateManager().getWaitForReportLatch().await();
+ return harness.getTaskStateManager()
+ .getLastJobManagerTaskStateSnapshot();
+ });
+ assertNotNull(taskStateSnapshot);
+ cleanupFeedbackChannel(iterationId);
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ taskStateSnapshot,
+ harness -> {
+ checkRestoredOperatorState(
+ harness,
+ HeadOperator.HeadOperatorStatus.RUNNING,
+ Collections.emptyList(),
+ Collections.singletonList(new SubtaskAlignedEvent(1, 1, false)),
+ Collections.singletonMap(1, 1L),
+ 1,
+ 0);
+ return null;
+ });
+ }
+
+ @Test
+ public void testSnapshotAndRestoreWithFeedbackRecords() throws Exception {
+ IterationID iterationId = new IterationID();
+ OperatorID operatorId = new OperatorID();
+
+ TaskStateSnapshot taskStateSnapshot =
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ null,
+ harness -> {
+ harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+ putFeedbackRecords(
+ iterationId,
+ IterationRecord.newEpochWatermark(4, "tail"),
+ null);
+ dispatchOperatorEvent(
+ harness, operatorId, new GloballyAlignedEvent(4, false));
+ harness.processAll();
+
+ putFeedbackRecords(
+ iterationId, IterationRecord.newRecord(100, 5), null);
+ dispatchOperatorEvent(
+ harness, operatorId, new CoordinatorCheckpointEvent(2));
+ harness.getStreamTask()
+ .triggerCheckpointAsync(
+ new CheckpointMetaData(2, 1000),
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference
+ .getDefault()));
+ harness.processAll();
+
+ putFeedbackRecords(
+ iterationId, IterationRecord.newRecord(101, 5), null);
+ putFeedbackRecords(
+ iterationId, IterationRecord.newRecord(102, 5), null);
+ putFeedbackRecords(
+ iterationId, IterationRecord.newRecord(103, 6), null);
+ putFeedbackRecords(
+ iterationId,
+ IterationRecord.newEpochWatermark(5, "tail"),
+ null);
+ putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+ harness.processAll();
+
+ harness.getTaskStateManager().getWaitForReportLatch().await();
+ return harness.getTaskStateManager()
+ .getLastJobManagerTaskStateSnapshot();
+ });
+ assertNotNull(taskStateSnapshot);
+ cleanupFeedbackChannel(iterationId);
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ taskStateSnapshot,
+ harness -> {
+ checkRestoredOperatorState(
+ harness,
+ HeadOperator.HeadOperatorStatus.RUNNING,
+ Arrays.asList(
+ new StreamRecord<>(IterationRecord.newRecord(101, 5)),
+ new StreamRecord<>(IterationRecord.newRecord(102, 5)),
+ new StreamRecord<>(IterationRecord.newRecord(103, 6))),
+ /* The one before checkpoint and the two after checkpoint */
+ Collections.singletonList(new SubtaskAlignedEvent(5, 3, false)),
+ new HashMap<Integer, Long>() {
+ {
+ this.put(5, 3L);
+ this.put(6, 1L);
+ }
+ },
+ 5,
+ 4);
+ return null;
+ });
+ }
+
+ @Test
+ public void testCheckpointBeforeTerminated() throws Exception {
+ IterationID iterationId = new IterationID();
+ OperatorID operatorId = new OperatorID();
+
+ TaskStateSnapshot taskStateSnapshot =
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ null,
+ harness -> {
+ harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+ putFeedbackRecords(
+ iterationId,
+ IterationRecord.newEpochWatermark(4, "tail"),
+ null);
+ dispatchOperatorEvent(
+ harness, operatorId, new GloballyAlignedEvent(4, false));
+ harness.processAll();
+
+ putFeedbackRecords(
+ iterationId,
+ IterationRecord.newEpochWatermark(5, "tail"),
+ null);
+ dispatchOperatorEvent(
+ harness, operatorId, new CoordinatorCheckpointEvent(2));
+ harness.getStreamTask()
+ .triggerCheckpointAsync(
+ new CheckpointMetaData(2, 1000),
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference
+ .getDefault()));
+ harness.processAll();
+
+ dispatchOperatorEvent(
+ harness, operatorId, new GloballyAlignedEvent(5, true));
+ harness.processAll();
+
+ putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+ harness.processAll();
+
+ harness.getTaskStateManager().getWaitForReportLatch().await();
+ return harness.getTaskStateManager()
+ .getLastJobManagerTaskStateSnapshot();
+ });
+
+ assertNotNull(taskStateSnapshot);
+ cleanupFeedbackChannel(iterationId);
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ taskStateSnapshot,
+ harness -> {
+ checkRestoredOperatorState(
+ harness,
+ HeadOperator.HeadOperatorStatus.RUNNING,
+ Collections.emptyList(),
+ Collections.singletonList(new SubtaskAlignedEvent(5, 0, false)),
+ Collections.emptyMap(),
+ 5,
+ 4);
+ return null;
+ });
+ }
+
+ @Test
+ public void testCheckpointAfterTerminating() throws Exception {
+ IterationID iterationId = new IterationID();
+ OperatorID operatorId = new OperatorID();
+
+ TaskStateSnapshot taskStateSnapshot =
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ null,
+ harness -> {
+ harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+ putFeedbackRecords(
+ iterationId,
+ IterationRecord.newEpochWatermark(5, "tail"),
+ null);
+ dispatchOperatorEvent(
+ harness, operatorId, new GloballyAlignedEvent(5, true));
+ harness.processAll();
+
+ dispatchOperatorEvent(
+ harness, operatorId, new CoordinatorCheckpointEvent(2));
+ harness.getStreamTask()
+ .triggerCheckpointAsync(
+ new CheckpointMetaData(2, 1000),
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference
+ .getDefault()));
+ harness.processAll();
+
+ harness.getTaskStateManager().getWaitForReportLatch().await();
+ return harness.getTaskStateManager()
+ .getLastJobManagerTaskStateSnapshot();
+ });
+ assertNotNull(taskStateSnapshot);
+ cleanupFeedbackChannel(iterationId);
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ taskStateSnapshot,
+ harness -> {
+ checkRestoredOperatorState(
+ harness,
+ HeadOperator.HeadOperatorStatus.TERMINATING,
+ Collections.emptyList(),
+ Collections.emptyList(),
+ Collections.emptyMap(),
+ -1,
+ -1);
+
+ putFeedbackRecords(
+ iterationId,
+ IterationRecord.newEpochWatermark(Integer.MAX_VALUE + 1, "tail"),
+ null);
+ harness.processEvent(EndOfData.INSTANCE);
+ harness.finishProcessing();
+
+ return null;
+ });
+ }
+
+ @Test(timeout = 20000)
+ public void testTailAbortPendingCheckpointIfHeadBlocked() throws Exception {
+ IterationID iterationId = new IterationID();
+ OperatorID operatorId = new OperatorID();
+
+ createHarnessAndRun(
+ iterationId,
+ operatorId,
+ null,
+ harness -> {
+ harness.processElement(new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+ dispatchOperatorEvent(harness, operatorId, new CoordinatorCheckpointEvent(2));
+ harness.getStreamTask()
+ .triggerCheckpointAsync(
+ new CheckpointMetaData(2, 1000),
+ CheckpointOptions.alignedNoTimeout(
+ CheckpointType.CHECKPOINT,
+ CheckpointStorageLocationReference.getDefault()));
+ harness.processAll();
+
+ putFeedbackRecords(iterationId, IterationRecord.newRecord(100, 1), null);
+ harness.processAll();
+
+ // Simulates the tail operators help to abort the checkpoint
+ CompletableFuture<Void> supplier =
+ CompletableFuture.supplyAsync(
+ () -> {
+ try {
+ // Slightly postpone the execution till the head
+ // operator get blocked.
+ Thread.sleep(2000);
+
+ OneInputStreamOperatorTestHarness<
+ IterationRecord<?>, Void>
+ testHarness =
+ new OneInputStreamOperatorTestHarness<>(
+ new TailOperator(
+ iterationId, 0));
+ testHarness.open();
+
+ testHarness.getOperator().notifyCheckpointAborted(2);
+ } catch (Exception e) {
+ throw new CompletionException(e);
+ }
+
+ return null;
+ });
+
+ harness.getStreamTask().notifyCheckpointAbortAsync(2, 0);
+ harness.processAll();
+
+ supplier.get();
+
+ return null;
+ });
+ }
+
private <T> T createHarnessAndRun(
IterationID iterationId,
OperatorID operatorId,
@@ -371,6 +825,46 @@ public class HeadOperatorTest extends TestLogger {
: new StreamRecord<>(record, timestamp));
}
+ private static void checkRestoredOperatorState(
+ StreamTaskMailboxTestHarness<?> harness,
+ HeadOperator.HeadOperatorStatus expectedStatus,
+ List<Object> expectedOutput,
+ List<OperatorEvent> expectedOperatorEvents,
+ Map<Integer, Long> expectedNumFeedbackRecords,
+ int expectedLastAligned,
+ int expectedLastGloballyAligned) {
+ HeadOperator headOperator = RecordingHeadOperatorFactory.latestHeadOperator;
+ assertEquals(expectedStatus, headOperator.getStatus());
+ assertEquals(expectedOutput, new ArrayList<>(harness.getOutput()));
+ RecordingOperatorEventGateway eventGateway =
+ (RecordingOperatorEventGateway) headOperator.getOperatorEventGateway();
+ assertEquals(expectedOperatorEvents, new ArrayList<>(eventGateway.operatorEvents));
+
+ if (expectedStatus == HeadOperator.HeadOperatorStatus.RUNNING) {
+ RegularHeadOperatorRecordProcessor recordProcessor =
+ (RegularHeadOperatorRecordProcessor) headOperator.getRecordProcessor();
+ assertEquals(
+ expectedNumFeedbackRecords, recordProcessor.getNumFeedbackRecordsPerEpoch());
+ assertEquals(expectedLastAligned, recordProcessor.getLatestRoundAligned());
+ assertEquals(
+ expectedLastGloballyAligned, recordProcessor.getLatestRoundGloballyAligned());
+ }
+ }
+
+ /**
+ * We have to manually cleanup the feedback channel due to not be able to set the attempt
+ * number.
+ */
+ private static void cleanupFeedbackChannel(IterationID iterationId) {
+ FeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel =
+ FeedbackChannelBroker.get()
+ .getChannel(
+ OperatorUtils.<StreamRecord<IterationRecord<?>>>createFeedbackKey(
+ iterationId, 0)
+ .withSubTaskIndex(0, 0));
+ feedbackChannel.close();
+ }
+
private static class RecordingOperatorEventGateway implements OperatorEventGateway {
final BlockingQueue<OperatorEvent> operatorEvents = new LinkedBlockingQueue<>();