You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by va...@apache.org on 2019/07/30 21:18:02 UTC
[spark] branch master updated: [SPARK-28209][CORE][SHUFFLE]
Proposed new shuffle writer API
This is an automated email from the ASF dual-hosted git repository.
vanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new abef84a [SPARK-28209][CORE][SHUFFLE] Proposed new shuffle writer API
abef84a is described below
commit abef84a868e9e15f346eea315bbab0ec8ac8e389
Author: mcheah <mc...@palantir.com>
AuthorDate: Tue Jul 30 14:17:30 2019 -0700
[SPARK-28209][CORE][SHUFFLE] Proposed new shuffle writer API
## What changes were proposed in this pull request?
As part of the shuffle storage API proposed in SPARK-25299, this introduces an API for persisting shuffle data in arbitrary storage systems.
This patch introduces several concepts:
* `ShuffleDataIO`, which is the root of the entire plugin tree that will be proposed over the course of the shuffle API project.
* `ShuffleExecutorComponents` - the subset of plugins for managing shuffle-related components for each executor. This will in turn instantiate shuffle readers and writers.
* `ShuffleMapOutputWriter` interface - instantiated once per map task. This provides child `ShufflePartitionWriter` instances for persisting the bytes for each partition in the map task.
The default implementation of these plugins exactly mirror what was done by the existing shuffle writing code - namely, writing the data to local disk and writing an index file. We leverage the APIs in the `BypassMergeSortShuffleWriter` only. Follow-up PRs will use the APIs in `SortShuffleWriter` and `UnsafeShuffleWriter`, but are left as future work to minimize the review surface area.
## How was this patch tested?
New unit tests were added. Micro-benchmarks indicate there's no slowdown in the affected code paths.
Closes #25007 from mccheah/spark-shuffle-writer-refactor.
Lead-authored-by: mcheah <mc...@palantir.com>
Co-authored-by: mccheah <mc...@palantir.com>
Signed-off-by: Marcelo Vanzin <va...@cloudera.com>
---
.../apache/spark/shuffle/api/ShuffleDataIO.java | 49 ++++
.../shuffle/api/ShuffleExecutorComponents.java | 55 +++++
.../spark/shuffle/api/ShuffleMapOutputWriter.java | 71 ++++++
.../spark/shuffle/api/ShufflePartitionWriter.java | 98 ++++++++
.../shuffle/api/WritableByteChannelWrapper.java | 42 ++++
.../shuffle/sort/BypassMergeSortShuffleWriter.java | 173 +++++++++-----
.../shuffle/sort/io/LocalDiskShuffleDataIO.java | 40 ++++
.../io/LocalDiskShuffleExecutorComponents.java | 71 ++++++
.../sort/io/LocalDiskShuffleMapOutputWriter.java | 261 +++++++++++++++++++++
.../org/apache/spark/internal/config/package.scala | 7 +
.../spark/shuffle/sort/SortShuffleManager.scala | 25 +-
.../main/scala/org/apache/spark/util/Utils.scala | 30 ++-
.../test/scala/org/apache/spark/ShuffleSuite.scala | 16 +-
.../sort/BypassMergeSortShuffleWriterSuite.scala | 149 +++++++-----
.../io/LocalDiskShuffleMapOutputWriterSuite.scala | 147 ++++++++++++
15 files changed, 1087 insertions(+), 147 deletions(-)
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
new file mode 100644
index 0000000..e9e50ec
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ * An interface for plugging in modules for storing and reading temporary shuffle data.
+ * <p>
+ * This is the root of a plugin system for storing shuffle bytes to arbitrary storage
+ * backends in the sort-based shuffle algorithm implemented by the
+ * {@link org.apache.spark.shuffle.sort.SortShuffleManager}. If another shuffle algorithm is
+ * needed instead of sort-based shuffle, one should implement
+ * {@link org.apache.spark.shuffle.ShuffleManager} instead.
+ * <p>
+ * A single instance of this module is loaded per process in the Spark application.
+ * The default implementation reads and writes shuffle data from the local disks of
+ * the executor, and is the implementation of shuffle file storage that has remained
+ * consistent throughout most of Spark's history.
+ * <p>
+ * Alternative implementations of shuffle data storage can be loaded via setting
+ * <code>spark.shuffle.sort.io.plugin.class</code>.
+ * @since 3.0.0
+ */
+@Private
+public interface ShuffleDataIO {
+
+ /**
+ * Called once on executor processes to bootstrap the shuffle data storage modules that
+ * are only invoked on the executors.
+ */
+ ShuffleExecutorComponents executor();
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
new file mode 100644
index 0000000..70c112b
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.IOException;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ * An interface for building shuffle support for Executors.
+ *
+ * @since 3.0.0
+ */
+@Private
+public interface ShuffleExecutorComponents {
+
+ /**
+ * Called once per executor to bootstrap this module with state that is specific to
+ * that executor, specifically the application ID and executor ID.
+ */
+ void initializeExecutor(String appId, String execId);
+
+ /**
+ * Called once per map task to create a writer that will be responsible for persisting all the
+ * partitioned bytes written by that map task.
+ * @param shuffleId Unique identifier for the shuffle the map task is a part of
+ * @param mapId Within the shuffle, the identifier of the map task
+ * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task
+ * with the same (shuffleId, mapId) pair can be distinguished by the
+ * different values of mapTaskAttemptId.
+ * @param numPartitions The number of partitions that will be written by the map task. Some of
+* these partitions may be empty.
+ */
+ ShuffleMapOutputWriter createMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ long mapTaskAttemptId,
+ int numPartitions) throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
new file mode 100644
index 0000000..45a593c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.IOException;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ * A top-level writer that returns child writers for persisting the output of a map task,
+ * and then commits all of the writes as one atomic operation.
+ *
+ * @since 3.0.0
+ */
+@Private
+public interface ShuffleMapOutputWriter {
+
+ /**
+ * Creates a writer that can open an output stream to persist bytes targeted for a given reduce
+ * partition id.
+ * <p>
+ * The chunk corresponds to bytes in the given reduce partition. This will not be called twice
+ * for the same partition within any given map task. The partition identifier will be in the
+ * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was
+ * provided upon the creation of this map output writer via
+ * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}.
+ * <p>
+ * Calls to this method will be invoked with monotonically increasing reducePartitionIds; each
+ * call to this method will be called with a reducePartitionId that is strictly greater than
+ * the reducePartitionIds given to any previous call to this method. This method is not
+ * guaranteed to be called for every partition id in the above described range. In particular,
+ * no guarantees are made as to whether or not this method will be called for empty partitions.
+ */
+ ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException;
+
+ /**
+ * Commits the writes done by all partition writers returned by all calls to this object's
+ * {@link #getPartitionWriter(int)}.
+ * <p>
+ * This should ensure that the writes conducted by this module's partition writers are
+ * available to downstream reduce tasks. If this method throws any exception, this module's
+ * {@link #abort(Throwable)} method will be invoked before propagating the exception.
+ * <p>
+ * This can also close any resources and clean up temporary state if necessary.
+ */
+ void commitAllPartitions() throws IOException;
+
+ /**
+ * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}.
+ * <p>
+ * This should invalidate the results of writing bytes. This can also close any resources and
+ * clean up temporary state if necessary.
+ */
+ void abort(Throwable error) throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
new file mode 100644
index 0000000..9288751
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.IOException;
+import java.util.Optional;
+import java.io.OutputStream;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ * An interface for opening streams to persist partition bytes to a backing data store.
+ * <p>
+ * This writer stores bytes for one (mapper, reducer) pair, corresponding to one shuffle
+ * block.
+ *
+ * @since 3.0.0
+ */
+@Private
+public interface ShufflePartitionWriter {
+
+ /**
+ * Open and return an {@link OutputStream} that can write bytes to the underlying
+ * data store.
+ * <p>
+ * This method will only be called once on this partition writer in the map task, to write the
+ * bytes to the partition. The output stream will only be used to write the bytes for this
+ * partition. The map task closes this output stream upon writing all the bytes for this
+ * block, or if the write fails for any reason.
+ * <p>
+ * Implementations that intend on combining the bytes for all the partitions written by this
+ * map task should reuse the same OutputStream instance across all the partition writers provided
+ * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
+ * {@link OutputStream#close()} does not close the resource, since it will be reused across
+ * partition writes. The underlying resources should be cleaned up in
+ * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+ * {@link ShuffleMapOutputWriter#abort(Throwable)}.
+ */
+ OutputStream openStream() throws IOException;
+
+ /**
+ * Opens and returns a {@link WritableByteChannelWrapper} for transferring bytes from
+ * input byte channels to the underlying shuffle data store.
+ * <p>
+ * This method will only be called once on this partition writer in the map task, to write the
+ * bytes to the partition. The channel will only be used to write the bytes for this
+ * partition. The map task closes this channel upon writing all the bytes for this
+ * block, or if the write fails for any reason.
+ * <p>
+ * Implementations that intend on combining the bytes for all the partitions written by this
+ * map task should reuse the same channel instance across all the partition writers provided
+ * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that
+ * {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel
+ * will be reused across partition writes. The underlying resources should be cleaned up in
+ * {@link ShuffleMapOutputWriter#commitAllPartitions()} and
+ * {@link ShuffleMapOutputWriter#abort(Throwable)}.
+ * <p>
+ * This method is primarily for advanced optimizations where bytes can be copied from the input
+ * spill files to the output channel without copying data into memory. If such optimizations are
+ * not supported, the implementation should return {@link Optional#empty()}. By default, the
+ * implementation returns {@link Optional#empty()}.
+ * <p>
+ * Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the
+ * underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure
+ * that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()},
+ * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or
+ * {@link ShuffleMapOutputWriter#abort(Throwable)}.
+ */
+ default Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
+ return Optional.empty();
+ }
+
+ /**
+ * Returns the number of bytes written either by this writer's output stream opened by
+ * {@link #openStream()} or the byte channel opened by {@link #openChannelWrapper()}.
+ * <p>
+ * This can be different from the number of bytes given by the caller. For example, the
+ * stream might compress or encrypt the bytes before persisting the data to the backing
+ * data store.
+ */
+ long getNumBytesWritten();
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java
new file mode 100644
index 0000000..a204903
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.api;
+
+import java.io.Closeable;
+import java.nio.channels.WritableByteChannel;
+
+import org.apache.spark.annotation.Private;
+
+/**
+ * :: Private ::
+ *
+ * A thin wrapper around a {@link WritableByteChannel}.
+ * <p>
+ * This is primarily provided for the local disk shuffle implementation to provide a
+ * {@link java.nio.channels.FileChannel} that keeps the channel open across partition writes.
+ *
+ * @since 3.0.0
+ */
+@Private
+public interface WritableByteChannelWrapper extends Closeable {
+
+ /**
+ * The underlying channel to write bytes into.
+ */
+ WritableByteChannel channel();
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 32b4467..3ccee70 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -19,8 +19,10 @@ package org.apache.spark.shuffle.sort;
import java.io.File;
import java.io.FileInputStream;
-import java.io.FileOutputStream;
import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.FileChannel;
+import java.util.Optional;
import javax.annotation.Nullable;
import scala.None$;
@@ -34,16 +36,19 @@ import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import org.apache.spark.internal.config.package$;
import org.apache.spark.Partitioner;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.ShufflePartitionWriter;
+import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.internal.config.package$;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -81,8 +86,9 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
+ private final long mapTaskAttemptId;
private final Serializer serializer;
- private final IndexShuffleBlockResolver shuffleBlockResolver;
+ private final ShuffleExecutorComponents shuffleExecutorComponents;
/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
@@ -99,74 +105,82 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
BypassMergeSortShuffleWriter(
BlockManager blockManager,
- IndexShuffleBlockResolver shuffleBlockResolver,
BypassMergeSortShuffleHandle<K, V> handle,
int mapId,
+ long mapTaskAttemptId,
SparkConf conf,
- ShuffleWriteMetricsReporter writeMetrics) {
+ ShuffleWriteMetricsReporter writeMetrics,
+ ShuffleExecutorComponents shuffleExecutorComponents) {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024;
this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
this.blockManager = blockManager;
final ShuffleDependency<K, V, V> dep = handle.dependency();
this.mapId = mapId;
+ this.mapTaskAttemptId = mapTaskAttemptId;
this.shuffleId = dep.shuffleId();
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
- this.shuffleBlockResolver = shuffleBlockResolver;
+ this.shuffleExecutorComponents = shuffleExecutorComponents;
}
@Override
public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
- if (!records.hasNext()) {
- partitionLengths = new long[numPartitions];
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
- return;
- }
- final SerializerInstance serInstance = serializer.newInstance();
- final long openStartTime = System.nanoTime();
- partitionWriters = new DiskBlockObjectWriter[numPartitions];
- partitionWriterSegments = new FileSegment[numPartitions];
- for (int i = 0; i < numPartitions; i++) {
- final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
- blockManager.diskBlockManager().createTempShuffleBlock();
- final File file = tempShuffleBlockIdPlusFile._2();
- final BlockId blockId = tempShuffleBlockIdPlusFile._1();
- partitionWriters[i] =
- blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
- }
- // Creating the file to write to and creating a disk writer both involve interacting with
- // the disk, and can take a long time in aggregate when we open many files, so should be
- // included in the shuffle write time.
- writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
-
- while (records.hasNext()) {
- final Product2<K, V> record = records.next();
- final K key = record._1();
- partitionWriters[partitioner.getPartition(key)].write(key, record._2());
- }
+ ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents
+ .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions);
+ try {
+ if (!records.hasNext()) {
+ partitionLengths = new long[numPartitions];
+ mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(
+ blockManager.shuffleServerId(),
+ partitionLengths);
+ return;
+ }
+ final SerializerInstance serInstance = serializer.newInstance();
+ final long openStartTime = System.nanoTime();
+ partitionWriters = new DiskBlockObjectWriter[numPartitions];
+ partitionWriterSegments = new FileSegment[numPartitions];
+ for (int i = 0; i < numPartitions; i++) {
+ final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
+ blockManager.diskBlockManager().createTempShuffleBlock();
+ final File file = tempShuffleBlockIdPlusFile._2();
+ final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+ partitionWriters[i] =
+ blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
+ }
+ // Creating the file to write to and creating a disk writer both involve interacting with
+ // the disk, and can take a long time in aggregate when we open many files, so should be
+ // included in the shuffle write time.
+ writeMetrics.incWriteTime(System.nanoTime() - openStartTime);
- for (int i = 0; i < numPartitions; i++) {
- try (DiskBlockObjectWriter writer = partitionWriters[i]) {
- partitionWriterSegments[i] = writer.commitAndGet();
+ while (records.hasNext()) {
+ final Product2<K, V> record = records.next();
+ final K key = record._1();
+ partitionWriters[partitioner.getPartition(key)].write(key, record._2());
}
- }
- File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
- File tmp = Utils.tempFileWith(output);
- try {
- partitionLengths = writePartitionedFile(tmp);
- shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
- } finally {
- if (tmp.exists() && !tmp.delete()) {
- logger.error("Error while deleting temp file {}", tmp.getAbsolutePath());
+ for (int i = 0; i < numPartitions; i++) {
+ try (DiskBlockObjectWriter writer = partitionWriters[i]) {
+ partitionWriterSegments[i] = writer.commitAndGet();
+ }
+ }
+
+ partitionLengths = writePartitionedData(mapOutputWriter);
+ mapOutputWriter.commitAllPartitions();
+ mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+ } catch (Exception e) {
+ try {
+ mapOutputWriter.abort(e);
+ } catch (Exception e2) {
+ logger.error("Failed to abort the writer after failing to write map output.", e2);
+ e.addSuppressed(e2);
}
+ throw e;
}
- mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
@VisibleForTesting
@@ -179,43 +193,80 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
*/
- private long[] writePartitionedFile(File outputFile) throws IOException {
+ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException {
// Track location of the partition starts in the output file
final long[] lengths = new long[numPartitions];
if (partitionWriters == null) {
// We were passed an empty iterator
return lengths;
}
-
- final FileOutputStream out = new FileOutputStream(outputFile, true);
final long writeStartTime = System.nanoTime();
- boolean threwException = true;
try {
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
+ ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i);
if (file.exists()) {
- final FileInputStream in = new FileInputStream(file);
- boolean copyThrewException = true;
- try {
- lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
- copyThrewException = false;
- } finally {
- Closeables.close(in, copyThrewException);
+ if (transferToEnabled) {
+ // Using WritableByteChannelWrapper to make resource closing consistent between
+ // this implementation and UnsafeShuffleWriter.
+ Optional<WritableByteChannelWrapper> maybeOutputChannel = writer.openChannelWrapper();
+ if (maybeOutputChannel.isPresent()) {
+ writePartitionedDataWithChannel(file, maybeOutputChannel.get());
+ } else {
+ writePartitionedDataWithStream(file, writer);
+ }
+ } else {
+ writePartitionedDataWithStream(file, writer);
}
if (!file.delete()) {
logger.error("Unable to delete file for partition {}", i);
}
}
+ lengths[i] = writer.getNumBytesWritten();
}
- threwException = false;
} finally {
- Closeables.close(out, threwException);
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
partitionWriters = null;
return lengths;
}
+ private void writePartitionedDataWithChannel(
+ File file,
+ WritableByteChannelWrapper outputChannel) throws IOException {
+ boolean copyThrewException = true;
+ try {
+ FileInputStream in = new FileInputStream(file);
+ try (FileChannel inputChannel = in.getChannel()) {
+ Utils.copyFileStreamNIO(
+ inputChannel, outputChannel.channel(), 0L, inputChannel.size());
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ }
+ } finally {
+ Closeables.close(outputChannel, copyThrewException);
+ }
+ }
+
+ private void writePartitionedDataWithStream(File file, ShufflePartitionWriter writer)
+ throws IOException {
+ boolean copyThrewException = true;
+ FileInputStream in = new FileInputStream(file);
+ OutputStream outputStream;
+ try {
+ outputStream = writer.openStream();
+ try {
+ Utils.copyStream(in, outputStream, false, false);
+ copyThrewException = false;
+ } finally {
+ Closeables.close(outputStream, copyThrewException);
+ }
+ } finally {
+ Closeables.close(in, copyThrewException);
+ }
+ }
+
@Override
public Option<MapStatus> stop(boolean success) {
if (stopping) {
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java
new file mode 100644
index 0000000..cabcb17
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
+import org.apache.spark.shuffle.api.ShuffleDataIO;
+
+/**
+ * Implementation of the {@link ShuffleDataIO} plugin system that replicates the local shuffle
+ * storage and index file functionality that has historically been used from Spark 2.4 and earlier.
+ */
+public class LocalDiskShuffleDataIO implements ShuffleDataIO {
+
+ private final SparkConf sparkConf;
+
+ public LocalDiskShuffleDataIO(SparkConf sparkConf) {
+ this.sparkConf = sparkConf;
+ }
+
+ @Override
+ public ShuffleExecutorComponents executor() {
+ return new LocalDiskShuffleExecutorComponents(sparkConf);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
new file mode 100644
index 0000000..02eb710
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import com.google.common.annotations.VisibleForTesting;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkEnv;
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.storage.BlockManager;
+
+public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents {
+
+ private final SparkConf sparkConf;
+ private BlockManager blockManager;
+ private IndexShuffleBlockResolver blockResolver;
+
+ public LocalDiskShuffleExecutorComponents(SparkConf sparkConf) {
+ this.sparkConf = sparkConf;
+ }
+
+ @VisibleForTesting
+ public LocalDiskShuffleExecutorComponents(
+ SparkConf sparkConf,
+ BlockManager blockManager,
+ IndexShuffleBlockResolver blockResolver) {
+ this.sparkConf = sparkConf;
+ this.blockManager = blockManager;
+ this.blockResolver = blockResolver;
+ }
+
+ @Override
+ public void initializeExecutor(String appId, String execId) {
+ blockManager = SparkEnv.get().blockManager();
+ if (blockManager == null) {
+ throw new IllegalStateException("No blockManager available from the SparkEnv.");
+ }
+ blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager);
+ }
+
+ @Override
+ public ShuffleMapOutputWriter createMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ long mapTaskAttemptId,
+ int numPartitions) {
+ if (blockResolver == null) {
+ throw new IllegalStateException(
+ "Executor components must be initialized before getting writers.");
+ }
+ return new LocalDiskShuffleMapOutputWriter(
+ shuffleId, mapId, numPartitions, blockResolver, sparkConf);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
new file mode 100644
index 0000000..add4634
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io;
+
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
+
+import java.util.Optional;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.api.ShuffleMapOutputWriter;
+import org.apache.spark.shuffle.api.ShufflePartitionWriter;
+import org.apache.spark.shuffle.api.WritableByteChannelWrapper;
+import org.apache.spark.internal.config.package$;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.util.Utils;
+
+/**
+ * Implementation of {@link ShuffleMapOutputWriter} that replicates the functionality of shuffle
+ * persisting shuffle data to local disk alongside index files, identical to Spark's historic
+ * canonical shuffle storage mechanism.
+ */
+public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter {
+
+ private static final Logger log =
+ LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class);
+
+ private final int shuffleId;
+ private final int mapId;
+ private final IndexShuffleBlockResolver blockResolver;
+ private final long[] partitionLengths;
+ private final int bufferSize;
+ private int lastPartitionId = -1;
+ private long currChannelPosition;
+
+ private final File outputFile;
+ private File outputTempFile;
+ private FileOutputStream outputFileStream;
+ private FileChannel outputFileChannel;
+ private BufferedOutputStream outputBufferedFileStream;
+
+ public LocalDiskShuffleMapOutputWriter(
+ int shuffleId,
+ int mapId,
+ int numPartitions,
+ IndexShuffleBlockResolver blockResolver,
+ SparkConf sparkConf) {
+ this.shuffleId = shuffleId;
+ this.mapId = mapId;
+ this.blockResolver = blockResolver;
+ this.bufferSize =
+ (int) (long) sparkConf.get(
+ package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024;
+ this.partitionLengths = new long[numPartitions];
+ this.outputFile = blockResolver.getDataFile(shuffleId, mapId);
+ this.outputTempFile = null;
+ }
+
+ @Override
+ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException {
+ if (reducePartitionId <= lastPartitionId) {
+ throw new IllegalArgumentException("Partitions should be requested in increasing order.");
+ }
+ lastPartitionId = reducePartitionId;
+ if (outputTempFile == null) {
+ outputTempFile = Utils.tempFileWith(outputFile);
+ }
+ if (outputFileChannel != null) {
+ currChannelPosition = outputFileChannel.position();
+ } else {
+ currChannelPosition = 0L;
+ }
+ return new LocalDiskShufflePartitionWriter(reducePartitionId);
+ }
+
+ @Override
+ public void commitAllPartitions() throws IOException {
+ cleanUp();
+ File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null;
+ blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp);
+ }
+
+ @Override
+ public void abort(Throwable error) throws IOException {
+ cleanUp();
+ if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) {
+ log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath());
+ }
+ }
+
+ private void cleanUp() throws IOException {
+ if (outputBufferedFileStream != null) {
+ outputBufferedFileStream.close();
+ }
+ if (outputFileChannel != null) {
+ outputFileChannel.close();
+ }
+ if (outputFileStream != null) {
+ outputFileStream.close();
+ }
+ }
+
+ private void initStream() throws IOException {
+ if (outputFileStream == null) {
+ outputFileStream = new FileOutputStream(outputTempFile, true);
+ }
+ if (outputBufferedFileStream == null) {
+ outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize);
+ }
+ }
+
+ private void initChannel() throws IOException {
+ if (outputFileStream == null) {
+ outputFileStream = new FileOutputStream(outputTempFile, true);
+ }
+ if (outputFileChannel == null) {
+ outputFileChannel = outputFileStream.getChannel();
+ }
+ }
+
+ private class LocalDiskShufflePartitionWriter implements ShufflePartitionWriter {
+
+ private final int partitionId;
+ private PartitionWriterStream partStream = null;
+ private PartitionWriterChannel partChannel = null;
+
+ private LocalDiskShufflePartitionWriter(int partitionId) {
+ this.partitionId = partitionId;
+ }
+
+ @Override
+ public OutputStream openStream() throws IOException {
+ if (partStream == null) {
+ if (outputFileChannel != null) {
+ throw new IllegalStateException("Requested an output channel for a previous write but" +
+ " now an output stream has been requested. Should not be using both channels" +
+ " and streams to write.");
+ }
+ initStream();
+ partStream = new PartitionWriterStream(partitionId);
+ }
+ return partStream;
+ }
+
+ @Override
+ public Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException {
+ if (partChannel == null) {
+ if (partStream != null) {
+ throw new IllegalStateException("Requested an output stream for a previous write but" +
+ " now an output channel has been requested. Should not be using both channels" +
+ " and streams to write.");
+ }
+ initChannel();
+ partChannel = new PartitionWriterChannel(partitionId);
+ }
+ return Optional.of(partChannel);
+ }
+
+ @Override
+ public long getNumBytesWritten() {
+ if (partChannel != null) {
+ try {
+ return partChannel.getCount();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ } else if (partStream != null) {
+ return partStream.getCount();
+ } else {
+ // Assume an empty partition if stream and channel are never created
+ return 0;
+ }
+ }
+ }
+
+ private class PartitionWriterStream extends OutputStream {
+ private final int partitionId;
+ private int count = 0;
+ private boolean isClosed = false;
+
+ PartitionWriterStream(int partitionId) {
+ this.partitionId = partitionId;
+ }
+
+ public int getCount() {
+ return count;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ verifyNotClosed();
+ outputBufferedFileStream.write(b);
+ count++;
+ }
+
+ @Override
+ public void write(byte[] buf, int pos, int length) throws IOException {
+ verifyNotClosed();
+ outputBufferedFileStream.write(buf, pos, length);
+ count += length;
+ }
+
+ @Override
+ public void close() {
+ isClosed = true;
+ partitionLengths[partitionId] = count;
+ }
+
+ private void verifyNotClosed() {
+ if (isClosed) {
+ throw new IllegalStateException("Attempting to write to a closed block output stream.");
+ }
+ }
+ }
+
+ private class PartitionWriterChannel implements WritableByteChannelWrapper {
+
+ private final int partitionId;
+
+ PartitionWriterChannel(int partitionId) {
+ this.partitionId = partitionId;
+ }
+
+ public long getCount() throws IOException {
+ long writtenPosition = outputFileChannel.position();
+ return writtenPosition - currChannelPosition;
+ }
+
+ @Override
+ public WritableByteChannel channel() {
+ return outputFileChannel;
+ }
+
+ @Override
+ public void close() throws IOException {
+ partitionLengths[partitionId] = getCount();
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index f2b88fe..cda3b57 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -24,6 +24,7 @@ import org.apache.spark.metrics.GarbageCollectionMetrics
import org.apache.spark.network.shuffle.Constants
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode}
+import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO
import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy}
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.util.Utils
@@ -811,6 +812,12 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[spark] val SHUFFLE_IO_PLUGIN_CLASS =
+ ConfigBuilder("spark.shuffle.sort.io.plugin.class")
+ .doc("Name of the class to use for shuffle IO.")
+ .stringConf
+ .createWithDefault(classOf[LocalDiskShuffleDataIO].getName)
+
private[spark] val SHUFFLE_FILE_BUFFER_SIZE =
ConfigBuilder("spark.shuffle.file.buffer")
.doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " +
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index b59fa8e..17719f5 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -20,8 +20,10 @@ package org.apache.spark.shuffle.sort
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark._
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{config, Logging}
import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents}
+import org.apache.spark.util.Utils
/**
* In sort-based shuffle, incoming records are sorted according to their target partition ids, then
@@ -68,6 +70,8 @@ import org.apache.spark.shuffle._
*/
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+ import SortShuffleManager._
+
if (!conf.getBoolean("spark.shuffle.spill", true)) {
logWarning(
"spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." +
@@ -79,6 +83,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
*/
private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
+ private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf)
+
override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
/**
@@ -134,7 +140,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
- shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+ shuffleBlockResolver,
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
@@ -144,11 +150,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
- shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
+ context.taskAttemptId(),
env.conf,
- metrics)
+ metrics,
+ shuffleExecutorComponents)
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
}
@@ -205,6 +212,16 @@ private[spark] object SortShuffleManager extends Logging {
true
}
}
+
+ private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = {
+ val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS)
+ val maybeIO = Utils.loadExtensions(
+ classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf)
+ require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses")
+ val executorComponents = maybeIO.head.executor()
+ executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId)
+ executorComponents
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 80d70a1..24042db 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
import java.net._
import java.nio.ByteBuffer
-import java.nio.channels.{Channels, FileChannel}
+import java.nio.channels.{Channels, FileChannel, WritableByteChannel}
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.security.SecureRandom
@@ -394,10 +394,14 @@ private[spark] object Utils extends Logging {
def copyFileStreamNIO(
input: FileChannel,
- output: FileChannel,
+ output: WritableByteChannel,
startPosition: Long,
bytesToCopy: Long): Unit = {
- val initialPos = output.position()
+ val outputInitialState = output match {
+ case outputFileChannel: FileChannel =>
+ Some((outputFileChannel.position(), outputFileChannel))
+ case _ => None
+ }
var count = 0L
// In case transferTo method transferred less data than we have required.
while (count < bytesToCopy) {
@@ -412,15 +416,17 @@ private[spark] object Utils extends Logging {
// kernel version 2.6.32, this issue can be seen in
// https://bugs.openjdk.java.net/browse/JDK-7052359
// This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
- val finalPos = output.position()
- val expectedPos = initialPos + bytesToCopy
- assert(finalPos == expectedPos,
- s"""
- |Current position $finalPos do not equal to expected position $expectedPos
- |after transferTo, please check your kernel version to see if it is 2.6.32,
- |this is a kernel bug which will lead to unexpected behavior when using transferTo.
- |You can set spark.file.transferTo = false to disable this NIO feature.
- """.stripMargin)
+ outputInitialState.foreach { case (initialPos, outputFileChannel) =>
+ val finalPos = outputFileChannel.position()
+ val expectedPos = initialPos + bytesToCopy
+ assert(finalPos == expectedPos,
+ s"""
+ |Current position $finalPos do not equal to expected position $expectedPos
+ |after transferTo, please check your kernel version to see if it is 2.6.32,
+ |this is a kernel bug which will lead to unexpected behavior when using transferTo.
+ |You can set spark.file.transferTo = false to disable this NIO feature.
+ """.stripMargin)
+ }
}
/**
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 8b1084a..923c9c9 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -383,13 +383,18 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
// simultaneously, and everything is still OK
def writeAndClose(
- writer: ShuffleWriter[Int, Int])(
- iter: Iterator[(Int, Int)]): Option[MapStatus] = {
- val files = writer.write(iter)
- writer.stop(true)
+ writer: ShuffleWriter[Int, Int],
+ taskContext: TaskContext)(
+ iter: Iterator[(Int, Int)]): Option[MapStatus] = {
+ try {
+ val files = writer.write(iter)
+ writer.stop(true)
+ } finally {
+ TaskContext.unset()
+ }
}
val interleaver = new InterleaveIterators(
- data1, writeAndClose(writer1), data2, writeAndClose(writer2))
+ data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2))
val (mapOutput1, mapOutput2) = interleaver.run()
// check that we can read the map output and it has the right data
@@ -407,6 +412,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics)
+ TaskContext.unset()
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index fc1422d..b9f81fa 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -27,13 +27,15 @@ import org.mockito.{Mock, MockitoAnnotations}
import org.mockito.Answers.RETURNS_SMART_NULLS
import org.mockito.ArgumentMatchers.{any, anyInt}
import org.mockito.Mockito._
-import org.mockito.invocation.InvocationOnMock
import org.scalatest.BeforeAndAfterEach
import org.apache.spark._
import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.shuffle.api.ShuffleExecutorComponents
+import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents
import org.apache.spark.storage._
import org.apache.spark.util.Utils
@@ -48,68 +50,82 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
private var taskMetrics: TaskMetrics = _
private var tempDir: File = _
private var outputFile: File = _
+ private var shuffleExecutorComponents: ShuffleExecutorComponents = _
private val conf: SparkConf = new SparkConf(loadDefaults = false)
+ .set("spark.app.id", "sampleApp")
private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _
override def beforeEach(): Unit = {
super.beforeEach()
+ MockitoAnnotations.initMocks(this)
tempDir = Utils.createTempDir()
outputFile = File.createTempFile("shuffle", null, tempDir)
taskMetrics = new TaskMetrics
- MockitoAnnotations.initMocks(this)
shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int](
shuffleId = 0,
numMaps = 2,
dependency = dependency
)
+ val memoryManager = new TestMemoryManager(conf)
+ val taskMemoryManager = new TaskMemoryManager(memoryManager, 0)
when(dependency.partitioner).thenReturn(new HashPartitioner(7))
when(dependency.serializer).thenReturn(new JavaSerializer(conf))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
- doAnswer { (invocationOnMock: InvocationOnMock) =>
- val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
- if (tmp != null) {
- outputFile.delete
- tmp.renameTo(outputFile)
- }
- null
- }.when(blockResolver)
- .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+ when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager)
+
+ when(blockResolver.writeIndexFileAndCommit(
+ anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])))
+ .thenAnswer { invocationOnMock =>
+ val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ if (tmp != null) {
+ outputFile.delete
+ tmp.renameTo(outputFile)
+ }
+ null
+ }
+
when(blockManager.getDiskWriter(
any[BlockId],
any[File],
any[SerializerInstance],
anyInt(),
- any[ShuffleWriteMetrics]
- )).thenAnswer((invocation: InvocationOnMock) => {
- val args = invocation.getArguments
- val manager = new SerializerManager(new JavaSerializer(conf), conf)
- new DiskBlockObjectWriter(
- args(1).asInstanceOf[File],
- manager,
- args(2).asInstanceOf[SerializerInstance],
- args(3).asInstanceOf[Int],
- syncWrites = false,
- args(4).asInstanceOf[ShuffleWriteMetrics],
- blockId = args(0).asInstanceOf[BlockId]
- )
- })
- when(diskBlockManager.createTempShuffleBlock()).thenAnswer((_: InvocationOnMock) => {
- val blockId = new TempShuffleBlockId(UUID.randomUUID)
- val file = new File(tempDir, blockId.name)
- blockIdToFileMap.put(blockId, file)
- temporaryFilesCreated += file
- (blockId, file)
- })
- when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) =>
+ any[ShuffleWriteMetrics]))
+ .thenAnswer { invocation =>
+ val args = invocation.getArguments
+ val manager = new SerializerManager(new JavaSerializer(conf), conf)
+ new DiskBlockObjectWriter(
+ args(1).asInstanceOf[File],
+ manager,
+ args(2).asInstanceOf[SerializerInstance],
+ args(3).asInstanceOf[Int],
+ syncWrites = false,
+ args(4).asInstanceOf[ShuffleWriteMetrics],
+ blockId = args(0).asInstanceOf[BlockId])
+ }
+
+ when(diskBlockManager.createTempShuffleBlock())
+ .thenAnswer { _ =>
+ val blockId = new TempShuffleBlockId(UUID.randomUUID)
+ val file = new File(tempDir, blockId.name)
+ blockIdToFileMap.put(blockId, file)
+ temporaryFilesCreated += file
+ (blockId, file)
+ }
+
+ when(diskBlockManager.getFile(any[BlockId])).thenAnswer { invocation =>
blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId])
}
+
+ shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents(
+ conf, blockManager, blockResolver)
}
override def afterEach(): Unit = {
+ TaskContext.unset()
try {
Utils.deleteRecursively(tempDir)
blockIdToFileMap.clear()
@@ -122,12 +138,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("write empty iterator") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
- blockResolver,
shuffleHandle,
0, // MapId
+ 0L, // MapTaskAttemptId
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
- )
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ shuffleExecutorComponents)
+
writer.write(Iterator.empty)
writer.stop( /* success = */ true)
assert(writer.getPartitionLengths.sum === 0)
@@ -141,28 +158,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
assert(taskMetrics.memoryBytesSpilled === 0)
}
- test("write with some empty partitions") {
- def records: Iterator[(Int, Int)] =
- Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
- val writer = new BypassMergeSortShuffleWriter[Int, Int](
- blockManager,
- blockResolver,
- shuffleHandle,
- 0, // MapId
- conf,
- taskContext.taskMetrics().shuffleWriteMetrics
- )
- writer.write(records)
- writer.stop( /* success = */ true)
- assert(temporaryFilesCreated.nonEmpty)
- assert(writer.getPartitionLengths.sum === outputFile.length())
- assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files
- assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
- val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
- assert(shuffleWriteMetrics.bytesWritten === outputFile.length())
- assert(shuffleWriteMetrics.recordsWritten === records.length)
- assert(taskMetrics.diskBytesSpilled === 0)
- assert(taskMetrics.memoryBytesSpilled === 0)
+ Seq(true, false).foreach { transferTo =>
+ test(s"write with some empty partitions - transferTo $transferTo") {
+ val transferConf = conf.clone.set("spark.file.transferTo", transferTo.toString)
+ def records: Iterator[(Int, Int)] =
+ Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+ val writer = new BypassMergeSortShuffleWriter[Int, Int](
+ blockManager,
+ shuffleHandle,
+ 0, // MapId
+ 0L,
+ transferConf,
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ shuffleExecutorComponents)
+ writer.write(records)
+ writer.stop( /* success = */ true)
+ assert(temporaryFilesCreated.nonEmpty)
+ assert(writer.getPartitionLengths.sum === outputFile.length())
+ assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files
+ assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted
+ val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics
+ assert(shuffleWriteMetrics.bytesWritten === outputFile.length())
+ assert(shuffleWriteMetrics.recordsWritten === records.length)
+ assert(taskMetrics.diskBytesSpilled === 0)
+ assert(taskMetrics.memoryBytesSpilled === 0)
+ }
}
test("only generate temp shuffle file for non-empty partition") {
@@ -181,12 +201,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
- blockResolver,
shuffleHandle,
0, // MapId
+ 0L,
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
- )
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ shuffleExecutorComponents)
intercept[SparkException] {
writer.write(records)
@@ -203,12 +223,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
test("cleanup of intermediate files after errors") {
val writer = new BypassMergeSortShuffleWriter[Int, Int](
blockManager,
- blockResolver,
shuffleHandle,
0, // MapId
+ 0L,
conf,
- taskContext.taskMetrics().shuffleWriteMetrics
- )
+ taskContext.taskMetrics().shuffleWriteMetrics,
+ shuffleExecutorComponents)
intercept[SparkException] {
writer.write((0 until 100000).iterator.map(i => {
if (i == 99990) {
@@ -221,5 +241,4 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
writer.stop( /* success = */ false)
assert(temporaryFilesCreated.count(_.exists()) === 0)
}
-
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
new file mode 100644
index 0000000..5693b98
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort.io
+
+import java.io.{File, FileInputStream}
+import java.nio.channels.FileChannel
+import java.nio.file.Files
+import java.util.Arrays
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.{any, anyInt}
+import org.mockito.Mock
+import org.mockito.Mockito.when
+import org.mockito.MockitoAnnotations
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.util.Utils
+
+class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+ @Mock(answer = RETURNS_SMART_NULLS)
+ private var blockResolver: IndexShuffleBlockResolver = _
+
+ private val NUM_PARTITIONS = 4
+ private val data: Array[Array[Byte]] = (0 until NUM_PARTITIONS).map { p =>
+ if (p == 3) {
+ Array.emptyByteArray
+ } else {
+ (0 to p * 10).map(_ + p).map(_.toByte).toArray
+ }
+ }.toArray
+
+ private val partitionLengths = data.map(_.length)
+
+ private var tempFile: File = _
+ private var mergedOutputFile: File = _
+ private var tempDir: File = _
+ private var partitionSizesInMergedFile: Array[Long] = _
+ private var conf: SparkConf = _
+ private var mapOutputWriter: LocalDiskShuffleMapOutputWriter = _
+
+ override def afterEach(): Unit = {
+ try {
+ Utils.deleteRecursively(tempDir)
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ override def beforeEach(): Unit = {
+ MockitoAnnotations.initMocks(this)
+ tempDir = Utils.createTempDir()
+ mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir)
+ tempFile = File.createTempFile("tempfile", "", tempDir)
+ partitionSizesInMergedFile = null
+ conf = new SparkConf()
+ .set("spark.app.id", "example.spark.app")
+ .set("spark.shuffle.unsafe.file.output.buffer", "16k")
+ when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile)
+ when(blockResolver.writeIndexFileAndCommit(
+ anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])))
+ .thenAnswer { invocationOnMock =>
+ partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]]
+ val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ if (tmp != null) {
+ mergedOutputFile.delete()
+ tmp.renameTo(mergedOutputFile)
+ }
+ null
+ }
+ mapOutputWriter = new LocalDiskShuffleMapOutputWriter(
+ 0,
+ 0,
+ NUM_PARTITIONS,
+ blockResolver,
+ conf)
+ }
+
+ test("writing to an outputstream") {
+ (0 until NUM_PARTITIONS).foreach { p =>
+ val writer = mapOutputWriter.getPartitionWriter(p)
+ val stream = writer.openStream()
+ data(p).foreach { i => stream.write(i) }
+ stream.close()
+ intercept[IllegalStateException] {
+ stream.write(p)
+ }
+ assert(writer.getNumBytesWritten === data(p).length)
+ }
+ verifyWrittenRecords()
+ }
+
+ test("writing to a channel") {
+ (0 until NUM_PARTITIONS).foreach { p =>
+ val writer = mapOutputWriter.getPartitionWriter(p)
+ val outputTempFile = File.createTempFile("channelTemp", "", tempDir)
+ Files.write(outputTempFile.toPath, data(p))
+ val tempFileInput = new FileInputStream(outputTempFile)
+ val channel = writer.openChannelWrapper()
+ Utils.tryWithResource(new FileInputStream(outputTempFile)) { tempFileInput =>
+ Utils.tryWithResource(writer.openChannelWrapper().get) { channelWrapper =>
+ assert(channelWrapper.channel().isInstanceOf[FileChannel],
+ "Underlying channel should be a file channel")
+ Utils.copyFileStreamNIO(
+ tempFileInput.getChannel, channelWrapper.channel(), 0L, data(p).length)
+ }
+ }
+ assert(writer.getNumBytesWritten === data(p).length,
+ s"Partition $p does not have the correct number of bytes.")
+ }
+ verifyWrittenRecords()
+ }
+
+ private def readRecordsFromFile() = {
+ val mergedOutputBytes = Files.readAllBytes(mergedOutputFile.toPath)
+ val result = (0 until NUM_PARTITIONS).map { part =>
+ val startOffset = data.slice(0, part).map(_.length).sum
+ val partitionSize = data(part).length
+ Arrays.copyOfRange(mergedOutputBytes, startOffset, startOffset + partitionSize)
+ }.toArray
+ result
+ }
+
+ private def verifyWrittenRecords(): Unit = {
+ mapOutputWriter.commitAllPartitions()
+ assert(partitionSizesInMergedFile === partitionLengths)
+ assert(mergedOutputFile.length() === partitionLengths.sum)
+ assert(data === readRecordsFromFile())
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org