You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/04/15 04:58:16 UTC
[spark] branch branch-2.4 updated: [SPARK-31186][PYSPARK][SQL][2.4]
toPandas should not fail on duplicate column names
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new 49abdc4 [SPARK-31186][PYSPARK][SQL][2.4] toPandas should not fail on duplicate column names
49abdc4 is described below
commit 49abdc42ff6d52c96f96c8867dc7b089cf9380a1
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Wed Apr 15 13:57:23 2020 +0900
[SPARK-31186][PYSPARK][SQL][2.4] toPandas should not fail on duplicate column names
### What changes were proposed in this pull request?
When `toPandas` API works on duplicate column names produced from operators like join, we see the error like:
```
ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
```
This patch fixes the error in `toPandas` API.
This is the backport of original patch to branch-2.4.
### Why are the changes needed?
To make `toPandas` work on dataframe with duplicate column names.
### Does this PR introduce any user-facing change?
Yes. Previously calling `toPandas` API on a dataframe with duplicate column names will fail. After this patch, it will produce correct result.
### How was this patch tested?
Unit test.
Closes #28219 from viirya/SPARK-31186-2.4.
Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
Signed-off-by: HyukjinKwon <gu...@apache.org>
---
python/pyspark/sql/dataframe.py | 40 ++++++++++++++++++++++++++++++++++------
python/pyspark/sql/tests.py | 18 ++++++++++++++++++
2 files changed, 52 insertions(+), 6 deletions(-)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index a2651d2..b58d976 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -27,6 +27,7 @@ else:
from itertools import imap as map
from cgi import escape as html_escape
+from collections import Counter
import warnings
from pyspark import copy_func, since, _NoValue
@@ -2148,9 +2149,16 @@ class DataFrame(object):
# Below is toPandas without Arrow optimization.
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
+ column_counter = Counter(self.columns)
+
+ dtype = [None] * len(self.schema)
+ for fieldIdx, field in enumerate(self.schema):
+ # For duplicate column name, we use `iloc` to access it.
+ if column_counter[field.name] > 1:
+ pandas_col = pdf.iloc[:, fieldIdx]
+ else:
+ pandas_col = pdf[field.name]
- dtype = {}
- for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
# SPARK-21766: if an integer field is nullable and has null values, it can be
# inferred by pandas as float column. Once we convert the column with NaN back
@@ -2158,11 +2166,31 @@ class DataFrame(object):
# float type, not the corrected type from the schema in this case.
if pandas_type is not None and \
not(isinstance(field.dataType, IntegralType) and field.nullable and
- pdf[field.name].isnull().any()):
- dtype[field.name] = pandas_type
+ pandas_col.isnull().any()):
+ dtype[fieldIdx] = pandas_type
+
+ df = pd.DataFrame()
+ for index, t in enumerate(dtype):
+ column_name = self.schema[index].name
+
+ # For duplicate column name, we use `iloc` to access it.
+ if column_counter[column_name] > 1:
+ series = pdf.iloc[:, index]
+ else:
+ series = pdf[column_name]
+
+ if t is not None:
+ series = series.astype(t, copy=False)
+
+ # `insert` API makes copy of data, we only do it for Series of duplicate column names.
+ # `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work because `iloc` could
+ # return a view or a copy depending by context.
+ if column_counter[column_name] > 1:
+ df.insert(index, column_name, series, allow_duplicates=True)
+ else:
+ df[column_name] = series
- for f, t in dtype.items():
- pdf[f] = pdf[f].astype(t, copy=False)
+ pdf = df
if timezone is None:
return pdf
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 0284267..d359e00 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3296,6 +3296,24 @@ class SQLTests(ReusedSQLTestCase):
self.assertEquals(types[4], np.object) # datetime.date
self.assertEquals(types[5], 'datetime64[ns]')
+ @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
+ def test_to_pandas_on_cross_join(self):
+ import numpy as np
+
+ sql = """
+ select t1.*, t2.* from (
+ select explode(sequence(1, 3)) v
+ ) t1 left join (
+ select explode(sequence(1, 3)) v
+ ) t2
+ """
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+ df = self.spark.sql(sql)
+ pdf = df.toPandas()
+ types = pdf.dtypes
+ self.assertEquals(types.iloc[0], np.int32)
+ self.assertEquals(types.iloc[1], np.int32)
+
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
def test_to_pandas_required_pandas_not_found(self):
with QuietTest(self.sc):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org