You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/21 01:43:34 UTC

[spark] branch master updated: [SPARK-42875][CONNECT][PYTHON] Fix toPandas to handle timezone and map types properly

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 61035129a35 [SPARK-42875][CONNECT][PYTHON] Fix toPandas to handle timezone and map types properly
61035129a35 is described below

commit 61035129a354d0b31c66908106238b12b1f2f7b0
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Tue Mar 21 09:43:11 2023 +0800

    [SPARK-42875][CONNECT][PYTHON] Fix toPandas to handle timezone and map types properly
    
    ### What changes were proposed in this pull request?
    
    Fix `DataFrame.toPandas()` to handle timezone and map types properly.
    
    ### Why are the changes needed?
    
    Currently `DataFrame.toPandas()` doesn't handle timezone for timestamp type, and map types properly.
    
    For example:
    
    ```py
    >>> schema = StructType().add("ts", TimestampType())
    >>> spark.createDataFrame([(datetime(1969, 1, 1, 1, 1, 1),), (datetime(2012, 3, 3, 3, 3, 3),), (datetime(2100, 4, 4, 4, 4, 4),)], schema).toPandas()
                             ts
    0 1969-01-01 01:01:01-08:00
    1 2012-03-03 03:03:03-08:00
    2 2100-04-04 03:04:04-08:00
    ```
    
    which should be:
    
    ```py
                       ts
    0 1969-01-01 01:01:01
    1 2012-03-03 03:03:03
    2 2100-04-04 04:04:04
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    The result of `DataFrame.toPandas()` with timestamp type and map type will be the same as PySpark.
    
    ### How was this patch tested?
    
    Enabled the related tests.
    
    Closes #40497 from ueshin/issues/SPARK-42875/timestamp.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/connect/client.py               | 24 ++++++---
 .../sql/tests/connect/test_parity_dataframe.py     | 20 +++----
 python/pyspark/sql/tests/test_dataframe.py         | 61 +++++++++++++---------
 3 files changed, 60 insertions(+), 45 deletions(-)

diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 090d239fbb4..53fa97372a7 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -71,7 +71,8 @@ from pyspark.sql.connect.expressions import (
     CommonInlineUserDefinedFunction,
     JavaUDF,
 )
-from pyspark.sql.types import DataType, StructType
+from pyspark.sql.pandas.types import _check_series_localize_timestamps, _convert_map_items_to_dict
+from pyspark.sql.types import DataType, MapType, StructType, TimestampType
 from pyspark.rdd import PythonEvalType
 
 
@@ -637,12 +638,23 @@ class SparkConnectClient(object):
         logger.info(f"Executing plan {self._proto_to_string(plan)}")
         req = self._execute_plan_request_with_metadata()
         req.plan.CopyFrom(plan)
-        table, _, metrics, observed_metrics, _ = self._execute_and_fetch(req)
+        table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req)
         assert table is not None
-        column_names = table.column_names
-        table = table.rename_columns([f"col_{i}" for i in range(len(column_names))])
-        pdf = table.to_pandas()
-        pdf.columns = column_names
+        pdf = table.rename_columns([f"col_{i}" for i in range(len(table.column_names))]).to_pandas()
+        pdf.columns = table.column_names
+
+        schema = schema or types.from_arrow_schema(table.schema)
+        assert schema is not None and isinstance(schema, StructType)
+
+        for field, pa_field in zip(schema, table.schema):
+            if isinstance(field.dataType, TimestampType):
+                assert pa_field.type.tz is not None
+                pdf[field.name] = _check_series_localize_timestamps(
+                    pdf[field.name], pa_field.type.tz
+                )
+            elif isinstance(field.dataType, MapType):
+                pdf[field.name] = _convert_map_items_to_dict(pdf[field.name])
+
         if len(metrics) > 0:
             pdf.attrs["metrics"] = metrics
         if len(observed_metrics) > 0:
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 31dee6a19d2..ae812b4ca55 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -22,11 +22,6 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
 
 
 class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
-    # TODO(SPARK-41834): Implement SparkSession.conf
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_create_dataframe_from_pandas_with_dst(self):
-        super().test_create_dataframe_from_pandas_with_dst()
-
     @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
     def test_help_command(self):
         super().test_help_command()
@@ -87,26 +82,25 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
     def test_to_local_iterator_prefetch(self):
         super().test_to_local_iterator_prefetch()
 
-    # TODO(SPARK-41884): DataFrame `toPandas` parity in return types
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_to_pandas(self):
-        super().test_to_pandas()
-
     def test_to_pandas_for_array_of_struct(self):
         # Spark Connect's implementation is based on Arrow.
         super().check_to_pandas_for_array_of_struct(True)
 
-    # TODO(SPARK-41834): Implement SparkSession.conf
-    @unittest.skip("Fails in Spark Connect, should enable.")
     def test_to_pandas_from_null_dataframe(self):
-        super().test_to_pandas_from_null_dataframe()
+        self.check_to_pandas_from_null_dataframe()
 
     def test_to_pandas_on_cross_join(self):
         self.check_to_pandas_on_cross_join()
 
+    def test_to_pandas_from_empty_dataframe(self):
+        self.check_to_pandas_from_empty_dataframe()
+
     def test_to_pandas_with_duplicated_column_names(self):
         self.check_to_pandas_with_duplicated_column_names()
 
+    def test_to_pandas_from_mixed_dataframe(self):
+        self.check_to_pandas_from_mixed_dataframe()
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index bd2f1cb75b7..cb209f472bf 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1186,6 +1186,12 @@ class DataFrameTestsMixin:
 
     @unittest.skipIf(not have_pandas, pandas_requirement_message)  # type: ignore
     def test_to_pandas_from_empty_dataframe(self):
+        is_arrow_enabled = [True, False]
+        for value in is_arrow_enabled:
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
+                self.check_to_pandas_from_empty_dataframe()
+
+    def check_to_pandas_from_empty_dataframe(self):
         # SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes
         # SPARK-30537 test that toPandas() on an empty dataframe has the correct dtypes
         # when arrow is enabled
@@ -1204,15 +1210,18 @@ class DataFrameTestsMixin:
             CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz,
             INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval
             """
+        dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
+        dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes
+        self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df))
+
+    @unittest.skipIf(not have_pandas, pandas_requirement_message)  # type: ignore
+    def test_to_pandas_from_null_dataframe(self):
         is_arrow_enabled = [True, False]
         for value in is_arrow_enabled:
             with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
-                dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
-                dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes
-                self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df))
+                self.check_to_pandas_from_null_dataframe()
 
-    @unittest.skipIf(not have_pandas, pandas_requirement_message)  # type: ignore
-    def test_to_pandas_from_null_dataframe(self):
+    def check_to_pandas_from_null_dataframe(self):
         # SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes
         # SPARK-30537 test that toPandas() on a dataframe with only nulls has correct dtypes
         # using arrow
@@ -1231,25 +1240,28 @@ class DataFrameTestsMixin:
             CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz,
             INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval
             """
+        pdf = self.spark.sql(sql).toPandas()
+        types = pdf.dtypes
+        self.assertEqual(types[0], np.float64)
+        self.assertEqual(types[1], np.float64)
+        self.assertEqual(types[2], np.float64)
+        self.assertEqual(types[3], np.float64)
+        self.assertEqual(types[4], np.float32)
+        self.assertEqual(types[5], np.float64)
+        self.assertEqual(types[6], object)
+        self.assertEqual(types[7], object)
+        self.assertTrue(np.can_cast(np.datetime64, types[8]))
+        self.assertTrue(np.can_cast(np.datetime64, types[9]))
+        self.assertTrue(np.can_cast(np.timedelta64, types[10]))
+
+    @unittest.skipIf(not have_pandas, pandas_requirement_message)  # type: ignore
+    def test_to_pandas_from_mixed_dataframe(self):
         is_arrow_enabled = [True, False]
         for value in is_arrow_enabled:
             with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
-                pdf = self.spark.sql(sql).toPandas()
-                types = pdf.dtypes
-                self.assertEqual(types[0], np.float64)
-                self.assertEqual(types[1], np.float64)
-                self.assertEqual(types[2], np.float64)
-                self.assertEqual(types[3], np.float64)
-                self.assertEqual(types[4], np.float32)
-                self.assertEqual(types[5], np.float64)
-                self.assertEqual(types[6], object)
-                self.assertEqual(types[7], object)
-                self.assertTrue(np.can_cast(np.datetime64, types[8]))
-                self.assertTrue(np.can_cast(np.datetime64, types[9]))
-                self.assertTrue(np.can_cast(np.timedelta64, types[10]))
+                self.check_to_pandas_from_mixed_dataframe()
 
-    @unittest.skipIf(not have_pandas, pandas_requirement_message)  # type: ignore
-    def test_to_pandas_from_mixed_dataframe(self):
+    def check_to_pandas_from_mixed_dataframe(self):
         # SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes
         # SPARK-30537 test that toPandas() on a dataframe with some nulls has correct dtypes
         # using arrow
@@ -1270,12 +1282,9 @@ class DataFrameTestsMixin:
         FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
                     (NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)
         """
-        is_arrow_enabled = [True, False]
-        for value in is_arrow_enabled:
-            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
-                pdf_with_some_nulls = self.spark.sql(sql).toPandas()
-                pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas()
-                self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes))
+        pdf_with_some_nulls = self.spark.sql(sql).toPandas()
+        pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas()
+        self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes))
 
     @unittest.skipIf(
         not have_pandas or not have_pyarrow or pyarrow_version_less_than_minimum("2.0.0"),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org