You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/15 03:35:23 UTC

[spark] branch branch-3.4 updated: [SPARK-42765][CONNECT][PYTHON] Enable importing `pandas_udf` from `pyspark.sql.connect.functions`

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new d92e5a5a683 [SPARK-42765][CONNECT][PYTHON] Enable importing `pandas_udf` from `pyspark.sql.connect.functions`
d92e5a5a683 is described below

commit d92e5a5a683a1287681178c4b5862839922dd2fc
Author: Xinrong Meng <xi...@apache.org>
AuthorDate: Wed Mar 15 12:34:53 2023 +0900

    [SPARK-42765][CONNECT][PYTHON] Enable importing `pandas_udf` from `pyspark.sql.connect.functions`
    
    ### What changes were proposed in this pull request?
    Enable users to import pandas_udf via `pyspark.sql.connect.functions.pandas_udf`.
    Previously, only `pyspark.sql.functions.pandas_udf` is supported.
    
    ### Why are the changes needed?
    Usability.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. Now users can import pandas_udf via `pyspark.sql.connect.functions.pandas_udf`.
    
    Previously only `pyspark.sql.functions.pandas_udf` is supported in Connect; importing `pyspark.sql.connect.functions.pandas_udf` raises an error instead, as shown below
    
    ```sh
    >>> pyspark.sql.connect.functions.pandas_udf()
    Traceback (most recent call last):
    ...
    NotImplementedError: pandas_udf() is not implemented.
    ```
    
    Now, `pyspark.sql.connect.functions.pandas_udf` point to `pyspark.sql.functions.pandas_udf`, as shown below,
    ```sh
    >>> from pyspark.sql.connect import functions as CF
    >>> from pyspark.sql import functions as SF
    >>> getattr(CF, "pandas_udf")
    <function pandas_udf at 0x7f9c88812700>
    >>> getattr(SF, "pandas_udf")
    <function pandas_udf at 0x7f9c88812700>
    ```
    
    ### How was this patch tested?
    Unit test.
    
    Closes #40388 from xinrong-meng/rmv_path.
    
    Authored-by: Xinrong Meng <xi...@apache.org>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
    (cherry picked from commit 149e020a5ca88b2db9c56a9d48e0c1c896b57069)
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/connect/functions.py                   | 9 +++++----
 python/pyspark/sql/tests/connect/test_connect_function.py | 9 +++------
 2 files changed, 8 insertions(+), 10 deletions(-)

diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index c89b0ad3fc0..e8bb06a3903 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -54,6 +54,11 @@ from pyspark.sql.connect.udf import _create_udf
 from pyspark.sql import functions as pysparkfuncs
 from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType, StringType
 
+# The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf
+# for code reuse.
+from pyspark.sql.functions import pandas_udf  # noqa: F401
+
+
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import (
         ColumnOrName,
@@ -2466,10 +2471,6 @@ def udf(
 udf.__doc__ = pysparkfuncs.udf.__doc__
 
 
-def pandas_udf(*args: Any, **kwargs: Any) -> None:
-    raise NotImplementedError("pandas_udf() is not implemented.")
-
-
 def _test() -> None:
     import sys
     import doctest
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py
index 599e595af62..a984bba1b66 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -2393,14 +2393,11 @@ class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, S
             sdf.withColumn("A", sfun(sdf.c)).toPandas(),
         )
 
-    def test_unsupported_functions(self):
-        # SPARK-41928: Disable unsupported functions.
-
+    def test_pandas_udf_import(self):
         from pyspark.sql.connect import functions as CF
+        from pyspark.sql import functions as SF
 
-        for f in ("pandas_udf",):
-            with self.assertRaises(NotImplementedError):
-                getattr(CF, f)()
+        self.assert_eq(getattr(CF, "pandas_udf"), getattr(SF, "pandas_udf"))
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org