You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by xt...@apache.org on 2023/01/05 11:24:36 UTC

[flink] 01/03: [FLINK-30185][rest][refactor] Distinguish which ExecutionAttemptID each ThreadInfoSample belongs to

This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit b9ffa7ac1ff5db3e9e9189d242149a01e543256d
Author: 1996fanrui <19...@gmail.com>
AuthorDate: Wed Dec 21 17:05:08 2022 +0800

    [FLINK-30185][rest][refactor] Distinguish which ExecutionAttemptID each ThreadInfoSample belongs to
---
 .../runtime/messages/TaskThreadInfoResponse.java   |  8 ++-
 .../flink/runtime/messages/ThreadInfoSample.java   | 15 ++--
 .../flink/runtime/taskexecutor/TaskExecutor.java   | 16 +++--
 .../taskexecutor/ThreadInfoSampleService.java      | 49 +++++++------
 .../org/apache/flink/runtime/util/JvmUtils.java    |  6 +-
 .../threadinfo/JobVertexThreadInfoStats.java       | 10 +--
 .../threadinfo/ThreadInfoRequestCoordinator.java   | 15 ++--
 .../taskexecutor/ThreadInfoSampleServiceTest.java  | 82 ++++++++++++++++------
 .../threadinfo/JobVertexThreadInfoTrackerTest.java | 48 +++++--------
 .../ThreadInfoRequestCoordinatorTest.java          | 30 +++++---
 10 files changed, 170 insertions(+), 109 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/TaskThreadInfoResponse.java b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/TaskThreadInfoResponse.java
index fd2987c1edc..525131f9d0c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/TaskThreadInfoResponse.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/TaskThreadInfoResponse.java
@@ -18,24 +18,26 @@
 
 package org.apache.flink.runtime.messages;
 
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.util.Preconditions;
 
 import java.io.Serializable;
 import java.util.Collection;
+import java.util.Map;
 
 /** Response to the request to collect thread details samples. */
 public class TaskThreadInfoResponse implements Serializable {
 
     private static final long serialVersionUID = -4786454630050578031L;
 
-    private final Collection<ThreadInfoSample> samples;
+    private final Map<ExecutionAttemptID, Collection<ThreadInfoSample>> samples;
 
     /**
      * Creates a response to the request to collect thread details samples.
      *
      * @param samples Thread info samples.
      */
-    public TaskThreadInfoResponse(Collection<ThreadInfoSample> samples) {
+    public TaskThreadInfoResponse(Map<ExecutionAttemptID, Collection<ThreadInfoSample>> samples) {
         this.samples = Preconditions.checkNotNull(samples);
     }
 
@@ -44,7 +46,7 @@ public class TaskThreadInfoResponse implements Serializable {
      *
      * @return A collection of thread info samples for a particular execution attempt (Task)
      */
-    public Collection<ThreadInfoSample> getSamples() {
+    public Map<ExecutionAttemptID, Collection<ThreadInfoSample>> getSamples() {
         return samples;
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/ThreadInfoSample.java b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/ThreadInfoSample.java
index 2de541a10b9..e13d948c0e1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/messages/ThreadInfoSample.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/messages/ThreadInfoSample.java
@@ -23,6 +23,7 @@ import javax.annotation.Nullable;
 import java.io.Serializable;
 import java.lang.management.ThreadInfo;
 import java.util.Collection;
+import java.util.Map;
 import java.util.Optional;
 import java.util.stream.Collectors;
 
@@ -63,13 +64,15 @@ public class ThreadInfoSample implements Serializable {
      * @param threadInfos the collection of {@link ThreadInfo}.
      * @return the collection of the corresponding {@link ThreadInfoSample}s.
      */
-    public static Collection<ThreadInfoSample> from(Collection<ThreadInfo> threadInfos) {
+    public static Map<Long, ThreadInfoSample> from(Collection<ThreadInfo> threadInfos) {
         return threadInfos.stream()
-                .map(
-                        threadInfo ->
-                                new ThreadInfoSample(
-                                        threadInfo.getThreadState(), threadInfo.getStackTrace()))
-                .collect(Collectors.toList());
+                .collect(
+                        Collectors.toMap(
+                                ThreadInfo::getThreadId,
+                                threadInfo ->
+                                        new ThreadInfoSample(
+                                                threadInfo.getThreadState(),
+                                                threadInfo.getStackTrace())));
     }
 
     public Thread.State getThreadState() {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
index bd8941622f0..218ecdaab46 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
@@ -571,11 +571,17 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
             }
         }
 
-        Collection<SampleableTask> sampleableTasks =
-                tasks.stream().map(SampleableTaskAdapter::fromTask).collect(Collectors.toList());
-
-        final CompletableFuture<Collection<ThreadInfoSample>> stackTracesFuture =
-                threadInfoSampleService.requestThreadInfoSamples(sampleableTasks, requestParams);
+        Map<Long, ExecutionAttemptID> sampleableTasks =
+                tasks.stream()
+                        .collect(
+                                Collectors.toMap(
+                                        task -> task.getExecutingThread().getId(),
+                                        Task::getExecutionId));
+
+        final CompletableFuture<Map<ExecutionAttemptID, Collection<ThreadInfoSample>>>
+                stackTracesFuture =
+                        threadInfoSampleService.requestThreadInfoSamples(
+                                sampleableTasks, requestParams);
 
         return stackTracesFuture.thenApply(TaskThreadInfoResponse::new);
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/ThreadInfoSampleService.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/ThreadInfoSampleService.java
index 69c7861d76c..7e29ba394e4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/ThreadInfoSampleService.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/ThreadInfoSampleService.java
@@ -19,6 +19,7 @@
 
 package org.apache.flink.runtime.taskexecutor;
 
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.messages.ThreadInfoSample;
 import org.apache.flink.runtime.util.JvmUtils;
 import org.apache.flink.runtime.webmonitor.threadinfo.ThreadInfoSamplesRequest;
@@ -28,6 +29,8 @@ import java.io.IOException;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
@@ -49,52 +52,55 @@ class ThreadInfoSampleService implements Closeable {
      * Returns a future that completes with a given number of thread info samples for a set of task
      * threads.
      *
-     * @param tasks The tasks to be sampled.
+     * @param threads the map key is thread id, the map value is the ExecutionAttemptID.
      * @param requestParams Parameters of the sampling request.
      * @return A future containing the stack trace samples.
      */
-    public CompletableFuture<Collection<ThreadInfoSample>> requestThreadInfoSamples(
-            final Collection<? extends SampleableTask> tasks,
-            final ThreadInfoSamplesRequest requestParams) {
-        checkNotNull(tasks, "task must not be null");
+    public CompletableFuture<Map<ExecutionAttemptID, Collection<ThreadInfoSample>>>
+            requestThreadInfoSamples(
+                    Map<Long, ExecutionAttemptID> threads,
+                    final ThreadInfoSamplesRequest requestParams) {
+        checkNotNull(threads, "threads must not be null");
         checkNotNull(requestParams, "requestParams must not be null");
 
-        CompletableFuture<Collection<ThreadInfoSample>> resultFuture = new CompletableFuture<>();
+        CompletableFuture<Map<ExecutionAttemptID, Collection<ThreadInfoSample>>> resultFuture =
+                new CompletableFuture<>();
         scheduledExecutor.execute(
                 () ->
                         requestThreadInfoSamples(
-                                tasks,
+                                threads,
                                 requestParams.getNumSamples(),
                                 requestParams.getDelayBetweenSamples(),
                                 requestParams.getMaxStackTraceDepth(),
-                                new ArrayList<>(requestParams.getNumSamples()),
+                                new HashMap<>(threads.size()),
                                 resultFuture));
         return resultFuture;
     }
 
     private void requestThreadInfoSamples(
-            final Collection<? extends SampleableTask> tasks,
+            Map<Long, ExecutionAttemptID> threads,
             final int numSamples,
             final Duration delayBetweenSamples,
             final int maxStackTraceDepth,
-            final Collection<ThreadInfoSample> currentTraces,
-            final CompletableFuture<Collection<ThreadInfoSample>> resultFuture) {
+            final Map<ExecutionAttemptID, Collection<ThreadInfoSample>> currentTraces,
+            final CompletableFuture<Map<ExecutionAttemptID, Collection<ThreadInfoSample>>>
+                    resultFuture) {
 
-        final Collection<Long> threadIds =
-                tasks.stream()
-                        .map(t -> t.getExecutingThread().getId())
-                        .collect(Collectors.toList());
-
-        final Collection<ThreadInfoSample> threadInfoSample =
-                JvmUtils.createThreadInfoSample(threadIds, maxStackTraceDepth);
+        final Map<Long, ThreadInfoSample> threadInfoSample =
+                JvmUtils.createThreadInfoSample(threads.keySet(), maxStackTraceDepth);
 
         if (!threadInfoSample.isEmpty()) {
-            currentTraces.addAll(threadInfoSample);
+            for (Map.Entry<Long, ThreadInfoSample> entry : threadInfoSample.entrySet()) {
+                ExecutionAttemptID executionAttemptID = threads.get(entry.getKey());
+                Collection<ThreadInfoSample> threadInfoSamples =
+                        currentTraces.computeIfAbsent(executionAttemptID, key -> new ArrayList<>());
+                threadInfoSamples.add(entry.getValue());
+            }
             if (numSamples > 1) {
                 scheduledExecutor.schedule(
                         () ->
                                 requestThreadInfoSamples(
-                                        tasks,
+                                        threads,
                                         numSamples - 1,
                                         delayBetweenSamples,
                                         maxStackTraceDepth,
@@ -111,8 +117,7 @@ class ThreadInfoSampleService implements Closeable {
             resultFuture.complete(currentTraces);
         } else {
             final String ids =
-                    tasks.stream()
-                            .map(SampleableTask::getExecutionId)
+                    threads.values().stream()
                             .map(e -> e == null ? "unknown" : e.toString())
                             .collect(Collectors.joining(", ", "[", "]"));
             resultFuture.completeExceptionally(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/JvmUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/JvmUtils.java
index 128fed7f8bb..89a8f8fe0cc 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/util/JvmUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/JvmUtils.java
@@ -29,6 +29,7 @@ import java.lang.management.ThreadMXBean;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
+import java.util.Map;
 import java.util.Optional;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
@@ -70,9 +71,10 @@ public final class JvmUtils {
      *
      * @param threadIds The IDs of the threads to create the thread dump for.
      * @param maxStackTraceDepth The maximum number of entries in the stack trace to be collected.
-     * @return The thread information for the requested thread IDs.
+     * @return The map key is the thread id, the map value is the thread information for the
+     *     requested thread IDs.
      */
-    public static Collection<ThreadInfoSample> createThreadInfoSample(
+    public static Map<Long, ThreadInfoSample> createThreadInfoSample(
             Collection<Long> threadIds, int maxStackTraceDepth) {
         ThreadMXBean threadMxBean = ManagementFactory.getThreadMXBean();
         long[] threadIdsArray = threadIds.stream().mapToLong(l -> l).toArray();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/webmonitor/threadinfo/JobVertexThreadInfoStats.java b/flink-runtime/src/main/java/org/apache/flink/runtime/webmonitor/threadinfo/JobVertexThreadInfoStats.java
index 3299bdc9df5..f035bebc67b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/webmonitor/threadinfo/JobVertexThreadInfoStats.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/webmonitor/threadinfo/JobVertexThreadInfoStats.java
@@ -22,8 +22,6 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.messages.ThreadInfoSample;
 import org.apache.flink.runtime.webmonitor.stats.Statistics;
 
-import org.apache.flink.shaded.guava30.com.google.common.collect.ImmutableSet;
-
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Map;
@@ -46,8 +44,7 @@ public class JobVertexThreadInfoStats implements Statistics {
     private final long endTime;
 
     /** Map of thread info samples by execution ID. */
-    private final Map<ImmutableSet<ExecutionAttemptID>, Collection<ThreadInfoSample>>
-            samplesBySubtask;
+    private final Map<ExecutionAttemptID, Collection<ThreadInfoSample>> samplesBySubtask;
 
     /**
      * Creates a thread details sample.
@@ -61,7 +58,7 @@ public class JobVertexThreadInfoStats implements Statistics {
             int requestId,
             long startTime,
             long endTime,
-            Map<ImmutableSet<ExecutionAttemptID>, Collection<ThreadInfoSample>> samplesBySubtask) {
+            Map<ExecutionAttemptID, Collection<ThreadInfoSample>> samplesBySubtask) {
 
         checkArgument(requestId >= 0, "Negative request ID");
         checkArgument(startTime >= 0, "Negative start time");
@@ -106,8 +103,7 @@ public class JobVertexThreadInfoStats implements Statistics {
      *
      * @return Map of thread info samples by task (execution ID)
      */
-    public Map<ImmutableSet<ExecutionAttemptID>, Collection<ThreadInfoSample>>
-            getSamplesBySubtask() {
+    public Map<ExecutionAttemptID, Collection<ThreadInfoSample>> getSamplesBySubtask() {
         return samplesBySubtask;
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/webmonitor/threadinfo/ThreadInfoRequestCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/webmonitor/threadinfo/ThreadInfoRequestCoordinator.java
index 60550e5ed44..eb0929124d5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/webmonitor/threadinfo/ThreadInfoRequestCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/webmonitor/threadinfo/ThreadInfoRequestCoordinator.java
@@ -30,6 +30,7 @@ import org.apache.flink.shaded.guava30.com.google.common.collect.ImmutableSet;
 
 import java.time.Duration;
 import java.util.Collection;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
@@ -41,7 +42,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
 /** A coordinator for triggering and collecting thread info stats of running job vertex subtasks. */
 public class ThreadInfoRequestCoordinator
         extends TaskStatsRequestCoordinator<
-                Collection<ThreadInfoSample>, JobVertexThreadInfoStats> {
+                Map<ExecutionAttemptID, Collection<ThreadInfoSample>>, JobVertexThreadInfoStats> {
 
     /**
      * Creates a new coordinator for the job.
@@ -159,7 +160,9 @@ public class ThreadInfoRequestCoordinator
     // ------------------------------------------------------------------------
 
     private static class PendingThreadInfoRequest
-            extends PendingStatsRequest<Collection<ThreadInfoSample>, JobVertexThreadInfoStats> {
+            extends PendingStatsRequest<
+                    Map<ExecutionAttemptID, Collection<ThreadInfoSample>>,
+                    JobVertexThreadInfoStats> {
 
         PendingThreadInfoRequest(
                 int requestId, Collection<? extends Set<ExecutionAttemptID>> tasksToCollect) {
@@ -168,8 +171,12 @@ public class ThreadInfoRequestCoordinator
 
         @Override
         protected JobVertexThreadInfoStats assembleCompleteStats(long endTime) {
-            return new JobVertexThreadInfoStats(
-                    requestId, startTime, endTime, statsResultByTaskGroup);
+            HashMap<ExecutionAttemptID, Collection<ThreadInfoSample>> samples = new HashMap<>();
+            for (Map<ExecutionAttemptID, Collection<ThreadInfoSample>> map :
+                    statsResultByTaskGroup.values()) {
+                samples.putAll(map);
+            }
+            return new JobVertexThreadInfoStats(requestId, startTime, endTime, samples);
         }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/ThreadInfoSampleServiceTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/ThreadInfoSampleServiceTest.java
index 8928ba25812..6647f46e716 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/ThreadInfoSampleServiceTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/ThreadInfoSampleServiceTest.java
@@ -31,10 +31,12 @@ import org.junit.jupiter.api.Test;
 import java.time.Duration;
 import java.util.Collection;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executors;
+import java.util.stream.Collectors;
 
 import static org.apache.flink.runtime.taskexecutor.IdleTestTask.executeWithTerminationGuarantee;
 import static org.assertj.core.api.Assertions.assertThat;
@@ -75,17 +77,27 @@ public class ThreadInfoSampleServiceTest extends TestLogger {
                     tasks.add(new IdleTestTask());
                     tasks.add(new IdleTestTask());
                     Thread.sleep(2000);
-                    final Collection<ThreadInfoSample> threadInfoSamples =
+
+                    Map<Long, ExecutionAttemptID> threads =
+                            tasks.stream()
+                                    .collect(
+                                            Collectors.toMap(
+                                                    task -> task.getExecutingThread().getId(),
+                                                    IdleTestTask::getExecutionId));
+                    final Map<ExecutionAttemptID, Collection<ThreadInfoSample>> threadInfoSamples =
                             threadInfoSampleService
-                                    .requestThreadInfoSamples(tasks, requestParams)
+                                    .requestThreadInfoSamples(threads, requestParams)
                                     .get();
 
-                    assertThat(threadInfoSamples).hasSize(NUMBER_OF_SAMPLES * 2);
-
-                    for (ThreadInfoSample sample : threadInfoSamples) {
-                        StackTraceElement[] traces = sample.getStackTrace();
-                        assertThat(traces).hasSizeLessThanOrEqualTo(MAX_STACK_TRACK_DEPTH);
+                    int count = 0;
+                    for (Collection<ThreadInfoSample> samples : threadInfoSamples.values()) {
+                        for (ThreadInfoSample sample : samples) {
+                            count++;
+                            StackTraceElement[] traces = sample.getStackTrace();
+                            assertThat(traces).hasSizeLessThanOrEqualTo(MAX_STACK_TRACK_DEPTH);
+                        }
                     }
+                    assertThat(count).isEqualTo(NUMBER_OF_SAMPLES * 2);
                 },
                 tasks);
     }
@@ -97,15 +109,21 @@ public class ThreadInfoSampleServiceTest extends TestLogger {
         executeWithTerminationGuarantee(
                 () -> {
                     tasks.add(new IdleTestTask());
-                    final Collection<ThreadInfoSample> threadInfoSamples1 =
+                    Map<Long, ExecutionAttemptID> threads =
+                            tasks.stream()
+                                    .collect(
+                                            Collectors.toMap(
+                                                    task -> task.getExecutingThread().getId(),
+                                                    IdleTestTask::getExecutionId));
+                    final Map<ExecutionAttemptID, Collection<ThreadInfoSample>> threadInfoSamples1 =
                             threadInfoSampleService
-                                    .requestThreadInfoSamples(tasks, requestParams)
+                                    .requestThreadInfoSamples(threads, requestParams)
                                     .get();
 
-                    final Collection<ThreadInfoSample> threadInfoSamples2 =
+                    final Map<ExecutionAttemptID, Collection<ThreadInfoSample>> threadInfoSamples2 =
                             threadInfoSampleService
                                     .requestThreadInfoSamples(
-                                            tasks,
+                                            threads,
                                             new ThreadInfoSamplesRequest(
                                                     1,
                                                     NUMBER_OF_SAMPLES,
@@ -113,13 +131,17 @@ public class ThreadInfoSampleServiceTest extends TestLogger {
                                                     MAX_STACK_TRACK_DEPTH - 6))
                                     .get();
 
-                    for (ThreadInfoSample sample : threadInfoSamples1) {
-                        assertThat(sample.getStackTrace())
-                                .hasSizeLessThanOrEqualTo(MAX_STACK_TRACK_DEPTH);
+                    for (Collection<ThreadInfoSample> samples : threadInfoSamples1.values()) {
+                        for (ThreadInfoSample sample : samples) {
+                            assertThat(sample.getStackTrace())
+                                    .hasSizeLessThanOrEqualTo(MAX_STACK_TRACK_DEPTH);
+                        }
                     }
 
-                    for (ThreadInfoSample sample : threadInfoSamples2) {
-                        assertThat(sample.getStackTrace()).hasSize(MAX_STACK_TRACK_DEPTH - 6);
+                    for (Collection<ThreadInfoSample> samples : threadInfoSamples2.values()) {
+                        for (ThreadInfoSample sample : samples) {
+                            assertThat(sample.getStackTrace()).hasSize(MAX_STACK_TRACK_DEPTH - 6);
+                        }
                     }
                 },
                 tasks);
@@ -134,8 +156,18 @@ public class ThreadInfoSampleServiceTest extends TestLogger {
                                 executeWithTerminationGuarantee(
                                         () -> {
                                             tasks.add(new IdleTestTask());
+
+                                            Map<Long, ExecutionAttemptID> threads =
+                                                    tasks.stream()
+                                                            .collect(
+                                                                    Collectors.toMap(
+                                                                            task ->
+                                                                                    task.getExecutingThread()
+                                                                                            .getId(),
+                                                                            IdleTestTask
+                                                                                    ::getExecutionId));
                                             threadInfoSampleService.requestThreadInfoSamples(
-                                                    tasks,
+                                                    threads,
                                                     new ThreadInfoSamplesRequest(
                                                             1,
                                                             -1,
@@ -153,8 +185,16 @@ public class ThreadInfoSampleServiceTest extends TestLogger {
             throws ExecutionException, InterruptedException {
         Set<SampleableTask> tasks = new HashSet<>();
         tasks.add(new NotRunningTask());
-        final CompletableFuture<Collection<ThreadInfoSample>> sampleFuture =
-                threadInfoSampleService.requestThreadInfoSamples(tasks, requestParams);
+
+        Map<Long, ExecutionAttemptID> threads =
+                tasks.stream()
+                        .collect(
+                                Collectors.toMap(
+                                        task -> task.getExecutingThread().getId(),
+                                        SampleableTask::getExecutionId));
+        final CompletableFuture<Map<ExecutionAttemptID, Collection<ThreadInfoSample>>>
+                sampleFuture =
+                        threadInfoSampleService.requestThreadInfoSamples(threads, requestParams);
 
         assertThat(sampleFuture).failsWithin(Duration.ofSeconds(10));
         assertThat(sampleFuture.handle((ignored, e) -> e).get())
@@ -163,12 +203,14 @@ public class ThreadInfoSampleServiceTest extends TestLogger {
 
     private static class NotRunningTask implements SampleableTask {
 
+        private final ExecutionAttemptID executionId = ExecutionAttemptID.randomId();
+
         public Thread getExecutingThread() {
             return new Thread();
         }
 
         public ExecutionAttemptID getExecutionId() {
-            return null;
+            return executionId;
         }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/webmonitor/threadinfo/JobVertexThreadInfoTrackerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/webmonitor/threadinfo/JobVertexThreadInfoTrackerTest.java
index f7d4ebfd288..d45aff28279 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/webmonitor/threadinfo/JobVertexThreadInfoTrackerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/webmonitor/threadinfo/JobVertexThreadInfoTrackerTest.java
@@ -58,8 +58,6 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
@@ -74,6 +72,7 @@ import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createScheduler;
+import static org.apache.flink.util.Preconditions.checkState;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThat;
 
 /** Tests for the {@link JobVertexThreadInfoTracker}. */
@@ -90,7 +89,6 @@ public class JobVertexThreadInfoTrackerTest extends TestLogger {
                     .collect(Collectors.toSet());
     private static final JobID JOB_ID = new JobID();
 
-    private static ThreadInfoSample threadInfoSample;
     private static JobVertexThreadInfoStats threadInfoStatsDefaultSample;
 
     private static final Duration CLEAN_UP_INTERVAL = Duration.ofSeconds(60);
@@ -110,13 +108,7 @@ public class JobVertexThreadInfoTrackerTest extends TestLogger {
         // Time gap determines endTime of stats, which controls if the "refresh" is triggered:
         // now >= stats.getEndTime() + statsRefreshInterval
         // Using a small gap to be able to test cache updates without much delay.
-        threadInfoSample =
-                JvmUtils.createThreadInfoSample(
-                                Thread.currentThread().getId(), MAX_STACK_TRACE_DEPTH)
-                        .get();
-        threadInfoStatsDefaultSample =
-                createThreadInfoStats(
-                        REQUEST_ID, SMALL_TIME_GAP, Collections.singletonList(threadInfoSample));
+        threadInfoStatsDefaultSample = createThreadInfoStats(REQUEST_ID, SMALL_TIME_GAP);
         executor = Executors.newScheduledThreadPool(1);
     }
 
@@ -136,8 +128,7 @@ public class JobVertexThreadInfoTrackerTest extends TestLogger {
     /** Tests that cached result is reused within refresh interval. */
     @Test
     public void testCachedStatsNotUpdatedWithinRefreshInterval() throws Exception {
-        final JobVertexThreadInfoStats unusedThreadInfoStats =
-                createThreadInfoStats(1, TIME_GAP, null);
+        final JobVertexThreadInfoStats unusedThreadInfoStats = createThreadInfoStats(1, TIME_GAP);
 
         final JobVertexThreadInfoTracker<JobVertexThreadInfoStats> tracker =
                 createThreadInfoTracker(
@@ -162,10 +153,9 @@ public class JobVertexThreadInfoTrackerTest extends TestLogger {
                 createThreadInfoStats(
                         Instant.now().minus(10, ChronoUnit.SECONDS),
                         REQUEST_ID,
-                        Duration.ofMillis(5),
-                        Collections.singletonList(threadInfoSample));
+                        Duration.ofMillis(5));
         final JobVertexThreadInfoStats threadInfoStatsAfterRefresh =
-                createThreadInfoStats(1, TIME_GAP, Collections.singletonList(threadInfoSample));
+                createThreadInfoStats(1, TIME_GAP);
 
         // register a CountDownLatch with the cache so we can await refresh of the entry
         CountDownLatch cacheRefreshed = new CountDownLatch(1);
@@ -321,32 +311,28 @@ public class JobVertexThreadInfoTrackerTest extends TestLogger {
                 .build();
     }
 
-    private static JobVertexThreadInfoStats createThreadInfoStats(
-            int requestId, Duration timeGap, List<ThreadInfoSample> threadInfoSamples) {
-        return createThreadInfoStats(Instant.now(), requestId, timeGap, threadInfoSamples);
+    private static JobVertexThreadInfoStats createThreadInfoStats(int requestId, Duration timeGap) {
+        return createThreadInfoStats(Instant.now(), requestId, timeGap);
     }
 
     private static JobVertexThreadInfoStats createThreadInfoStats(
-            Instant startTime,
-            int requestId,
-            Duration timeGap,
-            List<ThreadInfoSample> threadInfoSamples) {
+            Instant startTime, int requestId, Duration timeGap) {
         Instant endTime = startTime.plus(timeGap);
 
-        final Map<ImmutableSet<ExecutionAttemptID>, Collection<ThreadInfoSample>>
-                threadInfoRatiosByTask = new HashMap<>();
+        final Map<ExecutionAttemptID, Collection<ThreadInfoSample>> samples = new HashMap<>();
 
         for (ExecutionVertex vertex : TASK_VERTICES) {
-            Set<ExecutionAttemptID> attemptIds = new HashSet<>();
-            attemptIds.add(vertex.getCurrentExecutionAttempt().getAttemptId());
-            threadInfoRatiosByTask.put(ImmutableSet.copyOf(attemptIds), threadInfoSamples);
+            Optional<ThreadInfoSample> threadInfoSample =
+                    JvmUtils.createThreadInfoSample(
+                            Thread.currentThread().getId(), MAX_STACK_TRACE_DEPTH);
+            checkState(threadInfoSample.isPresent(), "The threadInfoSample should be empty.");
+            samples.put(
+                    vertex.getCurrentExecutionAttempt().getAttemptId(),
+                    Collections.singletonList(threadInfoSample.get()));
         }
 
         return new JobVertexThreadInfoStats(
-                requestId,
-                startTime.toEpochMilli(),
-                endTime.toEpochMilli(),
-                threadInfoRatiosByTask);
+                requestId, startTime.toEpochMilli(), endTime.toEpochMilli(), samples);
     }
 
     private static ExecutionJobVertex createExecutionJobVertex() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/webmonitor/threadinfo/ThreadInfoRequestCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/webmonitor/threadinfo/ThreadInfoRequestCoordinatorTest.java
index caaf52bbf39..147195ea859 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/webmonitor/threadinfo/ThreadInfoRequestCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/webmonitor/threadinfo/ThreadInfoRequestCoordinatorTest.java
@@ -38,6 +38,7 @@ import org.junit.jupiter.api.Test;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -115,7 +116,7 @@ public class ThreadInfoRequestCoordinatorTest extends TestLogger {
         // verify the request result
         assertThat(threadInfoStats.getRequestId()).isEqualTo(0);
 
-        Map<ImmutableSet<ExecutionAttemptID>, Collection<ThreadInfoSample>> samplesBySubtask =
+        Map<ExecutionAttemptID, Collection<ThreadInfoSample>> samplesBySubtask =
                 threadInfoStats.getSamplesBySubtask();
 
         for (Collection<ThreadInfoSample> result : samplesBySubtask.values()) {
@@ -237,15 +238,26 @@ public class ThreadInfoRequestCoordinatorTest extends TestLogger {
                         () -> {
                             tasks.add(new IdleTestTask());
                             tasks.add(new IdleTestTask());
-                            //                            Thread.sleep(100);
-                            List<Long> threadIds =
+                            Map<Long, ExecutionAttemptID> threads =
                                     tasks.stream()
-                                            .map(t -> t.getExecutingThread().getId())
-                                            .collect(Collectors.toList());
-                            Collection<ThreadInfoSample> threadInfoSample =
-                                    JvmUtils.createThreadInfoSample(threadIds, 100);
-                            responseFuture.complete(
-                                    new TaskThreadInfoResponse(new ArrayList<>(threadInfoSample)));
+                                            .collect(
+                                                    Collectors.toMap(
+                                                            task ->
+                                                                    task.getExecutingThread()
+                                                                            .getId(),
+                                                            IdleTestTask::getExecutionId));
+
+                            Map<ExecutionAttemptID, Collection<ThreadInfoSample>> threadInfoSample =
+                                    JvmUtils.createThreadInfoSample(threads.keySet(), 100)
+                                            .entrySet().stream()
+                                            .collect(
+                                                    Collectors.toMap(
+                                                            entry -> threads.get(entry.getKey()),
+                                                            entry ->
+                                                                    Collections.singletonList(
+                                                                            entry.getValue())));
+
+                            responseFuture.complete(new TaskThreadInfoResponse(threadInfoSample));
                         },
                         tasks);