You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/10/14 09:07:15 UTC
[spark] branch branch-3.2 updated: [SPARK-36632][SQL]
DivideYMInterval and DivideDTInterval should throw the same exception when
divide by zero
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new d93d056 [SPARK-36632][SQL] DivideYMInterval and DivideDTInterval should throw the same exception when divide by zero
d93d056 is described below
commit d93d0560db18681c74916a6080f87f4136dde434
Author: gengjiaan <ge...@360.cn>
AuthorDate: Thu Oct 14 17:05:25 2021 +0800
[SPARK-36632][SQL] DivideYMInterval and DivideDTInterval should throw the same exception when divide by zero
### What changes were proposed in this pull request?
When dividing by zero, `DivideYMInterval` and `DivideDTInterval` output
```
java.lang.ArithmeticException
/ by zero
```
But, in ansi mode, `select 2 / 0` will output
```
org.apache.spark.SparkArithmeticException
divide by zero
```
The behavior looks not inconsistent.
### Why are the changes needed?
Make consistent behavior.
### Does this PR introduce _any_ user-facing change?
'Yes'.
### How was this patch tested?
New tests.
Closes #33889 from beliefer/SPARK-36632.
Lead-authored-by: gengjiaan <ge...@360.cn>
Co-authored-by: beliefer <be...@163.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit de0161a4e85d3125e438a3431285d2fee22c1c65)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../catalyst/expressions/intervalExpressions.scala | 47 +++++++++++++++++-----
.../expressions/IntervalExpressionsSuite.scala | 8 ++--
.../sql-tests/results/ansi/interval.sql.out | 8 ++--
.../resources/sql-tests/results/interval.sql.out | 8 ++--
.../apache/spark/sql/ColumnExpressionSuite.scala | 28 ++++++++++++-
5 files changed, 75 insertions(+), 24 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
index c799c69..4f31708 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala
@@ -598,6 +598,17 @@ trait IntervalDivide {
}
}
}
+
+ def divideByZeroCheck(dataType: DataType, num: Any): Unit = dataType match {
+ case _: DecimalType =>
+ if (num.asInstanceOf[Decimal].isZero) throw QueryExecutionErrors.divideByZeroError()
+ case _ => if (num == 0) throw QueryExecutionErrors.divideByZeroError()
+ }
+
+ def divideByZeroCheckCodegen(dataType: DataType, value: String): String = dataType match {
+ case _: DecimalType => s"if ($value.isZero()) throw QueryExecutionErrors.divideByZeroError();"
+ case _ => s"if ($value == 0) throw QueryExecutionErrors.divideByZeroError();"
+ }
}
// Divide an year-month interval by a numeric
@@ -629,6 +640,7 @@ case class DivideYMInterval(
override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Int], Int.MinValue, right, num)
+ divideByZeroCheck(right.dataType, num)
evalFunc(interval.asInstanceOf[Int], num)
}
@@ -650,17 +662,24 @@ case class DivideYMInterval(
// Similarly to non-codegen code. The result of `divide(Int, Long, ...)` must fit to `Int`.
// Casting to `Int` is safe here.
s"""
+ |${divideByZeroCheckCodegen(right.dataType, n)}
|$checkIntegralDivideOverflow
|${ev.value} = ($javaType)$math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
- defineCodeGen(ctx, ev, (m, n) =>
- s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
- ".setScale(0, java.math.RoundingMode.HALF_UP).intValueExact()")
+ nullSafeCodeGen(ctx, ev, (m, n) =>
+ s"""
+ |${divideByZeroCheckCodegen(right.dataType, n)}
+ |${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
+ | .setScale(0, java.math.RoundingMode.HALF_UP).intValueExact();
+ """.stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
- defineCodeGen(ctx, ev, (m, n) =>
- s"$math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP)")
+ nullSafeCodeGen(ctx, ev, (m, n) =>
+ s"""
+ |${divideByZeroCheckCodegen(right.dataType, n)}
+ |${ev.value} = $math.roundToInt($m / (double)$n, java.math.RoundingMode.HALF_UP);
+ """.stripMargin)
}
override def toString: String = s"($left / $right)"
@@ -696,6 +715,7 @@ case class DivideDTInterval(
override def nullSafeEval(interval: Any, num: Any): Any = {
checkDivideOverflow(interval.asInstanceOf[Long], Long.MinValue, right, num)
+ divideByZeroCheck(right.dataType, num)
evalFunc(interval.asInstanceOf[Long], num)
}
@@ -711,17 +731,24 @@ case class DivideDTInterval(
|""".stripMargin
nullSafeCodeGen(ctx, ev, (m, n) =>
s"""
+ |${divideByZeroCheckCodegen(right.dataType, n)}
|$checkIntegralDivideOverflow
|${ev.value} = $math.divide($m, $n, java.math.RoundingMode.HALF_UP);
""".stripMargin)
case _: DecimalType =>
- defineCodeGen(ctx, ev, (m, n) =>
- s"((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()" +
- ".setScale(0, java.math.RoundingMode.HALF_UP).longValueExact()")
+ nullSafeCodeGen(ctx, ev, (m, n) =>
+ s"""
+ |${divideByZeroCheckCodegen(right.dataType, n)}
+ |${ev.value} = ((new Decimal()).set($m).$$div($n)).toJavaBigDecimal()
+ | .setScale(0, java.math.RoundingMode.HALF_UP).longValueExact();
+ """.stripMargin)
case _: FractionalType =>
val math = classOf[DoubleMath].getName
- defineCodeGen(ctx, ev, (m, n) =>
- s"$math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP)")
+ nullSafeCodeGen(ctx, ev, (m, n) =>
+ s"""
+ |${divideByZeroCheckCodegen(right.dataType, n)}
+ |${ev.value} = $math.roundToLong($m / (double)$n, java.math.RoundingMode.HALF_UP);
+ """.stripMargin)
}
override def toString: String = s"($left / $right)"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
index 12509ef..05f9d0f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala
@@ -412,8 +412,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
Seq(
- (Period.ofMonths(1), 0) -> "/ by zero",
- (Period.ofMonths(Int.MinValue), 0d) -> "input is infinite or NaN",
+ (Period.ofMonths(1), 0) -> "divide by zero",
+ (Period.ofMonths(Int.MinValue), 0d) -> "divide by zero",
(Period.ofMonths(-100), Float.NaN) -> "input is infinite or NaN"
).foreach { case ((period, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
@@ -447,8 +447,8 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
Seq(
- (Duration.ofDays(1), 0) -> "/ by zero",
- (Duration.ofMillis(Int.MinValue), 0d) -> "input is infinite or NaN",
+ (Duration.ofDays(1), 0) -> "divide by zero",
+ (Duration.ofMillis(Int.MinValue), 0d) -> "divide by zero",
(Duration.ofSeconds(-100), Float.NaN) -> "input is infinite or NaN"
).foreach { case ((period, num), expectedErrMsg) =>
checkExceptionInExpression[ArithmeticException](
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
index c347d31..6a5fa69 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/interval.sql.out
@@ -209,8 +209,8 @@ select interval '2 seconds' / 0
-- !query schema
struct<>
-- !query output
-java.lang.ArithmeticException
-/ by zero
+org.apache.spark.SparkArithmeticException
+divide by zero
-- !query
@@ -242,8 +242,8 @@ select interval '2' year / 0
-- !query schema
struct<>
-- !query output
-java.lang.ArithmeticException
-/ by zero
+org.apache.spark.SparkArithmeticException
+divide by zero
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/results/interval.sql.out b/sql/core/src/test/resources/sql-tests/results/interval.sql.out
index a8fa101..70079da 100644
--- a/sql/core/src/test/resources/sql-tests/results/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/interval.sql.out
@@ -203,8 +203,8 @@ select interval '2 seconds' / 0
-- !query schema
struct<>
-- !query output
-java.lang.ArithmeticException
-/ by zero
+org.apache.spark.SparkArithmeticException
+divide by zero
-- !query
@@ -236,8 +236,8 @@ select interval '2' year / 0
-- !query schema
struct<>
-- !query output
-java.lang.ArithmeticException
-/ by zero
+org.apache.spark.SparkArithmeticException
+divide by zero
-- !query
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index b0cd613..e7ca431 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -2737,7 +2737,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Period.ofYears(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
- assert(e.getMessage.contains("/ by zero"))
+ assert(e.getMessage.contains("divide by zero"))
+
+ val e2 = intercept[SparkException] {
+ Seq((Period.ofYears(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
+ }.getCause
+ assert(e2.isInstanceOf[ArithmeticException])
+ assert(e2.getMessage.contains("divide by zero"))
+
+ val e3 = intercept[SparkException] {
+ Seq((Period.ofYears(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
+ }.getCause
+ assert(e3.isInstanceOf[ArithmeticException])
+ assert(e3.getMessage.contains("divide by zero"))
}
test("SPARK-34875: divide day-time interval by numeric") {
@@ -2772,7 +2784,19 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
Seq((Duration.ofDays(9999), 0)).toDF("i", "n").select($"i" / $"n").collect()
}.getCause
assert(e.isInstanceOf[ArithmeticException])
- assert(e.getMessage.contains("/ by zero"))
+ assert(e.getMessage.contains("divide by zero"))
+
+ val e2 = intercept[SparkException] {
+ Seq((Duration.ofDays(9999), 0d)).toDF("i", "n").select($"i" / $"n").collect()
+ }.getCause
+ assert(e2.isInstanceOf[ArithmeticException])
+ assert(e2.getMessage.contains("divide by zero"))
+
+ val e3 = intercept[SparkException] {
+ Seq((Duration.ofDays(9999), BigDecimal(0))).toDF("i", "n").select($"i" / $"n").collect()
+ }.getCause
+ assert(e3.isInstanceOf[ArithmeticException])
+ assert(e3.getMessage.contains("divide by zero"))
}
test("SPARK-34896: return day-time interval from dates subtraction") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org