You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2021/10/05 09:13:46 UTC
[spark] branch master updated: [SPARK-36926][SQL] Decimal average
mistakenly overflow
This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new fb919af [SPARK-36926][SQL] Decimal average mistakenly overflow
fb919af is described below
commit fb919afac7e785fbb6d2b2507495437b7536e5f2
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Tue Oct 5 17:13:06 2021 +0800
[SPARK-36926][SQL] Decimal average mistakenly overflow
### What changes were proposed in this pull request?
This bug was introduced by https://github.com/apache/spark/pull/33177
When checking overflow of the sum value in the average function, we should use the `sumDataType` instead of the input decimal type.
### Why are the changes needed?
fix a regression
### Does this PR introduce _any_ user-facing change?
Yes, the result was wrong before this PR.
### How was this patch tested?
a new test
Closes #34180 from cloud-fan/bug.
Lead-authored-by: Wenchen Fan <we...@databricks.com>
Co-authored-by: Wenchen Fan <cl...@gmail.com>
Signed-off-by: Gengliang Wang <ge...@apache.org>
---
.../apache/spark/sql/catalyst/expressions/aggregate/Average.scala | 4 ++--
.../test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 6 ++++++
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 7ede3fc..9714a09 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -94,10 +94,10 @@ case class Average(
// If all input are nulls, count will be 0 and we will get null after the division.
// We can't directly use `/` as it throws an exception under ansi mode.
override lazy val evaluateExpression = child.dataType match {
- case d: DecimalType =>
+ case _: DecimalType =>
DecimalPrecision.decimalAndDecimal()(
Divide(
- CheckOverflowInSum(sum, d, !failOnError),
+ CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !failOnError),
count.cast(DecimalType.LongDecimal), failOnError = false)).cast(resultType)
case _: YearMonthIntervalType =>
If(EqualTo(count, Literal(0L)),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index d0a122e..1f8638c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1427,6 +1427,12 @@ class DataFrameAggregateSuite extends QueryTest
assert (df.schema == expectedSchema)
checkAnswer(df, Seq(Row(LocalDateTime.parse(ts1), 2), Row(LocalDateTime.parse(ts2), 1)))
}
+
+ test("SPARK-36926: decimal average mistakenly overflow") {
+ val df = (1 to 10).map(_ => "9999999999.99").toDF("d")
+ val res = df.select($"d".cast("decimal(12, 2)").as("d")).agg(avg($"d").cast("string"))
+ checkAnswer(res, Row("9999999999.990000"))
+ }
}
case class B(c: Option[Double])
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org