You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/01/31 00:53:33 UTC

[spark] branch branch-2.4 updated: [SPARK-29890][SQL][2.4] DataFrameNaFunctions.fill should handle duplicate columns

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new 4c3c1d6  [SPARK-29890][SQL][2.4] DataFrameNaFunctions.fill should handle duplicate columns
4c3c1d6 is described below

commit 4c3c1d6b2a3db18682abaa0250f62ec198f36fdb
Author: Terry Kim <yu...@gmail.com>
AuthorDate: Thu Jan 30 16:52:25 2020 -0800

    [SPARK-29890][SQL][2.4] DataFrameNaFunctions.fill should handle duplicate columns
    
    (Backport of #26593)
    
    ### What changes were proposed in this pull request?
    
    `DataFrameNaFunctions.fill` doesn't handle duplicate columns even when column names are not specified.
    
    ```Scala
    val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2")
    val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2")
    val df = left.join(right, Seq("col1"))
    df.printSchema
    df.na.fill("hello").show
    ```
    produces
    ```
    root
     |-- col1: string (nullable = true)
     |-- col2: string (nullable = true)
     |-- col2: string (nullable = true)
    
    org.apache.spark.sql.AnalysisException: Reference 'col2' is ambiguous, could be: col2, col2.;
      at org.apache.spark.sql.catalyst.expressions.package$AttributeSeq.resolve(package.scala:259)
      at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.resolveQuoted(LogicalPlan.scala:121)
      at org.apache.spark.sql.Dataset.resolve(Dataset.scala:221)
      at org.apache.spark.sql.Dataset.col(Dataset.scala:1268)
    ```
    The reason for the above failure is that columns are looked up with `DataSet.col()` which tries to resolve a column by name and if there are multiple columns with the same name, it will fail due to ambiguity.
    
    This PR updates `DataFrameNaFunctions.fill` such that if the columns to fill are not specified, it will resolve ambiguity gracefully by applying `fill` to all the eligible columns. (Note that if the user specifies the columns, it will still continue to fail due to ambiguity).
    
    ### Why are the changes needed?
    
    If column names are not specified, `fill` should not fail due to ambiguity since it should still be able to apply `fill` to the eligible columns.
    
    ### Does this PR introduce any user-facing change?
    
    Yes, now the above example displays the following:
    ```
    +----+-----+-----+
    |col1| col2| col2|
    +----+-----+-----+
    |   1|hello|    2|
    |   3|    4|hello|
    +----+-----+-----+
    
    ```
    
    ### How was this patch tested?
    
    Added new unit tests.
    
    Closes #27407 from imback82/backport-SPARK-29890.
    
    Authored-by: Terry Kim <yu...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../apache/spark/sql/DataFrameNaFunctions.scala    | 62 ++++++++++++++--------
 .../spark/sql/DataFrameNaFunctionsSuite.scala      | 45 ++++++++++++++++
 2 files changed, 85 insertions(+), 22 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
index 78df89d..e705635 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala
@@ -131,20 +131,20 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    *
    * @since 2.2.0
    */
-  def fill(value: Long): DataFrame = fill(value, df.columns)
+  def fill(value: Long): DataFrame = fillValue(value, outputAttributes)
 
   /**
    * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`.
    * @since 1.3.1
    */
-  def fill(value: Double): DataFrame = fill(value, df.columns)
+  def fill(value: Double): DataFrame = fillValue(value, outputAttributes)
 
   /**
    * Returns a new `DataFrame` that replaces null values in string columns with `value`.
    *
    * @since 1.3.1
    */
-  def fill(value: String): DataFrame = fill(value, df.columns)
+  def fill(value: String): DataFrame = fillValue(value, outputAttributes)
 
   /**
    * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns.
@@ -168,7 +168,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    *
    * @since 2.2.0
    */
-  def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, cols)
+  def fill(value: Long, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
 
   /**
    * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified
@@ -176,7 +176,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    *
    * @since 1.3.1
    */
-  def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, cols)
+  def fill(value: Double, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
 
 
   /**
@@ -193,14 +193,14 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    *
    * @since 1.3.1
    */
-  def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, cols)
+  def fill(value: String, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
 
   /**
    * Returns a new `DataFrame` that replaces null values in boolean columns with `value`.
    *
    * @since 2.3.0
    */
-  def fill(value: Boolean): DataFrame = fill(value, df.columns)
+  def fill(value: Boolean): DataFrame = fillValue(value, outputAttributes)
 
   /**
    * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified
@@ -208,7 +208,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
    *
    * @since 2.3.0
    */
-  def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, cols)
+  def fill(value: Boolean, cols: Seq[String]): DataFrame = fillValue(value, toAttributes(cols))
 
   /**
    * Returns a new `DataFrame` that replaces null values in specified boolean columns.
@@ -434,15 +434,24 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
 
   /**
    * Returns a [[Column]] expression that replaces null value in `col` with `replacement`.
+   * It selects a column based on its name.
    */
   private def fillCol[T](col: StructField, replacement: T): Column = {
     val quotedColName = "`" + col.name + "`"
-    val colValue = col.dataType match {
+    fillCol(col.dataType, col.name, df.col(quotedColName), replacement)
+  }
+
+  /**
+   * Returns a [[Column]] expression that replaces null value in `expr` with `replacement`.
+   * It uses the given `expr` as a column.
+   */
+  private def fillCol[T](dataType: DataType, name: String, expr: Column, replacement: T): Column = {
+    val colValue = dataType match {
       case DoubleType | FloatType =>
-        nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types
-      case _ => df.col(quotedColName)
+        nanvl(expr, lit(null)) // nanvl only supports these types
+      case _ => expr
     }
-    coalesce(colValue, lit(replacement).cast(col.dataType)).as(col.name)
+    coalesce(colValue, lit(replacement).cast(dataType)).as(name)
   }
 
   /**
@@ -469,12 +478,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
       s"Unsupported value type ${v.getClass.getName} ($v).")
   }
 
+  private def toAttributes(cols: Seq[String]): Seq[Attribute] = {
+    cols.map(name => df.col(name).expr).collect {
+      case a: Attribute => a
+    }
+  }
+
+  private def outputAttributes: Seq[Attribute] = {
+    df.queryExecution.analyzed.output
+  }
+
   /**
-   * Returns a new `DataFrame` that replaces null or NaN values in specified
-   * numeric, string columns. If a specified column is not a numeric, string
-   * or boolean column it is ignored.
+   * Returns a new `DataFrame` that replaces null or NaN values in the specified
+   * columns. If a specified column is not a numeric, string or boolean column,
+   * it is ignored.
    */
-  private def fillValue[T](value: T, cols: Seq[String]): DataFrame = {
+  private def fillValue[T](value: T, cols: Seq[Attribute]): DataFrame = {
     // the fill[T] which T is  Long/Double,
     // should apply on all the NumericType Column, for example:
     // val input = Seq[(java.lang.Integer, java.lang.Double)]((null, 164.3)).toDF("a","b")
@@ -488,9 +507,8 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
         s"Unsupported value type ${value.getClass.getName} ($value).")
     }
 
-    val columnEquals = df.sparkSession.sessionState.analyzer.resolver
-    val projections = df.schema.fields.map { f =>
-      val typeMatches = (targetType, f.dataType) match {
+    val projections = outputAttributes.map { col =>
+      val typeMatches = (targetType, col.dataType) match {
         case (NumericType, dt) => dt.isInstanceOf[NumericType]
         case (StringType, dt) => dt == StringType
         case (BooleanType, dt) => dt == BooleanType
@@ -498,10 +516,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) {
           throw new IllegalArgumentException(s"$targetType is not matched at fillValue")
       }
       // Only fill if the column is part of the cols list.
-      if (typeMatches && cols.exists(col => columnEquals(f.name, col))) {
-        fillCol[T](f, value)
+      if (typeMatches && cols.exists(_.semanticEquals(col))) {
+        fillCol(col.dataType, col.name, Column(col), value)
       } else {
-        df.col(f.name)
+        Column(col)
       }
     }
     df.select(projections : _*)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index 7cf0d25..c1abd1e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -21,6 +21,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{StringType, StructType}
 
 class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -239,6 +240,33 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("fill with col(*)") {
+    val df = createDF()
+    // If columns are specified with "*", they are ignored.
+    checkAnswer(df.na.fill("new name", Seq("*")), df.collect())
+  }
+
+  test("fill with nested columns") {
+    val schema = new StructType()
+      .add("c1", new StructType()
+        .add("c1-1", StringType)
+        .add("c1-2", StringType))
+
+    val data = Seq(
+      Row(Row(null, "a2")),
+      Row(Row("b1", "b2")),
+      Row(null))
+
+    val df = spark.createDataFrame(
+      spark.sparkContext.parallelize(data), schema)
+
+    checkAnswer(df.select("c1.c1-1"),
+      Row(null) :: Row("b1") :: Row(null) :: Nil)
+
+    // Nested columns are ignored for fill().
+    checkAnswer(df.na.fill("a1", Seq("c1.c1-1")), data)
+  }
+
   test("replace") {
     val input = createDF()
 
@@ -349,4 +377,21 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
       Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) ::
       Row(0, 0L, 0.toShort, 0.toByte, Float.NaN, Double.NaN) :: Nil)
   }
+
+  test("SPARK-29890: duplicate names are allowed for fill() if column names are not specified.") {
+    val left = Seq(("1", null), ("3", "4")).toDF("col1", "col2")
+    val right = Seq(("1", "2"), ("3", null)).toDF("col1", "col2")
+    val df = left.join(right, Seq("col1"))
+
+    // If column names are specified, the following fails due to ambiguity.
+    val exception = intercept[AnalysisException] {
+      df.na.fill("hello", Seq("col2"))
+    }
+    assert(exception.getMessage.contains("Reference 'col2' is ambiguous"))
+
+    // If column names are not specified, fill() is applied to all the eligible columns.
+    checkAnswer(
+      df.na.fill("hello"),
+      Row("1", "hello", "2") :: Row("3", "4", "hello") :: Nil)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org