You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ir...@apache.org on 2018/09/17 21:06:51 UTC

[3/3] spark git commit: [CORE] Updates to remote cache reads

[CORE] Updates to remote cache reads

Covered by tests in DistributedSuite

(cherry picked from commit a97001d21757ae214c86371141bd78a376200f66)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7beb3417
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7beb3417
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7beb3417

Branch: refs/heads/branch-2.4
Commit: 7beb3417591b0f9c436d9175f3343ee79795d536
Parents: 80e317b
Author: Imran Rashid <ir...@cloudera.com>
Authored: Wed Aug 22 16:38:28 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Mon Sep 17 15:28:40 2018 -0500

----------------------------------------------------------------------
 .../spark/network/buffer/ManagedBuffer.java     |  5 +-
 .../spark/network/shuffle/DownloadFile.java     | 47 ++++++++++
 .../network/shuffle/DownloadFileManager.java    | 36 ++++++++
 .../shuffle/DownloadFileWritableChannel.java    | 30 +++++++
 .../network/shuffle/ExternalShuffleClient.java  |  4 +-
 .../network/shuffle/OneForOneBlockFetcher.java  | 28 +++---
 .../spark/network/shuffle/ShuffleClient.java    |  4 +-
 .../network/shuffle/SimpleDownloadFile.java     | 91 ++++++++++++++++++++
 .../spark/network/shuffle/TempFileManager.java  | 36 --------
 .../spark/network/BlockTransferService.scala    |  6 +-
 .../netty/NettyBlockTransferService.scala       |  4 +-
 .../org/apache/spark/storage/BlockManager.scala | 78 ++++++++++++++---
 .../org/apache/spark/storage/DiskStore.scala    | 16 ++++
 .../storage/ShuffleBlockFetcherIterator.scala   | 21 +++--
 .../spark/storage/BlockManagerSuite.scala       |  8 +-
 .../ShuffleBlockFetcherIteratorSuite.scala      |  6 +-
 16 files changed, 328 insertions(+), 92 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
index 1861f8d..2d573f5 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
@@ -36,7 +36,10 @@ import java.nio.ByteBuffer;
  */
 public abstract class ManagedBuffer {
 
-  /** Number of bytes of the data. */
+  /**
+   * Number of bytes of the data. If this buffer will decrypt for all of the views into the data,
+   * this is the size of the decrypted data.
+   */
   public abstract long size();
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java
new file mode 100644
index 0000000..633622b
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+
+/**
+ * A handle on the file used when fetching remote data to disk.  Used to ensure the lifecycle of
+ * writing the data, reading it back, and then cleaning it up is followed.  Specific implementations
+ * may also handle encryption.  The data can be read only via DownloadFileWritableChannel,
+ * which ensures data is not read until after the writer is closed.
+ */
+public interface DownloadFile {
+  /**
+   * Delete the file.
+   *
+   * @return  <code>true</code> if and only if the file or directory is
+   *          successfully deleted; <code>false</code> otherwise
+   */
+  boolean delete();
+
+  /**
+   * A channel for writing data to the file.  This special channel allows access to the data for
+   * reading, after the channel is closed, via {@link DownloadFileWritableChannel#closeAndRead()}.
+   */
+  DownloadFileWritableChannel openForWriting() throws IOException;
+
+  /**
+   * The path of the file, intended only for debug purposes.
+   */
+  String path();
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java
new file mode 100644
index 0000000..c335a17
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A manager to create temp block files used when fetching remote data to reduce the memory usage.
+ * It will clean files when they won't be used any more.
+ */
+public interface DownloadFileManager {
+
+  /** Create a temp block file. */
+  DownloadFile createTempFile(TransportConf transportConf);
+
+  /**
+   * Register a temp file to clean up when it won't be used any more. Return whether the
+   * file is registered successfully. If `false`, the caller should clean up the file by itself.
+   */
+  boolean registerTempFileToClean(DownloadFile file);
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java
new file mode 100644
index 0000000..dbbbac4
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+import java.nio.channels.WritableByteChannel;
+
+/**
+ * A channel for writing data which is fetched to disk, which allows access to the written data only
+ * after the writer has been closed.  Used with DownloadFile and DownloadFileManager.
+ */
+public interface DownloadFileWritableChannel extends WritableByteChannel {
+  ManagedBuffer closeAndRead();
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 7ed0b6e..9a2cf0f 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -91,7 +91,7 @@ public class ExternalShuffleClient extends ShuffleClient {
       String execId,
       String[] blockIds,
       BlockFetchingListener listener,
-      TempFileManager tempFileManager) {
+      DownloadFileManager downloadFileManager) {
     checkInit();
     logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
     try {
@@ -99,7 +99,7 @@ public class ExternalShuffleClient extends ShuffleClient {
           (blockIds1, listener1) -> {
             TransportClient client = clientFactory.createClient(host, port);
             new OneForOneBlockFetcher(client, appId, execId,
-              blockIds1, listener1, conf, tempFileManager).start();
+              blockIds1, listener1, conf, downloadFileManager).start();
           };
 
       int maxRetries = conf.maxIORetries();

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index 0bc5718..3058702 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -17,18 +17,13 @@
 
 package org.apache.spark.network.shuffle;
 
-import java.io.File;
-import java.io.FileOutputStream;
 import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.nio.channels.Channels;
-import java.nio.channels.WritableByteChannel;
 import java.util.Arrays;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
 import org.apache.spark.network.buffer.ManagedBuffer;
 import org.apache.spark.network.client.ChunkReceivedCallback;
 import org.apache.spark.network.client.RpcResponseCallback;
@@ -58,7 +53,7 @@ public class OneForOneBlockFetcher {
   private final BlockFetchingListener listener;
   private final ChunkReceivedCallback chunkCallback;
   private final TransportConf transportConf;
-  private final TempFileManager tempFileManager;
+  private final DownloadFileManager downloadFileManager;
 
   private StreamHandle streamHandle = null;
 
@@ -79,14 +74,14 @@ public class OneForOneBlockFetcher {
       String[] blockIds,
       BlockFetchingListener listener,
       TransportConf transportConf,
-      TempFileManager tempFileManager) {
+      DownloadFileManager downloadFileManager) {
     this.client = client;
     this.openMessage = new OpenBlocks(appId, execId, blockIds);
     this.blockIds = blockIds;
     this.listener = listener;
     this.chunkCallback = new ChunkCallback();
     this.transportConf = transportConf;
-    this.tempFileManager = tempFileManager;
+    this.downloadFileManager = downloadFileManager;
   }
 
   /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
@@ -125,7 +120,7 @@ public class OneForOneBlockFetcher {
           // Immediately request all chunks -- we expect that the total size of the request is
           // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
           for (int i = 0; i < streamHandle.numChunks; i++) {
-            if (tempFileManager != null) {
+            if (downloadFileManager != null) {
               client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
                 new DownloadCallback(i));
             } else {
@@ -159,13 +154,13 @@ public class OneForOneBlockFetcher {
 
   private class DownloadCallback implements StreamCallback {
 
-    private WritableByteChannel channel = null;
-    private File targetFile = null;
+    private DownloadFileWritableChannel channel = null;
+    private DownloadFile targetFile = null;
     private int chunkIndex;
 
     DownloadCallback(int chunkIndex) throws IOException {
-      this.targetFile = tempFileManager.createTempFile();
-      this.channel = Channels.newChannel(new FileOutputStream(targetFile));
+      this.targetFile = downloadFileManager.createTempFile(transportConf);
+      this.channel = targetFile.openForWriting();
       this.chunkIndex = chunkIndex;
     }
 
@@ -178,11 +173,8 @@ public class OneForOneBlockFetcher {
 
     @Override
     public void onComplete(String streamId) throws IOException {
-      channel.close();
-      ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0,
-        targetFile.length());
-      listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
-      if (!tempFileManager.registerTempFileToClean(targetFile)) {
+      listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead());
+      if (!downloadFileManager.registerTempFileToClean(targetFile)) {
         targetFile.delete();
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
index 18b04fe..62b99c4 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
@@ -43,7 +43,7 @@ public abstract class ShuffleClient implements Closeable {
    * @param execId the executor id.
    * @param blockIds block ids to fetch.
    * @param listener the listener to receive block fetching status.
-   * @param tempFileManager TempFileManager to create and clean temp files.
+   * @param downloadFileManager DownloadFileManager to create and clean temp files.
    *                        If it's not <code>null</code>, the remote blocks will be streamed
    *                        into temp shuffle files to reduce the memory usage, otherwise,
    *                        they will be kept in memory.
@@ -54,7 +54,7 @@ public abstract class ShuffleClient implements Closeable {
       String execId,
       String[] blockIds,
       BlockFetchingListener listener,
-      TempFileManager tempFileManager);
+      DownloadFileManager downloadFileManager);
 
   /**
    * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java
new file mode 100644
index 0000000..670612f
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.shuffle;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.Channels;
+import java.nio.channels.WritableByteChannel;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A DownloadFile that does not take any encryption settings into account for reading and
+ * writing data.
+ *
+ * This does *not* mean the data in the file is un-encrypted -- it could be that the data is
+ * already encrypted when its written, and subsequent layer is responsible for decrypting.
+ */
+public class SimpleDownloadFile implements DownloadFile {
+
+  private final File file;
+  private final TransportConf transportConf;
+
+  public SimpleDownloadFile(File file, TransportConf transportConf) {
+    this.file = file;
+    this.transportConf = transportConf;
+  }
+
+  @Override
+  public boolean delete() {
+    return file.delete();
+  }
+
+  @Override
+  public DownloadFileWritableChannel openForWriting() throws IOException {
+    return new SimpleDownloadWritableChannel();
+  }
+
+  @Override
+  public String path() {
+    return file.getAbsolutePath();
+  }
+
+  private class SimpleDownloadWritableChannel implements DownloadFileWritableChannel {
+
+    private final WritableByteChannel channel;
+
+    SimpleDownloadWritableChannel() throws FileNotFoundException {
+      channel = Channels.newChannel(new FileOutputStream(file));
+    }
+
+    @Override
+    public ManagedBuffer closeAndRead() {
+      return new FileSegmentManagedBuffer(transportConf, file, 0, file.length());
+    }
+
+    @Override
+    public int write(ByteBuffer src) throws IOException {
+      return channel.write(src);
+    }
+
+    @Override
+    public boolean isOpen() {
+      return channel.isOpen();
+    }
+
+    @Override
+    public void close() throws IOException {
+      channel.close();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java
deleted file mode 100644
index 552364d..0000000
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.shuffle;
-
-import java.io.File;
-
-/**
- * A manager to create temp block files to reduce the memory usage and also clean temp
- * files when they won't be used any more.
- */
-public interface TempFileManager {
-
-  /** Create a temp block file. */
-  File createTempFile();
-
-  /**
-   * Register a temp file to clean up when it won't be used any more. Return whether the
-   * file is registered successfully. If `false`, the caller should clean up the file by itself.
-   */
-  boolean registerTempFileToClean(File file);
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index 1d8a266..eef8c31 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -26,7 +26,7 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient}
 import org.apache.spark.storage.{BlockId, StorageLevel}
 import org.apache.spark.util.ThreadUtils
 
@@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
       execId: String,
       blockIds: Array[String],
       listener: BlockFetchingListener,
-      tempFileManager: TempFileManager): Unit
+      tempFileManager: DownloadFileManager): Unit
 
   /**
    * Upload a single block to a remote node, available only after [[init]] is invoked.
@@ -92,7 +92,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
       port: Int,
       execId: String,
       blockId: String,
-      tempFileManager: TempFileManager): ManagedBuffer = {
+      tempFileManager: DownloadFileManager): ManagedBuffer = {
     // A monitor for the thread to wait on.
     val result = Promise[ManagedBuffer]()
     fetchBlocks(host, port, execId, Array(blockId),

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 1905632..dc55685 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -33,7 +33,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
 import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory}
 import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
 import org.apache.spark.network.server._
-import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher}
 import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream}
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.serializer.JavaSerializer
@@ -106,7 +106,7 @@ private[spark] class NettyBlockTransferService(
       execId: String,
       blockIds: Array[String],
       listener: BlockFetchingListener,
-      tempFileManager: TempFileManager): Unit = {
+      tempFileManager: DownloadFileManager): Unit = {
     logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
     try {
       val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index f5c69ad..2234146 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -33,6 +33,7 @@ import scala.util.Random
 import scala.util.control.NonFatal
 
 import com.codahale.metrics.{MetricRegistry, MetricSet}
+import com.google.common.io.CountingOutputStream
 
 import org.apache.spark._
 import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
@@ -43,8 +44,9 @@ import org.apache.spark.network._
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.client.StreamCallbackWithID
 import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager}
+import org.apache.spark.network.shuffle._
 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+import org.apache.spark.network.util.TransportConf
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
 import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
@@ -213,11 +215,11 @@ private[spark] class BlockManager(
 
   private var blockReplicationPolicy: BlockReplicationPolicy = _
 
-  // A TempFileManager used to track all the files of remote blocks which above the
+  // A DownloadFileManager used to track all the files of remote blocks which are above the
   // specified memory threshold. Files will be deleted automatically based on weak reference.
   // Exposed for test
   private[storage] val remoteBlockTempFileManager =
-    new BlockManager.RemoteBlockTempFileManager(this)
+    new BlockManager.RemoteBlockDownloadFileManager(this)
   private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)
 
   /**
@@ -1664,23 +1666,28 @@ private[spark] object BlockManager {
     metricRegistry.registerAll(metricSet)
   }
 
-  class RemoteBlockTempFileManager(blockManager: BlockManager)
-      extends TempFileManager with Logging {
+  class RemoteBlockDownloadFileManager(blockManager: BlockManager)
+      extends DownloadFileManager with Logging {
+    // lazy because SparkEnv is set after this
+    lazy val encryptionKey = SparkEnv.get.securityManager.getIOEncryptionKey()
 
-    private class ReferenceWithCleanup(file: File, referenceQueue: JReferenceQueue[File])
-        extends WeakReference[File](file, referenceQueue) {
-      private val filePath = file.getAbsolutePath
+    private class ReferenceWithCleanup(
+        file: DownloadFile,
+        referenceQueue: JReferenceQueue[DownloadFile]
+        ) extends WeakReference[DownloadFile](file, referenceQueue) {
+
+      val filePath = file.path()
 
       def cleanUp(): Unit = {
         logDebug(s"Clean up file $filePath")
 
-        if (!new File(filePath).delete()) {
+        if (!file.delete()) {
           logDebug(s"Fail to delete file $filePath")
         }
       }
     }
 
-    private val referenceQueue = new JReferenceQueue[File]
+    private val referenceQueue = new JReferenceQueue[DownloadFile]
     private val referenceBuffer = Collections.newSetFromMap[ReferenceWithCleanup](
       new ConcurrentHashMap)
 
@@ -1692,11 +1699,21 @@ private[spark] object BlockManager {
     cleaningThread.setName("RemoteBlock-temp-file-clean-thread")
     cleaningThread.start()
 
-    override def createTempFile(): File = {
-      blockManager.diskBlockManager.createTempLocalBlock()._2
+    override def createTempFile(transportConf: TransportConf): DownloadFile = {
+      val file = blockManager.diskBlockManager.createTempLocalBlock()._2
+      encryptionKey match {
+        case Some(key) =>
+          // encryption is enabled, so when we read the decrypted data off the network, we need to
+          // encrypt it when writing to disk.  Note that the data may have been encrypted when it
+          // was cached on disk on the remote side, but it was already decrypted by now (see
+          // EncryptedBlockData).
+          new EncryptedDownloadFile(file, key)
+        case None =>
+          new SimpleDownloadFile(file, transportConf)
+      }
     }
 
-    override def registerTempFileToClean(file: File): Boolean = {
+    override def registerTempFileToClean(file: DownloadFile): Boolean = {
       referenceBuffer.add(new ReferenceWithCleanup(file, referenceQueue))
     }
 
@@ -1724,4 +1741,39 @@ private[spark] object BlockManager {
       }
     }
   }
+
+  /**
+   * A DownloadFile that encrypts data when it is written, and decrypts when it's read.
+   */
+  private class EncryptedDownloadFile(
+      file: File,
+      key: Array[Byte]) extends DownloadFile {
+
+    private val env = SparkEnv.get
+
+    override def delete(): Boolean = file.delete()
+
+    override def openForWriting(): DownloadFileWritableChannel = {
+      new EncryptedDownloadWritableChannel()
+    }
+
+    override def path(): String = file.getAbsolutePath
+
+    private class EncryptedDownloadWritableChannel extends DownloadFileWritableChannel {
+      private val countingOutput: CountingWritableChannel = new CountingWritableChannel(
+        Channels.newChannel(env.serializerManager.wrapForEncryption(new FileOutputStream(file))))
+
+      override def closeAndRead(): ManagedBuffer = {
+        countingOutput.close()
+        val size = countingOutput.getCount
+        new EncryptedManagedBuffer(new EncryptedBlockData(file, size, env.conf, key))
+      }
+
+      override def write(src: ByteBuffer): Int = countingOutput.write(src)
+
+      override def isOpen: Boolean = countingOutput.isOpen()
+
+      override def close(): Unit = countingOutput.close()
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index a820bc7..d88bd71 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -30,6 +30,7 @@ import io.netty.channel.DefaultFileRegion
 
 import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils}
 import org.apache.spark.security.CryptoStreamUtils
 import org.apache.spark.util.Utils
@@ -260,7 +261,22 @@ private class EncryptedBlockData(
         throw e
     }
   }
+}
+
+private class EncryptedManagedBuffer(val blockData: EncryptedBlockData) extends ManagedBuffer {
+
+  // This is the size of the decrypted data
+  override def size(): Long = blockData.size
+
+  override def nioByteBuffer(): ByteBuffer = blockData.toByteBuffer()
+
+  override def convertToNetty(): AnyRef = blockData.toNetty()
+
+  override def createInputStream(): InputStream = blockData.toInputStream()
+
+  override def retain(): ManagedBuffer = this
 
+  override def release(): ManagedBuffer = this
 }
 
 private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long)

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 00d01dd..e534c74 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.storage
 
-import java.io.{File, InputStream, IOException}
+import java.io.{InputStream, IOException}
 import java.nio.ByteBuffer
 import java.util.concurrent.LinkedBlockingQueue
 import javax.annotation.concurrent.GuardedBy
@@ -28,7 +28,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
 import org.apache.spark.{SparkException, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager}
+import org.apache.spark.network.shuffle._
+import org.apache.spark.network.util.TransportConf
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util.Utils
 import org.apache.spark.util.io.ChunkedByteBufferOutputStream
@@ -71,7 +72,7 @@ final class ShuffleBlockFetcherIterator(
     maxBlocksInFlightPerAddress: Int,
     maxReqSizeShuffleToMem: Long,
     detectCorrupt: Boolean)
-  extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging {
+  extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
 
   import ShuffleBlockFetcherIterator._
 
@@ -150,7 +151,7 @@ final class ShuffleBlockFetcherIterator(
    * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
    */
   @GuardedBy("this")
-  private[this] val shuffleFilesSet = mutable.HashSet[File]()
+  private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]()
 
   initialize()
 
@@ -164,11 +165,15 @@ final class ShuffleBlockFetcherIterator(
     currentResult = null
   }
 
-  override def createTempFile(): File = {
-    blockManager.diskBlockManager.createTempLocalBlock()._2
+  override def createTempFile(transportConf: TransportConf): DownloadFile = {
+    // we never need to do any encryption or decryption here, regardless of configs, because that
+    // is handled at another layer in the code.  When encryption is enabled, shuffle data is written
+    // to disk encrypted in the first place, and sent over the network still encrypted.
+    new SimpleDownloadFile(
+      blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf)
   }
 
-  override def registerTempFileToClean(file: File): Boolean = synchronized {
+  override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized {
     if (isZombie) {
       false
     } else {
@@ -204,7 +209,7 @@ final class ShuffleBlockFetcherIterator(
     }
     shuffleFilesSet.foreach { file =>
       if (!file.delete()) {
-        logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath())
+        logWarning("Failed to cleanup shuffle fetch temp file " + file.path())
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index dbee1f6..32d6e8b 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -43,7 +43,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
 import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
 import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf}
 import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
 import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor}
 import org.apache.spark.network.util.TransportConf
 import org.apache.spark.rpc.RpcEnv
@@ -1437,7 +1437,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
 
   class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
     var numCalls = 0
-    var tempFileManager: TempFileManager = null
+    var tempFileManager: DownloadFileManager = null
 
     override def init(blockDataManager: BlockDataManager): Unit = {}
 
@@ -1447,7 +1447,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
         execId: String,
         blockIds: Array[String],
         listener: BlockFetchingListener,
-        tempFileManager: TempFileManager): Unit = {
+        tempFileManager: DownloadFileManager): Unit = {
       listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
     }
 
@@ -1474,7 +1474,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
         port: Int,
         execId: String,
         blockId: String,
-        tempFileManager: TempFileManager): ManagedBuffer = {
+        tempFileManager: DownloadFileManager): ManagedBuffer = {
       numCalls += 1
       this.tempFileManager = tempFileManager
       if (numCalls <= maxFailures) {

http://git-wip-us.apache.org/repos/asf/spark/blob/7beb3417/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index a2997db..b268195 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester
 import org.apache.spark.{SparkFunSuite, TaskContext}
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
 import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util.Utils
@@ -478,12 +478,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val remoteBlocks = Map[BlockId, ManagedBuffer](
       ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
     val transfer = mock(classOf[BlockTransferService])
-    var tempFileManager: TempFileManager = null
+    var tempFileManager: DownloadFileManager = null
     when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
       .thenAnswer(new Answer[Unit] {
         override def answer(invocation: InvocationOnMock): Unit = {
           val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-          tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager]
+          tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager]
           Future {
             listener.onBlockFetchSuccess(
               ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org