You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by va...@apache.org on 2019/07/30 21:18:02 UTC

[spark] branch master updated: [SPARK-28209][CORE][SHUFFLE] Proposed new shuffle writer API

This is an automated email from the ASF dual-hosted git repository.

vanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new abef84a  [SPARK-28209][CORE][SHUFFLE] Proposed new shuffle writer API
abef84a is described below

commit abef84a868e9e15f346eea315bbab0ec8ac8e389
Author: mcheah <mc...@palantir.com>
AuthorDate: Tue Jul 30 14:17:30 2019 -0700

    [SPARK-28209][CORE][SHUFFLE] Proposed new shuffle writer API
    
    ## What changes were proposed in this pull request?
    
    As part of the shuffle storage API proposed in SPARK-25299, this introduces an API for persisting shuffle data in arbitrary storage systems.
    
    This patch introduces several concepts:
    * `ShuffleDataIO`, which is the root of the entire plugin tree that will be proposed over the course of the shuffle API project.
    * `ShuffleExecutorComponents` - the subset of plugins for managing shuffle-related components for each executor. This will in turn instantiate shuffle readers and writers.
    * `ShuffleMapOutputWriter` interface - instantiated once per map task. This provides child `ShufflePartitionWriter` instances for persisting the bytes for each partition in the map task.
    
    The default implementation of these plugins exactly mirror what was done by the existing shuffle writing code - namely, writing the data to local disk and writing an index file. We leverage the APIs in the `BypassMergeSortShuffleWriter` only. Follow-up PRs will use the APIs in `SortShuffleWriter` and `UnsafeShuffleWriter`, but are left as future work to minimize the review surface area.
    
    ## How was this patch tested?
    
    New unit tests were added. Micro-benchmarks indicate there's no slowdown in the affected code paths.
    
    Closes #25007 from mccheah/spark-shuffle-writer-refactor.
    
    Lead-authored-by: mcheah <mc...@palantir.com>
    Co-authored-by: mccheah <mc...@palantir.com>
    Signed-off-by: Marcelo Vanzin <va...@cloudera.com>
---
 .../apache/spark/shuffle/api/ShuffleDataIO.java    |  49 ++++
 .../shuffle/api/ShuffleExecutorComponents.java     |  55 +++++
 .../spark/shuffle/api/ShuffleMapOutputWriter.java  |  71 ++++++
 .../spark/shuffle/api/ShufflePartitionWriter.java  |  98 ++++++++
 .../shuffle/api/WritableByteChannelWrapper.java    |  42 ++++
 .../shuffle/sort/BypassMergeSortShuffleWriter.java | 173 +++++++++-----
 .../shuffle/sort/io/LocalDiskShuffleDataIO.java    |  40 ++++
 .../io/LocalDiskShuffleExecutorComponents.java     |  71 ++++++
 .../sort/io/LocalDiskShuffleMapOutputWriter.java   | 261 +++++++++++++++++++++
 .../org/apache/spark/internal/config/package.scala |   7 +
 .../spark/shuffle/sort/SortShuffleManager.scala    |  25 +-
 .../main/scala/org/apache/spark/util/Utils.scala   |  30 ++-
 .../test/scala/org/apache/spark/ShuffleSuite.scala |  16 +-
 .../sort/BypassMergeSortShuffleWriterSuite.scala   | 149 +++++++-----
 .../io/LocalDiskShuffleMapOutputWriterSuite.scala  | 147 ++++++++++++
 15 files changed, 1087 insertions(+), 147 deletions(-)

diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
new file mode 100644
index 0000000..e9e50ec
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ * An interface for plugging in modules for storing and reading temporary shuffle data.
+ * <p>
+ * This is the root of a plugin system for storing shuffle bytes to arbitrary storage
+ * backends in the sort-based shuffle algorithm implemented by the
+ * {@link org.apache.spark.shuffle.sort.SortShuffleManager}. If another shuffle algorithm is
+ * needed instead of sort-based shuffle, one should implement
+ * {@link org.apache.spark.shuffle.ShuffleManager} instead.
+ * <p>
+ * A single instance of this module is loaded per process in the Spark application.
+ * The default implementation reads and writes shuffle data from the local disks of
+ * the executor, and is the implementation of shuffle file storage that has remained
+ * consistent throughout most of Spark's history.
+ * <p>
+ * Alternative implementations of shuffle data storage can be loaded via setting
+ * <code>spark.shuffle.sort.io.plugin.class</code>.
+ * @since 3.0.0
+ */
+@Private
+public interface ShuffleDataIO {
+
+  /**
+   * Called once on executor processes to bootstrap the shuffle data storage modules that
+   * are only invoked on the executors.
+   */
+  ShuffleExecutorComponents executor();
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
new file mode 100644
index 0000000..70c112b
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.IOException;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ * An interface for building shuffle support for Executors.
+ *
+ * @since 3.0.0
+ */
+@Private
+public interface ShuffleExecutorComponents {
+
+  /**
+   * Called once per executor to bootstrap this module with state that is specific to
+   * that executor, specifically the application ID and executor ID.
+   */
+  void initializeExecutor(String appId, String execId);
+
+  /**
+   * Called once per map task to create a writer that will be responsible for persisting all the
+   * partitioned bytes written by that map task.
+   *  @param shuffleId Unique identifier for the shuffle the map task is a part of
+   * @param mapId Within the shuffle, the identifier of the map task
+   * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task
+ *                         with the same (shuffleId, mapId) pair can be distinguished by the
+ *                         different values of mapTaskAttemptId.
+   * @param numPartitions The number of partitions that will be written by the map task. Some of
+*                      these partitions may be empty.
+   */
+  ShuffleMapOutputWriter createMapOutputWriter(
+      int shuffleId,
+      int mapId,
+      long mapTaskAttemptId,
+      int numPartitions) throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
new file mode 100644
index 0000000..45a593c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.IOException;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ * A top-level writer that returns child writers for persisting the output of a map task,
+ * and then commits all of the writes as one atomic operation.
+ *
+ * @since 3.0.0
+ */
+@Private
+public interface ShuffleMapOutputWriter {
+
+  /**
+   * Creates a writer that can open an output stream to persist bytes targeted for a given reduce
+   * partition id.
+   * <p>
+   * The chunk corresponds to bytes in the given reduce partition. This will not be called twice
+   * for the same partition within any given map task. The partition identifier will be in the
+   * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was
+   * provided upon the creation of this map output writer via
+   * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}.
+   * <p>
+   * Calls to this method will be invoked with monotonically increasing reducePartitionIds; each
+   * call to this method will be called with a reducePartitionId that is strictly greater than
+   * the reducePartitionIds given to any previous call to this method. This method is not
+   * guaranteed to be called for every partition id in the above described range. In particular,
+   * no guarantees are made as to whether or not this method will be called for empty partitions.
+   */
+  ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException;
+
+  /**
+   * Commits the writes done by all partition writers returned by all calls to this object's
+   * {@link #getPartitionWriter(int)}.
+   * <p>
+   * This should ensure that the writes conducted by this module's partition writers are
+   * available to downstream reduce tasks. If this method throws any exception, this module's
+   * {@link #abort(Throwable)} method will be invoked before propagating the exception.
+   * <p>
+   * This can also close any resources and clean up temporary state if necessary.
+   */
+  void commitAllPartitions() throws IOException;
+
+  /**
+   * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}.
+   * <p>
+   * This should invalidate the results of writing bytes. This can also close any resources and
+   * clean up temporary state if necessary.
+   */
+  void abort(Throwable error) throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
new file mode 100644
index 0000000..9288751
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.IOException;
+import java.util.Optional;
+import java.io.OutputStream;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ * An interface for opening streams to persist partition bytes to a backing data store.
+ * <p>
+ * This writer stores bytes for one (mapper, reducer) pair, corresponding to one shuffle
+ * block.
+ *
+ * @since 3.0.0
+ */
+@Private
+public interface ShufflePartitionWriter {
+
+  /**
+   * Open and return an {@link OutputStream} that can write bytes to the underlying
+   * data store.
+   * <p>
+   * This method will only be called once on this partition writer in the map task, to write the
+   * bytes to the partition. The output stream will only be used to write the bytes for this
+   * partition. The map task closes this output stream upon writing all the bytes for this
+   * block, or if the write fails for any reason.
+   * <p>
+   * Implementations that intend on combining the bytes for all the partitions written by this
+   * map task should reuse the same OutputStream instance across all the partition writers provided
+   * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
+   * {@link OutputStream#close()} does not close the resource, since it will be reused across
+   * partition writes. The underlying resources should be cleaned up in
+   * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+   * {@link ShuffleMapOutputWriter#abort(Throwable)}.
+   */
+  OutputStream openStream() throws IOException;
+
+  /**
+   * Opens and returns a {@link WritableByteChannelWrapper} for transferring bytes from
+   * input byte channels to the underlying shuffle data store.
+   * <p>
+   * This method will only be called once on this partition writer in the map task, to write the
+   * bytes to the partition. The channel will only be used to write the bytes for this
+   * partition. The map task closes this channel upon writing all the bytes for this
+   * block, or if the write fails for any reason.
+   * <p>
+   * Implementations that intend on combining the bytes for all the partitions written by this
+   * map task should reuse the same channel instance across all the partition writers provided
+   * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
+   * {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel
+   * will be reused across partition writes. The underlying resources should be cleaned up in
+   * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+   * {@link ShuffleMapOutputWriter#abort(Throwable)}.
+   * <p>
+   * This method is primarily for advanced optimizations where bytes can be copied from the input
+   * spill files to the output channel without copying data into memory. If such optimizations are
+   * not supported, the implementation should return {@link Optional#empty()}. By default, the
+   * implementation returns {@link Optional#empty()}.
+   * <p>
+   * Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the
+   * underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure
+   * that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()},
+   * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or
+   * {@link ShuffleMapOutputWriter#abort(Throwable)}.
+   */
+  default Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
+    return Optional.empty();
+  }
+
+  /**
+   * Returns the number of bytes written either by this writer's output stream opened by
+   * {@link #openStream()} or the byte channel opened by {@link #openChannelWrapper()}.
+   * <p>
+   * This can be different from the number of bytes given by the caller. For example, the
+   * stream might compress or encrypt the bytes before persisting the data to the backing
+   * data store.
+   */
+  long getNumBytesWritten();
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java
new file mode 100644
index 0000000..a204903
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.Closeable;
+import java.nio.channels.WritableByteChannel;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ *
+ * A thin wrapper around a {@link WritableByteChannel}.
+ * <p>
+ * This is primarily provided for the local disk shuffle implementation to provide a
+ * {@link java.nio.channels.FileChannel} that keeps the channel open across partition writes.
+ *
+ * @since 3.0.0
+ */
+@Private
+public interface WritableByteChannelWrapper extends Closeable {
+
+  /**
+   * The underlying channel to write bytes into.
+   */
+  WritableByteChannel channel();
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 32b4467..3ccee70 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -19,8 +19,10 @@ package org.apache.spark.shuffle.sort;
 
 import java.io.File;
 import java.io.FileInputStream;
-import java.io.FileOutputStream;
 import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.FileChannel;
+import java.util.Optional;
 import javax.annotation.Nullable;
 
 import scala.None$;
@@ -34,16 +36,19 @@ import com.google.common.io.Closeables;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.spark.internal.config.package$;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.ShufflePartitionWriter;
+import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.internal.config.package$;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.storage.*;
 import org.apache.spark.util.Utils;
@@ -81,8 +86,9 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   private final ShuffleWriteMetricsReporter writeMetrics;
   private final int shuffleId;
   private final int mapId;
+  private final long mapTaskAttemptId;
   private final Serializer serializer;
-  private final IndexShuffleBlockResolver shuffleBlockResolver;
+  private final ShuffleExecutorComponents shuffleExecutorComponents;
 
   /** Array of file writers, one for each partition */
   private DiskBlockObjectWriter[] partitionWriters;
@@ -99,74 +105,82 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   BypassMergeSortShuffleWriter(
       BlockManager blockManager,
-      IndexShuffleBlockResolver shuffleBlockResolver,
       BypassMergeSortShuffleHandle<K, V> handle,
       int mapId,
+      long mapTaskAttemptId,
       SparkConf conf,
-      ShuffleWriteMetricsReporter writeMetrics) {
+      ShuffleWriteMetricsReporter writeMetrics,
+      ShuffleExecutorComponents shuffleExecutorComponents) {
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
     this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
     this.blockManager = blockManager;
     final ShuffleDependency<K, V, V> dep = handle.dependency();
     this.mapId = mapId;
+    this.mapTaskAttemptId = mapTaskAttemptId;
     this.shuffleId = dep.shuffleId();
     this.partitioner = dep.partitioner();
     this.numPartitions = partitioner.numPartitions();
     this.writeMetrics = writeMetrics;
     this.serializer = dep.serializer();
-    this.shuffleBlockResolver = shuffleBlockResolver;
+    this.shuffleExecutorComponents = shuffleExecutorComponents;
   }
 
   @Override
   public void write(Iterator<Product2<K, V>> records) throws IOException {
     assert (partitionWriters == null);
-    if (!records.hasNext()) {
-      partitionLengths = new long[numPartitions];
-      shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
-      mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
-      return;
-    }
-    final SerializerInstance serInstance = serializer.newInstance();
-    final long openStartTime = System.nanoTime();
-    partitionWriters = new DiskBlockObjectWriter[numPartitions];
-    partitionWriterSegments = new FileSegment[numPartitions];
-    for (int i = 0; i < numPartitions; i++) {
-      final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
-        blockManager.diskBlockManager().createTempShuffleBlock();
-      final File file = tempShuffleBlockIdPlusFile._2();
-      final BlockId blockId = tempShuffleBlockIdPlusFile._1();
-      partitionWriters[i] =
-        blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
-    }
-    // Creating the file to write to and creating a disk writer both involve interacting with
-    // the disk, and can take a long time in aggregate when we open many files, so should be
-    // included in the shuffle write time.
-    writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
-
-    while (records.hasNext()) {
-      final Product2<K, V> record = records.next();
-      final K key = record._1();
-      partitionWriters[partitioner.getPartition(key)].write(key, record._2());
-    }
+    ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents
+        .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions);
+    try {
+      if (!records.hasNext()) {
+        partitionLengths = new long[numPartitions];
+        mapOutputWriter.commitAllPartitions();
+        mapStatus = MapStatus$.MODULE$.apply(
+            blockManager.shuffleServerId(),
+            partitionLengths);
+        return;
+      }
+      final SerializerInstance serInstance = serializer.newInstance();
+      final long openStartTime = System.nanoTime();
+      partitionWriters = new DiskBlockObjectWriter[numPartitions];
+      partitionWriterSegments = new FileSegment[numPartitions];
+      for (int i = 0; i < numPartitions; i++) {
+        final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
+            blockManager.diskBlockManager().createTempShuffleBlock();
+        final File file = tempShuffleBlockIdPlusFile._2();
+        final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+        partitionWriters[i] =
+            blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+      }
+      // Creating the file to write to and creating a disk writer both involve interacting with
+      // the disk, and can take a long time in aggregate when we open many files, so should be
+      // included in the shuffle write time.
+      writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
 
-    for (int i = 0; i < numPartitions; i++) {
-      try (DiskBlockObjectWriter writer = partitionWriters[i]) {
-        partitionWriterSegments[i] = writer.commitAndGet();
+      while (records.hasNext()) {
+        final Product2<K, V> record = records.next();
+        final K key = record._1();
+        partitionWriters[partitioner.getPartition(key)].write(key, record._2());
       }
-    }
 
-    File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
-    File tmp = Utils.tempFileWith(output);
-    try {
-      partitionLengths = writePartitionedFile(tmp);
-      shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
-    } finally {
-      if (tmp.exists() && !tmp.delete()) {
-        logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
+      for (int i = 0; i < numPartitions; i++) {
+        try (DiskBlockObjectWriter writer = partitionWriters[i]) {
+          partitionWriterSegments[i] = writer.commitAndGet();
+        }
+      }
+
+      partitionLengths = writePartitionedData(mapOutputWriter);
+      mapOutputWriter.commitAllPartitions();
+      mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+    } catch (Exception e) {
+      try {
+        mapOutputWriter.abort(e);
+      } catch (Exception e2) {
+        logger.error("Failed to abort the writer after failing to write map output.", e2);
+        e.addSuppressed(e2);
       }
+      throw e;
     }
-    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
   }
 
   @VisibleForTesting
@@ -179,43 +193,80 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
    *
    * @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
    */
-  private long[] writePartitionedFile(File outputFile) throws IOException {
+  private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException {
     // Track location of the partition starts in the output file
     final long[] lengths = new long[numPartitions];
     if (partitionWriters == null) {
       // We were passed an empty iterator
       return lengths;
     }
-
-    final FileOutputStream out = new FileOutputStream(outputFile, true);
     final long writeStartTime = System.nanoTime();
-    boolean threwException = true;
     try {
       for (int i = 0; i < numPartitions; i++) {
         final File file = partitionWriterSegments[i].file();
+        ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i);
         if (file.exists()) {
-          final FileInputStream in = new FileInputStream(file);
-          boolean copyThrewException = true;
-          try {
-            lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
-            copyThrewException = false;
-          } finally {
-            Closeables.close(in, copyThrewException);
+          if (transferToEnabled) {
+            // Using WritableByteChannelWrapper to make resource closing consistent between
+            // this implementation and UnsafeShuffleWriter.
+            Optional<WritableByteChannelWrapper> maybeOutputChannel = writer.openChannelWrapper();
+            if (maybeOutputChannel.isPresent()) {
+              writePartitionedDataWithChannel(file, maybeOutputChannel.get());
+            } else {
+              writePartitionedDataWithStream(file, writer);
+            }
+          } else {
+            writePartitionedDataWithStream(file, writer);
           }
           if (!file.delete()) {
             logger.error("Unable to delete file for partition {}", i);
           }
         }
+        lengths[i] = writer.getNumBytesWritten();
       }
-      threwException = false;
     } finally {
-      Closeables.close(out, threwException);
       writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
     }
     partitionWriters = null;
     return lengths;
   }
 
+  private void writePartitionedDataWithChannel(
+      File file,
+      WritableByteChannelWrapper outputChannel) throws IOException {
+    boolean copyThrewException = true;
+    try {
+      FileInputStream in = new FileInputStream(file);
+      try (FileChannel inputChannel = in.getChannel()) {
+        Utils.copyFileStreamNIO(
+            inputChannel, outputChannel.channel(), 0L, inputChannel.size());
+        copyThrewException = false;
+      } finally {
+        Closeables.close(in, copyThrewException);
+      }
+    } finally {
+      Closeables.close(outputChannel, copyThrewException);
+    }
+  }
+
+  private void writePartitionedDataWithStream(File file, ShufflePartitionWriter writer)
+      throws IOException {
+    boolean copyThrewException = true;
+    FileInputStream in = new FileInputStream(file);
+    OutputStream outputStream;
+    try {
+      outputStream = writer.openStream();
+      try {
+        Utils.copyStream(in, outputStream, false, false);
+        copyThrewException = false;
+      } finally {
+        Closeables.close(outputStream, copyThrewException);
+      }
+    } finally {
+      Closeables.close(in, copyThrewException);
+    }
+  }
+
   @Override
   public Option<MapStatus> stop(boolean success) {
     if (stopping) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java
new file mode 100644
index 0000000..cabcb17
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
+import org.apache.spark.shuffle.api.ShuffleDataIO;
+
+/**
+ * Implementation of the {@link ShuffleDataIO} plugin system that replicates the local shuffle
+ * storage and index file functionality that has historically been used from Spark 2.4 and earlier.
+ */
+public class LocalDiskShuffleDataIO implements ShuffleDataIO {
+
+  private final SparkConf sparkConf;
+
+  public LocalDiskShuffleDataIO(SparkConf sparkConf) {
+    this.sparkConf = sparkConf;
+  }
+
+  @Override
+  public ShuffleExecutorComponents executor() {
+    return new LocalDiskShuffleExecutorComponents(sparkConf);
+  }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
new file mode 100644
index 0000000..02eb710
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.storage.BlockManager;
+
+public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents {
+
+  private final SparkConf sparkConf;
+  private BlockManager blockManager;
+  private IndexShuffleBlockResolver blockResolver;
+
+  public LocalDiskShuffleExecutorComponents(SparkConf sparkConf) {
+    this.sparkConf = sparkConf;
+  }
+
+  @VisibleForTesting
+  public LocalDiskShuffleExecutorComponents(
+      SparkConf sparkConf,
+      BlockManager blockManager,
+      IndexShuffleBlockResolver blockResolver) {
+    this.sparkConf = sparkConf;
+    this.blockManager = blockManager;
+    this.blockResolver = blockResolver;
+  }
+
+  @Override
+  public void initializeExecutor(String appId, String execId) {
+    blockManager = SparkEnv.get().blockManager();
+    if (blockManager == null) {
+      throw new IllegalStateException("No blockManager available from the SparkEnv.");
+    }
+    blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager);
+  }
+
+  @Override
+  public ShuffleMapOutputWriter createMapOutputWriter(
+      int shuffleId,
+      int mapId,
+      long mapTaskAttemptId,
+      int numPartitions) {
+    if (blockResolver == null) {
+      throw new IllegalStateException(
+          "Executor components must be initialized before getting writers.");
+    }
+    return new LocalDiskShuffleMapOutputWriter(
+        shuffleId, mapId, numPartitions, blockResolver, sparkConf);
+  }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
new file mode 100644
index 0000000..add4634
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
+
+import java.util.Optional;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.ShufflePartitionWriter;
+import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.util.Utils;
+
+/**
+ * Implementation of {@link ShuffleMapOutputWriter} that replicates the functionality of shuffle
+ * persisting shuffle data to local disk alongside index files, identical to Spark's historic
+ * canonical shuffle storage mechanism.
+ */
+public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
+
+  private static final Logger log =
+    LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class);
+
+  private final int shuffleId;
+  private final int mapId;
+  private final IndexShuffleBlockResolver blockResolver;
+  private final long[] partitionLengths;
+  private final int bufferSize;
+  private int lastPartitionId = -1;
+  private long currChannelPosition;
+
+  private final File outputFile;
+  private File outputTempFile;
+  private FileOutputStream outputFileStream;
+  private FileChannel outputFileChannel;
+  private BufferedOutputStream outputBufferedFileStream;
+
+  public LocalDiskShuffleMapOutputWriter(
+      int shuffleId,
+      int mapId,
+      int numPartitions,
+      IndexShuffleBlockResolver blockResolver,
+      SparkConf sparkConf) {
+    this.shuffleId = shuffleId;
+    this.mapId = mapId;
+    this.blockResolver = blockResolver;
+    this.bufferSize =
+      (int) (long) sparkConf.get(
+        package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024;
+    this.partitionLengths = new long[numPartitions];
+    this.outputFile = blockResolver.getDataFile(shuffleId, mapId);
+    this.outputTempFile = null;
+  }
+
+  @Override
+  public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException {
+    if (reducePartitionId <= lastPartitionId) {
+      throw new IllegalArgumentException("Partitions should be requested in increasing order.");
+    }
+    lastPartitionId = reducePartitionId;
+    if (outputTempFile == null) {
+      outputTempFile = Utils.tempFileWith(outputFile);
+    }
+    if (outputFileChannel != null) {
+      currChannelPosition = outputFileChannel.position();
+    } else {
+      currChannelPosition = 0L;
+    }
+    return new LocalDiskShufflePartitionWriter(reducePartitionId);
+  }
+
+  @Override
+  public void commitAllPartitions() throws IOException {
+    cleanUp();
+    File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
+    blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
+  }
+
+  @Override
+  public void abort(Throwable error) throws IOException {
+    cleanUp();
+    if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) {
+      log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath());
+    }
+  }
+
+  private void cleanUp() throws IOException {
+    if (outputBufferedFileStream != null) {
+      outputBufferedFileStream.close();
+    }
+    if (outputFileChannel != null) {
+      outputFileChannel.close();
+    }
+    if (outputFileStream != null) {
+      outputFileStream.close();
+    }
+  }
+
+  private void initStream() throws IOException {
+    if (outputFileStream == null) {
+      outputFileStream = new FileOutputStream(outputTempFile, true);
+    }
+    if (outputBufferedFileStream == null) {
+      outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize);
+    }
+  }
+
+  private void initChannel() throws IOException {
+    if (outputFileStream == null) {
+      outputFileStream = new FileOutputStream(outputTempFile, true);
+    }
+    if (outputFileChannel == null) {
+      outputFileChannel = outputFileStream.getChannel();
+    }
+  }
+
+  private class LocalDiskShufflePartitionWriter implements ShufflePartitionWriter {
+
+    private final int partitionId;
+    private PartitionWriterStream partStream = null;
+    private PartitionWriterChannel partChannel = null;
+
+    private LocalDiskShufflePartitionWriter(int partitionId) {
+      this.partitionId = partitionId;
+    }
+
+    @Override
+    public OutputStream openStream() throws IOException {
+      if (partStream == null) {
+        if (outputFileChannel != null) {
+          throw new IllegalStateException("Requested an output channel for a previous write but" +
+              " now an output stream has been requested. Should not be using both channels" +
+              " and streams to write.");
+        }
+        initStream();
+        partStream = new PartitionWriterStream(partitionId);
+      }
+      return partStream;
+    }
+
+    @Override
+    public Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
+      if (partChannel == null) {
+        if (partStream != null) {
+          throw new IllegalStateException("Requested an output stream for a previous write but" +
+              " now an output channel has been requested. Should not be using both channels" +
+              " and streams to write.");
+        }
+        initChannel();
+        partChannel = new PartitionWriterChannel(partitionId);
+      }
+      return Optional.of(partChannel);
+    }
+
+    @Override
+    public long getNumBytesWritten() {
+      if (partChannel != null) {
+        try {
+          return partChannel.getCount();
+        } catch (IOException e) {
+          throw new RuntimeException(e);
+        }
+      } else if (partStream != null) {
+        return partStream.getCount();
+      } else {
+        // Assume an empty partition if stream and channel are never created
+        return 0;
+      }
+    }
+  }
+
+  private class PartitionWriterStream extends OutputStream {
+    private final int partitionId;
+    private int count = 0;
+    private boolean isClosed = false;
+
+    PartitionWriterStream(int partitionId) {
+      this.partitionId = partitionId;
+    }
+
+    public int getCount() {
+      return count;
+    }
+
+    @Override
+    public void write(int b) throws IOException {
+      verifyNotClosed();
+      outputBufferedFileStream.write(b);
+      count++;
+    }
+
+    @Override
+    public void write(byte[] buf, int pos, int length) throws IOException {
+      verifyNotClosed();
+      outputBufferedFileStream.write(buf, pos, length);
+      count += length;
+    }
+
+    @Override
+    public void close() {
+      isClosed = true;
+      partitionLengths[partitionId] = count;
+    }
+
+    private void verifyNotClosed() {
+      if (isClosed) {
+        throw new IllegalStateException("Attempting to write to a closed block output stream.");
+      }
+    }
+  }
+
+  private class PartitionWriterChannel implements WritableByteChannelWrapper {
+
+    private final int partitionId;
+
+    PartitionWriterChannel(int partitionId) {
+      this.partitionId = partitionId;
+    }
+
+    public long getCount() throws IOException {
+      long writtenPosition = outputFileChannel.position();
+      return writtenPosition - currChannelPosition;
+    }
+
+    @Override
+    public WritableByteChannel channel() {
+      return outputFileChannel;
+    }
+
+    @Override
+    public void close() throws IOException {
+      partitionLengths[partitionId] = getCount();
+    }
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index f2b88fe..cda3b57 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -24,6 +24,7 @@ import org.apache.spark.metrics.GarbageCollectionMetrics
 import org.apache.spark.network.shuffle.Constants
 import org.apache.spark.network.util.ByteUnit
 import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode}
+import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO
 import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy}
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.util.Utils
@@ -811,6 +812,12 @@ package object config {
       .booleanConf
       .createWithDefault(false)
 
+  private[spark] val SHUFFLE_IO_PLUGIN_CLASS =
+    ConfigBuilder("spark.shuffle.sort.io.plugin.class")
+      .doc("Name of the class to use for shuffle IO.")
+      .stringConf
+      .createWithDefault(classOf[LocalDiskShuffleDataIO].getName)
+
   private[spark] val SHUFFLE_FILE_BUFFER_SIZE =
     ConfigBuilder("spark.shuffle.file.buffer")
       .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " +
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index b59fa8e..17719f5 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -20,8 +20,10 @@ package org.apache.spark.shuffle.sort
 import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.spark._
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents}
+import org.apache.spark.util.Utils
 
 /**
  * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
@@ -68,6 +70,8 @@ import org.apache.spark.shuffle._
  */
 private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
 
+  import SortShuffleManager._
+
   if (!conf.getBoolean("spark.shuffle.spill", true)) {
     logWarning(
       "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." +
@@ -79,6 +83,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
    */
   private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
 
+  private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
+
   override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
 
   /**
@@ -134,7 +140,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
       case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
         new UnsafeShuffleWriter(
           env.blockManager,
-          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+          shuffleBlockResolver,
           context.taskMemoryManager(),
           unsafeShuffleHandle,
           mapId,
@@ -144,11 +150,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
       case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
         new BypassMergeSortShuffleWriter(
           env.blockManager,
-          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
           bypassMergeSortHandle,
           mapId,
+          context.taskAttemptId(),
           env.conf,
-          metrics)
+          metrics,
+          shuffleExecutorComponents)
       case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
         new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
     }
@@ -205,6 +212,16 @@ private[spark] object SortShuffleManager extends Logging {
       true
     }
   }
+
+  private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
+    val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS)
+    val maybeIO = Utils.loadExtensions(
+      classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
+    require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
+    val executorComponents = maybeIO.head.executor()
+    executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId)
+    executorComponents
+  }
 }
 
 /**
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 80d70a1..24042db 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException
 import java.math.{MathContext, RoundingMode}
 import java.net._
 import java.nio.ByteBuffer
-import java.nio.channels.{Channels, FileChannel}
+import java.nio.channels.{Channels, FileChannel, WritableByteChannel}
 import java.nio.charset.StandardCharsets
 import java.nio.file.Files
 import java.security.SecureRandom
@@ -394,10 +394,14 @@ private[spark] object Utils extends Logging {
 
   def copyFileStreamNIO(
       input: FileChannel,
-      output: FileChannel,
+      output: WritableByteChannel,
       startPosition: Long,
       bytesToCopy: Long): Unit = {
-    val initialPos = output.position()
+    val outputInitialState = output match {
+      case outputFileChannel: FileChannel =>
+        Some((outputFileChannel.position(), outputFileChannel))
+      case _ => None
+    }
     var count = 0L
     // In case transferTo method transferred less data than we have required.
     while (count < bytesToCopy) {
@@ -412,15 +416,17 @@ private[spark] object Utils extends Logging {
     // kernel version 2.6.32, this issue can be seen in
     // https://bugs.openjdk.java.net/browse/JDK-7052359
     // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
-    val finalPos = output.position()
-    val expectedPos = initialPos + bytesToCopy
-    assert(finalPos == expectedPos,
-      s"""
-         |Current position $finalPos do not equal to expected position $expectedPos
-         |after transferTo, please check your kernel version to see if it is 2.6.32,
-         |this is a kernel bug which will lead to unexpected behavior when using transferTo.
-         |You can set spark.file.transferTo = false to disable this NIO feature.
-           """.stripMargin)
+    outputInitialState.foreach { case (initialPos, outputFileChannel) =>
+      val finalPos = outputFileChannel.position()
+      val expectedPos = initialPos + bytesToCopy
+      assert(finalPos == expectedPos,
+        s"""
+           |Current position $finalPos do not equal to expected position $expectedPos
+           |after transferTo, please check your kernel version to see if it is 2.6.32,
+           |this is a kernel bug which will lead to unexpected behavior when using transferTo.
+           |You can set spark.file.transferTo = false to disable this NIO feature.
+         """.stripMargin)
+    }
   }
 
   /**
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 8b1084a..923c9c9 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -383,13 +383,18 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
     // simultaneously, and everything is still OK
 
     def writeAndClose(
-      writer: ShuffleWriter[Int, Int])(
-      iter: Iterator[(Int, Int)]): Option[MapStatus] = {
-      val files = writer.write(iter)
-      writer.stop(true)
+        writer: ShuffleWriter[Int, Int],
+        taskContext: TaskContext)(
+        iter: Iterator[(Int, Int)]): Option[MapStatus] = {
+      try {
+        val files = writer.write(iter)
+        writer.stop(true)
+      } finally {
+        TaskContext.unset()
+      }
     }
     val interleaver = new InterleaveIterators(
-      data1, writeAndClose(writer1), data2, writeAndClose(writer2))
+      data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2))
     val (mapOutput1, mapOutput2) = interleaver.run()
 
     // check that we can read the map output and it has the right data
@@ -407,6 +412,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
       1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)
     val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
     val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics)
+    TaskContext.unset()
     val readData = reader.read().toIndexedSeq
     assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
 
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index fc1422d..b9f81fa 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -27,13 +27,15 @@ import org.mockito.{Mock, MockitoAnnotations}
 import org.mockito.Answers.RETURNS_SMART_NULLS
 import org.mockito.ArgumentMatchers.{any, anyInt}
 import org.mockito.Mockito._
-import org.mockito.invocation.InvocationOnMock
 import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark._
 import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
 import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
 import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
 
@@ -48,68 +50,82 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
   private var taskMetrics: TaskMetrics = _
   private var tempDir: File = _
   private var outputFile: File = _
+  private var shuffleExecutorComponents: ShuffleExecutorComponents = _
   private val conf: SparkConf = new SparkConf(loadDefaults = false)
+    .set("spark.app.id", "sampleApp")
   private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
   private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
   private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _
 
   override def beforeEach(): Unit = {
     super.beforeEach()
+    MockitoAnnotations.initMocks(this)
     tempDir = Utils.createTempDir()
     outputFile = File.createTempFile("shuffle", null, tempDir)
     taskMetrics = new TaskMetrics
-    MockitoAnnotations.initMocks(this)
     shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int](
       shuffleId = 0,
       numMaps = 2,
       dependency = dependency
     )
+    val memoryManager = new TestMemoryManager(conf)
+    val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
     when(dependency.partitioner).thenReturn(new HashPartitioner(7))
     when(dependency.serializer).thenReturn(new JavaSerializer(conf))
     when(taskContext.taskMetrics()).thenReturn(taskMetrics)
     when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
-    doAnswer { (invocationOnMock: InvocationOnMock) =>
-      val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
-      if (tmp != null) {
-        outputFile.delete
-        tmp.renameTo(outputFile)
-      }
-      null
-    }.when(blockResolver)
-      .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))
     when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+    when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
+
+    when(blockResolver.writeIndexFileAndCommit(
+      anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])))
+      .thenAnswer { invocationOnMock =>
+        val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+        if (tmp != null) {
+          outputFile.delete
+          tmp.renameTo(outputFile)
+        }
+        null
+      }
+
     when(blockManager.getDiskWriter(
       any[BlockId],
       any[File],
       any[SerializerInstance],
       anyInt(),
-      any[ShuffleWriteMetrics]
-    )).thenAnswer((invocation: InvocationOnMock) => {
-      val args = invocation.getArguments
-      val manager = new SerializerManager(new JavaSerializer(conf), conf)
-      new DiskBlockObjectWriter(
-        args(1).asInstanceOf[File],
-        manager,
-        args(2).asInstanceOf[SerializerInstance],
-        args(3).asInstanceOf[Int],
-        syncWrites = false,
-        args(4).asInstanceOf[ShuffleWriteMetrics],
-        blockId = args(0).asInstanceOf[BlockId]
-      )
-    })
-    when(diskBlockManager.createTempShuffleBlock()).thenAnswer((_: InvocationOnMock) => {
-      val blockId = new TempShuffleBlockId(UUID.randomUUID)
-      val file = new File(tempDir, blockId.name)
-      blockIdToFileMap.put(blockId, file)
-      temporaryFilesCreated += file
-      (blockId, file)
-    })
-    when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) =>
+      any[ShuffleWriteMetrics]))
+      .thenAnswer { invocation =>
+        val args = invocation.getArguments
+        val manager = new SerializerManager(new JavaSerializer(conf), conf)
+        new DiskBlockObjectWriter(
+          args(1).asInstanceOf[File],
+          manager,
+          args(2).asInstanceOf[SerializerInstance],
+          args(3).asInstanceOf[Int],
+          syncWrites = false,
+          args(4).asInstanceOf[ShuffleWriteMetrics],
+          blockId = args(0).asInstanceOf[BlockId])
+      }
+
+    when(diskBlockManager.createTempShuffleBlock())
+      .thenAnswer { _ =>
+        val blockId = new TempShuffleBlockId(UUID.randomUUID)
+        val file = new File(tempDir, blockId.name)
+        blockIdToFileMap.put(blockId, file)
+        temporaryFilesCreated += file
+        (blockId, file)
+      }
+
+    when(diskBlockManager.getFile(any[BlockId])).thenAnswer { invocation =>
       blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId])
     }
+
+    shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents(
+      conf, blockManager, blockResolver)
   }
 
   override def afterEach(): Unit = {
+    TaskContext.unset()
     try {
       Utils.deleteRecursively(tempDir)
       blockIdToFileMap.clear()
@@ -122,12 +138,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
   test("write empty iterator") {
     val writer = new BypassMergeSortShuffleWriter[Int, Int](
       blockManager,
-      blockResolver,
       shuffleHandle,
       0, // MapId
+      0L, // MapTaskAttemptId
       conf,
-      taskContext.taskMetrics().shuffleWriteMetrics
-    )
+      taskContext.taskMetrics().shuffleWriteMetrics,
+      shuffleExecutorComponents)
+
     writer.write(Iterator.empty)
     writer.stop( /* success = */ true)
     assert(writer.getPartitionLengths.sum === 0)
@@ -141,28 +158,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
     assert(taskMetrics.memoryBytesSpilled === 0)
   }
 
-  test("write with some empty partitions") {
-    def records: Iterator[(Int, Int)] =
-      Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
-    val writer = new BypassMergeSortShuffleWriter[Int, Int](
-      blockManager,
-      blockResolver,
-      shuffleHandle,
-      0, // MapId
-      conf,
-      taskContext.taskMetrics().shuffleWriteMetrics
-    )
-    writer.write(records)
-    writer.stop( /* success = */ true)
-    assert(temporaryFilesCreated.nonEmpty)
-    assert(writer.getPartitionLengths.sum === outputFile.length())
-    assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files
-    assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
-    val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
-    assert(shuffleWriteMetrics.bytesWritten === outputFile.length())
-    assert(shuffleWriteMetrics.recordsWritten === records.length)
-    assert(taskMetrics.diskBytesSpilled === 0)
-    assert(taskMetrics.memoryBytesSpilled === 0)
+  Seq(true, false).foreach { transferTo =>
+    test(s"write with some empty partitions - transferTo $transferTo") {
+      val transferConf = conf.clone.set("spark.file.transferTo", transferTo.toString)
+      def records: Iterator[(Int, Int)] =
+        Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+      val writer = new BypassMergeSortShuffleWriter[Int, Int](
+        blockManager,
+        shuffleHandle,
+        0, // MapId
+        0L,
+        transferConf,
+        taskContext.taskMetrics().shuffleWriteMetrics,
+        shuffleExecutorComponents)
+      writer.write(records)
+      writer.stop( /* success = */ true)
+      assert(temporaryFilesCreated.nonEmpty)
+      assert(writer.getPartitionLengths.sum === outputFile.length())
+      assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files
+      assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted
+      val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
+      assert(shuffleWriteMetrics.bytesWritten === outputFile.length())
+      assert(shuffleWriteMetrics.recordsWritten === records.length)
+      assert(taskMetrics.diskBytesSpilled === 0)
+      assert(taskMetrics.memoryBytesSpilled === 0)
+    }
   }
 
   test("only generate temp shuffle file for non-empty partition") {
@@ -181,12 +201,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
 
     val writer = new BypassMergeSortShuffleWriter[Int, Int](
       blockManager,
-      blockResolver,
       shuffleHandle,
       0, // MapId
+      0L,
       conf,
-      taskContext.taskMetrics().shuffleWriteMetrics
-    )
+      taskContext.taskMetrics().shuffleWriteMetrics,
+      shuffleExecutorComponents)
 
     intercept[SparkException] {
       writer.write(records)
@@ -203,12 +223,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
   test("cleanup of intermediate files after errors") {
     val writer = new BypassMergeSortShuffleWriter[Int, Int](
       blockManager,
-      blockResolver,
       shuffleHandle,
       0, // MapId
+      0L,
       conf,
-      taskContext.taskMetrics().shuffleWriteMetrics
-    )
+      taskContext.taskMetrics().shuffleWriteMetrics,
+      shuffleExecutorComponents)
     intercept[SparkException] {
       writer.write((0 until 100000).iterator.map(i => {
         if (i == 99990) {
@@ -221,5 +241,4 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
     writer.stop( /* success = */ false)
     assert(temporaryFilesCreated.count(_.exists()) === 0)
   }
-
 }
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
new file mode 100644
index 0000000..5693b98
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io
+
+import java.io.{File, FileInputStream}
+import java.nio.channels.FileChannel
+import java.nio.file.Files
+import java.util.Arrays
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.{any, anyInt}
+import org.mockito.Mock
+import org.mockito.Mockito.when
+import org.mockito.MockitoAnnotations
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.util.Utils
+
+class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS)
+  private var blockResolver: IndexShuffleBlockResolver = _
+
+  private val NUM_PARTITIONS = 4
+  private val data: Array[Array[Byte]] = (0 until NUM_PARTITIONS).map { p =>
+    if (p == 3) {
+      Array.emptyByteArray
+    } else {
+      (0 to p * 10).map(_ + p).map(_.toByte).toArray
+    }
+  }.toArray
+
+  private val partitionLengths = data.map(_.length)
+
+  private var tempFile: File = _
+  private var mergedOutputFile: File = _
+  private var tempDir: File = _
+  private var partitionSizesInMergedFile: Array[Long] = _
+  private var conf: SparkConf = _
+  private var mapOutputWriter: LocalDiskShuffleMapOutputWriter = _
+
+  override def afterEach(): Unit = {
+    try {
+      Utils.deleteRecursively(tempDir)
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  override def beforeEach(): Unit = {
+    MockitoAnnotations.initMocks(this)
+    tempDir = Utils.createTempDir()
+    mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir)
+    tempFile = File.createTempFile("tempfile", "", tempDir)
+    partitionSizesInMergedFile = null
+    conf = new SparkConf()
+      .set("spark.app.id", "example.spark.app")
+      .set("spark.shuffle.unsafe.file.output.buffer", "16k")
+    when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile)
+    when(blockResolver.writeIndexFileAndCommit(
+      anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])))
+      .thenAnswer { invocationOnMock =>
+        partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]]
+        val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+        if (tmp != null) {
+          mergedOutputFile.delete()
+          tmp.renameTo(mergedOutputFile)
+        }
+        null
+      }
+    mapOutputWriter = new LocalDiskShuffleMapOutputWriter(
+      0,
+      0,
+      NUM_PARTITIONS,
+      blockResolver,
+      conf)
+  }
+
+  test("writing to an outputstream") {
+    (0 until NUM_PARTITIONS).foreach { p =>
+      val writer = mapOutputWriter.getPartitionWriter(p)
+      val stream = writer.openStream()
+      data(p).foreach { i => stream.write(i) }
+      stream.close()
+      intercept[IllegalStateException] {
+        stream.write(p)
+      }
+      assert(writer.getNumBytesWritten === data(p).length)
+    }
+    verifyWrittenRecords()
+  }
+
+  test("writing to a channel") {
+    (0 until NUM_PARTITIONS).foreach { p =>
+      val writer = mapOutputWriter.getPartitionWriter(p)
+      val outputTempFile = File.createTempFile("channelTemp", "", tempDir)
+      Files.write(outputTempFile.toPath, data(p))
+      val tempFileInput = new FileInputStream(outputTempFile)
+      val channel = writer.openChannelWrapper()
+      Utils.tryWithResource(new FileInputStream(outputTempFile)) { tempFileInput =>
+        Utils.tryWithResource(writer.openChannelWrapper().get) { channelWrapper =>
+          assert(channelWrapper.channel().isInstanceOf[FileChannel],
+            "Underlying channel should be a file channel")
+          Utils.copyFileStreamNIO(
+            tempFileInput.getChannel, channelWrapper.channel(), 0L, data(p).length)
+        }
+      }
+      assert(writer.getNumBytesWritten === data(p).length,
+        s"Partition $p does not have the correct number of bytes.")
+    }
+    verifyWrittenRecords()
+  }
+
+  private def readRecordsFromFile() = {
+    val mergedOutputBytes = Files.readAllBytes(mergedOutputFile.toPath)
+    val result = (0 until NUM_PARTITIONS).map { part =>
+      val startOffset = data.slice(0, part).map(_.length).sum
+      val partitionSize = data(part).length
+      Arrays.copyOfRange(mergedOutputBytes, startOffset, startOffset + partitionSize)
+    }.toArray
+    result
+  }
+
+  private def verifyWrittenRecords(): Unit = {
+    mapOutputWriter.commitAllPartitions()
+    assert(partitionSizesInMergedFile === partitionLengths)
+    assert(mergedOutputFile.length() === partitionLengths.sum)
+    assert(data === readRecordsFromFile())
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org