You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/10/17 19:50:44 UTC
spark git commit: [SPARK-22271][SQL] mean overflows and returns null
for some decimal variables
Repository: spark
Updated Branches:
refs/heads/master 75d666b95 -> 28f9f3f22
[SPARK-22271][SQL] mean overflows and returns null for some decimal variables
## What changes were proposed in this pull request?
In Average.scala, it has
```
override lazy val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
Cast(Cast(sum, dt) / Cast(count, dt), resultType)
case _ =>
Cast(sum, resultType) / Cast(count, resultType)
}
def setChild (newchild: Expression) = {
child = newchild
}
```
It is possible that Cast(count, dt), resultType) will make the precision of the decimal number bigger than 38, and this causes over flow. Since count is an integer and doesn't need a scale, I will cast it using DecimalType.bounded(38,0)
## How was this patch tested?
In DataFrameSuite, I will add a test case.
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Huaxin Gao <hu...@us.ibm.com>
Closes #19496 from huaxingao/spark-22271.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/28f9f3f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/28f9f3f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/28f9f3f2
Branch: refs/heads/master
Commit: 28f9f3f22511e9f2f900764d9bd5b90d2eeee773
Parents: 75d666b
Author: Huaxin Gao <hu...@us.ibm.com>
Authored: Tue Oct 17 12:50:41 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Tue Oct 17 12:50:41 2017 -0700
----------------------------------------------------------------------
.../spark/sql/catalyst/expressions/aggregate/Average.scala | 3 ++-
.../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++
2 files changed, 11 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/28f9f3f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index c423e17..708bdbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -80,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
- Cast(Cast(sum, dt) / Cast(count, dt), resultType)
+ Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)),
+ resultType)
case _ =>
Cast(sum, resultType) / Cast(count, resultType)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/28f9f3f2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 50de2fd..473c355 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2105,4 +2105,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)),
Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2)))
}
+
+ test("SPARK-22271: mean overflows and returns null for some decimal variables") {
+ val d = 0.034567890
+ val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol")
+ val result = df.select('DecimalCol cast DecimalType(38, 33))
+ .select(col("DecimalCol")).describe()
+ val mean = result.select("DecimalCol").where($"summary" === "mean")
+ assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org