You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/06/22 12:02:08 UTC

[spark] branch branch-3.0 updated: [SPARK-32038][SQL] NormalizeFloatingNumbers should also work on distinct aggregate

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 92f3877  [SPARK-32038][SQL] NormalizeFloatingNumbers should also work on distinct aggregate
92f3877 is described below

commit 92f3877b0f4451100ba2aa6cac6b66900f870951
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Mon Jun 22 04:58:22 2020 -0700

    [SPARK-32038][SQL] NormalizeFloatingNumbers should also work on distinct aggregate
    
    ### What changes were proposed in this pull request?
    
    This patch applies `NormalizeFloatingNumbers` to distinct aggregate to fix a regression of distinct aggregate on NaNs.
    
    ### Why are the changes needed?
    
    We added `NormalizeFloatingNumbers` optimization rule in 3.0.0 to normalize special floating numbers (NaN and -0.0). But it is missing in distinct aggregate so causes a regression. We need to apply this rule on distinct aggregate to fix it.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, fixing a regression of distinct aggregate on NaNs.
    
    ### How was this patch tested?
    
    Added unit test.
    
    Closes #28876 from viirya/SPARK-32038.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
    (cherry picked from commit 2e4557f45ce65ad0cf501c1734f2d4a50b00af54)
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../apache/spark/sql/execution/SparkStrategies.scala   | 18 ++++++++++++++++++
 .../spark/sql/execution/aggregate/AggUtils.scala       | 16 ++++------------
 .../org/apache/spark/sql/DataFrameAggregateSuite.scala | 16 ++++++++++++++++
 3 files changed, 38 insertions(+), 12 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 12a1a1e..10cd7b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -539,10 +539,28 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
               resultExpressions,
               planLater(child))
           } else {
+            // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain
+            // more than one DISTINCT aggregate function, all of those functions will have the
+            // same column expressions. For example, it would be valid for functionsWithDistinct
+            // to be [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but
+            // [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct
+            // aggregates have different column expressions.
+            val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children
+            val normalizedNamedDistinctExpressions = distinctExpressions.map { e =>
+              // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here
+              // because `distinctExpressions` is not extracted during logical phase.
+              NormalizeFloatingNumbers.normalize(e) match {
+                case ne: NamedExpression => ne
+                case other => Alias(other, other.toString)()
+              }
+            }
+
             AggUtils.planAggregateWithOneDistinct(
               normalizedGroupingExpressions,
               functionsWithDistinct,
               functionsWithoutDistinct,
+              distinctExpressions,
+              normalizedNamedDistinctExpressions,
               resultExpressions,
               planLater(child))
           }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 56a287d..761ac20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -135,20 +135,12 @@ object AggUtils {
       groupingExpressions: Seq[NamedExpression],
       functionsWithDistinct: Seq[AggregateExpression],
       functionsWithoutDistinct: Seq[AggregateExpression],
+      distinctExpressions: Seq[Expression],
+      normalizedNamedDistinctExpressions: Seq[NamedExpression],
       resultExpressions: Seq[NamedExpression],
       child: SparkPlan): Seq[SparkPlan] = {
 
-    // functionsWithDistinct is guaranteed to be non-empty. Even though it may contain more than one
-    // DISTINCT aggregate function, all of those functions will have the same column expressions.
-    // For example, it would be valid for functionsWithDistinct to be
-    // [COUNT(DISTINCT foo), MAX(DISTINCT foo)], but [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is
-    // disallowed because those two distinct aggregates have different column expressions.
-    val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children
-    val namedDistinctExpressions = distinctExpressions.map {
-      case ne: NamedExpression => ne
-      case other => Alias(other, other.toString)()
-    }
-    val distinctAttributes = namedDistinctExpressions.map(_.toAttribute)
+    val distinctAttributes = normalizedNamedDistinctExpressions.map(_.toAttribute)
     val groupingAttributes = groupingExpressions.map(_.toAttribute)
 
     // 1. Create an Aggregate Operator for partial aggregations.
@@ -159,7 +151,7 @@ object AggUtils {
       // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping
       // expressions will be [key, value].
       createAggregate(
-        groupingExpressions = groupingExpressions ++ namedDistinctExpressions,
+        groupingExpressions = groupingExpressions ++ normalizedNamedDistinctExpressions,
         aggregateExpressions = aggregateExpressions,
         aggregateAttributes = aggregateAttributes,
         resultExpressions = groupingAttributes ++ distinctAttributes ++
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 2293d4a..f7438f3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1012,4 +1012,20 @@ class DataFrameAggregateSuite extends QueryTest
       }
     }
   }
+
+  test("SPARK-32038: NormalizeFloatingNumbers should work on distinct aggregate") {
+    withTempView("view") {
+      val nan1 = java.lang.Float.intBitsToFloat(0x7f800001)
+      val nan2 = java.lang.Float.intBitsToFloat(0x7fffffff)
+
+      Seq(("mithunr", Float.NaN),
+        ("mithunr", nan1),
+        ("mithunr", nan2),
+        ("abellina", 1.0f),
+        ("abellina", 2.0f)).toDF("uid", "score").createOrReplaceTempView("view")
+
+      val df = spark.sql("select uid, count(distinct score) from view group by 1 order by 1 asc")
+      checkAnswer(df, Row("abellina", 2) :: Row("mithunr", 1) :: Nil)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org