You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2023/10/13 00:02:55 UTC

[spark] branch master updated: [SPARK-45505][PYTHON] Refactor analyzeInPython to make it reusable

This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 280f6b33110d [SPARK-45505][PYTHON] Refactor analyzeInPython to make it reusable
280f6b33110d is described below

commit 280f6b33110d707ebee6fec6e5bafa45b45213ae
Author: allisonwang-db <al...@databricks.com>
AuthorDate: Thu Oct 12 17:02:41 2023 -0700

    [SPARK-45505][PYTHON] Refactor analyzeInPython to make it reusable
    
    ### What changes were proposed in this pull request?
    
    Currently, the `analyzeInPython` method in UserDefinedPythonTableFunction object can starts a Python process in driver and run a Python function in the Python process. This PR aims to refactor this logic into a reusable runner class.
    
    ### Why are the changes needed?
    
    To make the code more reusable.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43340 from allisonwang-db/spark-45505-refactor-analyze-in-py.
    
    Authored-by: allisonwang-db <al...@databricks.com>
    Signed-off-by: Takuya UESHIN <ue...@databricks.com>
---
 python/pyspark/sql/worker/analyze_udtf.py          |   6 +-
 .../org/apache/spark/sql/internal/SQLConf.scala    |  18 +-
 .../sql/execution/python/PythonPlannerRunner.scala | 177 ++++++++++++
 .../python/UserDefinedPythonFunction.scala         | 321 +++++++--------------
 4 files changed, 286 insertions(+), 236 deletions(-)

diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py
index a6aa381eb14a..9e84b880fc96 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -98,14 +98,14 @@ def main(infile: IO, outfile: IO) -> None:
     """
     Runs the Python UDTF's `analyze` static method.
 
-    This process will be invoked from `UserDefinedPythonTableFunction.analyzeInPython` in JVM
-    and receive the Python UDTF and its arguments for the `analyze` static method,
+    This process will be invoked from `UserDefinedPythonTableFunctionAnalyzeRunner.runInPython`
+    in JVM and receive the Python UDTF and its arguments for the `analyze` static method,
     and call the `analyze` static method, and send back a AnalyzeResult as a result of the method.
     """
     try:
         check_python_version(infile)
 
-        memory_limit_mb = int(os.environ.get("PYSPARK_UDTF_ANALYZER_MEMORY_MB", "-1"))
+        memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
         setup_memory_limits(memory_limit_mb)
 
         setup_spark_files(infile)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 12ec9e911d31..000694f6f1bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3008,14 +3008,14 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
-  val PYTHON_TABLE_UDF_ANALYZER_MEMORY =
-    buildConf("spark.sql.analyzer.pythonUDTF.analyzeInPython.memory")
-      .doc("The amount of memory to be allocated to PySpark for Python UDTF analyzer, in MiB " +
-        "unless otherwise specified. If set, PySpark memory for Python UDTF analyzer will be " +
-        "limited to this amount. If not set, Spark will not limit Python's " +
-        "memory use and it is up to the application to avoid exceeding the overhead memory space " +
-        "shared with other non-JVM processes.\nNote: Windows does not support resource limiting " +
-        "and actual resource is not limited on MacOS.")
+  val PYTHON_PLANNER_EXEC_MEMORY =
+    buildConf("spark.sql.planner.pythonExecution.memory")
+      .doc("Specifies the memory allocation for executing Python code in Spark driver, in MiB. " +
+        "When set, it caps the memory for Python execution to the specified amount. " +
+        "If not set, Spark will not limit Python's memory usage and it is up to the application " +
+        "to avoid exceeding the overhead memory space shared with other non-JVM processes.\n" +
+        "Note: Windows does not support resource limiting and actual resource is not limited " +
+        "on MacOS.")
       .version("4.0.0")
       .bytesConf(ByteUnit.MiB)
       .createOptional
@@ -5157,7 +5157,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
   def pysparkWorkerPythonExecutable: Option[String] =
     getConf(SQLConf.PYSPARK_WORKER_PYTHON_EXECUTABLE)
 
-  def pythonUDTFAnalyzerMemory: Option[Long] = getConf(PYTHON_TABLE_UDF_ANALYZER_MEMORY)
+  def pythonPlannerExecMemory: Option[Long] = getConf(PYTHON_PLANNER_EXEC_MEMORY)
 
   def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
new file mode 100644
index 000000000000..183b96bb982c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.SelectionKey
+import java.util.HashMap
+
+import scala.jdk.CollectionConverters._
+
+import net.razorvine.pickle.Pickler
+
+import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException}
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.DirectByteBufferOutputStream
+
+/**
+ * A helper class to run Python functions in Spark driver.
+ */
+abstract class PythonPlannerRunner[T](func: PythonFunction) {
+
+  protected val workerModule: String
+
+  protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit
+
+  protected def receiveFromPython(dataIn: DataInputStream): T
+
+  def runInPython(): T = {
+    val env = SparkEnv.get
+    val bufferSize: Int = env.conf.get(BUFFER_SIZE)
+    val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+    val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
+    val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+    val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory
+
+    val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
+    val envVars = new HashMap[String, String](func.envVars)
+    val pythonExec = func.pythonExec
+    val pythonVer = func.pythonVer
+    val pythonIncludes = func.pythonIncludes.asScala.toSet
+    val broadcastVars = func.broadcastVars.asScala.toSeq
+    val maybeAccumulator = Option(func.accumulator).map(_.copyAndReset())
+
+    envVars.put("SPARK_LOCAL_DIRS", localdir)
+    if (reuseWorker) {
+      envVars.put("SPARK_REUSE_WORKER", "1")
+    }
+    if (simplifiedTraceback) {
+      envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
+    }
+    workerMemoryMb.foreach { memoryMb =>
+      envVars.put("PYSPARK_PLANNER_MEMORY_MB", memoryMb.toString)
+    }
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))
+
+    EvaluatePython.registerPicklers()
+    val pickler = new Pickler(/* useMemo = */ true,
+      /* valueCompare = */ false)
+
+    val (worker: PythonWorker, _) =
+      env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
+    var releasedOrClosed = false
+    val bufferStream = new DirectByteBufferOutputStream()
+    try {
+      val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize))
+
+      PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+      PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut)
+      PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
+
+      writeToPython(dataOut, pickler)
+
+      dataOut.writeInt(SpecialLengths.END_OF_STREAM)
+      dataOut.flush()
+
+      val dataIn = new DataInputStream(new BufferedInputStream(
+        new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize))
+
+      val res = receiveFromPython(dataIn)
+
+      PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn)
+      Option(func.accumulator).foreach(_.merge(maybeAccumulator.get))
+
+      dataIn.readInt() match {
+        case SpecialLengths.END_OF_STREAM if reuseWorker =>
+          env.releasePythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
+        case _ =>
+          env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
+      }
+      releasedOrClosed = true
+
+      res
+    } catch {
+      case eof: EOFException =>
+        throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
+    } finally {
+      try {
+        bufferStream.close()
+      } finally {
+        if (!releasedOrClosed) {
+          // An error happened. Force to close the worker.
+          env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
+        }
+      }
+    }
+  }
+
+  /**
+   * A wrapper of the non-blocking IO to write to/read from the worker.
+   *
+   * Since we use non-blocking IO to communicate with workers; see SPARK-44705,
+   * a wrapper is needed to do IO with the worker.
+   * This is a port and simplified version of `PythonRunner.ReaderInputStream`,
+   * and only supports to write all at once and then read all.
+   */
+  private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) extends InputStream {
+
+    private[this] val temp = new Array[Byte](1)
+
+    override def read(): Int = {
+      val n = read(temp)
+      if (n <= 0) {
+        -1
+      } else {
+        // Signed byte to unsigned integer
+        temp(0) & 0xff
+      }
+    }
+
+    override def read(b: Array[Byte], off: Int, len: Int): Int = {
+      val buf = ByteBuffer.wrap(b, off, len)
+      var n = 0
+      while (n == 0) {
+        worker.selector.select()
+        if (worker.selectionKey.isReadable) {
+          n = worker.channel.read(buf)
+        }
+        if (worker.selectionKey.isWritable) {
+          var acceptsInput = true
+          while (acceptsInput && buffer.hasRemaining) {
+            val n = worker.channel.write(buffer)
+            acceptsInput = n > 0
+          }
+          if (!buffer.hasRemaining) {
+            // We no longer have any data to write to the socket.
+            worker.selectionKey.interestOps(SelectionKey.OP_READ)
+          }
+        }
+      }
+      n
+    }
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index d8d3cc9b7fc4..f2f952f079e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -17,28 +17,19 @@
 
 package org.apache.spark.sql.execution.python
 
-import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InputStream}
-import java.nio.ByteBuffer
-import java.nio.channels.SelectionKey
-import java.util.HashMap
+import java.io.{DataInputStream, DataOutputStream}
 
 import scala.collection.mutable.ArrayBuffer
-import scala.jdk.CollectionConverters._
 
 import net.razorvine.pickle.Pickler
 
-import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException}
-import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorker, PythonWorkerUtils, SpecialLengths}
-import org.apache.spark.internal.config.BUFFER_SIZE
-import org.apache.spark.internal.config.Python._
+import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorkerUtils, SpecialLengths}
 import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF, PythonUDTF, PythonUDTFAnalyzeResult, SortOrder, UnresolvedPolymorphicPythonUDTF}
 import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation}
 import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.util.DirectByteBufferOutputStream
 
 /**
  * A user-defined Python function. This is used by the Python API.
@@ -141,13 +132,17 @@ case class UserDefinedPythonTableFunction(
           case NamedArgumentExpression(_, _: FunctionTableSubqueryArgumentExpression) => true
           case _ => false
         }
+        val runAnalyzeInPython = (func: PythonFunction, exprs: Seq[Expression]) => {
+          val runner = new UserDefinedPythonTableFunctionAnalyzeRunner(func, exprs, tableArgs)
+          runner.runInPython()
+        }
         UnresolvedPolymorphicPythonUDTF(
           name = name,
           func = func,
           children = exprs,
           evalType = pythonEvalType,
           udfDeterministic = udfDeterministic,
-          resolveElementMetadata = UserDefinedPythonTableFunction.analyzeInPython(_, _, tableArgs))
+          resolveElementMetadata = runAnalyzeInPython)
     }
     Generate(
       udtf,
@@ -166,228 +161,106 @@ case class UserDefinedPythonTableFunction(
   }
 }
 
-object UserDefinedPythonTableFunction {
-
-  private[this] val workerModule = "pyspark.sql.worker.analyze_udtf"
-
-  /**
-   * Runs the Python UDTF's `analyze` static method.
-   *
-   * When the Python UDTF is defined without a static return type,
-   * the analyzer will call this while resolving table-valued functions.
-   *
-   * This expects the Python UDTF to have `analyze` static method that take arguments:
-   *
-   * - The number and order of arguments are the same as the UDTF inputs
-   * - Each argument is an `AnalyzeArgument`, containing:
-   *   - data_type: DataType
-   *   - value: Any: if the argument is foldable; otherwise None
-   *   - is_table: bool: True if the argument is TABLE
-   *
-   * and that return an `AnalyzeResult`.
-   *
-   * It serializes/deserializes the data types via JSON,
-   * and the values for the case the argument is foldable are pickled.
-   *
-   * `AnalysisException` with the error class "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON"
-   * will be thrown when an exception is raised in Python.
-   */
-  def analyzeInPython(
-      func: PythonFunction,
-      exprs: Seq[Expression],
-      tableArgs: Seq[Boolean]): PythonUDTFAnalyzeResult = {
-    val env = SparkEnv.get
-    val bufferSize: Int = env.conf.get(BUFFER_SIZE)
-    val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
-    val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
-    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
-    val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
-    val workerMemoryMb = SQLConf.get.pythonUDTFAnalyzerMemory
-
-    val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
-
-    val envVars = new HashMap[String, String](func.envVars)
-    val pythonExec = func.pythonExec
-    val pythonVer = func.pythonVer
-    val pythonIncludes = func.pythonIncludes.asScala.toSet
-    val broadcastVars = func.broadcastVars.asScala.toSeq
-    val maybeAccumulator = Option(func.accumulator).map(_.copyAndReset())
-
-    envVars.put("SPARK_LOCAL_DIRS", localdir)
-    if (reuseWorker) {
-      envVars.put("SPARK_REUSE_WORKER", "1")
-    }
-    if (simplifiedTraceback) {
-      envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
-    }
-    workerMemoryMb.foreach { memoryMb =>
-      envVars.put("PYSPARK_UDTF_ANALYZER_MEMORY_MB", memoryMb.toString)
-    }
-    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
-    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
-
-    envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))
-
-    EvaluatePython.registerPicklers()
-    val pickler = new Pickler(/* useMemo = */ true,
-      /* valueCompare = */ false)
-
-    val (worker: PythonWorker, _) =
-      env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
-    var releasedOrClosed = false
-    val bufferStream = new DirectByteBufferOutputStream()
-    try {
-      val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize))
-
-      PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
-      PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut)
-      PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
-
-      // Send Python UDTF
-      PythonWorkerUtils.writePythonFunction(func, dataOut)
-
-      // Send arguments
-      dataOut.writeInt(exprs.length)
-      exprs.zip(tableArgs).foreach { case (expr, is_table) =>
-        PythonWorkerUtils.writeUTF(expr.dataType.json, dataOut)
-        if (expr.foldable) {
-          dataOut.writeBoolean(true)
-          val obj = pickler.dumps(EvaluatePython.toJava(expr.eval(), expr.dataType))
-          PythonWorkerUtils.writeBytes(obj, dataOut)
-        } else {
-          dataOut.writeBoolean(false)
-        }
-        dataOut.writeBoolean(is_table)
-        // If the expr is NamedArgumentExpression, send its name.
-        expr match {
-          case NamedArgumentExpression(key, _) =>
-            dataOut.writeBoolean(true)
-            PythonWorkerUtils.writeUTF(key, dataOut)
-          case _ =>
-            dataOut.writeBoolean(false)
-        }
-      }
-
-      dataOut.writeInt(SpecialLengths.END_OF_STREAM)
-      dataOut.flush()
-
-      val dataIn = new DataInputStream(new BufferedInputStream(
-        new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize))
-
-      // Receive the schema or an exception raised in Python worker.
-      val length = dataIn.readInt()
-      if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
-        val msg = PythonWorkerUtils.readUTF(dataIn)
-        throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
-      }
-
-      val schema = DataType.fromJson(
-        PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType]
-
-      // Receive the pickled AnalyzeResult buffer, if any.
-      val pickledAnalyzeResult: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
-
-      // Receive whether the "with single partition" property is requested.
-      val withSinglePartition = dataIn.readInt() == 1
-      // Receive the list of requested partitioning columns, if any.
-      val partitionByColumns = ArrayBuffer.empty[Expression]
-      val numPartitionByColumns = dataIn.readInt()
-      for (_ <- 0 until numPartitionByColumns) {
-        val columnName = PythonWorkerUtils.readUTF(dataIn)
-        partitionByColumns.append(UnresolvedAttribute(columnName))
-      }
-      // Receive the list of requested ordering columns, if any.
-      val orderBy = ArrayBuffer.empty[SortOrder]
-      val numOrderByItems = dataIn.readInt()
-      for (_ <- 0 until numOrderByItems) {
-        val columnName = PythonWorkerUtils.readUTF(dataIn)
-        val direction = if (dataIn.readInt() == 1) Ascending else Descending
-        val overrideNullsFirst = dataIn.readInt()
-        overrideNullsFirst match {
-          case 0 =>
-            orderBy.append(SortOrder(UnresolvedAttribute(columnName), direction))
-          case 1 => orderBy.append(
-            SortOrder(UnresolvedAttribute(columnName), direction, NullsFirst, Seq.empty))
-          case 2 => orderBy.append(
-            SortOrder(UnresolvedAttribute(columnName), direction, NullsLast, Seq.empty))
-        }
+/**
+ * Runs the Python UDTF's `analyze` static method.
+ *
+ * When the Python UDTF is defined without a static return type,
+ * the analyzer will call this while resolving table-valued functions.
+ *
+ * This expects the Python UDTF to have `analyze` static method that take arguments:
+ *
+ * - The number and order of arguments are the same as the UDTF inputs
+ * - Each argument is an `AnalyzeArgument`, containing:
+ *   - data_type: DataType
+ *   - value: Any: if the argument is foldable; otherwise None
+ *   - is_table: bool: True if the argument is TABLE
+ *
+ * and that return an `AnalyzeResult`.
+ *
+ * It serializes/deserializes the data types via JSON,
+ * and the values for the case the argument is foldable are pickled.
+ *
+ * `AnalysisException` with the error class "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON"
+ * will be thrown when an exception is raised in Python.
+ */
+class UserDefinedPythonTableFunctionAnalyzeRunner(
+    func: PythonFunction,
+    exprs: Seq[Expression],
+    tableArgs: Seq[Boolean]) extends PythonPlannerRunner[PythonUDTFAnalyzeResult](func) {
+
+  override val workerModule = "pyspark.sql.worker.analyze_udtf"
+
+  override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
+    // Send Python UDTF
+    PythonWorkerUtils.writePythonFunction(func, dataOut)
+
+    // Send arguments
+    dataOut.writeInt(exprs.length)
+    exprs.zip(tableArgs).foreach { case (expr, is_table) =>
+      PythonWorkerUtils.writeUTF(expr.dataType.json, dataOut)
+      if (expr.foldable) {
+        dataOut.writeBoolean(true)
+        val obj = pickler.dumps(EvaluatePython.toJava(expr.eval(), expr.dataType))
+        PythonWorkerUtils.writeBytes(obj, dataOut)
+      } else {
+        dataOut.writeBoolean(false)
       }
-
-      PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn)
-      Option(func.accumulator).foreach(_.merge(maybeAccumulator.get))
-
-      dataIn.readInt() match {
-        case SpecialLengths.END_OF_STREAM if reuseWorker =>
-          env.releasePythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
+      dataOut.writeBoolean(is_table)
+      // If the expr is NamedArgumentExpression, send its name.
+      expr match {
+        case NamedArgumentExpression(key, _) =>
+          dataOut.writeBoolean(true)
+          PythonWorkerUtils.writeUTF(key, dataOut)
         case _ =>
-          env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
-      }
-      releasedOrClosed = true
-
-      PythonUDTFAnalyzeResult(
-        schema = schema,
-        withSinglePartition = withSinglePartition,
-        partitionByExpressions = partitionByColumns.toSeq,
-        orderByExpressions = orderBy.toSeq,
-        pickledAnalyzeResult = pickledAnalyzeResult)
-    } catch {
-      case eof: EOFException =>
-        throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
-    } finally {
-      try {
-        bufferStream.close()
-      } finally {
-        if (!releasedOrClosed) {
-          // An error happened. Force to close the worker.
-          env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
-        }
+          dataOut.writeBoolean(false)
       }
     }
   }
 
-  /**
-   * A wrapper of the non-blocking IO to write to/read from the worker.
-   *
-   * Since we use non-blocking IO to communicate with workers; see SPARK-44705,
-   * a wrapper is needed to do IO with the worker.
-   * This is a port and simplified version of `PythonRunner.ReaderInputStream`,
-   * and only supports to write all at once and then read all.
-   */
-  private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) extends InputStream {
+  override protected def receiveFromPython(dataIn: DataInputStream): PythonUDTFAnalyzeResult = {
+    // Receive the schema or an exception raised in Python worker.
+    val length = dataIn.readInt()
+    if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+      val msg = PythonWorkerUtils.readUTF(dataIn)
+      throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
+    }
 
-    private[this] val temp = new Array[Byte](1)
+    val schema = DataType.fromJson(
+      PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType]
 
-    override def read(): Int = {
-      val n = read(temp)
-      if (n <= 0) {
-        -1
-      } else {
-        // Signed byte to unsigned integer
-        temp(0) & 0xff
-      }
-    }
+    // Receive the pickled AnalyzeResult buffer, if any.
+    val pickledAnalyzeResult: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
 
-    override def read(b: Array[Byte], off: Int, len: Int): Int = {
-      val buf = ByteBuffer.wrap(b, off, len)
-      var n = 0
-      while (n == 0) {
-        worker.selector.select()
-        if (worker.selectionKey.isReadable) {
-          n = worker.channel.read(buf)
-        }
-        if (worker.selectionKey.isWritable) {
-          var acceptsInput = true
-          while (acceptsInput && buffer.hasRemaining) {
-            val n = worker.channel.write(buffer)
-            acceptsInput = n > 0
-          }
-          if (!buffer.hasRemaining) {
-            // We no longer have any data to write to the socket.
-            worker.selectionKey.interestOps(SelectionKey.OP_READ)
-          }
-        }
+    // Receive whether the "with single partition" property is requested.
+    val withSinglePartition = dataIn.readInt() == 1
+    // Receive the list of requested partitioning columns, if any.
+    val partitionByColumns = ArrayBuffer.empty[Expression]
+    val numPartitionByColumns = dataIn.readInt()
+    for (_ <- 0 until numPartitionByColumns) {
+      val columnName = PythonWorkerUtils.readUTF(dataIn)
+      partitionByColumns.append(UnresolvedAttribute(columnName))
+    }
+    // Receive the list of requested ordering columns, if any.
+    val orderBy = ArrayBuffer.empty[SortOrder]
+    val numOrderByItems = dataIn.readInt()
+    for (_ <- 0 until numOrderByItems) {
+      val columnName = PythonWorkerUtils.readUTF(dataIn)
+      val direction = if (dataIn.readInt() == 1) Ascending else Descending
+      val overrideNullsFirst = dataIn.readInt()
+      overrideNullsFirst match {
+        case 0 =>
+          orderBy.append(SortOrder(UnresolvedAttribute(columnName), direction))
+        case 1 => orderBy.append(
+          SortOrder(UnresolvedAttribute(columnName), direction, NullsFirst, Seq.empty))
+        case 2 => orderBy.append(
+          SortOrder(UnresolvedAttribute(columnName), direction, NullsLast, Seq.empty))
       }
-      n
     }
+    PythonUDTFAnalyzeResult(
+      schema = schema,
+      withSinglePartition = withSinglePartition,
+      partitionByExpressions = partitionByColumns.toSeq,
+      orderByExpressions = orderBy.toSeq,
+      pickledAnalyzeResult = pickledAnalyzeResult)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org