You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/03/01 09:00:18 UTC
[spark] branch branch-3.4 updated: [SPARK-41870][CONNECT][PYTHON] Fix createDataFrame to handle duplicated column names
This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 0cebb4b2159 [SPARK-41870][CONNECT][PYTHON] Fix createDataFrame to handle duplicated column names
0cebb4b2159 is described below
commit 0cebb4b215981bc5209f1ca5112d2c580421510e
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Wed Mar 1 16:59:41 2023 +0800
[SPARK-41870][CONNECT][PYTHON] Fix createDataFrame to handle duplicated column names
### What changes were proposed in this pull request?
Fixes `createDataFrame` to handle duplicated column names.
### Why are the changes needed?
Currently the following command returns a wrong result:
```py
>>> spark.createDataFrame([(1, 2)], ["c", "c"]).show()
+---+---+
| c| c|
+---+---+
| 2| 2|
+---+---+
```
### Does this PR introduce _any_ user-facing change?
Duplicated column names will work:
```py
>>> spark.createDataFrame([(1, 2)], ["c", "c"]).show()
+---+---+
| c| c|
+---+---+
| 1| 2|
+---+---+
```
### How was this patch tested?
Enabled the related test.
Closes #40227 from ueshin/issues/SPARK-41870/dup_cols.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
(cherry picked from commit 1ff93ae93d87ed22281aa68fb82ea869754f67c1)
Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
python/pyspark/sql/connect/conversion.py | 42 ++++++++--------------
.../sql/tests/connect/test_parity_dataframe.py | 5 ---
2 files changed, 15 insertions(+), 32 deletions(-)
diff --git a/python/pyspark/sql/connect/conversion.py b/python/pyspark/sql/connect/conversion.py
index 40679b80291..7b452de48f6 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -243,36 +243,24 @@ class LocalDataToArrowConversion:
column_names = schema.fieldNames()
- column_convs = {
- field.name: LocalDataToArrowConversion._create_converter(field.dataType)
- for field in schema.fields
- }
+ column_convs = [
+ LocalDataToArrowConversion._create_converter(field.dataType) for field in schema.fields
+ ]
- pylist = []
+ pylist: List[List] = [[] for _ in range(len(column_names))]
for item in data:
- _dict = {}
-
- if isinstance(item, dict):
- for col, value in item.items():
- _dict[col] = column_convs[col](value)
- elif isinstance(item, Row) and hasattr(item, "__fields__"):
- for col, value in item.asDict(recursive=False).items():
- _dict[col] = column_convs[col](value)
- elif not isinstance(item, Row) and hasattr(item, "__dict__"):
- for col, value in item.__dict__.items():
- print(col, value)
- _dict[col] = column_convs[col](value)
- else:
- i = 0
- for value in item:
- col = column_names[i]
- _dict[col] = column_convs[col](value)
- i += 1
-
- pylist.append(_dict)
-
- return pa.Table.from_pylist(pylist, schema=pa_schema)
+ if not isinstance(item, Row) and hasattr(item, "__dict__"):
+ item = item.__dict__
+ for i, col in enumerate(column_names):
+ if isinstance(item, dict):
+ value = item.get(col)
+ else:
+ value = item[i]
+
+ pylist[i].append(column_convs[i](value))
+
+ return pa.Table.from_arrays(pylist, schema=pa_schema)
class ArrowTableToRowsConversion:
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index c5972ac02ae..79626586f73 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -27,11 +27,6 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
def test_create_dataframe_from_pandas_with_dst(self):
super().test_create_dataframe_from_pandas_with_dst()
- # TODO(SPARK-41870): Handle duplicate columns in `createDataFrame`
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_duplicated_column_names(self):
- super().test_duplicated_column_names()
-
@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_help_command(self):
super().test_help_command()
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org