You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ir...@apache.org on 2018/09/13 14:20:39 UTC

[1/4] spark git commit: [CORE] Updates to remote cache reads

Repository: spark
Updated Branches:
  refs/heads/branch-2.3 9ac9f36c4 -> 575fea120


[CORE] Updates to remote cache reads

Covered by tests in DistributedSuite


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/575fea12
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/575fea12
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/575fea12

Branch: refs/heads/branch-2.3
Commit: 575fea120e25249716e3f680396580c5f9e26b5b
Parents: 6d742d1
Author: Imran Rashid <ir...@cloudera.com>
Authored: Wed Aug 22 16:38:28 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Thu Sep 13 09:19:56 2018 -0500

----------------------------------------------------------------------
 .../spark/network/buffer/ManagedBuffer.java     |  5 +-
 .../spark/network/shuffle/DownloadFile.java     | 47 ++++++++++
 .../network/shuffle/DownloadFileManager.java    | 36 ++++++++
 .../shuffle/DownloadFileWritableChannel.java    | 31 +++++++
 .../network/shuffle/ExternalShuffleClient.java  |  4 +-
 .../network/shuffle/OneForOneBlockFetcher.java  | 28 +++---
 .../spark/network/shuffle/ShuffleClient.java    |  4 +-
 .../network/shuffle/SimpleDownloadFile.java     | 91 ++++++++++++++++++++
 .../spark/network/shuffle/TempFileManager.java  | 36 --------
 .../spark/network/BlockTransferService.scala    |  6 +-
 .../netty/NettyBlockTransferService.scala       |  4 +-
 .../org/apache/spark/storage/BlockManager.scala | 78 ++++++++++++++---
 .../org/apache/spark/storage/DiskStore.scala    | 16 ++++
 .../storage/ShuffleBlockFetcherIterator.scala   | 21 +++--
 .../spark/storage/BlockManagerSuite.scala       |  8 +-
 .../ShuffleBlockFetcherIteratorSuite.scala      |  6 +-
 16 files changed, 329 insertions(+), 92 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
----------------------------------------------------------------------
diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
index 1861f8d..2d573f5 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
@@ -36,7 +36,10 @@ import java.nio.ByteBuffer;
  */
 public abstract class ManagedBuffer {
 
-  /** Number of bytes of the data. */
+  /**
+   * Number of bytes of the data. If this buffer will decrypt for all of the views into the data,
+   * this is the size of the decrypted data.
+   */
   public abstract long size();
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java
new file mode 100644
index 0000000..76c2d02
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFile.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+
+/**
+ * A handle on the file used when fetching remote data to disk.  Used to ensure the lifecycle of
+ * writing the data, reading it back, and then cleaning it up is followed.  Specific implementations
+ * may also handle encryption.  The data can be read only via DownloadFileWritableChannel,
+ * which ensures data is not read until after the writer is closed.
+ */
+public interface DownloadFile {
+  /**
+   * Delete the file.
+   *
+   * @return  <code>true</code> if and only if the file or directory is
+   *          successfully deleted; <code>false</code> otherwise
+   */
+  public boolean delete();
+
+  /**
+   * A channel for writing data to the file.  This special channel allows access to the data for
+   * reading, after the channel is closed, via {@link DownloadFileWritableChannel#closeAndRead()}.
+   */
+  public DownloadFileWritableChannel openForWriting() throws IOException;
+
+  /**
+   * The path of the file, intended only for debug purposes.
+   */
+  public String path();
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java
new file mode 100644
index 0000000..c335a17
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileManager.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A manager to create temp block files used when fetching remote data to reduce the memory usage.
+ * It will clean files when they won't be used any more.
+ */
+public interface DownloadFileManager {
+
+  /** Create a temp block file. */
+  DownloadFile createTempFile(TransportConf transportConf);
+
+  /**
+   * Register a temp file to clean up when it won't be used any more. Return whether the
+   * file is registered successfully. If `false`, the caller should clean up the file by itself.
+   */
+  boolean registerTempFileToClean(DownloadFile file);
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java
new file mode 100644
index 0000000..acf194c
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/DownloadFileWritableChannel.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+import java.io.OutputStream;
+import java.nio.channels.WritableByteChannel;
+
+/**
+ * A channel for writing data which is fetched to disk, which allows access to the written data only
+ * after the writer has been closed.  Used with DownloadFile and DownloadFileManager.
+ */
+public interface DownloadFileWritableChannel extends WritableByteChannel {
+  public ManagedBuffer closeAndRead();
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
index 7ed0b6e..9a2cf0f 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java
@@ -91,7 +91,7 @@ public class ExternalShuffleClient extends ShuffleClient {
       String execId,
       String[] blockIds,
       BlockFetchingListener listener,
-      TempFileManager tempFileManager) {
+      DownloadFileManager downloadFileManager) {
     checkInit();
     logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
     try {
@@ -99,7 +99,7 @@ public class ExternalShuffleClient extends ShuffleClient {
           (blockIds1, listener1) -> {
             TransportClient client = clientFactory.createClient(host, port);
             new OneForOneBlockFetcher(client, appId, execId,
-              blockIds1, listener1, conf, tempFileManager).start();
+              blockIds1, listener1, conf, downloadFileManager).start();
           };
 
       int maxRetries = conf.maxIORetries();

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index 0bc5718..3058702 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -17,18 +17,13 @@
 
 package org.apache.spark.network.shuffle;
 
-import java.io.File;
-import java.io.FileOutputStream;
 import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.nio.channels.Channels;
-import java.nio.channels.WritableByteChannel;
 import java.util.Arrays;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
 import org.apache.spark.network.buffer.ManagedBuffer;
 import org.apache.spark.network.client.ChunkReceivedCallback;
 import org.apache.spark.network.client.RpcResponseCallback;
@@ -58,7 +53,7 @@ public class OneForOneBlockFetcher {
   private final BlockFetchingListener listener;
   private final ChunkReceivedCallback chunkCallback;
   private final TransportConf transportConf;
-  private final TempFileManager tempFileManager;
+  private final DownloadFileManager downloadFileManager;
 
   private StreamHandle streamHandle = null;
 
@@ -79,14 +74,14 @@ public class OneForOneBlockFetcher {
       String[] blockIds,
       BlockFetchingListener listener,
       TransportConf transportConf,
-      TempFileManager tempFileManager) {
+      DownloadFileManager downloadFileManager) {
     this.client = client;
     this.openMessage = new OpenBlocks(appId, execId, blockIds);
     this.blockIds = blockIds;
     this.listener = listener;
     this.chunkCallback = new ChunkCallback();
     this.transportConf = transportConf;
-    this.tempFileManager = tempFileManager;
+    this.downloadFileManager = downloadFileManager;
   }
 
   /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
@@ -125,7 +120,7 @@ public class OneForOneBlockFetcher {
           // Immediately request all chunks -- we expect that the total size of the request is
           // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]].
           for (int i = 0; i < streamHandle.numChunks; i++) {
-            if (tempFileManager != null) {
+            if (downloadFileManager != null) {
               client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i),
                 new DownloadCallback(i));
             } else {
@@ -159,13 +154,13 @@ public class OneForOneBlockFetcher {
 
   private class DownloadCallback implements StreamCallback {
 
-    private WritableByteChannel channel = null;
-    private File targetFile = null;
+    private DownloadFileWritableChannel channel = null;
+    private DownloadFile targetFile = null;
     private int chunkIndex;
 
     DownloadCallback(int chunkIndex) throws IOException {
-      this.targetFile = tempFileManager.createTempFile();
-      this.channel = Channels.newChannel(new FileOutputStream(targetFile));
+      this.targetFile = downloadFileManager.createTempFile(transportConf);
+      this.channel = targetFile.openForWriting();
       this.chunkIndex = chunkIndex;
     }
 
@@ -178,11 +173,8 @@ public class OneForOneBlockFetcher {
 
     @Override
     public void onComplete(String streamId) throws IOException {
-      channel.close();
-      ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0,
-        targetFile.length());
-      listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer);
-      if (!tempFileManager.registerTempFileToClean(targetFile)) {
+      listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead());
+      if (!downloadFileManager.registerTempFileToClean(targetFile)) {
         targetFile.delete();
       }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
index 18b04fe..62b99c4 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java
@@ -43,7 +43,7 @@ public abstract class ShuffleClient implements Closeable {
    * @param execId the executor id.
    * @param blockIds block ids to fetch.
    * @param listener the listener to receive block fetching status.
-   * @param tempFileManager TempFileManager to create and clean temp files.
+   * @param downloadFileManager DownloadFileManager to create and clean temp files.
    *                        If it's not <code>null</code>, the remote blocks will be streamed
    *                        into temp shuffle files to reduce the memory usage, otherwise,
    *                        they will be kept in memory.
@@ -54,7 +54,7 @@ public abstract class ShuffleClient implements Closeable {
       String execId,
       String[] blockIds,
       BlockFetchingListener listener,
-      TempFileManager tempFileManager);
+      DownloadFileManager downloadFileManager);
 
   /**
    * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java
new file mode 100644
index 0000000..670612f
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/SimpleDownloadFile.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.network.shuffle;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.Channels;
+import java.nio.channels.WritableByteChannel;
+
+import org.apache.spark.network.buffer.FileSegmentManagedBuffer;
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * A DownloadFile that does not take any encryption settings into account for reading and
+ * writing data.
+ *
+ * This does *not* mean the data in the file is un-encrypted -- it could be that the data is
+ * already encrypted when its written, and subsequent layer is responsible for decrypting.
+ */
+public class SimpleDownloadFile implements DownloadFile {
+
+  private final File file;
+  private final TransportConf transportConf;
+
+  public SimpleDownloadFile(File file, TransportConf transportConf) {
+    this.file = file;
+    this.transportConf = transportConf;
+  }
+
+  @Override
+  public boolean delete() {
+    return file.delete();
+  }
+
+  @Override
+  public DownloadFileWritableChannel openForWriting() throws IOException {
+    return new SimpleDownloadWritableChannel();
+  }
+
+  @Override
+  public String path() {
+    return file.getAbsolutePath();
+  }
+
+  private class SimpleDownloadWritableChannel implements DownloadFileWritableChannel {
+
+    private final WritableByteChannel channel;
+
+    SimpleDownloadWritableChannel() throws FileNotFoundException {
+      channel = Channels.newChannel(new FileOutputStream(file));
+    }
+
+    @Override
+    public ManagedBuffer closeAndRead() {
+      return new FileSegmentManagedBuffer(transportConf, file, 0, file.length());
+    }
+
+    @Override
+    public int write(ByteBuffer src) throws IOException {
+      return channel.write(src);
+    }
+
+    @Override
+    public boolean isOpen() {
+      return channel.isOpen();
+    }
+
+    @Override
+    public void close() throws IOException {
+      channel.close();
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java
----------------------------------------------------------------------
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java
deleted file mode 100644
index 552364d..0000000
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java
+++ /dev/null
@@ -1,36 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.shuffle;
-
-import java.io.File;
-
-/**
- * A manager to create temp block files to reduce the memory usage and also clean temp
- * files when they won't be used any more.
- */
-public interface TempFileManager {
-
-  /** Create a temp block file. */
-  File createTempFile();
-
-  /**
-   * Register a temp file to clean up when it won't be used any more. Return whether the
-   * file is registered successfully. If `false`, the caller should clean up the file by itself.
-   */
-  boolean registerTempFileToClean(File file);
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index 1d8a266..eef8c31 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -26,7 +26,7 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ShuffleClient}
 import org.apache.spark.storage.{BlockId, StorageLevel}
 import org.apache.spark.util.ThreadUtils
 
@@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
       execId: String,
       blockIds: Array[String],
       listener: BlockFetchingListener,
-      tempFileManager: TempFileManager): Unit
+      tempFileManager: DownloadFileManager): Unit
 
   /**
    * Upload a single block to a remote node, available only after [[init]] is invoked.
@@ -92,7 +92,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
       port: Int,
       execId: String,
       blockId: String,
-      tempFileManager: TempFileManager): ManagedBuffer = {
+      tempFileManager: DownloadFileManager): ManagedBuffer = {
     // A monitor for the thread to wait on.
     val result = Promise[ManagedBuffer]()
     fetchBlocks(host, port, execId, Array(blockId),

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index b7d8c35..1c48ad6 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -32,7 +32,7 @@ import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory}
 import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap}
 import org.apache.spark.network.server._
-import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher}
 import org.apache.spark.network.shuffle.protocol.UploadBlock
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.serializer.JavaSerializer
@@ -105,7 +105,7 @@ private[spark] class NettyBlockTransferService(
       execId: String,
       blockIds: Array[String],
       listener: BlockFetchingListener,
-      tempFileManager: TempFileManager): Unit = {
+      tempFileManager: DownloadFileManager): Unit = {
     logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
     try {
       val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index df1a4be..027c31f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -33,6 +33,7 @@ import scala.util.Random
 import scala.util.control.NonFatal
 
 import com.codahale.metrics.{MetricRegistry, MetricSet}
+import com.google.common.io.CountingOutputStream
 
 import org.apache.spark._
 import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
@@ -42,8 +43,9 @@ import org.apache.spark.metrics.source.Source
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.netty.SparkTransportConf
-import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager}
+import org.apache.spark.network.shuffle._
 import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
+import org.apache.spark.network.util.TransportConf
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
 import org.apache.spark.shuffle.ShuffleManager
@@ -206,11 +208,11 @@ private[spark] class BlockManager(
 
   private var blockReplicationPolicy: BlockReplicationPolicy = _
 
-  // A TempFileManager used to track all the files of remote blocks which above the
+  // A DownloadFileManager used to track all the files of remote blocks which are above the
   // specified memory threshold. Files will be deleted automatically based on weak reference.
   // Exposed for test
   private[storage] val remoteBlockTempFileManager =
-    new BlockManager.RemoteBlockTempFileManager(this)
+    new BlockManager.RemoteBlockDownloadFileManager(this)
   private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)
 
   /**
@@ -1582,23 +1584,28 @@ private[spark] object BlockManager {
     metricRegistry.registerAll(metricSet)
   }
 
-  class RemoteBlockTempFileManager(blockManager: BlockManager)
-      extends TempFileManager with Logging {
+  class RemoteBlockDownloadFileManager(blockManager: BlockManager)
+      extends DownloadFileManager with Logging {
+    // lazy because SparkEnv is set after this
+    lazy val encryptionKey = SparkEnv.get.securityManager.getIOEncryptionKey()
 
-    private class ReferenceWithCleanup(file: File, referenceQueue: JReferenceQueue[File])
-        extends WeakReference[File](file, referenceQueue) {
-      private val filePath = file.getAbsolutePath
+    private class ReferenceWithCleanup(
+        file: DownloadFile,
+        referenceQueue: JReferenceQueue[DownloadFile]
+        ) extends WeakReference[DownloadFile](file, referenceQueue) {
+
+      val filePath = file.path()
 
       def cleanUp(): Unit = {
         logDebug(s"Clean up file $filePath")
 
-        if (!new File(filePath).delete()) {
+        if (!file.delete()) {
           logDebug(s"Fail to delete file $filePath")
         }
       }
     }
 
-    private val referenceQueue = new JReferenceQueue[File]
+    private val referenceQueue = new JReferenceQueue[DownloadFile]
     private val referenceBuffer = Collections.newSetFromMap[ReferenceWithCleanup](
       new ConcurrentHashMap)
 
@@ -1610,11 +1617,21 @@ private[spark] object BlockManager {
     cleaningThread.setName("RemoteBlock-temp-file-clean-thread")
     cleaningThread.start()
 
-    override def createTempFile(): File = {
-      blockManager.diskBlockManager.createTempLocalBlock()._2
+    override def createTempFile(transportConf: TransportConf): DownloadFile = {
+      val file = blockManager.diskBlockManager.createTempLocalBlock()._2
+      encryptionKey match {
+        case Some(key) =>
+          // encryption is enabled, so when we read the decrypted data off the network, we need to
+          // encrypt it when writing to disk.  Note that the data may have been encrypted when it
+          // was cached on disk on the remote side, but it was already decrypted by now (see
+          // EncryptedBlockData).
+          new EncryptedDownloadFile(file, key)
+        case None =>
+          new SimpleDownloadFile(file, transportConf)
+      }
     }
 
-    override def registerTempFileToClean(file: File): Boolean = {
+    override def registerTempFileToClean(file: DownloadFile): Boolean = {
       referenceBuffer.add(new ReferenceWithCleanup(file, referenceQueue))
     }
 
@@ -1642,4 +1659,39 @@ private[spark] object BlockManager {
       }
     }
   }
+
+  /**
+   * A DownloadFile that encrypts data when it is written, and decrypts when it's read.
+   */
+  private class EncryptedDownloadFile(
+      file: File,
+      key: Array[Byte]) extends DownloadFile {
+
+    private val env = SparkEnv.get
+
+    override def delete(): Boolean = file.delete()
+
+    override def openForWriting(): DownloadFileWritableChannel = {
+      new EncryptedDownloadWritableChannel()
+    }
+
+    override def path(): String = file.getAbsolutePath
+
+    private class EncryptedDownloadWritableChannel extends DownloadFileWritableChannel {
+      private val countingOutput: CountingWritableChannel = new CountingWritableChannel(
+        Channels.newChannel(env.serializerManager.wrapForEncryption(new FileOutputStream(file))))
+
+      override def closeAndRead(): ManagedBuffer = {
+        countingOutput.close()
+        val size = countingOutput.getCount
+        new EncryptedManagedBuffer(new EncryptedBlockData(file, size, env.conf, key))
+      }
+
+      override def write(src: ByteBuffer): Int = countingOutput.write(src)
+
+      override def isOpen: Boolean = countingOutput.isOpen()
+
+      override def close(): Unit = countingOutput.close()
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 39249d4..e61cfe8 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -30,6 +30,7 @@ import io.netty.channel.DefaultFileRegion
 
 import org.apache.spark.{SecurityManager, SparkConf}
 import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils}
 import org.apache.spark.security.CryptoStreamUtils
 import org.apache.spark.util.Utils
@@ -261,7 +262,22 @@ private class EncryptedBlockData(
         throw e
     }
   }
+}
+
+private class EncryptedManagedBuffer(val blockData: EncryptedBlockData) extends ManagedBuffer {
+
+  // This is the size of the decrypted data
+  override def size(): Long = blockData.size
+
+  override def nioByteBuffer(): ByteBuffer = blockData.toByteBuffer()
+
+  override def convertToNetty(): AnyRef = blockData.toNetty()
+
+  override def createInputStream(): InputStream = blockData.toInputStream()
+
+  override def retain(): ManagedBuffer = this
 
+  override def release(): ManagedBuffer = this
 }
 
 private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long)

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index dd9df74..94def4d 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.storage
 
-import java.io.{File, InputStream, IOException}
+import java.io.{InputStream, IOException}
 import java.nio.ByteBuffer
 import java.util.concurrent.LinkedBlockingQueue
 import javax.annotation.concurrent.GuardedBy
@@ -28,7 +28,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
 import org.apache.spark.{SparkException, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager}
+import org.apache.spark.network.shuffle._
+import org.apache.spark.network.util.TransportConf
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util.Utils
 import org.apache.spark.util.io.ChunkedByteBufferOutputStream
@@ -69,7 +70,7 @@ final class ShuffleBlockFetcherIterator(
     maxBlocksInFlightPerAddress: Int,
     maxReqSizeShuffleToMem: Long,
     detectCorrupt: Boolean)
-  extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging {
+  extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {
 
   import ShuffleBlockFetcherIterator._
 
@@ -148,7 +149,7 @@ final class ShuffleBlockFetcherIterator(
    * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
    */
   @GuardedBy("this")
-  private[this] val shuffleFilesSet = mutable.HashSet[File]()
+  private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]()
 
   initialize()
 
@@ -162,11 +163,15 @@ final class ShuffleBlockFetcherIterator(
     currentResult = null
   }
 
-  override def createTempFile(): File = {
-    blockManager.diskBlockManager.createTempLocalBlock()._2
+  override def createTempFile(transportConf: TransportConf): DownloadFile = {
+    // we never need to do any encryption or decryption here, regardless of configs, because that
+    // is handled at another layer in the code.  When encryption is enabled, shuffle data is written
+    // to disk encrypted in the first place, and sent over the network still encrypted.
+    new SimpleDownloadFile(
+      blockManager.diskBlockManager.createTempLocalBlock()._2, transportConf)
   }
 
-  override def registerTempFileToClean(file: File): Boolean = synchronized {
+  override def registerTempFileToClean(file: DownloadFile): Boolean = synchronized {
     if (isZombie) {
       false
     } else {
@@ -202,7 +207,7 @@ final class ShuffleBlockFetcherIterator(
     }
     shuffleFilesSet.foreach { file =>
       if (!file.delete()) {
-        logWarning("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath())
+        logWarning("Failed to cleanup shuffle fetch temp file " + file.path())
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 629eed4..4d2168f 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -44,7 +44,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
 import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
 import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf}
 import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
 import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor}
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.LiveListenerBus
@@ -1403,7 +1403,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
 
   class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
     var numCalls = 0
-    var tempFileManager: TempFileManager = null
+    var tempFileManager: DownloadFileManager = null
 
     override def init(blockDataManager: BlockDataManager): Unit = {}
 
@@ -1413,7 +1413,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
         execId: String,
         blockIds: Array[String],
         listener: BlockFetchingListener,
-        tempFileManager: TempFileManager): Unit = {
+        tempFileManager: DownloadFileManager): Unit = {
       listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
     }
 
@@ -1440,7 +1440,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
         port: Int,
         execId: String,
         blockId: String,
-        tempFileManager: TempFileManager): ManagedBuffer = {
+        tempFileManager: DownloadFileManager): ManagedBuffer = {
       numCalls += 1
       this.tempFileManager = tempFileManager
       if (numCalls <= maxFailures) {

http://git-wip-us.apache.org/repos/asf/spark/blob/575fea12/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 692ae3b..24244f9 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester
 import org.apache.spark.{SparkFunSuite, TaskContext}
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager}
 import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util.Utils
@@ -482,12 +482,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     val remoteBlocks = Map[BlockId, ManagedBuffer](
       ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
     val transfer = mock(classOf[BlockTransferService])
-    var tempFileManager: TempFileManager = null
+    var tempFileManager: DownloadFileManager = null
     when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
       .thenAnswer(new Answer[Unit] {
         override def answer(invocation: InvocationOnMock): Unit = {
           val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-          tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager]
+          tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager]
           Future {
             listener.onBlockFetchSuccess(
               ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[3/4] spark git commit: [PYSPARK][SQL] Updates to RowQueue

Posted by ir...@apache.org.
[PYSPARK][SQL] Updates to RowQueue

Tested with updates to RowQueueSuite


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6d742d1b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6d742d1b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6d742d1b

Branch: refs/heads/branch-2.3
Commit: 6d742d1bd71aa3803dce91a830b37284cb18cf70
Parents: 09dd34c
Author: Imran Rashid <ir...@cloudera.com>
Authored: Thu Sep 6 12:11:47 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Thu Sep 13 09:19:56 2018 -0500

----------------------------------------------------------------------
 .../spark/sql/execution/python/RowQueue.scala   | 27 ++++++++++++++-----
 .../sql/execution/python/RowQueueSuite.scala    | 28 +++++++++++++++-----
 2 files changed, 41 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6d742d1b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
index e2fa6e7..d2820ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -21,9 +21,10 @@ import java.io._
 
 import com.google.common.io.Closeables
 
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkEnv, SparkException}
 import org.apache.spark.io.NioBufferedFileInputStream
 import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
+import org.apache.spark.serializer.SerializerManager
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.memory.MemoryBlock
@@ -108,9 +109,13 @@ private[python] abstract class InMemoryRowQueue(val page: MemoryBlock, numFields
  * A RowQueue that is backed by a file on disk. This queue will stop accepting new rows once any
  * reader has begun reading from the queue.
  */
-private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueue {
-  private var out = new DataOutputStream(
-    new BufferedOutputStream(new FileOutputStream(file.toString)))
+private[python] case class DiskRowQueue(
+    file: File,
+    fields: Int,
+    serMgr: SerializerManager) extends RowQueue {
+
+  private var out = new DataOutputStream(serMgr.wrapForEncryption(
+    new BufferedOutputStream(new FileOutputStream(file.toString))))
   private var unreadBytes = 0L
 
   private var in: DataInputStream = _
@@ -131,7 +136,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
     if (out != null) {
       out.close()
       out = null
-      in = new DataInputStream(new NioBufferedFileInputStream(file))
+      in = new DataInputStream(serMgr.wrapForEncryption(
+        new NioBufferedFileInputStream(file)))
     }
 
     if (unreadBytes > 0) {
@@ -166,7 +172,8 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
 private[python] case class HybridRowQueue(
     memManager: TaskMemoryManager,
     tempDir: File,
-    numFields: Int)
+    numFields: Int,
+    serMgr: SerializerManager)
   extends MemoryConsumer(memManager) with RowQueue {
 
   // Each buffer should have at least one row
@@ -212,7 +219,7 @@ private[python] case class HybridRowQueue(
   }
 
   private def createDiskQueue(): RowQueue = {
-    DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields)
+    DiskRowQueue(File.createTempFile("buffer", "", tempDir), numFields, serMgr)
   }
 
   private def createNewQueue(required: Long): RowQueue = {
@@ -279,3 +286,9 @@ private[python] case class HybridRowQueue(
     }
   }
 }
+
+private[python] object HybridRowQueue {
+  def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = {
+    HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/6d742d1b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
index ffda33c..1ec9986 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/RowQueueSuite.scala
@@ -20,12 +20,15 @@ package org.apache.spark.sql.execution.python
 import java.io.File
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.memory.{MemoryManager, TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.internal.config._
+import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager}
+import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite}
+import org.apache.spark.serializer.{JavaSerializer, SerializerManager}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.unsafe.memory.MemoryBlock
 import org.apache.spark.util.Utils
 
-class RowQueueSuite extends SparkFunSuite {
+class RowQueueSuite extends SparkFunSuite with EncryptionFunSuite {
 
   test("in-memory queue") {
     val page = MemoryBlock.fromLongArray(new Array[Long](1<<10))
@@ -53,10 +56,20 @@ class RowQueueSuite extends SparkFunSuite {
     queue.close()
   }
 
-  test("disk queue") {
+  private def createSerializerManager(conf: SparkConf): SerializerManager = {
+    val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
+      Some(CryptoStreamUtils.createKey(conf))
+    } else {
+      None
+    }
+    new SerializerManager(new JavaSerializer(conf), conf, ioEncryptionKey)
+  }
+
+  encryptionTest("disk queue") { conf =>
+    val serManager = createSerializerManager(conf)
     val dir = Utils.createTempDir().getCanonicalFile
     dir.mkdirs()
-    val queue = DiskRowQueue(new File(dir, "buffer"), 1)
+    val queue = DiskRowQueue(new File(dir, "buffer"), 1, serManager)
     val row = new UnsafeRow(1)
     row.pointTo(new Array[Byte](16), 16)
     val n = 1000
@@ -81,11 +94,12 @@ class RowQueueSuite extends SparkFunSuite {
     queue.close()
   }
 
-  test("hybrid queue") {
-    val mem = new TestMemoryManager(new SparkConf())
+  encryptionTest("hybrid queue") { conf =>
+    val serManager = createSerializerManager(conf)
+    val mem = new TestMemoryManager(conf)
     mem.limit(4<<10)
     val taskM = new TaskMemoryManager(mem, 0)
-    val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1)
+    val queue = HybridRowQueue(taskM, Utils.createTempDir().getCanonicalFile, 1, serManager)
     val row = new UnsafeRow(1)
     row.pointTo(new Array[Byte](16), 16)
     val n = (4<<10) / 16 * 3


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/4] spark git commit: [PYSPARK] Updates to pyspark broadcast

Posted by ir...@apache.org.
[PYSPARK] Updates to pyspark broadcast


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/09dd34cb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/09dd34cb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/09dd34cb

Branch: refs/heads/branch-2.3
Commit: 09dd34cb1706f2477a89174d6a1a0f17ed5b0a65
Parents: a2a54a5
Author: Imran Rashid <ir...@cloudera.com>
Authored: Mon Aug 13 21:35:34 2018 -0500
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Thu Sep 13 09:19:56 2018 -0500

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 297 ++++++++++++++++---
 .../apache/spark/api/python/PythonRunner.scala  |  52 +++-
 .../spark/api/python/PythonRDDSuite.scala       |  23 +-
 dev/sparktestsupport/modules.py                 |   2 +
 python/pyspark/broadcast.py                     |  58 +++-
 python/pyspark/context.py                       |  49 ++-
 python/pyspark/serializers.py                   |  58 ++++
 python/pyspark/test_broadcast.py                | 126 ++++++++
 python/pyspark/test_serializers.py              |  90 ++++++
 python/pyspark/worker.py                        |  24 +-
 10 files changed, 695 insertions(+), 84 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 8bc0ff7..5e6bd96 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,8 +24,10 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
 import scala.language.existentials
-import scala.util.control.NonFatal
+import scala.util.Try
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.io.compress.CompressionCodec
@@ -37,6 +39,7 @@ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.input.PortableDataStream
 import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.security.SocketAuthHelper
 import org.apache.spark.util._
@@ -168,27 +171,34 @@ private[spark] object PythonRDD extends Logging {
 
   def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
-    val file = new DataInputStream(new FileInputStream(filename))
+    readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
+  }
+
+  def readRDDFromInputStream(
+      sc: SparkContext,
+      in: InputStream,
+      parallelism: Int): JavaRDD[Array[Byte]] = {
+    val din = new DataInputStream(in)
     try {
       val objs = new mutable.ArrayBuffer[Array[Byte]]
       try {
         while (true) {
-          val length = file.readInt()
+          val length = din.readInt()
           val obj = new Array[Byte](length)
-          file.readFully(obj)
+          din.readFully(obj)
           objs += obj
         }
       } catch {
         case eof: EOFException => // No-op
       }
-      JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
+      JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
     } finally {
-      file.close()
+      din.close()
     }
   }
 
-  def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = {
-    sc.broadcast(new PythonBroadcast(path))
+  def setupBroadcast(path: String): PythonBroadcast = {
+    new PythonBroadcast(path)
   }
 
   def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -398,34 +408,15 @@ private[spark] object PythonRDD extends Logging {
    *         data collected from this job, and the secret for authentication.
    */
   def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
-    val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
-    // Close the socket if no connection in 15 seconds
-    serverSocket.setSoTimeout(15000)
-
-    new Thread(threadName) {
-      setDaemon(true)
-      override def run() {
-        try {
-          val sock = serverSocket.accept()
-          authHelper.authClient(sock)
-
-          val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
-          Utils.tryWithSafeFinally {
-            writeIteratorToStream(items, out)
-          } {
-            out.close()
-            sock.close()
-          }
-        } catch {
-          case NonFatal(e) =>
-            logError(s"Error while sending iterator", e)
-        } finally {
-          serverSocket.close()
-        }
+    val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
+      val out = new DataOutputStream(new BufferedOutputStream(s.getOutputStream()))
+      Utils.tryWithSafeFinally {
+        writeIteratorToStream(items, out)
+      } {
+        out.close()
       }
-    }.start()
-
-    Array(serverSocket.getLocalPort, authHelper.secret)
+    }
+    Array(port, secret)
   }
 
   private def getMergedConf(confAsMap: java.util.HashMap[String, String],
@@ -643,13 +634,11 @@ private[spark] class PythonAccumulatorV2(
   }
 }
 
-/**
- * A Wrapper for Python Broadcast, which is written into disk by Python. It also will
- * write the data into disk after deserialization, then Python can read it from disks.
- */
 // scalastyle:off no.finalize
 private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
-  with Logging {
+    with Logging {
+
+  private var encryptionServer: PythonServer[Unit] = null
 
   /**
    * Read data from disks, then copy it to `out`
@@ -692,5 +681,233 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
     }
     super.finalize()
   }
+
+  def setupEncryptionServer(): Array[Any] = {
+    encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") {
+      override def handleConnection(sock: Socket): Unit = {
+        val env = SparkEnv.get
+        val in = sock.getInputStream()
+        val dir = new File(Utils.getLocalDir(env.conf))
+        val file = File.createTempFile("broadcast", "", dir)
+        path = file.getAbsolutePath
+        val out = env.serializerManager.wrapForEncryption(new FileOutputStream(path))
+        DechunkedInputStream.dechunkAndCopyToOutput(in, out)
+      }
+    }
+    Array(encryptionServer.port, encryptionServer.secret)
+  }
+
+  def waitTillDataReceived(): Unit = encryptionServer.getResult()
 }
 // scalastyle:on no.finalize
+
+/**
+ * The inverse of pyspark's ChunkedStream for sending broadcast data.
+ * Tested from python tests.
+ */
+private[spark] class DechunkedInputStream(wrapped: InputStream) extends InputStream with Logging {
+  private val din = new DataInputStream(wrapped)
+  private var remainingInChunk = din.readInt()
+
+  override def read(): Int = {
+    val into = new Array[Byte](1)
+    val n = read(into, 0, 1)
+    if (n == -1) {
+      -1
+    } else {
+      // if you just cast a byte to an int, then anything > 127 is negative, which is interpreted
+      // as an EOF
+      into(0) & 0xFF
+    }
+  }
+
+  override def read(dest: Array[Byte], off: Int, len: Int): Int = {
+    if (remainingInChunk == -1) {
+      return -1
+    }
+    var destSpace = len
+    var destPos = off
+    while (destSpace > 0 && remainingInChunk != -1) {
+      val toCopy = math.min(remainingInChunk, destSpace)
+      val read = din.read(dest, destPos, toCopy)
+      destPos += read
+      destSpace -= read
+      remainingInChunk -= read
+      if (remainingInChunk == 0) {
+        remainingInChunk = din.readInt()
+      }
+    }
+    assert(destSpace == 0 || remainingInChunk == -1)
+    return destPos - off
+  }
+
+  override def close(): Unit = wrapped.close()
+}
+
+/**
+ * The inverse of pyspark's ChunkedStream for sending data of unknown size.
+ *
+ * We might be serializing a really large object from python -- we don't want
+ * python to buffer the whole thing in memory, nor can it write to a file,
+ * so we don't know the length in advance.  So python writes it in chunks, each chunk
+ * preceeded by a length, till we get a "length" of -1 which serves as EOF.
+ *
+ * Tested from python tests.
+ */
+private[spark] object DechunkedInputStream {
+
+  /**
+   * Dechunks the input, copies to output, and closes both input and the output safely.
+   */
+  def dechunkAndCopyToOutput(chunked: InputStream, out: OutputStream): Unit = {
+    val dechunked = new DechunkedInputStream(chunked)
+    Utils.tryWithSafeFinally {
+      Utils.copyStream(dechunked, out)
+    } {
+      JavaUtils.closeQuietly(out)
+      JavaUtils.closeQuietly(dechunked)
+    }
+  }
+}
+
+/**
+ * Creates a server in the jvm to communicate with python for handling one batch of data, with
+ * authentication and error handling.
+ */
+private[spark] abstract class PythonServer[T](
+    authHelper: SocketAuthHelper,
+    threadName: String) {
+
+  def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
+  def this(threadName: String) = this(SparkEnv.get, threadName)
+
+  val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock =>
+    promise.complete(Try(handleConnection(sock)))
+  }
+
+  /**
+   * Handle a connection which has already been authenticated.  Any error from this function
+   * will clean up this connection and the entire server, and get propogated to [[getResult]].
+   */
+  def handleConnection(sock: Socket): T
+
+  val promise = Promise[T]()
+
+  /**
+   * Blocks indefinitely for [[handleConnection]] to finish, and returns that result.  If
+   * handleConnection throws an exception, this will throw an exception which includes the original
+   * exception as a cause.
+   */
+  def getResult(): T = {
+    getResult(Duration.Inf)
+  }
+
+  def getResult(wait: Duration): T = {
+    ThreadUtils.awaitResult(promise.future, wait)
+  }
+
+}
+
+private[spark] object PythonServer {
+
+  /**
+   * Create a socket server and run user function on the socket in a background thread.
+   *
+   * The socket server can only accept one connection, or close if no connection
+   * in 15 seconds.
+   *
+   * The thread will terminate after the supplied user function, or if there are any exceptions.
+   *
+   * If you need to get a result of the supplied function, create a subclass of [[PythonServer]]
+   *
+   * @return The port number of a local socket and the secret for authentication.
+   */
+  def setupOneConnectionServer(
+      authHelper: SocketAuthHelper,
+      threadName: String)
+      (func: Socket => Unit): (Int, String) = {
+    val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
+    // Close the socket if no connection in 15 seconds
+    serverSocket.setSoTimeout(15000)
+
+    new Thread(threadName) {
+      setDaemon(true)
+      override def run(): Unit = {
+        var sock: Socket = null
+        try {
+          sock = serverSocket.accept()
+          authHelper.authClient(sock)
+          func(sock)
+        } finally {
+          JavaUtils.closeQuietly(serverSocket)
+          JavaUtils.closeQuietly(sock)
+        }
+      }
+    }.start()
+    (serverSocket.getLocalPort, authHelper.secret)
+  }
+}
+
+/**
+ * Sends decrypted broadcast data to python worker.  See [[PythonRunner]] for entire protocol.
+ */
+private[spark] class EncryptedPythonBroadcastServer(
+    val env: SparkEnv,
+    val idsAndFiles: Seq[(Long, String)])
+    extends PythonServer[Unit]("broadcast-decrypt-server") with Logging {
+
+  override def handleConnection(socket: Socket): Unit = {
+    val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
+    var socketIn: InputStream = null
+    // send the broadcast id, then the decrypted data.  We don't need to send the length, the
+    // the python pickle module just needs a stream.
+    Utils.tryWithSafeFinally {
+      (idsAndFiles).foreach { case (id, path) =>
+        out.writeLong(id)
+        val in = env.serializerManager.wrapForEncryption(new FileInputStream(path))
+        Utils.tryWithSafeFinally {
+          Utils.copyStream(in, out, false)
+        } {
+          in.close()
+        }
+      }
+      logTrace("waiting for python to accept broadcast data over socket")
+      out.flush()
+      socketIn = socket.getInputStream()
+      socketIn.read()
+      logTrace("done serving broadcast data")
+    } {
+      JavaUtils.closeQuietly(socketIn)
+      JavaUtils.closeQuietly(out)
+    }
+  }
+
+  def waitTillBroadcastDataSent(): Unit = {
+    getResult()
+  }
+}
+
+/**
+ * Helper for making RDD[Array[Byte]] from some python data, by reading the data from python
+ * over a socket.  This is used in preference to writing data to a file when encryption is enabled.
+ */
+private[spark] abstract class PythonRDDServer
+    extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
+
+  def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
+    val in = sock.getInputStream()
+    val dechunkedInput: InputStream = new DechunkedInputStream(in)
+    streamToRDD(dechunkedInput)
+  }
+
+  protected def streamToRDD(input: InputStream): RDD[Array[Byte]]
+
+}
+
+private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
+    extends PythonRDDServer {
+
+  override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
+    PythonRDD.readRDDFromInputStream(sc, input, parallelism)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 719ce5b..754a654 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -193,19 +193,51 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         val newBids = broadcastVars.map(_.id).toSet
         // number of different broadcasts
         val toRemove = oldBids.diff(newBids)
-        val cnt = toRemove.size + newBids.diff(oldBids).size
+        val addedBids = newBids.diff(oldBids)
+        val cnt = toRemove.size + addedBids.size
+        val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty
+        dataOut.writeBoolean(needsDecryptionServer)
         dataOut.writeInt(cnt)
-        for (bid <- toRemove) {
-          // remove the broadcast from worker
-          dataOut.writeLong(- bid - 1)  // bid >= 0
-          oldBids.remove(bid)
+        def sendBidsToRemove(): Unit = {
+          for (bid <- toRemove) {
+            // remove the broadcast from worker
+            dataOut.writeLong(-bid - 1) // bid >= 0
+            oldBids.remove(bid)
+          }
         }
-        for (broadcast <- broadcastVars) {
-          if (!oldBids.contains(broadcast.id)) {
+        if (needsDecryptionServer) {
+          // if there is encryption, we setup a server which reads the encrypted files, and sends
+          // the decrypted data to python
+          val idsAndFiles = broadcastVars.flatMap { broadcast =>
+            if (oldBids.contains(broadcast.id)) {
+              None
+            } else {
+              Some((broadcast.id, broadcast.value.path))
+            }
+          }
+          val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
+          dataOut.writeInt(server.port)
+          logTrace(s"broadcast decryption server setup on ${server.port}")
+          PythonRDD.writeUTF(server.secret, dataOut)
+          sendBidsToRemove()
+          idsAndFiles.foreach { case (id, _) =>
             // send new broadcast
-            dataOut.writeLong(broadcast.id)
-            PythonRDD.writeUTF(broadcast.value.path, dataOut)
-            oldBids.add(broadcast.id)
+            dataOut.writeLong(id)
+            oldBids.add(id)
+          }
+          dataOut.flush()
+          logTrace("waiting for python to read decrypted broadcast data from server")
+          server.waitTillBroadcastDataSent()
+          logTrace("done sending decrypted data to python")
+        } else {
+          sendBidsToRemove()
+          for (broadcast <- broadcastVars) {
+            if (!oldBids.contains(broadcast.id)) {
+              // send new broadcast
+              dataOut.writeLong(broadcast.id)
+              PythonRDD.writeUTF(broadcast.value.path, dataOut)
+              oldBids.add(broadcast.id)
+            }
           }
         }
         dataOut.flush()

http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 05b4e67..6f9b583 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -18,9 +18,13 @@
 package org.apache.spark.api.python
 
 import java.io.{ByteArrayOutputStream, DataOutputStream}
+import java.net.{InetAddress, Socket}
 import java.nio.charset.StandardCharsets
 
-import org.apache.spark.SparkFunSuite
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.security.SocketAuthHelper
 
 class PythonRDDSuite extends SparkFunSuite {
 
@@ -44,4 +48,21 @@ class PythonRDDSuite extends SparkFunSuite {
       ("a".getBytes(StandardCharsets.UTF_8), null),
       (null, "b".getBytes(StandardCharsets.UTF_8))), buffer)
   }
+
+  test("python server error handling") {
+    val authHelper = new SocketAuthHelper(new SparkConf())
+    val errorServer = new ExceptionPythonServer(authHelper)
+    val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
+    authHelper.authToServer(client)
+    val ex = intercept[Exception] { errorServer.getResult(Duration(1, "second")) }
+    assert(ex.getCause().getMessage().contains("exception within handleConnection"))
+  }
+
+  class ExceptionPythonServer(authHelper: SocketAuthHelper)
+      extends PythonServer[Unit](authHelper, "error-server") {
+
+    override def handleConnection(sock: Socket): Unit = {
+      throw new Exception("exception within handleConnection")
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/dev/sparktestsupport/modules.py
----------------------------------------------------------------------
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index b900f0b..d0bff13 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -377,6 +377,8 @@ pyspark_core = Module(
         "pyspark.profiler",
         "pyspark.shuffle",
         "pyspark.tests",
+        "pyspark.test_broadcast",
+        "pyspark.test_serializers",
         "pyspark.util",
     ]
 )

http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 02fc515..3f1298e 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -15,13 +15,16 @@
 # limitations under the License.
 #
 
+import gc
 import os
+import socket
 import sys
-import gc
 from tempfile import NamedTemporaryFile
 import threading
 
 from pyspark.cloudpickle import print_exec
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import ChunkedStream
 from pyspark.util import _exception_message
 
 if sys.version < '3':
@@ -64,19 +67,43 @@ class Broadcast(object):
     >>> large_broadcast = sc.broadcast(range(10000))
     """
 
-    def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
+    def __init__(self, sc=None, value=None, pickle_registry=None, path=None,
+                 sock_file=None):
         """
         Should not be called directly by users -- use L{SparkContext.broadcast()}
         instead.
         """
         if sc is not None:
+            # we're on the driver.  We want the pickled data to end up in a file (maybe encrypted)
             f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
-            self._path = self.dump(value, f)
-            self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path)
+            self._path = f.name
+            python_broadcast = sc._jvm.PythonRDD.setupBroadcast(self._path)
+            if sc._encryption_enabled:
+                # with encryption, we ask the jvm to do the encryption for us, we send it data
+                # over a socket
+                port, auth_secret = python_broadcast.setupEncryptionServer()
+                (encryption_sock_file, _) = local_connect_and_auth(port, auth_secret)
+                broadcast_out = ChunkedStream(encryption_sock_file, 8192)
+            else:
+                # no encryption, we can just write pickled data directly to the file from python
+                broadcast_out = f
+            self.dump(value, broadcast_out)
+            if sc._encryption_enabled:
+                python_broadcast.waitTillDataReceived()
+            self._jbroadcast = sc._jsc.broadcast(python_broadcast)
             self._pickle_registry = pickle_registry
         else:
+            # we're on an executor
             self._jbroadcast = None
-            self._path = path
+            if sock_file is not None:
+                # the jvm is doing decryption for us.  Read the value
+                # immediately from the sock_file
+                self._value = self.load(sock_file)
+            else:
+                # the jvm just dumps the pickled data in path -- we'll unpickle lazily when
+                # the value is requested
+                assert(path is not None)
+                self._path = path
 
     def dump(self, value, f):
         try:
@@ -89,24 +116,25 @@ class Broadcast(object):
             print_exec(sys.stderr)
             raise pickle.PicklingError(msg)
         f.close()
-        return f.name
 
-    def load(self, path):
+    def load_from_path(self, path):
         with open(path, 'rb', 1 << 20) as f:
-            # pickle.load() may create lots of objects, disable GC
-            # temporary for better performance
-            gc.disable()
-            try:
-                return pickle.load(f)
-            finally:
-                gc.enable()
+            return self.load(f)
+
+    def load(self, file):
+        # "file" could also be a socket
+        gc.disable()
+        try:
+            return pickle.load(file)
+        finally:
+            gc.enable()
 
     @property
     def value(self):
         """ Return the broadcasted value
         """
         if not hasattr(self, "_value") and self._path is not None:
-            self._value = self.load(self._path)
+            self._value = self.load_from_path(self._path)
         return self._value
 
     def unpersist(self, blocking=False):

http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 667d0a3..2aac7ba 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -33,9 +33,9 @@ from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
 from pyspark.conf import SparkConf
 from pyspark.files import SparkFiles
-from pyspark.java_gateway import launch_gateway
+from pyspark.java_gateway import launch_gateway, local_connect_and_auth
 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
-    PairDeserializer, AutoBatchedSerializer, NoOpSerializer
+    PairDeserializer, AutoBatchedSerializer, NoOpSerializer, ChunkedStream
 from pyspark.storagelevel import StorageLevel
 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
 from pyspark.traceback_utils import CallSite, first_spark_call
@@ -189,6 +189,13 @@ class SparkContext(object):
         self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
         self._jsc.sc().register(self._javaAccumulator)
 
+        # If encryption is enabled, we need to setup a server in the jvm to read broadcast
+        # data via a socket.
+        # scala's mangled names w/ $ in them require special treatment.
+        encryption_conf = self._jvm.org.apache.spark.internal.config.__getattr__("package$")\
+            .__getattr__("MODULE$").IO_ENCRYPTION_ENABLED()
+        self._encryption_enabled = self._jsc.sc().conf().get(encryption_conf)
+
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
         self.pythonVer = "%d.%d" % sys.version_info[:2]
 
@@ -499,19 +506,31 @@ class SparkContext(object):
 
     def _serialize_to_jvm(self, data, parallelism, serializer):
         """
-        Calling the Java parallelize() method with an ArrayList is too slow,
-        because it sends O(n) Py4J commands.  As an alternative, serialized
-        objects are written to a file and loaded through textFile().
-        """
-        tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
-        try:
-            serializer.dump_stream(data, tempFile)
-            tempFile.close()
-            readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
-            return readRDDFromFile(self._jsc, tempFile.name, parallelism)
-        finally:
-            # readRDDFromFile eagerily reads the file so we can delete right after.
-            os.unlink(tempFile.name)
+        Using py4j to send a large dataset to the jvm is really slow, so we use either a file
+        or a socket if we have encryption enabled.
+        """
+        if self._encryption_enabled:
+            # with encryption, we open a server in java and send the data directly
+            server = self._jvm.PythonParallelizeServer(self._jsc.sc(), parallelism)
+            (sock_file, _) = local_connect_and_auth(server.port(), server.secret())
+            chunked_out = ChunkedStream(sock_file, 8192)
+            serializer.dump_stream(data, chunked_out)
+            chunked_out.close()
+            # this call will block until the server has read all the data and processed it (or
+            # throws an exception)
+            return server.getResult()
+        else:
+            # without encryption, we serialize to a file, and we read the file in java and
+            # parallelize from there.
+            tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
+            try:
+                serializer.dump_stream(data, tempFile)
+                tempFile.close()
+                readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+                return readRDDFromFile(self._jsc, tempFile.name, parallelism)
+            finally:
+                # we eagerly read the file so we can delete right after.
+                os.unlink(tempFile.name)
 
     def pickleFile(self, name, minPartitions=None):
         """

http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 52a7afe..3927d21 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -697,11 +697,69 @@ def write_int(value, stream):
     stream.write(struct.pack("!i", value))
 
 
+def read_bool(stream):
+    length = stream.read(1)
+    if not length:
+        raise EOFError
+    return struct.unpack("!?", length)[0]
+
+
 def write_with_length(obj, stream):
     write_int(len(obj), stream)
     stream.write(obj)
 
 
+class ChunkedStream(object):
+
+    """
+    This file-like object takes a stream of data, of unknown length, and breaks it into fixed
+    length frames.  The intended use case is serializing large data and sending it immediately over
+    a socket -- we do not want to buffer the entire data before sending it, but the receiving end
+    needs to know whether or not there is more data coming.
+
+    It works by buffering the incoming data in some fixed-size chunks.  If the buffer is full, it
+    first sends the buffer size, then the data.  This repeats as long as there is more data to send.
+    When this is closed, it sends the length of whatever data is in the buffer, then that data, and
+    finally a "length" of -1 to indicate the stream has completed.
+    """
+
+    def __init__(self, wrapped, buffer_size):
+        self.buffer_size = buffer_size
+        self.buffer = bytearray(buffer_size)
+        self.current_pos = 0
+        self.wrapped = wrapped
+
+    def write(self, bytes):
+        byte_pos = 0
+        byte_remaining = len(bytes)
+        while byte_remaining > 0:
+            new_pos = byte_remaining + self.current_pos
+            if new_pos < self.buffer_size:
+                # just put it in our buffer
+                self.buffer[self.current_pos:new_pos] = bytes[byte_pos:]
+                self.current_pos = new_pos
+                byte_remaining = 0
+            else:
+                # fill the buffer, send the length then the contents, and start filling again
+                space_left = self.buffer_size - self.current_pos
+                new_byte_pos = byte_pos + space_left
+                self.buffer[self.current_pos:self.buffer_size] = bytes[byte_pos:new_byte_pos]
+                write_int(self.buffer_size, self.wrapped)
+                self.wrapped.write(self.buffer)
+                byte_remaining -= space_left
+                byte_pos = new_byte_pos
+                self.current_pos = 0
+
+    def close(self):
+        # if there is anything left in the buffer, write it out first
+        if self.current_pos > 0:
+            write_int(self.current_pos, self.wrapped)
+            self.wrapped.write(self.buffer[:self.current_pos])
+        # -1 length indicates to the receiving end that we're done.
+        write_int(-1, self.wrapped)
+        self.wrapped.close()
+
+
 if __name__ == '__main__':
     import doctest
     (failure_count, test_count) = doctest.testmod()

http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/test_broadcast.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_broadcast.py b/python/pyspark/test_broadcast.py
new file mode 100644
index 0000000..ce7ca83
--- /dev/null
+++ b/python/pyspark/test_broadcast.py
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import random
+import tempfile
+import unittest
+
+try:
+    import xmlrunner
+except ImportError:
+    xmlrunner = None
+
+from pyspark.broadcast import Broadcast
+from pyspark.conf import SparkConf
+from pyspark.context import SparkContext
+from pyspark.java_gateway import launch_gateway
+from pyspark.serializers import ChunkedStream
+
+
+class BroadcastTest(unittest.TestCase):
+
+    def tearDown(self):
+        if getattr(self, "sc", None) is not None:
+            self.sc.stop()
+            self.sc = None
+
+    def _test_encryption_helper(self, vs):
+        """
+        Creates a broadcast variables for each value in vs, and runs a simple job to make sure the
+        value is the same when it's read in the executors.  Also makes sure there are no task
+        failures.
+        """
+        bs = [self.sc.broadcast(value=v) for v in vs]
+        exec_values = self.sc.parallelize(range(2)).map(lambda x: [b.value for b in bs]).collect()
+        for ev in exec_values:
+            self.assertEqual(ev, vs)
+        # make sure there are no task failures
+        status = self.sc.statusTracker()
+        for jid in status.getJobIdsForGroup():
+            for sid in status.getJobInfo(jid).stageIds:
+                stage_info = status.getStageInfo(sid)
+                self.assertEqual(0, stage_info.numFailedTasks)
+
+    def _test_multiple_broadcasts(self, *extra_confs):
+        """
+        Test broadcast variables make it OK to the executors.  Tests multiple broadcast variables,
+        and also multiple jobs.
+        """
+        conf = SparkConf()
+        for key, value in extra_confs:
+            conf.set(key, value)
+        conf.setMaster("local-cluster[2,1,1024]")
+        self.sc = SparkContext(conf=conf)
+        self._test_encryption_helper([5])
+        self._test_encryption_helper([5, 10, 20])
+
+    def test_broadcast_with_encryption(self):
+        self._test_multiple_broadcasts(("spark.io.encryption.enabled", "true"))
+
+    def test_broadcast_no_encryption(self):
+        self._test_multiple_broadcasts()
+
+
+class BroadcastFrameProtocolTest(unittest.TestCase):
+
+    @classmethod
+    def setUpClass(cls):
+        gateway = launch_gateway(SparkConf())
+        cls._jvm = gateway.jvm
+        cls.longMessage = True
+        random.seed(42)
+
+    def _test_chunked_stream(self, data, py_buf_size):
+        # write data using the chunked protocol from python.
+        chunked_file = tempfile.NamedTemporaryFile(delete=False)
+        dechunked_file = tempfile.NamedTemporaryFile(delete=False)
+        dechunked_file.close()
+        try:
+            out = ChunkedStream(chunked_file, py_buf_size)
+            out.write(data)
+            out.close()
+            # now try to read it in java
+            jin = self._jvm.java.io.FileInputStream(chunked_file.name)
+            jout = self._jvm.java.io.FileOutputStream(dechunked_file.name)
+            self._jvm.DechunkedInputStream.dechunkAndCopyToOutput(jin, jout)
+            # java should have decoded it back to the original data
+            self.assertEqual(len(data), os.stat(dechunked_file.name).st_size)
+            with open(dechunked_file.name, "rb") as f:
+                byte = f.read(1)
+                idx = 0
+                while byte:
+                    self.assertEqual(data[idx], bytearray(byte)[0], msg="idx = " + str(idx))
+                    byte = f.read(1)
+                    idx += 1
+        finally:
+            os.unlink(chunked_file.name)
+            os.unlink(dechunked_file.name)
+
+    def test_chunked_stream(self):
+        def random_bytes(n):
+            return bytearray(random.getrandbits(8) for _ in range(n))
+        for data_length in [1, 10, 100, 10000]:
+            for buffer_length in [1, 2, 5, 8192]:
+                self._test_chunked_stream(random_bytes(data_length), buffer_length)
+
+if __name__ == '__main__':
+    from pyspark.test_broadcast import *
+    if xmlrunner:
+        unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+    else:
+        unittest.main()

http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/test_serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/test_serializers.py b/python/pyspark/test_serializers.py
new file mode 100644
index 0000000..5064e9f
--- /dev/null
+++ b/python/pyspark/test_serializers.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import io
+import math
+import struct
+import sys
+import unittest
+
+try:
+    import xmlrunner
+except ImportError:
+    xmlrunner = None
+
+from pyspark import serializers
+
+
+def read_int(b):
+    return struct.unpack("!i", b)[0]
+
+
+def write_int(i):
+    return struct.pack("!i", i)
+
+
+class SerializersTest(unittest.TestCase):
+
+    def test_chunked_stream(self):
+        original_bytes = bytearray(range(100))
+        for data_length in [1, 10, 100]:
+            for buffer_length in [1, 2, 3, 5, 20, 99, 100, 101, 500]:
+                dest = ByteArrayOutput()
+                stream_out = serializers.ChunkedStream(dest, buffer_length)
+                stream_out.write(original_bytes[:data_length])
+                stream_out.close()
+                num_chunks = int(math.ceil(float(data_length) / buffer_length))
+                # length for each chunk, and a final -1 at the very end
+                exp_size = (num_chunks + 1) * 4 + data_length
+                self.assertEqual(len(dest.buffer), exp_size)
+                dest_pos = 0
+                data_pos = 0
+                for chunk_idx in range(num_chunks):
+                    chunk_length = read_int(dest.buffer[dest_pos:(dest_pos + 4)])
+                    if chunk_idx == num_chunks - 1:
+                        exp_length = data_length % buffer_length
+                        if exp_length == 0:
+                            exp_length = buffer_length
+                    else:
+                        exp_length = buffer_length
+                    self.assertEqual(chunk_length, exp_length)
+                    dest_pos += 4
+                    dest_chunk = dest.buffer[dest_pos:dest_pos + chunk_length]
+                    orig_chunk = original_bytes[data_pos:data_pos + chunk_length]
+                    self.assertEqual(dest_chunk, orig_chunk)
+                    dest_pos += chunk_length
+                    data_pos += chunk_length
+                # ends with a -1
+                self.assertEqual(dest.buffer[-4:], write_int(-1))
+
+
+class ByteArrayOutput(object):
+    def __init__(self):
+        self.buffer = bytearray()
+
+    def write(self, b):
+        self.buffer += b
+
+    def close(self):
+        pass
+
+if __name__ == '__main__':
+    from pyspark.test_serializers import *
+    if xmlrunner:
+        unittest.main(testRunner=xmlrunner.XMLTestRunner(output='target/test-reports'))
+    else:
+        unittest.main()

http://git-wip-us.apache.org/repos/asf/spark/blob/09dd34cb/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 812f4b2..942a7f3 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -31,7 +31,7 @@ from pyspark.java_gateway import local_connect_and_auth
 from pyspark.taskcontext import TaskContext
 from pyspark.files import SparkFiles
 from pyspark.rdd import PythonEvalType
-from pyspark.serializers import write_with_length, write_int, read_long, \
+from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
     BatchedSerializer, ArrowStreamPandasSerializer
 from pyspark.sql.types import to_arrow_type
@@ -206,16 +206,34 @@ def main(infile, outfile):
             importlib.invalidate_caches()
 
         # fetch names and values of broadcast variables
+        needs_broadcast_decryption_server = read_bool(infile)
         num_broadcast_variables = read_int(infile)
+        if needs_broadcast_decryption_server:
+            # read the decrypted data from a server in the jvm
+            port = read_int(infile)
+            auth_secret = utf8_deserializer.loads(infile)
+            (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
+
         for _ in range(num_broadcast_variables):
             bid = read_long(infile)
             if bid >= 0:
-                path = utf8_deserializer.loads(infile)
-                _broadcastRegistry[bid] = Broadcast(path=path)
+                if needs_broadcast_decryption_server:
+                    read_bid = read_long(broadcast_sock_file)
+                    assert(read_bid == bid)
+                    _broadcastRegistry[bid] = \
+                        Broadcast(sock_file=broadcast_sock_file)
+                else:
+                    path = utf8_deserializer.loads(infile)
+                    _broadcastRegistry[bid] = Broadcast(path=path)
+
             else:
                 bid = - bid - 1
                 _broadcastRegistry.pop(bid)
 
+        if needs_broadcast_decryption_server:
+            broadcast_sock_file.write(b'1')
+            broadcast_sock_file.close()
+
         _accumulatorRegistry.clear()
         eval_type = read_int(infile)
         if eval_type == PythonEvalType.NON_UDF:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[4/4] spark git commit: [SPARK-25253][PYSPARK] Refactor local connection & auth code

Posted by ir...@apache.org.
[SPARK-25253][PYSPARK] Refactor local connection & auth code

This eliminates some duplication in the code to connect to a server on localhost to talk directly to the jvm.  Also it gives consistent ipv6 and error handling.  Two other incidental changes, that shouldn't matter:
1) python barrier tasks perform authentication immediately (rather than waiting for the BARRIER_FUNCTION indicator)
2) for `rdd._load_from_socket`, the timeout is only increased after authentication.

Closes #22247 from squito/py_connection_refactor.

Authored-by: Imran Rashid <ir...@cloudera.com>
Signed-off-by: hyukjinkwon <gu...@apache.org>
(cherry picked from commit 38391c9aa8a88fcebb337934f30298a32d91596b)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a2a54a5f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a2a54a5f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a2a54a5f

Branch: refs/heads/branch-2.3
Commit: a2a54a5f49364a1825932c9f04eb0ff82dd7d465
Parents: 9ac9f36
Author: Imran Rashid <ir...@cloudera.com>
Authored: Wed Aug 29 09:47:38 2018 +0800
Committer: Imran Rashid <ir...@cloudera.com>
Committed: Thu Sep 13 09:19:56 2018 -0500

----------------------------------------------------------------------
 python/pyspark/java_gateway.py | 32 +++++++++++++++++++++++++++++++-
 python/pyspark/rdd.py          | 24 ++----------------------
 python/pyspark/worker.py       |  7 ++-----
 3 files changed, 35 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a2a54a5f/python/pyspark/java_gateway.py
----------------------------------------------------------------------
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 0afbe9d..eed866d 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -134,7 +134,7 @@ def launch_gateway(conf=None):
     return gateway
 
 
-def do_server_auth(conn, auth_secret):
+def _do_server_auth(conn, auth_secret):
     """
     Performs the authentication protocol defined by the SocketAuthHelper class on the given
     file-like object 'conn'.
@@ -145,3 +145,33 @@ def do_server_auth(conn, auth_secret):
     if reply != "ok":
         conn.close()
         raise Exception("Unexpected reply from iterator server.")
+
+
+def local_connect_and_auth(port, auth_secret):
+    """
+    Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
+    Handles IPV4 & IPV6, does some error handling.
+    :param port
+    :param auth_secret
+    :return: a tuple with (sockfile, sock)
+    """
+    sock = None
+    errors = []
+    # Support for both IPv4 and IPv6.
+    # On most of IPv6-ready systems, IPv6 will take precedence.
+    for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
+        af, socktype, proto, _, sa = res
+        try:
+            sock = socket.socket(af, socktype, proto)
+            sock.settimeout(15)
+            sock.connect(sa)
+            sockfile = sock.makefile("rwb", 65536)
+            _do_server_auth(sockfile, auth_secret)
+            return (sockfile, sock)
+        except socket.error as e:
+            emsg = _exception_message(e)
+            errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
+            sock.close()
+            sock = None
+    else:
+        raise Exception("could not open socket: %s" % errors)

http://git-wip-us.apache.org/repos/asf/spark/blob/a2a54a5f/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index eadb5ab..3f6a8e6 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,7 +39,7 @@ if sys.version > '3':
 else:
     from itertools import imap as map, ifilter as filter
 
-from pyspark.java_gateway import do_server_auth
+from pyspark.java_gateway import local_connect_and_auth
 from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
     BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
     PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
@@ -139,30 +139,10 @@ def _parse_memory(s):
 
 
 def _load_from_socket(sock_info, serializer):
-    port, auth_secret = sock_info
-    sock = None
-    # Support for both IPv4 and IPv6.
-    # On most of IPv6-ready systems, IPv6 will take precedence.
-    for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
-        af, socktype, proto, canonname, sa = res
-        sock = socket.socket(af, socktype, proto)
-        try:
-            sock.settimeout(15)
-            sock.connect(sa)
-        except socket.error:
-            sock.close()
-            sock = None
-            continue
-        break
-    if not sock:
-        raise Exception("could not open socket")
+    (sockfile, sock) = local_connect_and_auth(*sock_info)
     # The RDD materialization time is unpredicable, if we set a timeout for socket reading
     # operation, it will very possibly fail. See SPARK-18281.
     sock.settimeout(None)
-
-    sockfile = sock.makefile("rwb", 65536)
-    do_server_auth(sockfile, auth_secret)
-
     # The socket will be automatically closed when garbage-collected.
     return serializer.load_stream(sockfile)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a2a54a5f/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 788b323..812f4b2 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,7 +27,7 @@ import traceback
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.broadcast import Broadcast, _broadcastRegistry
-from pyspark.java_gateway import do_server_auth
+from pyspark.java_gateway import local_connect_and_auth
 from pyspark.taskcontext import TaskContext
 from pyspark.files import SparkFiles
 from pyspark.rdd import PythonEvalType
@@ -269,8 +269,5 @@ if __name__ == '__main__':
     # Read information about how to connect back to the JVM from the environment.
     java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
     auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    sock.connect(("127.0.0.1", java_port))
-    sock_file = sock.makefile("rwb", 65536)
-    do_server_auth(sock_file, auth_secret)
+    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
     main(sock_file, sock_file)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org