You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2023/01/29 14:11:56 UTC

[flink] branch master updated: [FLINK-30707][runtime] Let speculative execution take input data amount into account when detecting slow tasks.

This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 02b09eac623 [FLINK-30707][runtime] Let speculative execution take input data amount into account when detecting slow tasks.
02b09eac623 is described below

commit 02b09eac6238fe059b6cb0ddbe791b761962dd14
Author: sunxia <xi...@gmail.com>
AuthorDate: Tue Jan 17 11:28:02 2023 +0800

    [FLINK-30707][runtime] Let speculative execution take input data amount into account when detecting slow tasks.
    
    This closes #21695.
---
 .../runtime/executiongraph/ExecutionVertex.java    |  14 ++
 .../adaptivebatch/AdaptiveBatchScheduler.java      |  43 ++++++
 .../adaptivebatch/AllToAllBlockingResultInfo.java  |  23 +++
 .../adaptivebatch/BlockingResultInfo.java          |  11 ++
 .../adaptivebatch/PointwiseBlockingResultInfo.java |  27 ++++
 .../ExecutionTimeBasedSlowTaskDetector.java        |  93 ++++++++++--
 .../adaptivebatch/AdaptiveBatchSchedulerTest.java  |  19 +++
 .../AllToAllBlockingResultInfoTest.java            |  15 ++
 ...tVertexParallelismAndInputInfosDeciderTest.java |   6 +
 .../PointwiseBlockingResultInfoTest.java           |  15 ++
 .../ExecutionTimeBasedSlowTaskDetectorTest.java    | 164 +++++++++++++++++++++
 11 files changed, 418 insertions(+), 12 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
index bd3a9794db0..5ff589669b4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
@@ -60,6 +60,8 @@ import static org.apache.flink.util.Preconditions.checkState;
 public class ExecutionVertex
         implements AccessExecutionVertex, Archiveable<ArchivedExecutionVertex> {
 
+    public static final long NUM_BYTES_UNKNOWN = -1;
+
     public static final int MAX_DISTINCT_LOCATIONS_TO_CONSIDER = 8;
 
     // --------------------------------------------------------------------------------------------
@@ -86,6 +88,8 @@ public class ExecutionVertex
 
     private int nextAttemptNumber;
 
+    private long inputBytes;
+
     /** This field holds the allocation id of the last successful assignment. */
     @Nullable private TaskManagerLocation lastAssignedLocation;
 
@@ -141,6 +145,8 @@ public class ExecutionVertex
 
         this.nextAttemptNumber = initialAttemptCount;
 
+        this.inputBytes = NUM_BYTES_UNKNOWN;
+
         this.timeout = timeout;
         this.inputSplits = new ArrayList<>();
 
@@ -169,6 +175,14 @@ public class ExecutionVertex
                 .get(subTaskIndex);
     }
 
+    public void setInputBytes(long inputBytes) {
+        this.inputBytes = inputBytes;
+    }
+
+    public long getInputBytes() {
+        return inputBytes;
+    }
+
     public Execution getPartitionProducer() {
         return currentExecution;
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
index 5032cc0f77b..e05df56b84f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchScheduler.java
@@ -32,7 +32,9 @@ import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertexInputInfo;
 import org.apache.flink.runtime.executiongraph.IOMetrics;
+import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
 import org.apache.flink.runtime.executiongraph.JobStatusListener;
 import org.apache.flink.runtime.executiongraph.MarkPartitionFinishedStrategy;
@@ -76,7 +78,10 @@ import java.util.Optional;
 import java.util.concurrent.Executor;
 import java.util.function.Consumer;
 import java.util.function.Function;
+import java.util.stream.Collectors;
 
+import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.HYBRID_FULL;
+import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.HYBRID_SELECTIVE;
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
@@ -207,6 +212,16 @@ public class AdaptiveBatchScheduler extends DefaultScheduler {
                 });
     }
 
+    @Override
+    public void allocateSlotsAndDeploy(final List<ExecutionVertexID> verticesToDeploy) {
+        List<ExecutionVertex> executionVertices =
+                verticesToDeploy.stream()
+                        .map(this::getExecutionVertex)
+                        .collect(Collectors.toList());
+        enrichInputBytesForExecutionVertices(executionVertices);
+        super.allocateSlotsAndDeploy(verticesToDeploy);
+    }
+
     @Override
     protected void resetForNewExecution(final ExecutionVertexID executionVertexId) {
         final ExecutionVertex executionVertex = getExecutionVertex(executionVertexId);
@@ -327,6 +342,34 @@ public class AdaptiveBatchScheduler extends DefaultScheduler {
         return parallelismAndInputInfos;
     }
 
+    private void enrichInputBytesForExecutionVertices(List<ExecutionVertex> executionVertices) {
+        for (ExecutionVertex ev : executionVertices) {
+            List<IntermediateResult> intermediateResults = ev.getJobVertex().getInputs();
+            boolean hasHybridEdge =
+                    intermediateResults.stream()
+                            .anyMatch(
+                                    ir ->
+                                            ir.getResultType() == HYBRID_FULL
+                                                    || ir.getResultType() == HYBRID_SELECTIVE);
+            if (intermediateResults.isEmpty() || hasHybridEdge) {
+                continue;
+            }
+            long inputBytes = 0;
+            for (IntermediateResult intermediateResult : intermediateResults) {
+                ExecutionVertexInputInfo inputInfo =
+                        ev.getExecutionVertexInputInfo(intermediateResult.getId());
+                IndexRange partitionIndexRange = inputInfo.getPartitionIndexRange();
+                IndexRange subpartitionIndexRange = inputInfo.getSubpartitionIndexRange();
+                BlockingResultInfo blockingResultInfo =
+                        checkNotNull(getBlockingResultInfo(intermediateResult.getId()));
+                inputBytes +=
+                        blockingResultInfo.getNumBytesProduced(
+                                partitionIndexRange, subpartitionIndexRange);
+            }
+            ev.setInputBytes(inputBytes);
+        }
+    }
+
     private void changeJobVertexParallelism(ExecutionJobVertex jobVertex, int parallelism) {
         if (jobVertex.isParallelismDecided()) {
             return;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
index bae270a608a..9f01a1061e1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfo.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
@@ -81,6 +82,28 @@ public class AllToAllBlockingResultInfo extends AbstractBlockingResultInfo {
         }
     }
 
+    @Override
+    public long getNumBytesProduced(
+            IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) {
+        checkState(aggregatedSubpartitionBytes != null, "Not all partition infos are ready");
+        checkState(
+                partitionIndexRange.getStartIndex() == 0
+                        && partitionIndexRange.getEndIndex() == numOfPartitions - 1,
+                "For All-To-All edges, the partition range should always be [0, %s).",
+                numOfPartitions);
+        checkState(
+                subpartitionIndexRange.getEndIndex() < numOfSubpartitions,
+                "Subpartition index %s is out of range.",
+                subpartitionIndexRange.getEndIndex());
+
+        return aggregatedSubpartitionBytes
+                .subList(
+                        subpartitionIndexRange.getStartIndex(),
+                        subpartitionIndexRange.getEndIndex() + 1)
+                .stream()
+                .reduce(0L, Long::sum);
+    }
+
     @Override
     public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) {
         // Once all partitions are finished, we can convert the subpartition bytes to aggregated
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
index 2eeb2a90a20..e836d993869 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/BlockingResultInfo.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.executiongraph.IntermediateResultInfo;
 import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 
@@ -39,6 +40,16 @@ public interface BlockingResultInfo extends IntermediateResultInfo {
      */
     long getNumBytesProduced();
 
+    /**
+     * Return the aggregated num of bytes according to the index range for partition and
+     * subpartition.
+     *
+     * @param partitionIndexRange range of the index of the consumed partition.
+     * @param subpartitionIndexRange range of the index of the consumed subpartition.
+     * @return aggregated bytes according to the index ranges.
+     */
+    long getNumBytesProduced(IndexRange partitionIndexRange, IndexRange subpartitionIndexRange);
+
     /**
      * Record the information of the result partition.
      *
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
index 225b653064b..ed993af9d81 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfo.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
 import java.util.Arrays;
@@ -60,4 +61,30 @@ public class PointwiseBlockingResultInfo extends AbstractBlockingResultInfo {
                 .flatMapToLong(Arrays::stream)
                 .reduce(0L, Long::sum);
     }
+
+    @Override
+    public long getNumBytesProduced(
+            IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) {
+        long inputBytes = 0;
+        for (int i = partitionIndexRange.getStartIndex();
+                i <= partitionIndexRange.getEndIndex();
+                ++i) {
+            checkState(
+                    subpartitionBytesByPartitionIndex.get(i) != null,
+                    "Partition index %s is not ready.",
+                    i);
+            checkState(
+                    subpartitionIndexRange.getEndIndex()
+                            < subpartitionBytesByPartitionIndex.get(i).length,
+                    "Subpartition end index %s is out of range of partition %s.",
+                    subpartitionIndexRange.getEndIndex(),
+                    i);
+            for (int j = subpartitionIndexRange.getStartIndex();
+                    j <= subpartitionIndexRange.getEndIndex();
+                    ++j) {
+                inputBytes += subpartitionBytesByPartitionIndex.get(i)[j];
+            }
+        }
+        return inputBytes;
+    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetector.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetector.java
index 0df536bc2e8..34cb5b47b36 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetector.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetector.java
@@ -41,6 +41,7 @@ import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 
+import static org.apache.flink.runtime.executiongraph.ExecutionVertex.NUM_BYTES_UNKNOWN;
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkState;
 
@@ -118,8 +119,9 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
 
     /**
      * Given that the parallelism is N and the ratio is R, define T as the median of the first N*R
-     * finished tasks' execution time. The baseline will be T*M, where M is the multiplier. A task
-     * will be identified as slow if its execution time is longer than the baseline.
+     * finished tasks' execution time. The baseline will be T*M, where M is the multiplier. Note
+     * that the execution time will be weighted with its input bytes when calculating the median. A
+     * task will be identified as slow if its weighted execution time is longer than the baseline.
      */
     @VisibleForTesting
     Map<ExecutionVertexID, Collection<ExecutionAttemptID>> findSlowTasks(
@@ -131,7 +133,7 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
         final List<ExecutionJobVertex> jobVerticesToCheck = getJobVerticesToCheck(executionGraph);
 
         for (ExecutionJobVertex ejv : jobVerticesToCheck) {
-            final long baseline = getBaseline(ejv, currentTimeMillis);
+            final ExecutionTimeWithInputBytes baseline = getBaseline(ejv, currentTimeMillis);
 
             for (ExecutionVertex ev : ejv.getTaskVertices()) {
                 if (ev.getExecutionState().isTerminal()) {
@@ -168,21 +170,25 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
         return (double) finishedCount / executionJobVertex.getTaskVertices().length;
     }
 
-    private long getBaseline(
+    private ExecutionTimeWithInputBytes getBaseline(
             final ExecutionJobVertex executionJobVertex, final long currentTimeMillis) {
-        final long executionTimeMedian =
+        final ExecutionTimeWithInputBytes weightedExecutionTimeMedian =
                 calculateFinishedTaskExecutionTimeMedian(executionJobVertex, currentTimeMillis);
-        return (long) Math.max(baselineLowerBoundMillis, executionTimeMedian * baselineMultiplier);
+        long multipliedBaseline =
+                (long) (weightedExecutionTimeMedian.getExecutionTime() * baselineMultiplier);
+
+        return new ExecutionTimeWithInputBytes(
+                multipliedBaseline, weightedExecutionTimeMedian.getInputBytes());
     }
 
-    private long calculateFinishedTaskExecutionTimeMedian(
+    private ExecutionTimeWithInputBytes calculateFinishedTaskExecutionTimeMedian(
             final ExecutionJobVertex executionJobVertex, final long currentTime) {
 
         final int baselineExecutionCount =
                 (int) Math.round(executionJobVertex.getParallelism() * baselineRatio);
 
         if (baselineExecutionCount == 0) {
-            return 0;
+            return new ExecutionTimeWithInputBytes(0L, NUM_BYTES_UNKNOWN);
         }
 
         final List<Execution> finishedExecutions =
@@ -193,9 +199,9 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
 
         checkState(finishedExecutions.size() >= baselineExecutionCount);
 
-        final List<Long> firstFinishedExecutions =
+        final List<ExecutionTimeWithInputBytes> firstFinishedExecutions =
                 finishedExecutions.stream()
-                        .map(e -> getExecutionTime(e, currentTime))
+                        .map(e -> getExecutionTimeAndInputBytes(e, currentTime))
                         .sorted()
                         .limit(baselineExecutionCount)
                         .collect(Collectors.toList());
@@ -204,10 +210,18 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
     }
 
     private List<ExecutionAttemptID> findExecutionsExceedingBaseline(
-            Collection<Execution> executions, long baseline, long currentTimeMillis) {
+            Collection<Execution> executions,
+            ExecutionTimeWithInputBytes baseline,
+            long currentTimeMillis) {
         return executions.stream()
                 .filter(e -> !e.getState().isTerminal() && e.getState() != ExecutionState.CANCELING)
-                .filter(e -> getExecutionTime(e, currentTimeMillis) >= baseline)
+                .filter(
+                        e -> {
+                            ExecutionTimeWithInputBytes timeWithBytes =
+                                    getExecutionTimeAndInputBytes(e, currentTimeMillis);
+                            return timeWithBytes.getExecutionTime() >= baselineLowerBoundMillis
+                                    && timeWithBytes.compareTo(baseline) >= 0;
+                        })
                 .map(Execution::getAttemptId)
                 .collect(Collectors.toList());
     }
@@ -225,10 +239,65 @@ public class ExecutionTimeBasedSlowTaskDetector implements SlowTaskDetector {
         }
     }
 
+    private long getExecutionInputBytes(final Execution execution) {
+        return execution.getVertex().getInputBytes();
+    }
+
+    private ExecutionTimeWithInputBytes getExecutionTimeAndInputBytes(
+            Execution execution, final long currentTime) {
+        long executionTime = getExecutionTime(execution, currentTime);
+        long executionInputBytes = getExecutionInputBytes(execution);
+
+        return new ExecutionTimeWithInputBytes(executionTime, executionInputBytes);
+    }
+
     @Override
     public void stop() {
         if (scheduledDetectionFuture != null) {
             scheduledDetectionFuture.cancel(false);
         }
     }
+
+    /** This class defines the execution time and input bytes for an execution. */
+    @VisibleForTesting
+    static class ExecutionTimeWithInputBytes implements Comparable<ExecutionTimeWithInputBytes> {
+
+        private final long executionTime;
+        private final long inputBytes;
+
+        public ExecutionTimeWithInputBytes(long executionTime, long inputBytes) {
+            this.executionTime = executionTime;
+            this.inputBytes = inputBytes;
+        }
+
+        public long getExecutionTime() {
+            return executionTime;
+        }
+
+        public long getInputBytes() {
+            return inputBytes;
+        }
+
+        @Override
+        public int compareTo(ExecutionTimeWithInputBytes other) {
+            // In order to ensure the stability of comparison, it requires both elements' input
+            // bytes should be both valid or both UNKNOWN, unless the execution time is 0.
+            // (When baselineRatio is 0, a baseline of 0 execution time will be generated.)
+            if (inputBytes == NUM_BYTES_UNKNOWN || other.getInputBytes() == NUM_BYTES_UNKNOWN) {
+                if (inputBytes == NUM_BYTES_UNKNOWN && other.getInputBytes() == NUM_BYTES_UNKNOWN
+                        || executionTime == 0
+                        || other.executionTime == 0) {
+                    return (int) (executionTime - other.getExecutionTime());
+                } else {
+                    throw new IllegalArgumentException(
+                            "Both compared elements should be NUM_BYTES_UNKNOWN.");
+                }
+            }
+
+            return Double.compare(
+                    (double) executionTime / Math.max(inputBytes, Double.MIN_VALUE),
+                    (double) other.getExecutionTime()
+                            / Math.max(other.getInputBytes(), Double.MIN_VALUE));
+        }
+    }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
index fec551032cb..b343ea6749b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
@@ -118,6 +118,9 @@ class AdaptiveBatchSchedulerTest {
 
         // check that the jobGraph is updated
         assertThat(sink.getParallelism()).isEqualTo(10);
+
+        // check aggregatedInputDataBytes of each ExecutionVertex calculated.
+        checkAggregatedInputDataBytesIsCalculated(sinkExecutionJobVertex);
     }
 
     @Test
@@ -146,6 +149,9 @@ class AdaptiveBatchSchedulerTest {
 
         // check that the jobGraph is updated
         assertThat(sink.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
+
+        // check aggregatedInputDataBytes of each ExecutionVertex calculated.
+        checkAggregatedInputDataBytesIsCalculated(sinkExecutionJobVertex);
     }
 
     @Test
@@ -280,6 +286,19 @@ class AdaptiveBatchSchedulerTest {
                 getOnlyElement(jobVertex.getProducedDataSets()).getId());
     }
 
+    private void checkAggregatedInputDataBytesIsCalculated(
+            ExecutionJobVertex sinkExecutionJobVertex) {
+        final ExecutionVertex[] executionVertices = sinkExecutionJobVertex.getTaskVertices();
+        long totalInputBytes = 0;
+        for (ExecutionVertex ev : executionVertices) {
+            long executionInputBytes = ev.getInputBytes();
+            assertThat(executionInputBytes).isNotEqualTo(-1);
+            totalInputBytes += executionInputBytes;
+        }
+
+        assertThat(totalInputBytes).isEqualTo(26_000L);
+    }
+
     private void triggerFailedByPartitionNotFound(
             SchedulerBase scheduler,
             ExecutionVertex producerVertex,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
index f1c632b9ea0..32fc107874c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AllToAllBlockingResultInfoTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
@@ -39,6 +40,20 @@ class AllToAllBlockingResultInfoTest {
         testGetNumBytesProduced(true, 96L);
     }
 
+    @Test
+    void testGetNumBytesProducedWithIndexRange() {
+        AllToAllBlockingResultInfo resultInfo =
+                new AllToAllBlockingResultInfo(new IntermediateDataSetID(), 2, 2, false);
+        resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L}));
+        resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {128L, 256L}));
+
+        IndexRange partitionIndexRange = new IndexRange(0, 1);
+        IndexRange subpartitionIndexRange = new IndexRange(0, 0);
+
+        assertThat(resultInfo.getNumBytesProduced(partitionIndexRange, subpartitionIndexRange))
+                .isEqualTo(160L);
+    }
+
     @Test
     void testGetAggregatedSubpartitionBytes() {
         AllToAllBlockingResultInfo resultInfo =
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
index 66584db4c6f..778f1c2df23 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/DefaultVertexParallelismAndInputInfosDeciderTest.java
@@ -542,6 +542,12 @@ class DefaultVertexParallelismAndInputInfosDeciderTest {
             return producedBytes;
         }
 
+        @Override
+        public long getNumBytesProduced(
+                IndexRange partitionIndexRange, IndexRange subpartitionIndexRange) {
+            throw new UnsupportedOperationException();
+        }
+
         @Override
         public void recordPartitionInfo(int partitionIndex, ResultPartitionBytes partitionBytes) {}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java
index 556a2d48876..73821e48637 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/PointwiseBlockingResultInfoTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.scheduler.adaptivebatch;
 
+import org.apache.flink.runtime.executiongraph.IndexRange;
 import org.apache.flink.runtime.executiongraph.ResultPartitionBytes;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
@@ -39,6 +40,20 @@ class PointwiseBlockingResultInfoTest {
         assertThat(resultInfo.getNumBytesProduced()).isEqualTo(192L);
     }
 
+    @Test
+    void testGetNumBytesProducedWithIndexRange() {
+        PointwiseBlockingResultInfo resultInfo =
+                new PointwiseBlockingResultInfo(new IntermediateDataSetID(), 2, 2);
+        resultInfo.recordPartitionInfo(0, new ResultPartitionBytes(new long[] {32L, 64L}));
+        resultInfo.recordPartitionInfo(1, new ResultPartitionBytes(new long[] {128L, 256L}));
+
+        IndexRange partitionIndexRange = new IndexRange(0, 0);
+        IndexRange subpartitionIndexRange = new IndexRange(0, 1);
+
+        assertThat(resultInfo.getNumBytesProduced(partitionIndexRange, subpartitionIndexRange))
+                .isEqualTo(96L);
+    }
+
     @Test
     void testGetBytesWithPartialPartitionInfos() {
         PointwiseBlockingResultInfo resultInfo =
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetectorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetectorTest.java
index 620fcbe67f7..b11f86c80d4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetectorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/slowtaskdetector/ExecutionTimeBasedSlowTaskDetectorTest.java
@@ -34,6 +34,7 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
 import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
+import org.apache.flink.runtime.scheduler.slowtaskdetector.ExecutionTimeBasedSlowTaskDetector.ExecutionTimeWithInputBytes;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.testutils.TestingUtils;
@@ -43,12 +44,16 @@ import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.time.Duration;
+import java.util.ArrayList;
 import java.util.Collection;
+import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ScheduledExecutorService;
+import java.util.stream.Collectors;
 
 import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createNoOpVertex;
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Tests for {@link ExecutionTimeBasedSlowTaskDetector}. */
 class ExecutionTimeBasedSlowTaskDetectorTest {
@@ -200,6 +205,165 @@ class ExecutionTimeBasedSlowTaskDetectorTest {
         assertThat(slowTasks).hasSize(2);
     }
 
+    @Test
+    void testBalancedInput() throws Exception {
+        final int parallelism = 3;
+        final JobVertex jobVertex1 = createNoOpVertex(parallelism);
+        final JobVertex jobVertex2 = createNoOpVertex(parallelism);
+        jobVertex2.connectNewDataSetAsInput(
+                jobVertex1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        final ExecutionGraph executionGraph = createExecutionGraph(jobVertex1, jobVertex2);
+        final ExecutionTimeBasedSlowTaskDetector slowTaskDetector =
+                createSlowTaskDetector(0.3, 1, 0);
+
+        final ExecutionVertex ev21 =
+                executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[0];
+        ev21.setInputBytes(1024);
+        final ExecutionVertex ev22 =
+                executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[1];
+        ev22.setInputBytes(1024);
+        final ExecutionVertex ev23 =
+                executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[2];
+        ev23.setInputBytes(1024);
+
+        ev23.getCurrentExecutionAttempt().markFinished();
+
+        final Map<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks =
+                slowTaskDetector.findSlowTasks(executionGraph);
+
+        assertThat(slowTasks).hasSize(2);
+    }
+
+    @Test
+    void testBalancedInputWithLargeLowerBound() throws Exception {
+        final int parallelism = 3;
+        final JobVertex jobVertex1 = createNoOpVertex(parallelism);
+        final JobVertex jobVertex2 = createNoOpVertex(parallelism);
+        jobVertex2.connectNewDataSetAsInput(
+                jobVertex1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        final ExecutionGraph executionGraph = createExecutionGraph(jobVertex1, jobVertex2);
+        final ExecutionTimeBasedSlowTaskDetector slowTaskDetector =
+                createSlowTaskDetector(0.3, 1, Integer.MAX_VALUE);
+
+        final ExecutionVertex ev21 =
+                executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[0];
+        ev21.setInputBytes(1024);
+        final ExecutionVertex ev22 =
+                executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[1];
+        ev22.setInputBytes(1024);
+        final ExecutionVertex ev23 =
+                executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[2];
+        ev23.setInputBytes(1024);
+
+        ev23.getCurrentExecutionAttempt().markFinished();
+
+        final Map<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks =
+                slowTaskDetector.findSlowTasks(executionGraph);
+
+        assertThat(slowTasks).isEmpty();
+    }
+
+    @Test
+    void testUnbalancedInput() throws Exception {
+        final int parallelism = 3;
+        final JobVertex jobVertex1 = createNoOpVertex(parallelism);
+        final JobVertex jobVertex2 = createNoOpVertex(parallelism);
+        jobVertex2.connectNewDataSetAsInput(
+                jobVertex1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        final ExecutionGraph executionGraph = createExecutionGraph(jobVertex1, jobVertex2);
+        final ExecutionTimeBasedSlowTaskDetector slowTaskDetector =
+                createSlowTaskDetector(0.3, 1, 0);
+
+        final ExecutionVertex ev21 =
+                executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[0];
+        ev21.setInputBytes(1024);
+        final ExecutionVertex ev22 =
+                executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[1];
+        ev22.setInputBytes(1_024_000);
+        final ExecutionVertex ev23 =
+                executionGraph.getJobVertex(jobVertex2.getID()).getTaskVertices()[2];
+        ev23.setInputBytes(4_096_000);
+
+        Thread.sleep(1000);
+        ev21.getCurrentExecutionAttempt().markFinished();
+
+        final Map<ExecutionVertexID, Collection<ExecutionAttemptID>> slowTasks =
+                slowTaskDetector.findSlowTasks(executionGraph);
+
+        // no task will be detected as slow task
+        assertThat(slowTasks).hasSize(0);
+    }
+
+    @Test
+    void testSortedExecutionTimeWithInputBytes() {
+        // executions with unbalanced input bytes
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes1 =
+                new ExecutionTimeWithInputBytes(10, 10);
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes2 =
+                new ExecutionTimeWithInputBytes(10, 20);
+
+        List<ExecutionTimeWithInputBytes> pairList = new ArrayList<>();
+        pairList.add(executionTimeWithInputBytes1);
+        pairList.add(executionTimeWithInputBytes2);
+
+        List<ExecutionTimeWithInputBytes> sortedList =
+                pairList.stream().sorted().collect(Collectors.toList());
+
+        assertThat(sortedList.get(0)).isEqualTo(executionTimeWithInputBytes2);
+
+        // executions with balanced input bytes
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes3 =
+                new ExecutionTimeWithInputBytes(20, 10);
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes4 =
+                new ExecutionTimeWithInputBytes(10, 10);
+
+        pairList.clear();
+        pairList.add(executionTimeWithInputBytes3);
+        pairList.add(executionTimeWithInputBytes4);
+
+        sortedList = pairList.stream().sorted().collect(Collectors.toList());
+
+        assertThat(sortedList.get(0)).isEqualTo(executionTimeWithInputBytes4);
+
+        // executions with UNKNOWN input bytes
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes5 =
+                new ExecutionTimeWithInputBytes(20, -1);
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes6 =
+                new ExecutionTimeWithInputBytes(10, -1);
+
+        pairList.clear();
+        pairList.add(executionTimeWithInputBytes5);
+        pairList.add(executionTimeWithInputBytes6);
+
+        sortedList = pairList.stream().sorted().collect(Collectors.toList());
+
+        assertThat(sortedList.get(0)).isEqualTo(executionTimeWithInputBytes6);
+
+        // executions with 0 input bytes and non-zero execution time
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes7 =
+                new ExecutionTimeWithInputBytes(1, 0);
+
+        assertThat(executionTimeWithInputBytes7.compareTo(executionTimeWithInputBytes1))
+                .isEqualTo(1);
+
+        // executions with 0 input bytes and 0 execution time
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes8 =
+                new ExecutionTimeWithInputBytes(0, 0);
+
+        assertThat(executionTimeWithInputBytes8.compareTo(executionTimeWithInputBytes1))
+                .isEqualTo(-1);
+
+        // executions with assigned input bytes mixed with UNKNOWN input bytes
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes9 =
+                new ExecutionTimeWithInputBytes(20, 100);
+        ExecutionTimeWithInputBytes executionTimeWithInputBytes10 =
+                new ExecutionTimeWithInputBytes(15, -1);
+
+        assertThatThrownBy(
+                        () -> executionTimeWithInputBytes9.compareTo(executionTimeWithInputBytes10))
+                .isInstanceOf(IllegalArgumentException.class);
+    }
+
     private ExecutionGraph createExecutionGraph(JobVertex... jobVertices) throws Exception {
         final JobGraph jobGraph = JobGraphTestUtils.streamingJobGraph(jobVertices);