You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/27 00:35:58 UTC

[spark] branch branch-3.4 updated: [SPARK-42920][CONNECT][PYTHON] Enable tests for UDF with UDT

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 31ede7330a3 [SPARK-42920][CONNECT][PYTHON] Enable tests for UDF with UDT
31ede7330a3 is described below

commit 31ede7330a314b18faa591a9313ed31c5c8b63c1
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Mon Mar 27 09:35:33 2023 +0900

    [SPARK-42920][CONNECT][PYTHON] Enable tests for UDF with UDT
    
    ### What changes were proposed in this pull request?
    
    Enables tests for UDF with UDT.
    
    ### Why are the changes needed?
    
    Now that UDF with UDT should work, the related tests should be enabled to see if it works.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Enabled/modified the related tests.
    
    Closes #40549 from ueshin/issues/SPARK-42920/udf_with_udt.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
    (cherry picked from commit 80f8664e8278335788d8fa1dd00654f3eaec8ed6)
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../pyspark/sql/tests/connect/test_parity_types.py |  4 +--
 python/pyspark/sql/tests/test_types.py             | 38 +++++++++++-----------
 2 files changed, 21 insertions(+), 21 deletions(-)

diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py
index a2f81fbf25e..aacf5793b2b 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -84,8 +84,8 @@ class TypesParityTests(TypesTestsMixin, ReusedConnectTestCase):
         super().test_infer_schema_upcast_int_to_string()
 
     @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
-    def test_udf_with_udt(self):
-        super().test_udf_with_udt()
+    def test_rdd_with_udt(self):
+        super().test_rdd_with_udt()
 
     @unittest.skip("Requires JVM access.")
     def test_udt(self):
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index bee899e928e..5d6476b47f4 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -25,8 +25,7 @@ import sys
 import unittest
 
 from pyspark.sql import Row
-from pyspark.sql.functions import col
-from pyspark.sql.udf import UserDefinedFunction
+from pyspark.sql import functions as F
 from pyspark.errors import AnalysisException
 from pyspark.sql.types import (
     ByteType,
@@ -381,7 +380,7 @@ class TypesTestsMixin:
         try:
             self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true")
             df = self.spark.createDataFrame([(1,), (11,)], ["value"])
-            ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
+            ret = df.select(F.col("value").cast(DecimalType(1, -1))).collect()
             actual = list(map(lambda r: int(r.value), ret))
             self.assertEqual(actual, [0, 10])
         finally:
@@ -548,8 +547,6 @@ class TypesTestsMixin:
         df.collect()
 
     def test_complex_nested_udt_in_df(self):
-        from pyspark.sql.functions import udf
-
         schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
         df = self.spark.createDataFrame(
             [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema
@@ -558,7 +555,7 @@ class TypesTestsMixin:
 
         gd = df.groupby("key").agg({"val": "collect_list"})
         gd.collect()
-        udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
+        udf = F.udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
         gd.select(udf(*gd)).collect()
 
     def test_udt_with_none(self):
@@ -667,20 +664,27 @@ class TypesTestsMixin:
     def test_udf_with_udt(self):
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
         df = self.spark.createDataFrame([row])
-        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
-        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+        udf = F.udf(lambda p: p.y, DoubleType())
         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
-        udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
+        udf2 = F.udf(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
         self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
 
         row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
         df = self.spark.createDataFrame([row])
-        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
-        udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+        udf = F.udf(lambda p: p.y, DoubleType())
         self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
-        udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
+        udf2 = F.udf(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
         self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
 
+    def test_rdd_with_udt(self):
+        row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+        df = self.spark.createDataFrame([row])
+        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
+
+        row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+        df = self.spark.createDataFrame([row])
+        self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
+
     def test_parquet_with_udt(self):
         row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
         df0 = self.spark.createDataFrame([row])
@@ -719,8 +723,6 @@ class TypesTestsMixin:
         )
 
     def test_cast_to_string_with_udt(self):
-        from pyspark.sql.functions import col
-
         row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
         schema = StructType(
             [
@@ -730,18 +732,16 @@ class TypesTestsMixin:
         )
         df = self.spark.createDataFrame([row], schema)
 
-        result = df.select(col("point").cast("string"), col("pypoint").cast("string")).head()
+        result = df.select(F.col("point").cast("string"), F.col("pypoint").cast("string")).head()
         self.assertEqual(result, Row(point="(1.0, 2.0)", pypoint="[3.0, 4.0]"))
 
     def test_cast_to_udt_with_udt(self):
-        from pyspark.sql.functions import col
-
         row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0))
         df = self.spark.createDataFrame([row])
         with self.assertRaises(AnalysisException):
-            df.select(col("point").cast(PythonOnlyUDT())).collect()
+            df.select(F.col("point").cast(PythonOnlyUDT())).collect()
         with self.assertRaises(AnalysisException):
-            df.select(col("python_only_point").cast(ExamplePointUDT())).collect()
+            df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect()
 
     def test_struct_type(self):
         struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org