You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/03/11 23:45:49 UTC

[spark] branch master updated: [SPARK-26923][SQL][R] Refactor ArrowRRunner and RRunner to share one BaseRRunner

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3725b13  [SPARK-26923][SQL][R] Refactor ArrowRRunner and RRunner to share one BaseRRunner
3725b13 is described below

commit 3725b1324f731d57dc776c256bc1a100ec9e6cd0
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Tue Mar 12 08:45:29 2019 +0900

    [SPARK-26923][SQL][R] Refactor ArrowRRunner and RRunner to share one BaseRRunner
    
    ## What changes were proposed in this pull request?
    
    This PR proposes to have one base R runner.
    
    In the high level,
    
    Previously, it had `ArrowRRunner` and it inherited `RRunner`:
    
    ```
    └── RRunner
        └── ArrowRRunner
    ```
    
    After this PR, now it has a `BaseRRunner`, and `ArrowRRunner` and `RRunner` inherit `BaseRRunner`:
    
    ```
    └── BaseRRunner
        ├── ArrowRRunner
        └── RRunner
    ```
    
    This way is consistent with Python's.
    
    In more details, see below:
    
    ```scala
    class BaseRRunner[IN, OUT] {
    
      def compute: Iterator[OUT] = {
        ...
        newWriterThread(...).start()
        ...
        newReaderIterator(...)
        ...
      }
    
      // Make a thread that writes data from JVM to R process
      abstract protected def newWriterThread(..., iter: Iterator[IN], ...): WriterThread
    
      // Make an iterator that reads data from the R process to JVM
      abstract protected def newReaderIterator(...): ReaderIterator
    
      abstract class WriterThread(..., iter: Iterator[IN], ...) extends Thread {
        override def run(): Unit {
          ...
          writeIteratorToStream(...)
          ...
        }
    
        // Actually writing logic to the socket stream.
        abstract protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
      }
    
      abstract class ReaderIterator extends Iterator[OUT] {
        override def hasNext(): Boolean = {
          ...
          read(...)
          ...
        }
    
        override def next(): OUT = {
          ...
          hasNext()
          ...
        }
    
        // Actually reading logic from the socket stream.
        abstract protected def read(...): OUT
      }
    }
    ```
    
    ```scala
    case [Arrow]RRunner extends BaseRRunner {
      override def newWriterThread(...) {
        new WriterThread(...) {
          override def writeIteratorToStream(...) {
            ...
          }
        }
      }
    
      override def newReaderIterator(...) {
        new ReaderIterator(...) {
          override def read(...) {
            ...
          }
        }
      }
    }
    ```
    
    ## How was this patch tested?
    
    Manually tested and existing tests should cover.
    
    Closes #23977 from HyukjinKwon/SPARK-26923.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../api/r/{RRunner.scala => BaseRRunner.scala}     | 302 +++++--------
 .../main/scala/org/apache/spark/api/r/RRDD.scala   |   2 +-
 .../scala/org/apache/spark/api/r/RRunner.scala     | 478 +++++----------------
 .../org/apache/spark/sql/execution/objects.scala   |  24 +-
 .../spark/sql/execution/r/ArrowRRunner.scala       | 140 +++---
 .../sql/execution/r/MapPartitionsRWrapper.scala    |   2 +-
 6 files changed, 309 insertions(+), 639 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala
similarity index 55%
copy from core/src/main/scala/org/apache/spark/api/r/RRunner.scala
copy to core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala
index 971d11f..f96c521 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala
@@ -34,31 +34,23 @@ import org.apache.spark.util.Utils
 /**
  * A helper class to run R UDFs in Spark.
  */
-private[spark] class RRunner[U](
+private[spark] abstract class BaseRRunner[IN, OUT](
     func: Array[Byte],
     deserializer: String,
     serializer: String,
     packageNames: Array[Byte],
     broadcastVars: Array[Broadcast[Object]],
-    numPartitions: Int = -1,
-    isDataFrame: Boolean = false,
-    colNames: Array[String] = null,
-    mode: Int = RRunnerModes.RDD)
+    numPartitions: Int,
+    isDataFrame: Boolean,
+    colNames: Array[String],
+    mode: Int)
   extends Logging {
   protected var bootTime: Double = _
-  private var dataStream: DataInputStream = _
-  val readData = numPartitions match {
-    case -1 =>
-      serializer match {
-        case SerializationFormats.STRING => readStringData _
-        case _ => readByteArrayData _
-      }
-    case _ => readShuffledData _
-  }
+  protected var dataStream: DataInputStream = _
 
   def compute(
-      inputIterator: Iterator[_],
-      partitionIndex: Int): Iterator[U] = {
+      inputIterator: Iterator[IN],
+      partitionIndex: Int): Iterator[OUT] = {
     // Timing start
     bootTime = System.currentTimeMillis / 1000.0
 
@@ -68,7 +60,7 @@ private[spark] class RRunner[U](
 
     // The stdout/stderr is shared by multiple tasks, because we use one daemon
     // to launch child process as worker.
-    val errThread = RRunner.createRWorker(listenPort)
+    val errThread = BaseRRunner.createRWorker(listenPort)
 
     // We use two sockets to separate input and output, then it's easy to manage
     // the lifecycle of them to avoid deadlock.
@@ -78,12 +70,12 @@ private[spark] class RRunner[U](
     serverSocket.setSoTimeout(10000)
     dataStream = try {
       val inSocket = serverSocket.accept()
-      RRunner.authHelper.authClient(inSocket)
-      startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
+      BaseRRunner.authHelper.authClient(inSocket)
+      newWriterThread(inSocket.getOutputStream(), inputIterator, partitionIndex).start()
 
       // the socket used to receive the output of task
       val outSocket = serverSocket.accept()
-      RRunner.authHelper.authClient(outSocket)
+      BaseRRunner.authHelper.authClient(outSocket)
       val inputStream = new BufferedInputStream(outSocket.getInputStream)
       new DataInputStream(inputStream)
     } finally {
@@ -98,197 +90,127 @@ private[spark] class RRunner[U](
     }
   }
 
+  /**
+   * Creates an iterator that reads data from R process.
+   */
   protected def newReaderIterator(
-      dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[U] = {
-    new Iterator[U] {
-      def next(): U = {
-        val obj = _nextObj
-        if (hasNext()) {
-          _nextObj = read()
-        }
-        obj
-      }
-
-      private var _nextObj = read()
+      dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator
 
-      def hasNext(): Boolean = {
-        val hasMore = _nextObj != null
-        if (!hasMore) {
-          dataStream.close()
-        }
-        hasMore
+  /**
+   * Start a thread to write RDD data to the R process.
+   */
+  protected def newWriterThread(
+      output: OutputStream,
+      iter: Iterator[IN],
+      partitionIndex: Int): WriterThread
+
+  abstract class ReaderIterator(
+      stream: DataInputStream,
+      errThread: BufferedStreamThread)
+    extends Iterator[OUT] {
+
+    private var nextObj: OUT = _
+    // eos should be marked as true when the stream is ended.
+    protected var eos = false
+
+    override def hasNext: Boolean = nextObj != null || {
+      if (!eos) {
+        nextObj = read()
+        hasNext
+      } else {
+        false
       }
     }
-  }
 
-  protected def writeData(
-      dataOut: DataOutputStream,
-      printOut: PrintStream,
-      iter: Iterator[_]): Unit = {
-    def writeElem(elem: Any): Unit = {
-      if (deserializer == SerializationFormats.BYTE) {
-        val elemArr = elem.asInstanceOf[Array[Byte]]
-        dataOut.writeInt(elemArr.length)
-        dataOut.write(elemArr)
-      } else if (deserializer == SerializationFormats.ROW) {
-        dataOut.write(elem.asInstanceOf[Array[Byte]])
-      } else if (deserializer == SerializationFormats.STRING) {
-        // write string(for StringRRDD)
-        // scalastyle:off println
-        printOut.println(elem)
-        // scalastyle:on println
+    override def next(): OUT = {
+      if (hasNext) {
+        val obj = nextObj
+        nextObj = null.asInstanceOf[OUT]
+        obj
+      } else {
+        Iterator.empty.next()
       }
     }
 
-    for (elem <- iter) {
-      elem match {
-        case (key, innerIter: Iterator[_]) =>
-          for (innerElem <- innerIter) {
-            writeElem(innerElem)
-          }
-          // Writes key which can be used as a boundary in group-aggregate
-          dataOut.writeByte('r')
-          writeElem(key)
-        case (key, value) =>
-          writeElem(key)
-          writeElem(value)
-        case _ =>
-          writeElem(elem)
-      }
-    }
+    /**
+     * Reads next object from the stream.
+     * When the stream reaches end of data, needs to process the following sections,
+     * and then returns null.
+     */
+    protected def read(): OUT
   }
 
   /**
-   * Start a thread to write RDD data to the R process.
+   * The thread responsible for writing the iterator to the R process.
    */
-  private def startStdinThread(
+  abstract class WriterThread(
       output: OutputStream,
-      iter: Iterator[_],
-      partitionIndex: Int): Unit = {
-    val env = SparkEnv.get
-    val taskContext = TaskContext.get()
-    val bufferSize = System.getProperty(BUFFER_SIZE.key,
-      BUFFER_SIZE.defaultValueString).toInt
-    val stream = new BufferedOutputStream(output, bufferSize)
-
-    new Thread("writer for R") {
-      override def run(): Unit = {
-        try {
-          SparkEnv.set(env)
-          TaskContext.setTaskContext(taskContext)
-          val dataOut = new DataOutputStream(stream)
-          dataOut.writeInt(partitionIndex)
-
-          SerDe.writeString(dataOut, deserializer)
-          SerDe.writeString(dataOut, serializer)
-
-          dataOut.writeInt(packageNames.length)
-          dataOut.write(packageNames)
-
-          dataOut.writeInt(func.length)
-          dataOut.write(func)
-
-          dataOut.writeInt(broadcastVars.length)
-          broadcastVars.foreach { broadcast =>
-            // TODO(shivaram): Read a Long in R to avoid this cast
-            dataOut.writeInt(broadcast.id.toInt)
-            // TODO: Pass a byte array from R to avoid this cast ?
-            val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
-            dataOut.writeInt(broadcastByteArr.length)
-            dataOut.write(broadcastByteArr)
-          }
+      iter: Iterator[IN],
+      partitionIndex: Int)
+    extends Thread("writer for R") {
 
-          dataOut.writeInt(numPartitions)
-          dataOut.writeInt(mode)
-
-          if (isDataFrame) {
-            SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
-          }
-
-          if (!iter.hasNext) {
-            dataOut.writeInt(0)
-          } else {
-            dataOut.writeInt(1)
-          }
-
-          val printOut = new PrintStream(stream)
+    private val env = SparkEnv.get
+    private val taskContext = TaskContext.get()
+    private val bufferSize = System.getProperty(BUFFER_SIZE.key,
+      BUFFER_SIZE.defaultValueString).toInt
+    private val stream = new BufferedOutputStream(output, bufferSize)
+    protected lazy val dataOut = new DataOutputStream(stream)
+    protected lazy val printOut = new PrintStream(stream)
+
+    override def run(): Unit = {
+      try {
+        SparkEnv.set(env)
+        TaskContext.setTaskContext(taskContext)
+        dataOut.writeInt(partitionIndex)
+
+        SerDe.writeString(dataOut, deserializer)
+        SerDe.writeString(dataOut, serializer)
+
+        dataOut.writeInt(packageNames.length)
+        dataOut.write(packageNames)
+
+        dataOut.writeInt(func.length)
+        dataOut.write(func)
+
+        dataOut.writeInt(broadcastVars.length)
+        broadcastVars.foreach { broadcast =>
+          // TODO(shivaram): Read a Long in R to avoid this cast
+          dataOut.writeInt(broadcast.id.toInt)
+          // TODO: Pass a byte array from R to avoid this cast ?
+          val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
+          dataOut.writeInt(broadcastByteArr.length)
+          dataOut.write(broadcastByteArr)
+        }
 
-          writeData(dataOut, printOut, iter)
+        dataOut.writeInt(numPartitions)
+        dataOut.writeInt(mode)
 
-          stream.flush()
-        } catch {
-          // TODO: We should propagate this error to the task thread
-          case e: Exception =>
-            logError("R Writer thread got an exception", e)
-        } finally {
-          Try(output.close())
+        if (isDataFrame) {
+          SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
         }
-      }
-    }.start()
-  }
 
-  private def read(): U = {
-    try {
-      val length = dataStream.readInt()
-
-      length match {
-        case SpecialLengths.TIMING_DATA =>
-          // Timing data from R worker
-          val boot = dataStream.readDouble - bootTime
-          val init = dataStream.readDouble
-          val broadcast = dataStream.readDouble
-          val input = dataStream.readDouble
-          val compute = dataStream.readDouble
-          val output = dataStream.readDouble
-          logInfo(
-            ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
-              "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
-              "total = %.3f s").format(
-                boot,
-                init,
-                broadcast,
-                input,
-                compute,
-                output,
-                boot + init + broadcast + input + compute + output))
-          read()
-        case length if length >= 0 =>
-          readData(length).asInstanceOf[U]
-      }
-    } catch {
-      case eof: EOFException =>
-        throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
-    }
-  }
+        if (!iter.hasNext) {
+          dataOut.writeInt(0)
+        } else {
+          dataOut.writeInt(1)
+        }
 
-  private def readShuffledData(length: Int): (Int, Array[Byte]) = {
-    length match {
-      case length if length == 2 =>
-        val hashedKey = dataStream.readInt()
-        val contentPairsLength = dataStream.readInt()
-        val contentPairs = new Array[Byte](contentPairsLength)
-        dataStream.readFully(contentPairs)
-        (hashedKey, contentPairs)
-      case _ => null
-    }
-  }
+        writeIteratorToStream(dataOut)
 
-  protected def readByteArrayData(length: Int): Array[Byte] = {
-    length match {
-      case length if length > 0 =>
-        val obj = new Array[Byte](length)
-        dataStream.readFully(obj)
-        obj
-      case _ => null
+        stream.flush()
+      } catch {
+        // TODO: We should propagate this error to the task thread
+        case e: Exception =>
+          logError("R Writer thread got an exception", e)
+      } finally {
+        Try(output.close())
+      }
     }
-  }
 
-  private def readStringData(length: Int): String = {
-    length match {
-      case length if length > 0 =>
-        SerDe.readStringBytes(dataStream, length)
-      case _ => null
-    }
+    /**
+     * Writes input data to the stream connected to the R worker.
+     */
+    protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
   }
 }
 
@@ -327,7 +249,7 @@ private[spark] class BufferedStreamThread(
   }
 }
 
-private[r] object RRunner {
+private[r] object BaseRRunner {
   // Because forking processes from Java is expensive, we prefer to launch
   // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
   // This daemon currently only works on UNIX-based systems now, so we should
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 4a59c3e..07f8405 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -43,7 +43,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
   override def getPartitions: Array[Partition] = parent.partitions
 
   override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
-    val runner = new RRunner[U](
+    val runner = new RRunner[T, U](
       func, deserializer, serializer, packageNames, broadcastVars, numPartitions)
 
     // The parent may be also an RRDD, so we should launch it first.
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 971d11f..0327386 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -18,23 +18,14 @@
 package org.apache.spark.api.r
 
 import java.io._
-import java.net.{InetAddress, ServerSocket}
-import java.util.Arrays
-
-import scala.io.Source
-import scala.util.Try
 
 import org.apache.spark._
 import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.BUFFER_SIZE
-import org.apache.spark.internal.config.R._
-import org.apache.spark.util.Utils
 
 /**
  * A helper class to run R UDFs in Spark.
  */
-private[spark] class RRunner[U](
+private[spark] class RRunner[IN, OUT](
     func: Array[Byte],
     deserializer: String,
     serializer: String,
@@ -44,380 +35,149 @@ private[spark] class RRunner[U](
     isDataFrame: Boolean = false,
     colNames: Array[String] = null,
     mode: Int = RRunnerModes.RDD)
-  extends Logging {
-  protected var bootTime: Double = _
-  private var dataStream: DataInputStream = _
-  val readData = numPartitions match {
-    case -1 =>
-      serializer match {
-        case SerializationFormats.STRING => readStringData _
-        case _ => readByteArrayData _
-      }
-    case _ => readShuffledData _
-  }
-
-  def compute(
-      inputIterator: Iterator[_],
-      partitionIndex: Int): Iterator[U] = {
-    // Timing start
-    bootTime = System.currentTimeMillis / 1000.0
-
-    // we expect two connections
-    val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost"))
-    val listenPort = serverSocket.getLocalPort()
-
-    // The stdout/stderr is shared by multiple tasks, because we use one daemon
-    // to launch child process as worker.
-    val errThread = RRunner.createRWorker(listenPort)
-
-    // We use two sockets to separate input and output, then it's easy to manage
-    // the lifecycle of them to avoid deadlock.
-    // TODO: optimize it to use one socket
-
-    // the socket used to send out the input of task
-    serverSocket.setSoTimeout(10000)
-    dataStream = try {
-      val inSocket = serverSocket.accept()
-      RRunner.authHelper.authClient(inSocket)
-      startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
-
-      // the socket used to receive the output of task
-      val outSocket = serverSocket.accept()
-      RRunner.authHelper.authClient(outSocket)
-      val inputStream = new BufferedInputStream(outSocket.getInputStream)
-      new DataInputStream(inputStream)
-    } finally {
-      serverSocket.close()
-    }
-
-    try {
-      newReaderIterator(dataStream, errThread)
-    } catch {
-      case e: Exception =>
-        throw new SparkException("R computation failed with\n " + errThread.getLines(), e)
-    }
-  }
+  extends BaseRRunner[IN, OUT](
+    func,
+    deserializer,
+    serializer,
+    packageNames,
+    broadcastVars,
+    numPartitions,
+    isDataFrame,
+    colNames,
+    mode) {
 
   protected def newReaderIterator(
-      dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[U] = {
-    new Iterator[U] {
-      def next(): U = {
-        val obj = _nextObj
-        if (hasNext()) {
-          _nextObj = read()
-        }
-        obj
+      dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator = {
+    new ReaderIterator(dataStream, errThread) {
+      private val readData = numPartitions match {
+        case -1 =>
+          serializer match {
+            case SerializationFormats.STRING => readStringData _
+            case _ => readByteArrayData _
+          }
+        case _ => readShuffledData _
       }
 
-      private var _nextObj = read()
-
-      def hasNext(): Boolean = {
-        val hasMore = _nextObj != null
-        if (!hasMore) {
-          dataStream.close()
+      private def readShuffledData(length: Int): (Int, Array[Byte]) = {
+        length match {
+          case length if length == 2 =>
+            val hashedKey = dataStream.readInt()
+            val contentPairsLength = dataStream.readInt()
+            val contentPairs = new Array[Byte](contentPairsLength)
+            dataStream.readFully(contentPairs)
+            (hashedKey, contentPairs)
+          case _ => null
         }
-        hasMore
       }
-    }
-  }
 
-  protected def writeData(
-      dataOut: DataOutputStream,
-      printOut: PrintStream,
-      iter: Iterator[_]): Unit = {
-    def writeElem(elem: Any): Unit = {
-      if (deserializer == SerializationFormats.BYTE) {
-        val elemArr = elem.asInstanceOf[Array[Byte]]
-        dataOut.writeInt(elemArr.length)
-        dataOut.write(elemArr)
-      } else if (deserializer == SerializationFormats.ROW) {
-        dataOut.write(elem.asInstanceOf[Array[Byte]])
-      } else if (deserializer == SerializationFormats.STRING) {
-        // write string(for StringRRDD)
-        // scalastyle:off println
-        printOut.println(elem)
-        // scalastyle:on println
+      private def readByteArrayData(length: Int): Array[Byte] = {
+        length match {
+          case length if length > 0 =>
+            val obj = new Array[Byte](length)
+            dataStream.readFully(obj)
+            obj
+          case _ => null
+        }
       }
-    }
 
-    for (elem <- iter) {
-      elem match {
-        case (key, innerIter: Iterator[_]) =>
-          for (innerElem <- innerIter) {
-            writeElem(innerElem)
-          }
-          // Writes key which can be used as a boundary in group-aggregate
-          dataOut.writeByte('r')
-          writeElem(key)
-        case (key, value) =>
-          writeElem(key)
-          writeElem(value)
-        case _ =>
-          writeElem(elem)
+      private def readStringData(length: Int): String = {
+        length match {
+          case length if length > 0 =>
+            SerDe.readStringBytes(dataStream, length)
+          case _ => null
+        }
       }
-    }
-  }
-
-  /**
-   * Start a thread to write RDD data to the R process.
-   */
-  private def startStdinThread(
-      output: OutputStream,
-      iter: Iterator[_],
-      partitionIndex: Int): Unit = {
-    val env = SparkEnv.get
-    val taskContext = TaskContext.get()
-    val bufferSize = System.getProperty(BUFFER_SIZE.key,
-      BUFFER_SIZE.defaultValueString).toInt
-    val stream = new BufferedOutputStream(output, bufferSize)
 
-    new Thread("writer for R") {
-      override def run(): Unit = {
+      /**
+       * Reads next object from the stream.
+       * When the stream reaches end of data, needs to process the following sections,
+       * and then returns null.
+       */
+      override protected def read(): OUT = {
         try {
-          SparkEnv.set(env)
-          TaskContext.setTaskContext(taskContext)
-          val dataOut = new DataOutputStream(stream)
-          dataOut.writeInt(partitionIndex)
-
-          SerDe.writeString(dataOut, deserializer)
-          SerDe.writeString(dataOut, serializer)
-
-          dataOut.writeInt(packageNames.length)
-          dataOut.write(packageNames)
-
-          dataOut.writeInt(func.length)
-          dataOut.write(func)
-
-          dataOut.writeInt(broadcastVars.length)
-          broadcastVars.foreach { broadcast =>
-            // TODO(shivaram): Read a Long in R to avoid this cast
-            dataOut.writeInt(broadcast.id.toInt)
-            // TODO: Pass a byte array from R to avoid this cast ?
-            val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
-            dataOut.writeInt(broadcastByteArr.length)
-            dataOut.write(broadcastByteArr)
-          }
-
-          dataOut.writeInt(numPartitions)
-          dataOut.writeInt(mode)
-
-          if (isDataFrame) {
-            SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
+          val length = dataStream.readInt()
+
+          length match {
+            case SpecialLengths.TIMING_DATA =>
+              // Timing data from R worker
+              val boot = dataStream.readDouble - bootTime
+              val init = dataStream.readDouble
+              val broadcast = dataStream.readDouble
+              val input = dataStream.readDouble
+              val compute = dataStream.readDouble
+              val output = dataStream.readDouble
+              logInfo(
+                ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
+                  "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
+                  "total = %.3f s").format(
+                  boot,
+                  init,
+                  broadcast,
+                  input,
+                  compute,
+                  output,
+                  boot + init + broadcast + input + compute + output))
+              read()
+            case length if length > 0 =>
+              readData(length).asInstanceOf[OUT]
+            case length if length == 0 =>
+              // End of stream
+              eos = true
+              null.asInstanceOf[OUT]
           }
-
-          if (!iter.hasNext) {
-            dataOut.writeInt(0)
-          } else {
-            dataOut.writeInt(1)
-          }
-
-          val printOut = new PrintStream(stream)
-
-          writeData(dataOut, printOut, iter)
-
-          stream.flush()
         } catch {
-          // TODO: We should propagate this error to the task thread
-          case e: Exception =>
-            logError("R Writer thread got an exception", e)
-        } finally {
-          Try(output.close())
+          case eof: EOFException =>
+            throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
         }
       }
-    }.start()
-  }
-
-  private def read(): U = {
-    try {
-      val length = dataStream.readInt()
-
-      length match {
-        case SpecialLengths.TIMING_DATA =>
-          // Timing data from R worker
-          val boot = dataStream.readDouble - bootTime
-          val init = dataStream.readDouble
-          val broadcast = dataStream.readDouble
-          val input = dataStream.readDouble
-          val compute = dataStream.readDouble
-          val output = dataStream.readDouble
-          logInfo(
-            ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
-              "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
-              "total = %.3f s").format(
-                boot,
-                init,
-                broadcast,
-                input,
-                compute,
-                output,
-                boot + init + broadcast + input + compute + output))
-          read()
-        case length if length >= 0 =>
-          readData(length).asInstanceOf[U]
-      }
-    } catch {
-      case eof: EOFException =>
-        throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
     }
   }
 
-  private def readShuffledData(length: Int): (Int, Array[Byte]) = {
-    length match {
-      case length if length == 2 =>
-        val hashedKey = dataStream.readInt()
-        val contentPairsLength = dataStream.readInt()
-        val contentPairs = new Array[Byte](contentPairsLength)
-        dataStream.readFully(contentPairs)
-        (hashedKey, contentPairs)
-      case _ => null
-    }
-  }
-
-  protected def readByteArrayData(length: Int): Array[Byte] = {
-    length match {
-      case length if length > 0 =>
-        val obj = new Array[Byte](length)
-        dataStream.readFully(obj)
-        obj
-      case _ => null
-    }
-  }
-
-  private def readStringData(length: Int): String = {
-    length match {
-      case length if length > 0 =>
-        SerDe.readStringBytes(dataStream, length)
-      case _ => null
-    }
-  }
-}
-
-private[spark] object SpecialLengths {
-  val TIMING_DATA = -1
-}
-
-private[spark] object RRunnerModes {
-  val RDD = 0
-  val DATAFRAME_DAPPLY = 1
-  val DATAFRAME_GAPPLY = 2
-}
-
-private[spark] class BufferedStreamThread(
-    in: InputStream,
-    name: String,
-    errBufferSize: Int) extends Thread(name) with Logging {
-  val lines = new Array[String](errBufferSize)
-  var lineIdx = 0
-  override def run() {
-    for (line <- Source.fromInputStream(in).getLines) {
-      synchronized {
-        lines(lineIdx) = line
-        lineIdx = (lineIdx + 1) % errBufferSize
-      }
-      logInfo(line)
-    }
-  }
-
-  def getLines(): String = synchronized {
-    (0 until errBufferSize).filter { x =>
-      lines((x + lineIdx) % errBufferSize) != null
-    }.map { x =>
-      lines((x + lineIdx) % errBufferSize)
-    }.mkString("\n")
-  }
-}
-
-private[r] object RRunner {
-  // Because forking processes from Java is expensive, we prefer to launch
-  // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
-  // This daemon currently only works on UNIX-based systems now, so we should
-  // also fall back to launching workers (worker.R) directly.
-  private[this] var errThread: BufferedStreamThread = _
-  private[this] var daemonChannel: DataOutputStream = _
-
-  private lazy val authHelper = {
-    val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
-    new RAuthHelper(conf)
-  }
-
-  /**
-   * Start a thread to print the process's stderr to ours
-   */
-  private def startStdoutThread(proc: Process): BufferedStreamThread = {
-    val BUFFER_SIZE = 100
-    val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
-    thread.setDaemon(true)
-    thread.start()
-    thread
-  }
-
-  private def createRProcess(port: Int, script: String): BufferedStreamThread = {
-    // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command",
-    // but kept here for backward compatibility.
-    val sparkConf = SparkEnv.get.conf
-    var rCommand = sparkConf.get(SPARKR_COMMAND)
-    rCommand = sparkConf.get(R_COMMAND).orElse(Some(rCommand)).get
-
-    val rConnectionTimeout = sparkConf.get(R_BACKEND_CONNECTION_TIMEOUT)
-    val rOptions = "--vanilla"
-    val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
-    val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
-    val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript))
-    // Unset the R_TESTS environment variable for workers.
-    // This is set by R CMD check as startup.Rs
-    // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
-    // and confuses worker script which tries to load a non-existent file
-    pb.environment().put("R_TESTS", "")
-    pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
-    pb.environment().put("SPARKR_WORKER_PORT", port.toString)
-    pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString)
-    pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory())
-    pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE")
-    pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret)
-    pb.redirectErrorStream(true)  // redirect stderr into stdout
-    val proc = pb.start()
-    val errThread = startStdoutThread(proc)
-    errThread
-  }
-
   /**
-   * ProcessBuilder used to launch worker R processes.
+   * Start a thread to write RDD data to the R process.
    */
-  def createRWorker(port: Int): BufferedStreamThread = {
-    val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
-    if (!Utils.isWindows && useDaemon) {
-      synchronized {
-        if (daemonChannel == null) {
-          // we expect one connections
-          val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
-          val daemonPort = serverSocket.getLocalPort
-          errThread = createRProcess(daemonPort, "daemon.R")
-          // the socket used to send out the input of task
-          serverSocket.setSoTimeout(10000)
-          val sock = serverSocket.accept()
-          try {
-            authHelper.authClient(sock)
-            daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
-          } finally {
-            serverSocket.close()
+  protected override def newWriterThread(
+      output: OutputStream,
+      iter: Iterator[IN],
+      partitionIndex: Int): WriterThread = {
+    new WriterThread(output, iter, partitionIndex) {
+
+      /**
+       * Writes input data to the stream connected to the R worker.
+       */
+      override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
+        def writeElem(elem: Any): Unit = {
+          if (deserializer == SerializationFormats.BYTE) {
+            val elemArr = elem.asInstanceOf[Array[Byte]]
+            dataOut.writeInt(elemArr.length)
+            dataOut.write(elemArr)
+          } else if (deserializer == SerializationFormats.ROW) {
+            dataOut.write(elem.asInstanceOf[Array[Byte]])
+          } else if (deserializer == SerializationFormats.STRING) {
+            // write string(for StringRRDD)
+            // scalastyle:off println
+            printOut.println(elem)
+            // scalastyle:on println
           }
         }
-        try {
-          daemonChannel.writeInt(port)
-          daemonChannel.flush()
-        } catch {
-          case e: IOException =>
-            // daemon process died
-            daemonChannel.close()
-            daemonChannel = null
-            errThread = null
-            // fail the current task, retry by scheduler
-            throw e
+
+        for (elem <- iter) {
+          elem match {
+            case (key, innerIter: Iterator[_]) =>
+              for (innerElem <- innerIter) {
+                writeElem(innerElem)
+              }
+              // Writes key which can be used as a boundary in group-aggregate
+              dataOut.writeByte('r')
+              writeElem(key)
+            case (key, value) =>
+              writeElem(key)
+              writeElem(value)
+            case _ =>
+              writeElem(elem)
+          }
         }
-        errThread
       }
-    } else {
-      createRProcess(port, "worker.R")
     }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index d298245..bedfa9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution
 
+import java.io.{ByteArrayOutputStream, DataOutputStream}
+
 import scala.collection.JavaConverters._
 import scala.language.existentials
 
@@ -490,7 +492,7 @@ case class FlatMapGroupsInRExec(
       val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
       val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
       val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
-      val runner = new RRunner[Array[Byte]](
+      val runner = new RRunner[(Array[Byte], Iterator[Array[Byte]]), Array[Byte]](
         func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars,
         isDataFrame = true, colNames = inputSchema.fieldNames,
         mode = RRunnerModes.DATAFRAME_GAPPLY)
@@ -548,12 +550,22 @@ case class FlatMapGroupsInRWithArrowExec(
     child.execute().mapPartitionsInternal { iter =>
       val grouped = GroupedIterator(iter, groupingAttributes, child.output)
       val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
-      val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
-        SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY)
 
-      val groupedByRKey = grouped.map { case (key, rowIter) =>
-        val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
-        (newKey, rowIter)
+      val keys = collection.mutable.ArrayBuffer.empty[Array[Byte]]
+      val groupedByRKey: Iterator[Iterator[InternalRow]] =
+        grouped.map { case (key, rowIter) =>
+          keys.append(rowToRBytes(getKey(key).asInstanceOf[Row]))
+          rowIter
+        }
+
+      val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
+        SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY) {
+        protected override def bufferedWrite(
+            dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream => Unit): Unit = {
+          super.bufferedWrite(dataOut)(writeFunc)
+          // Don't forget we're sending keys additionally.
+          keys.foreach(dataOut.write)
+        }
       }
 
       // The communication mechanism is as follows:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
index ee1f2e3..a94cb0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
@@ -47,7 +47,7 @@ class ArrowRRunner(
     schema: StructType,
     timeZoneId: String,
     mode: Int)
-  extends RRunner[ColumnarBatch](
+  extends BaseRRunner[Iterator[InternalRow], ColumnarBatch](
     func,
     "arrow",
     "arrow",
@@ -58,60 +58,10 @@ class ArrowRRunner(
     schema.fieldNames,
     mode) {
 
-  // TODO: it needs to refactor to share the same code with RRunner, and have separate
-  // ArrowRRunners.
-  private val getNextBatch = {
-    if (mode == RRunnerModes.DATAFRAME_GAPPLY) {
-      // gapply
-      (inputIterator: Iterator[_], keys: collection.mutable.ArrayBuffer[Array[Byte]]) => {
-        val (key, nextBatch) = inputIterator
-          .asInstanceOf[Iterator[(Array[Byte], Iterator[InternalRow])]].next()
-        keys.append(key)
-        nextBatch
-      }
-    } else {
-      // dapply
-      (inputIterator: Iterator[_], keys: collection.mutable.ArrayBuffer[Array[Byte]]) => {
-        inputIterator
-          .asInstanceOf[Iterator[Iterator[InternalRow]]].next()
-      }
-    }
-  }
-
-  protected override def writeData(
-      dataOut: DataOutputStream,
-      printOut: PrintStream,
-      inputIterator: Iterator[_]): Unit = if (inputIterator.hasNext) {
-    val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
-    val allocator = ArrowUtils.rootAllocator.newChildAllocator(
-      "stdout writer for R", 0, Long.MaxValue)
-    val root = VectorSchemaRoot.create(arrowSchema, allocator)
+  protected def bufferedWrite(
+      dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream => Unit): Unit = {
     val out = new ByteArrayOutputStream()
-    val keys = collection.mutable.ArrayBuffer.empty[Array[Byte]]
-
-    Utils.tryWithSafeFinally {
-      val arrowWriter = ArrowWriter.create(root)
-      val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out))
-      writer.start()
-
-      while (inputIterator.hasNext) {
-        val nextBatch: Iterator[InternalRow] = getNextBatch(inputIterator, keys)
-
-        while (nextBatch.hasNext) {
-          arrowWriter.write(nextBatch.next())
-        }
-
-        arrowWriter.finish()
-        writer.writeBatch()
-        arrowWriter.reset()
-      }
-      writer.end()
-    } {
-      // Don't close root and allocator in TaskCompletionListener to prevent
-      // a race condition. See `ArrowPythonRunner`.
-      root.close()
-      allocator.close()
-    }
+    writeFunc(out)
 
     // Currently, there looks no way to read batch by batch by socket connection in R side,
     // See ARROW-4512. Therefore, it writes the whole Arrow streaming-formatted binary at
@@ -119,13 +69,57 @@ class ArrowRRunner(
     val data = out.toByteArray
     dataOut.writeInt(data.length)
     dataOut.write(data)
+  }
 
-    keys.foreach(dataOut.write)
+  protected override def newWriterThread(
+      output: OutputStream,
+      inputIterator: Iterator[Iterator[InternalRow]],
+      partitionIndex: Int): WriterThread = {
+    new WriterThread(output, inputIterator, partitionIndex) {
+
+      /**
+       * Writes input data to the stream connected to the R worker.
+       */
+      override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
+        if (inputIterator.hasNext) {
+          val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+          val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+            "stdout writer for R", 0, Long.MaxValue)
+          val root = VectorSchemaRoot.create(arrowSchema, allocator)
+
+          bufferedWrite(dataOut) { out =>
+            Utils.tryWithSafeFinally {
+              val arrowWriter = ArrowWriter.create(root)
+              val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out))
+              writer.start()
+
+              while (inputIterator.hasNext) {
+                val nextBatch: Iterator[InternalRow] = inputIterator.next()
+
+                while (nextBatch.hasNext) {
+                  arrowWriter.write(nextBatch.next())
+                }
+
+                arrowWriter.finish()
+                writer.writeBatch()
+                arrowWriter.reset()
+              }
+              writer.end()
+            } {
+              // Don't close root and allocator in TaskCompletionListener to prevent
+              // a race condition. See `ArrowPythonRunner`.
+              root.close()
+              allocator.close()
+            }
+          }
+        }
+      }
+    }
   }
 
   protected override def newReaderIterator(
-      dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[ColumnarBatch] = {
-    new Iterator[ColumnarBatch] {
+      dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator = {
+    new ReaderIterator(dataStream, errThread) {
       private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
         "stdin reader for R", 0, Long.MaxValue)
 
@@ -141,29 +135,8 @@ class ArrowRRunner(
       }
 
       private var batchLoaded = true
-      private var nextObj: ColumnarBatch = _
-      private var eos = false
-
-      override def hasNext: Boolean = nextObj != null || {
-        if (!eos) {
-          nextObj = read()
-          hasNext
-        } else {
-          false
-        }
-      }
-
-      override def next(): ColumnarBatch = {
-        if (hasNext) {
-          val obj = nextObj
-          nextObj = null.asInstanceOf[ColumnarBatch]
-          obj
-        } else {
-          Iterator.empty.next()
-        }
-      }
 
-      private def read(): ColumnarBatch = try {
+      protected override def read(): ColumnarBatch = try {
         if (reader != null && batchLoaded) {
           batchLoaded = reader.loadNextBatch()
           if (batchLoaded) {
@@ -173,8 +146,8 @@ class ArrowRRunner(
           } else {
             reader.close(false)
             allocator.close()
-            eos = true
-            null
+            // Should read timing data after this.
+            read()
           }
         } else {
           dataStream.readInt() match {
@@ -202,7 +175,9 @@ class ArrowRRunner(
               // Likewise, there looks no way to send each batch in streaming format via socket
               // connection. See ARROW-4512.
               // So, it reads the whole Arrow streaming-formatted binary at once for now.
-              val in = new ByteArrayReadableSeekableByteChannel(readByteArrayData(length))
+              val buffer = new Array[Byte](length)
+              dataStream.readFully(buffer)
+              val in = new ByteArrayReadableSeekableByteChannel(buffer)
               reader = new ArrowStreamReader(in, allocator)
               root = reader.getVectorSchemaRoot
               vectors = root.getFieldVectors.asScala.map { vector =>
@@ -210,6 +185,7 @@ class ArrowRRunner(
               }.toArray[ColumnVector]
               read()
             case length if length == 0 =>
+              // End of stream
               eos = true
               null
           }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
index a62016d..a3a4088 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
@@ -51,7 +51,7 @@ case class MapPartitionsRWrapper(
       SerializationFormats.BYTE
     }
 
-    val runner = new RRunner[Array[Byte]](
+    val runner = new RRunner[Any, Array[Byte]](
       func, deserializer, serializer, packageNames, broadcastVars,
       isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY)
     // Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org