You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/28 04:01:34 UTC

[spark] branch master updated: [SPARK-42612][CONNECT][PYTHON][TESTS] Enable more parity tests related to functions

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a9f20c12f81 [SPARK-42612][CONNECT][PYTHON][TESTS] Enable more parity tests related to functions
a9f20c12f81 is described below

commit a9f20c12f81e8832123ea8ee87213e12846a69f9
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Tue Feb 28 13:01:18 2023 +0900

    [SPARK-42612][CONNECT][PYTHON][TESTS] Enable more parity tests related to functions
    
    ### What changes were proposed in this pull request?
    
    Enables more parity tests related to `functions`.
    
    ### Why are the changes needed?
    
    There are still some more tests we should enable.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Modified/enabled related tests.
    
    Closes #40203 from ueshin/issues/SPARK-42612/tests.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/connect/functions.py            |  4 +++
 .../pyspark/sql/tests/connect/test_connect_plan.py |  3 --
 .../sql/tests/connect/test_parity_functions.py     | 16 ++-------
 python/pyspark/sql/tests/test_functions.py         | 42 ++++++++++++++--------
 python/pyspark/testing/sqlutils.py                 |  2 +-
 5 files changed, 36 insertions(+), 31 deletions(-)

diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py
index 87dfe90107d..268774e3211 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -224,6 +224,10 @@ def lit(col: Any) -> Column:
     if isinstance(col, Column):
         return col
     elif isinstance(col, list):
+        if any(isinstance(c, Column) for c in col):
+            raise PySparkValueError(
+                error_class="COLUMN_IN_LIST", message_parameters={"func_name": "lit"}
+            )
         return array(*[lit(c) for c in col])
     elif isinstance(col, np.ndarray) and col.ndim == 1:
         if _from_numpy_type(col.dtype) is None:
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 2de51189c4d..8c09b9cfaa5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -838,9 +838,6 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
         p = multi_type_lit.to_plan(None)
         self.assertIsNotNone(p)
 
-        lit_list_plan = lit([lit(10), lit("str")]).to_plan(None)
-        self.assertIsNotNone(lit_list_plan)
-
     def test_column_alias(self) -> None:
         # SPARK-40809: Support for Column Aliases
         col0 = col("a").alias("martin")
diff --git a/python/pyspark/sql/tests/connect/test_parity_functions.py b/python/pyspark/sql/tests/connect/test_parity_functions.py
index a69e47effe4..747f9a1b287 100644
--- a/python/pyspark/sql/tests/connect/test_parity_functions.py
+++ b/python/pyspark/sql/tests/connect/test_parity_functions.py
@@ -38,23 +38,13 @@ class FunctionsParityTests(FunctionsTestsMixin, ReusedConnectTestCase):
     def test_input_file_name_reset_for_rdd(self):
         super().test_input_file_name_reset_for_rdd()
 
-    # TODO(SPARK-41901): Parity in String representation of Column
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_inverse_trig_functions(self):
-        super().test_inverse_trig_functions()
-
-    # TODO(SPARK-41834): Implement SparkSession.conf
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_lit_list(self):
-        super().test_lit_list()
-
     def test_raise_error(self):
         self.check_raise_error(SparkConnectException)
 
-    # Comparing column type of connect and pyspark
-    @unittest.skip("Fails in Spark Connect, should enable.")
     def test_sorting_functions_with_column(self):
-        super().test_sorting_functions_with_column()
+        from pyspark.sql.connect.column import Column
+
+        self.check_sorting_functions_with_column(Column)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 3aec7cc42de..44f1b9a4d13 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -318,19 +318,29 @@ class FunctionsTestsMixin:
         )
 
     def test_inverse_trig_functions(self):
-        from pyspark.sql import functions
+        df = self.spark.createDataFrame([Row(a=i * 0.2, b=i * -0.2) for i in range(10)])
 
-        funs = [
-            (functions.acosh, "ACOSH"),
-            (functions.asinh, "ASINH"),
-            (functions.atanh, "ATANH"),
-        ]
+        def check(trig, inv, y_axis_symmetrical):
+            SQLTestUtils.assert_close(
+                [n * 0.2 for n in range(10)],
+                df.select(inv(trig(df.a))).collect(),
+            )
+            if y_axis_symmetrical:
+                SQLTestUtils.assert_close(
+                    [n * 0.2 for n in range(10)],
+                    df.select(inv(trig(df.b))).collect(),
+                )
+            else:
+                SQLTestUtils.assert_close(
+                    [n * -0.2 for n in range(10)],
+                    df.select(inv(trig(df.b))).collect(),
+                )
 
-        cols = ["a", functions.col("a")]
+        from pyspark.sql import functions
 
-        for f, alias in funs:
-            for c in cols:
-                self.assertIn(f"{alias}(a)", repr(f(c)))
+        check(functions.cosh, functions.acosh, y_axis_symmetrical=True)
+        check(functions.sinh, functions.asinh, y_axis_symmetrical=False)
+        check(functions.tanh, functions.atanh, y_axis_symmetrical=False)
 
     def test_reciprocal_trig_functions(self):
         # SPARK-36683: Tests for reciprocal trig functions (SEC, CSC and COT)
@@ -578,9 +588,13 @@ class FunctionsTestsMixin:
         self.assertRaises(TypeError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
 
     def test_sorting_functions_with_column(self):
-        from pyspark.sql import functions
         from pyspark.sql.column import Column
 
+        self.check_sorting_functions_with_column(Column)
+
+    def check_sorting_functions_with_column(self, tpe):
+        from pyspark.sql import functions
+
         funs = [
             functions.asc_nulls_first,
             functions.asc_nulls_last,
@@ -592,17 +606,17 @@ class FunctionsTestsMixin:
         for fun in funs:
             for _expr in exprs:
                 res = fun(_expr)
-                self.assertIsInstance(res, Column)
+                self.assertIsInstance(res, tpe)
                 self.assertIn(f"""'x {fun.__name__.replace("_", " ").upper()}'""", str(res))
 
         for _expr in exprs:
             res = functions.asc(_expr)
-            self.assertIsInstance(res, Column)
+            self.assertIsInstance(res, tpe)
             self.assertIn("""'x ASC NULLS FIRST'""", str(res))
 
         for _expr in exprs:
             res = functions.desc(_expr)
-            self.assertIsInstance(res, Column)
+            self.assertIsInstance(res, tpe)
             self.assertIn("""'x DESC NULLS LAST'""", str(res))
 
     def test_sort_with_nulls_order(self):
diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py
index 46585cfdab0..937ad491479 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -251,7 +251,7 @@ class SQLTestUtils:
     def assert_close(a, b):
         c = [j[0] for j in b]
         diff = [abs(v - c[k]) < 1e-6 if math.isfinite(v) else v == c[k] for k, v in enumerate(a)]
-        return sum(diff) == len(a)
+        assert sum(diff) == len(a), f"sum: {sum(diff)}, len: {len(a)}"
 
 
 class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils, PySparkErrorTestUtils):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org