You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/01 07:50:59 UTC

[spark] branch branch-3.4 updated: [SPARK-42611][SQL] Insert char/varchar length checks for inner fields during resolution

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 09abea00cb0 [SPARK-42611][SQL] Insert char/varchar length checks for inner fields during resolution
09abea00cb0 is described below

commit 09abea00cb0d67336413ac8892617ca824429042
Author: aokolnychyi <ao...@apple.com>
AuthorDate: Wed Mar 1 15:50:10 2023 +0800

    [SPARK-42611][SQL] Insert char/varchar length checks for inner fields during resolution
    
    ### What changes were proposed in this pull request?
    
    This PR adds  char/varchar length checks for inner fields during resolution when struct fields are reordered.
    
    ### Why are the changes needed?
    
    These checks are needed to handle nested char/varchar columns correctly. Prior to this change, we would lose the raw type information when constructing nested attributes. As a result, we will not insert proper char/varchar length checks.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    This PR comes with tests that would previously fail.
    
    Closes #40206 from aokolnychyi/spark-42611.
    
    Authored-by: aokolnychyi <ao...@apple.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit d7d8af0dbb47e152b280226a7afcf0771b5a5ae8)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../catalyst/analysis/TableOutputResolver.scala    | 62 ++++++++++++++--------
 .../apache/spark/sql/CharVarcharTestSuite.scala    | 52 ++++++++++++++++++
 2 files changed, 93 insertions(+), 21 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
index 61d24964d60..e1ee0defa23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
@@ -36,20 +36,25 @@ object TableOutputResolver {
       byName: Boolean,
       conf: SQLConf): LogicalPlan = {
 
-    if (expected.size < query.output.size) {
-      throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError(tableName, expected, query)
+    val actualExpectedCols = expected.map { attr =>
+      attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType))
+    }
+
+    if (actualExpectedCols.size < query.output.size) {
+      throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError(
+        tableName, actualExpectedCols, query)
     }
 
     val errors = new mutable.ArrayBuffer[String]()
     val resolved: Seq[NamedExpression] = if (byName) {
-      reorderColumnsByName(query.output, expected, conf, errors += _)
+      reorderColumnsByName(query.output, actualExpectedCols, conf, errors += _)
     } else {
-      if (expected.size > query.output.size) {
+      if (actualExpectedCols.size > query.output.size) {
         throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(
-          tableName, expected, query)
+          tableName, actualExpectedCols, query)
       }
 
-      query.output.zip(expected).flatMap {
+      query.output.zip(actualExpectedCols).flatMap {
         case (queryExpr, tableAttr) =>
           checkField(tableAttr, queryExpr, byName, conf, err => errors += err, Seq(tableAttr.name))
       }
@@ -254,28 +259,23 @@ object TableOutputResolver {
       addError: String => Unit,
       colPath: Seq[String]): Option[NamedExpression] = {
 
+    val attrTypeHasCharVarchar = CharVarcharUtils.hasCharVarchar(tableAttr.dataType)
+    val attrTypeWithoutCharVarchar = if (attrTypeHasCharVarchar) {
+      CharVarcharUtils.replaceCharVarcharWithString(tableAttr.dataType)
+    } else {
+      tableAttr.dataType
+    }
     val storeAssignmentPolicy = conf.storeAssignmentPolicy
     lazy val outputField = if (tableAttr.dataType.sameType(queryExpr.dataType) &&
       tableAttr.name == queryExpr.name &&
       tableAttr.metadata == queryExpr.metadata) {
       Some(queryExpr)
     } else {
-      val casted = storeAssignmentPolicy match {
-        case StoreAssignmentPolicy.ANSI =>
-          val cast = Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone),
-            ansiEnabled = true)
-          cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
-          checkCastOverflowInTableInsert(cast, colPath.quoted)
-        case StoreAssignmentPolicy.LEGACY =>
-          Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone),
-            ansiEnabled = false)
-        case _ =>
-          Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone))
-      }
-      val exprWithStrLenCheck = if (conf.charVarcharAsString) {
+      val casted = cast(queryExpr, attrTypeWithoutCharVarchar, conf, colPath.quoted)
+      val exprWithStrLenCheck = if (conf.charVarcharAsString || !attrTypeHasCharVarchar) {
         casted
       } else {
-        CharVarcharUtils.stringLengthCheck(casted, tableAttr)
+        CharVarcharUtils.stringLengthCheck(casted, tableAttr.dataType)
       }
       // Renaming is needed for handling the following cases like
       // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2
@@ -290,7 +290,7 @@ object TableOutputResolver {
       case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI =>
         // run the type check first to ensure type errors are present
         val canWrite = DataType.canWrite(
-          queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, colPath.quoted,
+          queryExpr.dataType, attrTypeWithoutCharVarchar, byName, conf.resolver, colPath.quoted,
           storeAssignmentPolicy, addError)
         if (queryExpr.nullable && !tableAttr.nullable) {
           addError(s"Cannot write nullable values to non-null column '${colPath.quoted}'")
@@ -304,4 +304,24 @@ object TableOutputResolver {
         }
     }
   }
+
+  private def cast(
+      expr: Expression,
+      expectedType: DataType,
+      conf: SQLConf,
+      colName: String): Expression = {
+
+    conf.storeAssignmentPolicy match {
+      case StoreAssignmentPolicy.ANSI =>
+        val cast = Cast(expr, expectedType, Option(conf.sessionLocalTimeZone), ansiEnabled = true)
+        cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
+        checkCastOverflowInTableInsert(cast, colName)
+
+      case StoreAssignmentPolicy.LEGACY =>
+        Cast(expr, expectedType, Option(conf.sessionLocalTimeZone), ansiEnabled = false)
+
+      case _ =>
+        Cast(expr, expectedType, Option(conf.sessionLocalTimeZone))
+    }
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index c0ceebaa9a6..a6c310cd925 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -926,4 +926,56 @@ class DSV2CharVarcharTestSuite extends CharVarcharTestSuite
       }
     }
   }
+
+  test("SPARK-42611: check char/varchar length in reordered nested structs") {
+    Seq("CHAR(5)", "VARCHAR(5)").foreach { typ =>
+      withTable("t") {
+        sql(s"CREATE TABLE t(s STRUCT<n_c: $typ, n_i: INT>) USING $format")
+
+        val inputDF = sql("SELECT named_struct('n_i', 1, 'n_c', '123456') AS s")
+
+        val e = intercept[RuntimeException](inputDF.writeTo("t").append())
+        assert(e.getMessage.contains("Exceeds char/varchar type length limitation: 5"))
+      }
+    }
+  }
+
+  test("SPARK-42611: check char/varchar length in reordered structs within arrays") {
+    Seq("CHAR(5)", "VARCHAR(5)").foreach { typ =>
+      withTable("t") {
+        sql(s"CREATE TABLE t(a ARRAY<STRUCT<n_c: $typ, n_i: INT>>) USING $format")
+
+        val inputDF = sql("SELECT array(named_struct('n_i', 1, 'n_c', '123456')) AS a")
+
+        val e = intercept[SparkException](inputDF.writeTo("t").append())
+        assert(e.getCause.getMessage.contains("Exceeds char/varchar type length limitation: 5"))
+      }
+    }
+  }
+
+  test("SPARK-42611: check char/varchar length in reordered structs within map keys") {
+    Seq("CHAR(5)", "VARCHAR(5)").foreach { typ =>
+      withTable("t") {
+        sql(s"CREATE TABLE t(m MAP<STRUCT<n_c: $typ, n_i: INT>, INT>) USING $format")
+
+        val inputDF = sql("SELECT map(named_struct('n_i', 1, 'n_c', '123456'), 1) AS m")
+
+        val e = intercept[SparkException](inputDF.writeTo("t").append())
+        assert(e.getCause.getMessage.contains("Exceeds char/varchar type length limitation: 5"))
+      }
+    }
+  }
+
+  test("SPARK-42611: check char/varchar length in reordered structs within map values") {
+    Seq("CHAR(5)", "VARCHAR(5)").foreach { typ =>
+      withTable("t") {
+        sql(s"CREATE TABLE t(m MAP<INT, STRUCT<n_c: $typ, n_i: INT>>) USING $format")
+
+        val inputDF = sql("SELECT map(1, named_struct('n_i', 1, 'n_c', '123456')) AS m")
+
+        val e = intercept[SparkException](inputDF.writeTo("t").append())
+        assert(e.getCause.getMessage.contains("Exceeds char/varchar type length limitation: 5"))
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org