You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2022/07/19 09:03:04 UTC

[spark] branch master updated: [SPARK-39792][SQL] Add DecimalDivideWithOverflowCheck for decimal average

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 46071a9caa2 [SPARK-39792][SQL] Add DecimalDivideWithOverflowCheck for decimal average
46071a9caa2 is described below

commit 46071a9caa26b991bdd5bb0a3505a0ba76d16d0e
Author: ulysses-you <ul...@gmail.com>
AuthorDate: Tue Jul 19 02:02:52 2022 -0700

    [SPARK-39792][SQL] Add DecimalDivideWithOverflowCheck for decimal average
    
    ### What changes were proposed in this pull request?
    
    Add a new expression `DecimalDivideWithOverflowCheck` to replace the previous CheckOverflowInSum + Divide + Cast.
    
    ### Why are the changes needed?
    
    If the result data type is decimal, the Average will first calculate the result using the default precison and scale of divide, then cast to the result data type. We should do calculate and return the result data type directly so that we can avoid the precision loss. It can also save one unnecessary cast.
    
    And for the overflow check, we should check the result of divide whether overflow instead of the dividend.
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, a small bug fix
    
    ### How was this patch tested?
    
    add a test and fix test
    
    Closes #37207 from ulysses-you/average-decimal.
    
    Authored-by: ulysses-you <ul...@gmail.com>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 .../catalyst/expressions/aggregate/Average.scala   | 11 ++-
 .../catalyst/expressions/decimalExpressions.scala  | 78 +++++++++++++++++++++-
 .../apache/spark/sql/DataFrameAggregateSuite.scala |  4 ++
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    |  2 +-
 .../sql/hive/execution/AggregationQuerySuite.scala |  2 +-
 5 files changed, 91 insertions(+), 6 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 864ec7055f3..e64f76bdb0a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -88,9 +88,14 @@ abstract class AverageBase
   // We can't directly use `/` as it throws an exception under ansi mode.
   protected def getEvaluateExpression(queryContext: String) = child.dataType match {
     case _: DecimalType =>
-      Divide(
-        CheckOverflowInSum(sum, sumDataType.asInstanceOf[DecimalType], !useAnsiAdd, queryContext),
-        count.cast(DecimalType.LongDecimal), failOnError = false).cast(resultType)
+      If(EqualTo(count, Literal(0L)),
+        Literal(null, resultType),
+        DecimalDivideWithOverflowCheck(
+          sum,
+          count.cast(DecimalType.LongDecimal),
+          resultType.asInstanceOf[DecimalType],
+          queryContext,
+          !useAnsiAdd))
     case _: YearMonthIntervalType =>
       If(EqualTo(count, Literal(0L)),
         Literal(null, YearMonthIntervalType()), DivideYMInterval(sum, count))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index 7335763c253..2dd60a9d9ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.errors.QueryExecutionErrors
@@ -247,3 +247,79 @@ case class DecimalAddNoOverflowCheck(
       newLeft: Expression, newRight: Expression): DecimalAddNoOverflowCheck =
     copy(left = newLeft, right = newRight)
 }
+
+/**
+ * A divide expression for decimal values which is only used internally by Avg.
+ *
+ * It will fail when nullOnOverflow is false follows:
+ *   - left (sum in avg) is null due to over the max precision 38,
+ *     the right (count in avg) should never be null
+ *   - the result of divide is overflow
+ */
+case class DecimalDivideWithOverflowCheck(
+    left: Expression,
+    right: Expression,
+    override val dataType: DecimalType,
+    avgQueryContext: String,
+    nullOnOverflow: Boolean)
+  extends BinaryExpression with ExpectsInputTypes with SupportQueryContext {
+  override def nullable: Boolean = nullOnOverflow
+  override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType)
+  override def initQueryContext(): String = avgQueryContext
+  def decimalMethod: String = "$div"
+
+  override def eval(input: InternalRow): Any = {
+    val value1 = left.eval(input)
+    if (value1 == null) {
+      if (nullOnOverflow)  {
+        null
+      } else {
+        throw QueryExecutionErrors.overflowInSumOfDecimalError(queryContext)
+      }
+    } else {
+      val value2 = right.eval(input)
+      dataType.fractional.asInstanceOf[Fractional[Any]].div(value1, value2).asInstanceOf[Decimal]
+        .toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow,
+          queryContext)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val errorContextCode = if (nullOnOverflow) {
+      "\"\""
+    } else {
+      ctx.addReferenceObj("errCtx", queryContext)
+    }
+    val nullHandling = if (nullOnOverflow) {
+      ""
+    } else {
+      s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
+    }
+
+    val eval1 = left.genCode(ctx)
+    val eval2 = right.genCode(ctx)
+
+    // scalastyle:off line.size.limit
+    val code =
+      code"""
+         |${eval1.code}
+         |${eval2.code}
+         |boolean ${ev.isNull} = ${eval1.isNull};
+         |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+         |if (${eval1.isNull}) {
+         |  $nullHandling
+         |} else {
+         |  ${ev.value} = ${eval1.value}.$decimalMethod(${eval2.value}).toPrecision(
+         |      ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode);
+         |  ${ev.isNull} = ${ev.value} == null;
+         |}
+      """.stripMargin
+    // scalastyle:on line.size.limit
+    ev.copy(code = code)
+  }
+
+  override protected def withNewChildrenInternal(
+      newLeft: Expression, newRight: Expression): Expression = {
+    copy(left = newLeft, right = newRight)
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 296adcaa3e8..81a9294df39 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -322,6 +322,10 @@ class DataFrameAggregateSuite extends QueryTest
       decimalData.agg(
         avg($"a" cast DecimalType(10, 2)), sum_distinct($"a" cast DecimalType(10, 2))),
       Row(new java.math.BigDecimal(2), new java.math.BigDecimal(6)) :: Nil)
+
+    checkAnswer(
+      emptyTestData.agg(avg($"key" cast DecimalType(10, 0))),
+      Row(null))
   }
 
   test("null average") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index ac1c59ae01e..7e772c0febb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -2166,7 +2166,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
         df.queryExecution.optimizedPlan.collect {
           case _: DataSourceV2ScanRelation =>
             val expected_plan_fragment =
-              "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]"
+              "PushedAggregates: [COUNT(PRICE), SUM(PRICE)]"
             checkKeywordsExistsInExplain(df, expected_plan_fragment)
         }
         if (ansiEnabled) {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index e560c2ea32a..e63cdddd81c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -1023,7 +1023,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
       ("a", BigDecimal("11.9999999988"))).toDF("text", "number")
     val agg1 = df.groupBy($"text").agg(avg($"number").as("avg_res"))
     val agg2 = agg1.groupBy($"text").agg(sum($"avg_res"))
-    checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142860000")))
+    checkAnswer(agg2, Row("a", BigDecimal("11.9999999994857142857143")))
   }
 
   test("SPARK-29122: hash-based aggregates for unfixed-length decimals in the interpreter mode") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org