You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ue...@apache.org on 2023/10/13 00:02:55 UTC
[spark] branch master updated: [SPARK-45505][PYTHON] Refactor analyzeInPython to make it reusable
This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 280f6b33110d [SPARK-45505][PYTHON] Refactor analyzeInPython to make it reusable
280f6b33110d is described below
commit 280f6b33110d707ebee6fec6e5bafa45b45213ae
Author: allisonwang-db <al...@databricks.com>
AuthorDate: Thu Oct 12 17:02:41 2023 -0700
[SPARK-45505][PYTHON] Refactor analyzeInPython to make it reusable
### What changes were proposed in this pull request?
Currently, the `analyzeInPython` method in UserDefinedPythonTableFunction object can starts a Python process in driver and run a Python function in the Python process. This PR aims to refactor this logic into a reusable runner class.
### Why are the changes needed?
To make the code more reusable.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43340 from allisonwang-db/spark-45505-refactor-analyze-in-py.
Authored-by: allisonwang-db <al...@databricks.com>
Signed-off-by: Takuya UESHIN <ue...@databricks.com>
---
python/pyspark/sql/worker/analyze_udtf.py | 6 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 18 +-
.../sql/execution/python/PythonPlannerRunner.scala | 177 ++++++++++++
.../python/UserDefinedPythonFunction.scala | 321 +++++++--------------
4 files changed, 286 insertions(+), 236 deletions(-)
diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py
index a6aa381eb14a..9e84b880fc96 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -98,14 +98,14 @@ def main(infile: IO, outfile: IO) -> None:
"""
Runs the Python UDTF's `analyze` static method.
- This process will be invoked from `UserDefinedPythonTableFunction.analyzeInPython` in JVM
- and receive the Python UDTF and its arguments for the `analyze` static method,
+ This process will be invoked from `UserDefinedPythonTableFunctionAnalyzeRunner.runInPython`
+ in JVM and receive the Python UDTF and its arguments for the `analyze` static method,
and call the `analyze` static method, and send back a AnalyzeResult as a result of the method.
"""
try:
check_python_version(infile)
- memory_limit_mb = int(os.environ.get("PYSPARK_UDTF_ANALYZER_MEMORY_MB", "-1"))
+ memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)
setup_spark_files(infile)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 12ec9e911d31..000694f6f1bb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3008,14 +3008,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
- val PYTHON_TABLE_UDF_ANALYZER_MEMORY =
- buildConf("spark.sql.analyzer.pythonUDTF.analyzeInPython.memory")
- .doc("The amount of memory to be allocated to PySpark for Python UDTF analyzer, in MiB " +
- "unless otherwise specified. If set, PySpark memory for Python UDTF analyzer will be " +
- "limited to this amount. If not set, Spark will not limit Python's " +
- "memory use and it is up to the application to avoid exceeding the overhead memory space " +
- "shared with other non-JVM processes.\nNote: Windows does not support resource limiting " +
- "and actual resource is not limited on MacOS.")
+ val PYTHON_PLANNER_EXEC_MEMORY =
+ buildConf("spark.sql.planner.pythonExecution.memory")
+ .doc("Specifies the memory allocation for executing Python code in Spark driver, in MiB. " +
+ "When set, it caps the memory for Python execution to the specified amount. " +
+ "If not set, Spark will not limit Python's memory usage and it is up to the application " +
+ "to avoid exceeding the overhead memory space shared with other non-JVM processes.\n" +
+ "Note: Windows does not support resource limiting and actual resource is not limited " +
+ "on MacOS.")
.version("4.0.0")
.bytesConf(ByteUnit.MiB)
.createOptional
@@ -5157,7 +5157,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def pysparkWorkerPythonExecutable: Option[String] =
getConf(SQLConf.PYSPARK_WORKER_PYTHON_EXECUTABLE)
- def pythonUDTFAnalyzerMemory: Option[Long] = getConf(PYTHON_TABLE_UDF_ANALYZER_MEMORY)
+ def pythonPlannerExecMemory: Option[Long] = getConf(PYTHON_PLANNER_EXEC_MEMORY)
def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
new file mode 100644
index 000000000000..183b96bb982c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InputStream}
+import java.nio.ByteBuffer
+import java.nio.channels.SelectionKey
+import java.util.HashMap
+
+import scala.jdk.CollectionConverters._
+
+import net.razorvine.pickle.Pickler
+
+import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException}
+import org.apache.spark.api.python.{PythonFunction, PythonWorker, PythonWorkerUtils, SpecialLengths}
+import org.apache.spark.internal.config.BUFFER_SIZE
+import org.apache.spark.internal.config.Python._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.DirectByteBufferOutputStream
+
+/**
+ * A helper class to run Python functions in Spark driver.
+ */
+abstract class PythonPlannerRunner[T](func: PythonFunction) {
+
+ protected val workerModule: String
+
+ protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit
+
+ protected def receiveFromPython(dataIn: DataInputStream): T
+
+ def runInPython(): T = {
+ val env = SparkEnv.get
+ val bufferSize: Int = env.conf.get(BUFFER_SIZE)
+ val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+ val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
+ val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
+ val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+ val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory
+
+ val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
+ val envVars = new HashMap[String, String](func.envVars)
+ val pythonExec = func.pythonExec
+ val pythonVer = func.pythonVer
+ val pythonIncludes = func.pythonIncludes.asScala.toSet
+ val broadcastVars = func.broadcastVars.asScala.toSeq
+ val maybeAccumulator = Option(func.accumulator).map(_.copyAndReset())
+
+ envVars.put("SPARK_LOCAL_DIRS", localdir)
+ if (reuseWorker) {
+ envVars.put("SPARK_REUSE_WORKER", "1")
+ }
+ if (simplifiedTraceback) {
+ envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
+ }
+ workerMemoryMb.foreach { memoryMb =>
+ envVars.put("PYSPARK_PLANNER_MEMORY_MB", memoryMb.toString)
+ }
+ envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+ envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+ envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))
+
+ EvaluatePython.registerPicklers()
+ val pickler = new Pickler(/* useMemo = */ true,
+ /* valueCompare = */ false)
+
+ val (worker: PythonWorker, _) =
+ env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
+ var releasedOrClosed = false
+ val bufferStream = new DirectByteBufferOutputStream()
+ try {
+ val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize))
+
+ PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
+ PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut)
+ PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
+
+ writeToPython(dataOut, pickler)
+
+ dataOut.writeInt(SpecialLengths.END_OF_STREAM)
+ dataOut.flush()
+
+ val dataIn = new DataInputStream(new BufferedInputStream(
+ new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize))
+
+ val res = receiveFromPython(dataIn)
+
+ PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn)
+ Option(func.accumulator).foreach(_.merge(maybeAccumulator.get))
+
+ dataIn.readInt() match {
+ case SpecialLengths.END_OF_STREAM if reuseWorker =>
+ env.releasePythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
+ case _ =>
+ env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
+ }
+ releasedOrClosed = true
+
+ res
+ } catch {
+ case eof: EOFException =>
+ throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
+ } finally {
+ try {
+ bufferStream.close()
+ } finally {
+ if (!releasedOrClosed) {
+ // An error happened. Force to close the worker.
+ env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
+ }
+ }
+ }
+ }
+
+ /**
+ * A wrapper of the non-blocking IO to write to/read from the worker.
+ *
+ * Since we use non-blocking IO to communicate with workers; see SPARK-44705,
+ * a wrapper is needed to do IO with the worker.
+ * This is a port and simplified version of `PythonRunner.ReaderInputStream`,
+ * and only supports to write all at once and then read all.
+ */
+ private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) extends InputStream {
+
+ private[this] val temp = new Array[Byte](1)
+
+ override def read(): Int = {
+ val n = read(temp)
+ if (n <= 0) {
+ -1
+ } else {
+ // Signed byte to unsigned integer
+ temp(0) & 0xff
+ }
+ }
+
+ override def read(b: Array[Byte], off: Int, len: Int): Int = {
+ val buf = ByteBuffer.wrap(b, off, len)
+ var n = 0
+ while (n == 0) {
+ worker.selector.select()
+ if (worker.selectionKey.isReadable) {
+ n = worker.channel.read(buf)
+ }
+ if (worker.selectionKey.isWritable) {
+ var acceptsInput = true
+ while (acceptsInput && buffer.hasRemaining) {
+ val n = worker.channel.write(buffer)
+ acceptsInput = n > 0
+ }
+ if (!buffer.hasRemaining) {
+ // We no longer have any data to write to the socket.
+ worker.selectionKey.interestOps(SelectionKey.OP_READ)
+ }
+ }
+ }
+ n
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index d8d3cc9b7fc4..f2f952f079e2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -17,28 +17,19 @@
package org.apache.spark.sql.execution.python
-import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, InputStream}
-import java.nio.ByteBuffer
-import java.nio.channels.SelectionKey
-import java.util.HashMap
+import java.io.{DataInputStream, DataOutputStream}
import scala.collection.mutable.ArrayBuffer
-import scala.jdk.CollectionConverters._
import net.razorvine.pickle.Pickler
-import org.apache.spark.{JobArtifactSet, SparkEnv, SparkException}
-import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorker, PythonWorkerUtils, SpecialLengths}
-import org.apache.spark.internal.config.BUFFER_SIZE
-import org.apache.spark.internal.config.Python._
+import org.apache.spark.api.python.{PythonEvalType, PythonFunction, PythonWorkerUtils, SpecialLengths}
import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, FunctionTableSubqueryArgumentExpression, NamedArgumentExpression, NullsFirst, NullsLast, PythonUDAF, PythonUDF, PythonUDTF, PythonUDTFAnalyzeResult, SortOrder, UnresolvedPolymorphicPythonUDTF}
import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan, NamedParametersSupport, OneRowRelation}
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.util.DirectByteBufferOutputStream
/**
* A user-defined Python function. This is used by the Python API.
@@ -141,13 +132,17 @@ case class UserDefinedPythonTableFunction(
case NamedArgumentExpression(_, _: FunctionTableSubqueryArgumentExpression) => true
case _ => false
}
+ val runAnalyzeInPython = (func: PythonFunction, exprs: Seq[Expression]) => {
+ val runner = new UserDefinedPythonTableFunctionAnalyzeRunner(func, exprs, tableArgs)
+ runner.runInPython()
+ }
UnresolvedPolymorphicPythonUDTF(
name = name,
func = func,
children = exprs,
evalType = pythonEvalType,
udfDeterministic = udfDeterministic,
- resolveElementMetadata = UserDefinedPythonTableFunction.analyzeInPython(_, _, tableArgs))
+ resolveElementMetadata = runAnalyzeInPython)
}
Generate(
udtf,
@@ -166,228 +161,106 @@ case class UserDefinedPythonTableFunction(
}
}
-object UserDefinedPythonTableFunction {
-
- private[this] val workerModule = "pyspark.sql.worker.analyze_udtf"
-
- /**
- * Runs the Python UDTF's `analyze` static method.
- *
- * When the Python UDTF is defined without a static return type,
- * the analyzer will call this while resolving table-valued functions.
- *
- * This expects the Python UDTF to have `analyze` static method that take arguments:
- *
- * - The number and order of arguments are the same as the UDTF inputs
- * - Each argument is an `AnalyzeArgument`, containing:
- * - data_type: DataType
- * - value: Any: if the argument is foldable; otherwise None
- * - is_table: bool: True if the argument is TABLE
- *
- * and that return an `AnalyzeResult`.
- *
- * It serializes/deserializes the data types via JSON,
- * and the values for the case the argument is foldable are pickled.
- *
- * `AnalysisException` with the error class "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON"
- * will be thrown when an exception is raised in Python.
- */
- def analyzeInPython(
- func: PythonFunction,
- exprs: Seq[Expression],
- tableArgs: Seq[Boolean]): PythonUDTFAnalyzeResult = {
- val env = SparkEnv.get
- val bufferSize: Int = env.conf.get(BUFFER_SIZE)
- val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
- val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
- val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
- val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
- val workerMemoryMb = SQLConf.get.pythonUDTFAnalyzerMemory
-
- val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
-
- val envVars = new HashMap[String, String](func.envVars)
- val pythonExec = func.pythonExec
- val pythonVer = func.pythonVer
- val pythonIncludes = func.pythonIncludes.asScala.toSet
- val broadcastVars = func.broadcastVars.asScala.toSeq
- val maybeAccumulator = Option(func.accumulator).map(_.copyAndReset())
-
- envVars.put("SPARK_LOCAL_DIRS", localdir)
- if (reuseWorker) {
- envVars.put("SPARK_REUSE_WORKER", "1")
- }
- if (simplifiedTraceback) {
- envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
- }
- workerMemoryMb.foreach { memoryMb =>
- envVars.put("PYSPARK_UDTF_ANALYZER_MEMORY_MB", memoryMb.toString)
- }
- envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
- envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
-
- envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))
-
- EvaluatePython.registerPicklers()
- val pickler = new Pickler(/* useMemo = */ true,
- /* valueCompare = */ false)
-
- val (worker: PythonWorker, _) =
- env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
- var releasedOrClosed = false
- val bufferStream = new DirectByteBufferOutputStream()
- try {
- val dataOut = new DataOutputStream(new BufferedOutputStream(bufferStream, bufferSize))
-
- PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
- PythonWorkerUtils.writeSparkFiles(jobArtifactUUID, pythonIncludes, dataOut)
- PythonWorkerUtils.writeBroadcasts(broadcastVars, worker, env, dataOut)
-
- // Send Python UDTF
- PythonWorkerUtils.writePythonFunction(func, dataOut)
-
- // Send arguments
- dataOut.writeInt(exprs.length)
- exprs.zip(tableArgs).foreach { case (expr, is_table) =>
- PythonWorkerUtils.writeUTF(expr.dataType.json, dataOut)
- if (expr.foldable) {
- dataOut.writeBoolean(true)
- val obj = pickler.dumps(EvaluatePython.toJava(expr.eval(), expr.dataType))
- PythonWorkerUtils.writeBytes(obj, dataOut)
- } else {
- dataOut.writeBoolean(false)
- }
- dataOut.writeBoolean(is_table)
- // If the expr is NamedArgumentExpression, send its name.
- expr match {
- case NamedArgumentExpression(key, _) =>
- dataOut.writeBoolean(true)
- PythonWorkerUtils.writeUTF(key, dataOut)
- case _ =>
- dataOut.writeBoolean(false)
- }
- }
-
- dataOut.writeInt(SpecialLengths.END_OF_STREAM)
- dataOut.flush()
-
- val dataIn = new DataInputStream(new BufferedInputStream(
- new WorkerInputStream(worker, bufferStream.toByteBuffer), bufferSize))
-
- // Receive the schema or an exception raised in Python worker.
- val length = dataIn.readInt()
- if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
- val msg = PythonWorkerUtils.readUTF(dataIn)
- throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
- }
-
- val schema = DataType.fromJson(
- PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType]
-
- // Receive the pickled AnalyzeResult buffer, if any.
- val pickledAnalyzeResult: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
-
- // Receive whether the "with single partition" property is requested.
- val withSinglePartition = dataIn.readInt() == 1
- // Receive the list of requested partitioning columns, if any.
- val partitionByColumns = ArrayBuffer.empty[Expression]
- val numPartitionByColumns = dataIn.readInt()
- for (_ <- 0 until numPartitionByColumns) {
- val columnName = PythonWorkerUtils.readUTF(dataIn)
- partitionByColumns.append(UnresolvedAttribute(columnName))
- }
- // Receive the list of requested ordering columns, if any.
- val orderBy = ArrayBuffer.empty[SortOrder]
- val numOrderByItems = dataIn.readInt()
- for (_ <- 0 until numOrderByItems) {
- val columnName = PythonWorkerUtils.readUTF(dataIn)
- val direction = if (dataIn.readInt() == 1) Ascending else Descending
- val overrideNullsFirst = dataIn.readInt()
- overrideNullsFirst match {
- case 0 =>
- orderBy.append(SortOrder(UnresolvedAttribute(columnName), direction))
- case 1 => orderBy.append(
- SortOrder(UnresolvedAttribute(columnName), direction, NullsFirst, Seq.empty))
- case 2 => orderBy.append(
- SortOrder(UnresolvedAttribute(columnName), direction, NullsLast, Seq.empty))
- }
+/**
+ * Runs the Python UDTF's `analyze` static method.
+ *
+ * When the Python UDTF is defined without a static return type,
+ * the analyzer will call this while resolving table-valued functions.
+ *
+ * This expects the Python UDTF to have `analyze` static method that take arguments:
+ *
+ * - The number and order of arguments are the same as the UDTF inputs
+ * - Each argument is an `AnalyzeArgument`, containing:
+ * - data_type: DataType
+ * - value: Any: if the argument is foldable; otherwise None
+ * - is_table: bool: True if the argument is TABLE
+ *
+ * and that return an `AnalyzeResult`.
+ *
+ * It serializes/deserializes the data types via JSON,
+ * and the values for the case the argument is foldable are pickled.
+ *
+ * `AnalysisException` with the error class "TABLE_VALUED_FUNCTION_FAILED_TO_ANALYZE_IN_PYTHON"
+ * will be thrown when an exception is raised in Python.
+ */
+class UserDefinedPythonTableFunctionAnalyzeRunner(
+ func: PythonFunction,
+ exprs: Seq[Expression],
+ tableArgs: Seq[Boolean]) extends PythonPlannerRunner[PythonUDTFAnalyzeResult](func) {
+
+ override val workerModule = "pyspark.sql.worker.analyze_udtf"
+
+ override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = {
+ // Send Python UDTF
+ PythonWorkerUtils.writePythonFunction(func, dataOut)
+
+ // Send arguments
+ dataOut.writeInt(exprs.length)
+ exprs.zip(tableArgs).foreach { case (expr, is_table) =>
+ PythonWorkerUtils.writeUTF(expr.dataType.json, dataOut)
+ if (expr.foldable) {
+ dataOut.writeBoolean(true)
+ val obj = pickler.dumps(EvaluatePython.toJava(expr.eval(), expr.dataType))
+ PythonWorkerUtils.writeBytes(obj, dataOut)
+ } else {
+ dataOut.writeBoolean(false)
}
-
- PythonWorkerUtils.receiveAccumulatorUpdates(maybeAccumulator, dataIn)
- Option(func.accumulator).foreach(_.merge(maybeAccumulator.get))
-
- dataIn.readInt() match {
- case SpecialLengths.END_OF_STREAM if reuseWorker =>
- env.releasePythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
+ dataOut.writeBoolean(is_table)
+ // If the expr is NamedArgumentExpression, send its name.
+ expr match {
+ case NamedArgumentExpression(key, _) =>
+ dataOut.writeBoolean(true)
+ PythonWorkerUtils.writeUTF(key, dataOut)
case _ =>
- env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
- }
- releasedOrClosed = true
-
- PythonUDTFAnalyzeResult(
- schema = schema,
- withSinglePartition = withSinglePartition,
- partitionByExpressions = partitionByColumns.toSeq,
- orderByExpressions = orderBy.toSeq,
- pickledAnalyzeResult = pickledAnalyzeResult)
- } catch {
- case eof: EOFException =>
- throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
- } finally {
- try {
- bufferStream.close()
- } finally {
- if (!releasedOrClosed) {
- // An error happened. Force to close the worker.
- env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
- }
+ dataOut.writeBoolean(false)
}
}
}
- /**
- * A wrapper of the non-blocking IO to write to/read from the worker.
- *
- * Since we use non-blocking IO to communicate with workers; see SPARK-44705,
- * a wrapper is needed to do IO with the worker.
- * This is a port and simplified version of `PythonRunner.ReaderInputStream`,
- * and only supports to write all at once and then read all.
- */
- private class WorkerInputStream(worker: PythonWorker, buffer: ByteBuffer) extends InputStream {
+ override protected def receiveFromPython(dataIn: DataInputStream): PythonUDTFAnalyzeResult = {
+ // Receive the schema or an exception raised in Python worker.
+ val length = dataIn.readInt()
+ if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
+ val msg = PythonWorkerUtils.readUTF(dataIn)
+ throw QueryCompilationErrors.tableValuedFunctionFailedToAnalyseInPythonError(msg)
+ }
- private[this] val temp = new Array[Byte](1)
+ val schema = DataType.fromJson(
+ PythonWorkerUtils.readUTF(length, dataIn)).asInstanceOf[StructType]
- override def read(): Int = {
- val n = read(temp)
- if (n <= 0) {
- -1
- } else {
- // Signed byte to unsigned integer
- temp(0) & 0xff
- }
- }
+ // Receive the pickled AnalyzeResult buffer, if any.
+ val pickledAnalyzeResult: Array[Byte] = PythonWorkerUtils.readBytes(dataIn)
- override def read(b: Array[Byte], off: Int, len: Int): Int = {
- val buf = ByteBuffer.wrap(b, off, len)
- var n = 0
- while (n == 0) {
- worker.selector.select()
- if (worker.selectionKey.isReadable) {
- n = worker.channel.read(buf)
- }
- if (worker.selectionKey.isWritable) {
- var acceptsInput = true
- while (acceptsInput && buffer.hasRemaining) {
- val n = worker.channel.write(buffer)
- acceptsInput = n > 0
- }
- if (!buffer.hasRemaining) {
- // We no longer have any data to write to the socket.
- worker.selectionKey.interestOps(SelectionKey.OP_READ)
- }
- }
+ // Receive whether the "with single partition" property is requested.
+ val withSinglePartition = dataIn.readInt() == 1
+ // Receive the list of requested partitioning columns, if any.
+ val partitionByColumns = ArrayBuffer.empty[Expression]
+ val numPartitionByColumns = dataIn.readInt()
+ for (_ <- 0 until numPartitionByColumns) {
+ val columnName = PythonWorkerUtils.readUTF(dataIn)
+ partitionByColumns.append(UnresolvedAttribute(columnName))
+ }
+ // Receive the list of requested ordering columns, if any.
+ val orderBy = ArrayBuffer.empty[SortOrder]
+ val numOrderByItems = dataIn.readInt()
+ for (_ <- 0 until numOrderByItems) {
+ val columnName = PythonWorkerUtils.readUTF(dataIn)
+ val direction = if (dataIn.readInt() == 1) Ascending else Descending
+ val overrideNullsFirst = dataIn.readInt()
+ overrideNullsFirst match {
+ case 0 =>
+ orderBy.append(SortOrder(UnresolvedAttribute(columnName), direction))
+ case 1 => orderBy.append(
+ SortOrder(UnresolvedAttribute(columnName), direction, NullsFirst, Seq.empty))
+ case 2 => orderBy.append(
+ SortOrder(UnresolvedAttribute(columnName), direction, NullsLast, Seq.empty))
}
- n
}
+ PythonUDTFAnalyzeResult(
+ schema = schema,
+ withSinglePartition = withSinglePartition,
+ partitionByExpressions = partitionByColumns.toSeq,
+ orderByExpressions = orderBy.toSeq,
+ pickledAnalyzeResult = pickledAnalyzeResult)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org