You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2018/09/26 03:18:49 UTC

[2/2] spark git commit: [SPARKR] Match pyspark features in SparkR communication protocol

[SPARKR] Match pyspark features in SparkR communication protocol


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ef361682
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ef361682
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ef361682

Branch: refs/heads/branch-2.2
Commit: ef36168258b8ad15362312e0562794f4f07322d0
Parents: 8ad6693
Author: hyukjinkwon <gu...@apache.org>
Authored: Mon Sep 24 19:25:02 2018 +0800
Committer: hyukjinkwon <gu...@apache.org>
Committed: Wed Sep 26 10:50:46 2018 +0800

----------------------------------------------------------------------
 R/pkg/R/context.R                               | 43 ++++++++++++++------
 R/pkg/tests/fulltests/test_Serde.R              | 32 +++++++++++++++
 R/pkg/tests/fulltests/test_sparkSQL.R           | 12 ------
 .../scala/org/apache/spark/api/r/RRDD.scala     | 33 ++++++++++++++-
 .../scala/org/apache/spark/api/r/RUtils.scala   |  4 ++
 5 files changed, 98 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/R/pkg/R/context.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index 50856e3..c1a12f5 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -168,18 +168,30 @@ parallelize <- function(sc, coll, numSlices = 1) {
   # 2-tuples of raws
   serializedSlices <- lapply(slices, serialize, connection = NULL)
 
-  # The PRC backend cannot handle arguments larger than 2GB (INT_MAX)
+  # The RPC backend cannot handle arguments larger than 2GB (INT_MAX)
   # If serialized data is safely less than that threshold we send it over the PRC channel.
   # Otherwise, we write it to a file and send the file name
   if (objectSize < sizeLimit) {
     jrdd <- callJStatic("org.apache.spark.api.r.RRDD", "createRDDFromArray", sc, serializedSlices)
   } else {
-    fileName <- writeToTempFile(serializedSlices)
-    jrdd <- tryCatch(callJStatic(
-        "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)),
-      finally = {
-        file.remove(fileName)
-    })
+    if (callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc)) {
+      # the length of slices here is the parallelism to use in the jvm's sc.parallelize()
+      parallelism <- as.integer(numSlices)
+      jserver <- newJObject("org.apache.spark.api.r.RParallelizeServer", sc, parallelism)
+      authSecret <- callJMethod(jserver, "secret")
+      port <- callJMethod(jserver, "port")
+      conn <- socketConnection(port = port, blocking = TRUE, open = "wb", timeout = 1500)
+      doServerAuth(conn, authSecret)
+      writeToConnection(serializedSlices, conn)
+      jrdd <- callJMethod(jserver, "getResult")
+    } else {
+      fileName <- writeToTempFile(serializedSlices)
+      jrdd <- tryCatch(callJStatic(
+          "org.apache.spark.api.r.RRDD", "createRDDFromFile", sc, fileName, as.integer(numSlices)),
+        finally = {
+          file.remove(fileName)
+      })
+    }
   }
 
   RDD(jrdd, "byte")
@@ -195,14 +207,21 @@ getMaxAllocationLimit <- function(sc) {
   ))
 }
 
+writeToConnection <- function(serializedSlices, conn) {
+  tryCatch({
+    for (slice in serializedSlices) {
+      writeBin(as.integer(length(slice)), conn, endian = "big")
+      writeBin(slice, conn, endian = "big")
+    }
+  }, finally = {
+    close(conn)
+  })
+}
+
 writeToTempFile <- function(serializedSlices) {
   fileName <- tempfile()
   conn <- file(fileName, "wb")
-  for (slice in serializedSlices) {
-    writeBin(as.integer(length(slice)), conn, endian = "big")
-    writeBin(slice, conn, endian = "big")
-  }
-  close(conn)
+  writeToConnection(serializedSlices, conn)
   fileName
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/R/pkg/tests/fulltests/test_Serde.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_Serde.R b/R/pkg/tests/fulltests/test_Serde.R
index 6bbd201..092f9b8 100644
--- a/R/pkg/tests/fulltests/test_Serde.R
+++ b/R/pkg/tests/fulltests/test_Serde.R
@@ -77,3 +77,35 @@ test_that("SerDe of list of lists", {
 })
 
 sparkR.session.stop()
+
+# Note that this test should be at the end of tests since the configruations used here are not
+# specific to sessions, and the Spark context is restarted.
+test_that("createDataFrame large objects", {
+  for (encryptionEnabled in list("true", "false")) {
+    # To simulate a large object scenario, we set spark.r.maxAllocationLimit to a smaller value
+    conf <- list(spark.r.maxAllocationLimit = "100",
+                 spark.io.encryption.enabled = encryptionEnabled)
+
+    suppressWarnings(sparkR.session(master = sparkRTestMaster,
+                                    sparkConfig = conf,
+                                    enableHiveSupport = FALSE))
+
+    sc <- getSparkContext()
+    actual <- callJStatic("org.apache.spark.api.r.RUtils", "getEncryptionEnabled", sc)
+    expected <- as.logical(encryptionEnabled)
+    expect_equal(actual, expected)
+
+    tryCatch({
+      # suppress warnings from dot in the field names. See also SPARK-21536.
+      df <- suppressWarnings(createDataFrame(iris, numPartitions = 3))
+      expect_equal(getNumPartitions(df), 3)
+      expect_equal(dim(df), dim(iris))
+
+      df <- createDataFrame(cars, numPartitions = 3)
+      expect_equal(collect(df), cars)
+    },
+    finally = {
+      sparkR.stop()
+    })
+  }
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/R/pkg/tests/fulltests/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index f774554..f2b1c1d 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -298,18 +298,6 @@ test_that("create DataFrame from RDD", {
   unsetHiveContext()
 })
 
-test_that("createDataFrame uses files for large objects", {
-  # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value
-  conf <- callJMethod(sparkSession, "conf")
-  callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100")
-  df <- suppressWarnings(createDataFrame(iris, numPartitions = 3))
-  expect_equal(getNumPartitions(df), 3)
-
-  # Resetting the conf back to default value
-  callJMethod(conf, "set", "spark.r.maxAllocationLimit", toString(.Machine$integer.max / 10))
-  expect_equal(dim(df), dim(iris))
-})
-
 test_that("read/write csv as DataFrame", {
   if (windows_with_hadoop()) {
     csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")

http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 295355c..1dc61c7 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.api.r
 
-import java.io.File
+import java.io.{DataInputStream, File}
+import java.net.Socket
+import java.nio.charset.StandardCharsets.UTF_8
 import java.util.{Map => JMap}
 
 import scala.collection.JavaConverters._
@@ -25,10 +27,11 @@ import scala.reflect.ClassTag
 
 import org.apache.spark._
 import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
-import org.apache.spark.api.python.PythonRDD
+import org.apache.spark.api.python.{PythonRDD, PythonServer}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
+import org.apache.spark.security.SocketAuthHelper
 
 private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
     parent: RDD[T],
@@ -163,3 +166,29 @@ private[r] object RRDD {
     PythonRDD.readRDDFromFile(jsc, fileName, parallelism)
   }
 }
+
+/**
+ * Helper for making RDD[Array[Byte]] from some R data, by reading the data from R
+ * over a socket. This is used in preference to writing data to a file when encryption is enabled.
+ */
+private[spark] class RParallelizeServer(sc: JavaSparkContext, parallelism: Int)
+    extends PythonServer[JavaRDD[Array[Byte]]](
+      new RSocketAuthHelper(), "sparkr-parallelize-server") {
+
+  override def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
+    val in = sock.getInputStream()
+    PythonRDD.readRDDFromInputStream(sc.sc, in, parallelism)
+  }
+}
+
+private[spark] class RSocketAuthHelper extends SocketAuthHelper(SparkEnv.get.conf) {
+  override protected def readUtf8(s: Socket): String = {
+    val din = new DataInputStream(s.getInputStream())
+    val len = din.readInt()
+    val bytes = new Array[Byte](len)
+    din.readFully(bytes)
+    // The R code adds a null terminator to serialized strings, so ignore it here.
+    assert(bytes(bytes.length - 1) == 0) // sanity check.
+    new String(bytes, 0, bytes.length - 1, UTF_8)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ef361682/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
index fdd8cf6..9bf35af 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RUtils.scala
@@ -21,6 +21,8 @@ import java.io.File
 import java.util.Arrays
 
 import org.apache.spark.{SparkEnv, SparkException}
+import org.apache.spark.api.java.JavaSparkContext
+import org.apache.spark.api.python.PythonUtils
 
 private[spark] object RUtils {
   // Local path where R binary packages built from R source code contained in the spark
@@ -104,4 +106,6 @@ private[spark] object RUtils {
       case e: Exception => false
     }
   }
+
+  def getEncryptionEnabled(sc: JavaSparkContext): Boolean = PythonUtils.getEncryptionEnabled(sc)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org