You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/08/02 14:59:07 UTC

[spark] branch master updated: [SPARK-36206][CORE] Support shuffle data corruption diagnosis via shuffle checksum

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a98d919  [SPARK-36206][CORE] Support shuffle data corruption diagnosis via shuffle checksum
a98d919 is described below

commit a98d919da470abaf2e99060f99007a5373032fe1
Author: yi.wu <yi...@databricks.com>
AuthorDate: Mon Aug 2 09:58:36 2021 -0500

    [SPARK-36206][CORE] Support shuffle data corruption diagnosis via shuffle checksum
    
    ### What changes were proposed in this pull request?
    
    This PR adds support to diagnose shuffle data corruption. Basically, the diagnosis mechanism works like this:
    The shuffler reader would calculate the checksum (c1) for the corrupted shuffle block and send it to the server where the block is stored. At the server, it would read back the checksum (c2) that is stored in the checksum file and recalculate the checksum (c3) for the corresponding shuffle block. Then, if c2 != c3, we suspect the corruption is caused by the disk issue. Otherwise, if c1 != c3, we suspect the corruption is caused by the network issue. Otherwise, the checksum verifies pa [...]
    
    After the shuffle reader receives the diagnosis response, it'd take the action bases on the type of cause. Only in case of the network issue, we'd give a retry. Otherwise, we'd throw the fetch failure directly. Also note that, if the corruption happens inside BufferReleasingInputStream, the reducer will throw the fetch failure immediately no matter what the cause is since the data has been partially consumed by downstream RDDs. If corruption happens again after retry, the reducer will [...]
    
    Please check out https://github.com/apache/spark/pull/32385 to see the completed proposal of the shuffle checksum project.
    
    ### Why are the changes needed?
    
    Shuffle data corruption is a long-standing issue in Spark. For example, in SPARK-18105, people continually reports corruption issue. However, data corruption is difficult to reproduce in most cases and even harder to tell the root cause. We don't know if it's a Spark issue or not. With the diagnosis support for the shuffle corruption, Spark itself can at least distinguish the cause between disk and network, which is very important for users.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, users may know the cause of the shuffle corruption after this change.
    
    ### How was this patch tested?
    
    Added tests.
    
    Closes #33451 from Ngone51/SPARK-36206.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../spark/network/shuffle/BlockStoreClient.java    |  45 +++++-
 .../network/shuffle/ExternalBlockHandler.java      |  13 +-
 .../network/shuffle/ExternalBlockStoreClient.java  |  22 +--
 .../shuffle/ExternalShuffleBlockResolver.java      |  25 ++++
 .../spark/network/shuffle/checksum/Cause.java      |  25 ++++
 .../shuffle/checksum/ShuffleChecksumHelper.java    | 158 +++++++++++++++++++++
 .../shuffle/protocol/BlockTransferMessage.java     |   4 +-
 .../network/shuffle/protocol/CorruptionCause.java  |  74 ++++++++++
 .../shuffle/protocol/DiagnoseCorruption.java       | 131 +++++++++++++++++
 .../network/shuffle/ExternalBlockHandlerSuite.java | 115 ++++++++++++++-
 .../shuffle/checksum/ShuffleChecksumHelper.java    | 100 -------------
 .../shuffle/checksum/ShuffleChecksumSupport.java   |  45 ++++++
 .../shuffle/sort/BypassMergeSortShuffleWriter.java |  15 +-
 .../spark/shuffle/sort/ShuffleExternalSorter.java  |   9 +-
 .../spark/shuffle/sort/UnsafeShuffleWriter.java    |   2 +-
 .../org/apache/spark/internal/config/package.scala |   4 +-
 .../apache/spark/network/BlockDataManager.scala    |   9 ++
 .../spark/network/netty/NettyBlockRpcServer.scala  |   7 +
 .../network/netty/NettyBlockTransferService.scala  |   3 +-
 .../spark/shuffle/BlockStoreShuffleReader.scala    |   2 +
 .../spark/shuffle/IndexShuffleBlockResolver.scala  |  14 +-
 .../scala/org/apache/spark/storage/BlockId.scala   |   2 +-
 .../org/apache/spark/storage/BlockManager.scala    |  25 +++-
 .../storage/ShuffleBlockFetcherIterator.scala      | 129 +++++++++++++++--
 .../spark/util/collection/ExternalSorter.scala     |   9 +-
 .../shuffle/sort/UnsafeShuffleWriterSuite.java     |  25 ++--
 .../test/scala/org/apache/spark/ShuffleSuite.scala |  37 ++++-
 .../spark/shuffle/ShuffleChecksumTestHelper.scala  |  11 +-
 .../sort/BypassMergeSortShuffleWriterSuite.scala   |  13 +-
 .../sort/IndexShuffleBlockResolverSuite.scala      |   2 +-
 .../shuffle/sort/SortShuffleWriterSuite.scala      |   6 +-
 .../storage/ShuffleBlockFetcherIteratorSuite.scala |  82 ++++++++++-
 32 files changed, 973 insertions(+), 190 deletions(-)

diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java
index 8298846..6dc5fd5 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java
@@ -33,9 +33,9 @@ import org.apache.spark.network.buffer.ManagedBuffer;
 import org.apache.spark.network.client.RpcResponseCallback;
 import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.client.TransportClientFactory;
-import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
-import org.apache.spark.network.shuffle.protocol.GetLocalDirsForExecutors;
-import org.apache.spark.network.shuffle.protocol.LocalDirsForExecutors;
+import org.apache.spark.network.shuffle.checksum.Cause;
+import org.apache.spark.network.shuffle.protocol.*;
+import org.apache.spark.network.util.TransportConf;
 
 /**
  * Provides an interface for reading both shuffle files and RDD blocks, either from an Executor
@@ -46,6 +46,45 @@ public abstract class BlockStoreClient implements Closeable {
 
   protected volatile TransportClientFactory clientFactory;
   protected String appId;
+  protected TransportConf transportConf;
+
+  /**
+   * Send the diagnosis request for the corrupted shuffle block to the server.
+   *
+   * @param host the host of the remote node.
+   * @param port the port of the remote node.
+   * @param execId the executor id.
+   * @param shuffleId the shuffleId of the corrupted shuffle block
+   * @param mapId the mapId of the corrupted shuffle block
+   * @param reduceId the reduceId of the corrupted shuffle block
+   * @param checksum the shuffle checksum which calculated at client side for the corrupted
+   *                 shuffle block
+   * @return The cause of the shuffle block corruption
+   */
+  public Cause diagnoseCorruption(
+      String host,
+      int port,
+      String execId,
+      int shuffleId,
+      long mapId,
+      int reduceId,
+      long checksum,
+      String algorithm) {
+    try {
+      TransportClient client = clientFactory.createClient(host, port);
+      ByteBuffer response = client.sendRpcSync(
+        new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum, algorithm)
+          .toByteBuffer(),
+        transportConf.connectionTimeoutMs()
+      );
+      CorruptionCause cause =
+        (CorruptionCause) BlockTransferMessage.Decoder.fromByteBuffer(response);
+      return cause.cause;
+    } catch (Exception e) {
+      logger.warn("Failed to get the corruption cause.");
+      return Cause.UNKNOWN_ISSUE;
+    }
+  }
 
   /**
    * Fetch a sequence of blocks from a remote node asynchronously,
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
index cfabcd5..71741f2 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
@@ -17,7 +17,6 @@
 
 package org.apache.spark.network.shuffle;
 
-import com.google.common.base.Preconditions;
 import java.io.File;
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -35,8 +34,9 @@ import com.codahale.metrics.MetricSet;
 import com.codahale.metrics.RatioGauge;
 import com.codahale.metrics.Timer;
 import com.codahale.metrics.Counter;
-import com.google.common.collect.Sets;
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Sets;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -49,6 +49,7 @@ import org.apache.spark.network.protocol.MergedBlockMetaRequest;
 import org.apache.spark.network.server.OneForOneStreamManager;
 import org.apache.spark.network.server.RpcHandler;
 import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.shuffle.checksum.Cause;
 import org.apache.spark.network.shuffle.protocol.*;
 import org.apache.spark.network.util.TimerWithCustomTimeUnit;
 import static org.apache.spark.network.util.NettyUtils.getRemoteAddress;
@@ -223,6 +224,14 @@ public class ExternalBlockHandler extends RpcHandler
       } finally {
         responseDelayContext.stop();
       }
+    } else if (msgObj instanceof DiagnoseCorruption) {
+      DiagnoseCorruption msg = (DiagnoseCorruption) msgObj;
+      checkAuth(client, msg.appId);
+      Cause cause = blockManager.diagnoseShuffleBlockCorruption(
+        msg.appId, msg.execId, msg.shuffleId, msg.mapId, msg.reduceId, msg.checksum, msg.algorithm);
+      // In any cases of the error, diagnoseShuffleBlockCorruption should return UNKNOWN_ISSUE,
+      // so it should always reply as success.
+      callback.onSuccess(new CorruptionCause(cause).toByteBuffer());
     } else {
       throw new UnsupportedOperationException("Unexpected message: " + msgObj);
     }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
index eb2d118..826402c 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
@@ -49,7 +49,6 @@ import org.apache.spark.network.util.TransportConf;
 public class ExternalBlockStoreClient extends BlockStoreClient {
   private static final ErrorHandler PUSH_ERROR_HANDLER = new ErrorHandler.BlockPushErrorHandler();
 
-  private final TransportConf conf;
   private final boolean authEnabled;
   private final SecretKeyHolder secretKeyHolder;
   private final long registrationTimeoutMs;
@@ -63,7 +62,7 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
       SecretKeyHolder secretKeyHolder,
       boolean authEnabled,
       long registrationTimeoutMs) {
-    this.conf = conf;
+    this.transportConf = conf;
     this.secretKeyHolder = secretKeyHolder;
     this.authEnabled = authEnabled;
     this.registrationTimeoutMs = registrationTimeoutMs;
@@ -75,10 +74,11 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
    */
   public void init(String appId) {
     this.appId = appId;
-    TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true);
+    TransportContext context = new TransportContext(
+      transportConf, new NoOpRpcHandler(), true, true);
     List<TransportClientBootstrap> bootstraps = Lists.newArrayList();
     if (authEnabled) {
-      bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder));
+      bootstraps.add(new AuthClientBootstrap(transportConf, appId, secretKeyHolder));
     }
     clientFactory = context.createClientFactory(bootstraps);
   }
@@ -94,7 +94,7 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
     checkInit();
     logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
     try {
-      int maxRetries = conf.maxIORetries();
+      int maxRetries = transportConf.maxIORetries();
       RetryingBlockTransferor.BlockTransferStarter blockFetchStarter =
           (inputBlockId, inputListener) -> {
             // Unless this client is closed.
@@ -103,7 +103,7 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
                 "Expecting a BlockFetchingListener, but got " + inputListener.getClass();
               TransportClient client = clientFactory.createClient(host, port, maxRetries > 0);
               new OneForOneBlockFetcher(client, appId, execId, inputBlockId,
-                (BlockFetchingListener) inputListener, conf, downloadFileManager).start();
+                (BlockFetchingListener) inputListener, transportConf, downloadFileManager).start();
             } else {
               logger.info("This clientFactory was closed. Skipping further block fetch retries.");
             }
@@ -112,7 +112,7 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
       if (maxRetries > 0) {
         // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
         // a bug in this code. We should remove the if statement once we're sure of the stability.
-        new RetryingBlockTransferor(conf, blockFetchStarter, blockIds, listener).start();
+        new RetryingBlockTransferor(transportConf, blockFetchStarter, blockIds, listener).start();
       } else {
         blockFetchStarter.createAndStart(blockIds, listener);
       }
@@ -146,16 +146,16 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
               assert inputListener instanceof BlockPushingListener :
                 "Expecting a BlockPushingListener, but got " + inputListener.getClass();
               TransportClient client = clientFactory.createClient(host, port);
-              new OneForOneBlockPusher(client, appId, conf.appAttemptId(), inputBlockId,
+              new OneForOneBlockPusher(client, appId, transportConf.appAttemptId(), inputBlockId,
                 (BlockPushingListener) inputListener, buffersWithId).start();
             } else {
               logger.info("This clientFactory was closed. Skipping further block push retries.");
             }
           };
-      int maxRetries = conf.maxIORetries();
+      int maxRetries = transportConf.maxIORetries();
       if (maxRetries > 0) {
         new RetryingBlockTransferor(
-          conf, blockPushStarter, blockIds, listener, PUSH_ERROR_HANDLER).start();
+          transportConf, blockPushStarter, blockIds, listener, PUSH_ERROR_HANDLER).start();
       } else {
         blockPushStarter.createAndStart(blockIds, listener);
       }
@@ -178,7 +178,7 @@ public class ExternalBlockStoreClient extends BlockStoreClient {
     try {
       TransportClient client = clientFactory.createClient(host, port);
       ByteBuffer finalizeShuffleMerge =
-        new FinalizeShuffleMerge(appId, conf.appAttemptId(), shuffleId,
+        new FinalizeShuffleMerge(appId, transportConf.appAttemptId(), shuffleId,
           shuffleMergeId).toByteBuffer();
       client.sendRpc(finalizeShuffleMerge, new RpcResponseCallback() {
         @Override
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
index 493edd2..73d4e6c 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java
@@ -45,6 +45,8 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
 import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.shuffle.checksum.Cause;
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
 import org.apache.spark.network.util.LevelDBProvider;
 import org.apache.spark.network.util.LevelDBProvider.StoreVersion;
@@ -374,6 +376,29 @@ public class ExternalShuffleBlockResolver {
       .collect(Collectors.toMap(Pair::getKey, Pair::getValue));
   }
 
+  /**
+   * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums
+   */
+  public Cause diagnoseShuffleBlockCorruption(
+      String appId,
+      String execId,
+      int shuffleId,
+      long mapId,
+      int reduceId,
+      long checksumByReader,
+      String algorithm) {
+    ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId));
+    // This should be in sync with IndexShuffleBlockResolver.getChecksumFile
+    String fileName = "shuffle_" + shuffleId + "_" + mapId + "_0.checksum." + algorithm;
+    File checksumFile = ExecutorDiskUtils.getFile(
+      executor.localDirs,
+      executor.subDirsPerLocalDir,
+      fileName);
+    ManagedBuffer data = getBlockData(appId, execId, shuffleId, mapId, reduceId);
+    return ShuffleChecksumHelper.diagnoseCorruption(
+      algorithm, checksumFile, reduceId, data, checksumByReader);
+  }
+
   /** Simply encodes an executor's full ID, which is appId + execId. */
   public static class AppExecId {
     public final String appId;
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/Cause.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/Cause.java
new file mode 100644
index 0000000..d316737
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/Cause.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.checksum;
+
+/**
+ * The cause of shuffle data corruption.
+ */
+public enum Cause {
+  DISK_ISSUE, NETWORK_ISSUE, UNKNOWN_ISSUE, CHECKSUM_VERIFY_PASS, UNSUPPORTED_CHECKSUM_ALGORITHM
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java
new file mode 100644
index 0000000..f332f74
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/checksum/ShuffleChecksumHelper.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.checksum;
+
+import java.io.*;
+import java.util.concurrent.TimeUnit;
+import java.util.zip.Adler32;
+import java.util.zip.CRC32;
+import java.util.zip.CheckedInputStream;
+import java.util.zip.Checksum;
+
+import com.google.common.io.ByteStreams;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * A set of utility functions for the shuffle checksum.
+ */
+@Private
+public class ShuffleChecksumHelper {
+  private static final Logger logger =
+    LoggerFactory.getLogger(ShuffleChecksumHelper.class);
+
+  public static final int CHECKSUM_CALCULATION_BUFFER = 8192;
+  public static final Checksum[] EMPTY_CHECKSUM = new Checksum[0];
+  public static final long[] EMPTY_CHECKSUM_VALUE = new long[0];
+
+  public static Checksum[] createPartitionChecksums(int numPartitions, String algorithm) {
+    return getChecksumsByAlgorithm(numPartitions, algorithm);
+  }
+
+  private static Checksum[] getChecksumsByAlgorithm(int num, String algorithm) {
+    Checksum[] checksums;
+    switch (algorithm) {
+      case "ADLER32":
+        checksums = new Adler32[num];
+        for (int i = 0; i < num; i++) {
+          checksums[i] = new Adler32();
+        }
+        return checksums;
+
+      case "CRC32":
+        checksums = new CRC32[num];
+        for (int i = 0; i < num; i++) {
+          checksums[i] = new CRC32();
+        }
+        return checksums;
+
+      default:
+        throw new UnsupportedOperationException(
+          "Unsupported shuffle checksum algorithm: " + algorithm);
+    }
+  }
+
+  public static Checksum getChecksumByAlgorithm(String algorithm) {
+    return getChecksumsByAlgorithm(1, algorithm)[0];
+  }
+
+  public static String getChecksumFileName(String blockName, String algorithm) {
+    // append the shuffle checksum algorithm as the file extension
+    return String.format("%s.%s", blockName, algorithm);
+  }
+
+  private static long readChecksumByReduceId(File checksumFile, int reduceId) throws IOException {
+    try (DataInputStream in = new DataInputStream(new FileInputStream(checksumFile))) {
+      ByteStreams.skipFully(in, reduceId * 8);
+      return in.readLong();
+    }
+  }
+
+  private static long calculateChecksumForPartition(
+      ManagedBuffer partitionData,
+      Checksum checksumAlgo) throws IOException {
+    InputStream in = partitionData.createInputStream();
+    byte[] buffer = new byte[CHECKSUM_CALCULATION_BUFFER];
+    try(CheckedInputStream checksumIn = new CheckedInputStream(in, checksumAlgo)) {
+      while (checksumIn.read(buffer, 0, CHECKSUM_CALCULATION_BUFFER) != -1) {}
+      return checksumAlgo.getValue();
+    }
+  }
+
+  /**
+   * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums.
+   *
+   * There're 3 different kinds of checksums for the same shuffle partition:
+   *   - checksum (c1) that is calculated by the shuffle data reader
+   *   - checksum (c2) that is calculated by the shuffle data writer and stored in the checksum file
+   *   - checksum (c3) that is recalculated during diagnosis
+   *
+   * And the diagnosis mechanism works like this:
+   * If c2 != c3, we suspect the corruption is caused by the DISK_ISSUE. Otherwise, if c1 != c3,
+   * we suspect the corruption is caused by the NETWORK_ISSUE. Otherwise, the cause remains
+   * CHECKSUM_VERIFY_PASS. In case of the any other failures, the cause remains UNKNOWN_ISSUE.
+   *
+   * @param algorithm The checksum algorithm that is used for calculating checksum value
+   *                  of partitionData
+   * @param checksumFile The checksum file that written by the shuffle writer
+   * @param reduceId The reduceId of the shuffle block
+   * @param partitionData The partition data of the shuffle block
+   * @param checksumByReader The checksum value that calculated by the shuffle data reader
+   * @return The cause of data corruption
+   */
+  public static Cause diagnoseCorruption(
+      String algorithm,
+      File checksumFile,
+      int reduceId,
+      ManagedBuffer partitionData,
+      long checksumByReader) {
+    Cause cause;
+    try {
+      long diagnoseStartNs = System.nanoTime();
+      // Try to get the checksum instance before reading the checksum file so that
+      // `UnsupportedOperationException` can be thrown first before `FileNotFoundException`
+      // when the checksum algorithm isn't supported.
+      Checksum checksumAlgo = getChecksumByAlgorithm(algorithm);
+      long checksumByWriter = readChecksumByReduceId(checksumFile, reduceId);
+      long checksumByReCalculation = calculateChecksumForPartition(partitionData, checksumAlgo);
+      long duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - diagnoseStartNs);
+      logger.info("Shuffle corruption diagnosis took {} ms, checksum file {}",
+        duration, checksumFile.getAbsolutePath());
+      if (checksumByWriter != checksumByReCalculation) {
+        cause = Cause.DISK_ISSUE;
+      } else if (checksumByWriter != checksumByReader) {
+        cause = Cause.NETWORK_ISSUE;
+      } else {
+        cause = Cause.CHECKSUM_VERIFY_PASS;
+      }
+    } catch (UnsupportedOperationException e) {
+      cause = Cause.UNSUPPORTED_CHECKSUM_ALGORITHM;
+    } catch (FileNotFoundException e) {
+      // Even if checksum is enabled, a checksum file may not exist if error throws during writing.
+      logger.warn("Checksum file " + checksumFile.getName() + " doesn't exit");
+      cause = Cause.UNKNOWN_ISSUE;
+    } catch (Exception e) {
+      logger.warn("Unable to diagnose shuffle block corruption", e);
+      cause = Cause.UNKNOWN_ISSUE;
+    }
+    return cause;
+  }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index a55a6cf..453791d 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -49,7 +49,7 @@ public abstract class BlockTransferMessage implements Encodable {
     HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8),
     FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11),
     PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14),
-    FETCH_SHUFFLE_BLOCK_CHUNKS(15);
+    FETCH_SHUFFLE_BLOCK_CHUNKS(15), DIAGNOSE_CORRUPTION(16), CORRUPTION_CAUSE(17);
 
     private final byte id;
 
@@ -84,6 +84,8 @@ public abstract class BlockTransferMessage implements Encodable {
         case 13: return FinalizeShuffleMerge.decode(buf);
         case 14: return MergeStatuses.decode(buf);
         case 15: return FetchShuffleBlockChunks.decode(buf);
+        case 16: return DiagnoseCorruption.decode(buf);
+        case 17: return CorruptionCause.decode(buf);
         default: throw new IllegalArgumentException("Unknown message type: " + type);
       }
     }
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java
new file mode 100644
index 0000000..5690eee
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/CorruptionCause.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.commons.lang3.builder.ToStringBuilder;
+import org.apache.commons.lang3.builder.ToStringStyle;
+
+import org.apache.spark.network.shuffle.checksum.Cause;
+
+/** Response to the {@link DiagnoseCorruption} */
+public class CorruptionCause extends BlockTransferMessage {
+  public Cause cause;
+
+  public CorruptionCause(Cause cause) {
+    this.cause = cause;
+  }
+
+  @Override
+  protected Type type() {
+    return Type.CORRUPTION_CAUSE;
+  }
+
+  @Override
+  public String toString() {
+    return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
+      .append("cause", cause)
+      .toString();
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+
+    CorruptionCause that = (CorruptionCause) o;
+    return cause == that.cause;
+  }
+
+  @Override
+  public int hashCode() {
+    return cause.hashCode();
+  }
+
+  @Override
+  public int encodedLength() {
+    return 1; /* encoded length of cause */
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeByte(cause.ordinal());
+  }
+
+  public static CorruptionCause decode(ByteBuf buf) {
+    int ordinal = buf.readByte();
+    return new CorruptionCause(Cause.values()[ordinal]);
+  }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java
new file mode 100644
index 0000000..620b5ad
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/DiagnoseCorruption.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import io.netty.buffer.ByteBuf;
+import org.apache.commons.lang3.builder.ToStringBuilder;
+import org.apache.commons.lang3.builder.ToStringStyle;
+import org.apache.spark.network.protocol.Encoders;
+
+/** Request to get the cause of a corrupted block. Returns {@link CorruptionCause} */
+public class DiagnoseCorruption extends BlockTransferMessage {
+  public final String appId;
+  public final String execId;
+  public final int shuffleId;
+  public final long mapId;
+  public final int reduceId;
+  public final long checksum;
+  public final String algorithm;
+
+  public DiagnoseCorruption(
+      String appId,
+      String execId,
+      int shuffleId,
+      long mapId,
+      int reduceId,
+      long checksum,
+      String algorithm) {
+    this.appId = appId;
+    this.execId = execId;
+    this.shuffleId = shuffleId;
+    this.mapId = mapId;
+    this.reduceId = reduceId;
+    this.checksum = checksum;
+    this.algorithm = algorithm;
+  }
+
+  @Override
+  protected Type type() {
+    return Type.DIAGNOSE_CORRUPTION;
+  }
+
+  @Override
+  public String toString() {
+    return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
+      .append("appId", appId)
+      .append("execId", execId)
+      .append("shuffleId", shuffleId)
+      .append("mapId", mapId)
+      .append("reduceId", reduceId)
+      .append("checksum", checksum)
+      .append("algorithm", algorithm)
+      .toString();
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+
+    DiagnoseCorruption that = (DiagnoseCorruption) o;
+
+    if (checksum != that.checksum) return false;
+    if (shuffleId != that.shuffleId) return false;
+    if (mapId != that.mapId) return false;
+    if (reduceId != that.reduceId) return false;
+    if (!algorithm.equals(that.algorithm)) return false;
+    if (!appId.equals(that.appId)) return false;
+    if (!execId.equals(that.execId)) return false;
+    return true;
+  }
+
+  @Override
+  public int hashCode() {
+    int result = appId.hashCode();
+    result = 31 * result + execId.hashCode();
+    result = 31 * result + Integer.hashCode(shuffleId);
+    result = 31 * result + Long.hashCode(mapId);
+    result = 31 * result + Integer.hashCode(reduceId);
+    result = 31 * result + Long.hashCode(checksum);
+    result = 31 * result + algorithm.hashCode();
+    return result;
+  }
+
+  @Override
+  public int encodedLength() {
+    return Encoders.Strings.encodedLength(appId)
+      + Encoders.Strings.encodedLength(execId)
+      + 4 /* encoded length of shuffleId */
+      + 8 /* encoded length of mapId */
+      + 4 /* encoded length of reduceId */
+      + 8 /* encoded length of checksum */
+      + Encoders.Strings.encodedLength(algorithm); /* encoded length of algorithm */
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    Encoders.Strings.encode(buf, appId);
+    Encoders.Strings.encode(buf, execId);
+    buf.writeInt(shuffleId);
+    buf.writeLong(mapId);
+    buf.writeInt(reduceId);
+    buf.writeLong(checksum);
+    Encoders.Strings.encode(buf, algorithm);
+  }
+
+  public static DiagnoseCorruption decode(ByteBuf buf) {
+    String appId = Encoders.Strings.decode(buf);
+    String execId = Encoders.Strings.decode(buf);
+    int shuffleId = buf.readInt();
+    long mapId = buf.readLong();
+    int reduceId = buf.readInt();
+    long checksum = buf.readLong();
+    String algorithm = Encoders.Strings.decode(buf);
+    return new DiagnoseCorruption(appId, execId, shuffleId, mapId, reduceId, checksum, algorithm);
+  }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
index 9e0b3c6..d45cbd5 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
@@ -17,14 +17,18 @@
 
 package org.apache.spark.network.shuffle;
 
-import java.io.IOException;
+import java.io.*;
 import java.nio.ByteBuffer;
 import java.util.Iterator;
 import java.util.Map;
+import java.util.zip.CheckedInputStream;
+import java.util.zip.Checksum;
 
 import com.codahale.metrics.Meter;
 import com.codahale.metrics.Metric;
 import com.codahale.metrics.Timer;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Files;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
@@ -42,7 +46,11 @@ import org.apache.spark.network.client.TransportClient;
 import org.apache.spark.network.protocol.MergedBlockMetaRequest;
 import org.apache.spark.network.server.OneForOneStreamManager;
 import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.shuffle.checksum.Cause;
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.CorruptionCause;
+import org.apache.spark.network.shuffle.protocol.DiagnoseCorruption;
 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
 import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
 import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks;
@@ -108,6 +116,111 @@ public class ExternalBlockHandlerSuite {
     verifyOpenBlockLatencyMetrics(2, 2);
   }
 
+  private void checkDiagnosisResult(
+      String algorithm,
+      Cause expectedCaused) throws IOException {
+    String appId = "app0";
+    String execId = "execId";
+    int shuffleId = 0;
+    long mapId = 0;
+    int reduceId = 0;
+
+    // prepare the checksum file
+    File tmpDir = Files.createTempDir();
+    File checksumFile = new File(tmpDir,
+      "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum." + algorithm);
+    DataOutputStream out = new DataOutputStream(new FileOutputStream(checksumFile));
+    long checksumByReader = 0L;
+    if (expectedCaused != Cause.UNSUPPORTED_CHECKSUM_ALGORITHM) {
+      Checksum checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(algorithm);
+      CheckedInputStream checkedIn = new CheckedInputStream(
+        blockMarkers[0].createInputStream(), checksum);
+      byte[] buffer = new byte[10];
+      ByteStreams.readFully(checkedIn, buffer, 0, (int) blockMarkers[0].size());
+      long checksumByWriter = checkedIn.getChecksum().getValue();
+
+      switch (expectedCaused) {
+        // when checksumByWriter != checksumRecalculated
+        case DISK_ISSUE:
+          out.writeLong(checksumByWriter - 1);
+          checksumByReader = checksumByWriter;
+          break;
+
+        // when checksumByWriter == checksumRecalculated and checksumByReader != checksumByWriter
+        case NETWORK_ISSUE:
+          out.writeLong(checksumByWriter);
+          checksumByReader = checksumByWriter - 1;
+          break;
+
+        case UNKNOWN_ISSUE:
+          // write a int instead of a long to corrupt the checksum file
+          out.writeInt(0);
+          checksumByReader = checksumByWriter;
+          break;
+
+        default:
+          out.writeLong(checksumByWriter);
+          checksumByReader = checksumByWriter;
+      }
+    }
+    out.close();
+
+    when(blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId))
+      .thenReturn(blockMarkers[0]);
+    Cause actualCause = ShuffleChecksumHelper.diagnoseCorruption(algorithm, checksumFile, reduceId,
+      blockResolver.getBlockData(appId, execId, shuffleId, mapId, reduceId), checksumByReader);
+    when(blockResolver
+      .diagnoseShuffleBlockCorruption(
+        appId, execId, shuffleId, mapId, reduceId, checksumByReader, algorithm))
+      .thenReturn(actualCause);
+
+    when(client.getClientId()).thenReturn(appId);
+    RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+    DiagnoseCorruption diagnoseMsg = new DiagnoseCorruption(
+      appId, execId, shuffleId, mapId, reduceId, checksumByReader, algorithm);
+    handler.receive(client, diagnoseMsg.toByteBuffer(), callback);
+
+    ArgumentCaptor<ByteBuffer> response = ArgumentCaptor.forClass(ByteBuffer.class);
+    verify(callback, times(1)).onSuccess(response.capture());
+    verify(callback, never()).onFailure(any());
+
+    CorruptionCause cause =
+      (CorruptionCause) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
+    assertEquals(expectedCaused, cause.cause);
+    tmpDir.delete();
+  }
+
+  @Test
+  public void testShuffleCorruptionDiagnosisDiskIssue() throws IOException {
+    checkDiagnosisResult( "ADLER32", Cause.DISK_ISSUE);
+  }
+
+  @Test
+  public void testShuffleCorruptionDiagnosisNetworkIssue() throws IOException {
+    checkDiagnosisResult("ADLER32", Cause.NETWORK_ISSUE);
+  }
+
+  @Test
+  public void testShuffleCorruptionDiagnosisUnknownIssue() throws IOException {
+    checkDiagnosisResult("ADLER32", Cause.UNKNOWN_ISSUE);
+  }
+
+  @Test
+  public void testShuffleCorruptionDiagnosisChecksumVerifyPass() throws IOException {
+    checkDiagnosisResult("ADLER32", Cause.CHECKSUM_VERIFY_PASS);
+  }
+
+  @Test
+  public void testShuffleCorruptionDiagnosisUnSupportedAlgorithm() throws IOException {
+    checkDiagnosisResult("XXX", Cause.UNSUPPORTED_CHECKSUM_ALGORITHM);
+  }
+
+  @Test
+  public void testShuffleCorruptionDiagnosisCRC32() throws IOException {
+    checkDiagnosisResult("CRC32", Cause.CHECKSUM_VERIFY_PASS);
+  }
+
   @Test
   public void testFetchShuffleBlocks() {
     when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(blockMarkers[0]);
diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java
deleted file mode 100644
index a368836..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.checksum;
-
-import java.util.zip.Adler32;
-import java.util.zip.CRC32;
-import java.util.zip.Checksum;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.SparkException;
-import org.apache.spark.annotation.Private;
-import org.apache.spark.internal.config.package$;
-import org.apache.spark.storage.ShuffleChecksumBlockId;
-
-/**
- * A set of utility functions for the shuffle checksum.
- */
-@Private
-public class ShuffleChecksumHelper {
-
-  /** Used when the checksum is disabled for shuffle. */
-  private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0];
-  public static final long[] EMPTY_CHECKSUM_VALUE = new long[0];
-
-  public static boolean isShuffleChecksumEnabled(SparkConf conf) {
-    return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED());
-  }
-
-  public static Checksum[] createPartitionChecksumsIfEnabled(int numPartitions, SparkConf conf)
-    throws SparkException {
-    if (!isShuffleChecksumEnabled(conf)) {
-      return EMPTY_CHECKSUM;
-    }
-
-    String checksumAlgo = shuffleChecksumAlgorithm(conf);
-    return getChecksumByAlgorithm(numPartitions, checksumAlgo);
-  }
-
-  private static Checksum[] getChecksumByAlgorithm(int num, String algorithm)
-      throws SparkException {
-    Checksum[] checksums;
-    switch (algorithm) {
-      case "ADLER32":
-        checksums = new Adler32[num];
-        for (int i = 0; i < num; i ++) {
-          checksums[i] = new Adler32();
-        }
-        return checksums;
-
-      case "CRC32":
-        checksums = new CRC32[num];
-        for (int i = 0; i < num; i ++) {
-          checksums[i] = new CRC32();
-        }
-        return checksums;
-
-      default:
-        throw new SparkException("Unsupported shuffle checksum algorithm: " + algorithm);
-    }
-  }
-
-  public static long[] getChecksumValues(Checksum[] partitionChecksums) {
-    int numPartitions = partitionChecksums.length;
-    long[] checksumValues = new long[numPartitions];
-    for (int i = 0; i < numPartitions; i ++) {
-      checksumValues[i] = partitionChecksums[i].getValue();
-    }
-    return checksumValues;
-  }
-
-  public static String shuffleChecksumAlgorithm(SparkConf conf) {
-    return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM());
-  }
-
-  public static Checksum getChecksumByFileExtension(String fileName) throws SparkException {
-    int index = fileName.lastIndexOf(".");
-    String algorithm = fileName.substring(index + 1);
-    return getChecksumByAlgorithm(1, algorithm)[0];
-  }
-
-  public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkConf conf) {
-    // append the shuffle checksum algorithm as the file extension
-    return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf));
-  }
-}
diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java
new file mode 100644
index 0000000..4f7c3f2
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumSupport.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.checksum;
+
+import java.util.zip.Checksum;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
+
+public interface ShuffleChecksumSupport {
+
+  default Checksum[] createPartitionChecksums(int numPartitions, SparkConf conf) {
+    if ((boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED())) {
+      String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM());
+      return ShuffleChecksumHelper.createPartitionChecksums(numPartitions, checksumAlgorithm);
+    } else {
+      return ShuffleChecksumHelper.EMPTY_CHECKSUM;
+    }
+  }
+
+  default long[] getChecksumValues(Checksum[] partitionChecksums) {
+    int numPartitions = partitionChecksums.length;
+    long[] checksumValues = new long[numPartitions];
+    for (int i = 0; i < numPartitions; i++) {
+      checksumValues[i] = partitionChecksums[i].getValue();
+    }
+    return checksumValues;
+  }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 3222240..9a5ac6f 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -40,10 +40,12 @@ import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
 import org.apache.spark.SparkException;
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
 import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
 import org.apache.spark.shuffle.api.ShufflePartitionWriter;
 import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport;
 import org.apache.spark.internal.config.package$;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
@@ -51,7 +53,6 @@ import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
 import org.apache.spark.shuffle.ShuffleWriter;
-import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.storage.*;
 import org.apache.spark.util.Utils;
 
@@ -76,7 +77,9 @@ import org.apache.spark.util.Utils;
  * <p>
  * There have been proposals to completely remove this code path; see SPARK-6026 for details.
  */
-final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
+final class BypassMergeSortShuffleWriter<K, V>
+  extends ShuffleWriter<K, V>
+  implements ShuffleChecksumSupport {
 
   private static final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
 
@@ -125,8 +128,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     this.writeMetrics = writeMetrics;
     this.serializer = dep.serializer();
     this.shuffleExecutorComponents = shuffleExecutorComponents;
-    this.partitionChecksums =
-      ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
+    this.partitionChecksums = createPartitionChecksums(numPartitions, conf);
   }
 
   @Override
@@ -230,9 +232,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       }
       partitionWriters = null;
     }
-    return mapOutputWriter.commitAllPartitions(
-      ShuffleChecksumHelper.getChecksumValues(partitionChecksums)
-    ).getPartitionLengths();
+    return mapOutputWriter.commitAllPartitions(getChecksumValues(partitionChecksums))
+      .getPartitionLengths();
   }
 
   private void writePartitionedDataWithChannel(
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 0307027..a82f691 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -41,7 +41,7 @@ import org.apache.spark.memory.TooLargePageException;
 import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.FileSegment;
@@ -68,7 +68,7 @@ import org.apache.spark.util.Utils;
  * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
  * specialized merge procedure that avoids extra serialization/deserialization.
  */
-final class ShuffleExternalSorter extends MemoryConsumer {
+final class ShuffleExternalSorter extends MemoryConsumer implements ShuffleChecksumSupport {
 
   private static final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
 
@@ -139,12 +139,11 @@ final class ShuffleExternalSorter extends MemoryConsumer {
     this.peakMemoryUsedBytes = getMemoryUsage();
     this.diskWriteBufferSize =
         (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
-    this.partitionChecksums =
-      ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
+    this.partitionChecksums = createPartitionChecksums(numPartitions, conf);
   }
 
   public long[] getChecksums() {
-    return ShuffleChecksumHelper.getChecksumValues(partitionChecksums);
+    return getChecksumValues(partitionChecksums);
   }
 
   /**
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 2659b17..b1779a1 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -45,6 +45,7 @@ import org.apache.spark.io.CompressionCodec;
 import org.apache.spark.io.CompressionCodec$;
 import org.apache.spark.io.NioBufferedFileInputStream;
 import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
@@ -57,7 +58,6 @@ import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
 import org.apache.spark.shuffle.api.ShufflePartitionWriter;
 import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
 import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
-import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.TimeTrackingOutputStream;
 import org.apache.spark.unsafe.Platform;
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 39c526c..60ba3aa 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -1372,7 +1372,7 @@ package object config {
     ConfigBuilder("spark.shuffle.checksum.enabled")
       .doc("Whether to calculate the checksum of shuffle output. If enabled, Spark will try " +
         "its best to tell if shuffle data corruption is caused by network or disk or others.")
-      .version("3.3.0")
+      .version("3.2.0")
       .booleanConf
       .createWithDefault(true)
 
@@ -1380,7 +1380,7 @@ package object config {
     ConfigBuilder("spark.shuffle.checksum.algorithm")
       .doc("The algorithm used to calculate the checksum. Currently, it only supports" +
         " built-in algorithms of JDK.")
-      .version("3.3.0")
+      .version("3.2.0")
       .stringConf
       .transform(_.toUpperCase(Locale.ROOT))
       .checkValue(Set("ADLER32", "CRC32").contains, "Shuffle checksum algorithm " +
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index cafb39e..8917734 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -22,12 +22,21 @@ import scala.reflect.ClassTag
 import org.apache.spark.TaskContext
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.client.StreamCallbackWithID
+import org.apache.spark.network.shuffle.checksum.Cause
 import org.apache.spark.storage.{BlockId, StorageLevel}
 
 private[spark]
 trait BlockDataManager {
 
   /**
+   * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums
+   */
+  def diagnoseShuffleBlockCorruption(
+      blockId: BlockId,
+      checksumByReader: Long,
+      algorithm: String): Cause
+
+  /**
    * Get the local directories that used by BlockManager to save the blocks to disk
    */
   def getLocalDiskDirs: Array[String]
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 5f831dc..81c878d 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -133,6 +133,13 @@ class NettyBlockRpcServer(
               Map(actualExecId -> blockManager.getLocalDiskDirs).asJava).toByteBuffer)
           }
         }
+
+      case diagnose: DiagnoseCorruption =>
+        val cause = blockManager.diagnoseShuffleBlockCorruption(
+          ShuffleBlockId(diagnose.shuffleId, diagnose.mapId, diagnose.reduceId ),
+          diagnose.checksum,
+          diagnose.algorithm)
+        responseContext.onSuccess(new CorruptionCause(cause).toByteBuffer)
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 4e0beea..6da0cb4 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -61,7 +61,6 @@ private[spark] class NettyBlockTransferService(
   // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
   private val serializer = new JavaSerializer(conf)
   private val authEnabled = securityManager.isAuthenticationEnabled()
-  private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores)
 
   private[this] var transportContext: TransportContext = _
   private[this] var server: TransportServer = _
@@ -70,6 +69,7 @@ private[spark] class NettyBlockTransferService(
     val rpcHandler = new NettyBlockRpcServer(conf.getAppId, serializer, blockDataManager)
     var serverBootstrap: Option[TransportServerBootstrap] = None
     var clientBootstrap: Option[TransportClientBootstrap] = None
+    this.transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numCores)
     if (authEnabled) {
       serverBootstrap = Some(new AuthServerBootstrap(transportConf, securityManager))
       clientBootstrap = Some(new AuthClientBootstrap(transportConf, conf.getAppId, securityManager))
@@ -78,6 +78,7 @@ private[spark] class NettyBlockTransferService(
     clientFactory = transportContext.createClientFactory(clientBootstrap.toSeq.asJava)
     server = createServer(serverBootstrap.toList)
     appId = conf.getAppId
+
     logger.info(s"Server created on $hostName:${server.getPort}")
   }
 
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 818aa2e..df06b07 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -83,6 +83,8 @@ private[spark] class BlockStoreShuffleReader[K, C](
       SparkEnv.get.conf.get(config.SHUFFLE_MAX_ATTEMPTS_ON_NETTY_OOM),
       SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
       SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
+      SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ENABLED),
+      SparkEnv.get.conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM),
       readMetrics,
       fetchContinuousBlocksInBatch).toCompletionIterator
 
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 07928f8..7454a74 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -31,9 +31,9 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.client.StreamCallbackWithID
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.network.shuffle.{ExecutorDiskUtils, MergedBlockMeta}
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper
 import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
-import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
 
@@ -157,7 +157,7 @@ private[spark] class IndexShuffleBlockResolver(
       logWarning(s"Error deleting index ${file.getPath()}")
     }
 
-    file = getChecksumFile(shuffleId, mapId)
+    file = getChecksumFile(shuffleId, mapId, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM))
     if (file.exists() && !file.delete()) {
       logWarning(s"Error deleting checksum ${file.getPath()}")
     }
@@ -339,7 +339,8 @@ private[spark] class IndexShuffleBlockResolver(
     val (checksumFileOpt, checksumTmpOpt) = if (checksumEnabled) {
       assert(lengths.length == checksums.length,
         "The size of partition lengths and checksums should be equal")
-      val checksumFile = getChecksumFile(shuffleId, mapId)
+      val checksumFile =
+        getChecksumFile(shuffleId, mapId, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM))
       (Some(checksumFile), Some(Utils.tempFileWith(checksumFile)))
     } else {
       (None, None)
@@ -540,14 +541,13 @@ private[spark] class IndexShuffleBlockResolver(
   def getChecksumFile(
       shuffleId: Int,
       mapId: Long,
+      algorithm: String,
       dirs: Option[Array[String]] = None): File = {
     val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
-    val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId, conf)
+    val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId.name, algorithm)
     dirs
       .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileName))
-      .getOrElse {
-        blockManager.diskBlockManager.getFile(fileName)
-      }
+      .getOrElse(blockManager.diskBlockManager.getFile(fileName))
   }
 
   override def getBlockData(
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index ce53f08..e450129 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -94,7 +94,7 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten
   override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
 }
 
-@Since("3.3.0")
+@Since("3.2.0")
 @DeveloperApi
 case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {
   override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum"
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index b81b3b6..4c646b2 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -49,12 +49,13 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.client.StreamCallbackWithID
 import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.network.shuffle._
+import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper}
 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
 import org.apache.spark.network.util.TransportConf
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
 import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
-import org.apache.spark.shuffle.{MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter}
+import org.apache.spark.shuffle.{IndexShuffleBlockResolver, MigratableResolver, ShuffleManager, ShuffleWriteMetricsReporter}
 import org.apache.spark.storage.BlockManagerMessages.{DecommissionBlockManager, ReplicateBlock}
 import org.apache.spark.storage.memory._
 import org.apache.spark.unsafe.Platform
@@ -283,6 +284,28 @@ private[spark] class BlockManager(
   override def getLocalDiskDirs: Array[String] = diskBlockManager.localDirsString
 
   /**
+   * Diagnose the possible cause of the shuffle data corruption by verifying the shuffle checksums
+   *
+   * @param blockId The blockId of the corrupted shuffle block
+   * @param checksumByReader The checksum value of the corrupted block
+   * @param algorithm The cheksum algorithm that is used when calculating the checksum value
+   */
+  override def diagnoseShuffleBlockCorruption(
+      blockId: BlockId,
+      checksumByReader: Long,
+      algorithm: String): Cause = {
+    assert(blockId.isInstanceOf[ShuffleBlockId],
+      s"Corruption diagnosis only supports shuffle block yet, but got $blockId")
+    val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId]
+    val resolver = shuffleManager.shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver]
+    val checksumFile =
+      resolver.getChecksumFile(shuffleBlock.shuffleId, shuffleBlock.mapId, algorithm)
+    val reduceId = shuffleBlock.reduceId
+    ShuffleChecksumHelper.diagnoseCorruption(
+      algorithm, checksumFile, reduceId, resolver.getBlockData(shuffleBlock), checksumByReader)
+  }
+
+  /**
    * Abstraction for storing blocks from bytes, whether they start in memory or on disk.
    *
    * @param blockSize the decrypted size of the block
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index fd87f5e..3eb8acd 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -21,6 +21,7 @@ import java.io.{InputStream, IOException}
 import java.nio.channels.ClosedByInterruptException
 import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
 import java.util.concurrent.atomic.AtomicBoolean
+import java.util.zip.CheckedInputStream
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
@@ -36,6 +37,7 @@ import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.shuffle._
+import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper}
 import org.apache.spark.network.util.{NettyUtils, TransportConf}
 import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
 import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
@@ -69,7 +71,11 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
  * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
  * @param maxAttemptsOnNettyOOM The max number of a block could retry due to Netty OOM before
  *                              throwing the fetch failure.
- * @param detectCorrupt whether to detect any corruption in fetched blocks.
+ * @param detectCorrupt         whether to detect any corruption in fetched blocks.
+ * @param checksumEnabled whether the shuffle checksum is enabled. When enabled, Spark will try to
+ *                        diagnose the cause of the block corruption.
+ * @param checksumAlgorithm the checksum algorithm that is used when calculating the checksum value
+ *                         for the block data.
  * @param shuffleMetrics used to report shuffle metrics.
  * @param doBatchFetch fetch continuous shuffle blocks from same executor in batch if the server
  *                     side supports.
@@ -89,6 +95,8 @@ final class ShuffleBlockFetcherIterator(
     maxAttemptsOnNettyOOM: Int,
     detectCorrupt: Boolean,
     detectCorruptUseExtraMemory: Boolean,
+    checksumEnabled: Boolean,
+    checksumAlgorithm: String,
     shuffleMetrics: ShuffleReadMetricsReporter,
     doBatchFetch: Boolean)
   extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
@@ -732,6 +740,8 @@ final class ShuffleBlockFetcherIterator(
 
     var result: FetchResult = null
     var input: InputStream = null
+    // This's only initialized when shuffle checksum is enabled.
+    var checkedIn: CheckedInputStream = null
     var streamCompressedOrEncrypted: Boolean = false
     // Take the next fetched result and try to decompress it to detect data corruption,
     // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
@@ -787,7 +797,14 @@ final class ShuffleBlockFetcherIterator(
           }
 
           val in = try {
-            buf.createInputStream()
+            var bufIn = buf.createInputStream()
+            if (checksumEnabled) {
+              val checksum = ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)
+              checkedIn = new CheckedInputStream(bufIn, checksum)
+              checkedIn
+            } else {
+              bufIn
+            }
           } catch {
             // The exception could only be throwed by local shuffle block
             case e: IOException =>
@@ -822,8 +839,15 @@ final class ShuffleBlockFetcherIterator(
               }
             } catch {
               case e: IOException =>
-                buf.release()
+                // When shuffle checksum is enabled, for a block that is corrupted twice,
+                // we'd calculate the checksum of the block by consuming the remaining data
+                // in the buf. So, we should release the buf later.
+                if (!(checksumEnabled && corruptedBlocks.contains(blockId))) {
+                  buf.release()
+                }
+
                 if (blockId.isShuffleChunk) {
+                  // TODO (SPARK-36284): Add shuffle checksum support for push-based shuffle
                   // Retrying a corrupt block may result again in a corrupt block. For shuffle
                   // chunks, we opt to fallback on the original shuffle blocks that belong to that
                   // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt
@@ -834,17 +858,27 @@ final class ShuffleBlockFetcherIterator(
                   pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
                   // Set result to null to trigger another iteration of the while loop.
                   result = null
-                } else {
-                  if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                    || corruptedBlocks.contains(blockId)) {
-                    throwFetchFailedException(blockId, mapIndex, address, e)
+                } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+                  throwFetchFailedException(blockId, mapIndex, address, e)
+                } else if (corruptedBlocks.contains(blockId)) {
+                  // It's the second time this block is detected corrupted
+                  if (checksumEnabled) {
+                    // Diagnose the cause of data corruption if shuffle checksum is enabled
+                    val diagnosisResponse = diagnoseCorruption(checkedIn, address, blockId)
+                    buf.release()
+                    logError(diagnosisResponse)
+                    throwFetchFailedException(
+                      blockId, mapIndex, address, e, Some(diagnosisResponse))
                   } else {
-                    logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                    corruptedBlocks += blockId
-                    fetchRequests += FetchRequest(
-                      address, Array(FetchBlockInfo(blockId, size, mapIndex)))
-                    result = null
+                    throwFetchFailedException(blockId, mapIndex, address, e)
                   }
+                } else {
+                  // It's the first time this block is detected corrupted
+                  logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
+                  corruptedBlocks += blockId
+                  fetchRequests += FetchRequest(
+                    address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+                  result = null
                 }
             } finally {
               if (blockId.isShuffleChunk) {
@@ -975,7 +1009,66 @@ final class ShuffleBlockFetcherIterator(
         currentResult.mapIndex,
         currentResult.address,
         detectCorrupt && streamCompressedOrEncrypted,
-        currentResult.isNetworkReqDone))
+        currentResult.isNetworkReqDone,
+        Option(checkedIn)))
+  }
+
+  /**
+   * Get the suspect corruption cause for the corrupted block. It should be only invoked
+   * when checksum is enabled and corruption was detected at least once.
+   *
+   * This will firstly consume the rest of stream of the corrupted block to calculate the
+   * checksum of the block. Then, it will raise a synchronized RPC call along with the
+   * checksum to ask the server(where the corrupted block is fetched from) to diagnose the
+   * cause of corruption and return it.
+   *
+   * Any exception raised during the process will result in the [[Cause.UNKNOWN_ISSUE]] of the
+   * corruption cause since corruption diagnosis is only a best effort.
+   *
+   * @param checkedIn the [[CheckedInputStream]] which is used to calculate the checksum.
+   * @param address the address where the corrupted block is fetched from.
+   * @param blockId the blockId of the corrupted block.
+   * @return The corruption diagnosis response for different causes.
+   */
+  private[storage] def diagnoseCorruption(
+      checkedIn: CheckedInputStream,
+      address: BlockManagerId,
+      blockId: BlockId): String = {
+    logInfo("Start corruption diagnosis.")
+    val startTimeNs = System.nanoTime()
+    assert(blockId.isInstanceOf[ShuffleBlockId], s"Expected ShuffleBlockId, but got $blockId")
+    val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId]
+    val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER)
+    // consume the remaining data to calculate the checksum
+    var cause: Cause = null
+    try {
+      while (checkedIn.read(buffer) != -1) {}
+      val checksum = checkedIn.getChecksum.getValue
+      cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId,
+        shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum,
+        checksumAlgorithm)
+    } catch {
+      case e: Exception =>
+        logWarning("Unable to diagnose the corruption cause of the corrupted block", e)
+        cause = Cause.UNKNOWN_ISSUE
+    }
+    val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)
+    val diagnosisResponse = cause match {
+      case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM =>
+        s"Block $blockId is corrupted but corruption diagnosis failed due to " +
+          s"unsupported checksum algorithm: $checksumAlgorithm"
+
+      case Cause.CHECKSUM_VERIFY_PASS =>
+        s"Block $blockId is corrupted but checksum verification passed"
+
+      case Cause.UNKNOWN_ISSUE =>
+        s"Block $blockId is corrupted but the cause is unknown"
+
+      case otherCause =>
+        s"Block $blockId is corrupted due to $otherCause"
+    }
+    logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse")
+    diagnosisResponse
   }
 
   def toCompletionIterator: Iterator[(BlockId, InputStream)] = {
@@ -1158,7 +1251,8 @@ private class BufferReleasingInputStream(
     private val mapIndex: Int,
     private val address: BlockManagerId,
     private val detectCorruption: Boolean,
-    private val isNetworkReqDone: Boolean)
+    private val isNetworkReqDone: Boolean,
+    private val checkedInOpt: Option[CheckedInputStream])
   extends InputStream {
   private[this] var closed = false
 
@@ -1207,8 +1301,13 @@ private class BufferReleasingInputStream(
       block
     } catch {
       case e: IOException if detectCorruption =>
+        val diagnosisResponse = checkedInOpt.map { checkedIn =>
+          iterator.diagnoseCorruption(checkedIn, address, blockId)
+        }
         IOUtils.closeQuietly(this)
-        iterator.throwFetchFailedException(blockId, mapIndex, address, e)
+        // We'd never retry the block whatever the cause is since the block has been
+        // partially consumed by downstream RDDs.
+        iterator.throwFetchFailedException(blockId, mapIndex, address, e, diagnosisResponse)
     }
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index c63e196..eda408a 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -31,7 +31,7 @@ import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.serializer._
 import org.apache.spark.shuffle.ShufflePartitionPairsWriter
 import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
-import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
+import org.apache.spark.shuffle.checksum.ShuffleChecksumSupport
 import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
 import org.apache.spark.util.{CompletionIterator, Utils => TryUtils}
 
@@ -97,7 +97,7 @@ private[spark] class ExternalSorter[K, V, C](
     ordering: Option[Ordering[K]] = None,
     serializer: Serializer = SparkEnv.get.serializer)
   extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager())
-  with Logging {
+  with Logging with ShuffleChecksumSupport {
 
   private val conf = SparkEnv.get.conf
 
@@ -142,10 +142,9 @@ private[spark] class ExternalSorter[K, V, C](
   private val forceSpillFiles = new ArrayBuffer[SpilledFile]
   @volatile private var readingIterator: SpillableIterator = null
 
-  private val partitionChecksums =
-    ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf)
+  private val partitionChecksums = createPartitionChecksums(numPartitions, conf)
 
-  def getChecksums: Array[Long] = ShuffleChecksumHelper.getChecksumValues(partitionChecksums)
+  def getChecksums: Array[Long] = getChecksumValues(partitionChecksums)
 
   // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
   // Can be a partial ordering by hash code if a total ordering is not provided through by the
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 63220ed..87f9ab3 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -23,8 +23,8 @@ import java.nio.file.Files;
 import java.util.*;
 
 import org.apache.spark.*;
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.shuffle.ShuffleChecksumTestHelper;
-import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
 import org.mockito.stubbing.Answer;
 import scala.*;
 import scala.collection.Iterator;
@@ -38,11 +38,11 @@ import org.mockito.MockitoAnnotations;
 
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.internal.config.package$;
 import org.apache.spark.io.CompressionCodec$;
 import org.apache.spark.io.LZ4CompressionCodec;
 import org.apache.spark.io.LZFCompressionCodec;
 import org.apache.spark.io.SnappyCompressionCodec;
-import org.apache.spark.internal.config.package$;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
@@ -301,12 +301,13 @@ public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper {
     IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager);
     ShuffleChecksumBlockId checksumBlockId =
       new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID());
-    File checksumFile = new File(tempDir,
-      ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf));
+    String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM());
+    String checksumFileName = ShuffleChecksumHelper.getChecksumFileName(
+      checksumBlockId.name(), checksumAlgorithm);
+    File checksumFile = new File(tempDir, checksumFileName);
     File dataFile = new File(tempDir, "data");
     File indexFile = new File(tempDir, "index");
-    when(diskBlockManager.getFile(checksumFile.getName()))
-      .thenReturn(checksumFile);
+    when(diskBlockManager.getFile(checksumFileName)).thenReturn(checksumFile);
     when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0)))
       .thenReturn(dataFile);
     when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0)))
@@ -322,7 +323,7 @@ public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper {
     writer1.stop(true);
     assertTrue(checksumFile.exists());
     assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS);
-    compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile);
+    compareChecksums(NUM_PARTITIONS, checksumAlgorithm, checksumFile, dataFile, indexFile);
   }
 
   @Test
@@ -330,11 +331,13 @@ public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper {
     IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager);
     ShuffleChecksumBlockId checksumBlockId =
       new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID());
-    File checksumFile =
-      new File(tempDir, ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf));
+    String checksumAlgorithm = conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM());
+    String checksumFileName = ShuffleChecksumHelper.getChecksumFileName(
+      checksumBlockId.name(), checksumAlgorithm);
+    File checksumFile = new File(tempDir, checksumFileName);
     File dataFile = new File(tempDir, "data");
     File indexFile = new File(tempDir, "index");
-    when(diskBlockManager.getFile(eq(checksumFile.getName()))).thenReturn(checksumFile);
+    when(diskBlockManager.getFile(checksumFileName)).thenReturn(checksumFile);
     when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0)))
       .thenReturn(dataFile);
     when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0)))
@@ -356,7 +359,7 @@ public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper {
     writer1.closeAndWriteOutput();
     assertTrue(checksumFile.exists());
     assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS);
-    compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile);
+    compareChecksums(NUM_PARTITIONS, checksumAlgorithm, checksumFile, dataFile, indexFile);
   }
 
   private void testMergingSpills(
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index fd75d91..c1a964c 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark
 
-import java.io.File
+import java.io.{File, RandomAccessFile}
 import java.util.{Locale, Properties}
 import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService }
 
@@ -447,6 +447,41 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalRootDi
       }
     }
   }
+
+  test("SPARK-36206: shuffle checksum detect disk corruption") {
+    val newConf = conf.clone
+      .set(config.SHUFFLE_CHECKSUM_ENABLED, true)
+      .set(TEST_NO_STAGE_RETRY, false)
+      .set("spark.stage.maxConsecutiveAttempts", "1")
+    sc = new SparkContext("local-cluster[2, 1, 2048]", "test", newConf)
+    val rdd = sc.parallelize(1 to 10, 2).map((_, 1)).reduceByKey(_ + _)
+    // materialize the shuffle map outputs
+    rdd.count()
+
+    sc.parallelize(1 to 10, 2).barrier().mapPartitions { iter =>
+      var dataFile = SparkEnv.get.blockManager
+        .diskBlockManager.getFile(ShuffleDataBlockId(0, 0, 0))
+      if (!dataFile.exists()) {
+        dataFile = SparkEnv.get.blockManager
+          .diskBlockManager.getFile(ShuffleDataBlockId(0, 1, 0))
+      }
+
+      if (dataFile.exists()) {
+        val f = new RandomAccessFile(dataFile, "rw")
+        // corrupt the shuffle data files by writing some arbitrary bytes
+        f.seek(0)
+        f.write(Array[Byte](12))
+        f.close()
+      }
+      BarrierTaskContext.get().barrier()
+      iter
+    }.collect()
+
+    val e = intercept[SparkException] {
+      rdd.count()
+    }
+    assert(e.getMessage.contains("corrupted due to DISK_ISSUE"))
+  }
 }
 
 /**
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
index a8f2c40..3db2f77 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
@@ -20,15 +20,20 @@ package org.apache.spark.shuffle
 import java.io.{DataInputStream, File, FileInputStream}
 import java.util.zip.CheckedInputStream
 
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper
 import org.apache.spark.network.util.LimitedInputStream
-import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
 
 trait ShuffleChecksumTestHelper {
 
   /**
    * Ensure that the checksum values are consistent between write and read side.
    */
-  def compareChecksums(numPartition: Int, checksum: File, data: File, index: File): Unit = {
+  def compareChecksums(
+      numPartition: Int,
+      algorithm: String,
+      checksum: File,
+      data: File,
+      index: File): Unit = {
     assert(checksum.exists(), "Checksum file doesn't exist")
     assert(data.exists(), "Data file doesn't exist")
     assert(index.exists(), "Index file doesn't exist")
@@ -55,7 +60,7 @@ trait ShuffleChecksumTestHelper {
         val curOffset = indexIn.readLong
         val limit = (curOffset - prevOffset).toInt
         val bytes = new Array[Byte](limit)
-        val checksumCal = ShuffleChecksumHelper.getChecksumByFileExtension(checksum.getName)
+        val checksumCal = ShuffleChecksumHelper.getChecksumByAlgorithm(algorithm)
         checkedIn = new CheckedInputStream(
           new LimitedInputStream(dataIn, curOffset - prevOffset), checksumCal)
         checkedIn.read(bytes, 0, limit)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 39eef97..38ed702 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -31,11 +31,12 @@ import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark._
 import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
+import org.apache.spark.internal.config
 import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.network.shuffle.checksum.ShuffleChecksumHelper
 import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
 import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleChecksumTestHelper}
 import org.apache.spark.shuffle.api.ShuffleExecutorComponents
-import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
 import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
@@ -248,12 +249,14 @@ class BypassMergeSortShuffleWriterSuite
     val checksumBlockId = ShuffleChecksumBlockId(shuffleId, mapId, 0)
     val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0)
     val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0)
-    val checksumFile = new File(tempDir,
-      ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf))
+    val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)
+    val checksumFileName = ShuffleChecksumHelper.getChecksumFileName(
+      checksumBlockId.name, checksumAlgorithm)
+    val checksumFile = new File(tempDir, checksumFileName)
     val dataFile = new File(tempDir, dataBlockId.name)
     val indexFile = new File(tempDir, indexBlockId.name)
     reset(diskBlockManager)
-    when(diskBlockManager.getFile(checksumFile.getName)).thenAnswer(_ => checksumFile)
+    when(diskBlockManager.getFile(checksumFileName)).thenAnswer(_ => checksumFile)
     when(diskBlockManager.getFile(dataBlockId)).thenAnswer(_ => dataFile)
     when(diskBlockManager.getFile(indexBlockId)).thenAnswer(_ => indexFile)
     when(diskBlockManager.createTempShuffleBlock())
@@ -277,6 +280,6 @@ class BypassMergeSortShuffleWriterSuite
     writer.stop( /* success = */ true)
     assert(checksumFile.exists())
     assert(checksumFile.length() === 8 * numPartition)
-    compareChecksums(numPartition, checksumFile, dataFile, indexFile)
+    compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile)
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
index abe2b56..21704b1 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -262,7 +262,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
     val indexInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
     val checksumsInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
     resolver.writeMetadataFileAndCommit(0, 0, indexInMemory, checksumsInMemory, dataTmp)
-    val checksumFile = resolver.getChecksumFile(0, 0)
+    val checksumFile = resolver.getChecksumFile(0, 0, conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM))
     assert(checksumFile.exists())
     val checksumFileName = checksumFile.toString
     val checksumAlgo = checksumFileName.substring(checksumFileName.lastIndexOf(".") + 1)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
index e345736..6c13c7c 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -24,6 +24,7 @@ import org.scalatest.PrivateMethodTester
 import org.scalatest.matchers.must.Matchers
 
 import org.apache.spark.{Aggregator, DebugFilesystem, Partitioner, SharedSparkContext, ShuffleDependency, SparkContext, SparkFunSuite}
+import org.apache.spark.internal.config
 import org.apache.spark.memory.MemoryTestingUtils
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleChecksumTestHelper}
@@ -165,12 +166,13 @@ class SortShuffleWriterSuite
       val expectSpillSize = if (doSpill) records.size else 0
       assert(sorter.numSpills === expectSpillSize)
       writer.stop(success = true)
-      val checksumFile = shuffleBlockResolver.getChecksumFile(shuffleId, 0)
+      val checksumAlgorithm = conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM)
+      val checksumFile = shuffleBlockResolver.getChecksumFile(shuffleId, 0, checksumAlgorithm)
       assert(checksumFile.exists())
       assert(checksumFile.length() === 8 * numPartition)
       val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 0)
       val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, 0)
-      compareChecksums(numPartition, checksumFile, dataFile, indexFile)
+      compareChecksums(numPartition, checksumAlgorithm, checksumFile, dataFile, indexFile)
       localSC.stop()
     }
   }
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index c22e1d0..8ed0098 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -21,11 +21,13 @@ import java.io._
 import java.nio.ByteBuffer
 import java.util.UUID
 import java.util.concurrent.{CompletableFuture, Semaphore}
+import java.util.zip.CheckedInputStream
 
 import scala.collection.mutable
 import scala.concurrent.ExecutionContext.Implicits.global
 import scala.concurrent.Future
 
+import com.google.common.io.ByteStreams
 import io.netty.util.internal.OutOfDirectMemoryError
 import org.apache.log4j.Level
 import org.mockito.ArgumentMatchers.{any, eq => meq}
@@ -157,14 +159,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream]
     verify(buffer, times(0)).release()
     val delegateAccess = PrivateMethod[InputStream](Symbol("delegate"))
-
-    verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close()
+    var in = wrappedInputStream.invokePrivate(delegateAccess())
+    if (in.isInstanceOf[CheckedInputStream]) {
+      val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in")
+      underlyingInputFiled.setAccessible(true)
+      in = underlyingInputFiled.get(in.asInstanceOf[CheckedInputStream]).asInstanceOf[InputStream]
+    }
+    verify(in, times(0)).close()
     wrappedInputStream.close()
     verify(buffer, times(1)).release()
-    verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close()
+    verify(in, times(1)).close()
     wrappedInputStream.close() // close should be idempotent
     verify(buffer, times(1)).release()
-    verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close()
+    verify(in, times(1)).close()
   }
 
   // scalastyle:off argcount
@@ -180,6 +187,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       maxAttemptsOnNettyOOM: Int = 10,
       detectCorrupt: Boolean = true,
       detectCorruptUseExtraMemory: Boolean = true,
+      checksumEnabled: Boolean = true,
+      checksumAlgorithm: String = "ADLER32",
       shuffleMetrics: Option[ShuffleReadMetricsReporter] = None,
       doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = {
     val tContext = taskContext.getOrElse(TaskContext.empty())
@@ -197,6 +206,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       maxAttemptsOnNettyOOM,
       detectCorrupt,
       detectCorruptUseExtraMemory,
+      checksumEnabled,
+      checksumAlgorithm,
       shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()),
       doBatchFetch)
   }
@@ -213,6 +224,69 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     blockIds.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq
   }
 
+  test("SPARK-36206: diagnose the block when it's corrupted twice") {
+    // Make sure remote blocks would return
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val blocks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()
+    )
+    answerFetchBlocks { invocation =>
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+      listener.onBlockFetchSuccess(ShuffleBlockId(0, 0, 0).toString, mockCorruptBuffer())
+    }
+
+    val logAppender = new LogAppender("diagnose corruption")
+    withLogAppender(logAppender) {
+      val iterator = createShuffleBlockIteratorWithDefaults(
+        Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+        streamWrapperLimitSize = Some(100)
+      )
+      intercept[FetchFailedException](iterator.next())
+      // The block will be fetched twice due to retry
+      verify(transfer, times(2))
+        .fetchBlocks(any(), any(), any(), any(), any(), any())
+      // only diagnose once
+      assert(logAppender.loggingEvents.count(
+        _.getRenderedMessage.contains("Start corruption diagnosis")) === 1)
+    }
+  }
+
+  test("SPARK-36206: diagnose the block when it's corrupted " +
+    "inside BufferReleasingInputStream") {
+    // Make sure remote blocks would return
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val blocks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()
+    )
+    answerFetchBlocks { invocation =>
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+      listener.onBlockFetchSuccess(
+        ShuffleBlockId(0, 0, 0).toString,
+        mockCorruptBuffer(100, 50))
+    }
+
+    val logAppender = new LogAppender("diagnose corruption")
+    withLogAppender(logAppender) {
+      val iterator = createShuffleBlockIteratorWithDefaults(
+        Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+        streamWrapperLimitSize = Some(100),
+        maxBytesInFlight = 100
+      )
+      intercept[FetchFailedException] {
+        val inputStream = iterator.next()._2
+        // Consume the data to trigger the corruption
+        ByteStreams.readFully(inputStream, new Array[Byte](100))
+      }
+      // The block will be fetched only once because corruption can't be detected in
+      // maxBytesInFlight/3 of the data size
+      verify(transfer, times(1))
+        .fetchBlocks(any(), any(), any(), any(), any(), any())
+      // only diagnose once
+      assert(logAppender.loggingEvents.exists(
+        _.getRenderedMessage.contains("Start corruption diagnosis")))
+    }
+  }
+
   test("successful 3 local + 4 host local + 2 remote reads") {
     val blockManager = createMockBlockManager()
     val localBmId = blockManager.blockManagerId

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org