You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by va...@apache.org on 2016/11/30 22:10:38 UTC

spark git commit: [SPARK-18546][CORE] Fix merging shuffle spills when using encryption.

Repository: spark
Updated Branches:
  refs/heads/master f135b70fd -> 93e9d880b


[SPARK-18546][CORE] Fix merging shuffle spills when using encryption.

The problem exists because it's not possible to just concatenate encrypted
partition data from different spill files; currently each partition would
have its own initial vector to set up encryption, and the final merged file
should contain a single initial vector for each merged partiton, otherwise
iterating over each record becomes really hard.

To fix that, UnsafeShuffleWriter now decrypts the partitions when merging,
so that the merged file contains a single initial vector at the start of
the partition data.

Because it's not possible to do that using the fast transferTo path, when
encryption is enabled UnsafeShuffleWriter will revert back to using file
streams when merging. It may be possible to use a hybrid approach when
using encryption, using an intermediate direct buffer when reading from
files and encrypting the data, but that's better left for a separate patch.

As part of the change I made DiskBlockObjectWriter take a SerializerManager
instead of a "wrap stream" closure, since that makes it easier to test the
code without having to mock SerializerManager functionality.

Tested with newly added unit tests (UnsafeShuffleWriterSuite for the write
side and ExternalAppendOnlyMapSuite for integration), and by running some
apps that failed without the fix.

Author: Marcelo Vanzin <va...@cloudera.com>

Closes #15982 from vanzin/SPARK-18546.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/93e9d880
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/93e9d880
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/93e9d880

Branch: refs/heads/master
Commit: 93e9d880bf8a144112d74a6897af4e36fcfa5807
Parents: f135b70
Author: Marcelo Vanzin <va...@cloudera.com>
Authored: Wed Nov 30 14:10:32 2016 -0800
Committer: Marcelo Vanzin <va...@cloudera.com>
Committed: Wed Nov 30 14:10:32 2016 -0800

----------------------------------------------------------------------
 .../spark/shuffle/sort/UnsafeShuffleWriter.java |  48 +++++----
 .../spark/serializer/SerializerManager.scala    |   6 +-
 .../org/apache/spark/storage/BlockManager.scala |   5 +-
 .../spark/storage/DiskBlockObjectWriter.scala   |   6 +-
 .../shuffle/sort/UnsafeShuffleWriterSuite.java  | 100 +++++++++++++------
 .../map/AbstractBytesToBytesMapSuite.java       |  11 +-
 .../unsafe/sort/UnsafeExternalSorterSuite.java  |  21 ++--
 .../BypassMergeSortShuffleWriterSuite.scala     |   5 +-
 .../storage/DiskBlockObjectWriterSuite.scala    |  54 ++++------
 .../collection/ExternalAppendOnlyMapSuite.scala |   8 +-
 10 files changed, 145 insertions(+), 119 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index f235c43..8a17718 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -40,6 +40,8 @@ import org.apache.spark.annotation.Private;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.io.CompressionCodec;
 import org.apache.spark.io.CompressionCodec$;
+import org.apache.commons.io.output.CloseShieldOutputStream;
+import org.apache.commons.io.output.CountingOutputStream;
 import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
@@ -264,6 +266,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
     final boolean fastMergeIsSupported = !compressionEnabled ||
       CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
+    final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled();
     try {
       if (spills.length == 0) {
         new FileOutputStream(outputFile).close(); // Create an empty file
@@ -289,7 +292,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
           // Compression is disabled or we are using an IO compression codec that supports
           // decompression of concatenated compressed streams, so we can perform a fast spill merge
           // that doesn't need to interpret the spilled bytes.
-          if (transferToEnabled) {
+          if (transferToEnabled && !encryptionEnabled) {
             logger.debug("Using transferTo-based fast merge");
             partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
           } else {
@@ -320,9 +323,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
   /**
    * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
    * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
-   * cases where the IO compression codec does not support concatenation of compressed data, or in
-   * cases where users have explicitly disabled use of {@code transferTo} in order to work around
-   * kernel bugs.
+   * cases where the IO compression codec does not support concatenation of compressed data, when
+   * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in
+   * order to work around kernel bugs.
    *
    * @param spills the spills to merge.
    * @param outputFile the file to write the merged data to.
@@ -337,7 +340,11 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     final int numPartitions = partitioner.numPartitions();
     final long[] partitionLengths = new long[numPartitions];
     final InputStream[] spillInputStreams = new FileInputStream[spills.length];
-    OutputStream mergedFileOutputStream = null;
+
+    // Use a counting output stream to avoid having to close the underlying file and ask
+    // the file system for its size after each partition is written.
+    final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(
+      new FileOutputStream(outputFile));
 
     boolean threwException = true;
     try {
@@ -345,34 +352,35 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
         spillInputStreams[i] = new FileInputStream(spills[i].file);
       }
       for (int partition = 0; partition < numPartitions; partition++) {
-        final long initialFileLength = outputFile.length();
-        mergedFileOutputStream =
-          new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
+        final long initialFileLength = mergedFileOutputStream.getByteCount();
+        // Shield the underlying output stream from close() calls, so that we can close the higher
+        // level streams to make sure all data is really flushed and internal state is cleaned.
+        OutputStream partitionOutput = new CloseShieldOutputStream(
+          new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream));
+        partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
         if (compressionCodec != null) {
-          mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
+          partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
         }
-
         for (int i = 0; i < spills.length; i++) {
           final long partitionLengthInSpill = spills[i].partitionLengths[partition];
           if (partitionLengthInSpill > 0) {
-            InputStream partitionInputStream = null;
-            boolean innerThrewException = true;
+            InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i],
+              partitionLengthInSpill, false);
             try {
-              partitionInputStream =
-                  new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false);
+              partitionInputStream = blockManager.serializerManager().wrapForEncryption(
+                partitionInputStream);
               if (compressionCodec != null) {
                 partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
               }
-              ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
-              innerThrewException = false;
+              ByteStreams.copy(partitionInputStream, partitionOutput);
             } finally {
-              Closeables.close(partitionInputStream, innerThrewException);
+              partitionInputStream.close();
             }
           }
         }
-        mergedFileOutputStream.flush();
-        mergedFileOutputStream.close();
-        partitionLengths[partition] = (outputFile.length() - initialFileLength);
+        partitionOutput.flush();
+        partitionOutput.close();
+        partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength);
       }
       threwException = false;
     } finally {

http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 7371f88..686305e 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -75,6 +75,8 @@ private[spark] class SerializerManager(
    * loaded yet. */
   private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
 
+  def encryptionEnabled: Boolean = encryptionKey.isDefined
+
   def canUseKryo(ct: ClassTag[_]): Boolean = {
     primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
   }
@@ -129,7 +131,7 @@ private[spark] class SerializerManager(
   /**
    * Wrap an input stream for encryption if shuffle encryption is enabled
    */
-  private[this] def wrapForEncryption(s: InputStream): InputStream = {
+  def wrapForEncryption(s: InputStream): InputStream = {
     encryptionKey
       .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) }
       .getOrElse(s)
@@ -138,7 +140,7 @@ private[spark] class SerializerManager(
   /**
    * Wrap an output stream for encryption if shuffle encryption is enabled
    */
-  private[this] def wrapForEncryption(s: OutputStream): OutputStream = {
+  def wrapForEncryption(s: OutputStream): OutputStream = {
     encryptionKey
       .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) }
       .getOrElse(s)

http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 982b833..04521c9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -62,7 +62,7 @@ private[spark] class BlockManager(
     executorId: String,
     rpcEnv: RpcEnv,
     val master: BlockManagerMaster,
-    serializerManager: SerializerManager,
+    val serializerManager: SerializerManager,
     val conf: SparkConf,
     memoryManager: MemoryManager,
     mapOutputTracker: MapOutputTracker,
@@ -745,9 +745,8 @@ private[spark] class BlockManager(
       serializerInstance: SerializerInstance,
       bufferSize: Int,
       writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
-    val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _)
     val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
-    new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream,
+    new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize,
       syncWrites, writeMetrics, blockId)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index a499827..3cb12fc 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -22,7 +22,7 @@ import java.nio.channels.FileChannel
 
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.internal.Logging
-import org.apache.spark.serializer.{SerializationStream, SerializerInstance}
+import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
 import org.apache.spark.util.Utils
 
 /**
@@ -37,9 +37,9 @@ import org.apache.spark.util.Utils
  */
 private[spark] class DiskBlockObjectWriter(
     val file: File,
+    serializerManager: SerializerManager,
     serializerInstance: SerializerInstance,
     bufferSize: Int,
-    wrapStream: OutputStream => OutputStream,
     syncWrites: Boolean,
     // These write metrics concurrently shared with other active DiskBlockObjectWriters who
     // are themselves performing writes. All updates must be relative.
@@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter(
       initialized = true
     }
 
-    bs = wrapStream(mcs)
+    bs = serializerManager.wrapStream(blockId, mcs)
     objOut = serializerInstance.serializeStream(bs)
     streamOpen = true
     this

http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index a96cd82..088b681 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -26,11 +26,9 @@ import scala.Product2;
 import scala.Tuple2;
 import scala.Tuple2$;
 import scala.collection.Iterator;
-import scala.runtime.AbstractFunction1;
 
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.Iterators;
-import com.google.common.io.ByteStreams;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -53,6 +51,7 @@ import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.memory.TestMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.security.CryptoStreamUtils;
 import org.apache.spark.serializer.*;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.storage.*;
@@ -77,7 +76,6 @@ public class UnsafeShuffleWriterSuite {
   final LinkedList<File> spillFilesCreated = new LinkedList<>();
   SparkConf conf;
   final Serializer serializer = new KryoSerializer(new SparkConf());
-  final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf());
   TaskMetrics taskMetrics;
 
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
@@ -86,17 +84,6 @@ public class UnsafeShuffleWriterSuite {
   @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
   @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
 
-  private final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      if (conf.getBoolean("spark.shuffle.compress", true)) {
-        return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
-      } else {
-        return stream;
-      }
-    }
-  }
-
   @After
   public void tearDown() {
     Utils.deleteRecursively(tempDir);
@@ -121,6 +108,11 @@ public class UnsafeShuffleWriterSuite {
     memoryManager = new TestMemoryManager(conf);
     taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
 
+    // Some tests will override this manager because they change the configuration. This is a
+    // default for tests that don't need a specific one.
+    SerializerManager manager = new SerializerManager(serializer, conf);
+    when(blockManager.serializerManager()).thenReturn(manager);
+
     when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
     when(blockManager.getDiskWriter(
       any(BlockId.class),
@@ -131,12 +123,11 @@ public class UnsafeShuffleWriterSuite {
       @Override
       public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
         Object[] args = invocationOnMock.getArguments();
-
         return new DiskBlockObjectWriter(
           (File) args[1],
+          blockManager.serializerManager(),
           (SerializerInstance) args[2],
           (Integer) args[3],
-          new WrapStream(),
           false,
           (ShuffleWriteMetrics) args[4],
           (BlockId) args[0]
@@ -201,9 +192,10 @@ public class UnsafeShuffleWriterSuite {
     for (int i = 0; i < NUM_PARTITITONS; i++) {
       final long partitionSize = partitionSizesInMergedFile[i];
       if (partitionSize > 0) {
-        InputStream in = new FileInputStream(mergedOutputFile);
-        ByteStreams.skipFully(in, startOffset);
-        in = new LimitedInputStream(in, partitionSize);
+        FileInputStream fin = new FileInputStream(mergedOutputFile);
+        fin.getChannel().position(startOffset);
+        InputStream in = new LimitedInputStream(fin, partitionSize);
+        in = blockManager.serializerManager().wrapForEncryption(in);
         if (conf.getBoolean("spark.shuffle.compress", true)) {
           in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
         }
@@ -294,14 +286,32 @@ public class UnsafeShuffleWriterSuite {
   }
 
   private void testMergingSpills(
-      boolean transferToEnabled,
-      String compressionCodecName) throws IOException {
+      final boolean transferToEnabled,
+      String compressionCodecName,
+      boolean encrypt) throws Exception {
     if (compressionCodecName != null) {
       conf.set("spark.shuffle.compress", "true");
       conf.set("spark.io.compression.codec", compressionCodecName);
     } else {
       conf.set("spark.shuffle.compress", "false");
     }
+    conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt);
+
+    SerializerManager manager;
+    if (encrypt) {
+      manager = new SerializerManager(serializer, conf,
+        Option.apply(CryptoStreamUtils.createKey(conf)));
+    } else {
+      manager = new SerializerManager(serializer, conf);
+    }
+
+    when(blockManager.serializerManager()).thenReturn(manager);
+    testMergingSpills(transferToEnabled, encrypt);
+  }
+
+  private void testMergingSpills(
+      boolean transferToEnabled,
+      boolean encrypted) throws IOException {
     final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
     final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>();
     for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
@@ -324,6 +334,7 @@ public class UnsafeShuffleWriterSuite {
     for (long size: partitionSizesInMergedFile) {
       sumOfPartitionSizes += size;
     }
+
     assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
 
     assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile()));
@@ -338,42 +349,72 @@ public class UnsafeShuffleWriterSuite {
 
   @Test
   public void mergeSpillsWithTransferToAndLZF() throws Exception {
-    testMergingSpills(true, LZFCompressionCodec.class.getName());
+    testMergingSpills(true, LZFCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndLZF() throws Exception {
-    testMergingSpills(false, LZFCompressionCodec.class.getName());
+    testMergingSpills(false, LZFCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithTransferToAndLZ4() throws Exception {
-    testMergingSpills(true, LZ4CompressionCodec.class.getName());
+    testMergingSpills(true, LZ4CompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
-    testMergingSpills(false, LZ4CompressionCodec.class.getName());
+    testMergingSpills(false, LZ4CompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithTransferToAndSnappy() throws Exception {
-    testMergingSpills(true, SnappyCompressionCodec.class.getName());
+    testMergingSpills(true, SnappyCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
-    testMergingSpills(false, SnappyCompressionCodec.class.getName());
+    testMergingSpills(false, SnappyCompressionCodec.class.getName(), false);
   }
 
   @Test
   public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
-    testMergingSpills(true, null);
+    testMergingSpills(true, null, false);
   }
 
   @Test
   public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
-    testMergingSpills(false, null);
+    testMergingSpills(false, null, false);
+  }
+
+  @Test
+  public void mergeSpillsWithCompressionAndEncryption() throws Exception {
+    // This should actually be translated to a "file stream merge" internally, just have the
+    // test to make sure that it's the case.
+    testMergingSpills(true, LZ4CompressionCodec.class.getName(), true);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception {
+    testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
+  }
+
+  @Test
+  public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception {
+    conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false");
+    testMergingSpills(false, LZ4CompressionCodec.class.getName(), true);
+  }
+
+  @Test
+  public void mergeSpillsWithEncryptionAndNoCompression() throws Exception {
+    // This should actually be translated to a "file stream merge" internally, just have the
+    // test to make sure that it's the case.
+    testMergingSpills(true, null, true);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception {
+    testMergingSpills(false, null, true);
   }
 
   @Test
@@ -531,4 +572,5 @@ public class UnsafeShuffleWriterSuite {
       writer.stop(false);
     }
   }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 33709b4..2656814 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -19,13 +19,11 @@ package org.apache.spark.unsafe.map;
 
 import java.io.File;
 import java.io.IOException;
-import java.io.OutputStream;
 import java.nio.ByteBuffer;
 import java.util.*;
 
 import scala.Tuple2;
 import scala.Tuple2$;
-import scala.runtime.AbstractFunction1;
 
 import org.junit.After;
 import org.junit.Assert;
@@ -75,13 +73,6 @@ public abstract class AbstractBytesToBytesMapSuite {
   @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
   @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
 
-  private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      return stream;
-    }
-  }
-
   @Before
   public void setup() {
     memoryManager =
@@ -120,9 +111,9 @@ public abstract class AbstractBytesToBytesMapSuite {
 
         return new DiskBlockObjectWriter(
           (File) args[1],
+          serializerManager,
           (SerializerInstance) args[2],
           (Integer) args[3],
-          new WrapStream(),
           false,
           (ShuffleWriteMetrics) args[4],
           (BlockId) args[0]

http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index a9cf8ff..fbbe530 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -19,14 +19,12 @@ package org.apache.spark.util.collection.unsafe.sort;
 
 import java.io.File;
 import java.io.IOException;
-import java.io.OutputStream;
 import java.util.Arrays;
 import java.util.LinkedList;
 import java.util.UUID;
 
 import scala.Tuple2;
 import scala.Tuple2$;
-import scala.runtime.AbstractFunction1;
 
 import org.junit.After;
 import org.junit.Before;
@@ -57,13 +55,15 @@ import static org.mockito.Mockito.*;
 
 public class UnsafeExternalSorterSuite {
 
+  private final SparkConf conf = new SparkConf();
+
   final LinkedList<File> spillFilesCreated = new LinkedList<>();
   final TestMemoryManager memoryManager =
-    new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false"));
+    new TestMemoryManager(conf.clone().set("spark.memory.offHeap.enabled", "false"));
   final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0);
   final SerializerManager serializerManager = new SerializerManager(
-    new JavaSerializer(new SparkConf()),
-    new SparkConf().set("spark.shuffle.spill.compress", "false"));
+    new JavaSerializer(conf),
+    conf.clone().set("spark.shuffle.spill.compress", "false"));
   // Use integer comparison for comparing prefixes (which are partition ids, in this case)
   final PrefixComparator prefixComparator = PrefixComparators.LONG;
   // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so
@@ -86,14 +86,7 @@ public class UnsafeExternalSorterSuite {
 
   protected boolean shouldUseRadixSort() { return false; }
 
-  private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m");
-
-  private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      return stream;
-    }
-  }
+  private final long pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "4m");
 
   @Before
   public void setUp() {
@@ -126,9 +119,9 @@ public class UnsafeExternalSorterSuite {
 
         return new DiskBlockObjectWriter(
           (File) args[1],
+          serializerManager,
           (SerializerInstance) args[2],
           (Integer) args[3],
-          new WrapStream(),
           false,
           (ShuffleWriteMetrics) args[4],
           (BlockId) args[0]

http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 4429416..85ccb33 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark._
 import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics}
-import org.apache.spark.serializer.{JavaSerializer, SerializerInstance}
+import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager}
 import org.apache.spark.shuffle.IndexShuffleBlockResolver
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
@@ -90,11 +90,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
     )).thenAnswer(new Answer[DiskBlockObjectWriter] {
       override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = {
         val args = invocation.getArguments
+        val manager = new SerializerManager(new JavaSerializer(conf), conf)
         new DiskBlockObjectWriter(
           args(1).asInstanceOf[File],
+          manager,
           args(2).asInstanceOf[SerializerInstance],
           args(3).asInstanceOf[Int],
-          wrapStream = identity,
           syncWrites = false,
           args(4).asInstanceOf[ShuffleWriteMetrics],
           blockId = args(0).asInstanceOf[BlockId]

http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
index 684e978..bfb3ac4 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
 import org.apache.spark.util.Utils
 
 class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
@@ -42,11 +42,19 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
     }
   }
 
-  test("verify write metrics") {
+  private def createWriter(): (DiskBlockObjectWriter, File, ShuffleWriteMetrics) = {
     val file = new File(tempDir, "somefile")
+    val conf = new SparkConf()
+    val serializerManager = new SerializerManager(new JavaSerializer(conf), conf)
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+      file, serializerManager, new JavaSerializer(new SparkConf()).newInstance(), 1024, true,
+      writeMetrics)
+    (writer, file, writeMetrics)
+  }
+
+  test("verify write metrics") {
+    val (writer, file, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     // Record metrics update on every write
@@ -66,10 +74,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("verify write metrics on revert") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, _, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     // Record metrics update on every write
@@ -89,10 +94,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("Reopening a closed block writer") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, _, _) = createWriter()
 
     writer.open()
     writer.close()
@@ -102,10 +104,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("calling revertPartialWritesAndClose() on a partial write should truncate up to commit") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     val firstSegment = writer.commitAndGet()
@@ -120,10 +119,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("calling revertPartialWritesAndClose() after commit() should have no effect") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
 
     writer.write(Long.box(20), Long.box(30))
     val firstSegment = writer.commitAndGet()
@@ -136,10 +132,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
@@ -153,10 +146,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("commit() and close() should be idempotent") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
@@ -173,10 +163,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("revertPartialWritesAndClose() should be idempotent") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, file, writeMetrics) = createWriter()
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
@@ -191,10 +178,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("commit() and close() without ever opening or writing") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    val (writer, _, _) = createWriter()
     val segment = writer.commitAndGet()
     writer.close()
     assert(segment.length === 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/93e9d880/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 5141e36..7f08382 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.util.collection
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark._
+import org.apache.spark.internal.config._
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.memory.MemoryTestingUtils
 
@@ -230,14 +231,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
     }
   }
 
+  test("spilling with compression and encryption") {
+    testSimpleSpilling(Some(CompressionCodec.DEFAULT_COMPRESSION_CODEC), encrypt = true)
+  }
+
   /**
    * Test spilling through simple aggregations and cogroups.
    * If a compression codec is provided, use it. Otherwise, do not compress spills.
    */
-  private def testSimpleSpilling(codec: Option[String] = None): Unit = {
+  private def testSimpleSpilling(codec: Option[String] = None, encrypt: Boolean = false): Unit = {
     val size = 1000
     val conf = createSparkConf(loadDefaults = true, codec)  // Load defaults for Spark home
     conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString)
+    conf.set(IO_ENCRYPTION_ENABLED, encrypt)
     sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
 
     assertSpilled(sc, "reduceByKey") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org