You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2023/02/15 05:57:30 UTC

[spark] branch master updated: [SPARK-42427][SQL] ANSI MODE: Conv should return an error if the internal conversion overflows

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cb463fb40e8 [SPARK-42427][SQL] ANSI MODE: Conv should return an error if the internal conversion overflows
cb463fb40e8 is described below

commit cb463fb40e8f663b7e3019c8d8560a3490c241d0
Author: Gengliang Wang <ge...@apache.org>
AuthorDate: Tue Feb 14 21:57:13 2023 -0800

    [SPARK-42427][SQL] ANSI MODE: Conv should return an error if the internal conversion overflows
    
    ### What changes were proposed in this pull request?
    
    In ANSI SQL mode, function Conv() should return an error if the internal conversion overflows
    For example, before the change:
    ```
    > select conv('fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', 16, 10)
    18446744073709551615
    ```
    After the change
    ```
    > select conv('fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', 16, 10)
    org.apache.spark.SparkArithmeticException: [ARITHMETIC_OVERFLOW] Overflow in function conv(). If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
    == SQL(line 1, position 8) ==
    select conv('fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', 16, 10)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ```
    
    ### Why are the changes needed?
    
    Similar to the other SQL functions, this PR shows the overflow errors of `conv()` to users under ANSI SQL mode, instead of returning an unexpected number.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, function `conv()` will return an error if the internal conversion overflows
    
    ### How was this patch tested?
    
    UTs
    
    Closes #40001 from gengliangwang/fixConv.
    
    Authored-by: Gengliang Wang <ge...@apache.org>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 .../sql/catalyst/expressions/mathExpressions.scala | 27 +++++-
 .../spark/sql/catalyst/util/NumberConverter.scala  | 30 +++++--
 .../spark/sql/errors/QueryExecutionErrors.scala    |  4 +
 .../expressions/MathExpressionsSuite.scala         | 55 +++++++++----
 .../sql/catalyst/util/NumberConverterSuite.scala   | 15 +++-
 .../src/test/resources/sql-tests/inputs/math.sql   | 10 ++-
 .../resources/sql-tests/results/ansi/math.sql.out  | 96 ++++++++++++++++++++++
 .../test/resources/sql-tests/results/math.sql.out  | 48 +++++++++++
 8 files changed, 257 insertions(+), 28 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index f6bef9c6cc2..dcc821a24ea 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -434,8 +434,18 @@ case class Acosh(child: Expression)
   """,
   since = "1.5.0",
   group = "math_funcs")
-case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
-  extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
+case class Conv(
+    numExpr: Expression,
+    fromBaseExpr: Expression,
+    toBaseExpr: Expression,
+    ansiEnabled: Boolean = SQLConf.get.ansiEnabled)
+  extends TernaryExpression
+    with ImplicitCastInputTypes
+    with NullIntolerant
+    with SupportQueryContext {
+
+  def this(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) =
+    this(numExpr, fromBaseExpr, toBaseExpr, ansiEnabled = SQLConf.get.ansiEnabled)
 
   override def first: Expression = numExpr
   override def second: Expression = fromBaseExpr
@@ -448,14 +458,17 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
     NumberConverter.convert(
       num.asInstanceOf[UTF8String].trim().getBytes,
       fromBase.asInstanceOf[Int],
-      toBase.asInstanceOf[Int])
+      toBase.asInstanceOf[Int],
+      ansiEnabled,
+      getContextOrNull())
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val numconv = NumberConverter.getClass.getName.stripSuffix("$")
+    val context = getContextOrNullCode(ctx, ansiEnabled)
     nullSafeCodeGen(ctx, ev, (num, from, to) =>
       s"""
-       ${ev.value} = $numconv.convert($num.trim().getBytes(), $from, $to);
+       ${ev.value} = $numconv.convert($num.trim().getBytes(), $from, $to, $ansiEnabled, $context);
        if (${ev.value} == null) {
          ${ev.isNull} = true;
        }
@@ -466,6 +479,12 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre
   override protected def withNewChildrenInternal(
       newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
     copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird)
+
+  override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) {
+    Some(origin.context)
+  } else {
+    None
+  }
 }
 
 @ExpressionDescription(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
index 68a1ba25423..59765cde1f9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.util
 
+import org.apache.spark.sql.catalyst.trees.SQLQueryContext
+import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.unsafe.types.UTF8String
 
 object NumberConverter {
@@ -47,7 +49,12 @@ object NumberConverter {
    * @param fromPos is the first element that should be considered
    * @return the result should be treated as an unsigned 64-bit integer.
    */
-  private def encode(radix: Int, fromPos: Int, value: Array[Byte]): Long = {
+  private def encode(
+      radix: Int,
+      fromPos: Int,
+      value: Array[Byte],
+      ansiEnabled: Boolean,
+      context: SQLQueryContext): Long = {
     var v: Long = 0L
     // bound will always be positive since radix >= 2
     // Note that: -1 is equivalent to 11111111...1111 which is the largest unsigned long value
@@ -57,7 +64,11 @@ object NumberConverter {
       // if v < 0, which mean its bit presentation starts with 1, so v * radix will cause
       // overflow since radix is greater than 2
       if (v < 0) {
-        return -1
+        if (ansiEnabled) {
+          throw QueryExecutionErrors.overflowInConvError(context)
+        } else {
+          return -1
+        }
       }
       // check if v greater than bound
       // if v is greater than bound, v * radix + radix will cause overflow.
@@ -67,7 +78,11 @@ object NumberConverter {
         // will start with 0) and we can easily checking for overflow by checking
         // (-1 - value(i)) / radix < v or not
         if (java.lang.Long.divideUnsigned(-1 - value(i), radix) < v) {
-          return -1
+          if (ansiEnabled) {
+            throw QueryExecutionErrors.overflowInConvError(context)
+          } else {
+            return -1
+          }
         }
       }
       v = v * radix + value(i)
@@ -114,7 +129,12 @@ object NumberConverter {
    * unsigned, otherwise it is signed.
    * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv
    */
-  def convert(n: Array[Byte], fromBase: Int, toBase: Int ): UTF8String = {
+  def convert(
+      n: Array[Byte],
+      fromBase: Int,
+      toBase: Int,
+      ansiEnabled: Boolean,
+      context: SQLQueryContext): UTF8String = {
     if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX
         || Math.abs(toBase) < Character.MIN_RADIX
         || Math.abs(toBase) > Character.MAX_RADIX) {
@@ -135,7 +155,7 @@ object NumberConverter {
     char2byte(fromBase, temp.length - n.length + first, temp)
 
     // Do the conversion by going through a 64 bit integer
-    v = encode(fromBase, temp.length - n.length + first, temp)
+    v = encode(fromBase, temp.length - n.length + first, temp, ansiEnabled, context)
 
     if (negative && toBase > 0) {
       if (v < 0) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index fd3809ccd31..2bafa2e2c03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -310,6 +310,10 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
     arithmeticOverflowError("Overflow in integral divide", "try_divide", context)
   }
 
+  def overflowInConvError(context: SQLQueryContext): ArithmeticException = {
+    arithmeticOverflowError("Overflow in function conv()", context = context)
+  }
+
   def mapSizeExceedArraySizeWhenZipMapError(size: Int): SparkRuntimeException = {
     new SparkRuntimeException(
       errorClass = "_LEGACY_ERROR_TEMP_2003",
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
index be0375af094..437f7ddee01 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala
@@ -158,22 +158,45 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
   }
 
   test("conv") {
-    checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2)), "11")
-    checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F")
-    checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1")
-    checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48")
-    checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null)
-    checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null)
-    checkEvaluation(Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType)), null)
-    checkEvaluation(
-      Conv(Literal("1234"), Literal(10), Literal(37)), null)
-    checkEvaluation(
-      Conv(Literal(""), Literal(10), Literal(16)), null)
-    checkEvaluation(
-      Conv(Literal("9223372036854775807"), Literal(36), Literal(16)), "FFFFFFFFFFFFFFFF")
-    // If there is an invalid digit in the number, the longest valid prefix should be converted.
-    checkEvaluation(
-      Conv(Literal("11abc"), Literal(10), Literal(16)), "B")
+    Seq(true, false).foreach { ansiEnabled =>
+      checkEvaluation(Conv(Literal("3"), Literal(10), Literal(2), ansiEnabled), "11")
+      checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16), ansiEnabled), "-F")
+      checkEvaluation(
+        Conv(Literal("-15"), Literal(10), Literal(16), ansiEnabled), "FFFFFFFFFFFFFFF1")
+      checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16), ansiEnabled), "3A48")
+      checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16), ansiEnabled),
+        null)
+      checkEvaluation(
+        Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16), ansiEnabled), null)
+      checkEvaluation(
+        Conv(Literal("3"), Literal(16), Literal.create(null, IntegerType), ansiEnabled), null)
+      checkEvaluation(
+        Conv(Literal("1234"), Literal(10), Literal(37), ansiEnabled), null)
+      checkEvaluation(
+        Conv(Literal(""), Literal(10), Literal(16), ansiEnabled), null)
+
+      // If there is an invalid digit in the number, the longest valid prefix should be converted.
+      checkEvaluation(
+        Conv(Literal("11abc"), Literal(10), Literal(16), ansiEnabled), "B")
+    }
+  }
+
+  test("conv overflow") {
+    Seq(
+      ("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF"),
+      ("92233720368547758070", 10, 16, "FFFFFFFFFFFFFFFF"),
+      ("-92233720368547758070", 10, 16, "FFFFFFFFFFFFFFFF"),
+      ("100000000000000000000000000000000000000000000000000000000000000000", 2, 10,
+        "18446744073709551615"),
+      ("100000000000000000000000000000000000000000000000000000000000000000", 2, 8,
+        "1777777777777777777777")
+    ).foreach { case (numExpr, fromBase, toBase, expected) =>
+      checkEvaluation(
+       Conv(Literal(numExpr), Literal(fromBase), Literal(toBase), ansiEnabled = false), expected)
+      checkExceptionInExpression[SparkArithmeticException](
+        Conv(Literal(numExpr), Literal(fromBase), Literal(toBase), ansiEnabled = true),
+        "Overflow in function conv()")
+    }
   }
 
   test("e") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
index eb257b79756..c634c5b739b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala
@@ -27,7 +27,18 @@ import org.apache.spark.unsafe.types.UTF8String
 class NumberConverterSuite extends SparkFunSuite {
 
   private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = {
-    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) ===
+    Seq(true, false).foreach { ansiEnabled =>
+      checkConv(n, fromBase, toBase, expected, ansiEnabled)
+    }
+  }
+
+  private[this] def checkConv(
+      n: String,
+      fromBase: Int,
+      toBase: Int,
+      expected: String,
+      ansiEnabled: Boolean): Unit = {
+    assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase, ansiEnabled, null) ===
       UTF8String.fromString(expected))
   }
 
@@ -36,7 +47,7 @@ class NumberConverterSuite extends SparkFunSuite {
     checkConv("-15", 10, -16, "-F")
     checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1")
     checkConv("big", 36, 16, "3A48")
-    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF")
+    checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF", ansiEnabled = false)
     checkConv("11abc", 10, 16, "B")
   }
 
diff --git a/sql/core/src/test/resources/sql-tests/inputs/math.sql b/sql/core/src/test/resources/sql-tests/inputs/math.sql
index 46ee9fdb2d5..96fb0eeef7a 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/math.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/math.sql
@@ -68,4 +68,12 @@ SELECT bround(525L, -1);
 SELECT bround(525L, -2);
 SELECT bround(525L, -3);
 SELECT bround(9223372036854775807L, -1);
-SELECT bround(-9223372036854775808L, -1);
\ No newline at end of file
+SELECT bround(-9223372036854775808L, -1);
+
+-- Conv
+SELECT conv('100', 2, 10);
+SELECT conv(-10, 16, -10);
+SELECT conv('9223372036854775808', 10, 16);
+SELECT conv('92233720368547758070', 10, 16);
+SELECT conv('9223372036854775807', 36, 10);
+SELECT conv('-9223372036854775807', 36, 10);
diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
index e24b23bbacd..8cd1536d7f7 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/math.sql.out
@@ -701,3 +701,99 @@ org.apache.spark.SparkArithmeticException
     "fragment" : "bround(-9223372036854775808L, -1)"
   } ]
 }
+
+
+-- !query
+SELECT conv('100', 2, 10)
+-- !query schema
+struct<conv(100, 2, 10):string>
+-- !query output
+4
+
+
+-- !query
+SELECT conv(-10, 16, -10)
+-- !query schema
+struct<conv(-10, 16, -10):string>
+-- !query output
+-16
+
+
+-- !query
+SELECT conv('9223372036854775808', 10, 16)
+-- !query schema
+struct<conv(9223372036854775808, 10, 16):string>
+-- !query output
+8000000000000000
+
+
+-- !query
+SELECT conv('92233720368547758070', 10, 16)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "alternative" : "",
+    "config" : "\"spark.sql.ansi.enabled\"",
+    "message" : "Overflow in function conv()"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 43,
+    "fragment" : "conv('92233720368547758070', 10, 16)"
+  } ]
+}
+
+
+-- !query
+SELECT conv('9223372036854775807', 36, 10)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "alternative" : "",
+    "config" : "\"spark.sql.ansi.enabled\"",
+    "message" : "Overflow in function conv()"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 42,
+    "fragment" : "conv('9223372036854775807', 36, 10)"
+  } ]
+}
+
+
+-- !query
+SELECT conv('-9223372036854775807', 36, 10)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkArithmeticException
+{
+  "errorClass" : "ARITHMETIC_OVERFLOW",
+  "sqlState" : "22003",
+  "messageParameters" : {
+    "alternative" : "",
+    "config" : "\"spark.sql.ansi.enabled\"",
+    "message" : "Overflow in function conv()"
+  },
+  "queryContext" : [ {
+    "objectType" : "",
+    "objectName" : "",
+    "startIndex" : 8,
+    "stopIndex" : 43,
+    "fragment" : "conv('-9223372036854775807', 36, 10)"
+  } ]
+}
diff --git a/sql/core/src/test/resources/sql-tests/results/math.sql.out b/sql/core/src/test/resources/sql-tests/results/math.sql.out
index 4721d66b6fc..d3df5cb9335 100644
--- a/sql/core/src/test/resources/sql-tests/results/math.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/math.sql.out
@@ -445,3 +445,51 @@ SELECT bround(-9223372036854775808L, -1)
 struct<bround(-9223372036854775808, -1):bigint>
 -- !query output
 9223372036854775806
+
+
+-- !query
+SELECT conv('100', 2, 10)
+-- !query schema
+struct<conv(100, 2, 10):string>
+-- !query output
+4
+
+
+-- !query
+SELECT conv(-10, 16, -10)
+-- !query schema
+struct<conv(-10, 16, -10):string>
+-- !query output
+-16
+
+
+-- !query
+SELECT conv('9223372036854775808', 10, 16)
+-- !query schema
+struct<conv(9223372036854775808, 10, 16):string>
+-- !query output
+8000000000000000
+
+
+-- !query
+SELECT conv('92233720368547758070', 10, 16)
+-- !query schema
+struct<conv(92233720368547758070, 10, 16):string>
+-- !query output
+FFFFFFFFFFFFFFFF
+
+
+-- !query
+SELECT conv('9223372036854775807', 36, 10)
+-- !query schema
+struct<conv(9223372036854775807, 36, 10):string>
+-- !query output
+18446744073709551615
+
+
+-- !query
+SELECT conv('-9223372036854775807', 36, 10)
+-- !query schema
+struct<conv(-9223372036854775807, 36, 10):string>
+-- !query output
+18446744073709551615


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org