You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/01/31 00:53:33 UTC
[spark] branch branch-2.4 updated: [SPARK-29890][SQL][2.4]
DataFrameNaFunctions.fill should handle duplicate columns
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new 4c3c1d6 [SPARK-29890][SQL][2.4] DataFrameNaFunctions.fill should handle duplicate columns
4c3c1d6 is described below
commit 4c3c1d6b2a3db18682abaa0250f62ec198f36fdb
Author: Terry Kim <yu...@gmail.com>
AuthorDate: Thu Jan 30 16:52:25 2020 -0800
[SPARK-29890][SQL][2.4] DataFrameNaFunctions.fill should handle duplicate columns
(Backport of #26593)
### What changes were proposed in this pull request?
`DataFrameNaFunctions.fill` doesn't handle duplicate columns even when column names are not specified.
```Scala
val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2")
val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2")
val df = left.join(right, Seq("col1"))
df.printSchema
df.na.fill("hello").show
```
produces
```
root
|-- col1: string (nullable = true)
|-- col2: string (nullable = true)
|-- col2: string (nullable = true)
org.apache.spark.sql.AnalysisException: Reference 'col2' is ambiguous, could be: col2, col2.;
at org.apache.spark.sql.catalyst.expressions.package$AttributeSeq.resolve(package.scala:259)
at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveQuoted(LogicalPlan.scala:121)
at org.apache.spark.sql.Dataset.resolve(Dataset.scala:221)
at org.apache.spark.sql.Dataset.col(Dataset.scala:1268)
```
The reason for the above failure is that columns are looked up with `DataSet.col()` which tries to resolve a column by name and if there are multiple columns with the same name, it will fail due to ambiguity.
This PR updates `DataFrameNaFunctions.fill` such that if the columns to fill are not specified, it will resolve ambiguity gracefully by applying `fill` to all the eligible columns. (Note that if the user specifies the columns, it will still continue to fail due to ambiguity).
### Why are the changes needed?
If column names are not specified, `fill` should not fail due to ambiguity since it should still be able to apply `fill` to the eligible columns.
### Does this PR introduce any user-facing change?
Yes, now the above example displays the following:
```
+----+-----+-----+
|col1| col2| col2|
+----+-----+-----+
| 1|hello| 2|
| 3| 4|hello|
+----+-----+-----+
```
### How was this patch tested?
Added new unit tests.
Closes #27407 from imback82/backport-SPARK-29890.
Authored-by: Terry Kim <yu...@gmail.com>
Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
.../apache/spark/sql/DataFrameNaFunctions.scala | 62 ++++++++++++++--------
.../spark/sql/DataFrameNaFunctionsSuite.scala | 45 ++++++++++++++++
2 files changed, 85 insertions(+), 22 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 78df89d..e705635 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -131,20 +131,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 2.2.0
*/
- def fill(value: Long): DataFrame = fill(value, df.columns)
+ def fill(value: Long): DataFrame = fillValue(value, outputAttributes)
/**
* Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
* @since 1.3.1
*/
- def fill(value: Double): DataFrame = fill(value, df.columns)
+ def fill(value: Double): DataFrame = fillValue(value, outputAttributes)
/**
* Returns a new `DataFrame` that replaces null values in string columns with `value`.
*
* @since 1.3.1
*/
- def fill(value: String): DataFrame = fill(value, df.columns)
+ def fill(value: String): DataFrame = fillValue(value, outputAttributes)
/**
* Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
@@ -168,7 +168,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 2.2.0
*/
- def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)
+ def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
@@ -176,7 +176,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols)
+ def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
/**
@@ -193,14 +193,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 1.3.1
*/
- def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
+ def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
/**
* Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
*
* @since 2.3.0
*/
- def fill(value: Boolean): DataFrame = fill(value, df.columns)
+ def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes)
/**
* (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
@@ -208,7 +208,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
*
* @since 2.3.0
*/
- def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols)
+ def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
/**
* Returns a new `DataFrame` that replaces null values in specified boolean columns.
@@ -434,15 +434,24 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
/**
* Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
+ * It selects a column based on its name.
*/
private def fillCol[T](col: StructField, replacement: T): Column = {
val quotedColName = "`" + col.name + "`"
- val colValue = col.dataType match {
+ fillCol(col.dataType, col.name, df.col(quotedColName), replacement)
+ }
+
+ /**
+ * Returns a [[Column]] expression that replaces null value in `expr` with `replacement`.
+ * It uses the given `expr` as a column.
+ */
+ private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = {
+ val colValue = dataType match {
case DoubleType | FloatType =>
- nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
- case _ => df.col(quotedColName)
+ nanvl(expr, lit(null)) // nanvl only supports these types
+ case _ => expr
}
- coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)
+ coalesce(colValue, lit(replacement).cast(dataType)).as(name)
}
/**
@@ -469,12 +478,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
s"Unsupported value type ${v.getClass.getName} ($v).")
}
+ private def toAttributes(cols: Seq[String]): Seq[Attribute] = {
+ cols.map(name => df.col(name).expr).collect {
+ case a: Attribute => a
+ }
+ }
+
+ private def outputAttributes: Seq[Attribute] = {
+ df.queryExecution.analyzed.output
+ }
+
/**
- * Returns a new `DataFrame` that replaces null or NaN values in specified
- * numeric, string columns. If a specified column is not a numeric, string
- * or boolean column it is ignored.
+ * Returns a new `DataFrame` that replaces null or NaN values in the specified
+ * columns. If a specified column is not a numeric, string or boolean column,
+ * it is ignored.
*/
- private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
+ private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = {
// the fill[T] which T is Long/Double,
// should apply on all the NumericType Column, for example:
// val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
@@ -488,9 +507,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
s"Unsupported value type ${value.getClass.getName} ($value).")
}
- val columnEquals = df.sparkSession.sessionState.analyzer.resolver
- val projections = df.schema.fields.map { f =>
- val typeMatches = (targetType, f.dataType) match {
+ val projections = outputAttributes.map { col =>
+ val typeMatches = (targetType, col.dataType) match {
case (NumericType, dt) => dt.isInstanceOf[NumericType]
case (StringType, dt) => dt == StringType
case (BooleanType, dt) => dt == BooleanType
@@ -498,10 +516,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
throw new IllegalArgumentException(s"$targetType is not matched at fillValue")
}
// Only fill if the column is part of the cols list.
- if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
- fillCol[T](f, value)
+ if (typeMatches && cols.exists(_.semanticEquals(col))) {
+ fillCol(col.dataType, col.name, Column(col), value)
} else {
- df.col(f.name)
+ Column(col)
}
}
df.select(projections : _*)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 7cf0d25..c1abd1e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{StringType, StructType}
class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -239,6 +240,33 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
}
}
+ test("fill with col(*)") {
+ val df = createDF()
+ // If columns are specified with "*", they are ignored.
+ checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
+ }
+
+ test("fill with nested columns") {
+ val schema = new StructType()
+ .add("c1", new StructType()
+ .add("c1-1", StringType)
+ .add("c1-2", StringType))
+
+ val data = Seq(
+ Row(Row(null, "a2")),
+ Row(Row("b1", "b2")),
+ Row(null))
+
+ val df = spark.createDataFrame(
+ spark.sparkContext.parallelize(data), schema)
+
+ checkAnswer(df.select("c1.c1-1"),
+ Row(null) :: Row("b1") :: Row(null) :: Nil)
+
+ // Nested columns are ignored for fill().
+ checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
+ }
+
test("replace") {
val input = createDF()
@@ -349,4 +377,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
}
+
+ test("SPARK-29890: duplicate names are allowed for fill() if column names are not specified.") {
+ val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2")
+ val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2")
+ val df = left.join(right, Seq("col1"))
+
+ // If column names are specified, the following fails due to ambiguity.
+ val exception = intercept[AnalysisException] {
+ df.na.fill("hello", Seq("col2"))
+ }
+ assert(exception.getMessage.contains("Reference 'col2' is ambiguous"))
+
+ // If column names are not specified, fill() is applied to all the eligible columns.
+ checkAnswer(
+ df.na.fill("hello"),
+ Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil)
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org