You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by da...@apache.org on 2016/03/28 19:14:34 UTC

spark git commit: [SPARK-12792] [SPARKR] Refactor RRDD to support R UDF.

Repository: spark
Updated Branches:
  refs/heads/master 68c0c460b -> 40984f670


[SPARK-12792] [SPARKR] Refactor RRDD to support R UDF.

Refactor RRDD by separating the common logic interacting with the R worker to a new class RRunner, which can be used to evaluate R UDFs.

Now RRDD relies on RRuner for RDD computation and RRDD could be reomved if we want to remove RDD API in SparkR later.

Author: Sun Rui <ru...@intel.com>

Closes #10947 from sun-rui/SPARK-12792.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/40984f67
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/40984f67
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/40984f67

Branch: refs/heads/master
Commit: 40984f67065eeaea731940008e6677c2323dda3e
Parents: 68c0c46
Author: Sun Rui <ru...@intel.com>
Authored: Mon Mar 28 10:14:28 2016 -0700
Committer: Davies Liu <da...@gmail.com>
Committed: Mon Mar 28 10:14:28 2016 -0700

----------------------------------------------------------------------
 R/pkg/inst/tests/testthat/test_rdd.R            |   8 +
 .../scala/org/apache/spark/api/r/RRDD.scala     | 328 +----------------
 .../scala/org/apache/spark/api/r/RRunner.scala  | 367 +++++++++++++++++++
 3 files changed, 379 insertions(+), 324 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/40984f67/R/pkg/inst/tests/testthat/test_rdd.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R
index 3b0c16b..b6c8e1d 100644
--- a/R/pkg/inst/tests/testthat/test_rdd.R
+++ b/R/pkg/inst/tests/testthat/test_rdd.R
@@ -791,3 +791,11 @@ test_that("sampleByKey() on pairwise RDDs", {
   expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE)
   expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE)
 })
+
+test_that("Test correct concurrency of RRDD.compute()", {
+  rdd <- parallelize(sc, 1:1000, 100)
+  jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row")
+  zrdd <- callJMethod(jrdd, "zip", jrdd)
+  count <- callJMethod(zrdd, "count")
+  expect_equal(count, 1000)
+})

http://git-wip-us.apache.org/repos/asf/spark/blob/40984f67/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 588a57e..606ba6e 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -17,21 +17,16 @@
 
 package org.apache.spark.api.r
 
-import java.io._
-import java.net.{InetAddress, ServerSocket}
-import java.util.{Arrays, Map => JMap}
+import java.util.{Map => JMap}
 
 import scala.collection.JavaConverters._
-import scala.io.Source
 import scala.reflect.ClassTag
-import scala.util.Try
 
 import org.apache.spark._
 import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.util.Utils
 
 private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
     parent: RDD[T],
@@ -42,188 +37,16 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
     packageNames: Array[Byte],
     broadcastVars: Array[Broadcast[Object]])
   extends RDD[U](parent) with Logging {
-  protected var dataStream: DataInputStream = _
-  private var bootTime: Double = _
   override def getPartitions: Array[Partition] = parent.partitions
 
   override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
-
-    // Timing start
-    bootTime = System.currentTimeMillis / 1000.0
+    val runner = new RRunner[U](
+      func, deserializer, serializer, packageNames, broadcastVars, numPartitions)
 
     // The parent may be also an RRDD, so we should launch it first.
     val parentIterator = firstParent[T].iterator(partition, context)
 
-    // we expect two connections
-    val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost"))
-    val listenPort = serverSocket.getLocalPort()
-
-    // The stdout/stderr is shared by multiple tasks, because we use one daemon
-    // to launch child process as worker.
-    val errThread = RRDD.createRWorker(listenPort)
-
-    // We use two sockets to separate input and output, then it's easy to manage
-    // the lifecycle of them to avoid deadlock.
-    // TODO: optimize it to use one socket
-
-    // the socket used to send out the input of task
-    serverSocket.setSoTimeout(10000)
-    val inSocket = serverSocket.accept()
-    startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index)
-
-    // the socket used to receive the output of task
-    val outSocket = serverSocket.accept()
-    val inputStream = new BufferedInputStream(outSocket.getInputStream)
-    dataStream = new DataInputStream(inputStream)
-    serverSocket.close()
-
-    try {
-
-      return new Iterator[U] {
-        def next(): U = {
-          val obj = _nextObj
-          if (hasNext) {
-            _nextObj = read()
-          }
-          obj
-        }
-
-        var _nextObj = read()
-
-        def hasNext(): Boolean = {
-          val hasMore = (_nextObj != null)
-          if (!hasMore) {
-            dataStream.close()
-          }
-          hasMore
-        }
-      }
-    } catch {
-      case e: Exception =>
-        throw new SparkException("R computation failed with\n " + errThread.getLines())
-    }
-  }
-
-  /**
-   * Start a thread to write RDD data to the R process.
-   */
-  private def startStdinThread[T](
-    output: OutputStream,
-    iter: Iterator[T],
-    partition: Int): Unit = {
-
-    val env = SparkEnv.get
-    val taskContext = TaskContext.get()
-    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
-    val stream = new BufferedOutputStream(output, bufferSize)
-
-    new Thread("writer for R") {
-      override def run(): Unit = {
-        try {
-          SparkEnv.set(env)
-          TaskContext.setTaskContext(taskContext)
-          val dataOut = new DataOutputStream(stream)
-          dataOut.writeInt(partition)
-
-          SerDe.writeString(dataOut, deserializer)
-          SerDe.writeString(dataOut, serializer)
-
-          dataOut.writeInt(packageNames.length)
-          dataOut.write(packageNames)
-
-          dataOut.writeInt(func.length)
-          dataOut.write(func)
-
-          dataOut.writeInt(broadcastVars.length)
-          broadcastVars.foreach { broadcast =>
-            // TODO(shivaram): Read a Long in R to avoid this cast
-            dataOut.writeInt(broadcast.id.toInt)
-            // TODO: Pass a byte array from R to avoid this cast ?
-            val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
-            dataOut.writeInt(broadcastByteArr.length)
-            dataOut.write(broadcastByteArr)
-          }
-
-          dataOut.writeInt(numPartitions)
-
-          if (!iter.hasNext) {
-            dataOut.writeInt(0)
-          } else {
-            dataOut.writeInt(1)
-          }
-
-          val printOut = new PrintStream(stream)
-
-          def writeElem(elem: Any): Unit = {
-            if (deserializer == SerializationFormats.BYTE) {
-              val elemArr = elem.asInstanceOf[Array[Byte]]
-              dataOut.writeInt(elemArr.length)
-              dataOut.write(elemArr)
-            } else if (deserializer == SerializationFormats.ROW) {
-              dataOut.write(elem.asInstanceOf[Array[Byte]])
-            } else if (deserializer == SerializationFormats.STRING) {
-              // write string(for StringRRDD)
-              // scalastyle:off println
-              printOut.println(elem)
-              // scalastyle:on println
-            }
-          }
-
-          for (elem <- iter) {
-            elem match {
-              case (key, value) =>
-                writeElem(key)
-                writeElem(value)
-              case _ =>
-                writeElem(elem)
-            }
-          }
-          stream.flush()
-        } catch {
-          // TODO: We should propogate this error to the task thread
-          case e: Exception =>
-            logError("R Writer thread got an exception", e)
-        } finally {
-          Try(output.close())
-        }
-      }
-    }.start()
-  }
-
-  protected def readData(length: Int): U
-
-  protected def read(): U = {
-    try {
-      val length = dataStream.readInt()
-
-      length match {
-        case SpecialLengths.TIMING_DATA =>
-          // Timing data from R worker
-          val boot = dataStream.readDouble - bootTime
-          val init = dataStream.readDouble
-          val broadcast = dataStream.readDouble
-          val input = dataStream.readDouble
-          val compute = dataStream.readDouble
-          val output = dataStream.readDouble
-          logInfo(
-            ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
-             "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
-             "total = %.3f s").format(
-               boot,
-               init,
-               broadcast,
-               input,
-               compute,
-               output,
-               boot + init + broadcast + input + compute + output))
-          read()
-        case length if length >= 0 =>
-          readData(length)
-      }
-    } catch {
-      case eof: EOFException =>
-        throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
-    }
+    runner.compute(parentIterator, partition.index, context)
   }
 }
 
@@ -242,19 +65,6 @@ private class PairwiseRRDD[T: ClassTag](
     parent, numPartitions, hashFunc, deserializer,
     SerializationFormats.BYTE, packageNames,
     broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
-
-  override protected def readData(length: Int): (Int, Array[Byte]) = {
-    length match {
-      case length if length == 2 =>
-        val hashedKey = dataStream.readInt()
-        val contentPairsLength = dataStream.readInt()
-        val contentPairs = new Array[Byte](contentPairsLength)
-        dataStream.readFully(contentPairs)
-        (hashedKey, contentPairs)
-      case _ => null
-   }
-  }
-
   lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this)
 }
 
@@ -271,17 +81,6 @@ private class RRDD[T: ClassTag](
   extends BaseRRDD[T, Array[Byte]](
     parent, -1, func, deserializer, serializer, packageNames,
     broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
-
-  override protected def readData(length: Int): Array[Byte] = {
-    length match {
-      case length if length > 0 =>
-        val obj = new Array[Byte](length)
-        dataStream.readFully(obj)
-        obj
-      case _ => null
-    }
-  }
-
   lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
 }
 
@@ -297,55 +96,10 @@ private class StringRRDD[T: ClassTag](
   extends BaseRRDD[T, String](
     parent, -1, func, deserializer, SerializationFormats.STRING, packageNames,
     broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
-
-  override protected def readData(length: Int): String = {
-    length match {
-      case length if length > 0 =>
-        SerDe.readStringBytes(dataStream, length)
-      case _ => null
-    }
-  }
-
   lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
 }
 
-private object SpecialLengths {
-  val TIMING_DATA = -1
-}
-
-private[r] class BufferedStreamThread(
-    in: InputStream,
-    name: String,
-    errBufferSize: Int) extends Thread(name) with Logging {
-  val lines = new Array[String](errBufferSize)
-  var lineIdx = 0
-  override def run() {
-    for (line <- Source.fromInputStream(in).getLines) {
-      synchronized {
-        lines(lineIdx) = line
-        lineIdx = (lineIdx + 1) % errBufferSize
-      }
-      logInfo(line)
-    }
-  }
-
-  def getLines(): String = synchronized {
-    (0 until errBufferSize).filter { x =>
-      lines((x + lineIdx) % errBufferSize) != null
-    }.map { x =>
-      lines((x + lineIdx) % errBufferSize)
-    }.mkString("\n")
-  }
-}
-
 private[r] object RRDD {
-  // Because forking processes from Java is expensive, we prefer to launch
-  // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
-  // This daemon currently only works on UNIX-based systems now, so we should
-  // also fall back to launching workers (worker.R) directly.
-  private[this] var errThread: BufferedStreamThread = _
-  private[this] var daemonChannel: DataOutputStream = _
-
   def createSparkContext(
       master: String,
       appName: String,
@@ -353,7 +107,6 @@ private[r] object RRDD {
       jars: Array[String],
       sparkEnvirMap: JMap[Object, Object],
       sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = {
-
     val sparkConf = new SparkConf().setAppName(appName)
                                    .setSparkHome(sparkHome)
 
@@ -381,83 +134,10 @@ private[r] object RRDD {
   }
 
   /**
-   * Start a thread to print the process's stderr to ours
-   */
-  private def startStdoutThread(proc: Process): BufferedStreamThread = {
-    val BUFFER_SIZE = 100
-    val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
-    thread.setDaemon(true)
-    thread.start()
-    thread
-  }
-
-  private def createRProcess(port: Int, script: String): BufferedStreamThread = {
-    // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command",
-    // but kept here for backward compatibility.
-    val sparkConf = SparkEnv.get.conf
-    var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
-    rCommand = sparkConf.get("spark.r.command", rCommand)
-
-    val rOptions = "--vanilla"
-    val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
-    val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
-    val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript))
-    // Unset the R_TESTS environment variable for workers.
-    // This is set by R CMD check as startup.Rs
-    // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
-    // and confuses worker script which tries to load a non-existent file
-    pb.environment().put("R_TESTS", "")
-    pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
-    pb.environment().put("SPARKR_WORKER_PORT", port.toString)
-    pb.redirectErrorStream(true)  // redirect stderr into stdout
-    val proc = pb.start()
-    val errThread = startStdoutThread(proc)
-    errThread
-  }
-
-  /**
-   * ProcessBuilder used to launch worker R processes.
-   */
-  def createRWorker(port: Int): BufferedStreamThread = {
-    val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
-    if (!Utils.isWindows && useDaemon) {
-      synchronized {
-        if (daemonChannel == null) {
-          // we expect one connections
-          val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
-          val daemonPort = serverSocket.getLocalPort
-          errThread = createRProcess(daemonPort, "daemon.R")
-          // the socket used to send out the input of task
-          serverSocket.setSoTimeout(10000)
-          val sock = serverSocket.accept()
-          daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
-          serverSocket.close()
-        }
-        try {
-          daemonChannel.writeInt(port)
-          daemonChannel.flush()
-        } catch {
-          case e: IOException =>
-            // daemon process died
-            daemonChannel.close()
-            daemonChannel = null
-            errThread = null
-            // fail the current task, retry by scheduler
-            throw e
-        }
-        errThread
-      }
-    } else {
-      createRProcess(port, "worker.R")
-    }
-  }
-
-  /**
    * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is
    * called from R.
    */
   def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = {
     JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length))
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/40984f67/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
new file mode 100644
index 0000000..e8fcada
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -0,0 +1,367 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io._
+import java.net.{InetAddress, ServerSocket}
+import java.util.Arrays
+
+import scala.io.Source
+import scala.util.Try
+
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.util.Utils
+
+/**
+ * A helper class to run R UDFs in Spark.
+ */
+private[spark] class RRunner[U](
+    func: Array[Byte],
+    deserializer: String,
+    serializer: String,
+    packageNames: Array[Byte],
+    broadcastVars: Array[Broadcast[Object]],
+    numPartitions: Int = -1)
+  extends Logging {
+  private var bootTime: Double = _
+  private var dataStream: DataInputStream = _
+  val readData = numPartitions match {
+    case -1 =>
+      serializer match {
+        case SerializationFormats.STRING => readStringData _
+        case _ => readByteArrayData _
+      }
+    case _ => readShuffledData _
+  }
+
+  def compute(
+      inputIterator: Iterator[_],
+      partitionIndex: Int,
+      context: TaskContext): Iterator[U] = {
+    // Timing start
+    bootTime = System.currentTimeMillis / 1000.0
+
+    // we expect two connections
+    val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost"))
+    val listenPort = serverSocket.getLocalPort()
+
+    // The stdout/stderr is shared by multiple tasks, because we use one daemon
+    // to launch child process as worker.
+    val errThread = RRunner.createRWorker(listenPort)
+
+    // We use two sockets to separate input and output, then it's easy to manage
+    // the lifecycle of them to avoid deadlock.
+    // TODO: optimize it to use one socket
+
+    // the socket used to send out the input of task
+    serverSocket.setSoTimeout(10000)
+    val inSocket = serverSocket.accept()
+    startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
+
+    // the socket used to receive the output of task
+    val outSocket = serverSocket.accept()
+    val inputStream = new BufferedInputStream(outSocket.getInputStream)
+    dataStream = new DataInputStream(inputStream)
+    serverSocket.close()
+
+    try {
+      return new Iterator[U] {
+        def next(): U = {
+          val obj = _nextObj
+          if (hasNext) {
+            _nextObj = read()
+          }
+          obj
+        }
+
+        var _nextObj = read()
+
+        def hasNext(): Boolean = {
+          val hasMore = (_nextObj != null)
+          if (!hasMore) {
+            dataStream.close()
+          }
+          hasMore
+        }
+      }
+    } catch {
+      case e: Exception =>
+        throw new SparkException("R computation failed with\n " + errThread.getLines())
+    }
+  }
+
+  /**
+   * Start a thread to write RDD data to the R process.
+   */
+  private def startStdinThread(
+      output: OutputStream,
+      iter: Iterator[_],
+      partitionIndex: Int): Unit = {
+    val env = SparkEnv.get
+    val taskContext = TaskContext.get()
+    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+    val stream = new BufferedOutputStream(output, bufferSize)
+
+    new Thread("writer for R") {
+      override def run(): Unit = {
+        try {
+          SparkEnv.set(env)
+          TaskContext.setTaskContext(taskContext)
+          val dataOut = new DataOutputStream(stream)
+          dataOut.writeInt(partitionIndex)
+
+          SerDe.writeString(dataOut, deserializer)
+          SerDe.writeString(dataOut, serializer)
+
+          dataOut.writeInt(packageNames.length)
+          dataOut.write(packageNames)
+
+          dataOut.writeInt(func.length)
+          dataOut.write(func)
+
+          dataOut.writeInt(broadcastVars.length)
+          broadcastVars.foreach { broadcast =>
+            // TODO(shivaram): Read a Long in R to avoid this cast
+            dataOut.writeInt(broadcast.id.toInt)
+            // TODO: Pass a byte array from R to avoid this cast ?
+            val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
+            dataOut.writeInt(broadcastByteArr.length)
+            dataOut.write(broadcastByteArr)
+          }
+
+          dataOut.writeInt(numPartitions)
+
+          if (!iter.hasNext) {
+            dataOut.writeInt(0)
+          } else {
+            dataOut.writeInt(1)
+          }
+
+          val printOut = new PrintStream(stream)
+
+          def writeElem(elem: Any): Unit = {
+            if (deserializer == SerializationFormats.BYTE) {
+              val elemArr = elem.asInstanceOf[Array[Byte]]
+              dataOut.writeInt(elemArr.length)
+              dataOut.write(elemArr)
+            } else if (deserializer == SerializationFormats.ROW) {
+              dataOut.write(elem.asInstanceOf[Array[Byte]])
+            } else if (deserializer == SerializationFormats.STRING) {
+              // write string(for StringRRDD)
+              // scalastyle:off println
+              printOut.println(elem)
+              // scalastyle:on println
+            }
+          }
+
+          for (elem <- iter) {
+            elem match {
+              case (key, value) =>
+                writeElem(key)
+                writeElem(value)
+              case _ =>
+                writeElem(elem)
+            }
+          }
+          stream.flush()
+        } catch {
+          // TODO: We should propogate this error to the task thread
+          case e: Exception =>
+            logError("R Writer thread got an exception", e)
+        } finally {
+          Try(output.close())
+        }
+      }
+    }.start()
+  }
+
+  private def read(): U = {
+    try {
+      val length = dataStream.readInt()
+
+      length match {
+        case SpecialLengths.TIMING_DATA =>
+          // Timing data from R worker
+          val boot = dataStream.readDouble - bootTime
+          val init = dataStream.readDouble
+          val broadcast = dataStream.readDouble
+          val input = dataStream.readDouble
+          val compute = dataStream.readDouble
+          val output = dataStream.readDouble
+          logInfo(
+            ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
+              "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
+              "total = %.3f s").format(
+                boot,
+                init,
+                broadcast,
+                input,
+                compute,
+                output,
+                boot + init + broadcast + input + compute + output))
+          read()
+        case length if length >= 0 =>
+          readData(length).asInstanceOf[U]
+      }
+    } catch {
+      case eof: EOFException =>
+        throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
+    }
+  }
+
+  private def readShuffledData(length: Int): (Int, Array[Byte]) = {
+    length match {
+      case length if length == 2 =>
+        val hashedKey = dataStream.readInt()
+        val contentPairsLength = dataStream.readInt()
+        val contentPairs = new Array[Byte](contentPairsLength)
+        dataStream.readFully(contentPairs)
+        (hashedKey, contentPairs)
+      case _ => null
+    }
+  }
+
+  private def readByteArrayData(length: Int): Array[Byte] = {
+    length match {
+      case length if length > 0 =>
+        val obj = new Array[Byte](length)
+        dataStream.readFully(obj)
+        obj
+      case _ => null
+    }
+  }
+
+  private def readStringData(length: Int): String = {
+    length match {
+      case length if length > 0 =>
+        SerDe.readStringBytes(dataStream, length)
+      case _ => null
+    }
+  }
+}
+
+private object SpecialLengths {
+  val TIMING_DATA = -1
+}
+
+private[r] class BufferedStreamThread(
+    in: InputStream,
+    name: String,
+    errBufferSize: Int) extends Thread(name) with Logging {
+  val lines = new Array[String](errBufferSize)
+  var lineIdx = 0
+  override def run() {
+    for (line <- Source.fromInputStream(in).getLines) {
+      synchronized {
+        lines(lineIdx) = line
+        lineIdx = (lineIdx + 1) % errBufferSize
+      }
+      logInfo(line)
+    }
+  }
+
+  def getLines(): String = synchronized {
+    (0 until errBufferSize).filter { x =>
+      lines((x + lineIdx) % errBufferSize) != null
+    }.map { x =>
+      lines((x + lineIdx) % errBufferSize)
+    }.mkString("\n")
+  }
+}
+
+private[r] object RRunner {
+  // Because forking processes from Java is expensive, we prefer to launch
+  // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
+  // This daemon currently only works on UNIX-based systems now, so we should
+  // also fall back to launching workers (worker.R) directly.
+  private[this] var errThread: BufferedStreamThread = _
+  private[this] var daemonChannel: DataOutputStream = _
+
+  /**
+   * Start a thread to print the process's stderr to ours
+   */
+  private def startStdoutThread(proc: Process): BufferedStreamThread = {
+    val BUFFER_SIZE = 100
+    val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
+    thread.setDaemon(true)
+    thread.start()
+    thread
+  }
+
+  private def createRProcess(port: Int, script: String): BufferedStreamThread = {
+    // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command",
+    // but kept here for backward compatibility.
+    val sparkConf = SparkEnv.get.conf
+    var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
+    rCommand = sparkConf.get("spark.r.command", rCommand)
+
+    val rOptions = "--vanilla"
+    val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
+    val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
+    val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript))
+    // Unset the R_TESTS environment variable for workers.
+    // This is set by R CMD check as startup.Rs
+    // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
+    // and confuses worker script which tries to load a non-existent file
+    pb.environment().put("R_TESTS", "")
+    pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
+    pb.environment().put("SPARKR_WORKER_PORT", port.toString)
+    pb.redirectErrorStream(true)  // redirect stderr into stdout
+    val proc = pb.start()
+    val errThread = startStdoutThread(proc)
+    errThread
+  }
+
+  /**
+   * ProcessBuilder used to launch worker R processes.
+   */
+  def createRWorker(port: Int): BufferedStreamThread = {
+    val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
+    if (!Utils.isWindows && useDaemon) {
+      synchronized {
+        if (daemonChannel == null) {
+          // we expect one connections
+          val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
+          val daemonPort = serverSocket.getLocalPort
+          errThread = createRProcess(daemonPort, "daemon.R")
+          // the socket used to send out the input of task
+          serverSocket.setSoTimeout(10000)
+          val sock = serverSocket.accept()
+          daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+          serverSocket.close()
+        }
+        try {
+          daemonChannel.writeInt(port)
+          daemonChannel.flush()
+        } catch {
+          case e: IOException =>
+            // daemon process died
+            daemonChannel.close()
+            daemonChannel = null
+            errThread = null
+            // fail the current task, retry by scheduler
+            throw e
+        }
+        errThread
+      }
+    } else {
+      createRProcess(port, "worker.R")
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org