You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/07/17 05:24:27 UTC
[spark] branch branch-3.2 updated: [SPARK-35276][CORE] Calculate
checksum for shuffle data and write as checksum file
This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new d5022c3 [SPARK-35276][CORE] Calculate checksum for shuffle data and write as checksum file
d5022c3 is described below
commit d5022c3c6f73c014b3ec7535c78b1c70a2f03941
Author: yi.wu <yi...@databricks.com>
AuthorDate: Sat Jul 17 00:23:14 2021 -0500
[SPARK-35276][CORE] Calculate checksum for shuffle data and write as checksum file
### What changes were proposed in this pull request?
This is the initial work of add checksum support of shuffle. This is a piece of https://github.com/apache/spark/pull/32385. And this PR only adds checksum functionality at the shuffle writer side.
Basically, the idea is to wrap a `MutableCheckedOutputStream`* upon the `FileOutputStream` while the shuffle writer generating the shuffle data. But the specific wrapping places are a bit different among the shuffle writers due to their different implementation:
* `BypassMergeSortShuffleWriter` - wrap on each partition file
* `UnsafeShuffleWriter` - wrap on each spill files directly since they doesn't require aggregation, sorting
* `SortShuffleWriter` - wrap on the `ShufflePartitionPairsWriter` after merged spill files since they might require aggregation, sorting
\* `MutableCheckedOutputStream` is a variant of `java.util.zip.CheckedOutputStream` which can change the checksum calculator at runtime.
And we use the `Adler32`, which uses the CRC-32 algorithm but much faster, to calculate the checksum as the same as `Broadcast`'s checksum.
### Why are the changes needed?
### Does this PR introduce _any_ user-facing change?
Yes, added a new conf: `spark.shuffle.checksum`.
### How was this patch tested?
Added unit tests.
Closes #32401 from Ngone51/add-checksum-files.
Authored-by: yi.wu <yi...@databricks.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
(cherry picked from commit 4783fb72aff4a550ca0cc680dcc0d730e3e36dac)
Signed-off-by: Mridul Muralidharan <mridulatgmail.com>
---
.../spark/shuffle/api/ShuffleMapOutputWriter.java | 9 +-
.../spark/shuffle/api/ShufflePartitionWriter.java | 6 +-
.../api/SingleSpillShuffleMapOutputWriter.java | 5 +-
.../shuffle/checksum/ShuffleChecksumHelper.java | 100 +++++++++++
.../shuffle/sort/BypassMergeSortShuffleWriter.java | 24 ++-
.../spark/shuffle/sort/ShuffleExternalSorter.java | 17 +-
.../spark/shuffle/sort/UnsafeShuffleWriter.java | 15 +-
.../sort/io/LocalDiskShuffleMapOutputWriter.java | 5 +-
.../io/LocalDiskSingleSpillMapOutputWriter.java | 6 +-
.../org/apache/spark/internal/config/package.scala | 19 ++
.../spark/io/MutableCheckedOutputStream.scala} | 39 +++--
.../spark/shuffle/IndexShuffleBlockResolver.scala | 195 +++++++++++++++++----
.../shuffle/ShufflePartitionPairsWriter.scala | 15 +-
.../spark/shuffle/sort/SortShuffleWriter.scala | 2 +-
.../scala/org/apache/spark/storage/BlockId.scala | 6 +
.../spark/storage/DiskBlockObjectWriter.scala | 27 ++-
.../spark/util/collection/ExternalSorter.scala | 12 +-
.../shuffle/sort/UnsafeShuffleWriterSuite.java | 104 +++++++++--
.../spark/shuffle/ShuffleChecksumTestHelper.scala | 78 +++++++++
.../sort/BypassMergeSortShuffleWriterSuite.scala | 53 +++++-
.../sort/IndexShuffleBlockResolverSuite.scala | 24 ++-
.../shuffle/sort/SortShuffleWriterSuite.scala | 87 ++++++++-
.../io/LocalDiskShuffleMapOutputWriterSuite.scala | 9 +-
.../util/collection/ExternalSorterSuite.scala | 14 +-
project/MimaExcludes.scala | 10 +-
25 files changed, 762 insertions(+), 119 deletions(-)
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
index 0167002..2237ec0 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
@@ -59,6 +59,10 @@ public interface ShuffleMapOutputWriter {
* available to downstream reduce tasks. If this method throws any exception, this module's
* {@link #abort(Throwable)} method will be invoked before propagating the exception.
* <p>
+ * Shuffle extensions which care about the cause of shuffle data corruption should store
+ * the checksums properly. When corruption happens, Spark would provide the checksum
+ * of the fetched partition to the shuffle extension to help diagnose the cause of corruption.
+ * <p>
* This can also close any resources and clean up temporary state if necessary.
* <p>
* The returned commit message is a structure with two components:
@@ -68,8 +72,11 @@ public interface ShuffleMapOutputWriter {
* for that partition id.
* <p>
* 2) An optional metadata blob that can be used by shuffle readers.
+ *
+ * @param checksums The checksum values for each partition (where checksum index is equivalent to
+ * partition id) if shuffle checksum enabled. Otherwise, it's empty.
*/
- MapOutputCommitMessage commitAllPartitions() throws IOException;
+ MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException;
/**
* Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}.
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
index 9288751..143cc6c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
@@ -49,7 +49,7 @@ public interface ShufflePartitionWriter {
* by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
* {@link OutputStream#close()} does not close the resource, since it will be reused across
* partition writes. The underlying resources should be cleaned up in
- * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+ * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
*/
OutputStream openStream() throws IOException;
@@ -68,7 +68,7 @@ public interface ShufflePartitionWriter {
* by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
* {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel
* will be reused across partition writes. The underlying resources should be cleaned up in
- * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+ * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])} and
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
* <p>
* This method is primarily for advanced optimizations where bytes can be copied from the input
@@ -79,7 +79,7 @@ public interface ShufflePartitionWriter {
* Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the
* underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure
* that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()},
- * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or
+ * {@link ShuffleMapOutputWriter#commitAllPartitions(long[])}, or
* {@link ShuffleMapOutputWriter#abort(Throwable)}.
*/
default Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
index cad8dcf..ba3d5a6 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
@@ -32,5 +32,8 @@ public interface SingleSpillShuffleMapOutputWriter {
/**
* Transfer a file that contains the bytes of all the partitions written by this map task.
*/
- void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException;
+ void transferMapSpillFile(
+ File mapOutputFile,
+ long[] partitionLengths,
+ long[] checksums) throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java
new file mode 100644
index 0000000..a368836
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/checksum/ShuffleChecksumHelper.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.checksum;
+
+import java.util.zip.Adler32;
+import java.util.zip.CRC32;
+import java.util.zip.Checksum;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkException;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.storage.ShuffleChecksumBlockId;
+
+/**
+ * A set of utility functions for the shuffle checksum.
+ */
+@Private
+public class ShuffleChecksumHelper {
+
+ /** Used when the checksum is disabled for shuffle. */
+ private static final Checksum[] EMPTY_CHECKSUM = new Checksum[0];
+ public static final long[] EMPTY_CHECKSUM_VALUE = new long[0];
+
+ public static boolean isShuffleChecksumEnabled(SparkConf conf) {
+ return (boolean) conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ENABLED());
+ }
+
+ public static Checksum[] createPartitionChecksumsIfEnabled(int numPartitions, SparkConf conf)
+ throws SparkException {
+ if (!isShuffleChecksumEnabled(conf)) {
+ return EMPTY_CHECKSUM;
+ }
+
+ String checksumAlgo = shuffleChecksumAlgorithm(conf);
+ return getChecksumByAlgorithm(numPartitions, checksumAlgo);
+ }
+
+ private static Checksum[] getChecksumByAlgorithm(int num, String algorithm)
+ throws SparkException {
+ Checksum[] checksums;
+ switch (algorithm) {
+ case "ADLER32":
+ checksums = new Adler32[num];
+ for (int i = 0; i < num; i ++) {
+ checksums[i] = new Adler32();
+ }
+ return checksums;
+
+ case "CRC32":
+ checksums = new CRC32[num];
+ for (int i = 0; i < num; i ++) {
+ checksums[i] = new CRC32();
+ }
+ return checksums;
+
+ default:
+ throw new SparkException("Unsupported shuffle checksum algorithm: " + algorithm);
+ }
+ }
+
+ public static long[] getChecksumValues(Checksum[] partitionChecksums) {
+ int numPartitions = partitionChecksums.length;
+ long[] checksumValues = new long[numPartitions];
+ for (int i = 0; i < numPartitions; i ++) {
+ checksumValues[i] = partitionChecksums[i].getValue();
+ }
+ return checksumValues;
+ }
+
+ public static String shuffleChecksumAlgorithm(SparkConf conf) {
+ return conf.get(package$.MODULE$.SHUFFLE_CHECKSUM_ALGORITHM());
+ }
+
+ public static Checksum getChecksumByFileExtension(String fileName) throws SparkException {
+ int index = fileName.lastIndexOf(".");
+ String algorithm = fileName.substring(index + 1);
+ return getChecksumByAlgorithm(1, algorithm)[0];
+ }
+
+ public static String getChecksumFileName(ShuffleChecksumBlockId blockId, SparkConf conf) {
+ // append the shuffle checksum algorithm as the file extension
+ return String.format("%s.%s", blockId.name(), shuffleChecksumAlgorithm(conf));
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 3dbee1b..3222240 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -23,6 +23,7 @@ import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.FileChannel;
import java.util.Optional;
+import java.util.zip.Checksum;
import javax.annotation.Nullable;
import scala.None$;
@@ -38,6 +39,7 @@ import org.slf4j.LoggerFactory;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.SparkException;
import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
@@ -49,6 +51,7 @@ import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -93,6 +96,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private FileSegment[] partitionWriterSegments;
@Nullable private MapStatus mapStatus;
private long[] partitionLengths;
+ /** Checksum calculator for each partition. Empty when shuffle checksum disabled. */
+ private final Checksum[] partitionChecksums;
/**
* Are we in the process of stopping? Because map tasks can call stop() with success = true
@@ -107,7 +112,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
long mapId,
SparkConf conf,
ShuffleWriteMetricsReporter writeMetrics,
- ShuffleExecutorComponents shuffleExecutorComponents) {
+ ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
@@ -120,6 +125,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
this.shuffleExecutorComponents = shuffleExecutorComponents;
+ this.partitionChecksums =
+ ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
}
@Override
@@ -129,7 +136,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
.createMapOutputWriter(shuffleId, mapId, numPartitions);
try {
if (!records.hasNext()) {
- partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths();
+ partitionLengths = mapOutputWriter.commitAllPartitions(
+ ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
mapStatus = MapStatus$.MODULE$.apply(
blockManager.shuffleServerId(), partitionLengths, mapId);
return;
@@ -143,8 +151,12 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
- partitionWriters[i] =
- blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+ DiskBlockObjectWriter writer =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+ if (partitionChecksums.length > 0) {
+ writer.setChecksum(partitionChecksums[i]);
+ }
+ partitionWriters[i] = writer;
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
@@ -218,7 +230,9 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
}
partitionWriters = null;
}
- return mapOutputWriter.commitAllPartitions().getPartitionLengths();
+ return mapOutputWriter.commitAllPartitions(
+ ShuffleChecksumHelper.getChecksumValues(partitionChecksums)
+ ).getPartitionLengths();
}
private void writePartitionedDataWithChannel(
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 833744f..0307027 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -21,7 +21,9 @@ import javax.annotation.Nullable;
import java.io.File;
import java.io.IOException;
import java.util.LinkedList;
+import java.util.zip.Checksum;
+import org.apache.spark.SparkException;
import scala.Tuple2;
import com.google.common.annotations.VisibleForTesting;
@@ -39,6 +41,7 @@ import org.apache.spark.memory.TooLargePageException;
import org.apache.spark.serializer.DummySerializerInstance;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.DiskBlockObjectWriter;
import org.apache.spark.storage.FileSegment;
@@ -107,6 +110,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
@Nullable private MemoryBlock currentPage = null;
private long pageCursor = -1;
+ // Checksum calculator for each partition. Empty when shuffle checksum disabled.
+ private final Checksum[] partitionChecksums;
+
ShuffleExternalSorter(
TaskMemoryManager memoryManager,
BlockManager blockManager,
@@ -114,7 +120,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
int initialSize,
int numPartitions,
SparkConf conf,
- ShuffleWriteMetricsReporter writeMetrics) {
+ ShuffleWriteMetricsReporter writeMetrics) throws SparkException {
super(memoryManager,
(int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, memoryManager.pageSizeBytes()),
memoryManager.getTungstenMemoryMode());
@@ -133,6 +139,12 @@ final class ShuffleExternalSorter extends MemoryConsumer {
this.peakMemoryUsedBytes = getMemoryUsage();
this.diskWriteBufferSize =
(int) (long) conf.get(package$.MODULE$.SHUFFLE_DISK_WRITE_BUFFER_SIZE());
+ this.partitionChecksums =
+ ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf);
+ }
+
+ public long[] getChecksums() {
+ return ShuffleChecksumHelper.getChecksumValues(partitionChecksums);
}
/**
@@ -204,6 +216,9 @@ final class ShuffleExternalSorter extends MemoryConsumer {
spillInfo.partitionLengths[currentPartition] = fileSegment.length();
}
currentPartition = partition;
+ if (partitionChecksums.length > 0) {
+ writer.setChecksum(partitionChecksums[currentPartition]);
+ }
}
final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index e8f94ba..2659b17 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -57,6 +57,7 @@ import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.ShufflePartitionWriter;
import org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter;
import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
@@ -115,7 +116,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
TaskContext taskContext,
SparkConf sparkConf,
ShuffleWriteMetricsReporter writeMetrics,
- ShuffleExecutorComponents shuffleExecutorComponents) {
+ ShuffleExecutorComponents shuffleExecutorComponents) throws SparkException {
final int numPartitions = handle.dependency().partitioner().numPartitions();
if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
throw new IllegalArgumentException(
@@ -198,7 +199,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
}
}
- private void open() {
+ private void open() throws SparkException {
assert (sorter == null);
sorter = new ShuffleExternalSorter(
memoryManager,
@@ -219,10 +220,10 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
serBuffer = null;
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
- sorter = null;
try {
partitionLengths = mergeSpills(spills);
} finally {
+ sorter = null;
for (SpillInfo spill : spills) {
if (spill.file.exists() && !spill.file.delete()) {
logger.error("Error while deleting spill file {}", spill.file.getPath());
@@ -267,7 +268,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
if (spills.length == 0) {
final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents
.createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions());
- return mapWriter.commitAllPartitions().getPartitionLengths();
+ return mapWriter.commitAllPartitions(
+ ShuffleChecksumHelper.EMPTY_CHECKSUM_VALUE).getPartitionLengths();
} else if (spills.length == 1) {
Optional<SingleSpillShuffleMapOutputWriter> maybeSingleFileWriter =
shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId);
@@ -277,7 +279,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
partitionLengths = spills[0].partitionLengths;
logger.debug("Merge shuffle spills for mapId {} with length {}", mapId,
partitionLengths.length);
- maybeSingleFileWriter.get().transferMapSpillFile(spills[0].file, partitionLengths);
+ maybeSingleFileWriter.get()
+ .transferMapSpillFile(spills[0].file, partitionLengths, sorter.getChecksums());
} else {
partitionLengths = mergeSpillsUsingStandardWriter(spills);
}
@@ -330,7 +333,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
// to be counted as shuffle write, but this will lead to double-counting of the final
// SpillInfo's bytes.
writeMetrics.decBytesWritten(spills[spills.length - 1].file.length());
- partitionLengths = mapWriter.commitAllPartitions().getPartitionLengths();
+ partitionLengths = mapWriter.commitAllPartitions(sorter.getChecksums()).getPartitionLengths();
} catch (Exception e) {
try {
mapWriter.abort(e);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
index 0b28626..6c5025d 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
@@ -98,7 +98,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
}
@Override
- public MapOutputCommitMessage commitAllPartitions() throws IOException {
+ public MapOutputCommitMessage commitAllPartitions(long[] checksums) throws IOException {
// Check the position after transferTo loop to see if it is in the right position and raise a
// exception if it is incorrect. The position will not be increased to the expected length
// after calling transferTo in kernel version 2.6.32. This issue is described at
@@ -115,7 +115,8 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
log.debug("Writing shuffle index file for mapId {} with length {}", mapId,
partitionLengths.length);
- blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
+ blockResolver
+ .writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, checksums, resolvedTmp);
return MapOutputCommitMessage.of(partitionLengths);
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
index c8b4199..6a994b4 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java
@@ -44,12 +44,14 @@ public class LocalDiskSingleSpillMapOutputWriter
@Override
public void transferMapSpillFile(
File mapSpillFile,
- long[] partitionLengths) throws IOException {
+ long[] partitionLengths,
+ long[] checksums) throws IOException {
// The map spill file already has the proper format, and it contains all of the partition data.
// So just transfer it directly to the destination without any merging.
File outputFile = blockResolver.getDataFile(shuffleId, mapId);
File tempFile = Utils.tempFileWith(outputFile);
Files.move(mapSpillFile.toPath(), tempFile.toPath());
- blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tempFile);
+ blockResolver
+ .writeMetadataFileAndCommit(shuffleId, mapId, partitionLengths, checksums, tempFile);
}
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 613a66d..3ef964f 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -1368,6 +1368,25 @@ package object config {
s"The buffer size must be greater than 0 and less than or equal to ${Int.MaxValue}.")
.createWithDefault(4096)
+ private[spark] val SHUFFLE_CHECKSUM_ENABLED =
+ ConfigBuilder("spark.shuffle.checksum.enabled")
+ .doc("Whether to calculate the checksum of shuffle output. If enabled, Spark will try " +
+ "its best to tell if shuffle data corruption is caused by network or disk or others.")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ private[spark] val SHUFFLE_CHECKSUM_ALGORITHM =
+ ConfigBuilder("spark.shuffle.checksum.algorithm")
+ .doc("The algorithm used to calculate the checksum. Currently, it only supports" +
+ " built-in algorithms of JDK.")
+ .version("3.3.0")
+ .stringConf
+ .transform(_.toUpperCase(Locale.ROOT))
+ .checkValue(Set("ADLER32", "CRC32").contains, "Shuffle checksum algorithm " +
+ "should be either Adler32 or CRC32.")
+ .createWithDefault("ADLER32")
+
private[spark] val SHUFFLE_COMPRESS =
ConfigBuilder("spark.shuffle.compress")
.doc("Whether to compress shuffle output. Compression will use " +
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java b/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala
similarity index 51%
copy from core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
copy to core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala
index cad8dcf..754b4a8 100644
--- a/core/src/main/java/org/apache/spark/shuffle/api/SingleSpillShuffleMapOutputWriter.java
+++ b/core/src/main/scala/org/apache/spark/io/MutableCheckedOutputStream.scala
@@ -15,22 +15,35 @@
* limitations under the License.
*/
-package org.apache.spark.shuffle.api;
+package org.apache.spark.io
-import java.io.File;
-import java.io.IOException;
-
-import org.apache.spark.annotation.Private;
+import java.io.OutputStream
+import java.util.zip.Checksum
/**
- * Optional extension for partition writing that is optimized for transferring a single
- * file to the backing store.
+ * A variant of [[java.util.zip.CheckedOutputStream]] which can
+ * change the checksum calculator at runtime.
*/
-@Private
-public interface SingleSpillShuffleMapOutputWriter {
+class MutableCheckedOutputStream(out: OutputStream) extends OutputStream {
+ private var checksum: Checksum = _
+
+ def setChecksum(c: Checksum): Unit = {
+ this.checksum = c
+ }
+
+ override def write(b: Int): Unit = {
+ assert(checksum != null, "Checksum is not set.")
+ checksum.update(b)
+ out.write(b)
+ }
+
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+ assert(checksum != null, "Checksum is not set.")
+ checksum.update(b, off, len)
+ out.write(b, off, len)
+ }
+
+ override def flush(): Unit = out.flush()
- /**
- * Transfer a file that contains the bytes of all the partitions written by this map task.
- */
- void transferMapSpillFile(File mapOutputFile, long[] partitionLengths) throws IOException;
+ override def close(): Unit = out.close()
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 5d1da19..9c50569 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -22,6 +22,8 @@ import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.file.Files
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.{SparkConf, SparkEnv, SparkException}
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.io.NioBufferedFileInputStream
@@ -31,6 +33,7 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.{ExecutorDiskUtils, MergedBlockMeta}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
import org.apache.spark.storage._
import org.apache.spark.util.Utils
@@ -142,17 +145,18 @@ private[spark] class IndexShuffleBlockResolver(
*/
def removeDataByMap(shuffleId: Int, mapId: Long): Unit = {
var file = getDataFile(shuffleId, mapId)
- if (file.exists()) {
- if (!file.delete()) {
- logWarning(s"Error deleting data ${file.getPath()}")
- }
+ if (file.exists() && !file.delete()) {
+ logWarning(s"Error deleting data ${file.getPath()}")
}
file = getIndexFile(shuffleId, mapId)
- if (file.exists()) {
- if (!file.delete()) {
- logWarning(s"Error deleting index ${file.getPath()}")
- }
+ if (file.exists() && !file.delete()) {
+ logWarning(s"Error deleting index ${file.getPath()}")
+ }
+
+ file = getChecksumFile(shuffleId, mapId)
+ if (file.exists() && !file.delete()) {
+ logWarning(s"Error deleting checksum ${file.getPath()}")
}
}
@@ -303,22 +307,41 @@ private[spark] class IndexShuffleBlockResolver(
/**
- * Write an index file with the offsets of each block, plus a final offset at the end for the
- * end of the output file. This will be used by getBlockData to figure out where each block
- * begins and ends.
+ * Commit the data and metadata files as an atomic operation, use the existing ones, or
+ * replace them with new ones. Note that the metadata parameters (`lengths`, `checksums`)
+ * will be updated to match the existing ones if use the existing ones.
+ *
+ * There're two kinds of metadata files:
*
- * It will commit the data and index file as an atomic operation, use the existing ones, or
- * replace them with new ones.
+ * - index file
+ * An index file contains the offsets of each block, plus a final offset at the end
+ * for the end of the output file. It will be used by [[getBlockData]] to figure out
+ * where each block begins and ends.
*
- * Note: the `lengths` will be updated to match the existing index file if use the existing ones.
+ * - checksum file (optional)
+ * An checksum file contains the checksum of each block. It will be used to diagnose
+ * the cause when a block is corrupted. Note that empty `checksums` indicate that
+ * checksum is disabled.
*/
- def writeIndexFileAndCommit(
+ def writeMetadataFileAndCommit(
shuffleId: Int,
mapId: Long,
lengths: Array[Long],
+ checksums: Array[Long],
dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
val indexTmp = Utils.tempFileWith(indexFile)
+
+ val checksumEnabled = checksums.nonEmpty
+ val (checksumFileOpt, checksumTmpOpt) = if (checksumEnabled) {
+ assert(lengths.length == checksums.length,
+ "The size of partition lengths and checksums should be equal")
+ val checksumFile = getChecksumFile(shuffleId, mapId)
+ (Some(checksumFile), Some(Utils.tempFileWith(checksumFile)))
+ } else {
+ (None, None)
+ }
+
try {
val dataFile = getDataFile(shuffleId, mapId)
// There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
@@ -329,37 +352,47 @@ private[spark] class IndexShuffleBlockResolver(
// Another attempt for the same task has already written our map outputs successfully,
// so just use the existing partition lengths and delete our temporary map outputs.
System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
+ if (checksumEnabled) {
+ val existingChecksums = getChecksums(checksumFileOpt.get, checksums.length)
+ if (existingChecksums != null) {
+ System.arraycopy(existingChecksums, 0, checksums, 0, lengths.length)
+ } else {
+ // It's possible that the previous task attempt succeeded writing the
+ // index file and data file but failed to write the checksum file. In
+ // this case, the current task attempt could write the missing checksum
+ // file by itself.
+ writeMetadataFile(checksums, checksumTmpOpt.get, checksumFileOpt.get, false)
+ }
+ }
if (dataTmp != null && dataTmp.exists()) {
dataTmp.delete()
}
} else {
// This is the first successful attempt in writing the map outputs for this task,
// so override any existing index and data files with the ones we wrote.
- val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
- Utils.tryWithSafeFinally {
- // We take in lengths of each block, need to convert it to offsets.
- var offset = 0L
- out.writeLong(offset)
- for (length <- lengths) {
- offset += length
- out.writeLong(offset)
- }
- } {
- out.close()
- }
- if (indexFile.exists()) {
- indexFile.delete()
- }
+ val offsets = lengths.scanLeft(0L)(_ + _)
+ writeMetadataFile(offsets, indexTmp, indexFile, true)
+
if (dataFile.exists()) {
dataFile.delete()
}
- if (!indexTmp.renameTo(indexFile)) {
- throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
- }
if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
}
+
+ // write the checksum file
+ checksumTmpOpt.zip(checksumFileOpt).foreach { case (checksumTmp, checksumFile) =>
+ try {
+ writeMetadataFile(checksums, checksumTmp, checksumFile, false)
+ } catch {
+ case e: Exception =>
+ // It's not worthwhile to fail here after index file and data file are
+ // already successfully stored since checksum is only a best-effort for
+ // the corner error case.
+ logError("Failed to write checksum file", e)
+ }
+ }
}
}
} finally {
@@ -367,6 +400,63 @@ private[spark] class IndexShuffleBlockResolver(
if (indexTmp.exists() && !indexTmp.delete()) {
logError(s"Failed to delete temporary index file at ${indexTmp.getAbsolutePath}")
}
+ checksumTmpOpt.foreach { checksumTmp =>
+ if (checksumTmp.exists()) {
+ try {
+ if (!checksumTmp.delete()) {
+ logError(s"Failed to delete temporary checksum file " +
+ s"at ${checksumTmp.getAbsolutePath}")
+ }
+ } catch {
+ case e: Exception =>
+ // Unlike index deletion, we won't propagate the error for the checksum file since
+ // checksum is only a best-effort.
+ logError(s"Failed to delete temporary checksum file " +
+ s"at ${checksumTmp.getAbsolutePath}", e)
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Write the metadata file (index or checksum). Metadata values will be firstly write into
+ * the tmp file and the tmp file will be renamed to the target file at the end to avoid dirty
+ * writes.
+ * @param metaValues The metadata values
+ * @param tmpFile The temp file
+ * @param targetFile The target file
+ * @param propagateError Whether to propagate the error for file operation. Unlike index file,
+ * checksum is only a best-effort so we won't fail the whole task due to
+ * the error from checksum.
+ */
+ private def writeMetadataFile(
+ metaValues: Array[Long],
+ tmpFile: File,
+ targetFile: File,
+ propagateError: Boolean): Unit = {
+ val out = new DataOutputStream(
+ new BufferedOutputStream(
+ new FileOutputStream(tmpFile)
+ )
+ )
+ Utils.tryWithSafeFinally {
+ metaValues.foreach(out.writeLong)
+ } {
+ out.close()
+ }
+
+ if (targetFile.exists()) {
+ targetFile.delete()
+ }
+
+ if (!tmpFile.renameTo(targetFile)) {
+ val errorMsg = s"fail to rename file $tmpFile to $targetFile"
+ if (propagateError) {
+ throw new IOException(errorMsg)
+ } else {
+ logWarning(errorMsg)
+ }
}
}
@@ -414,6 +504,45 @@ private[spark] class IndexShuffleBlockResolver(
new MergedBlockMeta(numChunks, chunkBitMaps)
}
+ private[shuffle] def getChecksums(checksumFile: File, blockNum: Int): Array[Long] = {
+ if (!checksumFile.exists()) return null
+ val checksums = new ArrayBuffer[Long]
+ // Read the checksums of blocks
+ var in: DataInputStream = null
+ try {
+ in = new DataInputStream(new NioBufferedFileInputStream(checksumFile))
+ while (checksums.size < blockNum) {
+ checksums += in.readLong()
+ }
+ } catch {
+ case _: IOException | _: EOFException =>
+ return null
+ } finally {
+ in.close()
+ }
+
+ checksums.toArray
+ }
+
+ /**
+ * Get the shuffle checksum file.
+ *
+ * When the dirs parameter is None then use the disk manager's local directories. Otherwise,
+ * read from the specified directories.
+ */
+ def getChecksumFile(
+ shuffleId: Int,
+ mapId: Long,
+ dirs: Option[Array[String]] = None): File = {
+ val blockId = ShuffleChecksumBlockId(shuffleId, mapId, NOOP_REDUCE_ID)
+ val fileName = ShuffleChecksumHelper.getChecksumFileName(blockId, conf)
+ dirs
+ .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, fileName))
+ .getOrElse {
+ blockManager.diskBlockManager.getFile(fileName)
+ }
+ }
+
override def getBlockData(
blockId: BlockId,
dirs: Option[Array[String]]): ManagedBuffer = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
index e0affb8..9843ae1 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionPairsWriter.scala
@@ -18,7 +18,9 @@
package org.apache.spark.shuffle
import java.io.{Closeable, IOException, OutputStream}
+import java.util.zip.Checksum
+import org.apache.spark.io.MutableCheckedOutputStream
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.api.ShufflePartitionWriter
import org.apache.spark.storage.{BlockId, TimeTrackingOutputStream}
@@ -34,7 +36,8 @@ private[spark] class ShufflePartitionPairsWriter(
serializerManager: SerializerManager,
serializerInstance: SerializerInstance,
blockId: BlockId,
- writeMetrics: ShuffleWriteMetricsReporter)
+ writeMetrics: ShuffleWriteMetricsReporter,
+ checksum: Checksum)
extends PairsWriter with Closeable {
private var isClosed = false
@@ -44,6 +47,9 @@ private[spark] class ShufflePartitionPairsWriter(
private var objOut: SerializationStream = _
private var numRecordsWritten = 0
private var curNumBytesWritten = 0L
+ // this would be only initialized when checksum != null,
+ // which indicates shuffle checksum is enabled.
+ private var checksumOutputStream: MutableCheckedOutputStream = _
override def write(key: Any, value: Any): Unit = {
if (isClosed) {
@@ -61,7 +67,12 @@ private[spark] class ShufflePartitionPairsWriter(
try {
partitionStream = partitionWriter.openStream
timeTrackingStream = new TimeTrackingOutputStream(writeMetrics, partitionStream)
- wrappedStream = serializerManager.wrapStream(blockId, timeTrackingStream)
+ if (checksum != null) {
+ checksumOutputStream = new MutableCheckedOutputStream(timeTrackingStream)
+ checksumOutputStream.setChecksum(checksum)
+ }
+ wrappedStream = serializerManager.wrapStream(blockId,
+ if (checksumOutputStream != null) checksumOutputStream else timeTrackingStream)
objOut = serializerInstance.serializeStream(wrappedStream)
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index adbe6ec..3cbf301 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -68,7 +68,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
- partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
+ partitionLengths = mapOutputWriter.commitAllPartitions(sorter.getChecksums).getPartitionLengths
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index dc70a9a..db5862d 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -92,6 +92,12 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten
override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
}
+@Since("3.3.0")
+@DeveloperApi
+case class ShuffleChecksumBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {
+ override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".checksum"
+}
+
@Since("3.2.0")
@DeveloperApi
case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index e55c0927..f5d8c02 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -19,8 +19,10 @@ package org.apache.spark.storage
import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream}
import java.nio.channels.{ClosedByInterruptException, FileChannel}
+import java.util.zip.Checksum
import org.apache.spark.internal.Logging
+import org.apache.spark.io.MutableCheckedOutputStream
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.util.Utils
@@ -77,6 +79,11 @@ private[spark] class DiskBlockObjectWriter(
private var streamOpen = false
private var hasBeenClosed = false
+ // checksum related
+ private var checksumEnabled = false
+ private var checksumOutputStream: MutableCheckedOutputStream = _
+ private var checksum: Checksum = _
+
/**
* Cursors used to represent positions in the file.
*
@@ -101,12 +108,30 @@ private[spark] class DiskBlockObjectWriter(
*/
private var numRecordsWritten = 0
+ /**
+ * Set the checksum that the checksumOutputStream should use
+ */
+ def setChecksum(checksum: Checksum): Unit = {
+ if (checksumOutputStream == null) {
+ this.checksumEnabled = true
+ this.checksum = checksum
+ } else {
+ checksumOutputStream.setChecksum(checksum)
+ }
+ }
+
private def initialize(): Unit = {
fos = new FileOutputStream(file, true)
channel = fos.getChannel()
ts = new TimeTrackingOutputStream(writeMetrics, fos)
+ if (checksumEnabled) {
+ assert(this.checksum != null, "Checksum is not set")
+ checksumOutputStream = new MutableCheckedOutputStream(ts)
+ checksumOutputStream.setChecksum(checksum)
+ }
class ManualCloseBufferedOutputStream
- extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream
+ extends BufferedOutputStream(if (checksumEnabled) checksumOutputStream else ts, bufferSize)
+ with ManualCloseOutputStream
mcs = new ManualCloseBufferedOutputStream
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 1913637..dba9e74 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer._
import org.apache.spark.shuffle.ShufflePartitionPairsWriter
import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter}
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}
import org.apache.spark.util.{CompletionIterator, Utils => TryUtils}
@@ -141,6 +142,11 @@ private[spark] class ExternalSorter[K, V, C](
private val forceSpillFiles = new ArrayBuffer[SpilledFile]
@volatile private var readingIterator: SpillableIterator = null
+ private val partitionChecksums =
+ ShuffleChecksumHelper.createPartitionChecksumsIfEnabled(numPartitions, conf)
+
+ def getChecksums: Array[Long] = ShuffleChecksumHelper.getChecksumValues(partitionChecksums)
+
// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
@@ -762,7 +768,8 @@ private[spark] class ExternalSorter[K, V, C](
serializerManager,
serInstance,
blockId,
- context.taskMetrics().shuffleWriteMetrics)
+ context.taskMetrics().shuffleWriteMetrics,
+ if (partitionChecksums.nonEmpty) partitionChecksums(partitionId) else null)
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(partitionPairsWriter)
}
@@ -786,7 +793,8 @@ private[spark] class ExternalSorter[K, V, C](
serializerManager,
serInstance,
blockId,
- context.taskMetrics().shuffleWriteMetrics)
+ context.taskMetrics().shuffleWriteMetrics,
+ if (partitionChecksums.nonEmpty) partitionChecksums(id) else null)
if (elements.hasNext) {
for (elem <- elements) {
partitionPairsWriter.write(elem._1, elem._2)
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 5666bb3..cca3eb5 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -22,11 +22,11 @@ import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.util.*;
+import org.apache.spark.*;
+import org.apache.spark.shuffle.ShuffleChecksumTestHelper;
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper;
import org.mockito.stubbing.Answer;
-import scala.Option;
-import scala.Product2;
-import scala.Tuple2;
-import scala.Tuple2$;
+import scala.*;
import scala.collection.Iterator;
import com.google.common.collect.HashMultiset;
@@ -36,10 +36,6 @@ import org.junit.Test;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
-import org.apache.spark.HashPartitioner;
-import org.apache.spark.ShuffleDependency;
-import org.apache.spark.SparkConf;
-import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
import org.apache.spark.io.CompressionCodec$;
@@ -65,7 +61,7 @@ import static org.junit.Assert.*;
import static org.mockito.Answers.RETURNS_SMART_NULLS;
import static org.mockito.Mockito.*;
-public class UnsafeShuffleWriterSuite {
+public class UnsafeShuffleWriterSuite implements ShuffleChecksumTestHelper {
static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096;
static final int NUM_PARTITIONS = 4;
@@ -138,7 +134,7 @@ public class UnsafeShuffleWriterSuite {
Answer<?> renameTempAnswer = invocationOnMock -> {
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
- File tmp = (File) invocationOnMock.getArguments()[3];
+ File tmp = (File) invocationOnMock.getArguments()[4];
if (!mergedOutputFile.delete()) {
throw new RuntimeException("Failed to delete old merged output file.");
}
@@ -152,11 +148,13 @@ public class UnsafeShuffleWriterSuite {
doAnswer(renameTempAnswer)
.when(shuffleBlockResolver)
- .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class));
+ .writeMetadataFileAndCommit(
+ anyInt(), anyLong(), any(long[].class), any(long[].class), any(File.class));
doAnswer(renameTempAnswer)
.when(shuffleBlockResolver)
- .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null));
+ .writeMetadataFileAndCommit(
+ anyInt(), anyLong(), any(long[].class), any(long[].class), eq(null));
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> {
TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
@@ -171,7 +169,14 @@ public class UnsafeShuffleWriterSuite {
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager);
}
- private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabled) {
+ private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabled)
+ throws SparkException {
+ return createWriter(transferToEnabled, shuffleBlockResolver);
+ }
+
+ private UnsafeShuffleWriter<Object, Object> createWriter(
+ boolean transferToEnabled,
+ IndexShuffleBlockResolver blockResolver) throws SparkException {
conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
return new UnsafeShuffleWriter<>(
blockManager,
@@ -181,7 +186,7 @@ public class UnsafeShuffleWriterSuite {
taskContext,
conf,
taskContext.taskMetrics().shuffleWriteMetrics(),
- new LocalDiskShuffleExecutorComponents(conf, blockManager, shuffleBlockResolver));
+ new LocalDiskShuffleExecutorComponents(conf, blockManager, blockResolver));
}
private void assertSpillFilesWereCleanedUp() {
@@ -219,12 +224,12 @@ public class UnsafeShuffleWriterSuite {
}
@Test(expected=IllegalStateException.class)
- public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+ public void mustCallWriteBeforeSuccessfulStop() throws IOException, SparkException {
createWriter(false).stop(true);
}
@Test
- public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+ public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException, SparkException {
createWriter(false).stop(false);
}
@@ -291,6 +296,69 @@ public class UnsafeShuffleWriterSuite {
assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.bytesWritten());
}
+ @Test
+ public void writeChecksumFileWithoutSpill() throws Exception {
+ IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager);
+ ShuffleChecksumBlockId checksumBlockId =
+ new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID());
+ File checksumFile = new File(tempDir,
+ ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf));
+ File dataFile = new File(tempDir, "data");
+ File indexFile = new File(tempDir, "index");
+ when(diskBlockManager.getFile(checksumFile.getName()))
+ .thenReturn(checksumFile);
+ when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0)))
+ .thenReturn(dataFile);
+ when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0)))
+ .thenReturn(indexFile);
+
+ // In this example, each partition should have exactly one record:
+ final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
+ for (int i = 0; i < NUM_PARTITIONS; i ++) {
+ dataToWrite.add(new Tuple2<>(i, i));
+ }
+ final UnsafeShuffleWriter<Object, Object> writer1 = createWriter(true, blockResolver);
+ writer1.write(dataToWrite.iterator());
+ writer1.stop(true);
+ assertTrue(checksumFile.exists());
+ assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS);
+ compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile);
+ }
+
+ @Test
+ public void writeChecksumFileWithSpill() throws Exception {
+ IndexShuffleBlockResolver blockResolver = new IndexShuffleBlockResolver(conf, blockManager);
+ ShuffleChecksumBlockId checksumBlockId =
+ new ShuffleChecksumBlockId(0, 0, IndexShuffleBlockResolver.NOOP_REDUCE_ID());
+ File checksumFile =
+ new File(tempDir, ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf));
+ File dataFile = new File(tempDir, "data");
+ File indexFile = new File(tempDir, "index");
+ when(diskBlockManager.getFile(eq(checksumFile.getName()))).thenReturn(checksumFile);
+ when(diskBlockManager.getFile(new ShuffleDataBlockId(shuffleDep.shuffleId(), 0, 0)))
+ .thenReturn(dataFile);
+ when(diskBlockManager.getFile(new ShuffleIndexBlockId(shuffleDep.shuffleId(), 0, 0)))
+ .thenReturn(indexFile);
+
+ final UnsafeShuffleWriter<Object, Object> writer1 = createWriter(true, blockResolver);
+ writer1.insertRecordIntoSorter(new Tuple2<>(0, 0));
+ writer1.forceSorterToSpill();
+ writer1.insertRecordIntoSorter(new Tuple2<>(1, 0));
+ writer1.insertRecordIntoSorter(new Tuple2<>(2, 0));
+ writer1.forceSorterToSpill();
+ writer1.insertRecordIntoSorter(new Tuple2<>(0, 1));
+ writer1.insertRecordIntoSorter(new Tuple2<>(3, 0));
+ writer1.forceSorterToSpill();
+ writer1.insertRecordIntoSorter(new Tuple2<>(1, 1));
+ writer1.forceSorterToSpill();
+ writer1.insertRecordIntoSorter(new Tuple2<>(0, 2));
+ writer1.forceSorterToSpill();
+ writer1.closeAndWriteOutput();
+ assertTrue(checksumFile.exists());
+ assertEquals(checksumFile.length(), 8 * NUM_PARTITIONS);
+ compareChecksums(NUM_PARTITIONS, checksumFile, dataFile, indexFile);
+ }
+
private void testMergingSpills(
final boolean transferToEnabled,
String compressionCodecName,
@@ -317,7 +385,7 @@ public class UnsafeShuffleWriterSuite {
private void testMergingSpills(
boolean transferToEnabled,
- boolean encrypted) throws IOException {
+ boolean encrypted) throws IOException, SparkException {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
@@ -515,7 +583,7 @@ public class UnsafeShuffleWriterSuite {
}
@Test
- public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+ public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException, SparkException {
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
writer.insertRecordIntoSorter(new Tuple2<>(1, 1));
writer.insertRecordIntoSorter(new Tuple2<>(2, 2));
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
new file mode 100644
index 0000000..a8f2c40
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleChecksumTestHelper.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.{DataInputStream, File, FileInputStream}
+import java.util.zip.CheckedInputStream
+
+import org.apache.spark.network.util.LimitedInputStream
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
+
+trait ShuffleChecksumTestHelper {
+
+ /**
+ * Ensure that the checksum values are consistent between write and read side.
+ */
+ def compareChecksums(numPartition: Int, checksum: File, data: File, index: File): Unit = {
+ assert(checksum.exists(), "Checksum file doesn't exist")
+ assert(data.exists(), "Data file doesn't exist")
+ assert(index.exists(), "Index file doesn't exist")
+
+ var checksumIn: DataInputStream = null
+ val expectChecksums = Array.ofDim[Long](numPartition)
+ try {
+ checksumIn = new DataInputStream(new FileInputStream(checksum))
+ (0 until numPartition).foreach(i => expectChecksums(i) = checksumIn.readLong())
+ } finally {
+ if (checksumIn != null) {
+ checksumIn.close()
+ }
+ }
+
+ var dataIn: FileInputStream = null
+ var indexIn: DataInputStream = null
+ var checkedIn: CheckedInputStream = null
+ try {
+ dataIn = new FileInputStream(data)
+ indexIn = new DataInputStream(new FileInputStream(index))
+ var prevOffset = indexIn.readLong
+ (0 until numPartition).foreach { i =>
+ val curOffset = indexIn.readLong
+ val limit = (curOffset - prevOffset).toInt
+ val bytes = new Array[Byte](limit)
+ val checksumCal = ShuffleChecksumHelper.getChecksumByFileExtension(checksum.getName)
+ checkedIn = new CheckedInputStream(
+ new LimitedInputStream(dataIn, curOffset - prevOffset), checksumCal)
+ checkedIn.read(bytes, 0, limit)
+ prevOffset = curOffset
+ // checksum must be consistent at both write and read sides
+ assert(checkedIn.getChecksum.getValue == expectChecksums(i))
+ }
+ } finally {
+ if (dataIn != null) {
+ dataIn.close()
+ }
+ if (indexIn != null) {
+ indexIn.close()
+ }
+ if (checkedIn != null) {
+ checkedIn.close()
+ }
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 7fd0bf6..39eef97 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -33,13 +33,17 @@ import org.apache.spark._
import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
-import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleChecksumTestHelper}
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.checksum.ShuffleChecksumHelper
import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents
import org.apache.spark.storage._
import org.apache.spark.util.Utils
-class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+class BypassMergeSortShuffleWriterSuite
+ extends SparkFunSuite
+ with BeforeAndAfterEach
+ with ShuffleChecksumTestHelper {
@Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
@Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
@@ -76,10 +80,10 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
- when(blockResolver.writeIndexFileAndCommit(
- anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])))
+ when(blockResolver.writeMetadataFileAndCommit(
+ anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])))
.thenAnswer { invocationOnMock =>
- val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File]
if (tmp != null) {
outputFile.delete
tmp.renameTo(outputFile)
@@ -236,4 +240,43 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
writer.stop( /* success = */ false)
assert(temporaryFilesCreated.count(_.exists()) === 0)
}
+
+ test("write checksum file") {
+ val blockResolver = new IndexShuffleBlockResolver(conf, blockManager)
+ val shuffleId = shuffleHandle.shuffleId
+ val mapId = 0
+ val checksumBlockId = ShuffleChecksumBlockId(shuffleId, mapId, 0)
+ val dataBlockId = ShuffleDataBlockId(shuffleId, mapId, 0)
+ val indexBlockId = ShuffleIndexBlockId(shuffleId, mapId, 0)
+ val checksumFile = new File(tempDir,
+ ShuffleChecksumHelper.getChecksumFileName(checksumBlockId, conf))
+ val dataFile = new File(tempDir, dataBlockId.name)
+ val indexFile = new File(tempDir, indexBlockId.name)
+ reset(diskBlockManager)
+ when(diskBlockManager.getFile(checksumFile.getName)).thenAnswer(_ => checksumFile)
+ when(diskBlockManager.getFile(dataBlockId)).thenAnswer(_ => dataFile)
+ when(diskBlockManager.getFile(indexBlockId)).thenAnswer(_ => indexFile)
+ when(diskBlockManager.createTempShuffleBlock())
+ .thenAnswer { _ =>
+ val blockId = new TempShuffleBlockId(UUID.randomUUID)
+ val file = new File(tempDir, blockId.name)
+ temporaryFilesCreated += file
+ (blockId, file)
+ }
+
+ val numPartition = shuffleHandle.dependency.partitioner.numPartitions
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ blockManager,
+ shuffleHandle,
+ mapId,
+ conf,
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ new LocalDiskShuffleExecutorComponents(conf, blockManager, blockResolver))
+
+ writer.write(Iterator((0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6)))
+ writer.stop( /* success = */ true)
+ assert(checksumFile.exists())
+ assert(checksumFile.length() === 8 * numPartition)
+ compareChecksums(numPartition, checksumFile, dataFile, indexFile)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
index 5955d44..49c079c 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -28,6 +28,7 @@ import org.roaringbitmap.RoaringBitmap
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.{MapOutputTracker, SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleBlockInfo}
import org.apache.spark.storage._
import org.apache.spark.util.Utils
@@ -49,6 +50,8 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
(invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString))
+ when(diskBlockManager.getFile(any[String])).thenAnswer(
+ (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString))
when(diskBlockManager.getMergedShuffleFile(
any[BlockId], any[Option[Array[String]]])).thenAnswer(
(invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString))
@@ -77,7 +80,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
} {
out.close()
}
- resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
+ resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths, Array.empty, dataTmp)
val indexFile = new File(tempDir.getAbsolutePath, idxName)
val dataFile = resolver.getDataFile(shuffleId, mapId)
@@ -97,7 +100,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
} {
out2.close()
}
- resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2)
+ resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths2, Array.empty, dataTmp2)
assert(indexFile.length() === (lengths.length + 1) * 8)
assert(lengths2.toSeq === lengths.toSeq)
@@ -136,7 +139,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
} {
out3.close()
}
- resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3)
+ resolver.writeMetadataFileAndCommit(shuffleId, mapId, lengths3, Array.empty, dataTmp3)
assert(indexFile.length() === (lengths3.length + 1) * 8)
assert(lengths3.toSeq != lengths.toSeq)
assert(dataFile.exists())
@@ -248,4 +251,19 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
outIndex.close()
}
}
+
+ test("write checksum file") {
+ val resolver = new IndexShuffleBlockResolver(conf, blockManager)
+ val dataTmp = File.createTempFile("shuffle", null, tempDir)
+ val indexInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
+ val checksumsInMemory = Array[Long](0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
+ resolver.writeMetadataFileAndCommit(0, 0, indexInMemory, checksumsInMemory, dataTmp)
+ val checksumFile = resolver.getChecksumFile(0, 0)
+ assert(checksumFile.exists())
+ val checksumFileName = checksumFile.toString
+ val checksumAlgo = checksumFileName.substring(checksumFileName.lastIndexOf(".") + 1)
+ assert(checksumAlgo === conf.get(config.SHUFFLE_CHECKSUM_ALGORITHM))
+ val checksumsFromFile = resolver.getChecksums(checksumFile, 10)
+ assert(checksumsInMemory === checksumsFromFile)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
index 4c679fd..e345736 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -20,19 +20,25 @@ package org.apache.spark.shuffle.sort
import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.Answers.RETURNS_SMART_NULLS
import org.mockito.Mockito._
+import org.scalatest.PrivateMethodTester
import org.scalatest.matchers.must.Matchers
-import org.apache.spark.{Partitioner, SharedSparkContext, ShuffleDependency, SparkFunSuite}
+import org.apache.spark.{Aggregator, DebugFilesystem, Partitioner, SharedSparkContext, ShuffleDependency, SparkContext, SparkFunSuite}
import org.apache.spark.memory.MemoryTestingUtils
import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver}
+import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleChecksumTestHelper}
import org.apache.spark.shuffle.api.ShuffleExecutorComponents
import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents
import org.apache.spark.storage.BlockManager
import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.ExternalSorter
-
-class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with Matchers {
+class SortShuffleWriterSuite
+ extends SparkFunSuite
+ with SharedSparkContext
+ with Matchers
+ with PrivateMethodTester
+ with ShuffleChecksumTestHelper {
@Mock(answer = RETURNS_SMART_NULLS)
private var blockManager: BlockManager = _
@@ -44,13 +50,14 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with
private val serializer = new JavaSerializer(conf)
private var shuffleExecutorComponents: ShuffleExecutorComponents = _
+ private val partitioner = new Partitioner() {
+ def numPartitions = numMaps
+ def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions)
+ }
+
override def beforeEach(): Unit = {
super.beforeEach()
MockitoAnnotations.openMocks(this).close()
- val partitioner = new Partitioner() {
- def numPartitions = numMaps
- def getPartition(key: Any) = Utils.nonNegativeMod(key.hashCode, numPartitions)
- }
shuffleHandle = {
val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
when(dependency.partitioner).thenReturn(partitioner)
@@ -103,4 +110,68 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with
assert(dataFile.length() === writeMetrics.bytesWritten)
assert(records.size === writeMetrics.recordsWritten)
}
+
+ Seq((true, false, false),
+ (true, true, false),
+ (true, false, true),
+ (true, true, true),
+ (false, false, false),
+ (false, true, false),
+ (false, false, true),
+ (false, true, true)).foreach { case (doSpill, doAgg, doOrder) =>
+ test(s"write checksum file (spill=$doSpill, aggregator=$doAgg, order=$doOrder)") {
+ val aggregator = if (doAgg) {
+ Some(Aggregator[Int, Int, Int](
+ v => v,
+ (c, v) => c + v,
+ (c1, c2) => c1 + c2))
+ } else None
+ val order = if (doOrder) {
+ Some(new Ordering[Int] {
+ override def compare(x: Int, y: Int): Int = x - y
+ })
+ } else None
+
+ val shuffleHandle = {
+ val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
+ when(dependency.partitioner).thenReturn(partitioner)
+ when(dependency.serializer).thenReturn(serializer)
+ when(dependency.aggregator).thenReturn(aggregator)
+ when(dependency.keyOrdering).thenReturn(order)
+ new BaseShuffleHandle[Int, Int, Int](shuffleId, dependency)
+ }
+
+ // FIXME: this can affect other tests (if any) after this set of tests
+ // since `sc` is global.
+ sc.stop()
+ conf.set("spark.shuffle.spill.numElementsForceSpillThreshold",
+ if (doSpill) "0" else Int.MaxValue.toString)
+ conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
+ val localSC = new SparkContext("local[4]", "test", conf)
+ val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
+ val context = MemoryTestingUtils.fakeTaskContext(localSC.env)
+ val records = List[(Int, Int)](
+ (0, 1), (1, 2), (0, 2), (1, 3), (2, 3), (3, 4), (4, 5), (3, 5), (4, 6))
+ val numPartition = shuffleHandle.dependency.partitioner.numPartitions
+ val writer = new SortShuffleWriter[Int, Int, Int](
+ shuffleHandle,
+ mapId = 0,
+ context,
+ new LocalDiskShuffleExecutorComponents(
+ conf, shuffleBlockResolver._blockManager, shuffleBlockResolver))
+ writer.write(records.toIterator)
+ val sorterMethod = PrivateMethod[ExternalSorter[_, _, _]](Symbol("sorter"))
+ val sorter = writer.invokePrivate(sorterMethod())
+ val expectSpillSize = if (doSpill) records.size else 0
+ assert(sorter.numSpills === expectSpillSize)
+ writer.stop(success = true)
+ val checksumFile = shuffleBlockResolver.getChecksumFile(shuffleId, 0)
+ assert(checksumFile.exists())
+ assert(checksumFile.length() === 8 * numPartition)
+ val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 0)
+ val indexFile = shuffleBlockResolver.getIndexFile(shuffleId, 0)
+ compareChecksums(numPartition, checksumFile, dataFile, indexFile)
+ localSC.stop()
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
index ef5c615..35d9b4a 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
@@ -74,11 +74,11 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA
.set("spark.app.id", "example.spark.app")
.set("spark.shuffle.unsafe.file.output.buffer", "16k")
when(blockResolver.getDataFile(anyInt, anyLong)).thenReturn(mergedOutputFile)
- when(blockResolver.writeIndexFileAndCommit(
- anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File])))
+ when(blockResolver.writeMetadataFileAndCommit(
+ anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[Array[Long]]), any(classOf[File])))
.thenAnswer { invocationOnMock =>
partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]]
- val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ val tmp: File = invocationOnMock.getArguments()(4).asInstanceOf[File]
if (tmp != null) {
mergedOutputFile.delete()
tmp.renameTo(mergedOutputFile)
@@ -136,7 +136,8 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA
}
private def verifyWrittenRecords(): Unit = {
- val committedLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
+ val committedLengths =
+ mapOutputWriter.commitAllPartitions(Array.empty[Long]).getPartitionLengths
assert(partitionSizesInMergedFile === partitionLengths)
assert(committedLengths === partitionLengths)
assert(mergedOutputFile.length() === partitionLengths.sum)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index a6de64b..7bec961 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -500,15 +500,15 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
intercept[SparkException] {
data.reduceByKey(_ + _).count()
}
- // After the shuffle, there should be only 2 files on disk: the output of task 1 and
- // its index. All other files (map 2's output and intermediate merge files) should
- // have been deleted.
- assert(diskBlockManager.getAllFiles().length === 2)
+ // After the shuffle, there should be only 3 files on disk: the output of task 1 and
+ // its index and checksum. All other files (map 2's output and intermediate merge files)
+ // should have been deleted.
+ assert(diskBlockManager.getAllFiles().length === 3)
} else {
assert(data.reduceByKey(_ + _).count() === size)
- // After the shuffle, there should be only 4 files on disk: the output of both tasks
- // and their indices. All intermediate merge files should have been deleted.
- assert(diskBlockManager.getAllFiles().length === 4)
+ // After the shuffle, there should be only 6 files on disk: the output of both tasks
+ // and their indices/checksums. All intermediate merge files should have been deleted.
+ assert(diskBlockManager.getAllFiles().length === 6)
}
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3b93a34..dba74ac 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -61,7 +61,15 @@ object MimaExcludes {
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.WritablePartitionedIterator"),
// [SPARK-35757][CORE] Add bitwise AND operation and functionality for intersecting bloom filters
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.BloomFilter.intersectInPlace")
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.BloomFilter.intersectInPlace"),
+
+ // [SPARK-35276][CORE] Calculate checksum for shuffle data and write as checksum file
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleMapOutputWriter.commitAllPartitions"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.io.LocalDiskSingleSpillMapOutputWriter.transferMapSpillFile"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.SingleSpillShuffleMapOutputWriter.transferMapSpillFile"),
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions")
)
// Exclude rules for 3.1.x
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org