You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/02/08 11:16:27 UTC

[spark] branch branch-3.4 updated: [SPARK-42381][CONNECT][PYTHON] CreateDataFrame` should accept objects

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new e3f493132d9 [SPARK-42381][CONNECT][PYTHON] CreateDataFrame` should accept objects
e3f493132d9 is described below

commit e3f493132d96f0e8ef5f2ad47efeee91c22abb7f
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Wed Feb 8 19:15:45 2023 +0800

    [SPARK-42381][CONNECT][PYTHON] CreateDataFrame` should accept objects
    
    ### What changes were proposed in this pull request?
    `CreateDataFrame` should accept objects
    
    ### Why are the changes needed?
    for parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    enabled UT and added UT
    
    Closes #39939 from zhengruifeng/connect_createDF_objects.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
    (cherry picked from commit 2fbf57e9d138d03a520726ec21029a49a34035cf)
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/connect/conversion.py               | 12 +++++++++++-
 python/pyspark/sql/connect/session.py                  |  4 +++-
 python/pyspark/sql/tests/connect/test_connect_basic.py | 17 +++++++++++++++++
 python/pyspark/sql/tests/connect/test_parity_types.py  |  5 -----
 4 files changed, 31 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py
index 0a723f2977a..4dbdb5db212 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -105,7 +105,9 @@ class LocalDataToArrowConversion:
                 if value is None:
                     return None
                 else:
-                    assert isinstance(value, (tuple, dict)), f"{type(value)} {value}"
+                    assert isinstance(value, (tuple, dict)) or hasattr(
+                        value, "__dict__"
+                    ), f"{type(value)} {value}"
 
                     _dict = {}
                     if isinstance(value, dict):
@@ -116,6 +118,10 @@ class LocalDataToArrowConversion:
                         for k, v in value.asDict(recursive=False).items():
                             assert isinstance(k, str)
                             _dict[k] = field_convs[k](v)
+                    elif not isinstance(value, Row) and hasattr(value, "__dict__"):
+                        for k, v in value.__dict__.items():
+                            assert isinstance(k, str)
+                            _dict[k] = field_convs[k](v)
                     else:
                         i = 0
                         for v in value:
@@ -253,6 +259,10 @@ class LocalDataToArrowConversion:
             elif isinstance(item, Row) and hasattr(item, "__fields__"):
                 for col, value in item.asDict(recursive=False).items():
                     _dict[col] = column_convs[col](value)
+            elif not isinstance(item, Row) and hasattr(item, "__dict__"):
+                for col, value in item.__dict__.items():
+                    print(col, value)
+                    _dict[col] = column_convs[col](value)
             else:
                 i = 0
                 for value in item:
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index b93bbffc999..898baa45b03 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -294,7 +294,9 @@ class SparkSession:
                 # For dictionaries, we sort the schema in alphabetical order.
                 _data = [dict(sorted(d.items())) for d in _data]
 
-            elif not isinstance(_data[0], (Row, tuple, list, dict)):
+            elif not isinstance(_data[0], (Row, tuple, list, dict)) and not hasattr(
+                _data[0], "__dict__"
+            ):
                 # input data can be [1, 2, 3]
                 # we need to convert it to [[1], [2], [3]] to be able to infer schema.
                 _data = [[d] for d in _data]
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 9068d6f5635..a9beb71545d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -37,6 +37,7 @@ from pyspark.sql.types import (
 )
 
 from pyspark.testing.sqlutils import (
+    MyObject,
     SQLTestUtils,
     PythonOnlyUDT,
     ExamplePoint,
@@ -840,6 +841,22 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.assertEqual(cdf.schema, sdf.schema)
             self.assertEqual(cdf.collect(), sdf.collect())
 
+    def test_create_df_from_objects(self):
+        data = [MyObject(1, "1"), MyObject(2, "2")]
+
+        # +---+-----+
+        # |key|value|
+        # +---+-----+
+        # |  1|    1|
+        # |  2|    2|
+        # +---+-----+
+
+        cdf = self.connect.createDataFrame(data)
+        sdf = self.spark.createDataFrame(data)
+
+        self.assertEqual(cdf.schema, sdf.schema)
+        self.assertEqual(cdf.collect(), sdf.collect())
+
     def test_simple_explain_string(self):
         df = self.connect.read.table(self.tbl_name).limit(10)
         result = df._explain_string()
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py
index 025e64f2bf0..e966986c152 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -54,11 +54,6 @@ class TypesParityTests(TypesTestsMixin, ReusedConnectTestCase):
     def test_complex_nested_udt_in_df(self):
         super().test_complex_nested_udt_in_df()
 
-    # TODO(SPARK-42020): createDataFrame with UDT
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_create_dataframe_from_objects(self):
-        super().test_create_dataframe_from_objects()
-
     @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
     def test_create_dataframe_schema_mismatch(self):
         super().test_create_dataframe_schema_mismatch()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org