You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/10/17 19:50:44 UTC

spark git commit: [SPARK-22271][SQL] mean overflows and returns null for some decimal variables

Repository: spark
Updated Branches:
  refs/heads/master 75d666b95 -> 28f9f3f22


[SPARK-22271][SQL] mean overflows and returns null for some decimal variables

## What changes were proposed in this pull request?

In Average.scala, it has
```
  override lazy val evaluateExpression = child.dataType match {
    case DecimalType.Fixed(p, s) =>
      // increase the precision and scale to prevent precision loss
      val dt = DecimalType.bounded(p + 14, s + 4)
      Cast(Cast(sum, dt) / Cast(count, dt), resultType)
    case _ =>
      Cast(sum, resultType) / Cast(count, resultType)
  }

  def setChild (newchild: Expression) = {
    child = newchild
  }

```
It is possible that  Cast(count, dt), resultType) will make the precision of the decimal number bigger than 38, and this causes over flow.  Since count is an integer and doesn't need a scale, I will cast it using DecimalType.bounded(38,0)
## How was this patch tested?
In DataFrameSuite, I will add a test case.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Huaxin Gao <hu...@us.ibm.com>

Closes #19496 from huaxingao/spark-22271.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/28f9f3f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/28f9f3f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/28f9f3f2

Branch: refs/heads/master
Commit: 28f9f3f22511e9f2f900764d9bd5b90d2eeee773
Parents: 75d666b
Author: Huaxin Gao <hu...@us.ibm.com>
Authored: Tue Oct 17 12:50:41 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Tue Oct 17 12:50:41 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/aggregate/Average.scala  | 3 ++-
 .../test/scala/org/apache/spark/sql/DataFrameSuite.scala    | 9 +++++++++
 2 files changed, 11 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/28f9f3f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index c423e17..708bdbf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -80,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit
     case DecimalType.Fixed(p, s) =>
       // increase the precision and scale to prevent precision loss
       val dt = DecimalType.bounded(p + 14, s + 4)
-      Cast(Cast(sum, dt) / Cast(count, dt), resultType)
+      Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)),
+        resultType)
     case _ =>
       Cast(sum, resultType) / Cast(count, resultType)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/28f9f3f2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 50de2fd..473c355 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -2105,4 +2105,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
       testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)),
       Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2)))
   }
+
+  test("SPARK-22271: mean overflows and returns null for some decimal variables") {
+    val d = 0.034567890
+    val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol")
+    val result = df.select('DecimalCol cast DecimalType(38, 33))
+      .select(col("DecimalCol")).describe()
+    val mean = result.select("DecimalCol").where($"summary" === "mean")
+    assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000")))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org