You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2019/03/11 23:45:49 UTC
[spark] branch master updated: [SPARK-26923][SQL][R] Refactor
ArrowRRunner and RRunner to share one BaseRRunner
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3725b13 [SPARK-26923][SQL][R] Refactor ArrowRRunner and RRunner to share one BaseRRunner
3725b13 is described below
commit 3725b1324f731d57dc776c256bc1a100ec9e6cd0
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Tue Mar 12 08:45:29 2019 +0900
[SPARK-26923][SQL][R] Refactor ArrowRRunner and RRunner to share one BaseRRunner
## What changes were proposed in this pull request?
This PR proposes to have one base R runner.
In the high level,
Previously, it had `ArrowRRunner` and it inherited `RRunner`:
```
└── RRunner
└── ArrowRRunner
```
After this PR, now it has a `BaseRRunner`, and `ArrowRRunner` and `RRunner` inherit `BaseRRunner`:
```
└── BaseRRunner
├── ArrowRRunner
└── RRunner
```
This way is consistent with Python's.
In more details, see below:
```scala
class BaseRRunner[IN, OUT] {
def compute: Iterator[OUT] = {
...
newWriterThread(...).start()
...
newReaderIterator(...)
...
}
// Make a thread that writes data from JVM to R process
abstract protected def newWriterThread(..., iter: Iterator[IN], ...): WriterThread
// Make an iterator that reads data from the R process to JVM
abstract protected def newReaderIterator(...): ReaderIterator
abstract class WriterThread(..., iter: Iterator[IN], ...) extends Thread {
override def run(): Unit {
...
writeIteratorToStream(...)
...
}
// Actually writing logic to the socket stream.
abstract protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
}
abstract class ReaderIterator extends Iterator[OUT] {
override def hasNext(): Boolean = {
...
read(...)
...
}
override def next(): OUT = {
...
hasNext()
...
}
// Actually reading logic from the socket stream.
abstract protected def read(...): OUT
}
}
```
```scala
case [Arrow]RRunner extends BaseRRunner {
override def newWriterThread(...) {
new WriterThread(...) {
override def writeIteratorToStream(...) {
...
}
}
}
override def newReaderIterator(...) {
new ReaderIterator(...) {
override def read(...) {
...
}
}
}
}
```
## How was this patch tested?
Manually tested and existing tests should cover.
Closes #23977 from HyukjinKwon/SPARK-26923.
Authored-by: Hyukjin Kwon <gu...@apache.org>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../api/r/{RRunner.scala => BaseRRunner.scala} | 302 +++++--------
.../main/scala/org/apache/spark/api/r/RRDD.scala | 2 +-
.../scala/org/apache/spark/api/r/RRunner.scala | 478 +++++----------------
.../org/apache/spark/sql/execution/objects.scala | 24 +-
.../spark/sql/execution/r/ArrowRRunner.scala | 140 +++---
.../sql/execution/r/MapPartitionsRWrapper.scala | 2 +-
6 files changed, 309 insertions(+), 639 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala
similarity index 55%
copy from core/src/main/scala/org/apache/spark/api/r/RRunner.scala
copy to core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala
index 971d11f..f96c521 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/BaseRRunner.scala
@@ -34,31 +34,23 @@ import org.apache.spark.util.Utils
/**
* A helper class to run R UDFs in Spark.
*/
-private[spark] class RRunner[U](
+private[spark] abstract class BaseRRunner[IN, OUT](
func: Array[Byte],
deserializer: String,
serializer: String,
packageNames: Array[Byte],
broadcastVars: Array[Broadcast[Object]],
- numPartitions: Int = -1,
- isDataFrame: Boolean = false,
- colNames: Array[String] = null,
- mode: Int = RRunnerModes.RDD)
+ numPartitions: Int,
+ isDataFrame: Boolean,
+ colNames: Array[String],
+ mode: Int)
extends Logging {
protected var bootTime: Double = _
- private var dataStream: DataInputStream = _
- val readData = numPartitions match {
- case -1 =>
- serializer match {
- case SerializationFormats.STRING => readStringData _
- case _ => readByteArrayData _
- }
- case _ => readShuffledData _
- }
+ protected var dataStream: DataInputStream = _
def compute(
- inputIterator: Iterator[_],
- partitionIndex: Int): Iterator[U] = {
+ inputIterator: Iterator[IN],
+ partitionIndex: Int): Iterator[OUT] = {
// Timing start
bootTime = System.currentTimeMillis / 1000.0
@@ -68,7 +60,7 @@ private[spark] class RRunner[U](
// The stdout/stderr is shared by multiple tasks, because we use one daemon
// to launch child process as worker.
- val errThread = RRunner.createRWorker(listenPort)
+ val errThread = BaseRRunner.createRWorker(listenPort)
// We use two sockets to separate input and output, then it's easy to manage
// the lifecycle of them to avoid deadlock.
@@ -78,12 +70,12 @@ private[spark] class RRunner[U](
serverSocket.setSoTimeout(10000)
dataStream = try {
val inSocket = serverSocket.accept()
- RRunner.authHelper.authClient(inSocket)
- startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
+ BaseRRunner.authHelper.authClient(inSocket)
+ newWriterThread(inSocket.getOutputStream(), inputIterator, partitionIndex).start()
// the socket used to receive the output of task
val outSocket = serverSocket.accept()
- RRunner.authHelper.authClient(outSocket)
+ BaseRRunner.authHelper.authClient(outSocket)
val inputStream = new BufferedInputStream(outSocket.getInputStream)
new DataInputStream(inputStream)
} finally {
@@ -98,197 +90,127 @@ private[spark] class RRunner[U](
}
}
+ /**
+ * Creates an iterator that reads data from R process.
+ */
protected def newReaderIterator(
- dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[U] = {
- new Iterator[U] {
- def next(): U = {
- val obj = _nextObj
- if (hasNext()) {
- _nextObj = read()
- }
- obj
- }
-
- private var _nextObj = read()
+ dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator
- def hasNext(): Boolean = {
- val hasMore = _nextObj != null
- if (!hasMore) {
- dataStream.close()
- }
- hasMore
+ /**
+ * Start a thread to write RDD data to the R process.
+ */
+ protected def newWriterThread(
+ output: OutputStream,
+ iter: Iterator[IN],
+ partitionIndex: Int): WriterThread
+
+ abstract class ReaderIterator(
+ stream: DataInputStream,
+ errThread: BufferedStreamThread)
+ extends Iterator[OUT] {
+
+ private var nextObj: OUT = _
+ // eos should be marked as true when the stream is ended.
+ protected var eos = false
+
+ override def hasNext: Boolean = nextObj != null || {
+ if (!eos) {
+ nextObj = read()
+ hasNext
+ } else {
+ false
}
}
- }
- protected def writeData(
- dataOut: DataOutputStream,
- printOut: PrintStream,
- iter: Iterator[_]): Unit = {
- def writeElem(elem: Any): Unit = {
- if (deserializer == SerializationFormats.BYTE) {
- val elemArr = elem.asInstanceOf[Array[Byte]]
- dataOut.writeInt(elemArr.length)
- dataOut.write(elemArr)
- } else if (deserializer == SerializationFormats.ROW) {
- dataOut.write(elem.asInstanceOf[Array[Byte]])
- } else if (deserializer == SerializationFormats.STRING) {
- // write string(for StringRRDD)
- // scalastyle:off println
- printOut.println(elem)
- // scalastyle:on println
+ override def next(): OUT = {
+ if (hasNext) {
+ val obj = nextObj
+ nextObj = null.asInstanceOf[OUT]
+ obj
+ } else {
+ Iterator.empty.next()
}
}
- for (elem <- iter) {
- elem match {
- case (key, innerIter: Iterator[_]) =>
- for (innerElem <- innerIter) {
- writeElem(innerElem)
- }
- // Writes key which can be used as a boundary in group-aggregate
- dataOut.writeByte('r')
- writeElem(key)
- case (key, value) =>
- writeElem(key)
- writeElem(value)
- case _ =>
- writeElem(elem)
- }
- }
+ /**
+ * Reads next object from the stream.
+ * When the stream reaches end of data, needs to process the following sections,
+ * and then returns null.
+ */
+ protected def read(): OUT
}
/**
- * Start a thread to write RDD data to the R process.
+ * The thread responsible for writing the iterator to the R process.
*/
- private def startStdinThread(
+ abstract class WriterThread(
output: OutputStream,
- iter: Iterator[_],
- partitionIndex: Int): Unit = {
- val env = SparkEnv.get
- val taskContext = TaskContext.get()
- val bufferSize = System.getProperty(BUFFER_SIZE.key,
- BUFFER_SIZE.defaultValueString).toInt
- val stream = new BufferedOutputStream(output, bufferSize)
-
- new Thread("writer for R") {
- override def run(): Unit = {
- try {
- SparkEnv.set(env)
- TaskContext.setTaskContext(taskContext)
- val dataOut = new DataOutputStream(stream)
- dataOut.writeInt(partitionIndex)
-
- SerDe.writeString(dataOut, deserializer)
- SerDe.writeString(dataOut, serializer)
-
- dataOut.writeInt(packageNames.length)
- dataOut.write(packageNames)
-
- dataOut.writeInt(func.length)
- dataOut.write(func)
-
- dataOut.writeInt(broadcastVars.length)
- broadcastVars.foreach { broadcast =>
- // TODO(shivaram): Read a Long in R to avoid this cast
- dataOut.writeInt(broadcast.id.toInt)
- // TODO: Pass a byte array from R to avoid this cast ?
- val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
- dataOut.writeInt(broadcastByteArr.length)
- dataOut.write(broadcastByteArr)
- }
+ iter: Iterator[IN],
+ partitionIndex: Int)
+ extends Thread("writer for R") {
- dataOut.writeInt(numPartitions)
- dataOut.writeInt(mode)
-
- if (isDataFrame) {
- SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
- }
-
- if (!iter.hasNext) {
- dataOut.writeInt(0)
- } else {
- dataOut.writeInt(1)
- }
-
- val printOut = new PrintStream(stream)
+ private val env = SparkEnv.get
+ private val taskContext = TaskContext.get()
+ private val bufferSize = System.getProperty(BUFFER_SIZE.key,
+ BUFFER_SIZE.defaultValueString).toInt
+ private val stream = new BufferedOutputStream(output, bufferSize)
+ protected lazy val dataOut = new DataOutputStream(stream)
+ protected lazy val printOut = new PrintStream(stream)
+
+ override def run(): Unit = {
+ try {
+ SparkEnv.set(env)
+ TaskContext.setTaskContext(taskContext)
+ dataOut.writeInt(partitionIndex)
+
+ SerDe.writeString(dataOut, deserializer)
+ SerDe.writeString(dataOut, serializer)
+
+ dataOut.writeInt(packageNames.length)
+ dataOut.write(packageNames)
+
+ dataOut.writeInt(func.length)
+ dataOut.write(func)
+
+ dataOut.writeInt(broadcastVars.length)
+ broadcastVars.foreach { broadcast =>
+ // TODO(shivaram): Read a Long in R to avoid this cast
+ dataOut.writeInt(broadcast.id.toInt)
+ // TODO: Pass a byte array from R to avoid this cast ?
+ val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
+ dataOut.writeInt(broadcastByteArr.length)
+ dataOut.write(broadcastByteArr)
+ }
- writeData(dataOut, printOut, iter)
+ dataOut.writeInt(numPartitions)
+ dataOut.writeInt(mode)
- stream.flush()
- } catch {
- // TODO: We should propagate this error to the task thread
- case e: Exception =>
- logError("R Writer thread got an exception", e)
- } finally {
- Try(output.close())
+ if (isDataFrame) {
+ SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
}
- }
- }.start()
- }
- private def read(): U = {
- try {
- val length = dataStream.readInt()
-
- length match {
- case SpecialLengths.TIMING_DATA =>
- // Timing data from R worker
- val boot = dataStream.readDouble - bootTime
- val init = dataStream.readDouble
- val broadcast = dataStream.readDouble
- val input = dataStream.readDouble
- val compute = dataStream.readDouble
- val output = dataStream.readDouble
- logInfo(
- ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
- "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
- "total = %.3f s").format(
- boot,
- init,
- broadcast,
- input,
- compute,
- output,
- boot + init + broadcast + input + compute + output))
- read()
- case length if length >= 0 =>
- readData(length).asInstanceOf[U]
- }
- } catch {
- case eof: EOFException =>
- throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
- }
- }
+ if (!iter.hasNext) {
+ dataOut.writeInt(0)
+ } else {
+ dataOut.writeInt(1)
+ }
- private def readShuffledData(length: Int): (Int, Array[Byte]) = {
- length match {
- case length if length == 2 =>
- val hashedKey = dataStream.readInt()
- val contentPairsLength = dataStream.readInt()
- val contentPairs = new Array[Byte](contentPairsLength)
- dataStream.readFully(contentPairs)
- (hashedKey, contentPairs)
- case _ => null
- }
- }
+ writeIteratorToStream(dataOut)
- protected def readByteArrayData(length: Int): Array[Byte] = {
- length match {
- case length if length > 0 =>
- val obj = new Array[Byte](length)
- dataStream.readFully(obj)
- obj
- case _ => null
+ stream.flush()
+ } catch {
+ // TODO: We should propagate this error to the task thread
+ case e: Exception =>
+ logError("R Writer thread got an exception", e)
+ } finally {
+ Try(output.close())
+ }
}
- }
- private def readStringData(length: Int): String = {
- length match {
- case length if length > 0 =>
- SerDe.readStringBytes(dataStream, length)
- case _ => null
- }
+ /**
+ * Writes input data to the stream connected to the R worker.
+ */
+ protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
}
}
@@ -327,7 +249,7 @@ private[spark] class BufferedStreamThread(
}
}
-private[r] object RRunner {
+private[r] object BaseRRunner {
// Because forking processes from Java is expensive, we prefer to launch
// a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
// This daemon currently only works on UNIX-based systems now, so we should
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 4a59c3e..07f8405 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -43,7 +43,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
override def getPartitions: Array[Partition] = parent.partitions
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
- val runner = new RRunner[U](
+ val runner = new RRunner[T, U](
func, deserializer, serializer, packageNames, broadcastVars, numPartitions)
// The parent may be also an RRDD, so we should launch it first.
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 971d11f..0327386 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -18,23 +18,14 @@
package org.apache.spark.api.r
import java.io._
-import java.net.{InetAddress, ServerSocket}
-import java.util.Arrays
-
-import scala.io.Source
-import scala.util.Try
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.BUFFER_SIZE
-import org.apache.spark.internal.config.R._
-import org.apache.spark.util.Utils
/**
* A helper class to run R UDFs in Spark.
*/
-private[spark] class RRunner[U](
+private[spark] class RRunner[IN, OUT](
func: Array[Byte],
deserializer: String,
serializer: String,
@@ -44,380 +35,149 @@ private[spark] class RRunner[U](
isDataFrame: Boolean = false,
colNames: Array[String] = null,
mode: Int = RRunnerModes.RDD)
- extends Logging {
- protected var bootTime: Double = _
- private var dataStream: DataInputStream = _
- val readData = numPartitions match {
- case -1 =>
- serializer match {
- case SerializationFormats.STRING => readStringData _
- case _ => readByteArrayData _
- }
- case _ => readShuffledData _
- }
-
- def compute(
- inputIterator: Iterator[_],
- partitionIndex: Int): Iterator[U] = {
- // Timing start
- bootTime = System.currentTimeMillis / 1000.0
-
- // we expect two connections
- val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost"))
- val listenPort = serverSocket.getLocalPort()
-
- // The stdout/stderr is shared by multiple tasks, because we use one daemon
- // to launch child process as worker.
- val errThread = RRunner.createRWorker(listenPort)
-
- // We use two sockets to separate input and output, then it's easy to manage
- // the lifecycle of them to avoid deadlock.
- // TODO: optimize it to use one socket
-
- // the socket used to send out the input of task
- serverSocket.setSoTimeout(10000)
- dataStream = try {
- val inSocket = serverSocket.accept()
- RRunner.authHelper.authClient(inSocket)
- startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
-
- // the socket used to receive the output of task
- val outSocket = serverSocket.accept()
- RRunner.authHelper.authClient(outSocket)
- val inputStream = new BufferedInputStream(outSocket.getInputStream)
- new DataInputStream(inputStream)
- } finally {
- serverSocket.close()
- }
-
- try {
- newReaderIterator(dataStream, errThread)
- } catch {
- case e: Exception =>
- throw new SparkException("R computation failed with\n " + errThread.getLines(), e)
- }
- }
+ extends BaseRRunner[IN, OUT](
+ func,
+ deserializer,
+ serializer,
+ packageNames,
+ broadcastVars,
+ numPartitions,
+ isDataFrame,
+ colNames,
+ mode) {
protected def newReaderIterator(
- dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[U] = {
- new Iterator[U] {
- def next(): U = {
- val obj = _nextObj
- if (hasNext()) {
- _nextObj = read()
- }
- obj
+ dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator = {
+ new ReaderIterator(dataStream, errThread) {
+ private val readData = numPartitions match {
+ case -1 =>
+ serializer match {
+ case SerializationFormats.STRING => readStringData _
+ case _ => readByteArrayData _
+ }
+ case _ => readShuffledData _
}
- private var _nextObj = read()
-
- def hasNext(): Boolean = {
- val hasMore = _nextObj != null
- if (!hasMore) {
- dataStream.close()
+ private def readShuffledData(length: Int): (Int, Array[Byte]) = {
+ length match {
+ case length if length == 2 =>
+ val hashedKey = dataStream.readInt()
+ val contentPairsLength = dataStream.readInt()
+ val contentPairs = new Array[Byte](contentPairsLength)
+ dataStream.readFully(contentPairs)
+ (hashedKey, contentPairs)
+ case _ => null
}
- hasMore
}
- }
- }
- protected def writeData(
- dataOut: DataOutputStream,
- printOut: PrintStream,
- iter: Iterator[_]): Unit = {
- def writeElem(elem: Any): Unit = {
- if (deserializer == SerializationFormats.BYTE) {
- val elemArr = elem.asInstanceOf[Array[Byte]]
- dataOut.writeInt(elemArr.length)
- dataOut.write(elemArr)
- } else if (deserializer == SerializationFormats.ROW) {
- dataOut.write(elem.asInstanceOf[Array[Byte]])
- } else if (deserializer == SerializationFormats.STRING) {
- // write string(for StringRRDD)
- // scalastyle:off println
- printOut.println(elem)
- // scalastyle:on println
+ private def readByteArrayData(length: Int): Array[Byte] = {
+ length match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ dataStream.readFully(obj)
+ obj
+ case _ => null
+ }
}
- }
- for (elem <- iter) {
- elem match {
- case (key, innerIter: Iterator[_]) =>
- for (innerElem <- innerIter) {
- writeElem(innerElem)
- }
- // Writes key which can be used as a boundary in group-aggregate
- dataOut.writeByte('r')
- writeElem(key)
- case (key, value) =>
- writeElem(key)
- writeElem(value)
- case _ =>
- writeElem(elem)
+ private def readStringData(length: Int): String = {
+ length match {
+ case length if length > 0 =>
+ SerDe.readStringBytes(dataStream, length)
+ case _ => null
+ }
}
- }
- }
-
- /**
- * Start a thread to write RDD data to the R process.
- */
- private def startStdinThread(
- output: OutputStream,
- iter: Iterator[_],
- partitionIndex: Int): Unit = {
- val env = SparkEnv.get
- val taskContext = TaskContext.get()
- val bufferSize = System.getProperty(BUFFER_SIZE.key,
- BUFFER_SIZE.defaultValueString).toInt
- val stream = new BufferedOutputStream(output, bufferSize)
- new Thread("writer for R") {
- override def run(): Unit = {
+ /**
+ * Reads next object from the stream.
+ * When the stream reaches end of data, needs to process the following sections,
+ * and then returns null.
+ */
+ override protected def read(): OUT = {
try {
- SparkEnv.set(env)
- TaskContext.setTaskContext(taskContext)
- val dataOut = new DataOutputStream(stream)
- dataOut.writeInt(partitionIndex)
-
- SerDe.writeString(dataOut, deserializer)
- SerDe.writeString(dataOut, serializer)
-
- dataOut.writeInt(packageNames.length)
- dataOut.write(packageNames)
-
- dataOut.writeInt(func.length)
- dataOut.write(func)
-
- dataOut.writeInt(broadcastVars.length)
- broadcastVars.foreach { broadcast =>
- // TODO(shivaram): Read a Long in R to avoid this cast
- dataOut.writeInt(broadcast.id.toInt)
- // TODO: Pass a byte array from R to avoid this cast ?
- val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
- dataOut.writeInt(broadcastByteArr.length)
- dataOut.write(broadcastByteArr)
- }
-
- dataOut.writeInt(numPartitions)
- dataOut.writeInt(mode)
-
- if (isDataFrame) {
- SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
+ val length = dataStream.readInt()
+
+ length match {
+ case SpecialLengths.TIMING_DATA =>
+ // Timing data from R worker
+ val boot = dataStream.readDouble - bootTime
+ val init = dataStream.readDouble
+ val broadcast = dataStream.readDouble
+ val input = dataStream.readDouble
+ val compute = dataStream.readDouble
+ val output = dataStream.readDouble
+ logInfo(
+ ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
+ "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
+ "total = %.3f s").format(
+ boot,
+ init,
+ broadcast,
+ input,
+ compute,
+ output,
+ boot + init + broadcast + input + compute + output))
+ read()
+ case length if length > 0 =>
+ readData(length).asInstanceOf[OUT]
+ case length if length == 0 =>
+ // End of stream
+ eos = true
+ null.asInstanceOf[OUT]
}
-
- if (!iter.hasNext) {
- dataOut.writeInt(0)
- } else {
- dataOut.writeInt(1)
- }
-
- val printOut = new PrintStream(stream)
-
- writeData(dataOut, printOut, iter)
-
- stream.flush()
} catch {
- // TODO: We should propagate this error to the task thread
- case e: Exception =>
- logError("R Writer thread got an exception", e)
- } finally {
- Try(output.close())
+ case eof: EOFException =>
+ throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
}
}
- }.start()
- }
-
- private def read(): U = {
- try {
- val length = dataStream.readInt()
-
- length match {
- case SpecialLengths.TIMING_DATA =>
- // Timing data from R worker
- val boot = dataStream.readDouble - bootTime
- val init = dataStream.readDouble
- val broadcast = dataStream.readDouble
- val input = dataStream.readDouble
- val compute = dataStream.readDouble
- val output = dataStream.readDouble
- logInfo(
- ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
- "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
- "total = %.3f s").format(
- boot,
- init,
- broadcast,
- input,
- compute,
- output,
- boot + init + broadcast + input + compute + output))
- read()
- case length if length >= 0 =>
- readData(length).asInstanceOf[U]
- }
- } catch {
- case eof: EOFException =>
- throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
}
}
- private def readShuffledData(length: Int): (Int, Array[Byte]) = {
- length match {
- case length if length == 2 =>
- val hashedKey = dataStream.readInt()
- val contentPairsLength = dataStream.readInt()
- val contentPairs = new Array[Byte](contentPairsLength)
- dataStream.readFully(contentPairs)
- (hashedKey, contentPairs)
- case _ => null
- }
- }
-
- protected def readByteArrayData(length: Int): Array[Byte] = {
- length match {
- case length if length > 0 =>
- val obj = new Array[Byte](length)
- dataStream.readFully(obj)
- obj
- case _ => null
- }
- }
-
- private def readStringData(length: Int): String = {
- length match {
- case length if length > 0 =>
- SerDe.readStringBytes(dataStream, length)
- case _ => null
- }
- }
-}
-
-private[spark] object SpecialLengths {
- val TIMING_DATA = -1
-}
-
-private[spark] object RRunnerModes {
- val RDD = 0
- val DATAFRAME_DAPPLY = 1
- val DATAFRAME_GAPPLY = 2
-}
-
-private[spark] class BufferedStreamThread(
- in: InputStream,
- name: String,
- errBufferSize: Int) extends Thread(name) with Logging {
- val lines = new Array[String](errBufferSize)
- var lineIdx = 0
- override def run() {
- for (line <- Source.fromInputStream(in).getLines) {
- synchronized {
- lines(lineIdx) = line
- lineIdx = (lineIdx + 1) % errBufferSize
- }
- logInfo(line)
- }
- }
-
- def getLines(): String = synchronized {
- (0 until errBufferSize).filter { x =>
- lines((x + lineIdx) % errBufferSize) != null
- }.map { x =>
- lines((x + lineIdx) % errBufferSize)
- }.mkString("\n")
- }
-}
-
-private[r] object RRunner {
- // Because forking processes from Java is expensive, we prefer to launch
- // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
- // This daemon currently only works on UNIX-based systems now, so we should
- // also fall back to launching workers (worker.R) directly.
- private[this] var errThread: BufferedStreamThread = _
- private[this] var daemonChannel: DataOutputStream = _
-
- private lazy val authHelper = {
- val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
- new RAuthHelper(conf)
- }
-
- /**
- * Start a thread to print the process's stderr to ours
- */
- private def startStdoutThread(proc: Process): BufferedStreamThread = {
- val BUFFER_SIZE = 100
- val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
- thread.setDaemon(true)
- thread.start()
- thread
- }
-
- private def createRProcess(port: Int, script: String): BufferedStreamThread = {
- // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command",
- // but kept here for backward compatibility.
- val sparkConf = SparkEnv.get.conf
- var rCommand = sparkConf.get(SPARKR_COMMAND)
- rCommand = sparkConf.get(R_COMMAND).orElse(Some(rCommand)).get
-
- val rConnectionTimeout = sparkConf.get(R_BACKEND_CONNECTION_TIMEOUT)
- val rOptions = "--vanilla"
- val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
- val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
- val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript))
- // Unset the R_TESTS environment variable for workers.
- // This is set by R CMD check as startup.Rs
- // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
- // and confuses worker script which tries to load a non-existent file
- pb.environment().put("R_TESTS", "")
- pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
- pb.environment().put("SPARKR_WORKER_PORT", port.toString)
- pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString)
- pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory())
- pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE")
- pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret)
- pb.redirectErrorStream(true) // redirect stderr into stdout
- val proc = pb.start()
- val errThread = startStdoutThread(proc)
- errThread
- }
-
/**
- * ProcessBuilder used to launch worker R processes.
+ * Start a thread to write RDD data to the R process.
*/
- def createRWorker(port: Int): BufferedStreamThread = {
- val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
- if (!Utils.isWindows && useDaemon) {
- synchronized {
- if (daemonChannel == null) {
- // we expect one connections
- val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
- val daemonPort = serverSocket.getLocalPort
- errThread = createRProcess(daemonPort, "daemon.R")
- // the socket used to send out the input of task
- serverSocket.setSoTimeout(10000)
- val sock = serverSocket.accept()
- try {
- authHelper.authClient(sock)
- daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
- } finally {
- serverSocket.close()
+ protected override def newWriterThread(
+ output: OutputStream,
+ iter: Iterator[IN],
+ partitionIndex: Int): WriterThread = {
+ new WriterThread(output, iter, partitionIndex) {
+
+ /**
+ * Writes input data to the stream connected to the R worker.
+ */
+ override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
+ def writeElem(elem: Any): Unit = {
+ if (deserializer == SerializationFormats.BYTE) {
+ val elemArr = elem.asInstanceOf[Array[Byte]]
+ dataOut.writeInt(elemArr.length)
+ dataOut.write(elemArr)
+ } else if (deserializer == SerializationFormats.ROW) {
+ dataOut.write(elem.asInstanceOf[Array[Byte]])
+ } else if (deserializer == SerializationFormats.STRING) {
+ // write string(for StringRRDD)
+ // scalastyle:off println
+ printOut.println(elem)
+ // scalastyle:on println
}
}
- try {
- daemonChannel.writeInt(port)
- daemonChannel.flush()
- } catch {
- case e: IOException =>
- // daemon process died
- daemonChannel.close()
- daemonChannel = null
- errThread = null
- // fail the current task, retry by scheduler
- throw e
+
+ for (elem <- iter) {
+ elem match {
+ case (key, innerIter: Iterator[_]) =>
+ for (innerElem <- innerIter) {
+ writeElem(innerElem)
+ }
+ // Writes key which can be used as a boundary in group-aggregate
+ dataOut.writeByte('r')
+ writeElem(key)
+ case (key, value) =>
+ writeElem(key)
+ writeElem(value)
+ case _ =>
+ writeElem(elem)
+ }
}
- errThread
}
- } else {
- createRProcess(port, "worker.R")
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index d298245..bedfa9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution
+import java.io.{ByteArrayOutputStream, DataOutputStream}
+
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -490,7 +492,7 @@ case class FlatMapGroupsInRExec(
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
- val runner = new RRunner[Array[Byte]](
+ val runner = new RRunner[(Array[Byte], Iterator[Array[Byte]]), Array[Byte]](
func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars,
isDataFrame = true, colNames = inputSchema.fieldNames,
mode = RRunnerModes.DATAFRAME_GAPPLY)
@@ -548,12 +550,22 @@ case class FlatMapGroupsInRWithArrowExec(
child.execute().mapPartitionsInternal { iter =>
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
- val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
- SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY)
- val groupedByRKey = grouped.map { case (key, rowIter) =>
- val newKey = rowToRBytes(getKey(key).asInstanceOf[Row])
- (newKey, rowIter)
+ val keys = collection.mutable.ArrayBuffer.empty[Array[Byte]]
+ val groupedByRKey: Iterator[Iterator[InternalRow]] =
+ grouped.map { case (key, rowIter) =>
+ keys.append(rowToRBytes(getKey(key).asInstanceOf[Row]))
+ rowIter
+ }
+
+ val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
+ SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_GAPPLY) {
+ protected override def bufferedWrite(
+ dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream => Unit): Unit = {
+ super.bufferedWrite(dataOut)(writeFunc)
+ // Don't forget we're sending keys additionally.
+ keys.foreach(dataOut.write)
+ }
}
// The communication mechanism is as follows:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
index ee1f2e3..a94cb0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala
@@ -47,7 +47,7 @@ class ArrowRRunner(
schema: StructType,
timeZoneId: String,
mode: Int)
- extends RRunner[ColumnarBatch](
+ extends BaseRRunner[Iterator[InternalRow], ColumnarBatch](
func,
"arrow",
"arrow",
@@ -58,60 +58,10 @@ class ArrowRRunner(
schema.fieldNames,
mode) {
- // TODO: it needs to refactor to share the same code with RRunner, and have separate
- // ArrowRRunners.
- private val getNextBatch = {
- if (mode == RRunnerModes.DATAFRAME_GAPPLY) {
- // gapply
- (inputIterator: Iterator[_], keys: collection.mutable.ArrayBuffer[Array[Byte]]) => {
- val (key, nextBatch) = inputIterator
- .asInstanceOf[Iterator[(Array[Byte], Iterator[InternalRow])]].next()
- keys.append(key)
- nextBatch
- }
- } else {
- // dapply
- (inputIterator: Iterator[_], keys: collection.mutable.ArrayBuffer[Array[Byte]]) => {
- inputIterator
- .asInstanceOf[Iterator[Iterator[InternalRow]]].next()
- }
- }
- }
-
- protected override def writeData(
- dataOut: DataOutputStream,
- printOut: PrintStream,
- inputIterator: Iterator[_]): Unit = if (inputIterator.hasNext) {
- val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
- val allocator = ArrowUtils.rootAllocator.newChildAllocator(
- "stdout writer for R", 0, Long.MaxValue)
- val root = VectorSchemaRoot.create(arrowSchema, allocator)
+ protected def bufferedWrite(
+ dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream => Unit): Unit = {
val out = new ByteArrayOutputStream()
- val keys = collection.mutable.ArrayBuffer.empty[Array[Byte]]
-
- Utils.tryWithSafeFinally {
- val arrowWriter = ArrowWriter.create(root)
- val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out))
- writer.start()
-
- while (inputIterator.hasNext) {
- val nextBatch: Iterator[InternalRow] = getNextBatch(inputIterator, keys)
-
- while (nextBatch.hasNext) {
- arrowWriter.write(nextBatch.next())
- }
-
- arrowWriter.finish()
- writer.writeBatch()
- arrowWriter.reset()
- }
- writer.end()
- } {
- // Don't close root and allocator in TaskCompletionListener to prevent
- // a race condition. See `ArrowPythonRunner`.
- root.close()
- allocator.close()
- }
+ writeFunc(out)
// Currently, there looks no way to read batch by batch by socket connection in R side,
// See ARROW-4512. Therefore, it writes the whole Arrow streaming-formatted binary at
@@ -119,13 +69,57 @@ class ArrowRRunner(
val data = out.toByteArray
dataOut.writeInt(data.length)
dataOut.write(data)
+ }
- keys.foreach(dataOut.write)
+ protected override def newWriterThread(
+ output: OutputStream,
+ inputIterator: Iterator[Iterator[InternalRow]],
+ partitionIndex: Int): WriterThread = {
+ new WriterThread(output, inputIterator, partitionIndex) {
+
+ /**
+ * Writes input data to the stream connected to the R worker.
+ */
+ override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
+ if (inputIterator.hasNext) {
+ val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+ "stdout writer for R", 0, Long.MaxValue)
+ val root = VectorSchemaRoot.create(arrowSchema, allocator)
+
+ bufferedWrite(dataOut) { out =>
+ Utils.tryWithSafeFinally {
+ val arrowWriter = ArrowWriter.create(root)
+ val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out))
+ writer.start()
+
+ while (inputIterator.hasNext) {
+ val nextBatch: Iterator[InternalRow] = inputIterator.next()
+
+ while (nextBatch.hasNext) {
+ arrowWriter.write(nextBatch.next())
+ }
+
+ arrowWriter.finish()
+ writer.writeBatch()
+ arrowWriter.reset()
+ }
+ writer.end()
+ } {
+ // Don't close root and allocator in TaskCompletionListener to prevent
+ // a race condition. See `ArrowPythonRunner`.
+ root.close()
+ allocator.close()
+ }
+ }
+ }
+ }
+ }
}
protected override def newReaderIterator(
- dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[ColumnarBatch] = {
- new Iterator[ColumnarBatch] {
+ dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator = {
+ new ReaderIterator(dataStream, errThread) {
private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
"stdin reader for R", 0, Long.MaxValue)
@@ -141,29 +135,8 @@ class ArrowRRunner(
}
private var batchLoaded = true
- private var nextObj: ColumnarBatch = _
- private var eos = false
-
- override def hasNext: Boolean = nextObj != null || {
- if (!eos) {
- nextObj = read()
- hasNext
- } else {
- false
- }
- }
-
- override def next(): ColumnarBatch = {
- if (hasNext) {
- val obj = nextObj
- nextObj = null.asInstanceOf[ColumnarBatch]
- obj
- } else {
- Iterator.empty.next()
- }
- }
- private def read(): ColumnarBatch = try {
+ protected override def read(): ColumnarBatch = try {
if (reader != null && batchLoaded) {
batchLoaded = reader.loadNextBatch()
if (batchLoaded) {
@@ -173,8 +146,8 @@ class ArrowRRunner(
} else {
reader.close(false)
allocator.close()
- eos = true
- null
+ // Should read timing data after this.
+ read()
}
} else {
dataStream.readInt() match {
@@ -202,7 +175,9 @@ class ArrowRRunner(
// Likewise, there looks no way to send each batch in streaming format via socket
// connection. See ARROW-4512.
// So, it reads the whole Arrow streaming-formatted binary at once for now.
- val in = new ByteArrayReadableSeekableByteChannel(readByteArrayData(length))
+ val buffer = new Array[Byte](length)
+ dataStream.readFully(buffer)
+ val in = new ByteArrayReadableSeekableByteChannel(buffer)
reader = new ArrowStreamReader(in, allocator)
root = reader.getVectorSchemaRoot
vectors = root.getFieldVectors.asScala.map { vector =>
@@ -210,6 +185,7 @@ class ArrowRRunner(
}.toArray[ColumnVector]
read()
case length if length == 0 =>
+ // End of stream
eos = true
null
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
index a62016d..a3a4088 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/MapPartitionsRWrapper.scala
@@ -51,7 +51,7 @@ case class MapPartitionsRWrapper(
SerializationFormats.BYTE
}
- val runner = new RRunner[Array[Byte]](
+ val runner = new RRunner[Any, Array[Byte]](
func, deserializer, serializer, packageNames, broadcastVars,
isDataFrame = true, colNames = colNames, mode = RRunnerModes.DATAFRAME_DAPPLY)
// Partition index is ignored. Dataset has no support for mapPartitionsWithIndex.
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org