You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by ro...@apache.org on 2022/09/08 14:30:44 UTC

[incubator-uniffle] branch master updated: [Improvement][AQE] Avoid calling getShuffleResult multiple times (#190)

This is an automated email from the ASF dual-hosted git repository.

roryqi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 6aa43794 [Improvement][AQE] Avoid calling getShuffleResult multiple times (#190)
6aa43794 is described below

commit 6aa43794c1edbd73b3c3937e440cb6fb9db15e86
Author: Xianming Lei <31...@users.noreply.github.com>
AuthorDate: Thu Sep 8 22:30:39 2022 +0800

    [Improvement][AQE] Avoid calling getShuffleResult multiple times (#190)
    
    ###What changes were proposed in this pull request?
    For issue #136  , When we use AQE, we may call shuffleWriteClient.getShuffleResult multiple times. But if both partition 1 and partition 2 are on the server A, we call getShuffleResult(partition 1) to get data form server A, and then we call getShuffleResult(partition 2) to get data form server A, it's not necassray. We can get getShuffleResult(partition 1, partition 2) instead.
    
    ###Why are the changes needed?
    Improve getShuffleResult
    
    ###Does this PR introduce any user-facing change?
    No
    
    ###How was this patch tested?
    Added UT
    
    Co-authored-by: leixianming <le...@didiglobal.com>
---
 .../hadoop/mapred/SortWriteBufferManagerTest.java  |   6 +
 .../hadoop/mapreduce/task/reduce/FetcherTest.java  |   6 +
 .../apache/spark/shuffle/RssShuffleManager.java    |  28 +--
 .../uniffle/client/api/ShuffleWriteClient.java     |   3 +
 .../client/impl/ShuffleWriteClientImpl.java        |  41 ++++
 .../org/apache/uniffle/common/util/RssUtils.java   |  39 ++++
 .../apache/uniffle/common/util/RssUtilsTest.java   |  52 +++++
 integration-test/spark3/pom.xml                    |   7 +
 .../test/GetShuffleReportForMultiPartTest.java     | 235 +++++++++++++++++++++
 .../uniffle/client/api/ShuffleServerClient.java    |   3 +
 .../client/impl/grpc/ShuffleServerGrpcClient.java  |  35 +++
 .../RssGetShuffleResultForMultiPartRequest.java    |  44 ++++
 proto/src/main/proto/Rss.proto                     |  13 ++
 .../uniffle/server/ShuffleServerGrpcService.java   |  44 +++-
 .../apache/uniffle/server/ShuffleTaskManager.java  |  59 +++---
 .../server/MockedShuffleServerGrpcService.java     |  35 +++
 .../uniffle/server/ShuffleTaskManagerTest.java     | 117 +++++++++-
 17 files changed, 722 insertions(+), 45 deletions(-)

diff --git a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 3574c9c3..97c58472 100644
--- a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -326,6 +326,12 @@ public class SortWriteBufferManagerTest {
       return null;
     }
 
+    @Override
+    public Roaring64NavigableMap getShuffleResultForMultiPart(String clientType, Map<ShuffleServerInfo,
+        Set<Integer>> serverToPartitions, String appId, int shuffleId) {
+      return null;
+    }
+
     @Override
     public void close() {
 
diff --git a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index 531f5405..5f9713b2 100644
--- a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -416,6 +416,12 @@ public class FetcherTest {
       return null;
     }
 
+    @Override
+    public Roaring64NavigableMap getShuffleResultForMultiPart(String clientType, Map<ShuffleServerInfo,
+        Set<Integer>> serverToPartitions, String appId, int shuffleId) {
+      return null;
+    }
+
     @Override
     public void close() {
 
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index fbbef743..6f04ccff 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.shuffle;
 
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -27,6 +26,7 @@ import java.util.concurrent.ThreadPoolExecutor;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Maps;
@@ -406,18 +406,18 @@ public class RssShuffleManager implements ShuffleManager {
       readBufferSize = Integer.MAX_VALUE;
     }
     int shuffleId = rssShuffleHandle.getShuffleId();
-    Map<Integer, List<ShuffleServerInfo>> partitionToServers =  rssShuffleHandle.getPartitionToServers();
-    Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks = new HashMap<>();
-    for (int partition = startPartition; partition < endPartition; partition++) {
-      long start = System.currentTimeMillis();
-      Roaring64NavigableMap blockIdBitmap = shuffleWriteClient.getShuffleResult(
-          clientType, Sets.newHashSet(partitionToServers.get(partition)),
-          rssShuffleHandle.getAppId(), shuffleId, partition);
-      partitionToExpectBlocks.put(partition, blockIdBitmap);
-      LOG.info("Get shuffle blockId cost " + (System.currentTimeMillis() - start) + " ms, and get "
-          + blockIdBitmap.getLongCardinality() + " blockIds for shuffleId[" + shuffleId + "], partitionId["
-          + partition + "]");
-    }
+    Map<Integer, List<ShuffleServerInfo>> allPartitionToServers = rssShuffleHandle.getPartitionToServers();
+    Map<Integer, List<ShuffleServerInfo>> requirePartitionToServers = allPartitionToServers.entrySet()
+        .stream().filter(x -> x.getKey() >= startPartition && x.getKey() < endPartition)
+            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+    Map<ShuffleServerInfo, Set<Integer>> serverToPartitions = RssUtils.generateServerToPartitions(
+        requirePartitionToServers);
+    long start = System.currentTimeMillis();
+    Roaring64NavigableMap blockIdBitmap = shuffleWriteClient.getShuffleResultForMultiPart(
+        clientType, serverToPartitions, rssShuffleHandle.getAppId(), shuffleId);
+    LOG.info("Get shuffle blockId cost " + (System.currentTimeMillis() - start) + " ms, and get "
+        + blockIdBitmap.getLongCardinality() + " blockIds for shuffleId[" + shuffleId + "], startPartition["
+        + start + "], endPartition[" + endPartition + "]");
 
     ShuffleReadMetrics readMetrics;
     if (metrics != null) {
@@ -445,7 +445,7 @@ public class RssShuffleManager implements ShuffleManager {
         storageType,
         (int) readBufferSize,
         partitionNum,
-        partitionToExpectBlocks,
+        RssUtils.generatePartitionToBitmap(blockIdBitmap, startPartition, endPartition),
         taskIdBitmap,
         readMetrics);
   }
diff --git a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
index d5981c46..f507a96a 100644
--- a/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
+++ b/client/src/main/java/org/apache/uniffle/client/api/ShuffleWriteClient.java
@@ -65,5 +65,8 @@ public interface ShuffleWriteClient {
   Roaring64NavigableMap getShuffleResult(String clientType, Set<ShuffleServerInfo> shuffleServerInfoSet,
       String appId, int shuffleId, int partitionId);
 
+  Roaring64NavigableMap getShuffleResultForMultiPart(String clientType,
+      Map<ShuffleServerInfo, Set<Integer>> serverToPartitions, String appId, int shuffleId);
+
   void close();
 }
diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
index bd194ddf..7be84d4d 100644
--- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
+++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java
@@ -48,6 +48,7 @@ import org.apache.uniffle.client.request.RssFetchClientConfRequest;
 import org.apache.uniffle.client.request.RssFetchRemoteStorageRequest;
 import org.apache.uniffle.client.request.RssFinishShuffleRequest;
 import org.apache.uniffle.client.request.RssGetShuffleAssignmentsRequest;
+import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest;
 import org.apache.uniffle.client.request.RssGetShuffleResultRequest;
 import org.apache.uniffle.client.request.RssRegisterShuffleRequest;
 import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
@@ -495,6 +496,46 @@ public class ShuffleWriteClientImpl implements ShuffleWriteClient {
     return blockIdBitmap;
   }
 
+  @Override
+  public Roaring64NavigableMap getShuffleResultForMultiPart(String clientType,
+      Map<ShuffleServerInfo, Set<Integer>> serverToPartitions, String appId, int shuffleId) {
+    Map<Integer, Integer> partitionReadSuccess = Maps.newHashMap();
+    Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf();
+    for (Map.Entry<ShuffleServerInfo, Set<Integer>> entry : serverToPartitions.entrySet()) {
+      ShuffleServerInfo shuffleServerInfo = entry.getKey();
+      Set<Integer> requestPartitions = Sets.newHashSet();
+      for (Integer partitionId : entry.getValue()) {
+        partitionReadSuccess.putIfAbsent(partitionId, 0);
+        if (partitionReadSuccess.get(partitionId) < replicaRead) {
+          requestPartitions.add(partitionId);
+        }
+      }
+      RssGetShuffleResultForMultiPartRequest request = new RssGetShuffleResultForMultiPartRequest(
+          appId, shuffleId, requestPartitions);
+      try {
+        RssGetShuffleResultResponse response =
+            getShuffleServerClient(shuffleServerInfo).getShuffleResultForMultiPart(request);
+        if (response.getStatusCode() == ResponseStatusCode.SUCCESS) {
+          // merge into blockIds from multiple servers.
+          Roaring64NavigableMap blockIdBitmapOfServer = response.getBlockIdBitmap();
+          blockIdBitmap.or(blockIdBitmapOfServer);
+          for (Integer partitionId : requestPartitions) {
+            Integer oldVal = partitionReadSuccess.get(partitionId);
+            partitionReadSuccess.put(partitionId, oldVal + 1);
+          }
+        }
+      } catch (Exception e) {
+        LOG.warn("Get shuffle result is failed from " + shuffleServerInfo + " for appId[" + appId
+            + "], shuffleId[" + shuffleId + "], requestPartitions" + requestPartitions);
+      }
+    }
+    boolean isSuccessful = partitionReadSuccess.entrySet().stream().allMatch(x -> x.getValue() >= replicaRead);
+    if (!isSuccessful) {
+      throw new RssException("Get shuffle result is failed for appId[" + appId + "], shuffleId[" + shuffleId + "]");
+    }
+    return blockIdBitmap;
+  }
+
   @Override
   public void sendAppHeartbeat(String appId, long timeoutMs) {
     RssAppHeartBeatRequest request = new RssAppHeartBeatRequest(appId, timeoutMs);
diff --git a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
index 1224d8a4..896711bf 100644
--- a/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/util/RssUtils.java
@@ -35,11 +35,15 @@ import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.util.Enumeration;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Properties;
+import java.util.Set;
 
 import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
 import com.google.common.net.InetAddresses;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
@@ -48,6 +52,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.uniffle.common.BufferSegment;
 import org.apache.uniffle.common.ShuffleDataSegment;
 import org.apache.uniffle.common.ShuffleIndexResult;
+import org.apache.uniffle.common.ShuffleServerInfo;
 
 public class RssUtils {
 
@@ -292,4 +297,38 @@ public class RssUtils {
     }
     return hostName.replaceAll("[\\.-]", "_");
   }
+
+  public static Map<Integer, Roaring64NavigableMap> generatePartitionToBitmap(
+      Roaring64NavigableMap shuffleBitmap, int startPartition, int endPartition) {
+    Map<Integer, Roaring64NavigableMap> result = Maps.newHashMap();
+    for (int partitionId = startPartition; partitionId < endPartition; partitionId++) {
+      result.putIfAbsent(partitionId, Roaring64NavigableMap.bitmapOf());
+    }
+    Iterator<Long> it = shuffleBitmap.iterator();
+    long mask = (1L << Constants.PARTITION_ID_MAX_LENGTH) - 1;
+    while (it.hasNext()) {
+      Long blockId = it.next();
+      int partitionId = Math.toIntExact((blockId >> Constants.TASK_ATTEMPT_ID_MAX_LENGTH) & mask);
+      if (partitionId >= startPartition && partitionId < endPartition) {
+        result.get(partitionId).add(blockId);
+      }
+    }
+    return result;
+  }
+
+  public static Map<ShuffleServerInfo, Set<Integer>> generateServerToPartitions(
+      Map<Integer, List<ShuffleServerInfo>> partitionToServers) {
+    Map<ShuffleServerInfo, Set<Integer>> serverToPartitions = Maps.newHashMap();
+    for (Map.Entry<Integer, List<ShuffleServerInfo>> entry : partitionToServers.entrySet()) {
+      int partitionId = entry.getKey();
+      for (ShuffleServerInfo serverInfo : entry.getValue()) {
+        if (!serverToPartitions.containsKey(serverInfo)) {
+          serverToPartitions.put(serverInfo, Sets.newHashSet(partitionId));
+        } else {
+          serverToPartitions.get(serverInfo).add(partitionId);
+        }
+      }
+    }
+    return serverToPartitions;
+  }
 }
diff --git a/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java b/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java
index 97e732f2..a5211174 100644
--- a/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java
+++ b/common/src/test/java/org/apache/uniffle/common/util/RssUtilsTest.java
@@ -27,14 +27,18 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Random;
+import java.util.Set;
 
 import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
 import org.junit.jupiter.api.Test;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 
 import org.apache.uniffle.common.BufferSegment;
 import org.apache.uniffle.common.ShuffleDataSegment;
 import org.apache.uniffle.common.ShuffleIndexResult;
+import org.apache.uniffle.common.ShuffleServerInfo;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -190,6 +194,54 @@ public class RssUtilsTest {
     assertEquals(testStr, extsObjs.get(0).get());
   }
 
+  @Test
+  public void testShuffleBitmapToPartitionBitmap() {
+    Roaring64NavigableMap partition1Bitmap = Roaring64NavigableMap.bitmapOf(
+        getBlockId(0, 0, 0),
+        getBlockId(0, 0, 1),
+        getBlockId(0, 1, 0),
+        getBlockId(0, 1, 1));
+    Roaring64NavigableMap partition2Bitmap = Roaring64NavigableMap.bitmapOf(
+        getBlockId(1, 0, 0),
+        getBlockId(1, 0, 1),
+        getBlockId(1, 1, 0),
+        getBlockId(1, 1, 1));
+    Roaring64NavigableMap shuffleBitmap = Roaring64NavigableMap.bitmapOf();
+    shuffleBitmap.or(partition1Bitmap);
+    shuffleBitmap.or(partition2Bitmap);
+    assertEquals(8, shuffleBitmap.getLongCardinality());
+    Map<Integer, Roaring64NavigableMap> toPartitionBitmap =
+        RssUtils.generatePartitionToBitmap(shuffleBitmap, 0, 2);
+    assertEquals(2, toPartitionBitmap.size());
+    assertEquals(partition1Bitmap, toPartitionBitmap.get(0));
+    assertEquals(partition2Bitmap, toPartitionBitmap.get(1));
+  }
+
+  @Test
+  public void testGenerateServerToPartitions() {
+    Map<Integer, List<ShuffleServerInfo>> partitionToServers = Maps.newHashMap();
+    ShuffleServerInfo server1 = new ShuffleServerInfo("server1", "0.0.0.1", 100);
+    ShuffleServerInfo server2 = new ShuffleServerInfo("server2", "0.0.0.2", 200);
+    ShuffleServerInfo server3 = new ShuffleServerInfo("server3", "0.0.0.3", 300);
+    ShuffleServerInfo server4 = new ShuffleServerInfo("server4", "0.0.0.4", 400);
+    partitionToServers.put(1, Lists.newArrayList(server1, server2));
+    partitionToServers.put(2, Lists.newArrayList(server3, server4));
+    partitionToServers.put(3, Lists.newArrayList(server1, server2));
+    partitionToServers.put(4, Lists.newArrayList(server3, server4));
+    Map<ShuffleServerInfo, Set<Integer>> serverToPartitions = RssUtils.generateServerToPartitions(partitionToServers);
+    assertEquals(4, serverToPartitions.size());
+    assertEquals(serverToPartitions.get(server1), Sets.newHashSet(1, 3));
+    assertEquals(serverToPartitions.get(server2), Sets.newHashSet(1, 3));
+    assertEquals(serverToPartitions.get(server3), Sets.newHashSet(2, 4));
+    assertEquals(serverToPartitions.get(server4), Sets.newHashSet(2, 4));
+  }
+
+  // Copy from ClientUtils
+  private Long getBlockId(long partitionId, long taskAttemptId, long atomicInt) {
+    return (atomicInt << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH))
+        + (partitionId << Constants.TASK_ATTEMPT_ID_MAX_LENGTH) + taskAttemptId;
+  }
+
   interface RssUtilTestDummy {
     String get();
   }
diff --git a/integration-test/spark3/pom.xml b/integration-test/spark3/pom.xml
index 00755220..63a3966e 100644
--- a/integration-test/spark3/pom.xml
+++ b/integration-test/spark3/pom.xml
@@ -73,6 +73,13 @@
             <artifactId>shuffle-server</artifactId>
             <scope>test</scope>
         </dependency>
+        <dependency>
+            <groupId>org.apache.uniffle</groupId>
+            <artifactId>shuffle-server</artifactId>
+            <type>test-jar</type>
+            <version>${project.version}</version>
+            <scope>test</scope>
+        </dependency>
         <dependency>
             <groupId>org.apache.uniffle</groupId>
             <artifactId>coordinator</artifactId>
diff --git a/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java
new file mode 100644
index 00000000..f719c96d
--- /dev/null
+++ b/integration-test/spark3/src/test/java/org/apache/uniffle/test/GetShuffleReportForMultiPartTest.java
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.io.File;
+import java.nio.file.Files;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.RssShuffleHandle;
+import org.apache.spark.shuffle.RssShuffleManager;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.shuffle.ShuffleHandle;
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
+import org.apache.spark.shuffle.ShuffleReader;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec;
+import org.apache.spark.sql.execution.joins.SortMergeJoinExec;
+import org.apache.spark.sql.functions;
+import org.apache.spark.sql.internal.SQLConf;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+import org.roaringbitmap.longlong.Roaring64NavigableMap;
+
+import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.MockedGrpcServer;
+import org.apache.uniffle.server.MockedShuffleServerGrpcService;
+import org.apache.uniffle.server.ShuffleServer;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class GetShuffleReportForMultiPartTest extends SparkIntegrationTestBase {
+
+  private static final int replicateWrite = 3;
+  private static final int replicateRead = 2;
+
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    CoordinatorConf coordinatorConf = getCoordinatorConf();
+    Map<String, String> dynamicConf = Maps.newHashMap();
+    dynamicConf.put(CoordinatorConf.COORDINATOR_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test");
+    dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), StorageType.MEMORY_LOCALFILE_HDFS.name());
+    addDynamicConf(coordinatorConf, dynamicConf);
+    createCoordinatorServer(coordinatorConf);
+    // Create multi shuffle servers
+    createShuffleServers();
+    startServers();
+  }
+
+  private static void createShuffleServers() throws Exception {
+    for (int i = 0; i < 4; i++) {
+      // Copy from IntegrationTestBase#getShuffleServerConf
+      File dataFolder = Files.createTempDirectory("rssdata" + i).toFile();
+      ShuffleServerConf serverConf = new ShuffleServerConf();
+      dataFolder.deleteOnExit();
+      serverConf.setInteger("rss.rpc.server.port", SHUFFLE_SERVER_PORT + i);
+      serverConf.setString("rss.storage.type", StorageType.MEMORY_LOCALFILE_HDFS.name());
+      serverConf.setString("rss.storage.basePath", dataFolder.getAbsolutePath());
+      serverConf.setString("rss.server.buffer.capacity", "671088640");
+      serverConf.setString("rss.server.memory.shuffle.highWaterMark", "50.0");
+      serverConf.setString("rss.server.memory.shuffle.lowWaterMark", "0.0");
+      serverConf.setString("rss.server.read.buffer.capacity", "335544320");
+      serverConf.setString("rss.coordinator.quorum", COORDINATOR_QUORUM);
+      serverConf.setString("rss.server.heartbeat.delay", "1000");
+      serverConf.setString("rss.server.heartbeat.interval", "1000");
+      serverConf.setInteger("rss.jetty.http.port", 18080 + i);
+      serverConf.setInteger("rss.jetty.corePool.size", 64);
+      serverConf.setInteger("rss.rpc.executor.size", 10);
+      serverConf.setString("rss.server.hadoop.dfs.replication", "2");
+      serverConf.setLong("rss.server.disk.capacity", 10L * 1024L * 1024L * 1024L);
+      serverConf.setBoolean("rss.server.health.check.enable", false);
+      createMockedShuffleServer(serverConf);
+    }
+    enableRecordGetShuffleResult();
+  }
+
+  private static void enableRecordGetShuffleResult() {
+    for (ShuffleServer shuffleServer : shuffleServers) {
+      ((MockedGrpcServer) shuffleServer.getServer()).getService()
+          .enableRecordGetShuffleResult();
+    }
+  }
+
+  @Override
+  public void updateCommonSparkConf(SparkConf sparkConf) {
+    sparkConf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "true");
+    sparkConf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD().key(), "-1");
+    sparkConf.set(SQLConf.COALESCE_PARTITIONS_MIN_PARTITION_NUM().key(), "1");
+    sparkConf.set(SQLConf.SHUFFLE_PARTITIONS().key(), "100");
+    sparkConf.set(SQLConf.SKEW_JOIN_SKEWED_PARTITION_THRESHOLD().key(), "800");
+    sparkConf.set(SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES().key(), "800");
+  }
+
+  @Override
+  public void updateSparkConfCustomer(SparkConf sparkConf) {
+    sparkConf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), "HDFS");
+    sparkConf.set(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), HDFS_URI + "rss/test");
+  }
+
+  @Override
+  public void updateSparkConfWithRss(SparkConf sparkConf) {
+    super.updateSparkConfWithRss(sparkConf);
+    // Add multi replica conf
+    sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA.key(), String.valueOf(replicateWrite));
+    sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA_WRITE.key(), String.valueOf(replicateWrite));
+    sparkConf.set(RssSparkConfig.RSS_DATA_REPLICA_READ.key(), String.valueOf(replicateRead));
+
+    sparkConf.set("spark.shuffle.manager",
+        "org.apache.uniffle.test.GetShuffleReportForMultiPartTest$RssShuffleManagerWrapper");
+  }
+
+  @Test
+  public void resultCompareTest() throws Exception {
+    run();
+  }
+
+  @Override
+  Map runTest(SparkSession spark, String fileName) throws Exception {
+    Thread.sleep(4000);
+    Map<Integer, String> map = Maps.newHashMap();
+    Dataset<Row> df2 = spark.range(0, 1000, 1, 10)
+        .select(functions.when(functions.col("id").$less(250), 249)
+            .otherwise(functions.col("id")).as("key2"), functions.col("id").as("value2"));
+    Dataset<Row> df1 = spark.range(0, 1000, 1, 10)
+        .select(functions.when(functions.col("id").$less(250), 249)
+            .when(functions.col("id").$greater(750), 1000)
+                .otherwise(functions.col("id")).as("key1"), functions.col("id").as("value2"));
+    Dataset<Row> df3 = df1.join(df2, df1.col("key1").equalTo(df2.col("key2")));
+
+    List<String> result = Lists.newArrayList();
+    assertTrue(df3.queryExecution().executedPlan().toString().startsWith("AdaptiveSparkPlan isFinalPlan=false"));
+    df3.collectAsList().forEach(row -> {
+      result.add(row.json());
+    });
+    assertTrue(df3.queryExecution().executedPlan().toString().startsWith("AdaptiveSparkPlan isFinalPlan=true"));
+    AdaptiveSparkPlanExec plan = (AdaptiveSparkPlanExec) df3.queryExecution().executedPlan();
+    SortMergeJoinExec joinExec = (SortMergeJoinExec) plan.executedPlan().children().iterator().next();
+    assertTrue(joinExec.isSkewJoin());
+    result.sort(new Comparator<String>() {
+      @Override
+      public int compare(String o1, String o2) {
+        return o1.compareTo(o2);
+      }
+    });
+    int i = 0;
+    for (String str : result) {
+      map.put(i, str);
+      i++;
+    }
+    SparkConf conf = spark.sparkContext().conf();
+    if (!conf.get("spark.shuffle.manager", "").isEmpty()) {
+      RssShuffleManagerWrapper mockRssShuffleManager =
+          (RssShuffleManagerWrapper) spark.sparkContext().env().shuffleManager();
+      int expectRequestNum = mockRssShuffleManager.getShuffleIdToPartitionNum().values().stream()
+          .mapToInt(x -> x.get()).sum();
+      // Validate getShuffleResultForMultiPart is correct before return result
+      validateRequestCount(expectRequestNum * replicateRead);
+    }
+    return map;
+  }
+
+  public void validateRequestCount(int expectRequestNum) {
+    for (ShuffleServer shuffleServer : shuffleServers) {
+      MockedShuffleServerGrpcService service = ((MockedGrpcServer) shuffleServer.getServer()).getService();
+      Map<String, Map<Integer, AtomicInteger>> serviceRequestCount = service.getShuffleIdToPartitionRequest();
+      int requestNum = serviceRequestCount.entrySet().stream().flatMap(x -> x.getValue().values()
+           .stream()).mapToInt(AtomicInteger::get).sum();
+      expectRequestNum -= requestNum;
+    }
+    assertEquals(0, expectRequestNum);
+  }
+
+  public static class RssShuffleManagerWrapper extends RssShuffleManager {
+
+    // shuffleId -> partShouldRequestNum
+    Map<Integer, AtomicInteger> shuffleToPartShouldRequestNum = Maps.newConcurrentMap();
+
+    public RssShuffleManagerWrapper(SparkConf conf, boolean isDriver) {
+      super(conf, isDriver);
+    }
+
+    @Override
+    public <K, C> ShuffleReader<K, C> getReaderImpl(
+        ShuffleHandle handle,
+        int startMapIndex,
+        int endMapIndex,
+        int startPartition,
+        int endPartition,
+        TaskContext context,
+        ShuffleReadMetricsReporter metrics,
+        Roaring64NavigableMap taskIdBitmap) {
+      int shuffleId = handle.shuffleId();
+      RssShuffleHandle rssShuffleHandle = (RssShuffleHandle) handle;
+      Map<Integer, List<ShuffleServerInfo>> allPartitionToServers = rssShuffleHandle.getPartitionToServers();
+      int partitionNum = (int) allPartitionToServers.entrySet().stream()
+                                   .filter(x -> x.getKey() >= startPartition && x.getKey() < endPartition).count();
+      AtomicInteger partShouldRequestNum = shuffleToPartShouldRequestNum.computeIfAbsent(shuffleId,
+          x -> new AtomicInteger(0));
+      partShouldRequestNum.addAndGet(partitionNum);
+      return super.getReaderImpl(handle, startMapIndex, endMapIndex, startPartition, endPartition,
+          context, metrics, taskIdBitmap);
+    }
+
+    public Map<Integer, AtomicInteger> getShuffleIdToPartitionNum() {
+      return shuffleToPartShouldRequestNum;
+    }
+  }
+}
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleServerClient.java b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleServerClient.java
index d636c34b..f9873c8b 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleServerClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/api/ShuffleServerClient.java
@@ -22,6 +22,7 @@ import org.apache.uniffle.client.request.RssFinishShuffleRequest;
 import org.apache.uniffle.client.request.RssGetInMemoryShuffleDataRequest;
 import org.apache.uniffle.client.request.RssGetShuffleDataRequest;
 import org.apache.uniffle.client.request.RssGetShuffleIndexRequest;
+import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest;
 import org.apache.uniffle.client.request.RssGetShuffleResultRequest;
 import org.apache.uniffle.client.request.RssRegisterShuffleRequest;
 import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
@@ -54,6 +55,8 @@ public interface ShuffleServerClient {
 
   RssGetShuffleResultResponse getShuffleResult(RssGetShuffleResultRequest request);
 
+  RssGetShuffleResultResponse getShuffleResultForMultiPart(RssGetShuffleResultForMultiPartRequest request);
+
   RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest request);
 
   RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request);
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
index bd27b663..bfebd622 100644
--- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
+++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java
@@ -34,6 +34,7 @@ import org.apache.uniffle.client.request.RssFinishShuffleRequest;
 import org.apache.uniffle.client.request.RssGetInMemoryShuffleDataRequest;
 import org.apache.uniffle.client.request.RssGetShuffleDataRequest;
 import org.apache.uniffle.client.request.RssGetShuffleIndexRequest;
+import org.apache.uniffle.client.request.RssGetShuffleResultForMultiPartRequest;
 import org.apache.uniffle.client.request.RssGetShuffleResultRequest;
 import org.apache.uniffle.client.request.RssRegisterShuffleRequest;
 import org.apache.uniffle.client.request.RssReportShuffleResultRequest;
@@ -67,6 +68,8 @@ import org.apache.uniffle.proto.RssProtos.GetLocalShuffleIndexRequest;
 import org.apache.uniffle.proto.RssProtos.GetLocalShuffleIndexResponse;
 import org.apache.uniffle.proto.RssProtos.GetMemoryShuffleDataRequest;
 import org.apache.uniffle.proto.RssProtos.GetMemoryShuffleDataResponse;
+import org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartRequest;
+import org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartResponse;
 import org.apache.uniffle.proto.RssProtos.GetShuffleResultRequest;
 import org.apache.uniffle.proto.RssProtos.GetShuffleResultResponse;
 import org.apache.uniffle.proto.RssProtos.PartitionToBlockIds;
@@ -449,6 +452,38 @@ public class ShuffleServerGrpcClient extends GrpcClient implements ShuffleServer
     return response;
   }
 
+  @Override
+  public RssGetShuffleResultResponse getShuffleResultForMultiPart(RssGetShuffleResultForMultiPartRequest request) {
+    GetShuffleResultForMultiPartRequest rpcRequest = GetShuffleResultForMultiPartRequest
+        .newBuilder()
+        .setAppId(request.getAppId())
+        .setShuffleId(request.getShuffleId())
+        .addAllPartitions(request.getPartitions())
+        .build();
+    GetShuffleResultForMultiPartResponse rpcResponse = getBlockingStub().getShuffleResultForMultiPart(rpcRequest);
+    StatusCode statusCode = rpcResponse.getStatus();
+
+    RssGetShuffleResultResponse response;
+    switch (statusCode) {
+      case SUCCESS:
+        try {
+          response = new RssGetShuffleResultResponse(ResponseStatusCode.SUCCESS,
+              rpcResponse.getSerializedBitmap().toByteArray());
+        } catch (Exception e) {
+          throw new RuntimeException(e);
+        }
+        break;
+      default:
+        String msg = "Can't get shuffle result from " + host + ":" + port
+            + " for [appId=" + request.getAppId() + ", shuffleId=" + request.getShuffleId()
+            + ", errorMsg:" + rpcResponse.getRetMsg();
+        LOG.error(msg);
+        throw new RssException(msg);
+    }
+
+    return response;
+  }
+
   @Override
   public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request) {
     GetLocalShuffleDataRequest rpcRequest = GetLocalShuffleDataRequest
diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java
new file mode 100644
index 00000000..068d7b1d
--- /dev/null
+++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleResultForMultiPartRequest.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.client.request;
+
+import java.util.Set;
+
+public class RssGetShuffleResultForMultiPartRequest {
+  private String appId;
+  private int shuffleId;
+  private Set<Integer> partitions;
+
+  public RssGetShuffleResultForMultiPartRequest(String appId, int shuffleId, Set<Integer> partitions) {
+    this.appId = appId;
+    this.shuffleId = shuffleId;
+    this.partitions = partitions;
+  }
+
+  public String getAppId() {
+    return appId;
+  }
+
+  public int getShuffleId() {
+    return shuffleId;
+  }
+
+  public Set<Integer> getPartitions() {
+    return partitions;
+  }
+}
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index a9032ca2..26cd8b51 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -32,6 +32,7 @@ service ShuffleServer {
   rpc commitShuffleTask (ShuffleCommitRequest) returns (ShuffleCommitResponse);
   rpc reportShuffleResult (ReportShuffleResultRequest) returns (ReportShuffleResultResponse);
   rpc getShuffleResult (GetShuffleResultRequest) returns (GetShuffleResultResponse);
+  rpc getShuffleResultForMultiPart (GetShuffleResultForMultiPartRequest) returns (GetShuffleResultForMultiPartResponse);
   rpc finishShuffle (FinishShuffleRequest) returns (FinishShuffleResponse);
   rpc requireBuffer (RequireBufferRequest) returns (RequireBufferResponse);
   rpc appHeartbeat(AppHeartBeatRequest) returns (AppHeartBeatResponse);
@@ -141,6 +142,18 @@ message GetShuffleResultResponse {
   bytes serializedBitmap = 3;
 }
 
+message GetShuffleResultForMultiPartRequest {
+  string appId = 1;
+  int32 shuffleId = 2;
+  repeated int32 partitions = 3;
+}
+
+message GetShuffleResultForMultiPartResponse {
+  StatusCode status = 1;
+  string retMsg = 2;
+  bytes serializedBitmap = 3;
+}
+
 message ShufflePartitionRange {
   int32 start = 1;
   int32 end = 2;
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index fb17c36c..0501deca 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -23,6 +23,7 @@ import java.util.stream.Collectors;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
 import com.google.protobuf.ByteString;
 import com.google.protobuf.UnsafeByteOperations;
 import io.grpc.Context;
@@ -50,6 +51,8 @@ import org.apache.uniffle.proto.RssProtos.GetLocalShuffleIndexRequest;
 import org.apache.uniffle.proto.RssProtos.GetLocalShuffleIndexResponse;
 import org.apache.uniffle.proto.RssProtos.GetMemoryShuffleDataRequest;
 import org.apache.uniffle.proto.RssProtos.GetMemoryShuffleDataResponse;
+import org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartRequest;
+import org.apache.uniffle.proto.RssProtos.GetShuffleResultForMultiPartResponse;
 import org.apache.uniffle.proto.RssProtos.GetShuffleResultRequest;
 import org.apache.uniffle.proto.RssProtos.GetShuffleResultResponse;
 import org.apache.uniffle.proto.RssProtos.PartitionToBlockIds;
@@ -366,7 +369,7 @@ public class ShuffleServerGrpcService extends ShuffleServerImplBase {
 
     try {
       serializedBlockIds = shuffleServer.getShuffleTaskManager().getFinishedBlockIds(
-          appId, shuffleId, partitionId);
+          appId, shuffleId, Sets.newHashSet(partitionId));
       if (serializedBlockIds == null) {
         status = StatusCode.INTERNAL_ERROR;
         msg = "Can't get shuffle result for " + requestInfo;
@@ -389,6 +392,45 @@ public class ShuffleServerGrpcService extends ShuffleServerImplBase {
     responseObserver.onCompleted();
   }
 
+  @Override
+  public void getShuffleResultForMultiPart(GetShuffleResultForMultiPartRequest request,
+      StreamObserver<GetShuffleResultForMultiPartResponse> responseObserver) {
+    String appId = request.getAppId();
+    int shuffleId = request.getShuffleId();
+    List<Integer> partitionsList = request.getPartitionsList();
+
+    StatusCode status = StatusCode.SUCCESS;
+    String msg = "OK";
+    GetShuffleResultForMultiPartResponse reply;
+    byte[] serializedBlockIds = null;
+    String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "], partitions" + partitionsList;
+    ByteString serializedBlockIdsBytes = ByteString.EMPTY;
+
+    try {
+      serializedBlockIds = shuffleServer.getShuffleTaskManager().getFinishedBlockIds(
+          appId, shuffleId, Sets.newHashSet(partitionsList));
+      if (serializedBlockIds == null) {
+        status = StatusCode.INTERNAL_ERROR;
+        msg = "Can't get shuffle result for " + requestInfo;
+        LOG.warn(msg);
+      } else {
+        serializedBlockIdsBytes = UnsafeByteOperations.unsafeWrap(serializedBlockIds);
+      }
+    } catch (Exception e) {
+      status = StatusCode.INTERNAL_ERROR;
+      msg = e.getMessage();
+      LOG.error("Error happened when get shuffle result for {}", requestInfo, e);
+    }
+
+    reply = GetShuffleResultForMultiPartResponse.newBuilder()
+        .setStatus(valueOf(status))
+        .setRetMsg(msg)
+        .setSerializedBitmap(serializedBlockIdsBytes)
+        .build();
+    responseObserver.onNext(reply);
+    responseObserver.onCompleted();
+  }
+
   @Override
   public void getLocalShuffleData(GetLocalShuffleDataRequest request,
       StreamObserver<GetLocalShuffleDataResponse> responseObserver) {
diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index 648f83f9..a26006ae 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.server;
 
 import java.io.IOException;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -263,52 +264,56 @@ public class ShuffleTaskManager {
     return requireId;
   }
 
-  public byte[] getFinishedBlockIds(
-      String appId, Integer shuffleId, Integer partitionId) throws IOException {
+  public byte[] getFinishedBlockIds(String appId, Integer shuffleId, Set<Integer> partitions) throws IOException {
     refreshAppId(appId);
-    Storage storage = storageManager.selectStorage(new ShuffleDataReadEvent(appId, shuffleId, partitionId));
-    // update shuffle's timestamp that was recently read.
-    storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
-
+    for (int partitionId : partitions) {
+      Storage storage = storageManager.selectStorage(new ShuffleDataReadEvent(appId, shuffleId, partitionId));
+      // update shuffle's timestamp that was recently read.
+      storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+    }
     Map<Integer, Roaring64NavigableMap[]> shuffleIdToPartitions = partitionsToBlockIds.get(appId);
     if (shuffleIdToPartitions == null) {
       return null;
     }
+
     Roaring64NavigableMap[] blockIds = shuffleIdToPartitions.get(shuffleId);
     if (blockIds == null) {
       return new byte[]{};
     }
-    Roaring64NavigableMap bitmap = blockIds[partitionId % blockIds.length];
-    if (bitmap == null) {
-      return new byte[]{};
+    Map<Integer, Set<Integer>> bitmapIndexToPartitions = Maps.newHashMap();
+    for (int partitionId : partitions) {
+      int bitmapIndex = partitionId % blockIds.length;
+      if (bitmapIndexToPartitions.containsKey(bitmapIndex)) {
+        bitmapIndexToPartitions.get(bitmapIndex).add(partitionId);
+      } else {
+        HashSet<Integer> newHashSet = Sets.newHashSet(partitionId);
+        bitmapIndexToPartitions.put(bitmapIndex, newHashSet);
+      }
     }
 
-    if (partitionId > Constants.MAX_PARTITION_ID) {
-      throw new RuntimeException("Get invalid partitionId[" + partitionId
-          + "] which greater than " + Constants.MAX_PARTITION_ID);
+    Roaring64NavigableMap res = Roaring64NavigableMap.bitmapOf();
+    for (Map.Entry<Integer, Set<Integer>> entry : bitmapIndexToPartitions.entrySet()) {
+      Set<Integer> requestPartitions = entry.getValue();
+      Roaring64NavigableMap bitmap = blockIds[entry.getKey()];
+      getBlockIdsByPartitionId(requestPartitions, bitmap, res);
     }
-
-    return RssUtils.serializeBitMap(getBlockIdsByPartitionId(partitionId, bitmap));
+    return RssUtils.serializeBitMap(res);
   }
 
-  // partitionId is passed as long to calculate minValue/maxValue
-  protected Roaring64NavigableMap getBlockIdsByPartitionId(long partitionId, Roaring64NavigableMap bitmap) {
-    Roaring64NavigableMap result = Roaring64NavigableMap.bitmapOf();
+
+  // filter the specific partition blockId in the bitmap to the resultBitmap
+  protected Roaring64NavigableMap getBlockIdsByPartitionId(Set<Integer> requestPartitions,
+      Roaring64NavigableMap bitmap, Roaring64NavigableMap resultBitmap) {
     LongIterator iter = bitmap.getLongIterator();
-    long minValue = partitionId << Constants.TASK_ATTEMPT_ID_MAX_LENGTH;
-    long maxValue = Long.MAX_VALUE;
-    if (partitionId < Constants.MAX_PARTITION_ID) {
-      maxValue = (partitionId + 1) << (Constants.TASK_ATTEMPT_ID_MAX_LENGTH);
-    }
-    long mask = (1L << (Constants.TASK_ATTEMPT_ID_MAX_LENGTH + Constants.PARTITION_ID_MAX_LENGTH)) - 1;
+    long mask = (1L << Constants.PARTITION_ID_MAX_LENGTH) - 1;
     while (iter.hasNext()) {
       long blockId = iter.next();
-      long partitionAndTask = blockId & mask;
-      if (partitionAndTask >= minValue && partitionAndTask < maxValue) {
-        result.addLong(blockId);
+      int partitionId = Math.toIntExact((blockId >> Constants.TASK_ATTEMPT_ID_MAX_LENGTH) & mask);
+      if (requestPartitions.contains(partitionId)) {
+        resultBitmap.addLong(blockId);
       }
     }
-    return result;
+    return resultBitmap;
   }
 
   public ShuffleDataResult getInMemoryShuffleData(
diff --git a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
index 02b7c29a..a603ff62 100644
--- a/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
+++ b/server/src/test/java/org/apache/uniffle/server/MockedShuffleServerGrpcService.java
@@ -17,8 +17,12 @@
 
 package org.apache.uniffle.server;
 
+import java.util.List;
+import java.util.Map;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
 
+import com.google.common.collect.Maps;
 import com.google.common.util.concurrent.Uninterruptibles;
 import io.grpc.stub.StreamObserver;
 import org.slf4j.Logger;
@@ -31,12 +35,21 @@ public class MockedShuffleServerGrpcService extends ShuffleServerGrpcService {
 
   private static final Logger LOG = LoggerFactory.getLogger(MockedShuffleServerGrpcService.class);
 
+  // appId -> shuffleId -> partitionRequestNum
+  private Map<String, Map<Integer, AtomicInteger>> appToPartitionRequest = Maps.newConcurrentMap();
+
   private long mockedTimeout = -1L;
 
+  private boolean recordGetShuffleResult = false;
+
   public void enableMockedTimeout(long timeout) {
     mockedTimeout = timeout;
   }
 
+  public void enableRecordGetShuffleResult() {
+    recordGetShuffleResult = true;
+  }
+
   public void disableMockedTimeout() {
     mockedTimeout = -1;
   }
@@ -74,4 +87,26 @@ public class MockedShuffleServerGrpcService extends ShuffleServerGrpcService {
     }
     super.getShuffleResult(request, responseObserver);
   }
+
+  @Override
+  public void getShuffleResultForMultiPart(RssProtos.GetShuffleResultForMultiPartRequest request,
+      StreamObserver<RssProtos.GetShuffleResultForMultiPartResponse> responseObserver) {
+    if (mockedTimeout > 0) {
+      LOG.info("Add a mocked timeout on getShuffleResult");
+      Uninterruptibles.sleepUninterruptibly(mockedTimeout, TimeUnit.MILLISECONDS);
+    }
+    if (recordGetShuffleResult) {
+      List<Integer> requestPartitions = request.getPartitionsList();
+      Map<Integer, AtomicInteger> shuffleIdToPartitionRequestNum = appToPartitionRequest.computeIfAbsent(
+          request.getAppId(), x -> Maps.newConcurrentMap());
+      AtomicInteger partitionRequestNum = shuffleIdToPartitionRequestNum.computeIfAbsent(
+          request.getShuffleId(), x -> new AtomicInteger(0));
+      partitionRequestNum.addAndGet(requestPartitions.size());
+    }
+    super.getShuffleResultForMultiPart(request, responseObserver);
+  }
+
+  public Map<String, Map<Integer, AtomicInteger>> getShuffleIdToPartitionRequest() {
+    return appToPartitionRequest;
+  }
 }
diff --git a/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java b/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java
index 571b3c5d..2c8f54a3 100644
--- a/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java
+++ b/server/src/test/java/org/apache/uniffle/server/ShuffleTaskManagerTest.java
@@ -25,6 +25,7 @@ import java.util.Set;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
 import com.google.common.collect.RangeMap;
 import com.google.common.collect.Sets;
 import org.apache.commons.lang3.StringUtils;
@@ -41,6 +42,7 @@ import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.ShufflePartitionedData;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.common.util.Constants;
+import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo;
 import org.apache.uniffle.server.buffer.ShuffleBuffer;
 import org.apache.uniffle.server.buffer.ShuffleBufferManager;
@@ -346,20 +348,129 @@ public class ShuffleTaskManagerTest extends HdfsTestBase {
       }
     }
     Roaring64NavigableMap resultBlockIds = shuffleTaskManager.getBlockIdsByPartitionId(
-        expectedPartitionId, bitmapBlockIds);
+        Sets.newHashSet(expectedPartitionId), bitmapBlockIds, Roaring64NavigableMap.bitmapOf());
     assertEquals(expectedBlockIds, resultBlockIds);
 
     bitmapBlockIds.addLong(getBlockId(0, 0, 0));
-    resultBlockIds = shuffleTaskManager.getBlockIdsByPartitionId(0, bitmapBlockIds);
+    resultBlockIds = shuffleTaskManager.getBlockIdsByPartitionId(Sets.newHashSet(0), bitmapBlockIds,
+        Roaring64NavigableMap.bitmapOf());
     assertEquals(Roaring64NavigableMap.bitmapOf(0L), resultBlockIds);
 
     long expectedBlockId = getBlockId(
         Constants.MAX_PARTITION_ID, Constants.MAX_TASK_ATTEMPT_ID, Constants.MAX_SEQUENCE_NO);
     bitmapBlockIds.addLong(expectedBlockId);
-    resultBlockIds = shuffleTaskManager.getBlockIdsByPartitionId(Constants.MAX_PARTITION_ID, bitmapBlockIds);
+    resultBlockIds = shuffleTaskManager.getBlockIdsByPartitionId(Sets.newHashSet(Math.toIntExact(
+        Constants.MAX_PARTITION_ID)), bitmapBlockIds, Roaring64NavigableMap.bitmapOf());
     assertEquals(Roaring64NavigableMap.bitmapOf(expectedBlockId), resultBlockIds);
   }
 
+  @Test
+  public void getBlockIdsByMultiPartitionTest() {
+    ShuffleServerConf conf = new ShuffleServerConf();
+    ShuffleTaskManager shuffleTaskManager = new ShuffleTaskManager(
+        conf, null, null, null);
+
+    Roaring64NavigableMap expectedBlockIds = Roaring64NavigableMap.bitmapOf();
+    int startPartition = 3;
+    int endPartition = 5;
+    Roaring64NavigableMap bitmapBlockIds = Roaring64NavigableMap.bitmapOf();
+    for (int taskId = 1; taskId < 10; taskId++) {
+      for (int partitionId = 1; partitionId < 10; partitionId++) {
+        for (int i = 0; i < 2; i++) {
+          long blockId = getBlockId(partitionId, taskId, i);
+          bitmapBlockIds.addLong(blockId);
+          if (partitionId >= startPartition && partitionId <= endPartition) {
+            expectedBlockIds.addLong(blockId);
+          }
+        }
+      }
+    }
+    Set<Integer> requestPartitions = Sets.newHashSet();
+    Set<Integer> allPartitions = Sets.newHashSet();
+    for (int partitionId = 1; partitionId < 10; partitionId++) {
+      allPartitions.add(partitionId);
+      if (partitionId >= startPartition && partitionId <= endPartition) {
+        requestPartitions.add(partitionId);
+      }
+    }
+
+    Roaring64NavigableMap resultBlockIds =
+        shuffleTaskManager.getBlockIdsByPartitionId(requestPartitions, bitmapBlockIds,
+            Roaring64NavigableMap.bitmapOf());
+    assertEquals(expectedBlockIds, resultBlockIds);
+    assertEquals(bitmapBlockIds, shuffleTaskManager.getBlockIdsByPartitionId(allPartitions, bitmapBlockIds,
+        Roaring64NavigableMap.bitmapOf()));
+  }
+
+  @Test
+  public void testGetFinishedBlockIds() throws Exception {
+    ShuffleServerConf conf = new ShuffleServerConf();
+    String storageBasePath = HDFS_URI + "rss/test";
+    String appId = "test_app";
+    final int shuffleId = 1;
+    final int bitNum = 3;
+    final int partitionNum = 10;
+    final int taskNum = 10;
+    final int blocksPerTask = 2;
+    conf.set(ShuffleServerConf.RPC_SERVER_PORT, 1234);
+    conf.set(ShuffleServerConf.RSS_COORDINATOR_QUORUM, "localhost:9527");
+    conf.set(ShuffleServerConf.JETTY_HTTP_PORT, 12345);
+    conf.set(ShuffleServerConf.JETTY_CORE_POOL_SIZE, 64);
+    conf.set(ShuffleServerConf.SERVER_BUFFER_CAPACITY, 128L);
+    conf.set(ShuffleServerConf.SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE, 50.0);
+    conf.set(ShuffleServerConf.SERVER_MEMORY_SHUFFLE_LOWWATERMARK_PERCENTAGE, 0.0);
+    conf.set(ShuffleServerConf.RSS_STORAGE_BASE_PATH, Arrays.asList(storageBasePath));
+    conf.set(ShuffleServerConf.RSS_STORAGE_TYPE, StorageType.HDFS.name());
+    conf.set(ShuffleServerConf.SERVER_COMMIT_TIMEOUT, 10000L);
+    conf.set(ShuffleServerConf.SERVER_APP_EXPIRED_WITHOUT_HEARTBEAT, 2000L);
+    conf.set(ShuffleServerConf.HEALTH_CHECK_ENABLE, false);
+
+    ShuffleServer shuffleServer = new ShuffleServer(conf);
+    ShuffleBufferManager shuffleBufferManager = shuffleServer.getShuffleBufferManager();
+    ShuffleFlushManager shuffleFlushManager = shuffleServer.getShuffleFlushManager();
+    StorageManager storageManager = shuffleServer.getStorageManager();
+    ShuffleTaskManager shuffleTaskManager = new ShuffleTaskManager(conf, shuffleFlushManager,
+        shuffleBufferManager, storageManager);
+
+    int startPartition = 6;
+    int endPartition = 9;
+    Roaring64NavigableMap expectedBlockIds = Roaring64NavigableMap.bitmapOf();
+    Map<Integer, long[]>  blockIdsToReport = Maps.newHashMap();
+
+    for (int partitionId = 0; partitionId < partitionNum; partitionId++) {
+      shuffleTaskManager.registerShuffle(
+          appId,
+          shuffleId,
+          Lists.newArrayList(new PartitionRange(partitionId, partitionId)),
+          new RemoteStorageInfo(storageBasePath),
+          StringUtils.EMPTY
+      );
+      long[] blockIds = new long[taskNum * blocksPerTask];
+      for (int taskId = 0; taskId < taskNum; taskId++) {
+        for (int i = 0; i < blocksPerTask; i++) {
+          long blockId = getBlockId(partitionId, taskId, i);
+          blockIds[taskId * blocksPerTask + i] = blockId;
+        }
+      }
+      blockIdsToReport.putIfAbsent(partitionId, blockIds);
+      if (partitionId >= startPartition) {
+        expectedBlockIds.add(blockIds);
+      }
+    }
+    assertEquals((endPartition - startPartition + 1) * taskNum *  blocksPerTask,
+        expectedBlockIds.getLongCardinality());
+
+    shuffleTaskManager.addFinishedBlockIds(appId, shuffleId, blockIdsToReport, bitNum);
+    Set<Integer> requestPartitions = Sets.newHashSet();
+    for (int partitionId = startPartition; partitionId <= endPartition; partitionId++) {
+      requestPartitions.add(partitionId);
+    }
+    byte[] serializeBitMap =
+        shuffleTaskManager.getFinishedBlockIds(appId, shuffleId, requestPartitions);
+    Roaring64NavigableMap resBlockIds = RssUtils.deserializeBitMap(serializeBitMap);
+    assertEquals(expectedBlockIds, resBlockIds);
+  }
+
   // copy from ClientUtils
   private Long getBlockId(long partitionId, long taskAttemptId, long atomicInt) {
     return (atomicInt << (Constants.PARTITION_ID_MAX_LENGTH + Constants.TASK_ATTEMPT_ID_MAX_LENGTH))