You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/30 00:34:42 UTC

[spark] branch master updated: [SPARK-42970][CONNECT][PYTHON][TESTS] Reuse pyspark.sql.tests.test_arrow test cases

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 12e7991b5b3 [SPARK-42970][CONNECT][PYTHON][TESTS] Reuse pyspark.sql.tests.test_arrow test cases
12e7991b5b3 is described below

commit 12e7991b5b38302e6496307c7263ad729c82a6cf
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Thu Mar 30 09:34:09 2023 +0900

    [SPARK-42970][CONNECT][PYTHON][TESTS] Reuse pyspark.sql.tests.test_arrow test cases
    
    ### What changes were proposed in this pull request?
    
    Reuses `pyspark.sql.tests.test_arrow` test cases.
    
    ### Why are the changes needed?
    
    `test_arrow` is also helpful because it contains many tests for `createDataFrame` with pandas or `toPandas`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added the tests.
    
    Closes #40594 from ueshin/issues/SPARK-42970/test_arrow.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 dev/sparktestsupport/modules.py                    |   1 +
 .../pyspark/sql/tests/connect/test_parity_arrow.py | 110 +++++++++++++++++++++
 python/pyspark/sql/tests/test_arrow.py             |  65 +++++++-----
 3 files changed, 149 insertions(+), 27 deletions(-)

diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index f65ef7e3ac0..1a28a644e55 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -755,6 +755,7 @@ pyspark_connect = Module(
         "pyspark.sql.tests.connect.test_connect_basic",
         "pyspark.sql.tests.connect.test_connect_function",
         "pyspark.sql.tests.connect.test_connect_column",
+        "pyspark.sql.tests.connect.test_parity_arrow",
         "pyspark.sql.tests.connect.test_parity_datasources",
         "pyspark.sql.tests.connect.test_parity_errors",
         "pyspark.sql.tests.connect.test_parity_catalog",
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py
new file mode 100644
index 00000000000..f8180d661db
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_arrow import ArrowTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class ArrowParityTests(ArrowTestsMixin, ReusedConnectTestCase):
+    @unittest.skip("Spark Connect does not support Spark Context but the test depends on that.")
+    def test_createDataFrame_empty_partition(self):
+        super().test_createDataFrame_empty_partition()
+
+    @unittest.skip("Spark Connect does not support fallback.")
+    def test_createDataFrame_fallback_disabled(self):
+        super().test_createDataFrame_fallback_disabled()
+
+    @unittest.skip("Spark Connect does not support fallback.")
+    def test_createDataFrame_fallback_enabled(self):
+        super().test_createDataFrame_fallback_enabled()
+
+    def test_createDataFrame_with_incorrect_schema(self):
+        self.check_createDataFrame_with_incorrect_schema()
+
+    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_createDataFrame_with_map_type(self):
+        super().test_createDataFrame_with_map_type()
+
+    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_createDataFrame_with_ndarray(self):
+        super().test_createDataFrame_with_ndarray()
+
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_createDataFrame_with_single_data_type(self):
+        super().test_createDataFrame_with_single_data_type()
+
+    @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
+    def test_no_partition_frame(self):
+        super().test_no_partition_frame()
+
+    @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
+    def test_no_partition_toPandas(self):
+        super().test_no_partition_toPandas()
+
+    @unittest.skip("The test uses internal APIs.")
+    def test_pandas_self_destruct(self):
+        super().test_pandas_self_destruct()
+
+    def test_propagates_spark_exception(self):
+        self.check_propagates_spark_exception()
+
+    @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
+    def test_toPandas_batch_order(self):
+        super().test_toPandas_batch_order()
+
+    @unittest.skip("Spark Connect does not support Spark Context but the test depends on that.")
+    def test_toPandas_empty_df_arrow_enabled(self):
+        super().test_toPandas_empty_df_arrow_enabled()
+
+    @unittest.skip("Spark Connect does not support fallback.")
+    def test_toPandas_fallback_disabled(self):
+        super().test_toPandas_fallback_disabled()
+
+    @unittest.skip("Spark Connect does not support fallback.")
+    def test_toPandas_fallback_enabled(self):
+        super().test_toPandas_fallback_enabled()
+
+    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_toPandas_with_map_type(self):
+        super().test_toPandas_with_map_type()
+
+    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_toPandas_with_map_type_nulls(self):
+        super().test_toPandas_with_map_type_nulls()
+
+    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_createDataFrame_respect_session_timezone(self):
+        super().test_createDataFrame_respect_session_timezone()
+
+
+if __name__ == "__main__":
+    from pyspark.sql.tests.connect.test_parity_arrow import *  # noqa: F401
+
+    try:
+        import xmlrunner  # type: ignore[import]
+
+        testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+    except ImportError:
+        testRunner = None
+    unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index c61994380e6..0e162f686b9 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -63,17 +63,13 @@ if have_pyarrow:
     import pyarrow as pa  # noqa: F401
 
 
-@unittest.skipIf(
-    not have_pandas or not have_pyarrow,
-    cast(str, pandas_requirement_message or pyarrow_requirement_message),
-)
-class ArrowTests(ReusedSQLTestCase):
+class ArrowTestsMixin:
     @classmethod
     def setUpClass(cls):
         from datetime import date, datetime
         from decimal import Decimal
 
-        super(ArrowTests, cls).setUpClass()
+        super().setUpClass()
         cls.warnings_lock = threading.Lock()
 
         # Synchronize default timezone between Python and Java
@@ -168,7 +164,7 @@ class ArrowTests(ReusedSQLTestCase):
         if cls.tz_prev is not None:
             os.environ["TZ"] = cls.tz_prev
         time.tzset()
-        super(ArrowTests, cls).tearDownClass()
+        super().tearDownClass()
 
     def create_pandas_data_frame(self):
         import numpy as np
@@ -395,6 +391,10 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertTrue(pdf.empty)
 
     def test_propagates_spark_exception(self):
+        with QuietTest(self.sc):
+            self.check_propagates_spark_exception()
+
+    def check_propagates_spark_exception(self):
         df = self.spark.range(3).toDF("i")
 
         def raise_exception():
@@ -402,9 +402,9 @@ class ArrowTests(ReusedSQLTestCase):
 
         exception_udf = udf(raise_exception, IntegerType())
         df = df.withColumn("error", exception_udf())
-        with QuietTest(self.sc):
-            with self.assertRaisesRegex(Exception, "My error"):
-                df.toPandas()
+
+        with self.assertRaisesRegex(Exception, "My error"):
+            df.toPandas()
 
     def _createDataFrame_toggle(self, data, schema=None):
         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
@@ -459,29 +459,32 @@ class ArrowTests(ReusedSQLTestCase):
         assert_frame_equal(pdf_arrow, pdf)
 
     def test_createDataFrame_with_incorrect_schema(self):
+        with QuietTest(self.sc):
+            self.check_createDataFrame_with_incorrect_schema()
+
+    def check_createDataFrame_with_incorrect_schema(self):
         pdf = self.create_pandas_data_frame()
         fields = list(self.schema)
         fields[5], fields[6] = fields[6], fields[5]  # swap decimal with date
         wrong_schema = StructType(fields)
         with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
-            with QuietTest(self.sc):
-                with self.assertRaises(Exception) as context:
-                    self.spark.createDataFrame(pdf, schema=wrong_schema)
-
-                # the exception provides us with the column that is incorrect
-                exception = context.exception
-                self.assertTrue(hasattr(exception, "args"))
-                self.assertEqual(len(exception.args), 1)
-                self.assertRegex(
-                    exception.args[0],
-                    "with name '7_date_t' " "to Arrow Array \\(decimal128\\(38, 18\\)\\)",
-                )
+            with self.assertRaises(Exception) as context:
+                self.spark.createDataFrame(pdf, schema=wrong_schema)
+
+            # the exception provides us with the column that is incorrect
+            exception = context.exception
+            self.assertTrue(hasattr(exception, "args"))
+            self.assertEqual(len(exception.args), 1)
+            self.assertRegex(
+                exception.args[0],
+                "with name '7_date_t' " "to Arrow Array \\(decimal128\\(38, 18\\)\\)",
+            )
 
-                # the inner exception provides us with the incorrect types
-                exception = exception.__context__
-                self.assertTrue(hasattr(exception, "args"))
-                self.assertEqual(len(exception.args), 1)
-                self.assertRegex(exception.args[0], "[D|d]ecimal.*got.*date")
+            # the inner exception provides us with the incorrect types
+            exception = exception.__context__
+            self.assertTrue(hasattr(exception, "args"))
+            self.assertEqual(len(exception.args), 1)
+            self.assertRegex(exception.args[0], "[D|d]ecimal.*got.*date")
 
     def test_createDataFrame_with_names(self):
         pdf = self.create_pandas_data_frame()
@@ -786,6 +789,14 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertGreater(self.spark.sparkContext.defaultParallelism, len(pdf))
 
 
+@unittest.skipIf(
+    not have_pandas or not have_pyarrow,
+    cast(str, pandas_requirement_message or pyarrow_requirement_message),
+)
+class ArrowTests(ArrowTestsMixin, ReusedSQLTestCase):
+    pass
+
+
 @unittest.skipIf(
     not have_pandas or not have_pyarrow,
     cast(str, pandas_requirement_message or pyarrow_requirement_message),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org