You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2021/10/05 09:13:46 UTC

[spark] branch master updated: [SPARK-36926][SQL] Decimal average mistakenly overflow

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fb919af  [SPARK-36926][SQL] Decimal average mistakenly overflow
fb919af is described below

commit fb919afac7e785fbb6d2b2507495437b7536e5f2
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Tue Oct 5 17:13:06 2021 +0800

    [SPARK-36926][SQL] Decimal average mistakenly overflow
    
    ### What changes were proposed in this pull request?
    
    This bug was introduced by https://github.com/apache/spark/pull/33177
    
    When checking overflow of the sum value in the average function, we should use the `sumDataType` instead of the input decimal type.
    
    ### Why are the changes needed?
    
    fix a regression
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the result was wrong before this PR.
    
    ### How was this patch tested?
    
    a new test
    
    Closes #34180 from cloud-fan/bug.
    
    Lead-authored-by: Wenchen Fan <we...@databricks.com>
    Co-authored-by: Wenchen Fan <cl...@gmail.com>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 .../apache/spark/sql/catalyst/expressions/aggregate/Average.scala   | 4 ++--
 .../test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala   | 6 ++++++
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 7ede3fc..9714a09 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -94,10 +94,10 @@ case class Average(
   // If all input are nulls, count will be 0 and we will get null after the division.
   // We can't directly use `/` as it throws an exception under ansi mode.
   override lazy val evaluateExpression = child.dataType match {
-    case d: DecimalType =>
+    case _: DecimalType =>
       DecimalPrecision.decimalAndDecimal()(
         Divide(
-          CheckOverflowInSum(sum, d, !failOnError),
+          CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !failOnError),
           count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
     case _: YearMonthIntervalType =>
       If(EqualTo(count, Literal(0L)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index d0a122e..1f8638c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1427,6 +1427,12 @@ class DataFrameAggregateSuite extends QueryTest
     assert (df.schema == expectedSchema)
     checkAnswer(df, Seq(Row(LocalDateTime.parse(ts1), 2), Row(LocalDateTime.parse(ts2), 1)))
   }
+
+  test("SPARK-36926: decimal average mistakenly overflow") {
+    val df = (1 to 10).map(_ => "9999999999.99").toDF("d")
+    val res = df.select($"d".cast("decimal(12, 2)").as("d")).agg(avg($"d").cast("string"))
+    checkAnswer(res, Row("9999999999.990000"))
+  }
 }
 
 case class B(c: Option[Double])

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org