You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/28 04:01:34 UTC
[spark] branch master updated: [SPARK-42612][CONNECT][PYTHON][TESTS] Enable more parity tests related to functions
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a9f20c12f81 [SPARK-42612][CONNECT][PYTHON][TESTS] Enable more parity tests related to functions
a9f20c12f81 is described below
commit a9f20c12f81e8832123ea8ee87213e12846a69f9
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Tue Feb 28 13:01:18 2023 +0900
[SPARK-42612][CONNECT][PYTHON][TESTS] Enable more parity tests related to functions
### What changes were proposed in this pull request?
Enables more parity tests related to `functions`.
### Why are the changes needed?
There are still some more tests we should enable.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Modified/enabled related tests.
Closes #40203 from ueshin/issues/SPARK-42612/tests.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/connect/functions.py | 4 +++
.../pyspark/sql/tests/connect/test_connect_plan.py | 3 --
.../sql/tests/connect/test_parity_functions.py | 16 ++-------
python/pyspark/sql/tests/test_functions.py | 42 ++++++++++++++--------
python/pyspark/testing/sqlutils.py | 2 +-
5 files changed, 36 insertions(+), 31 deletions(-)
diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index 87dfe90107d..268774e3211 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -224,6 +224,10 @@ def lit(col: Any) -> Column:
if isinstance(col, Column):
return col
elif isinstance(col, list):
+ if any(isinstance(c, Column) for c in col):
+ raise PySparkValueError(
+ error_class="COLUMN_IN_LIST", message_parameters={"func_name": "lit"}
+ )
return array(*[lit(c) for c in col])
elif isinstance(col, np.ndarray) and col.ndim == 1:
if _from_numpy_type(col.dtype) is None:
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 2de51189c4d..8c09b9cfaa5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -838,9 +838,6 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
p = multi_type_lit.to_plan(None)
self.assertIsNotNone(p)
- lit_list_plan = lit([lit(10), lit("str")]).to_plan(None)
- self.assertIsNotNone(lit_list_plan)
-
def test_column_alias(self) -> None:
# SPARK-40809: Support for Column Aliases
col0 = col("a").alias("martin")
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py
index a69e47effe4..747f9a1b287 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -38,23 +38,13 @@ class FunctionsParityTests(FunctionsTestsMixin, ReusedConnectTestCase):
def test_input_file_name_reset_for_rdd(self):
super().test_input_file_name_reset_for_rdd()
- # TODO(SPARK-41901): Parity in String representation of Column
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_inverse_trig_functions(self):
- super().test_inverse_trig_functions()
-
- # TODO(SPARK-41834): Implement SparkSession.conf
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_lit_list(self):
- super().test_lit_list()
-
def test_raise_error(self):
self.check_raise_error(SparkConnectException)
- # Comparing column type of connect and pyspark
- @unittest.skip("Fails in Spark Connect, should enable.")
def test_sorting_functions_with_column(self):
- super().test_sorting_functions_with_column()
+ from pyspark.sql.connect.column import Column
+
+ self.check_sorting_functions_with_column(Column)
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 3aec7cc42de..44f1b9a4d13 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -318,19 +318,29 @@ class FunctionsTestsMixin:
)
def test_inverse_trig_functions(self):
- from pyspark.sql import functions
+ df = self.spark.createDataFrame([Row(a=i * 0.2, b=i * -0.2) for i in range(10)])
- funs = [
- (functions.acosh, "ACOSH"),
- (functions.asinh, "ASINH"),
- (functions.atanh, "ATANH"),
- ]
+ def check(trig, inv, y_axis_symmetrical):
+ SQLTestUtils.assert_close(
+ [n * 0.2 for n in range(10)],
+ df.select(inv(trig(df.a))).collect(),
+ )
+ if y_axis_symmetrical:
+ SQLTestUtils.assert_close(
+ [n * 0.2 for n in range(10)],
+ df.select(inv(trig(df.b))).collect(),
+ )
+ else:
+ SQLTestUtils.assert_close(
+ [n * -0.2 for n in range(10)],
+ df.select(inv(trig(df.b))).collect(),
+ )
- cols = ["a", functions.col("a")]
+ from pyspark.sql import functions
- for f, alias in funs:
- for c in cols:
- self.assertIn(f"{alias}(a)", repr(f(c)))
+ check(functions.cosh, functions.acosh, y_axis_symmetrical=True)
+ check(functions.sinh, functions.asinh, y_axis_symmetrical=False)
+ check(functions.tanh, functions.atanh, y_axis_symmetrical=False)
def test_reciprocal_trig_functions(self):
# SPARK-36683: Tests for reciprocal trig functions (SEC, CSC and COT)
@@ -578,9 +588,13 @@ class FunctionsTestsMixin:
self.assertRaises(TypeError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
def test_sorting_functions_with_column(self):
- from pyspark.sql import functions
from pyspark.sql.column import Column
+ self.check_sorting_functions_with_column(Column)
+
+ def check_sorting_functions_with_column(self, tpe):
+ from pyspark.sql import functions
+
funs = [
functions.asc_nulls_first,
functions.asc_nulls_last,
@@ -592,17 +606,17 @@ class FunctionsTestsMixin:
for fun in funs:
for _expr in exprs:
res = fun(_expr)
- self.assertIsInstance(res, Column)
+ self.assertIsInstance(res, tpe)
self.assertIn(f"""'x {fun.__name__.replace("_", " ").upper()}'""", str(res))
for _expr in exprs:
res = functions.asc(_expr)
- self.assertIsInstance(res, Column)
+ self.assertIsInstance(res, tpe)
self.assertIn("""'x ASC NULLS FIRST'""", str(res))
for _expr in exprs:
res = functions.desc(_expr)
- self.assertIsInstance(res, Column)
+ self.assertIsInstance(res, tpe)
self.assertIn("""'x DESC NULLS LAST'""", str(res))
def test_sort_with_nulls_order(self):
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index 46585cfdab0..937ad491479 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -251,7 +251,7 @@ class SQLTestUtils:
def assert_close(a, b):
c = [j[0] for j in b]
diff = [abs(v - c[k]) < 1e-6 if math.isfinite(v) else v == c[k] for k, v in enumerate(a)]
- return sum(diff) == len(a)
+ assert sum(diff) == len(a), f"sum: {sum(diff)}, len: {len(a)}"
class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils, PySparkErrorTestUtils):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org