You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by "dtenedor (via GitHub)" <gi...@apache.org> on 2023/05/27 00:04:48 UTC

[GitHub] [spark] dtenedor commented on a diff in pull request #41316: [SPARK-43798][SQL][PYTHON] Support Python user-defined table functions

dtenedor commented on code in PR #41316:
URL: https://github.com/apache/spark/pull/41316#discussion_r1207487409


##########
python/pyspark/sql/functions.py:
##########
@@ -10403,6 +10405,15 @@ def udf(
         return _create_py_udf(f=f, returnType=returnType, useArrow=useArrow)
 
 
+def udtf(

Review Comment:
   This is awesome! We will want a comment to document it. Maybe we could start with something like:
   
   ```
       """Creates a user defined table function (UDTF).
   
       .. versionadded:: 3.5.0
   
       Parameters
       ----------
       <fill this in>
   
       Examples
       --------
       >>> # Implement the UDTF class
       >>> class TestUDTF:
       ...   def __init__(self):
       ...     <logic>
       ...   def eval(self, *args):
       ...     yield "hello", "world"  
       ...  def terminate(self):
       ...     <logic>
       ...
       >>>
       >>> # Create the UDTF
       >>> from pyspark.sql.functions import udtf
       >>> 
       >>> test_udtf = udtf(TestUDTF, returnType="c1: string, c2: string")
       >>> 
       >>> # Invoke the UDTF
       >>> test_udtf().show()
       +-----+-----+
       |   c1|   c2|
       +-----+-----+
       |hello|world|
       +-----+-----+
   
       >>> # Register the UDTF
       >>> spark.udtf.register(name="test_udtf", f=test_udtf)
       >>> 
       >>> # Invoke the UDTF in SQL
       spark.sql("SELECT * FROM test_udtf()").show()
       +-----+-----+
       |   c1|   c2|
       +-----+-----+
       |hello|world|
       +-----+-----+
     
       Notes
       -----
       User-defined table functions are considered opaque to the optimizer by default.
       As a result, operations like filters from WHERE clauses or limits from
       LIMIT/OFFSET clauses that appear after the UDTF call will execute on the
       UDTF's result relation. By the same token, any relations forwarded as input
       to UDTFs will plan as full table scans in the absence of any explicit such
       filtering or other logic explicitly written in a table subquery surrounding the
       provided input relation.
   ```
   



##########
python/pyspark/worker.py:
##########
@@ -456,6 +456,54 @@ def assign_cols_by_name(runner_conf):
     )
 
 
+def read_udtf(pickleSer, infile, eval_type):

Review Comment:
   should we add some high level comment here explaining how the logic works?



##########
python/pyspark/sql/udtf.py:
##########
@@ -0,0 +1,209 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+"""
+User-defined table function related classes and functions
+"""
+import sys
+from typing import Type, TYPE_CHECKING, Optional, Union
+
+from py4j.java_gateway import JavaObject
+
+from pyspark.sql.column import _to_java_column, _to_seq
+from pyspark.sql.types import StructType, _parse_datatype_string
+from pyspark.sql.udf import _wrap_function
+
+if TYPE_CHECKING:
+    from pyspark.sql._typing import ColumnOrName
+    from pyspark.sql.dataframe import DataFrame
+    from pyspark.sql.session import SparkSession
+
+__all__ = ["UDTFRegistration"]
+
+
+def _create_udtf(
+    f: Type,
+    returnType: Union[StructType, str],
+    name: Optional[str] = None,
+    deterministic: bool = True,
+) -> "UserDefinedTableFunction":
+    """Create a Python UDTF."""
+    udtf_obj = UserDefinedTableFunction(
+        f, returnType=returnType, name=name, deterministic=deterministic
+    )
+    return udtf_obj
+
+
+class UserDefinedTableFunction:
+    """
+    User-defined table function in Python
+
+    .. versionadded:: 3.5
+
+    Notes
+    -----
+    The constructor of this class is not supposed to be directly called.
+    Use :meth:`pyspark.sql.functions.udtf` to create this instance.
+
+    This API is evolving.
+    """
+
+    def __init__(
+        self,
+        func: Type,
+        returnType: Union[StructType, str],
+        name: Optional[str] = None,
+        deterministic: bool = True,
+    ):
+        if not isinstance(func, type):
+            raise TypeError(
+                f"Invalid user-defined table function: the function handler "
+                f"must be a class, but got {type(func)}."
+            )
+
+        # TODO: add more checks for invalid user-defined table functions.

Review Comment:
   maybe we could check that it contains a method named `eval` with the expected arguments as well?



##########
python/pyspark/worker.py:
##########
@@ -456,6 +456,54 @@ def assign_cols_by_name(runner_conf):
     )
 
 
+def read_udtf(pickleSer, infile, eval_type):
+    num_udtfs = read_int(infile)
+    if num_udtfs != 1:
+        raise RuntimeError("Got more than 1 UDTF")
+
+    # See `PythonUDFRunner.writeUDFs`.
+    num_arg = read_int(infile)
+    arg_offsets = [read_int(infile) for _ in range(num_arg)]
+    num_chained_funcs = read_int(infile)
+    if num_chained_funcs != 1:
+        raise RuntimeError("Got more than 1 chained UDTF")
+
+    handler, return_type = read_command(pickleSer, infile)
+    if not isinstance(handler, type):
+        raise RuntimeError(f"UDTF handler must be a class, but got {type(handler)}.")
+
+    # Instantiate the UDTF class.
+    try:
+        udtf = handler()
+    except Exception as e:
+        raise RuntimeError(f"Failed to init the UDTF handler: {str(e)}") from None
+
+    # Wrap the eval method.
+    if not hasattr(udtf, "eval"):
+        raise RuntimeError("Python UDTF must implement the eval method.")
+
+    def wrap_udtf(f, return_type):
+        if return_type.needConversion():
+            toInternal = return_type.toInternal
+            return lambda *a: map(toInternal, f(*a))
+        else:
+            return lambda *a: f(*a)
+
+    f = wrap_udtf(getattr(udtf, "eval"), return_type)
+
+    def mapper(a):
+        results = tuple(f(*[a[o] for o in arg_offsets]))
+        return results
+
+    # Return an iterator of iterators.
+    def func(_, it):
+        return map(mapper, it)
+
+    ser = BatchedSerializer(CPickleSerializer(), 100)

Review Comment:
   what is this 100 number represent? leave a comment and/or move it to a separate constant?



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonUDTFExec.scala:
##########
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+
+import net.razorvine.pickle.Unpickler
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A physical plan that evaluates a [[PythonUDTF]].

Review Comment:
   should we explain what the `requiredChildOutput` and `resultAttrs` represent?



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonUDTFSuite.scala:
##########
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import org.apache.spark.sql.{IntegratedUDFTestUtils, QueryTest, Row}
+import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.test.SharedSparkSession
+
+class PythonUDTFSuite extends QueryTest with SharedSparkSession {
+
+  import testImplicits._
+
+  import IntegratedUDFTestUtils._
+
+  val pythonTestUDTF = TestPythonUDTF(name = "pyUDTF")
+
+  test("Simple PythonUDTF") {
+    // scalastyle:off assume
+    assume(shouldTestPythonUDFs)
+    // scalastyle:on assume
+    val df = pythonTestUDTF(spark, lit(1), lit(2))
+    checkAnswer(df, Seq(Row(1, 2, -1), Row(1, 2, 1), Row(1, 2, 3)))
+  }
+
+  test("PythonUDTF with lateral join") {
+    // scalastyle:off assume
+    assume(shouldTestPythonUDFs)
+    // scalastyle:on assume
+    withTempView("t") {
+      val func = createUserDefinedPythonTableFunction("testUDTF")
+      spark.udtf.registerPython("testUDTF", func)
+      Seq((0, 1), (1, 2)).toDF("a", "b").createOrReplaceTempView("t")
+      checkAnswer(
+        sql("SELECT f.* FROM t, LATERAL testUDTF(a, b) f"),

Review Comment:
   more test ideas:
   
   * call the UDTF with a combination of scalar constant values (e.g. `lit(1)`) and correlated references to columns from the left side of the lateral join
   * UDTF with a scalar python UDF call as one of the arg(s)
   * test UDTF where the `finalize` method returns more rows
   * call the UDTF with a number or types of input arguments that it does not expect, do we throw an appropriate exception?



##########
sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala:
##########
@@ -191,6 +191,59 @@ object IntegratedUDFTestUtils extends SQLHelper {
     throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
   }
 
+  private def createPythonUDTF(funcName: String, pythonScript: String): Array[Byte] = {
+    if (shouldTestPythonUDFs) {
+      var binaryPandasFunc: Array[Byte] = null
+      withTempPath { codePath =>
+        Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8))
+        withTempPath { path =>
+          Process(
+            Seq(
+              pythonExec,
+              "-c",
+              "from pyspark.serializers import CloudPickleSerializer; " +
+                s"f = open('$path', 'wb');" +
+                s"exec(open('$codePath', 'r').read());" +
+                "f.write(CloudPickleSerializer().dumps(" +
+                s"($funcName, returnType)))"),
+            None,
+            "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
+          binaryPandasFunc = Files.readAllBytes(path.toPath)
+        }
+      }
+      assert(binaryPandasFunc != null)
+      binaryPandasFunc
+    } else {
+      throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
+    }
+  }
+
+  private lazy val pythonTableFunc: Array[Byte] = {
+    val script =
+      """
+        |from pyspark.sql.types import StructType, StructField, IntegerType
+        |returnType = StructType([
+        |  StructField("a", IntegerType()),
+        |  StructField("b", IntegerType()),
+        |  StructField("c", IntegerType()),
+        |])
+        |class SimpleUDTF:
+        |    def __init__(self):
+        |        self._count = 0
+        |
+        |    def eval(self, a: int, b: int):
+        |        self._count += 1
+        |        yield a, b, a + b
+        |        yield a, b, a - b
+        |        yield a, b, b - a
+        |
+        |    def terminate(self):
+        |        self._count = 0

Review Comment:
   do we plan to reuse the class instance between multiple evaluations of input partition row sets? I would imagine it would be simpler for the user to not do so, therefore we do not need to think about resetting logic like this, and we can just assume that the __init__ method will freshly initialize and assign any class members as needed?



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala:
##########
@@ -171,6 +186,18 @@ case class ArrowEvalPython(
     copy(child = newChild)
 }
 
+/**
+ * A logical plan that evaluates a [[PythonUDTF]].

Review Comment:
   should we explain what the `requiredChildOutput` and `resultAttrs` represent here?



##########
python/pyspark/worker.py:
##########
@@ -456,6 +456,54 @@ def assign_cols_by_name(runner_conf):
     )
 
 
+def read_udtf(pickleSer, infile, eval_type):
+    num_udtfs = read_int(infile)
+    if num_udtfs != 1:
+        raise RuntimeError("Got more than 1 UDTF")
+
+    # See `PythonUDFRunner.writeUDFs`.
+    num_arg = read_int(infile)
+    arg_offsets = [read_int(infile) for _ in range(num_arg)]
+    num_chained_funcs = read_int(infile)
+    if num_chained_funcs != 1:
+        raise RuntimeError("Got more than 1 chained UDTF")
+
+    handler, return_type = read_command(pickleSer, infile)
+    if not isinstance(handler, type):
+        raise RuntimeError(f"UDTF handler must be a class, but got {type(handler)}.")
+
+    # Instantiate the UDTF class.
+    try:
+        udtf = handler()
+    except Exception as e:
+        raise RuntimeError(f"Failed to init the UDTF handler: {str(e)}") from None
+
+    # Wrap the eval method.
+    if not hasattr(udtf, "eval"):

Review Comment:
   should we also check that the "eval" method has the expected number and type of input arguments?



##########
sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala:
##########
@@ -191,6 +191,59 @@ object IntegratedUDFTestUtils extends SQLHelper {
     throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")
   }
 
+  private def createPythonUDTF(funcName: String, pythonScript: String): Array[Byte] = {
+    if (shouldTestPythonUDFs) {

Review Comment:
   reverse the logic and throw the exception early, de-denting the rest of the block?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org