You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/07/25 03:28:25 UTC

[spark] branch master updated: [SPARK-39840][SQL][PYTHON] Factor PythonArrowInput out as a symmetry to PythonArrowOutput

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e1467f748e [SPARK-39840][SQL][PYTHON] Factor PythonArrowInput out as a symmetry to PythonArrowOutput
2e1467f748e is described below

commit 2e1467f748e35d5bdfa302f3af2906dc55c9591e
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Mon Jul 25 12:28:11 2022 +0900

    [SPARK-39840][SQL][PYTHON] Factor PythonArrowInput out as a symmetry to PythonArrowOutput
    
    ### What changes were proposed in this pull request?
    
    This PR factors the Arrow input code path out as `PythonArrowInput` as symmetry to `PythonArrowOutput`. The current hierarchy is not affected:
    
    ```
        └── BasePythonRunner
            ├── ArrowPythonRunner with PythonArrowOutput with PythonArrowInput
            ├── CoGroupedArrowPythonRunner with PythonArrowOutput
            ├── PythonRunner
            └── PythonUDFRunner
    ```
    
    In addition, this PR also factors out `handleMetadataAfterExec` and `handleMetadataBeforeExec` which contains the logic to send and receive the metadata such as runtime configurations specific to Arrow in/out.
    
    ### Why are the changes needed?
    
    https://github.com/apache/spark/commit/40485f4656261231251f1d68b5b7d8b3b8600372 factored `PythonArrowOutput` out. It's better to factor `PythonArrowInput` out too to be consistent
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, this is refactoring.
    
    ### How was this patch tested?
    
    Existing test cases should cover.
    
    Closes #37253 from HyukjinKwon/pyarrow-output-trait.
    
    Authored-by: Hyukjin Kwon <gu...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../org/apache/spark/api/python/PythonRunner.scala |  6 +-
 .../sql/execution/python/ArrowPythonRunner.scala   | 82 ++--------------------
 ...owPythonRunner.scala => PythonArrowInput.scala} | 54 ++++++--------
 .../sql/execution/python/PythonArrowOutput.scala   |  7 ++
 4 files changed, 35 insertions(+), 114 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index efe31be897b..5a13674e8bf 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -84,9 +84,9 @@ private object BasePythonRunner {
  * functions (from bottom to top).
  */
 private[spark] abstract class BasePythonRunner[IN, OUT](
-    funcs: Seq[ChainedPythonFunctions],
-    evalType: Int,
-    argOffsets: Array[Array[Int]])
+    protected val funcs: Seq[ChainedPythonFunctions],
+    protected val evalType: Int,
+    protected val argOffsets: Array[Array[Int]])
   extends Logging {
 
   require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index 7171c7f7f97..137e2fe93c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -17,21 +17,11 @@
 
 package org.apache.spark.sql.execution.python
 
-import java.io._
-import java.net._
-
-import org.apache.arrow.vector.VectorSchemaRoot
-import org.apache.arrow.vector.ipc.ArrowStreamWriter
-
-import org.apache.spark._
 import org.apache.spark.api.python._
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.arrow.ArrowWriter
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.sql.util.ArrowUtils
 import org.apache.spark.sql.vectorized.ColumnarBatch
-import org.apache.spark.util.Utils
 
 /**
  * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream.
@@ -40,10 +30,11 @@ class ArrowPythonRunner(
     funcs: Seq[ChainedPythonFunctions],
     evalType: Int,
     argOffsets: Array[Array[Int]],
-    schema: StructType,
-    timeZoneId: String,
-    conf: Map[String, String])
+    protected override val schema: StructType,
+    protected override val timeZoneId: String,
+    protected override val workerConf: Map[String, String])
   extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets)
+  with PythonArrowInput
   with PythonArrowOutput {
 
   override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
@@ -53,69 +44,4 @@ class ArrowPythonRunner(
     bufferSize >= 4,
     "Pandas execution requires more than 4 bytes. Please set higher buffer. " +
       s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
-
-  protected override def newWriterThread(
-      env: SparkEnv,
-      worker: Socket,
-      inputIterator: Iterator[Iterator[InternalRow]],
-      partitionIndex: Int,
-      context: TaskContext): WriterThread = {
-    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
-
-      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-
-        // Write config for the worker as a number of key -> value pairs of strings
-        dataOut.writeInt(conf.size)
-        for ((k, v) <- conf) {
-          PythonRDD.writeUTF(k, dataOut)
-          PythonRDD.writeUTF(v, dataOut)
-        }
-
-        PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
-      }
-
-      protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = {
-        val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId)
-        val allocator = ArrowUtils.rootAllocator.newChildAllocator(
-          s"stdout writer for $pythonExec", 0, Long.MaxValue)
-        val root = VectorSchemaRoot.create(arrowSchema, allocator)
-
-        Utils.tryWithSafeFinally {
-          val arrowWriter = ArrowWriter.create(root)
-          val writer = new ArrowStreamWriter(root, null, dataOut)
-          writer.start()
-
-          while (inputIterator.hasNext) {
-            val nextBatch = inputIterator.next()
-
-            while (nextBatch.hasNext) {
-              arrowWriter.write(nextBatch.next())
-            }
-
-            arrowWriter.finish()
-            writer.writeBatch()
-            arrowWriter.reset()
-          }
-          // end writes footer to the output stream and doesn't clean any resources.
-          // It could throw exception if the output stream is closed, so it should be
-          // in the try block.
-          writer.end()
-        } {
-          // If we close root and allocator in TaskCompletionListener, there could be a race
-          // condition where the writer thread keeps writing to the VectorSchemaRoot while
-          // it's being closed by the TaskCompletion listener.
-          // Closing root and allocator here is cleaner because root and allocator is owned
-          // by the writer thread and is only visible to the writer thread.
-          //
-          // If the writer thread is interrupted by TaskCompletionListener, it should either
-          // (1) in the try block, in which case it will get an InterruptedException when
-          // performing io, and goes into the finally block or (2) in the finally block,
-          // in which case it will ignore the interruption and close the resources.
-          root.close()
-          allocator.close()
-        }
-      }
-    }
-  }
-
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
similarity index 74%
copy from sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
copy to sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 7171c7f7f97..79365080f8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -14,45 +14,41 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-
 package org.apache.spark.sql.execution.python
 
-import java.io._
-import java.net._
+import java.io.DataOutputStream
+import java.net.Socket
 
 import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
 
-import org.apache.spark._
-import org.apache.spark.api.python._
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.api.python.{BasePythonRunner, PythonRDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.arrow.ArrowWriter
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.util.Utils
 
 /**
- * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream.
+ * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from
+ * JVM (an iterator of internal rows) to Python (Arrow).
  */
-class ArrowPythonRunner(
-    funcs: Seq[ChainedPythonFunctions],
-    evalType: Int,
-    argOffsets: Array[Array[Int]],
-    schema: StructType,
-    timeZoneId: String,
-    conf: Map[String, String])
-  extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](funcs, evalType, argOffsets)
-  with PythonArrowOutput {
+private[python] trait PythonArrowInput { self: BasePythonRunner[Iterator[InternalRow], _] =>
+  protected val workerConf: Map[String, String]
+
+  protected val schema: StructType
 
-  override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+  protected val timeZoneId: String
 
-  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
-  require(
-    bufferSize >= 4,
-    "Pandas execution requires more than 4 bytes. Please set higher buffer. " +
-      s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
+  protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
+    // Write config for the worker as a number of key -> value pairs of strings
+    stream.writeInt(workerConf.size)
+    for ((k, v) <- workerConf) {
+      PythonRDD.writeUTF(k, stream)
+      PythonRDD.writeUTF(v, stream)
+    }
+  }
 
   protected override def newWriterThread(
       env: SparkEnv,
@@ -63,14 +59,7 @@ class ArrowPythonRunner(
     new WriterThread(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-
-        // Write config for the worker as a number of key -> value pairs of strings
-        dataOut.writeInt(conf.size)
-        for ((k, v) <- conf) {
-          PythonRDD.writeUTF(k, dataOut)
-          PythonRDD.writeUTF(v, dataOut)
-        }
-
+        handleMetadataBeforeExec(dataOut)
         PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
       }
 
@@ -117,5 +106,4 @@ class ArrowPythonRunner(
       }
     }
   }
-
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index 00bab1e9fbf..d06a0d01299 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -37,6 +37,8 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, Column
  */
 private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatch] =>
 
+  protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
+
   protected def newReaderIterator(
       stream: DataInputStream,
       writerThread: WriterThread,
@@ -67,6 +69,11 @@ private[python] trait PythonArrowOutput { self: BasePythonRunner[_, ColumnarBatc
 
       private var batchLoaded = true
 
+      protected override def handleEndOfDataSection(): Unit = {
+        handleMetadataAfterExec(stream)
+        super.handleEndOfDataSection()
+      }
+
       protected override def read(): ColumnarBatch = {
         if (writerThread.exception.isDefined) {
           throw writerThread.exception.get


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org