You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/04/15 04:58:16 UTC

[spark] branch branch-2.4 updated: [SPARK-31186][PYSPARK][SQL][2.4] toPandas should not fail on duplicate column names

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 49abdc4  [SPARK-31186][PYSPARK][SQL][2.4] toPandas should not fail on duplicate column names
49abdc4 is described below

commit 49abdc42ff6d52c96f96c8867dc7b089cf9380a1
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Wed Apr 15 13:57:23 2020 +0900

    [SPARK-31186][PYSPARK][SQL][2.4] toPandas should not fail on duplicate column names
    
    ### What changes were proposed in this pull request?
    
    When `toPandas` API works on duplicate column names produced from operators like join, we see the error like:
    
    ```
    ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
    ```
    
    This patch fixes the error in `toPandas` API.
    
    This is the backport of original patch to branch-2.4.
    
    ### Why are the changes needed?
    
    To make `toPandas` work on dataframe with duplicate column names.
    
    ### Does this PR introduce any user-facing change?
    
    Yes. Previously calling `toPandas` API on a dataframe with duplicate column names will fail. After this patch, it will produce correct result.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #28219 from viirya/SPARK-31186-2.4.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 python/pyspark/sql/dataframe.py | 40 ++++++++++++++++++++++++++++++++++------
 python/pyspark/sql/tests.py     | 18 ++++++++++++++++++
 2 files changed, 52 insertions(+), 6 deletions(-)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index a2651d2..b58d976 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -27,6 +27,7 @@ else:
     from itertools import imap as map
     from cgi import escape as html_escape
 
+from collections import Counter
 import warnings
 
 from pyspark import copy_func, since, _NoValue
@@ -2148,9 +2149,16 @@ class DataFrame(object):
 
         # Below is toPandas without Arrow optimization.
         pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
+        column_counter = Counter(self.columns)
+
+        dtype = [None] * len(self.schema)
+        for fieldIdx, field in enumerate(self.schema):
+            # For duplicate column name, we use `iloc` to access it.
+            if column_counter[field.name] > 1:
+                pandas_col = pdf.iloc[:, fieldIdx]
+            else:
+                pandas_col = pdf[field.name]
 
-        dtype = {}
-        for field in self.schema:
             pandas_type = _to_corrected_pandas_type(field.dataType)
             # SPARK-21766: if an integer field is nullable and has null values, it can be
             # inferred by pandas as float column. Once we convert the column with NaN back
@@ -2158,11 +2166,31 @@ class DataFrame(object):
             # float type, not the corrected type from the schema in this case.
             if pandas_type is not None and \
                 not(isinstance(field.dataType, IntegralType) and field.nullable and
-                    pdf[field.name].isnull().any()):
-                dtype[field.name] = pandas_type
+                    pandas_col.isnull().any()):
+                dtype[fieldIdx] = pandas_type
+
+        df = pd.DataFrame()
+        for index, t in enumerate(dtype):
+            column_name = self.schema[index].name
+
+            # For duplicate column name, we use `iloc` to access it.
+            if column_counter[column_name] > 1:
+                series = pdf.iloc[:, index]
+            else:
+                series = pdf[column_name]
+
+            if t is not None:
+                series = series.astype(t, copy=False)
+
+            # `insert` API makes copy of data, we only do it for Series of duplicate column names.
+            # `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work because `iloc` could
+            # return a view or a copy depending by context.
+            if column_counter[column_name] > 1:
+                df.insert(index, column_name, series, allow_duplicates=True)
+            else:
+                df[column_name] = series
 
-        for f, t in dtype.items():
-            pdf[f] = pdf[f].astype(t, copy=False)
+        pdf = df
 
         if timezone is None:
             return pdf
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 0284267..d359e00 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3296,6 +3296,24 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEquals(types[4], np.object)  # datetime.date
         self.assertEquals(types[5], 'datetime64[ns]')
 
+    @unittest.skipIf(not _have_pandas, _pandas_requirement_message)
+    def test_to_pandas_on_cross_join(self):
+        import numpy as np
+
+        sql = """
+        select t1.*, t2.* from (
+          select explode(sequence(1, 3)) v
+        ) t1 left join (
+          select explode(sequence(1, 3)) v
+        ) t2
+        """
+        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+            df = self.spark.sql(sql)
+            pdf = df.toPandas()
+            types = pdf.dtypes
+            self.assertEquals(types.iloc[0], np.int32)
+            self.assertEquals(types.iloc[1], np.int32)
+
     @unittest.skipIf(_have_pandas, "Required Pandas was found.")
     def test_to_pandas_required_pandas_not_found(self):
         with QuietTest(self.sc):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org