You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/08/13 03:56:35 UTC

[spark] branch branch-3.0 updated: [SPARK-32018][SQL][FOLLOWUP][3.0] Throw exception on decimal value overflow of sum aggregation

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 89765f5  [SPARK-32018][SQL][FOLLOWUP][3.0] Throw exception on decimal value overflow of sum aggregation
89765f5 is described below

commit 89765f556f26252aed1add71a9da84209ff03493
Author: Gengliang Wang <ge...@databricks.com>
AuthorDate: Thu Aug 13 03:52:12 2020 +0000

    [SPARK-32018][SQL][FOLLOWUP][3.0] Throw exception on decimal value overflow of sum aggregation
    
    ### What changes were proposed in this pull request?
    
    This is a followup of https://github.com/apache/spark/pull/29125
    In branch 3.0:
    1. for hash aggregation, before https://github.com/apache/spark/pull/29125 there will be a runtime exception on decimal overflow of sum aggregation; after https://github.com/apache/spark/pull/29125, there could be a wrong result.
    2. for sort aggregation, with/without https://github.com/apache/spark/pull/29125, there could be a wrong result on decimal overflow.
    
    While in master branch(the future 3.1 release), the problem doesn't exist since in https://github.com/apache/spark/pull/27627 there is a flag for marking whether overflow happens in aggregation buffer. However, the aggregation buffer is written in steaming checkpoints. Thus, we can't change to aggregation buffer to resolve the issue.
    
    As there is no easy solution for returning null/throwing exception regarding `spark.sql.ansi.enabled` on overflow in branch 3.0, we have to make a choice here: always throw exception on decimal value overflow of sum aggregation.
    ### Why are the changes needed?
    
    Avoid returning wrong result in decimal value sum aggregation.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, there is always exception on decimal value overflow of sum aggregation, instead of a possible wrong result.
    
    ### How was this patch tested?
    
    Unit test case
    
    Closes #29404 from gengliangwang/fixSum.
    
    Authored-by: Gengliang Wang <ge...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/expressions/aggregate/Sum.scala   | 19 +++++++++--
 .../apache/spark/sql/DataFrameAggregateSuite.scala | 37 ++++++++++++++++++++++
 2 files changed, 53 insertions(+), 3 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index d2daaac..d442549 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -71,23 +71,36 @@ case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCast
   )
 
   override lazy val updateExpressions: Seq[Expression] = {
+    val sumWithChild = resultType match {
+      case d: DecimalType =>
+        CheckOverflow(coalesce(sum, zero) + child.cast(sumDataType), d, nullOnOverflow = false)
+      case _ =>
+        coalesce(sum, zero) + child.cast(sumDataType)
+    }
+
     if (child.nullable) {
       Seq(
         /* sum = */
-        coalesce(coalesce(sum, zero) + child.cast(sumDataType), sum)
+        coalesce(sumWithChild, sum)
       )
     } else {
       Seq(
         /* sum = */
-        coalesce(sum, zero) + child.cast(sumDataType)
+        sumWithChild
       )
     }
   }
 
   override lazy val mergeExpressions: Seq[Expression] = {
+    val sumWithRight = resultType match {
+      case d: DecimalType =>
+        CheckOverflow(coalesce(sum.left, zero) + sum.right, d, nullOnOverflow = false)
+
+      case _ => coalesce(sum.left, zero) + sum.right
+    }
     Seq(
       /* sum = */
-      coalesce(coalesce(sum.left, zero) + sum.right, sum.left)
+      coalesce(sumWithRight, sum.left)
     )
   }
 
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 54327b3..8c0358e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
 
 import org.scalatest.Matchers.the
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
@@ -1044,6 +1045,42 @@ class DataFrameAggregateSuite extends QueryTest
     checkAnswer(sql(queryTemplate("FIRST")), Row(1))
     checkAnswer(sql(queryTemplate("LAST")), Row(3))
   }
+
+  private def exceptionOnDecimalOverflow(df: DataFrame): Unit = {
+    val msg = intercept[SparkException] {
+      df.collect()
+    }.getCause.getMessage
+    assert(msg.contains("cannot be represented as Decimal(38, 18)"))
+  }
+
+  test("SPARK-32018: Throw exception on decimal overflow at partial aggregate phase") {
+    val decimalString = "1" + "0" * 19
+    val union = spark.range(0, 1, 1, 1).union(spark.range(0, 11, 1, 1))
+    val hashAgg = union
+      .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key"))
+      .groupBy("key")
+      .agg(sum($"d").alias("sumD"))
+      .select($"sumD")
+    exceptionOnDecimalOverflow(hashAgg)
+
+    val sortAgg = union
+      .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("a").as("str"),
+      lit("1").as("key")).groupBy("key")
+      .agg(sum($"d").alias("sumD"), min($"str").alias("minStr")).select($"sumD", $"minStr")
+    exceptionOnDecimalOverflow(sortAgg)
+  }
+
+  test("SPARK-32018: Throw exception on decimal overflow at merge aggregation phase") {
+    val decimalString = "5" + "0" * 19
+    val union = spark.range(0, 1, 1, 1).union(spark.range(0, 1, 1, 1))
+      .union(spark.range(0, 1, 1, 1))
+    val agg = union
+      .select(expr(s"cast('$decimalString' as decimal (38, 18)) as d"), lit("1").as("key"))
+      .groupBy("key")
+      .agg(sum($"d").alias("sumD"))
+      .select($"sumD")
+    exceptionOnDecimalOverflow(agg)
+  }
 }
 
 case class B(c: Option[Double])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org