You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/02/08 11:16:27 UTC
[spark] branch branch-3.4 updated: [SPARK-42381][CONNECT][PYTHON] CreateDataFrame` should accept objects
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new e3f493132d9 [SPARK-42381][CONNECT][PYTHON] CreateDataFrame` should accept objects
e3f493132d9 is described below
commit e3f493132d96f0e8ef5f2ad47efeee91c22abb7f
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Feb 8 19:15:45 2023 +0800
[SPARK-42381][CONNECT][PYTHON] CreateDataFrame` should accept objects
### What changes were proposed in this pull request?
`CreateDataFrame` should accept objects
### Why are the changes needed?
for parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
enabled UT and added UT
Closes #39939 from zhengruifeng/connect_createDF_objects.
Authored-by: Ruifeng Zheng <ru...@apache.org>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
(cherry picked from commit 2fbf57e9d138d03a520726ec21029a49a34035cf)
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/sql/connect/conversion.py | 12 +++++++++++-
python/pyspark/sql/connect/session.py | 4 +++-
python/pyspark/sql/tests/connect/test_connect_basic.py | 17 +++++++++++++++++
python/pyspark/sql/tests/connect/test_parity_types.py | 5 -----
4 files changed, 31 insertions(+), 7 deletions(-)
diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py
index 0a723f2977a..4dbdb5db212 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -105,7 +105,9 @@ class LocalDataToArrowConversion:
if value is None:
return None
else:
- assert isinstance(value, (tuple, dict)), f"{type(value)} {value}"
+ assert isinstance(value, (tuple, dict)) or hasattr(
+ value, "__dict__"
+ ), f"{type(value)} {value}"
_dict = {}
if isinstance(value, dict):
@@ -116,6 +118,10 @@ class LocalDataToArrowConversion:
for k, v in value.asDict(recursive=False).items():
assert isinstance(k, str)
_dict[k] = field_convs[k](v)
+ elif not isinstance(value, Row) and hasattr(value, "__dict__"):
+ for k, v in value.__dict__.items():
+ assert isinstance(k, str)
+ _dict[k] = field_convs[k](v)
else:
i = 0
for v in value:
@@ -253,6 +259,10 @@ class LocalDataToArrowConversion:
elif isinstance(item, Row) and hasattr(item, "__fields__"):
for col, value in item.asDict(recursive=False).items():
_dict[col] = column_convs[col](value)
+ elif not isinstance(item, Row) and hasattr(item, "__dict__"):
+ for col, value in item.__dict__.items():
+ print(col, value)
+ _dict[col] = column_convs[col](value)
else:
i = 0
for value in item:
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index b93bbffc999..898baa45b03 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -294,7 +294,9 @@ class SparkSession:
# For dictionaries, we sort the schema in alphabetical order.
_data = [dict(sorted(d.items())) for d in _data]
- elif not isinstance(_data[0], (Row, tuple, list, dict)):
+ elif not isinstance(_data[0], (Row, tuple, list, dict)) and not hasattr(
+ _data[0], "__dict__"
+ ):
# input data can be [1, 2, 3]
# we need to convert it to [[1], [2], [3]] to be able to infer schema.
_data = [[d] for d in _data]
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9068d6f5635..a9beb71545d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -37,6 +37,7 @@ from pyspark.sql.types import (
)
from pyspark.testing.sqlutils import (
+ MyObject,
SQLTestUtils,
PythonOnlyUDT,
ExamplePoint,
@@ -840,6 +841,22 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assertEqual(cdf.schema, sdf.schema)
self.assertEqual(cdf.collect(), sdf.collect())
+ def test_create_df_from_objects(self):
+ data = [MyObject(1, "1"), MyObject(2, "2")]
+
+ # +---+-----+
+ # |key|value|
+ # +---+-----+
+ # | 1| 1|
+ # | 2| 2|
+ # +---+-----+
+
+ cdf = self.connect.createDataFrame(data)
+ sdf = self.spark.createDataFrame(data)
+
+ self.assertEqual(cdf.schema, sdf.schema)
+ self.assertEqual(cdf.collect(), sdf.collect())
+
def test_simple_explain_string(self):
df = self.connect.read.table(self.tbl_name).limit(10)
result = df._explain_string()
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py
index 025e64f2bf0..e966986c152 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -54,11 +54,6 @@ class TypesParityTests(TypesTestsMixin, ReusedConnectTestCase):
def test_complex_nested_udt_in_df(self):
super().test_complex_nested_udt_in_df()
- # TODO(SPARK-42020): createDataFrame with UDT
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_create_dataframe_from_objects(self):
- super().test_create_dataframe_from_objects()
-
@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_create_dataframe_schema_mismatch(self):
super().test_create_dataframe_schema_mismatch()
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org