You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/21 01:43:55 UTC
[spark] branch branch-3.4 updated: [SPARK-42875][CONNECT][PYTHON] Fix toPandas to handle timezone and map types properly
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 602aaff4462 [SPARK-42875][CONNECT][PYTHON] Fix toPandas to handle timezone and map types properly
602aaff4462 is described below
commit 602aaff44621decaff1c303eac91037c2376aa52
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Tue Mar 21 09:43:11 2023 +0800
[SPARK-42875][CONNECT][PYTHON] Fix toPandas to handle timezone and map types properly
### What changes were proposed in this pull request?
Fix `DataFrame.toPandas()` to handle timezone and map types properly.
### Why are the changes needed?
Currently `DataFrame.toPandas()` doesn't handle timezone for timestamp type, and map types properly.
For example:
```py
>>> schema = StructType().add("ts", TimestampType())
>>> spark.createDataFrame([(datetime(1969, 1, 1, 1, 1, 1),), (datetime(2012, 3, 3, 3, 3, 3),), (datetime(2100, 4, 4, 4, 4, 4),)], schema).toPandas()
ts
0 1969-01-01 01:01:01-08:00
1 2012-03-03 03:03:03-08:00
2 2100-04-04 03:04:04-08:00
```
which should be:
```py
ts
0 1969-01-01 01:01:01
1 2012-03-03 03:03:03
2 2100-04-04 04:04:04
```
### Does this PR introduce _any_ user-facing change?
The result of `DataFrame.toPandas()` with timestamp type and map type will be the same as PySpark.
### How was this patch tested?
Enabled the related tests.
Closes #40497 from ueshin/issues/SPARK-42875/timestamp.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
(cherry picked from commit 61035129a354d0b31c66908106238b12b1f2f7b0)
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/sql/connect/client.py | 24 ++++++---
.../sql/tests/connect/test_parity_dataframe.py | 20 +++----
python/pyspark/sql/tests/test_dataframe.py | 61 +++++++++++++---------
3 files changed, 60 insertions(+), 45 deletions(-)
diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py
index 090d239fbb4..53fa97372a7 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -71,7 +71,8 @@ from pyspark.sql.connect.expressions import (
CommonInlineUserDefinedFunction,
JavaUDF,
)
-from pyspark.sql.types import DataType, StructType
+from pyspark.sql.pandas.types import _check_series_localize_timestamps, _convert_map_items_to_dict
+from pyspark.sql.types import DataType, MapType, StructType, TimestampType
from pyspark.rdd import PythonEvalType
@@ -637,12 +638,23 @@ class SparkConnectClient(object):
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
- table, _, metrics, observed_metrics, _ = self._execute_and_fetch(req)
+ table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req)
assert table is not None
- column_names = table.column_names
- table = table.rename_columns([f"col_{i}" for i in range(len(column_names))])
- pdf = table.to_pandas()
- pdf.columns = column_names
+ pdf = table.rename_columns([f"col_{i}" for i in range(len(table.column_names))]).to_pandas()
+ pdf.columns = table.column_names
+
+ schema = schema or types.from_arrow_schema(table.schema)
+ assert schema is not None and isinstance(schema, StructType)
+
+ for field, pa_field in zip(schema, table.schema):
+ if isinstance(field.dataType, TimestampType):
+ assert pa_field.type.tz is not None
+ pdf[field.name] = _check_series_localize_timestamps(
+ pdf[field.name], pa_field.type.tz
+ )
+ elif isinstance(field.dataType, MapType):
+ pdf[field.name] = _convert_map_items_to_dict(pdf[field.name])
+
if len(metrics) > 0:
pdf.attrs["metrics"] = metrics
if len(observed_metrics) > 0:
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 31dee6a19d2..ae812b4ca55 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -22,11 +22,6 @@ from pyspark.testing.connectutils import ReusedConnectTestCase
class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
- # TODO(SPARK-41834): Implement SparkSession.conf
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_create_dataframe_from_pandas_with_dst(self):
- super().test_create_dataframe_from_pandas_with_dst()
-
@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_help_command(self):
super().test_help_command()
@@ -87,26 +82,25 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
def test_to_local_iterator_prefetch(self):
super().test_to_local_iterator_prefetch()
- # TODO(SPARK-41884): DataFrame `toPandas` parity in return types
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_to_pandas(self):
- super().test_to_pandas()
-
def test_to_pandas_for_array_of_struct(self):
# Spark Connect's implementation is based on Arrow.
super().check_to_pandas_for_array_of_struct(True)
- # TODO(SPARK-41834): Implement SparkSession.conf
- @unittest.skip("Fails in Spark Connect, should enable.")
def test_to_pandas_from_null_dataframe(self):
- super().test_to_pandas_from_null_dataframe()
+ self.check_to_pandas_from_null_dataframe()
def test_to_pandas_on_cross_join(self):
self.check_to_pandas_on_cross_join()
+ def test_to_pandas_from_empty_dataframe(self):
+ self.check_to_pandas_from_empty_dataframe()
+
def test_to_pandas_with_duplicated_column_names(self):
self.check_to_pandas_with_duplicated_column_names()
+ def test_to_pandas_from_mixed_dataframe(self):
+ self.check_to_pandas_from_mixed_dataframe()
+
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index bd2f1cb75b7..cb209f472bf 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1186,6 +1186,12 @@ class DataFrameTestsMixin:
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
def test_to_pandas_from_empty_dataframe(self):
+ is_arrow_enabled = [True, False]
+ for value in is_arrow_enabled:
+ with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
+ self.check_to_pandas_from_empty_dataframe()
+
+ def check_to_pandas_from_empty_dataframe(self):
# SPARK-29188 test that toPandas() on an empty dataframe has the correct dtypes
# SPARK-30537 test that toPandas() on an empty dataframe has the correct dtypes
# when arrow is enabled
@@ -1204,15 +1210,18 @@ class DataFrameTestsMixin:
CAST('2019-01-01' AS TIMESTAMP_NTZ) AS timestamp_ntz,
INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval
"""
+ dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
+ dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes
+ self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df))
+
+ @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
+ def test_to_pandas_from_null_dataframe(self):
is_arrow_enabled = [True, False]
for value in is_arrow_enabled:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
- dtypes_when_nonempty_df = self.spark.sql(sql).toPandas().dtypes
- dtypes_when_empty_df = self.spark.sql(sql).filter("False").toPandas().dtypes
- self.assertTrue(np.all(dtypes_when_empty_df == dtypes_when_nonempty_df))
+ self.check_to_pandas_from_null_dataframe()
- @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
- def test_to_pandas_from_null_dataframe(self):
+ def check_to_pandas_from_null_dataframe(self):
# SPARK-29188 test that toPandas() on a dataframe with only nulls has correct dtypes
# SPARK-30537 test that toPandas() on a dataframe with only nulls has correct dtypes
# using arrow
@@ -1231,25 +1240,28 @@ class DataFrameTestsMixin:
CAST(NULL AS TIMESTAMP_NTZ) AS timestamp_ntz,
INTERVAL '1563:04' MINUTE TO SECOND AS day_time_interval
"""
+ pdf = self.spark.sql(sql).toPandas()
+ types = pdf.dtypes
+ self.assertEqual(types[0], np.float64)
+ self.assertEqual(types[1], np.float64)
+ self.assertEqual(types[2], np.float64)
+ self.assertEqual(types[3], np.float64)
+ self.assertEqual(types[4], np.float32)
+ self.assertEqual(types[5], np.float64)
+ self.assertEqual(types[6], object)
+ self.assertEqual(types[7], object)
+ self.assertTrue(np.can_cast(np.datetime64, types[8]))
+ self.assertTrue(np.can_cast(np.datetime64, types[9]))
+ self.assertTrue(np.can_cast(np.timedelta64, types[10]))
+
+ @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
+ def test_to_pandas_from_mixed_dataframe(self):
is_arrow_enabled = [True, False]
for value in is_arrow_enabled:
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
- pdf = self.spark.sql(sql).toPandas()
- types = pdf.dtypes
- self.assertEqual(types[0], np.float64)
- self.assertEqual(types[1], np.float64)
- self.assertEqual(types[2], np.float64)
- self.assertEqual(types[3], np.float64)
- self.assertEqual(types[4], np.float32)
- self.assertEqual(types[5], np.float64)
- self.assertEqual(types[6], object)
- self.assertEqual(types[7], object)
- self.assertTrue(np.can_cast(np.datetime64, types[8]))
- self.assertTrue(np.can_cast(np.datetime64, types[9]))
- self.assertTrue(np.can_cast(np.timedelta64, types[10]))
+ self.check_to_pandas_from_mixed_dataframe()
- @unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
- def test_to_pandas_from_mixed_dataframe(self):
+ def check_to_pandas_from_mixed_dataframe(self):
# SPARK-29188 test that toPandas() on a dataframe with some nulls has correct dtypes
# SPARK-30537 test that toPandas() on a dataframe with some nulls has correct dtypes
# using arrow
@@ -1270,12 +1282,9 @@ class DataFrameTestsMixin:
FROM VALUES (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
(NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL)
"""
- is_arrow_enabled = [True, False]
- for value in is_arrow_enabled:
- with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": value}):
- pdf_with_some_nulls = self.spark.sql(sql).toPandas()
- pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas()
- self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes))
+ pdf_with_some_nulls = self.spark.sql(sql).toPandas()
+ pdf_with_only_nulls = self.spark.sql(sql).filter("tinyint is null").toPandas()
+ self.assertTrue(np.all(pdf_with_only_nulls.dtypes == pdf_with_some_nulls.dtypes))
@unittest.skipIf(
not have_pandas or not have_pyarrow or pyarrow_version_less_than_minimum("2.0.0"),
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org