You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/10/14 09:07:15 UTC

[spark] branch branch-3.2 updated: [SPARK-36632][SQL] DivideYMInterval and DivideDTInterval should throw the same exception when divide by zero

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new d93d056  [SPARK-36632][SQL] DivideYMInterval and DivideDTInterval should throw the same exception when divide by zero
d93d056 is described below

commit d93d0560db18681c74916a6080f87f4136dde434
Author: gengjiaan <ge...@360.cn>
AuthorDate: Thu Oct 14 17:05:25 2021 +0800

    [SPARK-36632][SQL] DivideYMInterval and DivideDTInterval should throw the same exception when divide by zero
    
    ### What changes were proposed in this pull request?
    When dividing by zero, `DivideYMInterval` and `DivideDTInterval` output
    ```
    java.lang.ArithmeticException
    / by zero
    ```
    But, in ansi mode, `select 2 / 0` will output
    ```
    org.apache.spark.SparkArithmeticException
    divide by zero
    ```
    The behavior looks not inconsistent.
    
    ### Why are the changes needed?
    Make consistent behavior.
    
    ### Does this PR introduce _any_ user-facing change?
    'Yes'.
    
    ### How was this patch tested?
    New tests.
    
    Closes #33889 from beliefer/SPARK-36632.
    
    Lead-authored-by: gengjiaan <ge...@360.cn>
    Co-authored-by: beliefer <be...@163.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit de0161a4e85d3125e438a3431285d2fee22c1c65)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../catalyst/expressions/intervalExpressions.scala | 47 +++++++++++++++++-----
 .../expressions/IntervalExpressionsSuite.scala     |  8 ++--
 .../sql-tests/results/ansi/interval.sql.out        |  8 ++--
 .../resources/sql-tests/results/interval.sql.out   |  8 ++--
 .../apache/spark/sql/ColumnExpressionSuite.scala   | 28 ++++++++++++-
 5 files changed, 75 insertions(+), 24 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index c799c69..4f31708 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -598,6 +598,17 @@ trait IntervalDivide {
       }
     }
   }
+
+  def divideByZeroCheck(dataType: DataType, num: Any): Unit = dataType match {
+    case _: DecimalType =>
+      if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError()
+    case _ => if (num == 0) throw QueryExecutionErrors.divideByZeroError()
+  }
+
+  def divideByZeroCheckCodegen(dataType: DataType, value: String): String = dataType match {
+    case _: DecimalType => s"if ($value.isZero()) throw QueryExecutionErrors.divideByZeroError();"
+    case _ => s"if ($value == 0) throw QueryExecutionErrors.divideByZeroError();"
+  }
 }
 
 // Divide an year-month interval by a numeric
@@ -629,6 +640,7 @@ case class DivideYMInterval(
 
   override def nullSafeEval(interval: Any, num: Any): Any = {
     checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num)
+    divideByZeroCheck(right.dataType, num)
     evalFunc(interval.asInstanceOf[Int], num)
   }
 
@@ -650,17 +662,24 @@ case class DivideYMInterval(
         // Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`.
         // Casting to `Int` is safe here.
         s"""
+           |${divideByZeroCheckCodegen(right.dataType, n)}
            |$checkIntegralDivideOverflow
            |${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP);
         """.stripMargin)
     case _: DecimalType =>
-      defineCodeGen(ctx, ev, (m, n) =>
-        s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
-        ".setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()")
+      nullSafeCodeGen(ctx, ev, (m, n) =>
+        s"""
+           |${divideByZeroCheckCodegen(right.dataType, n)}
+           |${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
+           |  .setScale(0, java.math.RoundingMode.HALF_UP).intValueExact();
+         """.stripMargin)
     case _: FractionalType =>
       val math = classOf[DoubleMath].getName
-      defineCodeGen(ctx, ev, (m, n) =>
-        s"$math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP)")
+      nullSafeCodeGen(ctx, ev, (m, n) =>
+        s"""
+           |${divideByZeroCheckCodegen(right.dataType, n)}
+           |${ev.value} = $math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP);
+         """.stripMargin)
   }
 
   override def toString: String = s"($left / $right)"
@@ -696,6 +715,7 @@ case class DivideDTInterval(
 
   override def nullSafeEval(interval: Any, num: Any): Any = {
     checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num)
+    divideByZeroCheck(right.dataType, num)
     evalFunc(interval.asInstanceOf[Long], num)
   }
 
@@ -711,17 +731,24 @@ case class DivideDTInterval(
            |""".stripMargin
       nullSafeCodeGen(ctx, ev, (m, n) =>
         s"""
+           |${divideByZeroCheckCodegen(right.dataType, n)}
            |$checkIntegralDivideOverflow
            |${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP);
         """.stripMargin)
     case _: DecimalType =>
-      defineCodeGen(ctx, ev, (m, n) =>
-        s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
-        ".setScale(0, java.math.RoundingMode.HALF_UP).longValueExact()")
+      nullSafeCodeGen(ctx, ev, (m, n) =>
+        s"""
+           |${divideByZeroCheckCodegen(right.dataType, n)}
+           |${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
+           |  .setScale(0, java.math.RoundingMode.HALF_UP).longValueExact();
+         """.stripMargin)
     case _: FractionalType =>
       val math = classOf[DoubleMath].getName
-      defineCodeGen(ctx, ev, (m, n) =>
-        s"$math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP)")
+      nullSafeCodeGen(ctx, ev, (m, n) =>
+        s"""
+           |${divideByZeroCheckCodegen(right.dataType, n)}
+           |${ev.value} = $math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP);
+         """.stripMargin)
   }
 
   override def toString: String = s"($left / $right)"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
index 12509ef..05f9d0f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
@@ -412,8 +412,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     }
 
     Seq(
-      (Period.ofMonths(1), 0) -> "/ by zero",
-      (Period.ofMonths(Int.MinValue), 0d) -> "input is infinite or NaN",
+      (Period.ofMonths(1), 0) -> "divide by zero",
+      (Period.ofMonths(Int.MinValue), 0d) -> "divide by zero",
       (Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN"
     ).foreach { case ((period, num), expectedErrMsg) =>
       checkExceptionInExpression[ArithmeticException](
@@ -447,8 +447,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
     }
 
     Seq(
-      (Duration.ofDays(1), 0) -> "/ by zero",
-      (Duration.ofMillis(Int.MinValue), 0d) -> "input is infinite or NaN",
+      (Duration.ofDays(1), 0) -> "divide by zero",
+      (Duration.ofMillis(Int.MinValue), 0d) -> "divide by zero",
       (Duration.ofSeconds(-100), Float.NaN) -> "input is infinite or NaN"
     ).foreach { case ((period, num), expectedErrMsg) =>
       checkExceptionInExpression[ArithmeticException](
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
index c347d31..6a5fa69 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
@@ -209,8 +209,8 @@ select interval '2 seconds' / 0
 -- !query schema
 struct<>
 -- !query output
-java.lang.ArithmeticException
-/ by zero
+org.apache.spark.SparkArithmeticException
+divide by zero
 
 
 -- !query
@@ -242,8 +242,8 @@ select interval '2' year / 0
 -- !query schema
 struct<>
 -- !query output
-java.lang.ArithmeticException
-/ by zero
+org.apache.spark.SparkArithmeticException
+divide by zero
 
 
 -- !query
diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out
index a8fa101..70079da 100644
--- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out
@@ -203,8 +203,8 @@ select interval '2 seconds' / 0
 -- !query schema
 struct<>
 -- !query output
-java.lang.ArithmeticException
-/ by zero
+org.apache.spark.SparkArithmeticException
+divide by zero
 
 
 -- !query
@@ -236,8 +236,8 @@ select interval '2' year / 0
 -- !query schema
 struct<>
 -- !query output
-java.lang.ArithmeticException
-/ by zero
+org.apache.spark.SparkArithmeticException
+divide by zero
 
 
 -- !query
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index b0cd613..e7ca431 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -2737,7 +2737,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
       Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
     }.getCause
     assert(e.isInstanceOf[ArithmeticException])
-    assert(e.getMessage.contains("/ by zero"))
+    assert(e.getMessage.contains("divide by zero"))
+
+    val e2 = intercept[SparkException] {
+      Seq((Period.ofYears(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
+    }.getCause
+    assert(e2.isInstanceOf[ArithmeticException])
+    assert(e2.getMessage.contains("divide by zero"))
+
+    val e3 = intercept[SparkException] {
+      Seq((Period.ofYears(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
+    }.getCause
+    assert(e3.isInstanceOf[ArithmeticException])
+    assert(e3.getMessage.contains("divide by zero"))
   }
 
   test("SPARK-34875: divide day-time interval by numeric") {
@@ -2772,7 +2784,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
       Seq((Duration.ofDays(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
     }.getCause
     assert(e.isInstanceOf[ArithmeticException])
-    assert(e.getMessage.contains("/ by zero"))
+    assert(e.getMessage.contains("divide by zero"))
+
+    val e2 = intercept[SparkException] {
+      Seq((Duration.ofDays(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
+    }.getCause
+    assert(e2.isInstanceOf[ArithmeticException])
+    assert(e2.getMessage.contains("divide by zero"))
+
+    val e3 = intercept[SparkException] {
+      Seq((Duration.ofDays(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
+    }.getCause
+    assert(e3.isInstanceOf[ArithmeticException])
+    assert(e3.getMessage.contains("divide by zero"))
   }
 
   test("SPARK-34896: return day-time interval from dates subtraction") {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org