You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2022/08/11 03:14:35 UTC
[spark] branch master updated: [SPARK-39895][SQL][PYTHON] Support multiple column drop
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 69f402ad808 [SPARK-39895][SQL][PYTHON] Support multiple column drop
69f402ad808 is described below
commit 69f402ad8085bc50837128645b61168e7d5244b9
Author: Santosh Pingale <pi...@gmail.com>
AuthorDate: Thu Aug 11 12:14:23 2022 +0900
[SPARK-39895][SQL][PYTHON] Support multiple column drop
### What changes were proposed in this pull request?
Pyspark dataframe drop has following signature:
`def drop(self, *cols: "ColumnOrName") -> "DataFrame":`
However when we try to pass multiple Column types to drop function it raises TypeError
`each col in the param list should be a string`
*Minimal reproducible example:*
```python
values = [("id_1", 5, 9), ("id_2", 5, 1), ("id_3", 4, 3), ("id_1", 3, 3), ("id_2", 4, 3)]
df = spark.createDataFrame(values, "id string, point int, count int")
df.drop(df.point, df.count)
```
It spits out following:
```
/spark/python/lib/pyspark.zip/pyspark/sql/dataframe.py in drop(self, *cols)
2537 for col in cols:
2538 if not isinstance(col, str):
-> 2539 raise TypeError("each col in the param list should be a string")
2540 jdf = self._jdf.drop(self._jseq(cols))
2541
TypeError: each col in the param list should be a string
```
### Why are the changes needed?
We expect that multiple columns can be handled by drop call on df because of its typing but that is not the case.
### Does this PR introduce _any_ user-facing change?
Yes, fixes issues related type confirmation in pyspark api
### How was this patch tested?
Added missing tests for regression testing. CI Pipeline on fork and CI here will test them.
Closes #37335 from santosh-d3vpl3x/master.
Lead-authored-by: Santosh Pingale <pi...@gmail.com>
Co-authored-by: Hyukjin Kwon <gu...@gmail.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
python/pyspark/sql/dataframe.py | 7 +++---
python/pyspark/sql/tests/test_dataframe.py | 9 ++++++++
.../main/scala/org/apache/spark/sql/Dataset.scala | 26 ++++++++++++++++++----
.../org/apache/spark/sql/DataFrameSuite.scala | 10 +++++++++
4 files changed, 44 insertions(+), 8 deletions(-)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 41ac701a332..8ab3ed35578 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -3380,10 +3380,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
else:
raise TypeError("col should be a string or a Column")
else:
- for col in cols:
- if not isinstance(col, str):
- raise TypeError("each col in the param list should be a string")
- jdf = self._jdf.drop(self._jseq(cols))
+ jcols = [_to_java_column(c) for c in cols]
+ first_column, *remaining_columns = jcols
+ jdf = self._jdf.drop(first_column, self._jseq(remaining_columns))
return DataFrame(jdf, self.sparkSession)
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 987ff91402d..58b0a28e9f2 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -87,6 +87,15 @@ class DataFrameTests(ReusedSQLTestCase):
pydoc.render_doc(df.foo)
pydoc.render_doc(df.take(1))
+ def test_drop(self):
+ df = self.spark.createDataFrame([("A", 50, "Y"), ("B", 60, "Y")], ["name", "age", "active"])
+ self.assertEqual(df.drop("active").columns, ["name", "age"])
+ self.assertEqual(df.drop("active", "nonexistent_column").columns, ["name", "age"])
+ self.assertEqual(df.drop("name", "age", "active").columns, [])
+ self.assertEqual(df.drop(col("name")).columns, ["age", "active"])
+ self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"])
+ self.assertEqual(df.drop(col("name"), col("age"), col("random")).columns, ["active"])
+
def test_drop_duplicates(self):
# SPARK-36034 test that drop duplicates throws a type error when in correct type provided
df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)], ["name", "age"])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3f0cef33b5f..18aea40f556 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2857,7 +2857,9 @@ class Dataset[T] private[sql](
}
/**
- * Returns a new Dataset with a column dropped.
+ * Returns a new Dataset with column dropped.
+ *
+ * This method can only be used to drop top level column.
* This version of drop accepts a [[Column]] rather than a name.
* This is a no-op if the Dataset doesn't have a column
* with an equivalent expression.
@@ -2866,15 +2868,31 @@ class Dataset[T] private[sql](
* @since 2.0.0
*/
def drop(col: Column): DataFrame = {
- val expression = col match {
+ drop(col, Seq.empty : _*)
+ }
+
+ /**
+ * Returns a new Dataset with columns dropped.
+ *
+ * This method can only be used to drop top level columns.
+ * This is a no-op if the Dataset doesn't have a columns
+ * with an equivalent expression.
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ @scala.annotation.varargs
+ def drop(col: Column, cols: Column*): DataFrame = {
+ val allColumns = col +: cols
+ val expressions = (for (col <- allColumns) yield col match {
case Column(u: UnresolvedAttribute) =>
queryExecution.analyzed.resolveQuoted(
u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
case Column(expr: Expression) => expr
- }
+ })
val attrs = this.logicalPlan.output
val colsAfterDrop = attrs.filter { attr =>
- !attr.semanticEquals(expression)
+ expressions.forall(expression => !attr.semanticEquals(expression))
}.map(attr => Column(attr))
select(colsAfterDrop : _*)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index bdcaa9f3b0e..74b01b691b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -826,6 +826,16 @@ class DataFrameSuite extends QueryTest
assert(df.schema.map(_.name) === Seq("key", "value"))
}
+ test("SPARK-39895: drop two column references") {
+ val col = Column("key")
+ val randomCol = Column("random")
+ val df = testData.drop(col, randomCol)
+ checkAnswer(
+ df,
+ testData.collect().map(x => Row(x.getString(1))).toSeq)
+ assert(df.schema.map(_.name) === Seq("value"))
+ }
+
test("drop unknown column with same name with column reference") {
val col = Column("key")
val df = testData.drop(col)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org