You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2022/07/27 03:27:07 UTC

[spark] branch master updated: [SPARK-39865][SQL] Show proper error messages on the overflow errors of table insert

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d5dbe7d4e9e [SPARK-39865][SQL] Show proper error messages on the overflow errors of table insert
d5dbe7d4e9e is described below

commit d5dbe7d4e9e0e46a514c363efaac15f37d07857c
Author: Gengliang Wang <ge...@apache.org>
AuthorDate: Tue Jul 26 20:26:53 2022 -0700

    [SPARK-39865][SQL] Show proper error messages on the overflow errors of table insert
    
    ### What changes were proposed in this pull request?
    
    In Spark 3.3, the error message of ANSI CAST is improved. However, the table insertion is using the same CAST expression:
    ```
    > create table tiny(i tinyint);
    > insert into tiny values (1000);
    
    org.apache.spark.SparkArithmeticException[CAST_OVERFLOW]: The value 1000 of the type "INT" cannot be cast to "TINYINT" due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error.
    ```
    
    Showing the hint of `If necessary set "spark.sql.ansi.enabled" to "false" to bypass this error` doesn't help at all. This PR is to fix the error message. After changes, the error message of this example will become:
    ```
    org.apache.spark.SparkArithmeticException: [CAST_OVERFLOW_IN_TABLE_INSERT] Fail to insert a value of "INT" type into the "TINYINT" type column `i` due to an overflow. Use `try_cast` on the input value to tolerate overflow and return NULL instead.
    ```
    ### Why are the changes needed?
    
    Show proper error messages on the overflow errors of table insert. The current message is super confusing.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, after changes it show proper error messages on the overflow errors of table insert.
    
    ### How was this patch tested?
    
    Unit test
    
    Closes #37283 from gengliangwang/insertionOverflow.
    
    Authored-by: Gengliang Wang <ge...@apache.org>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 core/src/main/resources/error/error-classes.json   |  6 +++
 .../catalyst/analysis/TableOutputResolver.scala    | 23 ++++++++++-
 .../spark/sql/catalyst/expressions/Cast.scala      | 44 ++++++++++++++++++++++
 .../spark/sql/errors/QueryExecutionErrors.scala    | 15 ++++++++
 .../sql/errors/QueryExecutionAnsiErrorsSuite.scala | 21 ++++++++++-
 .../org/apache/spark/sql/sources/InsertSuite.scala | 20 +++++-----
 6 files changed, 117 insertions(+), 12 deletions(-)

diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json
index 29ca280719e..9d35b1a1a69 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -59,6 +59,12 @@
     ],
     "sqlState" : "22005"
   },
+  "CAST_OVERFLOW_IN_TABLE_INSERT" : {
+    "message" : [
+      "Fail to insert a value of <sourceType> type into the <targetType> type column <columnName> due to an overflow. Use `try_cast` on the input value to tolerate overflow and return NULL instead."
+    ],
+    "sqlState" : "22005"
+  },
   "CONCURRENT_QUERY" : {
     "message" : [
       "Another instance of this query was just started by a concurrent session."
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
index aca99b001d2..b9e3c380216 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
-import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
+import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, IntegralType, MapType, StructType}
 
 object TableOutputResolver {
   def resolveOutputColumns(
@@ -220,6 +220,21 @@ object TableOutputResolver {
     }
   }
 
+  private def containsIntegralOrDecimalType(dt: DataType): Boolean = dt match {
+    case _: IntegralType | _: DecimalType => true
+    case a: ArrayType => containsIntegralOrDecimalType(a.elementType)
+    case m: MapType =>
+      containsIntegralOrDecimalType(m.keyType) || containsIntegralOrDecimalType(m.valueType)
+    case s: StructType =>
+      s.fields.exists(sf => containsIntegralOrDecimalType(sf.dataType))
+    case _ => false
+  }
+
+  private def canCauseCastOverflow(cast: Cast): Boolean = {
+    containsIntegralOrDecimalType(cast.dataType) &&
+      !Cast.canUpCast(cast.child.dataType, cast.dataType)
+  }
+
   private def checkField(
       tableAttr: Attribute,
       queryExpr: NamedExpression,
@@ -238,7 +253,11 @@ object TableOutputResolver {
           val cast = Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone),
             ansiEnabled = true)
           cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
-          cast
+          if (canCauseCastOverflow(cast)) {
+            CheckOverflowInTableInsert(cast, tableAttr.name)
+          } else {
+            cast
+          }
         case StoreAssignmentPolicy.LEGACY =>
           Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone),
             ansiEnabled = false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 82de2a0de14..0ba651b5650 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -21,6 +21,7 @@ import java.time.{ZoneId, ZoneOffset}
 import java.util.Locale
 import java.util.concurrent.TimeUnit._
 
+import org.apache.spark.SparkArithmeticException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
 import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -2360,3 +2361,46 @@ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: S
 
   override protected def withNewChildInternal(newChild: Expression): UpCast = copy(child = newChild)
 }
+
+/**
+ * Casting a numeric value as another numeric type in store assignment. It can capture the
+ * arithmetic errors and show proper error messages to users.
+ */
+case class CheckOverflowInTableInsert(child: Cast, columnName: String) extends UnaryExpression {
+  override protected def withNewChildInternal(newChild: Expression): Expression =
+    copy(child = newChild.asInstanceOf[Cast])
+
+  override def eval(input: InternalRow): Any = try {
+    child.eval(input)
+  } catch {
+    case e: SparkArithmeticException =>
+      QueryExecutionErrors.castingCauseOverflowErrorInTableInsert(
+        child.child.dataType,
+        child.dataType,
+        columnName)
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val childGen = child.genCode(ctx)
+    val exceptionClass = classOf[SparkArithmeticException].getCanonicalName
+    val fromDt =
+      ctx.addReferenceObj("from", child.child.dataType, child.child.dataType.getClass.getName)
+    val toDt = ctx.addReferenceObj("to", child.dataType, child.dataType.getClass.getName)
+    val col = ctx.addReferenceObj("colName", columnName, "java.lang.String")
+    // scalastyle:off line.size.limit
+    ev.copy(code = code"""
+      boolean ${ev.isNull} = true;
+      ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+      try {
+        ${childGen.code}
+        ${ev.isNull} = ${childGen.isNull};
+        ${ev.value} = ${childGen.value};
+      } catch ($exceptionClass e) {
+        throw QueryExecutionErrors.castingCauseOverflowErrorInTableInsert($fromDt, $toDt, $col);
+      }"""
+    )
+    // scalastyle:on line.size.limit
+  }
+
+  override def dataType: DataType = child.dataType
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index e0b08df940d..80918a9d8ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -93,6 +93,21 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase {
       summary = "")
   }
 
+  def castingCauseOverflowErrorInTableInsert(
+      from: DataType,
+      to: DataType,
+      columnName: String): ArithmeticException = {
+    new SparkArithmeticException(
+      errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT",
+      messageParameters = Array(
+        toSQLType(from),
+        toSQLType(to),
+        toSQLId(columnName)),
+      context = None,
+      summary = ""
+    )
+  }
+
   def cannotChangeDecimalPrecisionError(
       value: Decimal,
       decimalPrecision: Int,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
index 31ee6fcde94..8d7359e449d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionAnsiErrorsSuite.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.spark.sql.errors
 
-import org.apache.spark.{SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkConf, SparkDateTimeException, SparkNoSuchElementException, SparkNumberFormatException}
+import org.apache.spark._
 import org.apache.spark.sql.QueryTest
 import org.apache.spark.sql.internal.SQLConf
 
@@ -150,4 +150,23 @@ class QueryExecutionAnsiErrorsSuite extends QueryTest with QueryErrorsSuiteBase
         "ansiConfig" -> ansiConf)
     )
   }
+
+  test("CAST_OVERFLOW_IN_TABLE_INSERT: overflow during table insertion") {
+    Seq("TINYINT", "SMALLINT", "INT", "BIGINT", "DECIMAL(7,2)").foreach { targetType =>
+      val tableName = "overflowTable"
+      withTable(tableName) {
+        sql(s"CREATE TABLE $tableName(i $targetType) USING parquet")
+        checkError(
+          exception = intercept[SparkException] {
+            sql(s"insert into $tableName values 12345678901234567890D")
+          }.getCause.getCause.getCause.asInstanceOf[SparkThrowable],
+          errorClass = "CAST_OVERFLOW_IN_TABLE_INSERT",
+          parameters = Map(
+            "sourceType" -> "\"DOUBLE\"",
+            "targetType" -> ("\"" + targetType + "\""),
+            "columnName" -> "`i`")
+        )
+      }
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 725141eeeeb..7497aa66fa6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -717,18 +717,18 @@ class InsertSuite extends DataSourceTest with SharedSparkSession {
       withTable("t") {
         sql("create table t(b int) using parquet")
         val outOfRangeValue1 = (Int.MaxValue + 1L).toString
+        val expectedMsg = "Fail to insert a value of \"BIGINT\" type into the \"INT\" type column" +
+          " `b` due to an overflow."
         var msg = intercept[SparkException] {
           sql(s"insert into t values($outOfRangeValue1)")
         }.getCause.getMessage
-        assert(msg.contains(
-          s"""The value ${outOfRangeValue1}L of the type "BIGINT" cannot be cast to "INT""""))
+        assert(msg.contains(expectedMsg))
 
         val outOfRangeValue2 = (Int.MinValue - 1L).toString
         msg = intercept[SparkException] {
           sql(s"insert into t values($outOfRangeValue2)")
         }.getCause.getMessage
-        assert(msg.contains(
-          s"""The value ${outOfRangeValue2}L of the type "BIGINT" cannot be cast to "INT""""))
+        assert(msg.contains(expectedMsg))
       }
     }
   }
@@ -739,18 +739,18 @@ class InsertSuite extends DataSourceTest with SharedSparkSession {
       withTable("t") {
         sql("create table t(b long) using parquet")
         val outOfRangeValue1 = Math.nextUp(Long.MaxValue)
+        val expectedMsg = "Fail to insert a value of \"DOUBLE\" type into the \"BIGINT\" type " +
+          "column `b` due to an overflow."
         var msg = intercept[SparkException] {
           sql(s"insert into t values(${outOfRangeValue1}D)")
         }.getCause.getMessage
-        assert(msg.contains(
-          s"""The value ${outOfRangeValue1}D of the type "DOUBLE" cannot be cast to "BIGINT""""))
+        assert(msg.contains(expectedMsg))
 
         val outOfRangeValue2 = Math.nextDown(Long.MinValue)
         msg = intercept[SparkException] {
           sql(s"insert into t values(${outOfRangeValue2}D)")
         }.getCause.getMessage
-        assert(msg.contains(
-          s"""The value ${outOfRangeValue2}D of the type "DOUBLE" cannot be cast to "BIGINT""""))
+        assert(msg.contains(expectedMsg))
       }
     }
   }
@@ -761,10 +761,12 @@ class InsertSuite extends DataSourceTest with SharedSparkSession {
       withTable("t") {
         sql("create table t(b decimal(3,2)) using parquet")
         val outOfRangeValue = "123.45"
+        val expectedMsg = "Fail to insert a value of \"DECIMAL(5,2)\" type into the " +
+          "\"DECIMAL(3,2)\" type column `b` due to an overflow."
         val msg = intercept[SparkException] {
           sql(s"insert into t values(${outOfRangeValue})")
         }.getCause.getMessage
-        assert(msg.contains("cannot be represented as Decimal(3, 2)"))
+        assert(msg.contains(expectedMsg))
       }
     }
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org