You are viewing a plain text version of this content. The canonical link for it is here.
Posted to issues@spark.apache.org by "Robert Joseph Evans (Jira)" <ji...@apache.org> on 2021/10/16 12:31:00 UTC

[jira] [Created] (SPARK-37024) Even more decimal overflow issues in average

Robert Joseph Evans created SPARK-37024:
-------------------------------------------

             Summary: Even more decimal overflow issues in average
                 Key: SPARK-37024
                 URL: https://issues.apache.org/jira/browse/SPARK-37024
             Project: Spark
          Issue Type: Bug
          Components: SQL
    Affects Versions: 3.2.0
            Reporter: Robert Joseph Evans


As a part of trying to accelerate the {{Decimal}} average aggregation on a [GPU|https://nvidia.github.io/spark-rapids/] I noticed a few issues around overflow. I think all of these can be fixed by replacing {{Average}} with explicit {{Sum}}, {{Count}}, and {{Divide}} operations for decimal instead of implicitly doing them. But the extra checks would come with a performance cost.

This is related to SPARK-35955, but goes quite a bit beyond it.
 # There are no ANSI overflow checks on the summation portion of average.
 # Nulls are inserted/overflow is detected on summation differently depending on code generation and parallelism.
 # If the input decimal precision is 11 or below all overflow checks are disabled, and the answer is wrong instead of null on overflow.

*Details:*

*there are no ANSI overflow checks on the summation portion.*
{code:scala}
scala> spark.conf.set("spark.sql.ansi.enabled", "true")

scala> spark.time(spark.range(2000001)
    .repartition(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 0)) as v")
    .selectExpr("AVG(v)")
    .show(truncate = false))
+------+
|avg(v)|
+------+
|null  |
+------+

Time taken: 622 ms

scala> spark.time(spark.range(2000001)
    .repartition(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 0)) as v")
    .selectExpr("SUM(v)")
    .show(truncate = false))
21/10/16 06:08:00 ERROR Executor: Exception in task 0.0 in stage 15.0 (TID 19)
java.lang.ArithmeticException: Overflow in sum of decimals.
...
{code}
*nulls are inserted on summation overflow differently depending on code generation and parallelism.*

Because there are no explicit overflow checks when doing the sum a user can get very inconsistent results for when a null is inserted on overflow. The checks really only take place when the {{Decimal}} value is converted and stored into an {{UnsafeRow}}.  This happens when the values are shuffled, or after each operation if code gen is disabled.  For a {{DECIMAL(32, 0)}} you can add 1,000,000 max values before the summation overflows.
{code:scala}
scala> spark.time(spark.range(1000000)
    .coalesce(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 0)) as v")
    .selectExpr("SUM(v) as s", "COUNT(v) as c", "AVG(v) as a")
    .selectExpr("s", "c", "s/c as sum_div_count", "a")
    .show(truncate = false))
+--------------------------------------+-------+---------------------------------------+-------------------------------------+
|s                                     |c      |sum_div_count                          |a                                    |
+--------------------------------------+-------+---------------------------------------+-------------------------------------+
|99999999999999999999999999999999000000|1000000|99999999999999999999999999999999.000000|99999999999999999999999999999999.0000|
+--------------------------------------+-------+---------------------------------------+-------------------------------------+
Time taken: 241 ms

scala> spark.time(spark.range(2000000)
    .coalesce(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 0)) as v")
    .selectExpr("SUM(v) as s", "COUNT(v) as c", "AVG(v) as a")
    .selectExpr("s", "c", "s/c as sum_div_count", "a")
    .show(truncate = false))
+----+-------+-------------+-------------------------------------+
|s   |c      |sum_div_count|a                                    |
+----+-------+-------------+-------------------------------------+
|null|2000000|null         |99999999999999999999999999999999.0000|
+----+-------+-------------+-------------------------------------+
Time taken: 228 ms

scala> spark.time(spark.range(3000000)
    .coalesce(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 0)) as v")
    .selectExpr("SUM(v) as s", "COUNT(v) as c", "AVG(v) as a")
    .selectExpr("s", "c", "s/c as sum_div_count", "a")
    .show(truncate = false))
+----+-------+-------------+----+
|s   |c      |sum_div_count|a   |
+----+-------+-------------+----+
|null|3000000|null         |null|
+----+-------+-------------+----+
Time taken: 347 ms

scala> spark.conf.set("spark.sql.codegen.wholeStage", "false")
scala> spark.time(spark.range(1000001)
    .coalesce(2)
    .selectExpr("id", "CAST('99999999999999999999999999999999' AS DECIMAL(32, 0)) as v")
    .selectExpr("SUM(v) as s", "COUNT(v) as c", "AVG(v) as a")
    .selectExpr("s", "c", "s/c as sum_div_count", "a")
    .show(truncate = false))
+----+-------+-------------+----+
|s   |c      |sum_div_count|a   |
+----+-------+-------------+----+
|null|1000001|null         |null|
+----+-------+-------------+----+
Time taken: 310 ms
{code}
With code gen disabled the limit is enforced on 1,000,001 entries, just like with sum, but if code gen is enabled it depends on the number of upstream tasks and the order of the data, which means if I change the size of the cluster I am running on I might get different results from one run to another.

*if the input decimal precision is 11 or below all overflow checks are disabled.*

When the precision of a {{Decimal}} in an average is 11 or below, the average will be done in terms of a {{Double}}.  The logic is that a {{Double}} has around 56 bits of precision which should give us {{MAX_DOUBLE_DIGITS = 15}} worth of decimal precision. The main problem is that the 15 digit limit is compared to the output precision, which is input precision + 4, instead of being compared to the summation precision, which is input [precision +10|https://github.com/apache/spark/blob/67b547aa1cb7aeaf0f6e1f1017d21f582dacf697/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L73]. So instead of getting a full 10 billion entries before an overflow happens. In the worst case I was able to see it happen at just 90,100 entries.
{code:scala}
scala> spark.time(spark.range(90100)
    .repartition(1)
    .selectExpr("id", "CAST('99999999999' AS DECIMAL(11, 0)) as v")
    .selectExpr("AVG(v)")
    .show(truncate = false))
+----------------+
|avg(v)          |
+----------------+
|99999999999.0003|
+----------------+
Time taken: 86 ms
{code}
{{Sum}} has a similar optimization but does its calculations as a long. {{Sum}} guarantees 10 billion entries before overflowing (which is what the +10 precision is for), and in practice can actually sum over 90-billion max values in the worst case before the overflow gets to a point where sum can no longer detect it and produces an incorrect answer.  For me personally I feel okay with taking a risk that my data has 90 billion max values in it vs ~100,000 of them.

*Performance cost*

But there would be a performance cost in switching to SUM/COUNT.  There is the cost of the checks in general that sum does.
{code:scala}
scala> spark.time(spark.range(Int.MaxValue)
    .selectExpr("id", "CAST('999999.99' AS DECIMAL(8, 2)) as v")
    .selectExpr("ROUND(SUM(v)/COUNT(v), 6) as sum_cnt_avg")
    .show(truncate = false))
+-------------+                                                                 
|sum_cnt_avg  |
+-------------+
|999999.990000|
+-------------+

Time taken: 2201 ms

scala> spark.time(spark.range(Int.MaxValue)
    .selectExpr("id", "CAST('999999.99' AS DECIMAL(8, 2)) as v")
    .selectExpr("AVG(v)")
    .show(truncate = false))
+-------------+                                                                 
|avg(v)       |
+-------------+
|999999.994967|
+-------------+

Time taken: 1845 ms
{code}
But there is also the cost that averages on Decimal values with a precision of 9, 10, or 11 would no longer have a performance optimization by doing them as longs/doubles.
{code:scala}
scala> spark.time(spark.range(Int.MaxValue)
    .selectExpr("id", "CAST('999999.99' AS DECIMAL(9, 2)) as v")
    .selectExpr("ROUND(SUM(v)/COUNT(v), 6) as sum_cnt_avg")
    .show(truncate = false))
+-------------+                                                                 
|sum_cnt_avg  |
+-------------+
|999999.990000|
+-------------+
Time taken: 13252 ms

scala> spark.time(spark.range(Int.MaxValue)
    .selectExpr("id", "CAST('999999.99' AS DECIMAL(9, 2)) as v")
    .selectExpr("AVG(v)")
    .show(truncate = false))
+-------------+                                                                 
|avg(v)       |
+-------------+
|999999.994967|
+-------------+
Time taken: 1928 ms
{code}
 



--
This message was sent by Atlassian Jira
(v8.3.4#803005)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscribe@spark.apache.org
For additional commands, e-mail: issues-help@spark.apache.org