You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/03/30 23:47:25 UTC

[spark] branch branch-3.4 updated: [SPARK-42969][CONNECT][TESTS] Fix the comparison the result with Arrow optimization enabled/disabled

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 68fa8cafc3c [SPARK-42969][CONNECT][TESTS] Fix the comparison the result with Arrow optimization enabled/disabled
68fa8cafc3c is described below

commit 68fa8cafc3cfcdc043920ca8544c24ac88f0a63c
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Fri Mar 31 08:45:44 2023 +0900

    [SPARK-42969][CONNECT][TESTS] Fix the comparison the result with Arrow optimization enabled/disabled
    
    Fixes the comparison the result with Arrow optimization enabled/disabled.
    
    in `test_arrow`, there are a bunch of comparison between DataFrames with Arrow optimization enabled/disabled.
    
    These should be fixed to compare with the expected values so that it can be reusable for Spark Connect parity tests.
    
    No.
    
    Updated the tests.
    
    Closes #40612 from ueshin/issues/SPARK-42969/test_arrow.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
    (cherry picked from commit 35503a535771d257b517e7ddf2adfaefefd97dad)
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../pyspark/sql/tests/connect/test_parity_arrow.py |  47 +++--
 python/pyspark/sql/tests/test_arrow.py             | 202 +++++++++++++--------
 2 files changed, 163 insertions(+), 86 deletions(-)

diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py b/python/pyspark/sql/tests/connect/test_parity_arrow.py
index f8180d661db..8953b2f8d98 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -37,19 +37,20 @@ class ArrowParityTests(ArrowTestsMixin, ReusedConnectTestCase):
     def test_createDataFrame_with_incorrect_schema(self):
         self.check_createDataFrame_with_incorrect_schema()
 
-    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    # TODO(SPARK-42982): INVALID_COLUMN_OR_FIELD_DATA_TYPE
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_createDataFrame_with_map_type(self):
-        super().test_createDataFrame_with_map_type()
+        self.check_createDataFrame_with_map_type(True)
 
-    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    # TODO(SPARK-42983): len() of unsized object
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_createDataFrame_with_ndarray(self):
-        super().test_createDataFrame_with_ndarray()
+        self.check_createDataFrame_with_ndarray(True)
 
+    # TODO(SPARK-42984): ValueError not raised
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_createDataFrame_with_single_data_type(self):
-        super().test_createDataFrame_with_single_data_type()
+        self.check_createDataFrame_with_single_data_type()
 
     @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
     def test_no_partition_frame(self):
@@ -70,9 +71,20 @@ class ArrowParityTests(ArrowTestsMixin, ReusedConnectTestCase):
     def test_toPandas_batch_order(self):
         super().test_toPandas_batch_order()
 
-    @unittest.skip("Spark Connect does not support Spark Context but the test depends on that.")
     def test_toPandas_empty_df_arrow_enabled(self):
-        super().test_toPandas_empty_df_arrow_enabled()
+        self.check_toPandas_empty_df_arrow_enabled(True)
+
+    def test_create_data_frame_to_pandas_timestamp_ntz(self):
+        self.check_create_data_frame_to_pandas_timestamp_ntz(True)
+
+    def test_create_data_frame_to_pandas_day_time_internal(self):
+        self.check_create_data_frame_to_pandas_day_time_internal(True)
+
+    def test_toPandas_respect_session_timezone(self):
+        self.check_toPandas_respect_session_timezone(True)
+
+    def test_toPandas_with_array_type(self):
+        self.check_toPandas_with_array_type(True)
 
     @unittest.skip("Spark Connect does not support fallback.")
     def test_toPandas_fallback_disabled(self):
@@ -82,20 +94,29 @@ class ArrowParityTests(ArrowTestsMixin, ReusedConnectTestCase):
     def test_toPandas_fallback_enabled(self):
         super().test_toPandas_fallback_enabled()
 
-    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    # TODO(SPARK-42982): INVALID_COLUMN_OR_FIELD_DATA_TYPE
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_toPandas_with_map_type(self):
-        super().test_toPandas_with_map_type()
+        self.check_toPandas_with_map_type(True)
 
-    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    # TODO(SPARK-42982): INVALID_COLUMN_OR_FIELD_DATA_TYPE
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_toPandas_with_map_type_nulls(self):
-        super().test_toPandas_with_map_type_nulls()
+        self.check_toPandas_with_map_type_nulls(True)
 
-    # TODO(SPARK-42969): Fix the comparison the result with Arrow optimization enabled/disabled.
+    # TODO(SPARK-42985): Respect session timezone
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_createDataFrame_respect_session_timezone(self):
-        super().test_createDataFrame_respect_session_timezone()
+        self.check_createDataFrame_respect_session_timezone(True)
+
+    def test_createDataFrame_with_array_type(self):
+        self.check_createDataFrame_with_array_type(True)
+
+    def test_createDataFrame_with_int_col_names(self):
+        self.check_createDataFrame_with_int_col_names(True)
+
+    def test_timestamp_nat(self):
+        self.check_timestamp_nat(True)
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index 751b19d3882..95100ac359c 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -218,6 +218,11 @@ class ArrowTestsMixin:
                     df.toPandas()
 
     def test_toPandas_empty_df_arrow_enabled(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_toPandas_empty_df_arrow_enabled(arrow_enabled)
+
+    def check_toPandas_empty_df_arrow_enabled(self, arrow_enabled):
         # SPARK-30537 test that toPandas() on an empty dataframe has the correct dtypes
         # when arrow is enabled
         from datetime import date
@@ -238,7 +243,7 @@ class ArrowTestsMixin:
                 StructField("L", DayTimeIntervalType(0, 3), True),
             ]
         )
-        df = self.spark.createDataFrame(self.spark.sparkContext.emptyRDD(), schema=schema)
+        df = self.spark.createDataFrame([], schema=schema)
         non_empty_df = self.spark.createDataFrame(
             [
                 (
@@ -258,11 +263,10 @@ class ArrowTestsMixin:
             schema=schema,
         )
 
-        pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
-        pdf_non_empty, pdf_arrow_non_empty = self._toPandas_arrow_toggle(non_empty_df)
-        assert_frame_equal(pdf, pdf_arrow)
-        self.assertTrue(pdf_arrow.dtypes.equals(pdf_arrow_non_empty.dtypes))
-        self.assertTrue(pdf_arrow.dtypes.equals(pdf_non_empty.dtypes))
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+            pdf = df.toPandas()
+            pdf_non_empty = non_empty_df.toPandas()
+        self.assertTrue(pdf.dtypes.equals(pdf_non_empty.dtypes))
 
     def test_null_conversion(self):
         df_null = self.spark.createDataFrame(
@@ -288,6 +292,11 @@ class ArrowTestsMixin:
         assert_frame_equal(expected, pdf_arrow)
 
     def test_create_data_frame_to_pandas_timestamp_ntz(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_create_data_frame_to_pandas_timestamp_ntz(arrow_enabled)
+
+    def check_create_data_frame_to_pandas_timestamp_ntz(self, arrow_enabled):
         # SPARK-36626: Test TimestampNTZ in createDataFrame and toPandas
         with self.sql_conf({"spark.sql.session.timeZone": "America/Los_Angeles"}):
             origin = pd.DataFrame({"a": [datetime.datetime(2012, 2, 2, 2, 2, 2)]})
@@ -296,11 +305,16 @@ class ArrowTestsMixin:
             )
             df.selectExpr("assert_true('2012-02-02 02:02:02' == CAST(a AS STRING))").collect()
 
-            pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+                pdf = df.toPandas()
             assert_frame_equal(origin, pdf)
-            assert_frame_equal(pdf, pdf_arrow)
 
     def test_create_data_frame_to_pandas_day_time_internal(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_create_data_frame_to_pandas_day_time_internal(arrow_enabled)
+
+    def check_create_data_frame_to_pandas_day_time_internal(self, arrow_enabled):
         # SPARK-37279: Test DayTimeInterval in createDataFrame and toPandas
         origin = pd.DataFrame({"a": [datetime.timedelta(microseconds=123)]})
         df = self.spark.createDataFrame(origin)
@@ -308,22 +322,27 @@ class ArrowTestsMixin:
             assert_true(lit("INTERVAL '0 00:00:00.000123' DAY TO SECOND") == df.a.cast("string"))
         ).collect()
 
-        pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+            pdf = df.toPandas()
         assert_frame_equal(origin, pdf)
-        assert_frame_equal(pdf, pdf_arrow)
 
     def test_toPandas_respect_session_timezone(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_toPandas_respect_session_timezone(arrow_enabled)
+
+    def check_toPandas_respect_session_timezone(self, arrow_enabled):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
 
         timezone = "America/Los_Angeles"
         with self.sql_conf({"spark.sql.session.timeZone": timezone}):
-            pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
-            assert_frame_equal(pdf_arrow_la, pdf_la)
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+                pdf_la = df.toPandas()
 
         timezone = "America/New_York"
         with self.sql_conf({"spark.sql.session.timeZone": timezone}):
-            pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
-            assert_frame_equal(pdf_arrow_ny, pdf_ny)
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+                pdf_ny = df.toPandas()
 
             self.assertFalse(pdf_ny.equals(pdf_la))
 
@@ -420,22 +439,25 @@ class ArrowTestsMixin:
         self.assertEqual(df_no_arrow.collect(), df_arrow.collect())
 
     def test_createDataFrame_respect_session_timezone(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_createDataFrame_respect_session_timezone(arrow_enabled)
+
+    def check_createDataFrame_respect_session_timezone(self, arrow_enabled):
         from datetime import timedelta
 
         pdf = self.create_pandas_data_frame()
         timezone = "America/Los_Angeles"
         with self.sql_conf({"spark.sql.session.timeZone": timezone}):
-            df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
-            result_la = df_no_arrow_la.collect()
-            result_arrow_la = df_arrow_la.collect()
-            self.assertEqual(result_la, result_arrow_la)
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+                df_la = self.spark.createDataFrame(pdf, schema=self.schema)
+            result_la = df_la.collect()
 
         timezone = "America/New_York"
         with self.sql_conf({"spark.sql.session.timeZone": timezone}):
-            df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
-            result_ny = df_no_arrow_ny.collect()
-            result_arrow_ny = df_arrow_ny.collect()
-            self.assertEqual(result_ny, result_arrow_ny)
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+                df_ny = self.spark.createDataFrame(pdf, schema=self.schema)
+            result_ny = df_ny.collect()
 
             self.assertNotEqual(result_ny, result_la)
 
@@ -492,8 +514,11 @@ class ArrowTestsMixin:
 
     def test_createDataFrame_with_single_data_type(self):
         with QuietTest(self.sc):
-            with self.assertRaisesRegex(ValueError, ".*IntegerType.*not supported.*"):
-                self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
+            self.check_createDataFrame_with_single_data_type()
+
+    def check_createDataFrame_with_single_data_type(self):
+        with self.assertRaisesRegex(ValueError, ".*IntegerType.*not supported.*"):
+            self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int").collect()
 
     def test_createDataFrame_does_not_modify_input(self):
         # Some series get converted for Spark to consume, this makes sure input is unchanged
@@ -514,6 +539,11 @@ class ArrowTestsMixin:
         self.assertEqual(self.schema, schema_rt)
 
     def test_createDataFrame_with_ndarray(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_createDataFrame_with_ndarray(arrow_enabled)
+
+    def check_createDataFrame_with_ndarray(self, arrow_enabled):
         dtypes = ["tinyint", "smallint", "int", "bigint", "float", "double"]
         expected_dtypes = (
             [[("value", t)] for t in dtypes]
@@ -523,66 +553,70 @@ class ArrowTestsMixin:
         arrs = self.create_np_arrs
 
         for arr, dtypes in zip(arrs, expected_dtypes):
-            df, df_arrow = self._createDataFrame_toggle(arr)
-            self.assertEqual(df.dtypes, df_arrow.dtypes)
-            self.assertEqual(df_arrow.dtypes, dtypes)
-            self.assertEqual(df.collect(), df_arrow.collect())
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+                df = self.spark.createDataFrame(arr)
+            self.assertEqual(df.dtypes, dtypes)
+            np.array_equal(np.array(df.collect()), arr)
 
         with self.assertRaisesRegex(ValueError, "NumPy array input should be of 1 or 2 dimensions"):
             self.spark.createDataFrame(np.array(0))
 
     def test_createDataFrame_with_array_type(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_createDataFrame_with_array_type(arrow_enabled)
+
+    def check_createDataFrame_with_array_type(self, arrow_enabled):
         pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [["x", "y"], ["y", "z"]]})
-        df, df_arrow = self._createDataFrame_toggle(pdf)
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+            df = self.spark.createDataFrame(pdf)
         result = df.collect()
-        result_arrow = df_arrow.collect()
         expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
         for r in range(len(expected)):
             for e in range(len(expected[r])):
-                self.assertTrue(
-                    expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]
-                )
+                self.assertTrue(expected[r][e] == result[r][e])
 
     def test_toPandas_with_array_type(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_toPandas_with_array_type(arrow_enabled)
+
+    def check_toPandas_with_array_type(self, arrow_enabled):
         expected = [([1, 2], ["x", "y"]), ([3, 4], ["y", "z"])]
         array_schema = StructType(
             [StructField("a", ArrayType(IntegerType())), StructField("b", ArrayType(StringType()))]
         )
         df = self.spark.createDataFrame(expected, schema=array_schema)
-        pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+        pdf = df.toPandas()
         result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
-        result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)]
         for r in range(len(expected)):
             for e in range(len(expected[r])):
-                self.assertTrue(
-                    expected[r][e] == result_arrow[r][e] and result[r][e] == result_arrow[r][e]
-                )
+                self.assertTrue(expected[r][e] == result[r][e])
 
     def test_createDataFrame_with_map_type(self):
+        with QuietTest(self.sc):
+            for arrow_enabled in [True, False]:
+                with self.subTest(arrow_enabled=arrow_enabled):
+                    self.check_createDataFrame_with_map_type(arrow_enabled)
+
+    def check_createDataFrame_with_map_type(self, arrow_enabled):
         map_data = [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]
 
         pdf = pd.DataFrame({"id": [0, 1, 2, 3, 4], "m": map_data})
         schema = "id long, m map<string, long>"
 
-        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
-            df = self.spark.createDataFrame(pdf, schema=schema)
-
-        if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
-            with QuietTest(self.sc):
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+            if arrow_enabled and LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
                 with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
-                    self.spark.createDataFrame(pdf, schema=schema)
-        else:
-            df_arrow = self.spark.createDataFrame(pdf, schema=schema)
+                    self.spark.createDataFrame(pdf, schema=schema).collect()
+            else:
+                df = self.spark.createDataFrame(pdf, schema=schema)
 
-            result = df.collect()
-            result_arrow = df_arrow.collect()
+                result = df.collect()
 
-            self.assertEqual(len(result), len(result_arrow))
-            for row, row_arrow in zip(result, result_arrow):
-                i, m = row
-                _, m_arrow = row_arrow
-                self.assertEqual(m, map_data[i])
-                self.assertEqual(m_arrow, map_data[i])
+                for row in result:
+                    i, m = row
+                    self.assertEqual(m, map_data[i])
 
     def test_createDataFrame_with_string_dtype(self):
         # SPARK-34521: spark.createDataFrame does not support Pandas StringDtype extension type
@@ -606,45 +640,62 @@ class ArrowTestsMixin:
             assert_frame_equal(pandas_df, df.toPandas(), check_dtype=False)
 
     def test_toPandas_with_map_type(self):
-        pdf = pd.DataFrame(
+        with QuietTest(self.sc):
+            for arrow_enabled in [True, False]:
+                with self.subTest(arrow_enabled=arrow_enabled):
+                    self.check_toPandas_with_map_type(arrow_enabled)
+
+    def check_toPandas_with_map_type(self, arrow_enabled):
+        origin = pd.DataFrame(
             {"id": [0, 1, 2, 3], "m": [{}, {"a": 1}, {"a": 1, "b": 2}, {"a": 1, "b": 2, "c": 3}]}
         )
 
         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
-            df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")
+            df = self.spark.createDataFrame(origin, schema="id long, m map<string, long>")
 
-        if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
-            with QuietTest(self.sc):
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+            if arrow_enabled and LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
                 with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
                     df.toPandas()
-        else:
-            pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
-            assert_frame_equal(pdf_arrow, pdf_non)
+            else:
+                pdf = df.toPandas()
+                assert_frame_equal(origin, pdf)
 
     def test_toPandas_with_map_type_nulls(self):
-        pdf = pd.DataFrame(
+        with QuietTest(self.sc):
+            for arrow_enabled in [True, False]:
+                with self.subTest(arrow_enabled=arrow_enabled):
+                    self.check_toPandas_with_map_type_nulls(arrow_enabled)
+
+    def check_toPandas_with_map_type_nulls(self, arrow_enabled):
+        origin = pd.DataFrame(
             {"id": [0, 1, 2, 3, 4], "m": [{"a": 1}, {"b": 2, "c": 3}, {}, None, {"d": None}]}
         )
 
         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
-            df = self.spark.createDataFrame(pdf, schema="id long, m map<string, long>")
+            df = self.spark.createDataFrame(origin, schema="id long, m map<string, long>")
 
-        if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
-            with QuietTest(self.sc):
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+            if arrow_enabled and LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
                 with self.assertRaisesRegex(Exception, "MapType.*only.*pyarrow 2.0.0"):
                     df.toPandas()
-        else:
-            pdf_non, pdf_arrow = self._toPandas_arrow_toggle(df)
-            assert_frame_equal(pdf_arrow, pdf_non)
+            else:
+                pdf = df.toPandas()
+                assert_frame_equal(origin, pdf)
 
     def test_createDataFrame_with_int_col_names(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_createDataFrame_with_int_col_names(arrow_enabled)
+
+    def check_createDataFrame_with_int_col_names(self, arrow_enabled):
         import numpy as np
 
         pdf = pd.DataFrame(np.random.rand(4, 2))
-        df, df_arrow = self._createDataFrame_toggle(pdf)
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+            df = self.spark.createDataFrame(pdf)
         pdf_col_names = [str(c) for c in pdf.columns]
         self.assertEqual(pdf_col_names, df.columns)
-        self.assertEqual(pdf_col_names, df_arrow.columns)
 
     def test_createDataFrame_fallback_enabled(self):
         ts = datetime.datetime(2015, 11, 1, 0, 30)
@@ -690,12 +741,17 @@ class ArrowTestsMixin:
 
     # Regression test for SPARK-28003
     def test_timestamp_nat(self):
+        for arrow_enabled in [True, False]:
+            with self.subTest(arrow_enabled=arrow_enabled):
+                self.check_timestamp_nat(arrow_enabled)
+
+    def check_timestamp_nat(self, arrow_enabled):
         dt = [pd.NaT, pd.Timestamp("2019-06-11"), None] * 100
         pdf = pd.DataFrame({"time": dt})
-        df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf)
+        with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": arrow_enabled}):
+            df = self.spark.createDataFrame(pdf)
 
-        assert_frame_equal(pdf, df_no_arrow.toPandas())
-        assert_frame_equal(pdf, df_arrow.toPandas())
+        assert_frame_equal(pdf, df.toPandas())
 
     def test_toPandas_batch_order(self):
         def delay_first_part(partition_index, iterator):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org