You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/08/11 03:14:35 UTC

[spark] branch master updated: [SPARK-39895][SQL][PYTHON] Support multiple column drop

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 69f402ad808 [SPARK-39895][SQL][PYTHON] Support multiple column drop
69f402ad808 is described below

commit 69f402ad8085bc50837128645b61168e7d5244b9
Author: Santosh Pingale <pi...@gmail.com>
AuthorDate: Thu Aug 11 12:14:23 2022 +0900

    [SPARK-39895][SQL][PYTHON] Support multiple column drop
    
    ### What changes were proposed in this pull request?
    Pyspark dataframe drop has following signature:
    `def drop(self, *cols: "ColumnOrName") -> "DataFrame":`
    
    However when we try to pass multiple Column types to drop function it raises TypeError
    
    `each col in the param list should be a string`
    
    *Minimal reproducible example:*
    ```python
    values = [("id_1", 5, 9), ("id_2", 5, 1), ("id_3", 4, 3), ("id_1", 3, 3), ("id_2", 4, 3)]
    df = spark.createDataFrame(values, "id string, point int, count int")
    df.drop(df.point, df.count)
    ```
    
    It spits out following:
    ```
    /spark/python/lib/pyspark.zip/pyspark/sql/dataframe.py in drop(self, *cols)
    2537 for col in cols:
    2538 if not isinstance(col, str):
    -> 2539 raise TypeError("each col in the param list should be a string")
    2540 jdf = self._jdf.drop(self._jseq(cols))
    2541
    
    TypeError: each col in the param list should be a string
    ```
    
    ### Why are the changes needed?
    We expect that multiple columns can be handled by drop call on df because of its typing but that is not the case.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, fixes issues related type confirmation in pyspark api
    
    ### How was this patch tested?
    Added missing tests for regression testing. CI Pipeline on fork and CI here will test them.
    
    Closes #37335 from santosh-d3vpl3x/master.
    
    Lead-authored-by: Santosh Pingale <pi...@gmail.com>
    Co-authored-by: Hyukjin Kwon <gu...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 python/pyspark/sql/dataframe.py                    |  7 +++---
 python/pyspark/sql/tests/test_dataframe.py         |  9 ++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 26 ++++++++++++++++++----
 .../org/apache/spark/sql/DataFrameSuite.scala      | 10 +++++++++
 4 files changed, 44 insertions(+), 8 deletions(-)

diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 41ac701a332..8ab3ed35578 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -3380,10 +3380,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
             else:
                 raise TypeError("col should be a string or a Column")
         else:
-            for col in cols:
-                if not isinstance(col, str):
-                    raise TypeError("each col in the param list should be a string")
-            jdf = self._jdf.drop(self._jseq(cols))
+            jcols = [_to_java_column(c) for c in cols]
+            first_column, *remaining_columns = jcols
+            jdf = self._jdf.drop(first_column, self._jseq(remaining_columns))
 
         return DataFrame(jdf, self.sparkSession)
 
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 987ff91402d..58b0a28e9f2 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -87,6 +87,15 @@ class DataFrameTests(ReusedSQLTestCase):
         pydoc.render_doc(df.foo)
         pydoc.render_doc(df.take(1))
 
+    def test_drop(self):
+        df = self.spark.createDataFrame([("A", 50, "Y"), ("B", 60, "Y")], ["name", "age", "active"])
+        self.assertEqual(df.drop("active").columns, ["name", "age"])
+        self.assertEqual(df.drop("active", "nonexistent_column").columns, ["name", "age"])
+        self.assertEqual(df.drop("name", "age", "active").columns, [])
+        self.assertEqual(df.drop(col("name")).columns, ["age", "active"])
+        self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"])
+        self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"])
+
     def test_drop_duplicates(self):
         # SPARK-36034 test that drop duplicates throws a type error when in correct type provided
         df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3f0cef33b5f..18aea40f556 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2857,7 +2857,9 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * Returns a new Dataset with a column dropped.
+   * Returns a new Dataset with column dropped.
+   *
+   * This method can only be used to drop top level column.
    * This version of drop accepts a [[Column]] rather than a name.
    * This is a no-op if the Dataset doesn't have a column
    * with an equivalent expression.
@@ -2866,15 +2868,31 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   def drop(col: Column): DataFrame = {
-    val expression = col match {
+    drop(col, Seq.empty : _*)
+  }
+
+  /**
+   * Returns a new Dataset with columns dropped.
+   *
+   * This method can only be used to drop top level columns.
+   * This is a no-op if the Dataset doesn't have a columns
+   * with an equivalent expression.
+   *
+   * @group untypedrel
+   * @since 3.4.0
+   */
+  @scala.annotation.varargs
+  def drop(col: Column, cols: Column*): DataFrame = {
+    val allColumns = col +: cols
+    val expressions = (for (col <- allColumns) yield col match {
       case Column(u: UnresolvedAttribute) =>
         queryExecution.analyzed.resolveQuoted(
           u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
       case Column(expr: Expression) => expr
-    }
+    })
     val attrs = this.logicalPlan.output
     val colsAfterDrop = attrs.filter { attr =>
-      !attr.semanticEquals(expression)
+      expressions.forall(expression => !attr.semanticEquals(expression))
     }.map(attr => Column(attr))
     select(colsAfterDrop : _*)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index bdcaa9f3b0e..74b01b691b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -826,6 +826,16 @@ class DataFrameSuite extends QueryTest
     assert(df.schema.map(_.name) === Seq("key", "value"))
   }
 
+  test("SPARK-39895: drop two column references") {
+    val col = Column("key")
+    val randomCol = Column("random")
+    val df = testData.drop(col, randomCol)
+    checkAnswer(
+      df,
+      testData.collect().map(x => Row(x.getString(1))).toSeq)
+    assert(df.schema.map(_.name) === Seq("value"))
+  }
+
   test("drop unknown column with same name with column reference") {
     val col = Column("key")
     val df = testData.drop(col)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org