You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ru...@apache.org on 2023/02/24 00:03:53 UTC

[spark] branch branch-3.4 updated: [SPARK-42444][PYTHON] `DataFrame.drop` should handle duplicated columns properly

This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 8dd00d945e4 [SPARK-42444][PYTHON] `DataFrame.drop` should handle duplicated columns properly
8dd00d945e4 is described below

commit 8dd00d945e408764b89f59eda14109fa508f3072
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Fri Feb 24 08:03:06 2023 +0800

    [SPARK-42444][PYTHON] `DataFrame.drop` should handle duplicated columns properly
    
    ### What changes were proposed in this pull request?
    Existing implementation always convert inputs (maybe column or column name) to columns, this cause `AMBIGUOUS_REFERENCE` issue since there maybe several columns with the same name.
    
    In the JVM side, the logics of drop(column: Column) and drop(columnName: String) are different, we can not simply always convert a column name to column via col() method.
    
    When there are multi-column with the same name (e.g, `name`), users can:
    1, `drop('name')` --- drop all the columns;
    2, `drop(df1.name)` --- drop the column from the specific dataframe `df1`;
    
    But if users call `drop(col('name'))`, it will fail due to ambiguous issue.
    
    In Pyspark, it is a bit complex, that the user can input both column names with columns. This PR drops the columns first, and then the column names.
    
    ### Why are the changes needed?
    bug fix
    
    ```
    >>> from pyspark.sql import Row
    >>> df1 = spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
    >>> df2 = spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")])
    >>> df3 = df1.join(df2, df1.name == df2.name, 'inner')
    >>> df3.show()
    +---+----+------+----+
    |age|name|height|name|
    +---+----+------+----+
    | 16| Bob|    85| Bob|
    | 14| Tom|    80| Tom|
    +---+----+------+----+
    ```
    
    BEFORE
    ```
    >>> df3.drop("name", "age").columns
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.captured.AnalysisException: [AMBIGUOUS_REFERENCE] Reference `name` is ambiguous, could be: [`name`, `name`].
    ```
    
    AFTER
    ```
    >>> df3.drop("name", "age").columns
    ['height']
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    added tests
    
    Closes #40135 from zhengruifeng/py_fix_drop.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
    (cherry picked from commit 0b9ed26e48248aa58642b3626a02dd8c89a01afb)
    Signed-off-by: Ruifeng Zheng <ru...@apache.org>
---
 python/pyspark/sql/dataframe.py                    | 27 +++++++++++++---------
 .../sql/tests/connect/test_parity_dataframe.py     |  5 ++++
 python/pyspark/sql/tests/test_dataframe.py         |  9 ++++++++
 3 files changed, 30 insertions(+), 11 deletions(-)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index fa25d148060..1cd28f0e8b2 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -4923,21 +4923,26 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
         | 14|    80|
         +---+------+
         """
-        if len(cols) == 1:
-            col = cols[0]
-            if isinstance(col, str):
-                jdf = self._jdf.drop(col)
-            elif isinstance(col, Column):
-                jdf = self._jdf.drop(col._jc)
+        column_names: List[str] = []
+        java_columns: List[JavaObject] = []
+
+        for c in cols:
+            if isinstance(c, str):
+                column_names.append(c)
+            elif isinstance(c, Column):
+                java_columns.append(c._jc)
             else:
                 raise PySparkTypeError(
                     error_class="NOT_COLUMN_OR_STR",
-                    message_parameters={"arg_name": "col", "arg_type": type(col).__name__},
+                    message_parameters={"arg_name": "col", "arg_type": type(c).__name__},
                 )
-        else:
-            jcols = [_to_java_column(c) for c in cols]
-            first_column, *remaining_columns = jcols
-            jdf = self._jdf.drop(first_column, self._jseq(remaining_columns))
+
+        jdf = self._jdf
+        if len(java_columns) > 0:
+            first_column, *remaining_columns = java_columns
+            jdf = jdf.drop(first_column, self._jseq(remaining_columns))
+        if len(column_names) > 0:
+            jdf = jdf.drop(self._jseq(column_names))
 
         return DataFrame(jdf, self.sparkSession)
 
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 07cae0fb27d..25fdbebd991 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -142,6 +142,11 @@ class DataFrameParityTests(DataFrameTestsMixin, ReusedConnectTestCase):
     def test_to_pandas_with_duplicated_column_names(self):
         super().test_to_pandas_with_duplicated_column_names()
 
+    # TODO(SPARK-42367): DataFrame.drop should handle duplicated columns properly
+    @unittest.skip("Fails in Spark Connect, should enable.")
+    def test_drop_duplicates_with_ambiguous_reference(self):
+        super().test_drop_duplicates_with_ambiguous_reference()
+
 
 if __name__ == "__main__":
     import unittest
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 1d52602a96f..610edc0926d 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -144,6 +144,15 @@ class DataFrameTestsMixin:
             message_parameters={"arg_name": "subset", "arg_type": "str"},
         )
 
+    def test_drop_duplicates_with_ambiguous_reference(self):
+        df1 = self.spark.createDataFrame([(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"])
+        df2 = self.spark.createDataFrame([Row(height=80, name="Tom"), Row(height=85, name="Bob")])
+        df3 = df1.join(df2, df1.name == df2.name, "inner")
+
+        self.assertEqual(df3.drop("name", "age").columns, ["height"])
+        self.assertEqual(df3.drop("name", df3.age, "unknown").columns, ["height"])
+        self.assertEqual(df3.drop("name", "age", df3.height).columns, [])
+
     def test_dropna(self):
         schema = StructType(
             [


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org