You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2023/02/15 05:57:30 UTC
[spark] branch master updated: [SPARK-42427][SQL] ANSI MODE: Conv should return an error if the internal conversion overflows
This is an automated email from the ASF dual-hosted git repository.
gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new cb463fb40e8 [SPARK-42427][SQL] ANSI MODE: Conv should return an error if the internal conversion overflows
cb463fb40e8 is described below
commit cb463fb40e8f663b7e3019c8d8560a3490c241d0
Author: Gengliang Wang <ge...@apache.org>
AuthorDate: Tue Feb 14 21:57:13 2023 -0800
[SPARK-42427][SQL] ANSI MODE: Conv should return an error if the internal conversion overflows
### What changes were proposed in this pull request?
In ANSI SQL mode, function Conv() should return an error if the internal conversion overflows
For example, before the change:
```
> select conv('fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', 16, 10)
18446744073709551615
```
After the change
```
> select conv('fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', 16, 10)
org.apache.spark.SparkArithmeticException: [ARITHMETIC_OVERFLOW] Overflow in function conv(). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
== SQL(line 1, position 8) ==
select conv('fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', 16, 10)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```
### Why are the changes needed?
Similar to the other SQL functions, this PR shows the overflow errors of `conv()` to users under ANSI SQL mode, instead of returning an unexpected number.
### Does this PR introduce _any_ user-facing change?
Yes, function `conv()` will return an error if the internal conversion overflows
### How was this patch tested?
UTs
Closes #40001 from gengliangwang/fixConv.
Authored-by: Gengliang Wang <ge...@apache.org>
Signed-off-by: Gengliang Wang <ge...@apache.org>
---
.../sql/catalyst/expressions/mathExpressions.scala | 27 +++++-
.../spark/sql/catalyst/util/NumberConverter.scala | 30 +++++--
.../spark/sql/errors/QueryExecutionErrors.scala | 4 +
.../expressions/MathExpressionsSuite.scala | 55 +++++++++----
.../sql/catalyst/util/NumberConverterSuite.scala | 15 +++-
.../src/test/resources/sql-tests/inputs/math.sql | 10 ++-
.../resources/sql-tests/results/ansi/math.sql.out | 96 ++++++++++++++++++++++
.../test/resources/sql-tests/results/math.sql.out | 48 +++++++++++
8 files changed, 257 insertions(+), 28 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index f6bef9c6cc2..dcc821a24ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -434,8 +434,18 @@ case class Acosh(child: Expression)
""",
since = "1.5.0",
group = "math_funcs")
-case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
- extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
+case class Conv(
+ numExpr: Expression,
+ fromBaseExpr: Expression,
+ toBaseExpr: Expression,
+ ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
+ extends TernaryExpression
+ with ImplicitCastInputTypes
+ with NullIntolerant
+ with SupportQueryContext {
+
+ def this(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) =
+ this(numExpr, fromBaseExpr, toBaseExpr, ansiEnabled = SQLConf.get.ansiEnabled)
override def first: Expression = numExpr
override def second: Expression = fromBaseExpr
@@ -448,14 +458,17 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
NumberConverter.convert(
num.asInstanceOf[UTF8String].trim().getBytes,
fromBase.asInstanceOf[Int],
- toBase.asInstanceOf[Int])
+ toBase.asInstanceOf[Int],
+ ansiEnabled,
+ getContextOrNull())
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val numconv = NumberConverter.getClass.getName.stripSuffix("$")
+ val context = getContextOrNullCode(ctx, ansiEnabled)
nullSafeCodeGen(ctx, ev, (num, from, to) =>
s"""
- ${ev.value} = $numconv.convert($num.trim().getBytes(), $from, $to);
+ ${ev.value} = $numconv.convert($num.trim().getBytes(), $from, $to, $ansiEnabled, $context);
if (${ev.value} == null) {
${ev.isNull} = true;
}
@@ -466,6 +479,12 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
override protected def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird)
+
+ override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) {
+ Some(origin.context)
+ } else {
+ None
+ }
}
@ExpressionDescription(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
index 68a1ba25423..59765cde1f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.util
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.unsafe.types.UTF8String
object NumberConverter {
@@ -47,7 +49,12 @@ object NumberConverter {
* @param fromPos is the first element that should be considered
* @return the result should be treated as an unsigned 64-bit integer.
*/
- private def encode(radix: Int, fromPos: Int, value: Array[Byte]): Long = {
+ private def encode(
+ radix: Int,
+ fromPos: Int,
+ value: Array[Byte],
+ ansiEnabled: Boolean,
+ context: SQLQueryContext): Long = {
var v: Long = 0L
// bound will always be positive since radix >= 2
// Note that: -1 is equivalent to 11111111...1111 which is the largest unsigned long value
@@ -57,7 +64,11 @@ object NumberConverter {
// if v < 0, which mean its bit presentation starts with 1, so v * radix will cause
// overflow since radix is greater than 2
if (v < 0) {
- return -1
+ if (ansiEnabled) {
+ throw QueryExecutionErrors.overflowInConvError(context)
+ } else {
+ return -1
+ }
}
// check if v greater than bound
// if v is greater than bound, v * radix + radix will cause overflow.
@@ -67,7 +78,11 @@ object NumberConverter {
// will start with 0) and we can easily checking for overflow by checking
// (-1 - value(i)) / radix < v or not
if (java.lang.Long.divideUnsigned(-1 - value(i), radix) < v) {
- return -1
+ if (ansiEnabled) {
+ throw QueryExecutionErrors.overflowInConvError(context)
+ } else {
+ return -1
+ }
}
}
v = v * radix + value(i)
@@ -114,7 +129,12 @@ object NumberConverter {
* unsigned, otherwise it is signed.
* NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv
*/
- def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = {
+ def convert(
+ n: Array[Byte],
+ fromBase: Int,
+ toBase: Int,
+ ansiEnabled: Boolean,
+ context: SQLQueryContext): UTF8String = {
if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
|| Math.abs(toBase) < Character.MIN_RADIX
|| Math.abs(toBase) > Character.MAX_RADIX) {
@@ -135,7 +155,7 @@ object NumberConverter {
char2byte(fromBase, temp.length - n.length + first, temp)
// Do the conversion by going through a 64 bit integer
- v = encode(fromBase, temp.length - n.length + first, temp)
+ v = encode(fromBase, temp.length - n.length + first, temp, ansiEnabled, context)
if (negative && toBase > 0) {
if (v < 0) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index fd3809ccd31..2bafa2e2c03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -310,6 +310,10 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
arithmeticOverflowError("Overflow in integral divide", "try_divide", context)
}
+ def overflowInConvError(context: SQLQueryContext): ArithmeticException = {
+ arithmeticOverflowError("Overflow in function conv()", context = context)
+ }
+
def mapSizeExceedArraySizeWhenZipMapError(size: Int): SparkRuntimeException = {
new SparkRuntimeException(
errorClass = "_LEGACY_ERROR_TEMP_2003",
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index be0375af094..437f7ddee01 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -158,22 +158,45 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("conv") {
- checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
- checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
- checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
- checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
- checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
- checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
- checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null)
- checkEvaluation(
- Conv(Literal("1234"), Literal(10), Literal(37)), null)
- checkEvaluation(
- Conv(Literal(""), Literal(10), Literal(16)), null)
- checkEvaluation(
- Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
- // If there is an invalid digit in the number, the longest valid prefix should be converted.
- checkEvaluation(
- Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
+ Seq(true, false).foreach { ansiEnabled =>
+ checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2), ansiEnabled), "11")
+ checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16), ansiEnabled), "-F")
+ checkEvaluation(
+ Conv(Literal("-15"), Literal(10), Literal(16), ansiEnabled), "FFFFFFFFFFFFFFF1")
+ checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16), ansiEnabled), "3A48")
+ checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16), ansiEnabled),
+ null)
+ checkEvaluation(
+ Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16), ansiEnabled), null)
+ checkEvaluation(
+ Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType), ansiEnabled), null)
+ checkEvaluation(
+ Conv(Literal("1234"), Literal(10), Literal(37), ansiEnabled), null)
+ checkEvaluation(
+ Conv(Literal(""), Literal(10), Literal(16), ansiEnabled), null)
+
+ // If there is an invalid digit in the number, the longest valid prefix should be converted.
+ checkEvaluation(
+ Conv(Literal("11abc"), Literal(10), Literal(16), ansiEnabled), "B")
+ }
+ }
+
+ test("conv overflow") {
+ Seq(
+ ("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF"),
+ ("92233720368547758070", 10, 16, "FFFFFFFFFFFFFFFF"),
+ ("-92233720368547758070", 10, 16, "FFFFFFFFFFFFFFFF"),
+ ("100000000000000000000000000000000000000000000000000000000000000000", 2, 10,
+ "18446744073709551615"),
+ ("100000000000000000000000000000000000000000000000000000000000000000", 2, 8,
+ "1777777777777777777777")
+ ).foreach { case (numExpr, fromBase, toBase, expected) =>
+ checkEvaluation(
+ Conv(Literal(numExpr), Literal(fromBase), Literal(toBase), ansiEnabled = false), expected)
+ checkExceptionInExpression[SparkArithmeticException](
+ Conv(Literal(numExpr), Literal(fromBase), Literal(toBase), ansiEnabled = true),
+ "Overflow in function conv()")
+ }
}
test("e") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
index eb257b79756..c634c5b739b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
@@ -27,7 +27,18 @@ import org.apache.spark.unsafe.types.UTF8String
class NumberConverterSuite extends SparkFunSuite {
private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
- assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
+ Seq(true, false).foreach { ansiEnabled =>
+ checkConv(n, fromBase, toBase, expected, ansiEnabled)
+ }
+ }
+
+ private[this] def checkConv(
+ n: String,
+ fromBase: Int,
+ toBase: Int,
+ expected: String,
+ ansiEnabled: Boolean): Unit = {
+ assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase, ansiEnabled, null) ===
UTF8String.fromString(expected))
}
@@ -36,7 +47,7 @@ class NumberConverterSuite extends SparkFunSuite {
checkConv("-15", 10, -16, "-F")
checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
checkConv("big", 36, 16, "3A48")
- checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
+ checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF", ansiEnabled = false)
checkConv("11abc", 10, 16, "B")
}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/math.sql b/sql/core/src/test/resources/sql-tests/inputs/math.sql
index 46ee9fdb2d5..96fb0eeef7a 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/math.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/math.sql
@@ -68,4 +68,12 @@ SELECT bround(525L, -1);
SELECT bround(525L, -2);
SELECT bround(525L, -3);
SELECT bround(9223372036854775807L, -1);
-SELECT bround(-9223372036854775808L, -1);
\ No newline at end of file
+SELECT bround(-9223372036854775808L, -1);
+
+-- Conv
+SELECT conv('100', 2, 10);
+SELECT conv(-10, 16, -10);
+SELECT conv('9223372036854775808', 10, 16);
+SELECT conv('92233720368547758070', 10, 16);
+SELECT conv('9223372036854775807', 36, 10);
+SELECT conv('-9223372036854775807', 36, 10);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
index e24b23bbacd..8cd1536d7f7 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
@@ -701,3 +701,99 @@ org.apache.spark.SparkArithmeticException
"fragment" : "bround(-9223372036854775808L, -1)"
} ]
}
+
+
+-- !query
+SELECT conv('100', 2, 10)
+-- !query schema
+struct<conv(100, 2, 10):string>
+-- !query output
+4
+
+
+-- !query
+SELECT conv(-10, 16, -10)
+-- !query schema
+struct<conv(-10, 16, -10):string>
+-- !query output
+-16
+
+
+-- !query
+SELECT conv('9223372036854775808', 10, 16)
+-- !query schema
+struct<conv(9223372036854775808, 10, 16):string>
+-- !query output
+8000000000000000
+
+
+-- !query
+SELECT conv('92233720368547758070', 10, 16)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "ARITHMETIC_OVERFLOW",
+ "sqlState" : "22003",
+ "messageParameters" : {
+ "alternative" : "",
+ "config" : "\"spark.sql.ansi.enabled\"",
+ "message" : "Overflow in function conv()"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 43,
+ "fragment" : "conv('92233720368547758070', 10, 16)"
+ } ]
+}
+
+
+-- !query
+SELECT conv('9223372036854775807', 36, 10)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "ARITHMETIC_OVERFLOW",
+ "sqlState" : "22003",
+ "messageParameters" : {
+ "alternative" : "",
+ "config" : "\"spark.sql.ansi.enabled\"",
+ "message" : "Overflow in function conv()"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 42,
+ "fragment" : "conv('9223372036854775807', 36, 10)"
+ } ]
+}
+
+
+-- !query
+SELECT conv('-9223372036854775807', 36, 10)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "ARITHMETIC_OVERFLOW",
+ "sqlState" : "22003",
+ "messageParameters" : {
+ "alternative" : "",
+ "config" : "\"spark.sql.ansi.enabled\"",
+ "message" : "Overflow in function conv()"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 43,
+ "fragment" : "conv('-9223372036854775807', 36, 10)"
+ } ]
+}
diff --git a/sql/core/src/test/resources/sql-tests/results/math.sql.out b/sql/core/src/test/resources/sql-tests/results/math.sql.out
index 4721d66b6fc..d3df5cb9335 100644
--- a/sql/core/src/test/resources/sql-tests/results/math.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/math.sql.out
@@ -445,3 +445,51 @@ SELECT bround(-9223372036854775808L, -1)
struct<bround(-9223372036854775808, -1):bigint>
-- !query output
9223372036854775806
+
+
+-- !query
+SELECT conv('100', 2, 10)
+-- !query schema
+struct<conv(100, 2, 10):string>
+-- !query output
+4
+
+
+-- !query
+SELECT conv(-10, 16, -10)
+-- !query schema
+struct<conv(-10, 16, -10):string>
+-- !query output
+-16
+
+
+-- !query
+SELECT conv('9223372036854775808', 10, 16)
+-- !query schema
+struct<conv(9223372036854775808, 10, 16):string>
+-- !query output
+8000000000000000
+
+
+-- !query
+SELECT conv('92233720368547758070', 10, 16)
+-- !query schema
+struct<conv(92233720368547758070, 10, 16):string>
+-- !query output
+FFFFFFFFFFFFFFFF
+
+
+-- !query
+SELECT conv('9223372036854775807', 36, 10)
+-- !query schema
+struct<conv(9223372036854775807, 36, 10):string>
+-- !query output
+18446744073709551615
+
+
+-- !query
+SELECT conv('-9223372036854775807', 36, 10)
+-- !query schema
+struct<conv(-9223372036854775807, 36, 10):string>
+-- !query output
+18446744073709551615
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org