You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2023/01/29 14:11:56 UTC
[flink] branch master updated: [FLINK-30707][runtime] Let speculative execution take input data amount into account when detecting slow tasks.
This is an automated email from the ASF dual-hosted git repository.
zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 02b09eac623 [FLINK-30707][runtime] Let speculative execution take input data amount into account when detecting slow tasks.
02b09eac623 is described below
commit 02b09eac6238fe059b6cb0ddbe791b761962dd14
Author: sunxia <xi...@gmail.com>
AuthorDate: Tue Jan 17 11:28:02 2023 +0800
[FLINK-30707][runtime] Let speculative execution take input data amount into account when detecting slow tasks.
This closes #21695.
---
.../runtime/executiongraph/ExecutionVertex.java | 14 ++
.../adaptivebatch/AdaptiveBatchScheduler.java | 43 ++++++
.../adaptivebatch/AllToAllBlockingResultInfo.java | 23 +++
.../adaptivebatch/BlockingResultInfo.java | 11 ++
.../adaptivebatch/PointwiseBlockingResultInfo.java | 27 ++++
.../ExecutionTimeBasedSlowTaskDetector.java | 93 ++++++++++--
.../adaptivebatch/AdaptiveBatchSchedulerTest.java | 19 +++
.../AllToAllBlockingResultInfoTest.java | 15 ++
...tVertexParallelismAndInputInfosDeciderTest.java | 6 +
.../PointwiseBlockingResultInfoTest.java | 15 ++
.../ExecutionTimeBasedSlowTaskDetectorTest.java | 164 +++++++++++++++++++++
11 files changed, 418 insertions(+), 12 deletions(-)
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
index bd3a9794db0..5ff589669b4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
@@ -60,6 +60,8 @@ import static org.apache.flink.util.Preconditions.checkState;
public class ExecutionVertex
implements AccessExecutionVertex, Archiveable<ArchivedExecutionVertex> {
+ public static final long NUM_BYTES_UNKNOWN = -1;
+
public static final int MAX_DISTINCT_LOCATIONS_TO_CONSIDER = 8;
// --------------------------------------------------------------------------------------------
@@ -86,6 +88,8 @@ public class ExecutionVertex
private int nextAttemptNumber;
+ private long inputBytes;
+
/** This field holds the allocation id of the last successful assignment. */
@Nullable private TaskManagerLocation lastAssignedLocation;
@@ -141,6 +145,8 @@ public class ExecutionVertex
this.nextAttemptNumber = initialAttemptCount;
+ this.inputBytes = NUM_BYTES_UNKNOWN;
+
this.timeout = timeout;
this.inputSplits = new ArrayList<>();
@@ -169,6 +175,14 @@ public class ExecutionVertex
.get(subTaskIndex);
}
+ public void setInputBytes(long inputBytes) {
+ this.inputBytes = inputBytes;
+ }
+
+ public long getInputBytes() {
+ return inputBytes;
+ }
+
public Execution getPartitionProducer() {
return currentExecution;
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
index 5032cc0f77b..e05df56b84f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
@@ -32,7 +32,9 @@ import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
import org.apache.flink.runtime.executiongraph.IOMetrics;
+import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.executiongraph.JobStatusListener;
import org.apache.flink.runtime.executiongraph.MarkPartitionFinishedStrategy;
@@ -76,7 +78,10 @@ import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.function.Consumer;
import java.util.function.Function;
+import java.util.stream.Collectors;
+import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.HYBRID_FULL;
+import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.HYBRID_SELECTIVE;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
@@ -207,6 +212,16 @@ public class AdaptiveBatchScheduler extends DefaultScheduler {
});
}
+ @Override
+ public void allocateSlotsAndDeploy(final List<ExecutionVertexID> verticesToDeploy) {
+ List<ExecutionVertex> executionVertices =
+ verticesToDeploy.stream()
+ .map(this::getExecutionVertex)
+ .collect(Collectors.toList());
+ enrichInputBytesForExecutionVertices(executionVertices);
+ super.allocateSlotsAndDeploy(verticesToDeploy);
+ }
+
@Override
protected void resetForNewExecution(final ExecutionVertexID executionVertexId) {
final ExecutionVertex executionVertex = getExecutionVertex(executionVertexId);
@@ -327,6 +342,34 @@ public class AdaptiveBatchScheduler extends DefaultScheduler {
return parallelismAndInputInfos;
}
+ private void enrichInputBytesForExecutionVertices(List<ExecutionVertex> executionVertices) {
+ for (ExecutionVertex ev : executionVertices) {
+ List<IntermediateResult> intermediateResults = ev.getJobVertex().getInputs();
+ boolean hasHybridEdge =
+ intermediateResults.stream()
+ .anyMatch(
+ ir ->
+ ir.getResultType() == HYBRID_FULL
+ || ir.getResultType() == HYBRID_SELECTIVE);
+ if (intermediateResults.isEmpty() || hasHybridEdge) {
+ continue;
+ }
+ long inputBytes = 0;
+ for (IntermediateResult intermediateResult : intermediateResults) {
+ ExecutionVertexInputInfo inputInfo =
+ ev.getExecutionVertexInputInfo(intermediateResult.getId());
+ IndexRange partitionIndexRange = inputInfo.getPartitionIndexRange();
+ IndexRange subpartitionIndexRange = inputInfo.getSubpartitionIndexRange();
+ BlockingResultInfo blockingResultInfo =
+ checkNotNull(getBlockingResultInfo(intermediateResult.getId()));
+ inputBytes +=
+ blockingResultInfo.getNumBytesProduced(
+ partitionIndexRange, subpartitionIndexRange);
+ }
+ ev.setInputBytes(inputBytes);
+ }
+ }
+
private void changeJobVertexParallelism(ExecutionJobVertex jobVertex, int parallelism) {
if (jobVertex.isParallelismDecided()) {
return;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
index bae270a608a..9f01a1061e1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
@@ -18,6 +18,7 @@
package org.apache.flink.runtime.scheduler.adaptivebatch;
+import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
@@ -81,6 +82,28 @@ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo {
}
}
+ @Override
+ public long getNumBytesProduced(
+ IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) {
+ checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready");
+ checkState(
+ partitionIndexRange.getStartIndex() == 0
+ && partitionIndexRange.getEndIndex() == numOfPartitions - 1,
+ "For All-To-All edges, the partition range should always be [0, %s).",
+ numOfPartitions);
+ checkState(
+ subpartitionIndexRange.getEndIndex() < numOfSubpartitions,
+ "Subpartition index %s is out of range.",
+ subpartitionIndexRange.getEndIndex());
+
+ return aggregatedSubpartitionBytes
+ .subList(
+ subpartitionIndexRange.getStartIndex(),
+ subpartitionIndexRange.getEndIndex() + 1)
+ .stream()
+ .reduce(0L, Long::sum);
+ }
+
@Override
public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) {
// Once all partitions are finished, we can convert the subpartition bytes to aggregated
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
index 2eeb2a90a20..e836d993869 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
@@ -18,6 +18,7 @@
package org.apache.flink.runtime.scheduler.adaptivebatch;
+import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IntermediateResultInfo;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
@@ -39,6 +40,16 @@ public interface BlockingResultInfo extends IntermediateResultInfo {
*/
long getNumBytesProduced();
+ /**
+ * Return the aggregated num of bytes according to the index range for partition and
+ * subpartition.
+ *
+ * @param partitionIndexRange range of the index of the consumed partition.
+ * @param subpartitionIndexRange range of the index of the consumed subpartition.
+ * @return aggregated bytes according to the index ranges.
+ */
+ long getNumBytesProduced(IndexRange partitionIndexRange, IndexRange subpartitionIndexRange);
+
/**
* Record the information of the result partition.
*
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
index 225b653064b..ed993af9d81 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
@@ -18,6 +18,7 @@
package org.apache.flink.runtime.scheduler.adaptivebatch;
+import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import java.util.Arrays;
@@ -60,4 +61,30 @@ public class PointwiseBlockingResultInfo extends AbstractBlockingResultInfo {
.flatMapToLong(Arrays::stream)
.reduce(0L, Long::sum);
}
+
+ @Override
+ public long getNumBytesProduced(
+ IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) {
+ long inputBytes = 0;
+ for (int i = partitionIndexRange.getStartIndex();
+ i <= partitionIndexRange.getEndIndex();
+ ++i) {
+ checkState(
+ subpartitionBytesByPartitionIndex.get(i) != null,
+ "Partition index %s is not ready.",
+ i);
+ checkState(
+ subpartitionIndexRange.getEndIndex()
+ < subpartitionBytesByPartitionIndex.get(i).length,
+ "Subpartition end index %s is out of range of partition %s.",
+ subpartitionIndexRange.getEndIndex(),
+ i);
+ for (int j = subpartitionIndexRange.getStartIndex();
+ j <= subpartitionIndexRange.getEndIndex();
+ ++j) {
+ inputBytes += subpartitionBytesByPartitionIndex.get(i)[j];
+ }
+ }
+ return inputBytes;
+ }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetector.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetector.java
index 0df536bc2e8..34cb5b47b36 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetector.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetector.java
@@ -41,6 +41,7 @@ import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
+import static org.apache.flink.runtime.executiongraph.ExecutionVertex.NUM_BYTES_UNKNOWN;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkState;
@@ -118,8 +119,9 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
/**
* Given that the parallelism is N and the ratio is R, define T as the median of the first N*R
- * finished tasks' execution time. The baseline will be T*M, where M is the multiplier. A task
- * will be identified as slow if its execution time is longer than the baseline.
+ * finished tasks' execution time. The baseline will be T*M, where M is the multiplier. Note
+ * that the execution time will be weighted with its input bytes when calculating the median. A
+ * task will be identified as slow if its weighted execution time is longer than the baseline.
*/
@VisibleForTesting
Map<ExecutionVertexID, Collection<ExecutionAttemptID>> findSlowTasks(
@@ -131,7 +133,7 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
final List<ExecutionJobVertex> jobVerticesToCheck = getJobVerticesToCheck(executionGraph);
for (ExecutionJobVertex ejv : jobVerticesToCheck) {
- final long baseline = getBaseline(ejv, currentTimeMillis);
+ final ExecutionTimeWithInputBytes baseline = getBaseline(ejv, currentTimeMillis);
for (ExecutionVertex ev : ejv.getTaskVertices()) {
if (ev.getExecutionState().isTerminal()) {
@@ -168,21 +170,25 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
return (double) finishedCount / executionJobVertex.getTaskVertices().length;
}
- private long getBaseline(
+ private ExecutionTimeWithInputBytes getBaseline(
final ExecutionJobVertex executionJobVertex, final long currentTimeMillis) {
- final long executionTimeMedian =
+ final ExecutionTimeWithInputBytes weightedExecutionTimeMedian =
calculateFinishedTaskExecutionTimeMedian(executionJobVertex, currentTimeMillis);
- return (long) Math.max(baselineLowerBoundMillis, executionTimeMedian * baselineMultiplier);
+ long multipliedBaseline =
+ (long) (weightedExecutionTimeMedian.getExecutionTime() * baselineMultiplier);
+
+ return new ExecutionTimeWithInputBytes(
+ multipliedBaseline, weightedExecutionTimeMedian.getInputBytes());
}
- private long calculateFinishedTaskExecutionTimeMedian(
+ private ExecutionTimeWithInputBytes calculateFinishedTaskExecutionTimeMedian(
final ExecutionJobVertex executionJobVertex, final long currentTime) {
final int baselineExecutionCount =
(int) Math.round(executionJobVertex.getParallelism() * baselineRatio);
if (baselineExecutionCount == 0) {
- return 0;
+ return new ExecutionTimeWithInputBytes(0L, NUM_BYTES_UNKNOWN);
}
final List<Execution> finishedExecutions =
@@ -193,9 +199,9 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
checkState(finishedExecutions.size() >= baselineExecutionCount);
- final List<Long> firstFinishedExecutions =
+ final List<ExecutionTimeWithInputBytes> firstFinishedExecutions =
finishedExecutions.stream()
- .map(e -> getExecutionTime(e, currentTime))
+ .map(e -> getExecutionTimeAndInputBytes(e, currentTime))
.sorted()
.limit(baselineExecutionCount)
.collect(Collectors.toList());
@@ -204,10 +210,18 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
}
private List<ExecutionAttemptID> findExecutionsExceedingBaseline(
- Collection<Execution> executions, long baseline, long currentTimeMillis) {
+ Collection<Execution> executions,
+ ExecutionTimeWithInputBytes baseline,
+ long currentTimeMillis) {
return executions.stream()
.filter(e -> !e.getState().isTerminal() && e.getState() != ExecutionState.CANCELING)
- .filter(e -> getExecutionTime(e, currentTimeMillis) >= baseline)
+ .filter(
+ e -> {
+ ExecutionTimeWithInputBytes timeWithBytes =
+ getExecutionTimeAndInputBytes(e, currentTimeMillis);
+ return timeWithBytes.getExecutionTime() >= baselineLowerBoundMillis
+ && timeWithBytes.compareTo(baseline) >= 0;
+ })
.map(Execution::getAttemptId)
.collect(Collectors.toList());
}
@@ -225,10 +239,65 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
}
}
+ private long getExecutionInputBytes(final Execution execution) {
+ return execution.getVertex().getInputBytes();
+ }
+
+ private ExecutionTimeWithInputBytes getExecutionTimeAndInputBytes(
+ Execution execution, final long currentTime) {
+ long executionTime = getExecutionTime(execution, currentTime);
+ long executionInputBytes = getExecutionInputBytes(execution);
+
+ return new ExecutionTimeWithInputBytes(executionTime, executionInputBytes);
+ }
+
@Override
public void stop() {
if (scheduledDetectionFuture != null) {
scheduledDetectionFuture.cancel(false);
}
}
+
+ /** This class defines the execution time and input bytes for an execution. */
+ @VisibleForTesting
+ static class ExecutionTimeWithInputBytes implements Comparable<ExecutionTimeWithInputBytes> {
+
+ private final long executionTime;
+ private final long inputBytes;
+
+ public ExecutionTimeWithInputBytes(long executionTime, long inputBytes) {
+ this.executionTime = executionTime;
+ this.inputBytes = inputBytes;
+ }
+
+ public long getExecutionTime() {
+ return executionTime;
+ }
+
+ public long getInputBytes() {
+ return inputBytes;
+ }
+
+ @Override
+ public int compareTo(ExecutionTimeWithInputBytes other) {
+ // In order to ensure the stability of comparison, it requires both elements' input
+ // bytes should be both valid or both UNKNOWN, unless the execution time is 0.
+ // (When baselineRatio is 0, a baseline of 0 execution time will be generated.)
+ if (inputBytes == NUM_BYTES_UNKNOWN || other.getInputBytes() == NUM_BYTES_UNKNOWN) {
+ if (inputBytes == NUM_BYTES_UNKNOWN && other.getInputBytes() == NUM_BYTES_UNKNOWN
+ || executionTime == 0
+ || other.executionTime == 0) {
+ return (int) (executionTime - other.getExecutionTime());
+ } else {
+ throw new IllegalArgumentException(
+ "Both compared elements should be NUM_BYTES_UNKNOWN.");
+ }
+ }
+
+ return Double.compare(
+ (double) executionTime / Math.max(inputBytes, Double.MIN_VALUE),
+ (double) other.getExecutionTime()
+ / Math.max(other.getInputBytes(), Double.MIN_VALUE));
+ }
+ }
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
index fec551032cb..b343ea6749b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
@@ -118,6 +118,9 @@ class AdaptiveBatchSchedulerTest {
// check that the jobGraph is updated
assertThat(sink.getParallelism()).isEqualTo(10);
+
+ // check aggregatedInputDataBytes of each ExecutionVertex calculated.
+ checkAggregatedInputDataBytesIsCalculated(sinkExecutionJobVertex);
}
@Test
@@ -146,6 +149,9 @@ class AdaptiveBatchSchedulerTest {
// check that the jobGraph is updated
assertThat(sink.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
+
+ // check aggregatedInputDataBytes of each ExecutionVertex calculated.
+ checkAggregatedInputDataBytesIsCalculated(sinkExecutionJobVertex);
}
@Test
@@ -280,6 +286,19 @@ class AdaptiveBatchSchedulerTest {
getOnlyElement(jobVertex.getProducedDataSets()).getId());
}
+ private void checkAggregatedInputDataBytesIsCalculated(
+ ExecutionJobVertex sinkExecutionJobVertex) {
+ final ExecutionVertex[] executionVertices = sinkExecutionJobVertex.getTaskVertices();
+ long totalInputBytes = 0;
+ for (ExecutionVertex ev : executionVertices) {
+ long executionInputBytes = ev.getInputBytes();
+ assertThat(executionInputBytes).isNotEqualTo(-1);
+ totalInputBytes += executionInputBytes;
+ }
+
+ assertThat(totalInputBytes).isEqualTo(26_000L);
+ }
+
private void triggerFailedByPartitionNotFound(
SchedulerBase scheduler,
ExecutionVertex producerVertex,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
index f1c632b9ea0..32fc107874c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
@@ -18,6 +18,7 @@
package org.apache.flink.runtime.scheduler.adaptivebatch;
+import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
@@ -39,6 +40,20 @@ class AllToAllBlockingResultInfoTest {
testGetNumBytesProduced(true, 96L);
}
+ @Test
+ void testGetNumBytesProducedWithIndexRange() {
+ AllToAllBlockingResultInfo resultInfo =
+ new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false);
+ resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L}));
+ resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {128L, 256L}));
+
+ IndexRange partitionIndexRange = new IndexRange(0, 1);
+ IndexRange subpartitionIndexRange = new IndexRange(0, 0);
+
+ assertThat(resultInfo.getNumBytesProduced(partitionIndexRange, subpartitionIndexRange))
+ .isEqualTo(160L);
+ }
+
@Test
void testGetAggregatedSubpartitionBytes() {
AllToAllBlockingResultInfo resultInfo =
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
index 66584db4c6f..778f1c2df23 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
@@ -542,6 +542,12 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
return producedBytes;
}
+ @Override
+ public long getNumBytesProduced(
+ IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) {}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java
index 556a2d48876..73821e48637 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java
@@ -18,6 +18,7 @@
package org.apache.flink.runtime.scheduler.adaptivebatch;
+import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
@@ -39,6 +40,20 @@ class PointwiseBlockingResultInfoTest {
assertThat(resultInfo.getNumBytesProduced()).isEqualTo(192L);
}
+ @Test
+ void testGetNumBytesProducedWithIndexRange() {
+ PointwiseBlockingResultInfo resultInfo =
+ new PointwiseBlockingResultInfo(new IntermediateDataSetID(), 2, 2);
+ resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L}));
+ resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {128L, 256L}));
+
+ IndexRange partitionIndexRange = new IndexRange(0, 0);
+ IndexRange subpartitionIndexRange = new IndexRange(0, 1);
+
+ assertThat(resultInfo.getNumBytesProduced(partitionIndexRange, subpartitionIndexRange))
+ .isEqualTo(96L);
+ }
+
@Test
void testGetBytesWithPartialPartitionInfos() {
PointwiseBlockingResultInfo resultInfo =
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetectorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetectorTest.java
index 620fcbe67f7..b11f86c80d4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetectorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetectorTest.java
@@ -34,6 +34,7 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
+import org.apache.flink.runtime.scheduler.slowtaskdetector.ExecutionTimeBasedSlowTaskDetector.ExecutionTimeWithInputBytes;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.testutils.TestingUtils;
@@ -43,12 +44,16 @@ import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
import java.time.Duration;
+import java.util.ArrayList;
import java.util.Collection;
+import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
+import java.util.stream.Collectors;
import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createNoOpVertex;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
/** Tests for {@link ExecutionTimeBasedSlowTaskDetector}. */
class ExecutionTimeBasedSlowTaskDetectorTest {
@@ -200,6 +205,165 @@ class ExecutionTimeBasedSlowTaskDetectorTest {
assertThat(slowTasks).hasSize(2);
}
+ @Test
+ void testBalancedInput() throws Exception {
+ final int parallelism = 3;
+ final JobVertex jobVertex1 = createNoOpVertex(parallelism);
+ final JobVertex jobVertex2 = createNoOpVertex(parallelism);
+ jobVertex2.connectNewDataSetAsInput(
+ jobVertex1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+ final ExecutionGraph executionGraph = createExecutionGraph(jobVertex1, jobVertex2);
+ final ExecutionTimeBasedSlowTaskDetector slowTaskDetector =
+ createSlowTaskDetector(0.3, 1, 0);
+
+ final ExecutionVertex ev21 =
+ executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[0];
+ ev21.setInputBytes(1024);
+ final ExecutionVertex ev22 =
+ executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[1];
+ ev22.setInputBytes(1024);
+ final ExecutionVertex ev23 =
+ executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[2];
+ ev23.setInputBytes(1024);
+
+ ev23.getCurrentExecutionAttempt().markFinished();
+
+ final Map<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks =
+ slowTaskDetector.findSlowTasks(executionGraph);
+
+ assertThat(slowTasks).hasSize(2);
+ }
+
+ @Test
+ void testBalancedInputWithLargeLowerBound() throws Exception {
+ final int parallelism = 3;
+ final JobVertex jobVertex1 = createNoOpVertex(parallelism);
+ final JobVertex jobVertex2 = createNoOpVertex(parallelism);
+ jobVertex2.connectNewDataSetAsInput(
+ jobVertex1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+ final ExecutionGraph executionGraph = createExecutionGraph(jobVertex1, jobVertex2);
+ final ExecutionTimeBasedSlowTaskDetector slowTaskDetector =
+ createSlowTaskDetector(0.3, 1, Integer.MAX_VALUE);
+
+ final ExecutionVertex ev21 =
+ executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[0];
+ ev21.setInputBytes(1024);
+ final ExecutionVertex ev22 =
+ executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[1];
+ ev22.setInputBytes(1024);
+ final ExecutionVertex ev23 =
+ executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[2];
+ ev23.setInputBytes(1024);
+
+ ev23.getCurrentExecutionAttempt().markFinished();
+
+ final Map<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks =
+ slowTaskDetector.findSlowTasks(executionGraph);
+
+ assertThat(slowTasks).isEmpty();
+ }
+
+ @Test
+ void testUnbalancedInput() throws Exception {
+ final int parallelism = 3;
+ final JobVertex jobVertex1 = createNoOpVertex(parallelism);
+ final JobVertex jobVertex2 = createNoOpVertex(parallelism);
+ jobVertex2.connectNewDataSetAsInput(
+ jobVertex1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+ final ExecutionGraph executionGraph = createExecutionGraph(jobVertex1, jobVertex2);
+ final ExecutionTimeBasedSlowTaskDetector slowTaskDetector =
+ createSlowTaskDetector(0.3, 1, 0);
+
+ final ExecutionVertex ev21 =
+ executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[0];
+ ev21.setInputBytes(1024);
+ final ExecutionVertex ev22 =
+ executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[1];
+ ev22.setInputBytes(1_024_000);
+ final ExecutionVertex ev23 =
+ executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[2];
+ ev23.setInputBytes(4_096_000);
+
+ Thread.sleep(1000);
+ ev21.getCurrentExecutionAttempt().markFinished();
+
+ final Map<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks =
+ slowTaskDetector.findSlowTasks(executionGraph);
+
+ // no task will be detected as slow task
+ assertThat(slowTasks).hasSize(0);
+ }
+
+ @Test
+ void testSortedExecutionTimeWithInputBytes() {
+ // executions with unbalanced input bytes
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes1 =
+ new ExecutionTimeWithInputBytes(10, 10);
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes2 =
+ new ExecutionTimeWithInputBytes(10, 20);
+
+ List<ExecutionTimeWithInputBytes> pairList = new ArrayList<>();
+ pairList.add(executionTimeWithInputBytes1);
+ pairList.add(executionTimeWithInputBytes2);
+
+ List<ExecutionTimeWithInputBytes> sortedList =
+ pairList.stream().sorted().collect(Collectors.toList());
+
+ assertThat(sortedList.get(0)).isEqualTo(executionTimeWithInputBytes2);
+
+ // executions with balanced input bytes
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes3 =
+ new ExecutionTimeWithInputBytes(20, 10);
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes4 =
+ new ExecutionTimeWithInputBytes(10, 10);
+
+ pairList.clear();
+ pairList.add(executionTimeWithInputBytes3);
+ pairList.add(executionTimeWithInputBytes4);
+
+ sortedList = pairList.stream().sorted().collect(Collectors.toList());
+
+ assertThat(sortedList.get(0)).isEqualTo(executionTimeWithInputBytes4);
+
+ // executions with UNKNOWN input bytes
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes5 =
+ new ExecutionTimeWithInputBytes(20, -1);
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes6 =
+ new ExecutionTimeWithInputBytes(10, -1);
+
+ pairList.clear();
+ pairList.add(executionTimeWithInputBytes5);
+ pairList.add(executionTimeWithInputBytes6);
+
+ sortedList = pairList.stream().sorted().collect(Collectors.toList());
+
+ assertThat(sortedList.get(0)).isEqualTo(executionTimeWithInputBytes6);
+
+ // executions with 0 input bytes and non-zero execution time
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes7 =
+ new ExecutionTimeWithInputBytes(1, 0);
+
+ assertThat(executionTimeWithInputBytes7.compareTo(executionTimeWithInputBytes1))
+ .isEqualTo(1);
+
+ // executions with 0 input bytes and 0 execution time
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes8 =
+ new ExecutionTimeWithInputBytes(0, 0);
+
+ assertThat(executionTimeWithInputBytes8.compareTo(executionTimeWithInputBytes1))
+ .isEqualTo(-1);
+
+ // executions with assigned input bytes mixed with UNKNOWN input bytes
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes9 =
+ new ExecutionTimeWithInputBytes(20, 100);
+ ExecutionTimeWithInputBytes executionTimeWithInputBytes10 =
+ new ExecutionTimeWithInputBytes(15, -1);
+
+ assertThatThrownBy(
+ () -> executionTimeWithInputBytes9.compareTo(executionTimeWithInputBytes10))
+ .isInstanceOf(IllegalArgumentException.class);
+ }
+
private ExecutionGraph createExecutionGraph(JobVertex... jobVertices) throws Exception {
final JobGraph jobGraph = JobGraphTestUtils.streamingJobGraph(jobVertices);