You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/27 00:35:58 UTC
[spark] branch branch-3.4 updated: [SPARK-42920][CONNECT][PYTHON] Enable tests for UDF with UDT
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 31ede7330a3 [SPARK-42920][CONNECT][PYTHON] Enable tests for UDF with UDT
31ede7330a3 is described below
commit 31ede7330a314b18faa591a9313ed31c5c8b63c1
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Mon Mar 27 09:35:33 2023 +0900
[SPARK-42920][CONNECT][PYTHON] Enable tests for UDF with UDT
### What changes were proposed in this pull request?
Enables tests for UDF with UDT.
### Why are the changes needed?
Now that UDF with UDT should work, the related tests should be enabled to see if it works.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Enabled/modified the related tests.
Closes #40549 from ueshin/issues/SPARK-42920/udf_with_udt.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
(cherry picked from commit 80f8664e8278335788d8fa1dd00654f3eaec8ed6)
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../pyspark/sql/tests/connect/test_parity_types.py | 4 +--
python/pyspark/sql/tests/test_types.py | 38 +++++++++++-----------
2 files changed, 21 insertions(+), 21 deletions(-)
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py
index a2f81fbf25e..aacf5793b2b 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -84,8 +84,8 @@ class TypesParityTests(TypesTestsMixin, ReusedConnectTestCase):
super().test_infer_schema_upcast_int_to_string()
@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
- def test_udf_with_udt(self):
- super().test_udf_with_udt()
+ def test_rdd_with_udt(self):
+ super().test_rdd_with_udt()
@unittest.skip("Requires JVM access.")
def test_udt(self):
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index bee899e928e..5d6476b47f4 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -25,8 +25,7 @@ import sys
import unittest
from pyspark.sql import Row
-from pyspark.sql.functions import col
-from pyspark.sql.udf import UserDefinedFunction
+from pyspark.sql import functions as F
from pyspark.errors import AnalysisException
from pyspark.sql.types import (
ByteType,
@@ -381,7 +380,7 @@ class TypesTestsMixin:
try:
self.spark.sql("set spark.sql.legacy.allowNegativeScaleOfDecimal=true")
df = self.spark.createDataFrame([(1,), (11,)], ["value"])
- ret = df.select(col("value").cast(DecimalType(1, -1))).collect()
+ ret = df.select(F.col("value").cast(DecimalType(1, -1))).collect()
actual = list(map(lambda r: int(r.value), ret))
self.assertEqual(actual, [0, 10])
finally:
@@ -548,8 +547,6 @@ class TypesTestsMixin:
df.collect()
def test_complex_nested_udt_in_df(self):
- from pyspark.sql.functions import udf
-
schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
df = self.spark.createDataFrame(
[(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema
@@ -558,7 +555,7 @@ class TypesTestsMixin:
gd = df.groupby("key").agg({"val": "collect_list"})
gd.collect()
- udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
+ udf = F.udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
gd.select(udf(*gd)).collect()
def test_udt_with_none(self):
@@ -667,20 +664,27 @@ class TypesTestsMixin:
def test_udf_with_udt(self):
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df = self.spark.createDataFrame([row])
- self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
- udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+ udf = F.udf(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
- udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
+ udf2 = F.udf(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
df = self.spark.createDataFrame([row])
- self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
- udf = UserDefinedFunction(lambda p: p.y, DoubleType())
+ udf = F.udf(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
- udf2 = UserDefinedFunction(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
+ udf2 = F.udf(lambda p: PythonOnlyPoint(p.x + 1, p.y + 1), PythonOnlyUDT())
self.assertEqual(PythonOnlyPoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
+ def test_rdd_with_udt(self):
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ df = self.spark.createDataFrame([row])
+ self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
+
+ row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0))
+ df = self.spark.createDataFrame([row])
+ self.assertEqual(1.0, df.rdd.map(lambda r: r.point.x).first())
+
def test_parquet_with_udt(self):
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df0 = self.spark.createDataFrame([row])
@@ -719,8 +723,6 @@ class TypesTestsMixin:
)
def test_cast_to_string_with_udt(self):
- from pyspark.sql.functions import col
-
row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
schema = StructType(
[
@@ -730,18 +732,16 @@ class TypesTestsMixin:
)
df = self.spark.createDataFrame([row], schema)
- result = df.select(col("point").cast("string"), col("pypoint").cast("string")).head()
+ result = df.select(F.col("point").cast("string"), F.col("pypoint").cast("string")).head()
self.assertEqual(result, Row(point="(1.0, 2.0)", pypoint="[3.0, 4.0]"))
def test_cast_to_udt_with_udt(self):
- from pyspark.sql.functions import col
-
row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0))
df = self.spark.createDataFrame([row])
with self.assertRaises(AnalysisException):
- df.select(col("point").cast(PythonOnlyUDT())).collect()
+ df.select(F.col("point").cast(PythonOnlyUDT())).collect()
with self.assertRaises(AnalysisException):
- df.select(col("python_only_point").cast(ExamplePointUDT())).collect()
+ df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect()
def test_struct_type(self):
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org