You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "dtenedor (via GitHub)" <gi...@apache.org> on 2023/07/13 23:43:09 UTC

[GitHub] [spark] dtenedor commented on a diff in pull request #41948: [SPARK-44380][SQL][PYTHON] Support for Python UDTF to analyze in Python

dtenedor commented on code in PR #41948:
URL: https://github.com/apache/spark/pull/41948#discussion_r1263116521


##########
python/pyspark/errors/error_classes.py:
##########
@@ -277,6 +277,11 @@
       "The UDTF '<name>' is invalid. It does not implement the required 'eval' method. Please implement the 'eval' method in '<name>' and try again."
     ]
   },
+  "INVALID_UDTF_RETURN_TYPE" : {
+    "message" : [
+      "The UDTF '<name>' is invalid. It does not specify its return type or implement the required 'analysis' static method. Please specify the return type or implement the 'analyze' static method in '<name>' and try again."

Review Comment:
   ```suggestion
         "The UDTF '<name>' is invalid. It does not specify its return type or implement the required 'analyze' static method. Please specify the return type or implement the 'analyze' static method in '<name>' and try again."
   ```



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -719,6 +726,153 @@ def terminate(self):
         self.assertIn("Evaluate the input row", cls.eval.__doc__)
         self.assertIn("Terminate the UDTF", cls.terminate.__doc__)
 
+    def test_simple_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> StructType:
+                return StructType().add("c1", StringType()).add("c2", StringType())
+
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF)
+        rows = func().collect()

Review Comment:
   Can you also please add a test for each of these cases using SQL as well (via PySpark), to cover the end-to-end query processing steps?



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -719,6 +726,153 @@ def terminate(self):
         self.assertIn("Evaluate the input row", cls.eval.__doc__)
         self.assertIn("Terminate the UDTF", cls.terminate.__doc__)
 
+    def test_simple_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> StructType:
+                return StructType().add("c1", StringType()).add("c2", StringType())
+
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF)
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a) -> StructType:
+                assert isinstance(a, dict)
+                assert isinstance(a["data_type"], DataType)
+                assert a["value"] is not None
+                assert a["is_table"] is False
+                return StructType().add("a", a["data_type"])
+
+            def eval(self, a):
+                yield a,
+
+        func = udtf(TestUDTF)
+
+        df1 = func(lit(1))
+        self.assertEquals(df1.schema, StructType().add("a", IntegerType()))
+        self.assertEqual(df1.collect(), [Row(a=1)])
+
+        df2 = func(lit("x"))
+        self.assertEquals(df2.schema, StructType().add("a", StringType()))
+        self.assertEqual(df2.collect(), [Row(a="x")])
+
+    def test_udtf_with_analyze_multiple_arguments(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a, b) -> StructType:
+                return StructType().add("a", a["data_type"]).add("b", b["data_type"])
+
+            def eval(self, a, b):
+                yield a, b
+
+        func = udtf(TestUDTF)
+
+        df = func(lit(1), lit("x"))
+        self.assertEquals(df.schema, StructType().add("a", IntegerType()).add("b", StringType()))
+        self.assertEqual(df.collect(), [Row(a=1, b="x")])
+
+    def test_udtf_with_analyze_table_argument(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a) -> StructType:
+                assert isinstance(a, dict)
+                assert isinstance(a["data_type"], StructType)
+                assert a["value"] is None
+                assert a["is_table"] is True
+                return StructType().add("a", a["data_type"][0].dataType)
+
+            def eval(self, a: Row):
+                if a["id"] > 5:
+                    yield a["id"],
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 8)))")
+        self.assertEqual(df.schema, StructType().add("a", LongType()))
+        self.assertEqual(df.collect(), [Row(a=6), Row(a=7)])
+
+    def test_udtf_with_neither_return_type_nor_analyze(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        with self.assertRaises(PySparkAttributeError) as e:
+            udtf(TestUDTF)
+
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_RETURN_TYPE",
+            message_parameters={"name": "TestUDTF"},
+        )
+
+    def test_udtf_with_non_static_analyze(self):
+        class TestUDTF:
+            def analyze(self) -> StructType:
+                return StructType().add("c1", StringType()).add("c2", StringType())
+
+            def eval(self):
+                yield "hello", "world"
+
+        with self.assertRaises(PySparkAttributeError) as e:
+            udtf(TestUDTF)
+
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_RETURN_TYPE",
+            message_parameters={"name": "TestUDTF"},
+        )
+
+    def test_udtf_with_analyze_returning_non_struct(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze():
+                return StringType()
+
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF)
+
+        with self.assertRaisesRegex(
+            AnalysisException,
+            "Output of `analyze` static method of Python UDTFs expects a StructType "
+            "but got: <class 'pyspark.sql.types.StringType'>",
+        ):
+            func().collect()
+
+    def test_udtf_with_analyze_taking_wrong_number_of_arguments(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a, b) -> StructType:
+                return StructType().add("a", a["data_type"]).add("b", b["data_type"])
+
+            def eval(self, a):
+                yield a, a + 1
+
+        func = udtf(TestUDTF)
+
+        with self.assertRaisesRegex(
+            AnalysisException, r"analyze\(\) missing 1 required positional argument: 'b'"

Review Comment:
   Can you update this error message to say something like "The table valued function <functionName>" instead of "analyze()" to help the user understand what to do?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala:
##########
@@ -91,3 +122,104 @@ case class UserDefinedPythonTableFunction(
     Dataset.ofRows(session, udtf)
   }
 }
+
+object UserDefinedPythonTableFunction {
+
+  private[this] val workerModule = "pyspark.sql.worker.analyze_udtf"
+
+  /**
+   * Runs the Python UDTF's `analyze` static method.
+   */
+  def analyzeInPython(func: PythonFunction, exprs: Seq[Expression]): StructType = {
+    val env = SparkEnv.get
+    val bufferSize: Int = env.conf.get(BUFFER_SIZE)
+    val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+    val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
+    val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+    val envVars = new HashMap[String, String](func.envVars)
+    val pythonExec = func.pythonExec
+    val pythonVer = func.pythonVer
+
+    if (reuseWorker) {
+      envVars.put("SPARK_REUSE_WORKER", "1")
+    }
+    if (simplifiedTraceback) {
+      envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
+    }
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    EvaluatePython.registerPicklers()
+    val pickler = new Pickler(/* useMemo = */ true,

Review Comment:
   ```suggestion
       val pickler = new Pickler(useMemo = true,
   ```



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala:
##########
@@ -91,3 +122,104 @@ case class UserDefinedPythonTableFunction(
     Dataset.ofRows(session, udtf)
   }
 }
+
+object UserDefinedPythonTableFunction {
+
+  private[this] val workerModule = "pyspark.sql.worker.analyze_udtf"
+
+  /**
+   * Runs the Python UDTF's `analyze` static method.
+   */
+  def analyzeInPython(func: PythonFunction, exprs: Seq[Expression]): StructType = {
+    val env = SparkEnv.get
+    val bufferSize: Int = env.conf.get(BUFFER_SIZE)
+    val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+    val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
+    val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+    val envVars = new HashMap[String, String](func.envVars)
+    val pythonExec = func.pythonExec
+    val pythonVer = func.pythonVer
+
+    if (reuseWorker) {
+      envVars.put("SPARK_REUSE_WORKER", "1")
+    }
+    if (simplifiedTraceback) {
+      envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
+    }
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    EvaluatePython.registerPicklers()
+    val pickler = new Pickler(/* useMemo = */ true,
+      /* valueCompare = */ false)
+
+    val (worker: Socket, _) =
+      env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
+    val releasedOrClosed = new AtomicBoolean(false)
+    try {
+      val dataOut =
+        new DataOutputStream(new BufferedOutputStream(worker.getOutputStream, bufferSize))
+      val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
+
+      // Python version of driver
+      PythonRDD.writeUTF(pythonVer, dataOut)
+
+      // Send Python UDTF
+      dataOut.writeInt(func.command.length)
+      dataOut.write(func.command.toArray)
+
+      // Send arguments
+      dataOut.writeInt(exprs.length)
+      exprs.foreach { expr =>
+        PythonRDD.writeUTF(expr.dataType.json, dataOut)
+        if (expr.foldable) {
+          dataOut.writeBoolean(true)
+          val obj = pickler.dumps(EvaluatePython.toJava(expr.eval(), expr.dataType))
+          dataOut.writeInt(obj.length)
+          dataOut.write(obj)
+        } else {
+          dataOut.writeBoolean(false)
+        }
+        dataOut.writeBoolean(expr.isInstanceOf[FunctionTableSubqueryArgumentExpression])
+      }
+
+      dataOut.writeInt(SpecialLengths.END_OF_STREAM)
+      dataOut.flush()
+
+      // Receive the schema
+      val schema = dataIn.readInt() match {
+        case length if length >= 0 =>
+          val obj = new Array[Byte](length)
+          dataIn.readFully(obj)
+          DataType.fromJson(new String(obj, StandardCharsets.UTF_8)).asInstanceOf[StructType]
+
+        case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+          val exLength = dataIn.readInt()
+          val obj = new Array[Byte](exLength)
+          dataIn.readFully(obj)
+          val msg = new String(obj, StandardCharsets.UTF_8)
+          env.destroyPythonWorker(pythonExec, workerModule, envVars.asScala.toMap, worker)
+          throw new AnalysisException(msg)

Review Comment:
   should we make a new error class for this?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala:
##########
@@ -91,3 +122,104 @@ case class UserDefinedPythonTableFunction(
     Dataset.ofRows(session, udtf)
   }
 }
+
+object UserDefinedPythonTableFunction {
+
+  private[this] val workerModule = "pyspark.sql.worker.analyze_udtf"
+
+  /**
+   * Runs the Python UDTF's `analyze` static method.
+   */
+  def analyzeInPython(func: PythonFunction, exprs: Seq[Expression]): StructType = {
+    val env = SparkEnv.get
+    val bufferSize: Int = env.conf.get(BUFFER_SIZE)
+    val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+    val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
+    val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+    val envVars = new HashMap[String, String](func.envVars)
+    val pythonExec = func.pythonExec
+    val pythonVer = func.pythonVer
+
+    if (reuseWorker) {
+      envVars.put("SPARK_REUSE_WORKER", "1")
+    }
+    if (simplifiedTraceback) {
+      envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
+    }
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    EvaluatePython.registerPicklers()
+    val pickler = new Pickler(/* useMemo = */ true,
+      /* valueCompare = */ false)

Review Comment:
   ```suggestion
         valueCompare = false)
   ```



##########
python/pyspark/sql/functions.py:
##########
@@ -15524,9 +15524,10 @@ def udtf(
     ----------
     cls : class
         the Python user-defined table function handler class.
-    returnType : :class:`pyspark.sql.types.StructType` or str
+    returnType : :class:`pyspark.sql.types.StructType` or str, optional
         the return type of the user-defined table function. The value can be either a
         :class:`pyspark.sql.types.StructType` object or a DDL-formatted struct type string.
+        If None, the handler class must provide `analyze` static method.

Review Comment:
   Can you please port the description of how the 'analyze' method works from the PR description to a new action right above "Notes" on L15580? This will be useful to have a place for documentation for how this works in the code.



##########
python/pyspark/sql/worker/analyze_udtf.py:
##########
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import inspect
+import os
+import sys
+import traceback
+from typing import IO
+
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_bool,
+    read_int,
+    write_int,
+    write_with_length,
+    CPickleSerializer,
+    SpecialLengths,
+    UTF8Deserializer,
+)
+from pyspark.sql.types import StructType, _parse_datatype_json_string
+from pyspark.util import try_simplify_traceback
+from pyspark.worker import read_command
+
+pickleSer = CPickleSerializer()
+utf8_deserializer = UTF8Deserializer()
+
+
+def main(infile: IO, outfile: IO) -> None:

Review Comment:
   can you put a high level comment here describing that this runs on the Python interpreter and is responsible for receiving the call to the 'analyze' method, deserializing the arguments, invoking the method, serializing the result, and sending it back?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala:
##########
@@ -91,3 +122,104 @@ case class UserDefinedPythonTableFunction(
     Dataset.ofRows(session, udtf)
   }
 }
+
+object UserDefinedPythonTableFunction {
+
+  private[this] val workerModule = "pyspark.sql.worker.analyze_udtf"
+
+  /**
+   * Runs the Python UDTF's `analyze` static method.

Review Comment:
   please also mention:
   * who is expected to call this
   * what parts of the Python UDTF in 'func' we inspect (i.e. the arguments)
   * how we serialize and deserialize the arguments, and the 'analyze' result type
   * which error conditions may occur and which exceptions we raise under those circumstances



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -719,6 +726,153 @@ def terminate(self):
         self.assertIn("Evaluate the input row", cls.eval.__doc__)
         self.assertIn("Terminate the UDTF", cls.terminate.__doc__)
 
+    def test_simple_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> StructType:
+                return StructType().add("c1", StringType()).add("c2", StringType())
+
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF)
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a) -> StructType:
+                assert isinstance(a, dict)
+                assert isinstance(a["data_type"], DataType)
+                assert a["value"] is not None
+                assert a["is_table"] is False
+                return StructType().add("a", a["data_type"])
+
+            def eval(self, a):
+                yield a,
+
+        func = udtf(TestUDTF)
+
+        df1 = func(lit(1))
+        self.assertEquals(df1.schema, StructType().add("a", IntegerType()))
+        self.assertEqual(df1.collect(), [Row(a=1)])
+
+        df2 = func(lit("x"))
+        self.assertEquals(df2.schema, StructType().add("a", StringType()))
+        self.assertEqual(df2.collect(), [Row(a="x")])
+
+    def test_udtf_with_analyze_multiple_arguments(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a, b) -> StructType:
+                return StructType().add("a", a["data_type"]).add("b", b["data_type"])
+
+            def eval(self, a, b):
+                yield a, b
+
+        func = udtf(TestUDTF)
+
+        df = func(lit(1), lit("x"))
+        self.assertEquals(df.schema, StructType().add("a", IntegerType()).add("b", StringType()))
+        self.assertEqual(df.collect(), [Row(a=1, b="x")])
+
+    def test_udtf_with_analyze_table_argument(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a) -> StructType:
+                assert isinstance(a, dict)
+                assert isinstance(a["data_type"], StructType)
+                assert a["value"] is None
+                assert a["is_table"] is True
+                return StructType().add("a", a["data_type"][0].dataType)
+
+            def eval(self, a: Row):
+                if a["id"] > 5:
+                    yield a["id"],
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 8)))")
+        self.assertEqual(df.schema, StructType().add("a", LongType()))
+        self.assertEqual(df.collect(), [Row(a=6), Row(a=7)])
+
+    def test_udtf_with_neither_return_type_nor_analyze(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        with self.assertRaises(PySparkAttributeError) as e:
+            udtf(TestUDTF)
+
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_RETURN_TYPE",
+            message_parameters={"name": "TestUDTF"},
+        )
+
+    def test_udtf_with_non_static_analyze(self):

Review Comment:
   this says "non-static", but the method looks static and always returns the same schema?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala:
##########
@@ -91,3 +122,104 @@ case class UserDefinedPythonTableFunction(
     Dataset.ofRows(session, udtf)
   }
 }
+
+object UserDefinedPythonTableFunction {
+
+  private[this] val workerModule = "pyspark.sql.worker.analyze_udtf"
+
+  /**
+   * Runs the Python UDTF's `analyze` static method.
+   */
+  def analyzeInPython(func: PythonFunction, exprs: Seq[Expression]): StructType = {
+    val env = SparkEnv.get
+    val bufferSize: Int = env.conf.get(BUFFER_SIZE)
+    val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+    val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
+    val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+    val envVars = new HashMap[String, String](func.envVars)
+    val pythonExec = func.pythonExec
+    val pythonVer = func.pythonVer
+
+    if (reuseWorker) {
+      envVars.put("SPARK_REUSE_WORKER", "1")
+    }
+    if (simplifiedTraceback) {
+      envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
+    }
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    EvaluatePython.registerPicklers()
+    val pickler = new Pickler(/* useMemo = */ true,
+      /* valueCompare = */ false)
+
+    val (worker: Socket, _) =
+      env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
+    val releasedOrClosed = new AtomicBoolean(false)

Review Comment:
   please add a brief comment mentioning why we need an atomic boolean here? what is the concurrency?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala:
##########
@@ -91,3 +122,104 @@ case class UserDefinedPythonTableFunction(
     Dataset.ofRows(session, udtf)
   }
 }
+
+object UserDefinedPythonTableFunction {
+
+  private[this] val workerModule = "pyspark.sql.worker.analyze_udtf"
+
+  /**
+   * Runs the Python UDTF's `analyze` static method.
+   */
+  def analyzeInPython(func: PythonFunction, exprs: Seq[Expression]): StructType = {
+    val env = SparkEnv.get
+    val bufferSize: Int = env.conf.get(BUFFER_SIZE)
+    val authSocketTimeout = env.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+    val reuseWorker = env.conf.get(PYTHON_WORKER_REUSE)
+    val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
+
+    val envVars = new HashMap[String, String](func.envVars)
+    val pythonExec = func.pythonExec
+    val pythonVer = func.pythonVer
+
+    if (reuseWorker) {
+      envVars.put("SPARK_REUSE_WORKER", "1")
+    }
+    if (simplifiedTraceback) {
+      envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
+    }
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+
+    EvaluatePython.registerPicklers()
+    val pickler = new Pickler(/* useMemo = */ true,
+      /* valueCompare = */ false)
+
+    val (worker: Socket, _) =
+      env.createPythonWorker(pythonExec, workerModule, envVars.asScala.toMap)
+    val releasedOrClosed = new AtomicBoolean(false)
+    try {
+      val dataOut =
+        new DataOutputStream(new BufferedOutputStream(worker.getOutputStream, bufferSize))
+      val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
+
+      // Python version of driver
+      PythonRDD.writeUTF(pythonVer, dataOut)
+
+      // Send Python UDTF
+      dataOut.writeInt(func.command.length)
+      dataOut.write(func.command.toArray)
+
+      // Send arguments
+      dataOut.writeInt(exprs.length)
+      exprs.foreach { expr =>
+        PythonRDD.writeUTF(expr.dataType.json, dataOut)
+        if (expr.foldable) {

Review Comment:
   here we 'eval' the expression if it's foldable but non-literal, and that is supported. Can you update the PR description and any class comments/docs in this PR to mention that we can provide the expression value like this for not only literal expressions, but also constant but non-literal expressions as well?



##########
python/pyspark/sql/worker/analyze_udtf.py:
##########
@@ -0,0 +1,137 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import inspect
+import os
+import sys
+import traceback
+from typing import IO
+
+from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.serializers import (
+    read_bool,
+    read_int,
+    write_int,
+    write_with_length,
+    CPickleSerializer,
+    SpecialLengths,
+    UTF8Deserializer,
+)
+from pyspark.sql.types import StructType, _parse_datatype_json_string
+from pyspark.util import try_simplify_traceback
+from pyspark.worker import read_command
+
+pickleSer = CPickleSerializer()
+utf8_deserializer = UTF8Deserializer()
+
+
+def main(infile: IO, outfile: IO) -> None:
+    try:
+        # Check Python version

Review Comment:
   optional: this is a big block of code in the `try` block, should we move it into a helper function to reduce indentation?



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -719,6 +726,153 @@ def terminate(self):
         self.assertIn("Evaluate the input row", cls.eval.__doc__)
         self.assertIn("Terminate the UDTF", cls.terminate.__doc__)
 
+    def test_simple_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> StructType:
+                return StructType().add("c1", StringType()).add("c2", StringType())
+
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF)
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a) -> StructType:
+                assert isinstance(a, dict)
+                assert isinstance(a["data_type"], DataType)
+                assert a["value"] is not None
+                assert a["is_table"] is False
+                return StructType().add("a", a["data_type"])
+
+            def eval(self, a):
+                yield a,
+
+        func = udtf(TestUDTF)
+
+        df1 = func(lit(1))
+        self.assertEquals(df1.schema, StructType().add("a", IntegerType()))
+        self.assertEqual(df1.collect(), [Row(a=1)])
+
+        df2 = func(lit("x"))
+        self.assertEquals(df2.schema, StructType().add("a", StringType()))
+        self.assertEqual(df2.collect(), [Row(a="x")])
+
+    def test_udtf_with_analyze_multiple_arguments(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a, b) -> StructType:
+                return StructType().add("a", a["data_type"]).add("b", b["data_type"])
+
+            def eval(self, a, b):
+                yield a, b
+
+        func = udtf(TestUDTF)
+
+        df = func(lit(1), lit("x"))
+        self.assertEquals(df.schema, StructType().add("a", IntegerType()).add("b", StringType()))
+        self.assertEqual(df.collect(), [Row(a=1, b="x")])
+
+    def test_udtf_with_analyze_table_argument(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a) -> StructType:
+                assert isinstance(a, dict)
+                assert isinstance(a["data_type"], StructType)
+                assert a["value"] is None
+                assert a["is_table"] is True
+                return StructType().add("a", a["data_type"][0].dataType)
+
+            def eval(self, a: Row):
+                if a["id"] > 5:
+                    yield a["id"],
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 8)))")
+        self.assertEqual(df.schema, StructType().add("a", LongType()))
+        self.assertEqual(df.collect(), [Row(a=6), Row(a=7)])
+
+    def test_udtf_with_neither_return_type_nor_analyze(self):
+        class TestUDTF:
+            def eval(self):
+                yield "hello", "world"
+
+        with self.assertRaises(PySparkAttributeError) as e:
+            udtf(TestUDTF)
+
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_RETURN_TYPE",
+            message_parameters={"name": "TestUDTF"},
+        )
+
+    def test_udtf_with_non_static_analyze(self):
+        class TestUDTF:
+            def analyze(self) -> StructType:
+                return StructType().add("c1", StringType()).add("c2", StringType())
+
+            def eval(self):
+                yield "hello", "world"
+
+        with self.assertRaises(PySparkAttributeError) as e:
+            udtf(TestUDTF)
+
+        self.check_error(
+            exception=e.exception,
+            error_class="INVALID_UDTF_RETURN_TYPE",
+            message_parameters={"name": "TestUDTF"},
+        )
+
+    def test_udtf_with_analyze_returning_non_struct(self):

Review Comment:
   other test ideas:
   
   positive tests:
   * call the UDTF with integer types, strings, array/struct/map types
   * call the UDTF with a relation using the TABLE keyword, and the 'analyze' method returns a relation with the same schema + an additional column
   * call the UDTF with a constant (but non-literal) scalar integer argument, and the 'analyze' method returns a relation with the same number of columns as the value of that integer (or returns an error unless the integer is between 1 and 10).
   
   negative tests:
   * the 'analyze' method throws an explicit exception
   * the 'analyze' method gets called with literal `NULL`
   * the 'analyze' method attempts to lookup a key from the dictionary that doesn't exist



##########
python/pyspark/sql/tests/test_udtf.py:
##########
@@ -719,6 +726,153 @@ def terminate(self):
         self.assertIn("Evaluate the input row", cls.eval.__doc__)
         self.assertIn("Terminate the UDTF", cls.terminate.__doc__)
 
+    def test_simple_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze() -> StructType:
+                return StructType().add("c1", StringType()).add("c2", StringType())
+
+            def eval(self):
+                yield "hello", "world"
+
+        func = udtf(TestUDTF)
+        rows = func().collect()
+        self.assertEqual(rows, [Row(c1="hello", c2="world")])
+
+    def test_udtf_with_analyze(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a) -> StructType:
+                assert isinstance(a, dict)
+                assert isinstance(a["data_type"], DataType)
+                assert a["value"] is not None
+                assert a["is_table"] is False
+                return StructType().add("a", a["data_type"])
+
+            def eval(self, a):
+                yield a,
+
+        func = udtf(TestUDTF)
+
+        df1 = func(lit(1))
+        self.assertEquals(df1.schema, StructType().add("a", IntegerType()))
+        self.assertEqual(df1.collect(), [Row(a=1)])
+
+        df2 = func(lit("x"))
+        self.assertEquals(df2.schema, StructType().add("a", StringType()))
+        self.assertEqual(df2.collect(), [Row(a="x")])
+
+    def test_udtf_with_analyze_multiple_arguments(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a, b) -> StructType:
+                return StructType().add("a", a["data_type"]).add("b", b["data_type"])
+
+            def eval(self, a, b):
+                yield a, b
+
+        func = udtf(TestUDTF)
+
+        df = func(lit(1), lit("x"))
+        self.assertEquals(df.schema, StructType().add("a", IntegerType()).add("b", StringType()))
+        self.assertEqual(df.collect(), [Row(a=1, b="x")])
+
+    def test_udtf_with_analyze_table_argument(self):
+        class TestUDTF:
+            @staticmethod
+            def analyze(a) -> StructType:
+                assert isinstance(a, dict)
+                assert isinstance(a["data_type"], StructType)
+                assert a["value"] is None
+                assert a["is_table"] is True
+                return StructType().add("a", a["data_type"][0].dataType)
+
+            def eval(self, a: Row):
+                if a["id"] > 5:
+                    yield a["id"],
+
+        func = udtf(TestUDTF)
+        self.spark.udtf.register("test_udtf", func)
+
+        df = self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 8)))")
+        self.assertEqual(df.schema, StructType().add("a", LongType()))
+        self.assertEqual(df.collect(), [Row(a=6), Row(a=7)])
+
+    def test_udtf_with_neither_return_type_nor_analyze(self):

Review Comment:
   can you also add a case with both a return type and an 'analyze' method?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org