You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/07/17 05:24:27 UTC

[spark] branch branch-3.2 updated: [SPARK-35276][CORE] Calculate checksum for shuffle data and write as checksum file

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new d5022c3  [SPARK-35276][CORE] Calculate checksum for shuffle data and write as checksum file
d5022c3 is described below

commit d5022c3c6f73c014b3ec7535c78b1c70a2f03941
Author: yi.wu <yi...@databricks.com>
AuthorDate: Sat Jul 17 00:23:14 2021 -0500

    [SPARK-35276][CORE] Calculate checksum for shuffle data and write as checksum file
    
    ### What changes were proposed in this pull request?
    
    This is the initial work of add checksum support of shuffle. This is a piece of https://github.com/apache/spark/pull/32385. And this PR only adds checksum functionality at the shuffle writer side.
    
    Basically, the idea is to wrap a `MutableCheckedOutputStream`* upon the `FileOutputStream` while the shuffle writer generating the shuffle data. But the specific wrapping places are a bit different among the shuffle writers due to their different implementation:
    
    * `BypassMergeSortShuffleWriter` -  wrap on each partition file
    * `UnsafeShuffleWriter` - wrap on each spill files directly since they doesn't require aggregation, sorting
    * `SortShuffleWriter` - wrap on the `ShufflePartitionPairsWriter` after merged spill files since they might require aggregation, sorting
    
    \* `MutableCheckedOutputStream` is a variant of `java.util.zip.CheckedOutputStream` which can change the checksum calculator at runtime.
    
    And we use the `Adler32`, which uses the CRC-32 algorithm but much faster, to calculate the checksum as the same as `Broadcast`'s checksum.
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, added a new conf: `spark.shuffle.checksum`.
    
    ### How was this patch tested?
    
    Added unit tests.
    
    Closes #32401 from Ngone51/add-checksum-files.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
    (cherry picked from commit 4783fb72aff4a550ca0cc680dcc0d730e3e36dac)
    Signed-off-by: Mridul Muralidharan <mridulatgmail.com>
---
 .../spark/shuffle/api/ShuffleMapOutputWriter.java  |   9 +-
 .../spark/shuffle/api/ShufflePartitionWriter.java  |   6 +-
 .../api/SingleSpillShuffleMapOutputWriter.java     |   5 +-
 .../shuffle/checksum/ShuffleChecksumHelper.java    | 100 +++++++++++
 .../shuffle/sort/BypassMergeSortShuffleWriter.java |  24 ++-
 .../spark/shuffle/sort/ShuffleExternalSorter.java  |  17 +-
 .../spark/shuffle/sort/UnsafeShuffleWriter.java    |  15 +-
 .../sort/io/LocalDiskShuffleMapOutputWriter.java   |   5 +-
 .../io/LocalDiskSingleSpillMapOutputWriter.java    |   6 +-
 .../org/apache/spark/internal/config/package.scala |  19 ++
 .../spark/io/MutableCheckedOutputStream.scala}     |  39 +++--
 .../spark/shuffle/IndexShuffleBlockResolver.scala  | 195 +++++++++++++++++----
 .../shuffle/ShufflePartitionPairsWriter.scala      |  15 +-
 .../spark/shuffle/sort/SortShuffleWriter.scala     |   2 +-
 .../scala/org/apache/spark/storage/BlockId.scala   |   6 +
 .../spark/storage/DiskBlockObjectWriter.scala      |  27 ++-
 .../spark/util/collection/ExternalSorter.scala     |  12 +-
 .../shuffle/sort/UnsafeShuffleWriterSuite.java     | 104 +++++++++--
 .../spark/shuffle/ShuffleChecksumTestHelper.scala  |  78 +++++++++
 .../sort/BypassMergeSortShuffleWriterSuite.scala   |  53 +++++-
 .../sort/IndexShuffleBlockResolverSuite.scala      |  24 ++-
 .../shuffle/sort/SortShuffleWriterSuite.scala      |  87 ++++++++-
 .../io/LocalDiskShuffleMapOutputWriterSuite.scala  |   9 +-
 .../util/collection/ExternalSorterSuite.scala      |  14 +-
 project/MimaExcludes.scala                         |  10 +-
 25 files changed, 762 insertions(+), 119 deletions(-)

diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
index 0167002..2237ec0 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
@@ -59,6 +59,10 @@ public interface ShuffleMapOutputWriter {
    * available to downstream reduce tasks. If this method throws any exception, this module's
    * {@link #abort(Throwable)} method will be invoked before propagating the exception.
    * <p>
+   * Shuffle extensions which care about the cause of shuffle data corruption should store
+   * the checksums properly. When corruption happens, Spark would provide the checksum
+   * of the fetched partition to the shuffle extension to help diagnose the cause of corruption.
+   * <p>
    * This can also close any resources and clean up temporary state if necessary.
    * <p>
    * The returned commit message is a structure with two components:
@@ -68,8 +72,11 @@ public interface ShuffleMapOutputWriter {
    *    for that partition id.
    * <p>
    * 2) An optional metadata blob that can be used by shuffle readers.
+   *
+   * @param checksums The checksum values for each partition (where checksum index is equivalent to
+   *                  partition id) if shuffle checksum enabled. Otherwise, it's empty.
    */
-  MapOutputCommitMessage commitAllPartitions() throws IOException;
+  MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException;
 
   /**
    * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}.
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
index 9288751..143cc6c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
@@ -49,7 +49,7 @@ public interface ShufflePartitionWriter {
    * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
    * {@link OutputStream#close()} does not close the resource, since it will be reused across
    * partition writes. The underlying resources should be cleaned up in
-   * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+   * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and
    * {@link ShuffleMapOutputWriter#abort(Throwable)}.
    */
   OutputStream openStream() throws IOException;
@@ -68,7 +68,7 @@ public interface ShufflePartitionWriter {
    * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
    * {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel
    * will be reused across partition writes. The underlying resources should be cleaned up in
-   * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+   * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and
    * {@link ShuffleMapOutputWriter#abort(Throwable)}.
    * <p>
    * This method is primarily for advanced optimizations where bytes can be copied from the input
@@ -79,7 +79,7 @@ public interface ShufflePartitionWriter {
    * Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the
    * underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure
    * that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()},
-   * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or
+   * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])}, or
    * {@link ShuffleMapOutputWriter#abort(Throwable)}.
    */
   default Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
index cad8dcf..ba3d5a6 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
@@ -32,5 +32,8 @@ public interface SingleSpillShuffleMapOutputWriter {
   /**
    * Transfer a file that contains the bytes of all the partitions written by this map task.
    */
-  void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException;
+  void transferMapSpillFile(
+      File mapOutputFile,
+      long[] partitionLengths,
+      long[] checksums) throws IOException;
 }
diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java
new file mode 100644
index 0000000..a368836
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.checksum;
+
+import java.util.zip.Adler32;
+import java.util.zip.CRC32;
+import java.util.zip.Checksum;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkException;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.storage.ShuffleChecksumBlockId;
+
+/**
+ * A set of utility functions for the shuffle checksum.
+ */
+@Private
+public class ShuffleChecksumHelper {
+
+  /** Used when the checksum is disabled for shuffle. */
+  private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0];
+  public static final long[] EMPTY_CHECKSUM_VALUE = new long[0];
+
+  public static boolean isShuffleChecksumEnabled(SparkConf conf) {
+    return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED());
+  }
+
+  public static Checksum[] createPartitionChecksumsIfEnabled(int numPartitions, SparkConf conf)
+    throws SparkException {
+    if (!isShuffleChecksumEnabled(conf)) {
+      return EMPTY_CHECKSUM;
+    }
+
+    String checksumAlgo = shuffleChecksumAlgorithm(conf);
+    return getChecksumByAlgorithm(numPartitions, checksumAlgo);
+  }
+
+  private static Checksum[] getChecksumByAlgorithm(int num, String algorithm)
+      throws SparkException {
+    Checksum[] checksums;
+    switch (algorithm) {
+      case "ADLER32":
+        checksums = new Adler32[num];
+        for (int i = 0; i < num; i ++) {
+          checksums[i] = new Adler32();
+        }
+        return checksums;
+
+      case "CRC32":
+        checksums = new CRC32[num];
+        for (int i = 0; i < num; i ++) {
+          checksums[i] = new CRC32();
+        }
+        return checksums;
+
+      default:
+        throw new SparkException("Unsupported shuffle checksum algorithm: " + algorithm);
+    }
+  }
+
+  public static long[] getChecksumValues(Checksum[] partitionChecksums) {
+    int numPartitions = partitionChecksums.length;
+    long[] checksumValues = new long[numPartitions];
+    for (int i = 0; i < numPartitions; i ++) {
+      checksumValues[i] = partitionChecksums[i].getValue();
+    }
+    return checksumValues;
+  }
+
+  public static String shuffleChecksumAlgorithm(SparkConf conf) {
+    return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM());
+  }
+
+  public static Checksum getChecksumByFileExtension(String fileName) throws SparkException {
+    int index = fileName.lastIndexOf(".");
+    String algorithm = fileName.substring(index + 1);
+    return getChecksumByAlgorithm(1, algorithm)[0];
+  }
+
+  public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkConf conf) {
+    // append the shuffle checksum algorithm as the file extension
+    return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf));
+  }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 3dbee1b..3222240 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -23,6 +23,7 @@ import java.io.IOException;
 import java.io.OutputStream;
 import java.nio.channels.FileChannel;
 import java.util.Optional;
+import java.util.zip.Checksum;
 import javax.annotation.Nullable;
 
 import scala.None$;
@@ -38,6 +39,7 @@ import org.slf4j.LoggerFactory;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
+import org.apache.spark.SparkException;
 import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
 import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
 import org.apache.spark.shuffle.api.ShufflePartitionWriter;
@@ -49,6 +51,7 @@ import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
 import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.storage.*;
 import org.apache.spark.util.Utils;
 
@@ -93,6 +96,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   private FileSegment[] partitionWriterSegments;
   @Nullable private MapStatus mapStatus;
   private long[] partitionLengths;
+  /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */
+  private final Checksum[] partitionChecksums;
 
   /**
    * Are we in the process of stopping? Because map tasks can call stop() with success = true
@@ -107,7 +112,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       long mapId,
       SparkConf conf,
       ShuffleWriteMetricsReporter writeMetrics,
-      ShuffleExecutorComponents shuffleExecutorComponents) {
+      ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException {
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
     this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
@@ -120,6 +125,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     this.writeMetrics = writeMetrics;
     this.serializer = dep.serializer();
     this.shuffleExecutorComponents = shuffleExecutorComponents;
+    this.partitionChecksums =
+      ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
   }
 
   @Override
@@ -129,7 +136,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
         .createMapOutputWriter(shuffleId, mapId, numPartitions);
     try {
       if (!records.hasNext()) {
-        partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths();
+        partitionLengths = mapOutputWriter.commitAllPartitions(
+          ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
         mapStatus = MapStatus$.MODULE$.apply(
           blockManager.shuffleServerId(), partitionLengths, mapId);
         return;
@@ -143,8 +151,12 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
             blockManager.diskBlockManager().createTempShuffleBlock();
         final File file = tempShuffleBlockIdPlusFile._2();
         final BlockId blockId = tempShuffleBlockIdPlusFile._1();
-        partitionWriters[i] =
-            blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+        DiskBlockObjectWriter writer =
+          blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+        if (partitionChecksums.length > 0) {
+          writer.setChecksum(partitionChecksums[i]);
+        }
+        partitionWriters[i] = writer;
       }
       // Creating the file to write to and creating a disk writer both involve interacting with
       // the disk, and can take a long time in aggregate when we open many files, so should be
@@ -218,7 +230,9 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       }
       partitionWriters = null;
     }
-    return mapOutputWriter.commitAllPartitions().getPartitionLengths();
+    return mapOutputWriter.commitAllPartitions(
+      ShuffleChecksumHelper.getChecksumValues(partitionChecksums)
+    ).getPartitionLengths();
   }
 
   private void writePartitionedDataWithChannel(
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 833744f..0307027 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -21,7 +21,9 @@ import javax.annotation.Nullable;
 import java.io.File;
 import java.io.IOException;
 import java.util.LinkedList;
+import java.util.zip.Checksum;
 
+import org.apache.spark.SparkException;
 import scala.Tuple2;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -39,6 +41,7 @@ import org.apache.spark.memory.TooLargePageException;
 import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.FileSegment;
@@ -107,6 +110,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
   @Nullable private MemoryBlock currentPage = null;
   private long pageCursor = -1;
 
+  // Checksum calculator for each partition. Empty when shuffle checksum disabled.
+  private final Checksum[] partitionChecksums;
+
   ShuffleExternalSorter(
       TaskMemoryManager memoryManager,
       BlockManager blockManager,
@@ -114,7 +120,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
       int initialSize,
       int numPartitions,
       SparkConf conf,
-      ShuffleWriteMetricsReporter writeMetrics) {
+      ShuffleWriteMetricsReporter writeMetrics) throws SparkException {
     super(memoryManager,
       (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
       memoryManager.getTungstenMemoryMode());
@@ -133,6 +139,12 @@ final class ShuffleExternalSorter extends MemoryConsumer {
     this.peakMemoryUsedBytes = getMemoryUsage();
     this.diskWriteBufferSize =
         (int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
+    this.partitionChecksums =
+      ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
+  }
+
+  public long[] getChecksums() {
+    return ShuffleChecksumHelper.getChecksumValues(partitionChecksums);
   }
 
   /**
@@ -204,6 +216,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
             spillInfo.partitionLengths[currentPartition] = fileSegment.length();
           }
           currentPartition = partition;
+          if (partitionChecksums.length > 0) {
+            writer.setChecksum(partitionChecksums[currentPartition]);
+          }
         }
 
         final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index e8f94ba..2659b17 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -57,6 +57,7 @@ import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
 import org.apache.spark.shuffle.api.ShufflePartitionWriter;
 import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
 import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.TimeTrackingOutputStream;
 import org.apache.spark.unsafe.Platform;
@@ -115,7 +116,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       TaskContext taskContext,
       SparkConf sparkConf,
       ShuffleWriteMetricsReporter writeMetrics,
-      ShuffleExecutorComponents shuffleExecutorComponents) {
+      ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException {
     final int numPartitions = handle.dependency().partitioner().numPartitions();
     if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
       throw new IllegalArgumentException(
@@ -198,7 +199,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     }
   }
 
-  private void open() {
+  private void open() throws SparkException {
     assert (sorter == null);
     sorter = new ShuffleExternalSorter(
       memoryManager,
@@ -219,10 +220,10 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     serBuffer = null;
     serOutputStream = null;
     final SpillInfo[] spills = sorter.closeAndGetSpills();
-    sorter = null;
     try {
       partitionLengths = mergeSpills(spills);
     } finally {
+      sorter = null;
       for (SpillInfo spill : spills) {
         if (spill.file.exists() && !spill.file.delete()) {
           logger.error("Error while deleting spill file {}", spill.file.getPath());
@@ -267,7 +268,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     if (spills.length == 0) {
       final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
           .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
-      return mapWriter.commitAllPartitions().getPartitionLengths();
+      return mapWriter.commitAllPartitions(
+        ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
     } else if (spills.length == 1) {
       Optional<SingleSpillShuffleMapOutputWriter> maybeSingleFileWriter =
           shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId);
@@ -277,7 +279,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
         partitionLengths = spills[0].partitionLengths;
         logger.debug("Merge shuffle spills for mapId {} with length {}", mapId,
             partitionLengths.length);
-        maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
+        maybeSingleFileWriter.get()
+          .transferMapSpillFile(spills[0].file, partitionLengths, sorter.getChecksums());
       } else {
         partitionLengths = mergeSpillsUsingStandardWriter(spills);
       }
@@ -330,7 +333,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       // to be counted as shuffle write, but this will lead to double-counting of the final
       // SpillInfo's bytes.
       writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
-      partitionLengths = mapWriter.commitAllPartitions().getPartitionLengths();
+      partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths();
     } catch (Exception e) {
       try {
         mapWriter.abort(e);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
index 0b28626..6c5025d 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
@@ -98,7 +98,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
   }
 
   @Override
-  public MapOutputCommitMessage commitAllPartitions() throws IOException {
+  public MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException {
     // Check the position after transferTo loop to see if it is in the right position and raise a
     // exception if it is incorrect. The position will not be increased to the expected length
     // after calling transferTo in kernel version 2.6.32. This issue is described at
@@ -115,7 +115,8 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
     File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
     log.debug("Writing shuffle index file for mapId {} with length {}", mapId,
         partitionLengths.length);
-    blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
+    blockResolver
+      .writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, checksums, resolvedTmp);
     return MapOutputCommitMessage.of(partitionLengths);
   }
 
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
index c8b4199..6a994b4 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
@@ -44,12 +44,14 @@ public class LocalDiskSingleSpillMapOutputWriter
   @Override
   public void transferMapSpillFile(
       File mapSpillFile,
-      long[] partitionLengths) throws IOException {
+      long[] partitionLengths,
+      long[] checksums) throws IOException {
     // The map spill file already has the proper format, and it contains all of the partition data.
     // So just transfer it directly to the destination without any merging.
     File outputFile = blockResolver.getDataFile(shuffleId, mapId);
     File tempFile = Utils.tempFileWith(outputFile);
     Files.move(mapSpillFile.toPath(), tempFile.toPath());
-    blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile);
+    blockResolver
+      .writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, checksums, tempFile);
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 613a66d..3ef964f 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -1368,6 +1368,25 @@ package object config {
         s"The buffer size must be greater than 0 and less than or equal to ${Int.MaxValue}.")
       .createWithDefault(4096)
 
+  private[spark] val SHUFFLE_CHECKSUM_ENABLED =
+    ConfigBuilder("spark.shuffle.checksum.enabled")
+      .doc("Whether to calculate the checksum of shuffle output. If enabled, Spark will try " +
+        "its best to tell if shuffle data corruption is caused by network or disk or others.")
+      .version("3.3.0")
+      .booleanConf
+      .createWithDefault(true)
+
+  private[spark] val SHUFFLE_CHECKSUM_ALGORITHM =
+    ConfigBuilder("spark.shuffle.checksum.algorithm")
+      .doc("The algorithm used to calculate the checksum. Currently, it only supports" +
+        " built-in algorithms of JDK.")
+      .version("3.3.0")
+      .stringConf
+      .transform(_.toUpperCase(Locale.ROOT))
+      .checkValue(Set("ADLER32", "CRC32").contains, "Shuffle checksum algorithm " +
+        "should be either Adler32 or CRC32.")
+      .createWithDefault("ADLER32")
+
   private[spark] val SHUFFLE_COMPRESS =
     ConfigBuilder("spark.shuffle.compress")
       .doc("Whether to compress shuffle output. Compression will use " +
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala
similarity index 51%
copy from core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
copy to core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala
index cad8dcf..754b4a8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
+++ b/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala
@@ -15,22 +15,35 @@
  * limitations under the License.
  */
 
-package org.apache.spark.shuffle.api;
+package org.apache.spark.io
 
-import java.io.File;
-import java.io.IOException;
-
-import org.apache.spark.annotation.Private;
+import java.io.OutputStream
+import java.util.zip.Checksum
 
 /**
- * Optional extension for partition writing that is optimized for transferring a single
- * file to the backing store.
+ * A variant of [[java.util.zip.CheckedOutputStream]] which can
+ * change the checksum calculator at runtime.
  */
-@Private
-public interface SingleSpillShuffleMapOutputWriter {
+class MutableCheckedOutputStream(out: OutputStream) extends OutputStream {
+  private var checksum: Checksum = _
+
+  def setChecksum(c: Checksum): Unit = {
+    this.checksum = c
+  }
+
+  override def write(b: Int): Unit = {
+    assert(checksum != null, "Checksum is not set.")
+    checksum.update(b)
+    out.write(b)
+  }
+
+  override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+    assert(checksum != null, "Checksum is not set.")
+    checksum.update(b, off, len)
+    out.write(b, off, len)
+  }
+
+  override def flush(): Unit = out.flush()
 
-  /**
-   * Transfer a file that contains the bytes of all the partitions written by this map task.
-   */
-  void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException;
+  override def close(): Unit = out.close()
 }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 5d1da19..9c50569 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -22,6 +22,8 @@ import java.nio.ByteBuffer
 import java.nio.channels.Channels
 import java.nio.file.Files
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.{SparkConf, SparkEnv, SparkException}
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.io.NioBufferedFileInputStream
@@ -31,6 +33,7 @@ import org.apache.spark.network.netty.SparkTransportConf
 import org.apache.spark.network.shuffle.{ExecutorDiskUtils, MergedBlockMeta}
 import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
 
@@ -142,17 +145,18 @@ private[spark] class IndexShuffleBlockResolver(
    */
   def removeDataByMap(shuffleId: Int, mapId: Long): Unit = {
     var file = getDataFile(shuffleId, mapId)
-    if (file.exists()) {
-      if (!file.delete()) {
-        logWarning(s"Error deleting data ${file.getPath()}")
-      }
+    if (file.exists() && !file.delete()) {
+      logWarning(s"Error deleting data ${file.getPath()}")
     }
 
     file = getIndexFile(shuffleId, mapId)
-    if (file.exists()) {
-      if (!file.delete()) {
-        logWarning(s"Error deleting index ${file.getPath()}")
-      }
+    if (file.exists() && !file.delete()) {
+      logWarning(s"Error deleting index ${file.getPath()}")
+    }
+
+    file = getChecksumFile(shuffleId, mapId)
+    if (file.exists() && !file.delete()) {
+      logWarning(s"Error deleting checksum ${file.getPath()}")
     }
   }
 
@@ -303,22 +307,41 @@ private[spark] class IndexShuffleBlockResolver(
 
 
   /**
-   * Write an index file with the offsets of each block, plus a final offset at the end for the
-   * end of the output file. This will be used by getBlockData to figure out where each block
-   * begins and ends.
+   * Commit the data and metadata files as an atomic operation, use the existing ones, or
+   * replace them with new ones. Note that the metadata parameters (`lengths`, `checksums`)
+   * will be updated to match the existing ones if use the existing ones.
+   *
+   * There're two kinds of metadata files:
    *
-   * It will commit the data and index file as an atomic operation, use the existing ones, or
-   * replace them with new ones.
+   * - index file
+   * An index file contains the offsets of each block, plus a final offset at the end
+   * for the end of the output file. It will be used by [[getBlockData]] to figure out
+   * where each block begins and ends.
    *
-   * Note: the `lengths` will be updated to match the existing index file if use the existing ones.
+   * - checksum file (optional)
+   * An checksum file contains the checksum of each block. It will be used to diagnose
+   * the cause when a block is corrupted. Note that empty `checksums` indicate that
+   * checksum is disabled.
    */
-  def writeIndexFileAndCommit(
+  def writeMetadataFileAndCommit(
       shuffleId: Int,
       mapId: Long,
       lengths: Array[Long],
+      checksums: Array[Long],
       dataTmp: File): Unit = {
     val indexFile = getIndexFile(shuffleId, mapId)
     val indexTmp = Utils.tempFileWith(indexFile)
+
+    val checksumEnabled = checksums.nonEmpty
+    val (checksumFileOpt, checksumTmpOpt) = if (checksumEnabled) {
+      assert(lengths.length == checksums.length,
+        "The size of partition lengths and checksums should be equal")
+      val checksumFile = getChecksumFile(shuffleId, mapId)
+      (Some(checksumFile), Some(Utils.tempFileWith(checksumFile)))
+    } else {
+      (None, None)
+    }
+
     try {
       val dataFile = getDataFile(shuffleId, mapId)
       // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
@@ -329,37 +352,47 @@ private[spark] class IndexShuffleBlockResolver(
           // Another attempt for the same task has already written our map outputs successfully,
           // so just use the existing partition lengths and delete our temporary map outputs.
           System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
+          if (checksumEnabled) {
+            val existingChecksums = getChecksums(checksumFileOpt.get, checksums.length)
+            if (existingChecksums != null) {
+              System.arraycopy(existingChecksums, 0, checksums, 0, lengths.length)
+            } else {
+              // It's possible that the previous task attempt succeeded writing the
+              // index file and data file but failed to write the checksum file. In
+              // this case, the current task attempt could write the missing checksum
+              // file by itself.
+              writeMetadataFile(checksums, checksumTmpOpt.get, checksumFileOpt.get, false)
+            }
+          }
           if (dataTmp != null && dataTmp.exists()) {
             dataTmp.delete()
           }
         } else {
           // This is the first successful attempt in writing the map outputs for this task,
           // so override any existing index and data files with the ones we wrote.
-          val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
-          Utils.tryWithSafeFinally {
-            // We take in lengths of each block, need to convert it to offsets.
-            var offset = 0L
-            out.writeLong(offset)
-            for (length <- lengths) {
-              offset += length
-              out.writeLong(offset)
-            }
-          } {
-            out.close()
-          }
 
-          if (indexFile.exists()) {
-            indexFile.delete()
-          }
+          val offsets = lengths.scanLeft(0L)(_ + _)
+          writeMetadataFile(offsets, indexTmp, indexFile, true)
+
           if (dataFile.exists()) {
             dataFile.delete()
           }
-          if (!indexTmp.renameTo(indexFile)) {
-            throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
-          }
           if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
             throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
           }
+
+          // write the checksum file
+          checksumTmpOpt.zip(checksumFileOpt).foreach { case (checksumTmp, checksumFile) =>
+            try {
+              writeMetadataFile(checksums, checksumTmp, checksumFile, false)
+            } catch {
+              case e: Exception =>
+                // It's not worthwhile to fail here after index file and data file are
+                // already successfully stored since checksum is only a best-effort for
+                // the corner error case.
+                logError("Failed to write checksum file", e)
+            }
+          }
         }
       }
     } finally {
@@ -367,6 +400,63 @@ private[spark] class IndexShuffleBlockResolver(
       if (indexTmp.exists() && !indexTmp.delete()) {
         logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")
       }
+      checksumTmpOpt.foreach { checksumTmp =>
+        if (checksumTmp.exists()) {
+          try {
+            if (!checksumTmp.delete()) {
+              logError(s"Failed to delete temporary checksum file " +
+                s"at ${checksumTmp.getAbsolutePath}")
+            }
+          } catch {
+            case e: Exception =>
+              // Unlike index deletion, we won't propagate the error for the checksum file since
+              // checksum is only a best-effort.
+              logError(s"Failed to delete temporary checksum file " +
+                s"at ${checksumTmp.getAbsolutePath}", e)
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Write the metadata file (index or checksum). Metadata values will be firstly write into
+   * the tmp file and the tmp file will be renamed to the target file at the end to avoid dirty
+   * writes.
+   * @param metaValues The metadata values
+   * @param tmpFile The temp file
+   * @param targetFile The target file
+   * @param propagateError Whether to propagate the error for file operation. Unlike index file,
+   *                       checksum is only a best-effort so we won't fail the whole task due to
+   *                       the error from checksum.
+   */
+  private def writeMetadataFile(
+      metaValues: Array[Long],
+      tmpFile: File,
+      targetFile: File,
+      propagateError: Boolean): Unit = {
+    val out = new DataOutputStream(
+      new BufferedOutputStream(
+        new FileOutputStream(tmpFile)
+      )
+    )
+    Utils.tryWithSafeFinally {
+      metaValues.foreach(out.writeLong)
+    } {
+      out.close()
+    }
+
+    if (targetFile.exists()) {
+      targetFile.delete()
+    }
+
+    if (!tmpFile.renameTo(targetFile)) {
+      val errorMsg = s"fail to rename file $tmpFile to $targetFile"
+      if (propagateError) {
+        throw new IOException(errorMsg)
+      } else {
+        logWarning(errorMsg)
+      }
     }
   }
 
@@ -414,6 +504,45 @@ private[spark] class IndexShuffleBlockResolver(
     new MergedBlockMeta(numChunks, chunkBitMaps)
   }
 
+  private[shuffle] def getChecksums(checksumFile: File, blockNum: Int): Array[Long] = {
+    if (!checksumFile.exists()) return null
+    val checksums = new ArrayBuffer[Long]
+    // Read the checksums of blocks
+    var in: DataInputStream = null
+    try {
+      in = new DataInputStream(new NioBufferedFileInputStream(checksumFile))
+      while (checksums.size < blockNum) {
+        checksums += in.readLong()
+      }
+    } catch {
+      case _: IOException | _: EOFException =>
+        return null
+    } finally {
+      in.close()
+    }
+
+    checksums.toArray
+  }
+
+  /**
+   * Get the shuffle checksum file.
+   *
+   * When the dirs parameter is None then use the disk manager's local directories. Otherwise,
+   * read from the specified directories.
+   */
+  def getChecksumFile(
+      shuffleId: Int,
+      mapId: Long,
+      dirs: Option[Array[String]] = None): File = {
+    val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
+    val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId, conf)
+    dirs
+      .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileName))
+      .getOrElse {
+        blockManager.diskBlockManager.getFile(fileName)
+      }
+  }
+
   override def getBlockData(
       blockId: BlockId,
       dirs: Option[Array[String]]): ManagedBuffer = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
index e0affb8..9843ae1 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
@@ -18,7 +18,9 @@
 package org.apache.spark.shuffle
 
 import java.io.{Closeable, IOException, OutputStream}
+import java.util.zip.Checksum
 
+import org.apache.spark.io.MutableCheckedOutputStream
 import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
 import org.apache.spark.shuffle.api.ShufflePartitionWriter
 import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream}
@@ -34,7 +36,8 @@ private[spark] class ShufflePartitionPairsWriter(
     serializerManager: SerializerManager,
     serializerInstance: SerializerInstance,
     blockId: BlockId,
-    writeMetrics: ShuffleWriteMetricsReporter)
+    writeMetrics: ShuffleWriteMetricsReporter,
+    checksum: Checksum)
   extends PairsWriter with Closeable {
 
   private var isClosed = false
@@ -44,6 +47,9 @@ private[spark] class ShufflePartitionPairsWriter(
   private var objOut: SerializationStream = _
   private var numRecordsWritten = 0
   private var curNumBytesWritten = 0L
+  // this would be only initialized when checksum != null,
+  // which indicates shuffle checksum is enabled.
+  private var checksumOutputStream: MutableCheckedOutputStream = _
 
   override def write(key: Any, value: Any): Unit = {
     if (isClosed) {
@@ -61,7 +67,12 @@ private[spark] class ShufflePartitionPairsWriter(
     try {
       partitionStream = partitionWriter.openStream
       timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream)
-      wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream)
+      if (checksum != null) {
+        checksumOutputStream = new MutableCheckedOutputStream(timeTrackingStream)
+        checksumOutputStream.setChecksum(checksum)
+      }
+      wrappedStream = serializerManager.wrapStream(blockId,
+        if (checksumOutputStream != null) checksumOutputStream else timeTrackingStream)
       objOut = serializerInstance.serializeStream(wrappedStream)
     } catch {
       case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index adbe6ec..3cbf301 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -68,7 +68,7 @@ private[spark] class SortShuffleWriter[K, V, C](
     val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
       dep.shuffleId, mapId, dep.partitioner.numPartitions)
     sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
-    partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
+    partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths
     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
   }
 
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index dc70a9a..db5862d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -92,6 +92,12 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten
   override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
 }
 
+@Since("3.3.0")
+@DeveloperApi
+case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {
+  override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum"
+}
+
 @Since("3.2.0")
 @DeveloperApi
 case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index e55c0927..f5d8c02 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -19,8 +19,10 @@ package org.apache.spark.storage
 
 import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream}
 import java.nio.channels.{ClosedByInterruptException, FileChannel}
+import java.util.zip.Checksum
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.io.MutableCheckedOutputStream
 import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
 import org.apache.spark.util.Utils
@@ -77,6 +79,11 @@ private[spark] class DiskBlockObjectWriter(
   private var streamOpen = false
   private var hasBeenClosed = false
 
+  // checksum related
+  private var checksumEnabled = false
+  private var checksumOutputStream: MutableCheckedOutputStream = _
+  private var checksum: Checksum = _
+
   /**
    * Cursors used to represent positions in the file.
    *
@@ -101,12 +108,30 @@ private[spark] class DiskBlockObjectWriter(
    */
   private var numRecordsWritten = 0
 
+  /**
+   * Set the checksum that the checksumOutputStream should use
+   */
+  def setChecksum(checksum: Checksum): Unit = {
+    if (checksumOutputStream == null) {
+      this.checksumEnabled = true
+      this.checksum = checksum
+    } else {
+      checksumOutputStream.setChecksum(checksum)
+    }
+  }
+
   private def initialize(): Unit = {
     fos = new FileOutputStream(file, true)
     channel = fos.getChannel()
     ts = new TimeTrackingOutputStream(writeMetrics, fos)
+    if (checksumEnabled) {
+      assert(this.checksum != null, "Checksum is not set")
+      checksumOutputStream = new MutableCheckedOutputStream(ts)
+      checksumOutputStream.setChecksum(checksum)
+    }
     class ManualCloseBufferedOutputStream
-      extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream
+      extends BufferedOutputStream(if (checksumEnabled) checksumOutputStream else ts, bufferSize)
+        with ManualCloseOutputStream
     mcs = new ManualCloseBufferedOutputStream
   }
 
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 1913637..dba9e74 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.serializer._
 import org.apache.spark.shuffle.ShufflePartitionPairsWriter
 import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
 import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
 import org.apache.spark.util.{CompletionIterator, Utils => TryUtils}
 
@@ -141,6 +142,11 @@ private[spark] class ExternalSorter[K, V, C](
   private val forceSpillFiles = new ArrayBuffer[SpilledFile]
   @volatile private var readingIterator: SpillableIterator = null
 
+  private val partitionChecksums =
+    ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf)
+
+  def getChecksums: Array[Long] = ShuffleChecksumHelper.getChecksumValues(partitionChecksums)
+
   // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
   // Can be a partial ordering by hash code if a total ordering is not provided through by the
   // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
@@ -762,7 +768,8 @@ private[spark] class ExternalSorter[K, V, C](
             serializerManager,
             serInstance,
             blockId,
-            context.taskMetrics().shuffleWriteMetrics)
+            context.taskMetrics().shuffleWriteMetrics,
+            if (partitionChecksums.nonEmpty) partitionChecksums(partitionId) else null)
           while (it.hasNext && it.nextPartition() == partitionId) {
             it.writeNext(partitionPairsWriter)
           }
@@ -786,7 +793,8 @@ private[spark] class ExternalSorter[K, V, C](
             serializerManager,
             serInstance,
             blockId,
-            context.taskMetrics().shuffleWriteMetrics)
+            context.taskMetrics().shuffleWriteMetrics,
+            if (partitionChecksums.nonEmpty) partitionChecksums(id) else null)
           if (elements.hasNext) {
             for (elem <- elements) {
               partitionPairsWriter.write(elem._1, elem._2)
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 5666bb3..cca3eb5 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -22,11 +22,11 @@ import java.nio.ByteBuffer;
 import java.nio.file.Files;
 import java.util.*;
 
+import org.apache.spark.*;
+import org.apache.spark.shuffle.ShuffleChecksumTestHelper;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
 import org.mockito.stubbing.Answer;
-import scala.Option;
-import scala.Product2;
-import scala.Tuple2;
-import scala.Tuple2$;
+import scala.*;
 import scala.collection.Iterator;
 
 import com.google.common.collect.HashMultiset;
@@ -36,10 +36,6 @@ import org.junit.Test;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
-import org.apache.spark.HashPartitioner;
-import org.apache.spark.ShuffleDependency;
-import org.apache.spark.SparkConf;
-import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
 import org.apache.spark.io.CompressionCodec$;
@@ -65,7 +61,7 @@ import static org.junit.Assert.*;
 import static org.mockito.Answers.RETURNS_SMART_NULLS;
 import static org.mockito.Mockito.*;
 
-public class UnsafeShuffleWriterSuite {
+public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper {
 
   static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
   static final int NUM_PARTITIONS = 4;
@@ -138,7 +134,7 @@ public class UnsafeShuffleWriterSuite {
 
     Answer<?> renameTempAnswer = invocationOnMock -> {
       partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
-      File tmp = (File) invocationOnMock.getArguments()[3];
+      File tmp = (File) invocationOnMock.getArguments()[4];
       if (!mergedOutputFile.delete()) {
         throw new RuntimeException("Failed to delete old merged output file.");
       }
@@ -152,11 +148,13 @@ public class UnsafeShuffleWriterSuite {
 
     doAnswer(renameTempAnswer)
         .when(shuffleBlockResolver)
-        .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class));
+        .writeMetadataFileAndCommit(
+          anyInt(), anyLong(), any(long[].class), any(long[].class), any(File.class));
 
     doAnswer(renameTempAnswer)
         .when(shuffleBlockResolver)
-        .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null));
+        .writeMetadataFileAndCommit(
+          anyInt(), anyLong(), any(long[].class), any(long[].class), eq(null));
 
     when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> {
       TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
@@ -171,7 +169,14 @@ public class UnsafeShuffleWriterSuite {
     when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
   }
 
-  private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabled) {
+  private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabled)
+    throws SparkException {
+    return createWriter(transferToEnabled, shuffleBlockResolver);
+  }
+
+  private UnsafeShuffleWriter<Object, Object> createWriter(
+    boolean transferToEnabled,
+    IndexShuffleBlockResolver blockResolver) throws SparkException {
     conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
     return new UnsafeShuffleWriter<>(
       blockManager,
@@ -181,7 +186,7 @@ public class UnsafeShuffleWriterSuite {
       taskContext,
       conf,
       taskContext.taskMetrics().shuffleWriteMetrics(),
-      new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
+      new LocalDiskShuffleExecutorComponents(conf, blockManager, blockResolver));
   }
 
   private void assertSpillFilesWereCleanedUp() {
@@ -219,12 +224,12 @@ public class UnsafeShuffleWriterSuite {
   }
 
   @Test(expected=IllegalStateException.class)
-  public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+  public void mustCallWriteBeforeSuccessfulStop() throws IOException, SparkException {
     createWriter(false).stop(true);
   }
 
   @Test
-  public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+  public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException, SparkException {
     createWriter(false).stop(false);
   }
 
@@ -291,6 +296,69 @@ public class UnsafeShuffleWriterSuite {
     assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
   }
 
+  @Test
+  public void writeChecksumFileWithoutSpill() throws Exception {
+    IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager);
+    ShuffleChecksumBlockId checksumBlockId =
+      new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID());
+    File checksumFile = new File(tempDir,
+      ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf));
+    File dataFile = new File(tempDir, "data");
+    File indexFile = new File(tempDir, "index");
+    when(diskBlockManager.getFile(checksumFile.getName()))
+      .thenReturn(checksumFile);
+    when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0)))
+      .thenReturn(dataFile);
+    when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0)))
+      .thenReturn(indexFile);
+
+    // In this example, each partition should have exactly one record:
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
+    for (int i = 0; i < NUM_PARTITIONS; i ++) {
+      dataToWrite.add(new Tuple2<>(i, i));
+    }
+    final UnsafeShuffleWriter<Object, Object> writer1 = createWriter(true, blockResolver);
+    writer1.write(dataToWrite.iterator());
+    writer1.stop(true);
+    assertTrue(checksumFile.exists());
+    assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS);
+    compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile);
+  }
+
+  @Test
+  public void writeChecksumFileWithSpill() throws Exception {
+    IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager);
+    ShuffleChecksumBlockId checksumBlockId =
+      new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID());
+    File checksumFile =
+      new File(tempDir, ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf));
+    File dataFile = new File(tempDir, "data");
+    File indexFile = new File(tempDir, "index");
+    when(diskBlockManager.getFile(eq(checksumFile.getName()))).thenReturn(checksumFile);
+    when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0)))
+      .thenReturn(dataFile);
+    when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0)))
+      .thenReturn(indexFile);
+
+    final UnsafeShuffleWriter<Object, Object> writer1 = createWriter(true, blockResolver);
+    writer1.insertRecordIntoSorter(new Tuple2<>(0, 0));
+    writer1.forceSorterToSpill();
+    writer1.insertRecordIntoSorter(new Tuple2<>(1, 0));
+    writer1.insertRecordIntoSorter(new Tuple2<>(2, 0));
+    writer1.forceSorterToSpill();
+    writer1.insertRecordIntoSorter(new Tuple2<>(0, 1));
+    writer1.insertRecordIntoSorter(new Tuple2<>(3, 0));
+    writer1.forceSorterToSpill();
+    writer1.insertRecordIntoSorter(new Tuple2<>(1, 1));
+    writer1.forceSorterToSpill();
+    writer1.insertRecordIntoSorter(new Tuple2<>(0, 2));
+    writer1.forceSorterToSpill();
+    writer1.closeAndWriteOutput();
+    assertTrue(checksumFile.exists());
+    assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS);
+    compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile);
+  }
+
   private void testMergingSpills(
       final boolean transferToEnabled,
       String compressionCodecName,
@@ -317,7 +385,7 @@ public class UnsafeShuffleWriterSuite {
 
   private void testMergingSpills(
       boolean transferToEnabled,
-      boolean encrypted) throws IOException {
+      boolean encrypted) throws IOException, SparkException {
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
     for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
@@ -515,7 +583,7 @@ public class UnsafeShuffleWriterSuite {
   }
 
   @Test
-  public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+  public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException, SparkException {
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
     writer.insertRecordIntoSorter(new Tuple2<>(1, 1));
     writer.insertRecordIntoSorter(new Tuple2<>(2, 2));
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
new file mode 100644
index 0000000..a8f2c40
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.{DataInputStream, File, FileInputStream}
+import java.util.zip.CheckedInputStream
+
+import org.apache.spark.network.util.LimitedInputStream
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
+
+trait ShuffleChecksumTestHelper {
+
+  /**
+   * Ensure that the checksum values are consistent between write and read side.
+   */
+  def compareChecksums(numPartition: Int, checksum: File, data: File, index: File): Unit = {
+    assert(checksum.exists(), "Checksum file doesn't exist")
+    assert(data.exists(), "Data file doesn't exist")
+    assert(index.exists(), "Index file doesn't exist")
+
+    var checksumIn: DataInputStream = null
+    val expectChecksums = Array.ofDim[Long](numPartition)
+    try {
+      checksumIn = new DataInputStream(new FileInputStream(checksum))
+      (0 until numPartition).foreach(i => expectChecksums(i) = checksumIn.readLong())
+    } finally {
+      if (checksumIn != null) {
+        checksumIn.close()
+      }
+    }
+
+    var dataIn: FileInputStream = null
+    var indexIn: DataInputStream = null
+    var checkedIn: CheckedInputStream = null
+    try {
+      dataIn = new FileInputStream(data)
+      indexIn = new DataInputStream(new FileInputStream(index))
+      var prevOffset = indexIn.readLong
+      (0 until numPartition).foreach { i =>
+        val curOffset = indexIn.readLong
+        val limit = (curOffset - prevOffset).toInt
+        val bytes = new Array[Byte](limit)
+        val checksumCal = ShuffleChecksumHelper.getChecksumByFileExtension(checksum.getName)
+        checkedIn = new CheckedInputStream(
+          new LimitedInputStream(dataIn, curOffset - prevOffset), checksumCal)
+        checkedIn.read(bytes, 0, limit)
+        prevOffset = curOffset
+        // checksum must be consistent at both write and read sides
+        assert(checkedIn.getChecksum.getValue == expectChecksums(i))
+      }
+    } finally {
+      if (dataIn != null) {
+        dataIn.close()
+      }
+      if (indexIn != null) {
+        indexIn.close()
+      }
+      if (checkedIn != null) {
+        checkedIn.close()
+      }
+    }
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 7fd0bf6..39eef97 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -33,13 +33,17 @@ import org.apache.spark._
 import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
 import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
 import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
-import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleChecksumTestHelper}
 import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
 import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
 
-class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+class BypassMergeSortShuffleWriterSuite
+  extends SparkFunSuite
+    with BeforeAndAfterEach
+    with ShuffleChecksumTestHelper {
 
   @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
   @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
@@ -76,10 +80,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
     when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
     when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
 
-    when(blockResolver.writeIndexFileAndCommit(
-      anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])))
+    when(blockResolver.writeMetadataFileAndCommit(
+      anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])))
       .thenAnswer { invocationOnMock =>
-        val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+        val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File]
         if (tmp != null) {
           outputFile.delete
           tmp.renameTo(outputFile)
@@ -236,4 +240,43 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
     writer.stop( /* success = */ false)
     assert(temporaryFilesCreated.count(_.exists()) === 0)
   }
+
+  test("write checksum file") {
+    val blockResolver = new IndexShuffleBlockResolver(conf, blockManager)
+    val shuffleId = shuffleHandle.shuffleId
+    val mapId = 0
+    val checksumBlockId = ShuffleChecksumBlockId(shuffleId, mapId, 0)
+    val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0)
+    val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0)
+    val checksumFile = new File(tempDir,
+      ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf))
+    val dataFile = new File(tempDir, dataBlockId.name)
+    val indexFile = new File(tempDir, indexBlockId.name)
+    reset(diskBlockManager)
+    when(diskBlockManager.getFile(checksumFile.getName)).thenAnswer(_ => checksumFile)
+    when(diskBlockManager.getFile(dataBlockId)).thenAnswer(_ => dataFile)
+    when(diskBlockManager.getFile(indexBlockId)).thenAnswer(_ => indexFile)
+    when(diskBlockManager.createTempShuffleBlock())
+      .thenAnswer { _ =>
+        val blockId = new TempShuffleBlockId(UUID.randomUUID)
+        val file = new File(tempDir, blockId.name)
+        temporaryFilesCreated += file
+        (blockId, file)
+      }
+
+    val numPartition = shuffleHandle.dependency.partitioner.numPartitions
+    val writer = new BypassMergeSortShuffleWriter[Int, Int](
+      blockManager,
+      shuffleHandle,
+      mapId,
+      conf,
+      taskContext.taskMetrics().shuffleWriteMetrics,
+      new LocalDiskShuffleExecutorComponents(conf, blockManager, blockResolver))
+
+    writer.write(Iterator((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6)))
+    writer.stop( /* success = */ true)
+    assert(checksumFile.exists())
+    assert(checksumFile.length() === 8 * numPartition)
+    compareChecksums(numPartition, checksumFile, dataFile, indexFile)
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
index 5955d44..49c079c 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -28,6 +28,7 @@ import org.roaringbitmap.RoaringBitmap
 import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark.{MapOutputTracker, SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config
 import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo}
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
@@ -49,6 +50,8 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
     when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
     when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
       (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString))
+    when(diskBlockManager.getFile(any[String])).thenAnswer(
+      (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString))
     when(diskBlockManager.getMergedShuffleFile(
       any[BlockId], any[Option[Array[String]]])).thenAnswer(
       (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString))
@@ -77,7 +80,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
     } {
       out.close()
     }
-    resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
+    resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths, Array.empty, dataTmp)
 
     val indexFile = new File(tempDir.getAbsolutePath, idxName)
     val dataFile = resolver.getDataFile(shuffleId, mapId)
@@ -97,7 +100,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
     } {
       out2.close()
     }
-    resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2)
+    resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths2, Array.empty, dataTmp2)
 
     assert(indexFile.length() === (lengths.length + 1) * 8)
     assert(lengths2.toSeq === lengths.toSeq)
@@ -136,7 +139,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
     } {
       out3.close()
     }
-    resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3)
+    resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths3, Array.empty, dataTmp3)
     assert(indexFile.length() === (lengths3.length + 1) * 8)
     assert(lengths3.toSeq != lengths.toSeq)
     assert(dataFile.exists())
@@ -248,4 +251,19 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
       outIndex.close()
     }
   }
+
+  test("write checksum file") {
+    val resolver = new IndexShuffleBlockResolver(conf, blockManager)
+    val dataTmp = File.createTempFile("shuffle", null, tempDir)
+    val indexInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
+    val checksumsInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
+    resolver.writeMetadataFileAndCommit(0, 0, indexInMemory, checksumsInMemory, dataTmp)
+    val checksumFile = resolver.getChecksumFile(0, 0)
+    assert(checksumFile.exists())
+    val checksumFileName = checksumFile.toString
+    val checksumAlgo = checksumFileName.substring(checksumFileName.lastIndexOf(".") + 1)
+    assert(checksumAlgo === conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM))
+    val checksumsFromFile = resolver.getChecksums(checksumFile, 10)
+    assert(checksumsInMemory === checksumsFromFile)
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
index 4c679fd..e345736 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -20,19 +20,25 @@ package org.apache.spark.shuffle.sort
 import org.mockito.{Mock, MockitoAnnotations}
 import org.mockito.Answers.RETURNS_SMART_NULLS
 import org.mockito.Mockito._
+import org.scalatest.PrivateMethodTester
 import org.scalatest.matchers.must.Matchers
 
-import org.apache.spark.{Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite}
+import org.apache.spark.{Aggregator, DebugFilesystem, Partitioner, SharedSparkContext, ShuffleDependency, SparkContext, SparkFunSuite}
 import org.apache.spark.memory.MemoryTestingUtils
 import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver}
+import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleChecksumTestHelper}
 import org.apache.spark.shuffle.api.ShuffleExecutorComponents
 import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents
 import org.apache.spark.storage.BlockManager
 import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.ExternalSorter
 
-
-class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers {
+class SortShuffleWriterSuite
+  extends SparkFunSuite
+    with SharedSparkContext
+    with Matchers
+    with PrivateMethodTester
+    with ShuffleChecksumTestHelper {
 
   @Mock(answer = RETURNS_SMART_NULLS)
   private var blockManager: BlockManager = _
@@ -44,13 +50,14 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with
   private val serializer = new JavaSerializer(conf)
   private var shuffleExecutorComponents: ShuffleExecutorComponents = _
 
+  private val partitioner = new Partitioner() {
+    def numPartitions = numMaps
+    def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions)
+  }
+
   override def beforeEach(): Unit = {
     super.beforeEach()
     MockitoAnnotations.openMocks(this).close()
-    val partitioner = new Partitioner() {
-      def numPartitions = numMaps
-      def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions)
-    }
     shuffleHandle = {
       val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
       when(dependency.partitioner).thenReturn(partitioner)
@@ -103,4 +110,68 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with
     assert(dataFile.length() === writeMetrics.bytesWritten)
     assert(records.size === writeMetrics.recordsWritten)
   }
+
+  Seq((true, false, false),
+    (true, true, false),
+    (true, false, true),
+    (true, true, true),
+    (false, false, false),
+    (false, true, false),
+    (false, false, true),
+    (false, true, true)).foreach { case (doSpill, doAgg, doOrder) =>
+    test(s"write checksum file (spill=$doSpill, aggregator=$doAgg, order=$doOrder)") {
+      val aggregator = if (doAgg) {
+        Some(Aggregator[Int, Int, Int](
+          v => v,
+          (c, v) => c + v,
+          (c1, c2) => c1 + c2))
+      } else None
+      val order = if (doOrder) {
+        Some(new Ordering[Int] {
+          override def compare(x: Int, y: Int): Int = x - y
+        })
+      } else None
+
+      val shuffleHandle = {
+        val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
+        when(dependency.partitioner).thenReturn(partitioner)
+        when(dependency.serializer).thenReturn(serializer)
+        when(dependency.aggregator).thenReturn(aggregator)
+        when(dependency.keyOrdering).thenReturn(order)
+        new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency)
+      }
+
+      // FIXME: this can affect other tests (if any) after this set of tests
+      //  since `sc` is global.
+      sc.stop()
+      conf.set("spark.shuffle.spill.numElementsForceSpillThreshold",
+        if (doSpill) "0" else Int.MaxValue.toString)
+      conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
+      val localSC = new SparkContext("local[4]", "test", conf)
+      val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
+      val context = MemoryTestingUtils.fakeTaskContext(localSC.env)
+      val records = List[(Int, Int)](
+        (0, 1), (1, 2), (0, 2), (1, 3), (2, 3), (3, 4), (4, 5), (3, 5), (4, 6))
+      val numPartition = shuffleHandle.dependency.partitioner.numPartitions
+      val writer = new SortShuffleWriter[Int, Int, Int](
+        shuffleHandle,
+        mapId = 0,
+        context,
+        new LocalDiskShuffleExecutorComponents(
+          conf, shuffleBlockResolver._blockManager, shuffleBlockResolver))
+      writer.write(records.toIterator)
+      val sorterMethod = PrivateMethod[ExternalSorter[_, _, _]](Symbol("sorter"))
+      val sorter = writer.invokePrivate(sorterMethod())
+      val expectSpillSize = if (doSpill) records.size else 0
+      assert(sorter.numSpills === expectSpillSize)
+      writer.stop(success = true)
+      val checksumFile = shuffleBlockResolver.getChecksumFile(shuffleId, 0)
+      assert(checksumFile.exists())
+      assert(checksumFile.length() === 8 * numPartition)
+      val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 0)
+      val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, 0)
+      compareChecksums(numPartition, checksumFile, dataFile, indexFile)
+      localSC.stop()
+    }
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
index ef5c615..35d9b4a 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
@@ -74,11 +74,11 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA
       .set("spark.app.id", "example.spark.app")
       .set("spark.shuffle.unsafe.file.output.buffer", "16k")
     when(blockResolver.getDataFile(anyInt, anyLong)).thenReturn(mergedOutputFile)
-    when(blockResolver.writeIndexFileAndCommit(
-      anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])))
+    when(blockResolver.writeMetadataFileAndCommit(
+      anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])))
       .thenAnswer { invocationOnMock =>
         partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]]
-        val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+        val tmp: File = invocationOnMock.getArguments()(4).asInstanceOf[File]
         if (tmp != null) {
           mergedOutputFile.delete()
           tmp.renameTo(mergedOutputFile)
@@ -136,7 +136,8 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA
   }
 
   private def verifyWrittenRecords(): Unit = {
-    val committedLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
+    val committedLengths =
+      mapOutputWriter.commitAllPartitions(Array.empty[Long]).getPartitionLengths
     assert(partitionSizesInMergedFile === partitionLengths)
     assert(committedLengths === partitionLengths)
     assert(mergedOutputFile.length() === partitionLengths.sum)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index a6de64b..7bec961 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -500,15 +500,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
         intercept[SparkException] {
           data.reduceByKey(_ + _).count()
         }
-        // After the shuffle, there should be only 2 files on disk: the output of task 1 and
-        // its index. All other files (map 2's output and intermediate merge files) should
-        // have been deleted.
-        assert(diskBlockManager.getAllFiles().length === 2)
+        // After the shuffle, there should be only 3 files on disk: the output of task 1 and
+        // its index and checksum. All other files (map 2's output and intermediate merge files)
+        // should have been deleted.
+        assert(diskBlockManager.getAllFiles().length === 3)
       } else {
         assert(data.reduceByKey(_ + _).count() === size)
-        // After the shuffle, there should be only 4 files on disk: the output of both tasks
-        // and their indices. All intermediate merge files should have been deleted.
-        assert(diskBlockManager.getAllFiles().length === 4)
+        // After the shuffle, there should be only 6 files on disk: the output of both tasks
+        // and their indices/checksums. All intermediate merge files should have been deleted.
+        assert(diskBlockManager.getAllFiles().length === 6)
       }
     }
   }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3b93a34..dba74ac 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -61,7 +61,15 @@ object MimaExcludes {
     ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.WritablePartitionedIterator"),
 
     // [SPARK-35757][CORE] Add bitwise AND operation and functionality for intersecting bloom filters
-    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.BloomFilter.intersectInPlace")
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.BloomFilter.intersectInPlace"),
+
+    // [SPARK-35276][CORE] Calculate checksum for shuffle data and write as checksum file
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleMapOutputWriter.commitAllPartitions"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskSingleSpillMapOutputWriter.transferMapSpillFile"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"),
+    ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"),
+    ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions")
   )
 
   // Exclude rules for 3.1.x

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org