You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2019/03/07 10:57:04 UTC

[GitHub] [spark] HyukjinKwon commented on a change in pull request #23977: [SPARK-26923][SQL][R] Refactor ArrowRRunner and RRunner to share one BaseRRunner

HyukjinKwon commented on a change in pull request #23977: [SPARK-26923][SQL][R] Refactor ArrowRRunner and RRunner to share one BaseRRunner
URL: https://github.com/apache/spark/pull/23977#discussion_r263331923
 
 

 ##########
 File path: core/src/main/scala/org/apache/spark/api/r/RRunner.scala
 ##########
 @@ -44,380 +35,149 @@ private[spark] class RRunner[U](
     isDataFrame: Boolean = false,
     colNames: Array[String] = null,
     mode: Int = RRunnerModes.RDD)
-  extends Logging {
-  protected var bootTime: Double = _
-  private var dataStream: DataInputStream = _
-  val readData = numPartitions match {
-    case -1 =>
-      serializer match {
-        case SerializationFormats.STRING => readStringData _
-        case _ => readByteArrayData _
-      }
-    case _ => readShuffledData _
-  }
-
-  def compute(
-      inputIterator: Iterator[_],
-      partitionIndex: Int): Iterator[U] = {
-    // Timing start
-    bootTime = System.currentTimeMillis / 1000.0
-
-    // we expect two connections
-    val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost"))
-    val listenPort = serverSocket.getLocalPort()
-
-    // The stdout/stderr is shared by multiple tasks, because we use one daemon
-    // to launch child process as worker.
-    val errThread = RRunner.createRWorker(listenPort)
-
-    // We use two sockets to separate input and output, then it's easy to manage
-    // the lifecycle of them to avoid deadlock.
-    // TODO: optimize it to use one socket
-
-    // the socket used to send out the input of task
-    serverSocket.setSoTimeout(10000)
-    dataStream = try {
-      val inSocket = serverSocket.accept()
-      RRunner.authHelper.authClient(inSocket)
-      startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
-
-      // the socket used to receive the output of task
-      val outSocket = serverSocket.accept()
-      RRunner.authHelper.authClient(outSocket)
-      val inputStream = new BufferedInputStream(outSocket.getInputStream)
-      new DataInputStream(inputStream)
-    } finally {
-      serverSocket.close()
-    }
-
-    try {
-      newReaderIterator(dataStream, errThread)
-    } catch {
-      case e: Exception =>
-        throw new SparkException("R computation failed with\n " + errThread.getLines(), e)
-    }
-  }
+  extends BaseRRunner[IN, OUT](
+    func,
+    deserializer,
+    serializer,
+    packageNames,
+    broadcastVars,
+    numPartitions,
+    isDataFrame,
+    colNames,
+    mode) {
 
   protected def newReaderIterator(
-      dataStream: DataInputStream, errThread: BufferedStreamThread): Iterator[U] = {
-    new Iterator[U] {
-      def next(): U = {
-        val obj = _nextObj
-        if (hasNext()) {
-          _nextObj = read()
-        }
-        obj
+      dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator = {
+    new ReaderIterator(dataStream, errThread) {
+      private val readData = numPartitions match {
+        case -1 =>
+          serializer match {
+            case SerializationFormats.STRING => readStringData _
+            case _ => readByteArrayData _
+          }
+        case _ => readShuffledData _
       }
 
-      private var _nextObj = read()
-
-      def hasNext(): Boolean = {
-        val hasMore = _nextObj != null
-        if (!hasMore) {
-          dataStream.close()
+      private def readShuffledData(length: Int): (Int, Array[Byte]) = {
+        length match {
+          case length if length == 2 =>
+            val hashedKey = dataStream.readInt()
+            val contentPairsLength = dataStream.readInt()
+            val contentPairs = new Array[Byte](contentPairsLength)
+            dataStream.readFully(contentPairs)
+            (hashedKey, contentPairs)
+          case _ => null
         }
-        hasMore
       }
-    }
-  }
 
-  protected def writeData(
-      dataOut: DataOutputStream,
-      printOut: PrintStream,
-      iter: Iterator[_]): Unit = {
-    def writeElem(elem: Any): Unit = {
-      if (deserializer == SerializationFormats.BYTE) {
-        val elemArr = elem.asInstanceOf[Array[Byte]]
-        dataOut.writeInt(elemArr.length)
-        dataOut.write(elemArr)
-      } else if (deserializer == SerializationFormats.ROW) {
-        dataOut.write(elem.asInstanceOf[Array[Byte]])
-      } else if (deserializer == SerializationFormats.STRING) {
-        // write string(for StringRRDD)
-        // scalastyle:off println
-        printOut.println(elem)
-        // scalastyle:on println
+      private def readByteArrayData(length: Int): Array[Byte] = {
+        length match {
+          case length if length > 0 =>
+            val obj = new Array[Byte](length)
+            dataStream.readFully(obj)
+            obj
+          case _ => null
+        }
       }
-    }
 
-    for (elem <- iter) {
-      elem match {
-        case (key, innerIter: Iterator[_]) =>
-          for (innerElem <- innerIter) {
-            writeElem(innerElem)
-          }
-          // Writes key which can be used as a boundary in group-aggregate
-          dataOut.writeByte('r')
-          writeElem(key)
-        case (key, value) =>
-          writeElem(key)
-          writeElem(value)
-        case _ =>
-          writeElem(elem)
+      private def readStringData(length: Int): String = {
+        length match {
+          case length if length > 0 =>
+            SerDe.readStringBytes(dataStream, length)
+          case _ => null
+        }
       }
-    }
-  }
-
-  /**
-   * Start a thread to write RDD data to the R process.
-   */
-  private def startStdinThread(
-      output: OutputStream,
-      iter: Iterator[_],
-      partitionIndex: Int): Unit = {
-    val env = SparkEnv.get
-    val taskContext = TaskContext.get()
-    val bufferSize = System.getProperty(BUFFER_SIZE.key,
-      BUFFER_SIZE.defaultValueString).toInt
-    val stream = new BufferedOutputStream(output, bufferSize)
 
-    new Thread("writer for R") {
-      override def run(): Unit = {
+      /**
+       * Reads next object from the stream.
+       * When the stream reaches end of data, needs to process the following sections,
+       * and then returns null.
+       */
+      override protected def read(): OUT = {
         try {
-          SparkEnv.set(env)
-          TaskContext.setTaskContext(taskContext)
-          val dataOut = new DataOutputStream(stream)
-          dataOut.writeInt(partitionIndex)
-
-          SerDe.writeString(dataOut, deserializer)
-          SerDe.writeString(dataOut, serializer)
-
-          dataOut.writeInt(packageNames.length)
-          dataOut.write(packageNames)
-
-          dataOut.writeInt(func.length)
-          dataOut.write(func)
-
-          dataOut.writeInt(broadcastVars.length)
-          broadcastVars.foreach { broadcast =>
-            // TODO(shivaram): Read a Long in R to avoid this cast
-            dataOut.writeInt(broadcast.id.toInt)
-            // TODO: Pass a byte array from R to avoid this cast ?
-            val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
-            dataOut.writeInt(broadcastByteArr.length)
-            dataOut.write(broadcastByteArr)
-          }
-
-          dataOut.writeInt(numPartitions)
-          dataOut.writeInt(mode)
-
-          if (isDataFrame) {
-            SerDe.writeObject(dataOut, colNames, jvmObjectTracker = null)
+          val length = dataStream.readInt()
+
+          length match {
+            case SpecialLengths.TIMING_DATA =>
+              // Timing data from R worker
+              val boot = dataStream.readDouble - bootTime
+              val init = dataStream.readDouble
+              val broadcast = dataStream.readDouble
+              val input = dataStream.readDouble
+              val compute = dataStream.readDouble
+              val output = dataStream.readDouble
+              logInfo(
+                ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
+                  "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
+                  "total = %.3f s").format(
+                  boot,
+                  init,
+                  broadcast,
+                  input,
+                  compute,
+                  output,
+                  boot + init + broadcast + input + compute + output))
+              read()
+            case length if length > 0 =>
+              readData(length).asInstanceOf[OUT]
+            case length if length == 0 =>
 
 Review comment:
   This is actually being tested by existing tests in R. (I am going to explain from scratch to completely make sure we're synced on this).
   
   Here is when the data format is sent from R process to JVM via socket.
   
   
   ```
   ----------------------------
   | ... # Output data        |
   ----------------------------
   | -1 # Timing data start   |
   | ... # Computed time, etc.|
   ----------------------------
   | 0 # End of stream        |
   ----------------------------
   ```
   
   **Before:**
   
   Previously, RRunner's iterator in this code path stopped iteration when `read()` returns `null`. So, `readData` returned `null` when it's `0`.
   
   **After:**
   
   Now, RRunner needs to mark the end of stream explicitly via setting `eos` to `true` (This prevents when the returned value of `read()` is `null` and it has some meanings about it).
   
   So, in other words, it works if we mark `eos` to `true` at everywhere it returned `null` in the previous RRunner, 
   
   https://github.com/apache/spark/blob/8126d09fb5b969c1e293f1f8c41bec35357f74b5/core/src/main/scala/org/apache/spark/api/r/RRunner.scala#L290
   
   https://github.com/apache/spark/blob/8126d09fb5b969c1e293f1f8c41bec35357f74b5/core/src/main/scala/org/apache/spark/api/r/RRunner.scala#L282
   
   https://github.com/apache/spark/blob/8126d09fb5b969c1e293f1f8c41bec35357f74b5/core/src/main/scala/org/apache/spark/api/r/RRunner.scala#L272
   
   However, explicitly marking the end of stream when 0 is read at one place should be better. So I decided to do it as so.
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org