You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/03/10 06:08:39 UTC

[spark] branch master updated: [SPARK-27102][R][PYTHON][CORE] Remove the references to Python's Scala codes in R's Scala codes

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 28d0030  [SPARK-27102][R][PYTHON][CORE] Remove the references to Python's Scala codes in R's Scala codes
28d0030 is described below

commit 28d003097b114623b891e0ae5dbe1709a54da891
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Sun Mar 10 15:08:23 2019 +0900

    [SPARK-27102][R][PYTHON][CORE] Remove the references to Python's Scala codes in R's Scala codes
    
    ## What changes were proposed in this pull request?
    
    Currently, R's Scala codes happened to refer Python's Scala codes for code deduplications. It's a bit odd. For instance, when we face an exception from R, it shows python related code path, which makes confusing to debug. It should rather have one code base and R's and Python's should share.
    
    This PR proposes:
    
    1. Make a `SocketAuthServer` and move `PythonServer` so that `PythonRDD` and `RRDD` can share it.
    2. Move `readRDDFromFile` and `readRDDFromInputStream` into `JavaRDD`.
    3. Reuse `RAuthHelper` and remove `RSocketAuthHelper` in `RRDD`.
    4. Rename `getEncryptionEnabled` to `isEncryptionEnabled` while I am here.
    
    So, now, the places below:
    
    - `sql/core/src/main/scala/org/apache/spark/sql/api/r`
    - `core/src/main/scala/org/apache/spark/api/r`
    - `mllib/src/main/scala/org/apache/spark/ml/r`
    
    don't refer Python's Scala codes.
    
    ## How was this patch tested?
    
    Existing tests should cover this.
    
    Closes #24023 from HyukjinKwon/SPARK-27102.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 R/pkg/R/context.R                                  |   2 +-
 R/pkg/tests/fulltests/test_Serde.R                 |   2 +-
 .../scala/org/apache/spark/api/java/JavaRDD.scala  |  33 +++++
 .../org/apache/spark/api/python/PythonRDD.scala    | 135 +++------------------
 .../org/apache/spark/api/python/PythonUtils.scala  |   2 +-
 .../main/scala/org/apache/spark/api/r/RRDD.scala   |  28 ++---
 .../main/scala/org/apache/spark/api/r/RUtils.scala |   5 +-
 .../apache/spark/security/SocketAuthHelper.scala   |  18 ++-
 .../apache/spark/security/SocketAuthServer.scala   | 108 +++++++++++++++++
 .../apache/spark/api/python/PythonRDDSuite.scala   |   4 +-
 python/pyspark/context.py                          |   2 +-
 11 files changed, 188 insertions(+), 151 deletions(-)

diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index 1c064a6..6191536 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -175,7 +175,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
   if (objectSize < sizeLimit) {
     jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices)
   } else {
-    if (callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc)) {
+    if (callJStatic("org.apache.spark.api.r.RUtils", "isEncryptionEnabled", sc)) {
       connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
       # the length of slices here is the parallelism to use in the jvm's sc.parallelize()
       parallelism <- as.integer(numSlices)
diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R
index 1525bdb..e01f6ee 100644
--- a/R/pkg/tests/fulltests/test_Serde.R
+++ b/R/pkg/tests/fulltests/test_Serde.R
@@ -138,7 +138,7 @@ test_that("createDataFrame large objects", {
                                     enableHiveSupport = FALSE))
 
     sc <- getSparkContext()
-    actual <- callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc)
+    actual <- callJStatic("org.apache.spark.api.r.RUtils", "isEncryptionEnabled", sc)
     expected <- as.logical(encryptionEnabled)
     expect_equal(actual, expected)
 
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 41b5cab..6f01822 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -17,6 +17,9 @@
 
 package org.apache.spark.api.java
 
+import java.io.{DataInputStream, EOFException, FileInputStream, InputStream}
+
+import scala.collection.mutable
 import scala.language.implicitConversions
 import scala.reflect.ClassTag
 
@@ -213,4 +216,34 @@ object JavaRDD {
   implicit def fromRDD[T: ClassTag](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
 
   implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd
+
+  private[api] def readRDDFromFile(
+      sc: JavaSparkContext,
+      filename: String,
+      parallelism: Int): JavaRDD[Array[Byte]] = {
+    readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
+  }
+
+  private[api] def readRDDFromInputStream(
+      sc: SparkContext,
+      in: InputStream,
+      parallelism: Int): JavaRDD[Array[Byte]] = {
+    val din = new DataInputStream(in)
+    try {
+      val objs = new mutable.ArrayBuffer[Array[Byte]]
+      try {
+        while (true) {
+          val length = din.readInt()
+          val obj = new Array[Byte](length)
+          din.readFully(obj)
+          objs += obj
+        }
+      } catch {
+        case eof: EOFException => // No-op
+      }
+      JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
+    } finally {
+      din.close()
+    }
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 0937a63..5b492b1 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -42,7 +42,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.rdd.RDD
-import org.apache.spark.security.SocketAuthHelper
+import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
 import org.apache.spark.util._
 
 
@@ -171,32 +171,18 @@ private[spark] object PythonRDD extends Logging {
     serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
   }
 
-  def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
-  JavaRDD[Array[Byte]] = {
-    readRDDFromInputStream(sc.sc, new FileInputStream(filename), parallelism)
+  def readRDDFromFile(
+      sc: JavaSparkContext,
+      filename: String,
+      parallelism: Int): JavaRDD[Array[Byte]] = {
+    JavaRDD.readRDDFromFile(sc, filename, parallelism)
   }
 
   def readRDDFromInputStream(
       sc: SparkContext,
       in: InputStream,
       parallelism: Int): JavaRDD[Array[Byte]] = {
-    val din = new DataInputStream(in)
-    try {
-      val objs = new mutable.ArrayBuffer[Array[Byte]]
-      try {
-        while (true) {
-          val length = din.readInt()
-          val obj = new Array[Byte](length)
-          din.readFully(obj)
-          objs += obj
-        }
-      } catch {
-        case eof: EOFException => // No-op
-      }
-      JavaRDD.fromRDD(sc.parallelize(objs, parallelism))
-    } finally {
-      din.close()
-    }
+    JavaRDD.readRDDFromInputStream(sc, in, parallelism)
   }
 
   def setupBroadcast(path: String): PythonBroadcast = {
@@ -430,21 +416,7 @@ private[spark] object PythonRDD extends Logging {
    */
   private[spark] def serveToStream(
       threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
-    serveToStream(threadName, authHelper)(writeFunc)
-  }
-
-  private[spark] def serveToStream(
-      threadName: String, authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit)
-    : Array[Any] = {
-    val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { s =>
-      val out = new BufferedOutputStream(s.getOutputStream())
-      Utils.tryWithSafeFinally {
-        writeFunc(out)
-      } {
-        out.close()
-      }
-    }
-    Array(port, secret)
+    SocketAuthHelper.serveToStream(threadName, authHelper)(writeFunc)
   }
 
   private def getMergedConf(confAsMap: java.util.HashMap[String, String],
@@ -666,8 +638,8 @@ private[spark] class PythonAccumulatorV2(
 private[spark] class PythonBroadcast(@transient var path: String) extends Serializable
     with Logging {
 
-  private var encryptionServer: PythonServer[Unit] = null
-  private var decryptionServer: PythonServer[Unit] = null
+  private var encryptionServer: SocketAuthServer[Unit] = null
+  private var decryptionServer: SocketAuthServer[Unit] = null
 
   /**
    * Read data from disks, then copy it to `out`
@@ -712,7 +684,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
   }
 
   def setupEncryptionServer(): Array[Any] = {
-    encryptionServer = new PythonServer[Unit]("broadcast-encrypt-server") {
+    encryptionServer = new SocketAuthServer[Unit]("broadcast-encrypt-server") {
       override def handleConnection(sock: Socket): Unit = {
         val env = SparkEnv.get
         val in = sock.getInputStream()
@@ -725,7 +697,7 @@ private[spark] class PythonBroadcast(@transient var path: String) extends Serial
   }
 
   def setupDecryptionServer(): Array[Any] = {
-    decryptionServer = new PythonServer[Unit]("broadcast-decrypt-server-for-driver") {
+    decryptionServer = new SocketAuthServer[Unit]("broadcast-decrypt-server-for-driver") {
       override def handleConnection(sock: Socket): Unit = {
         val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream()))
         Utils.tryWithSafeFinally {
@@ -821,90 +793,12 @@ private[spark] object DechunkedInputStream {
 }
 
 /**
- * Creates a server in the jvm to communicate with python for handling one batch of data, with
- * authentication and error handling.
- */
-private[spark] abstract class PythonServer[T](
-    authHelper: SocketAuthHelper,
-    threadName: String) {
-
-  def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
-  def this(threadName: String) = this(SparkEnv.get, threadName)
-
-  val (port, secret) = PythonServer.setupOneConnectionServer(authHelper, threadName) { sock =>
-    promise.complete(Try(handleConnection(sock)))
-  }
-
-  /**
-   * Handle a connection which has already been authenticated.  Any error from this function
-   * will clean up this connection and the entire server, and get propogated to [[getResult]].
-   */
-  def handleConnection(sock: Socket): T
-
-  val promise = Promise[T]()
-
-  /**
-   * Blocks indefinitely for [[handleConnection]] to finish, and returns that result.  If
-   * handleConnection throws an exception, this will throw an exception which includes the original
-   * exception as a cause.
-   */
-  def getResult(): T = {
-    getResult(Duration.Inf)
-  }
-
-  def getResult(wait: Duration): T = {
-    ThreadUtils.awaitResult(promise.future, wait)
-  }
-
-}
-
-private[spark] object PythonServer {
-
-  /**
-   * Create a socket server and run user function on the socket in a background thread.
-   *
-   * The socket server can only accept one connection, or close if no connection
-   * in 15 seconds.
-   *
-   * The thread will terminate after the supplied user function, or if there are any exceptions.
-   *
-   * If you need to get a result of the supplied function, create a subclass of [[PythonServer]]
-   *
-   * @return The port number of a local socket and the secret for authentication.
-   */
-  def setupOneConnectionServer(
-      authHelper: SocketAuthHelper,
-      threadName: String)
-      (func: Socket => Unit): (Int, String) = {
-    val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
-    // Close the socket if no connection in 15 seconds
-    serverSocket.setSoTimeout(15000)
-
-    new Thread(threadName) {
-      setDaemon(true)
-      override def run(): Unit = {
-        var sock: Socket = null
-        try {
-          sock = serverSocket.accept()
-          authHelper.authClient(sock)
-          func(sock)
-        } finally {
-          JavaUtils.closeQuietly(serverSocket)
-          JavaUtils.closeQuietly(sock)
-        }
-      }
-    }.start()
-    (serverSocket.getLocalPort, authHelper.secret)
-  }
-}
-
-/**
  * Sends decrypted broadcast data to python worker.  See [[PythonRunner]] for entire protocol.
  */
 private[spark] class EncryptedPythonBroadcastServer(
     val env: SparkEnv,
     val idsAndFiles: Seq[(Long, String)])
-    extends PythonServer[Unit]("broadcast-decrypt-server") with Logging {
+    extends SocketAuthServer[Unit]("broadcast-decrypt-server") with Logging {
 
   override def handleConnection(socket: Socket): Unit = {
     val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream()))
@@ -942,7 +836,7 @@ private[spark] class EncryptedPythonBroadcastServer(
  * over a socket.  This is used in preference to writing data to a file when encryption is enabled.
  */
 private[spark] abstract class PythonRDDServer
-    extends PythonServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
+    extends SocketAuthServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
 
   def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
     val in = sock.getInputStream()
@@ -961,4 +855,3 @@ private[spark] class PythonParallelizeServer(sc: SparkContext, parallelism: Int)
     PythonRDD.readRDDFromInputStream(sc, input, parallelism)
   }
 }
-
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index b6b0cac..ab1bf69 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -76,7 +76,7 @@ private[spark] object PythonUtils {
     jm.asScala.toMap
   }
 
-  def getEncryptionEnabled(sc: JavaSparkContext): Boolean = {
+  def isEncryptionEnabled(sc: JavaSparkContext): Boolean = {
     sc.conf.get(org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 04fc6e1..4a59c3e 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -17,9 +17,8 @@
 
 package org.apache.spark.api.r
 
-import java.io.{DataInputStream, File, OutputStream}
+import java.io.{File, OutputStream}
 import java.net.Socket
-import java.nio.charset.StandardCharsets.UTF_8
 import java.util.{Map => JMap}
 
 import scala.collection.JavaConverters._
@@ -27,11 +26,10 @@ import scala.reflect.ClassTag
 
 import org.apache.spark._
 import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
-import org.apache.spark.api.python.{PythonRDD, PythonServer}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.security.SocketAuthHelper
+import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
 
 private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
     parent: RDD[T],
@@ -163,12 +161,12 @@ private[spark] object RRDD {
    */
   def createRDDFromFile(jsc: JavaSparkContext, fileName: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
-    PythonRDD.readRDDFromFile(jsc, fileName, parallelism)
+    JavaRDD.readRDDFromFile(jsc, fileName, parallelism)
   }
 
   private[spark] def serveToStream(
       threadName: String)(writeFunc: OutputStream => Unit): Array[Any] = {
-    PythonRDD.serveToStream(threadName, new RSocketAuthHelper())(writeFunc)
+    SocketAuthHelper.serveToStream(threadName, new RAuthHelper(SparkEnv.get.conf))(writeFunc)
   }
 }
 
@@ -177,23 +175,11 @@ private[spark] object RRDD {
  * over a socket. This is used in preference to writing data to a file when encryption is enabled.
  */
 private[spark] class RParallelizeServer(sc: JavaSparkContext, parallelism: Int)
-    extends PythonServer[JavaRDD[Array[Byte]]](
-      new RSocketAuthHelper(), "sparkr-parallelize-server") {
+    extends SocketAuthServer[JavaRDD[Array[Byte]]](
+      new RAuthHelper(SparkEnv.get.conf), "sparkr-parallelize-server") {
 
   override def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
     val in = sock.getInputStream()
-    PythonRDD.readRDDFromInputStream(sc.sc, in, parallelism)
-  }
-}
-
-private[spark] class RSocketAuthHelper extends SocketAuthHelper(SparkEnv.get.conf) {
-  override protected def readUtf8(s: Socket): String = {
-    val din = new DataInputStream(s.getInputStream())
-    val len = din.readInt()
-    val bytes = new Array[Byte](len)
-    din.readFully(bytes)
-    // The R code adds a null terminator to serialized strings, so ignore it here.
-    assert(bytes(bytes.length - 1) == 0) // sanity check.
-    new String(bytes, 0, bytes.length - 1, UTF_8)
+    JavaRDD.readRDDFromInputStream(sc.sc, in, parallelism)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
index 6832223..5a43302 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
@@ -22,7 +22,6 @@ import java.util.Arrays
 
 import org.apache.spark.{SparkEnv, SparkException}
 import org.apache.spark.api.java.JavaSparkContext
-import org.apache.spark.api.python.PythonUtils
 import org.apache.spark.internal.config._
 
 private[spark] object RUtils {
@@ -108,5 +107,7 @@ private[spark] object RUtils {
     }
   }
 
-  def getEncryptionEnabled(sc: JavaSparkContext): Boolean = PythonUtils.getEncryptionEnabled(sc)
+  def isEncryptionEnabled(sc: JavaSparkContext): Boolean = {
+    sc.conf.get(org.apache.spark.internal.config.IO_ENCRYPTION_ENABLED)
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
index ea38ccb..3a107c0 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.security
 
-import java.io.{DataInputStream, DataOutputStream, InputStream}
+import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, OutputStream}
 import java.net.Socket
 import java.nio.charset.StandardCharsets.UTF_8
 
@@ -115,3 +115,19 @@ private[spark] class SocketAuthHelper(conf: SparkConf) {
   }
 
 }
+
+private[spark] object SocketAuthHelper {
+  def serveToStream(
+      threadName: String,
+      authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): Array[Any] = {
+    val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { s =>
+      val out = new BufferedOutputStream(s.getOutputStream())
+      Utils.tryWithSafeFinally {
+        writeFunc(out)
+      } {
+        out.close()
+      }
+    }
+    Array(port, secret)
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
new file mode 100644
index 0000000..c65c8fd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.security
+
+import java.net.{InetAddress, ServerSocket, Socket}
+
+import scala.concurrent.Promise
+import scala.concurrent.duration.Duration
+import scala.language.existentials
+import scala.util.Try
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.ThreadUtils
+
+
+/**
+ * Creates a server in the JVM to communicate with external processes (e.g., Python and R) for
+ * handling one batch of data, with authentication and error handling.
+ */
+private[spark] abstract class SocketAuthServer[T](
+    authHelper: SocketAuthHelper,
+    threadName: String) {
+
+  def this(env: SparkEnv, threadName: String) = this(new SocketAuthHelper(env.conf), threadName)
+  def this(threadName: String) = this(SparkEnv.get, threadName)
+
+  private val promise = Promise[T]()
+
+  val (port, secret) = SocketAuthServer.setupOneConnectionServer(authHelper, threadName) { sock =>
+    promise.complete(Try(handleConnection(sock)))
+  }
+
+  /**
+   * Handle a connection which has already been authenticated.  Any error from this function
+   * will clean up this connection and the entire server, and get propagated to [[getResult]].
+   */
+  def handleConnection(sock: Socket): T
+
+  /**
+   * Blocks indefinitely for [[handleConnection]] to finish, and returns that result.  If
+   * handleConnection throws an exception, this will throw an exception which includes the original
+   * exception as a cause.
+   */
+  def getResult(): T = {
+    getResult(Duration.Inf)
+  }
+
+  def getResult(wait: Duration): T = {
+    ThreadUtils.awaitResult(promise.future, wait)
+  }
+
+}
+
+private[spark] object SocketAuthServer {
+
+  /**
+   * Create a socket server and run user function on the socket in a background thread.
+   *
+   * The socket server can only accept one connection, or close if no connection
+   * in 15 seconds.
+   *
+   * The thread will terminate after the supplied user function, or if there are any exceptions.
+   *
+   * If you need to get a result of the supplied function, create a subclass of [[SocketAuthServer]]
+   *
+   * @return The port number of a local socket and the secret for authentication.
+   */
+  def setupOneConnectionServer(
+      authHelper: SocketAuthHelper,
+      threadName: String)
+      (func: Socket => Unit): (Int, String) = {
+    val serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
+    // Close the socket if no connection in 15 seconds
+    serverSocket.setSoTimeout(15000)
+
+    new Thread(threadName) {
+      setDaemon(true)
+      override def run(): Unit = {
+        var sock: Socket = null
+        try {
+          sock = serverSocket.accept()
+          authHelper.authClient(sock)
+          func(sock)
+        } finally {
+          JavaUtils.closeQuietly(serverSocket)
+          JavaUtils.closeQuietly(sock)
+        }
+      }
+    }.start()
+    (serverSocket.getLocalPort, authHelper.secret)
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 6f9b583..e2ec50f 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -24,7 +24,7 @@ import java.nio.charset.StandardCharsets
 import scala.concurrent.duration.Duration
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.security.SocketAuthHelper
+import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
 
 class PythonRDDSuite extends SparkFunSuite {
 
@@ -59,7 +59,7 @@ class PythonRDDSuite extends SparkFunSuite {
   }
 
   class ExceptionPythonServer(authHelper: SocketAuthHelper)
-      extends PythonServer[Unit](authHelper, "error-server") {
+      extends SocketAuthServer[Unit](authHelper, "error-server") {
 
     override def handleConnection(sock: Socket): Unit = {
       throw new Exception("exception within handleConnection")
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 5a4bd57..63c3043 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -204,7 +204,7 @@ class SparkContext(object):
         # If encryption is enabled, we need to setup a server in the jvm to read broadcast
         # data via a socket.
         # scala's mangled names w/ $ in them require special treatment.
-        self._encryption_enabled = self._jvm.PythonUtils.getEncryptionEnabled(self._jsc)
+        self._encryption_enabled = self._jvm.PythonUtils.isEncryptionEnabled(self._jsc)
 
         self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
         self.pythonVer = "%d.%d" % sys.version_info[:2]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org