You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/01 08:59:58 UTC

[spark] branch master updated: [SPARK-41870][CONNECT][PYTHON] Fix createDataFrame to handle duplicated column names

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ff93ae93d8 [SPARK-41870][CONNECT][PYTHON] Fix createDataFrame to handle duplicated column names
1ff93ae93d8 is described below

commit 1ff93ae93d87ed22281aa68fb82ea869754f67c1
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Wed Mar 1 16:59:41 2023 +0800

    [SPARK-41870][CONNECT][PYTHON] Fix createDataFrame to handle duplicated column names
    
    ### What changes were proposed in this pull request?
    
    Fixes `createDataFrame` to handle duplicated column names.
    
    ### Why are the changes needed?
    
    Currently the following command returns a wrong result:
    
    ```py
    >>> spark.createDataFrame([(1, 2)], ["c", "c"]).show()
    +---+---+
    |  c|  c|
    +---+---+
    |  2|  2|
    +---+---+
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    Duplicated column names will work:
    
    ```py
    >>> spark.createDataFrame([(1, 2)], ["c", "c"]).show()
    +---+---+
    |  c|  c|
    +---+---+
    |  1|  2|
    +---+---+
    ```
    
    ### How was this patch tested?
    
    Enabled the related test.
    
    Closes #40227 from ueshin/issues/SPARK-41870/dup_cols.
    
    Authored-by: Takuya UESHIN <ue...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/connect/conversion.py           | 42 ++++++++--------------
 .../sql/tests/connect/test_parity_dataframe.py     |  5 ---
 2 files changed, 15 insertions(+), 32 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py
index 40679b80291..7b452de48f6 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -243,36 +243,24 @@ class LocalDataToArrowConversion:
 
         column_names = schema.fieldNames()
 
-        column_convs = {
-            field.name: LocalDataToArrowConversion._create_converter(field.dataType)
-            for field in schema.fields
-        }
+        column_convs = [
+            LocalDataToArrowConversion._create_converter(field.dataType) for field in schema.fields
+        ]
 
-        pylist = []
+        pylist: List[List] = [[] for _ in range(len(column_names))]
 
         for item in data:
-            _dict = {}
-
-            if isinstance(item, dict):
-                for col, value in item.items():
-                    _dict[col] = column_convs[col](value)
-            elif isinstance(item, Row) and hasattr(item, "__fields__"):
-                for col, value in item.asDict(recursive=False).items():
-                    _dict[col] = column_convs[col](value)
-            elif not isinstance(item, Row) and hasattr(item, "__dict__"):
-                for col, value in item.__dict__.items():
-                    print(col, value)
-                    _dict[col] = column_convs[col](value)
-            else:
-                i = 0
-                for value in item:
-                    col = column_names[i]
-                    _dict[col] = column_convs[col](value)
-                    i += 1
-
-            pylist.append(_dict)
-
-        return pa.Table.from_pylist(pylist, schema=pa_schema)
+            if not isinstance(item, Row) and hasattr(item, "__dict__"):
+                item = item.__dict__
+            for i, col in enumerate(column_names):
+                if isinstance(item, dict):
+                    value = item.get(col)
+                else:
+                    value = item[i]
+
+                pylist[i].append(column_convs[i](value))
+
+        return pa.Table.from_arrays(pylist, schema=pa_schema)
 
 
 class ArrowTableToRowsConversion:
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index c5972ac02ae..79626586f73 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -27,11 +27,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
     def test_create_dataframe_from_pandas_with_dst(self):
         super().test_create_dataframe_from_pandas_with_dst()
 
-    # TODO(SPARK-41870): Handle duplicate columns in `createDataFrame`
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_duplicated_column_names(self):
-        super().test_duplicated_column_names()
-
     @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
     def test_help_command(self):
         super().test_help_command()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org