You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/03/01 07:50:59 UTC
[spark] branch branch-3.4 updated: [SPARK-42611][SQL] Insert char/varchar length checks for inner fields during resolution
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 09abea00cb0 [SPARK-42611][SQL] Insert char/varchar length checks for inner fields during resolution
09abea00cb0 is described below
commit 09abea00cb0d67336413ac8892617ca824429042
Author: aokolnychyi <ao...@apple.com>
AuthorDate: Wed Mar 1 15:50:10 2023 +0800
[SPARK-42611][SQL] Insert char/varchar length checks for inner fields during resolution
### What changes were proposed in this pull request?
This PR adds char/varchar length checks for inner fields during resolution when struct fields are reordered.
### Why are the changes needed?
These checks are needed to handle nested char/varchar columns correctly. Prior to this change, we would lose the raw type information when constructing nested attributes. As a result, we will not insert proper char/varchar length checks.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
This PR comes with tests that would previously fail.
Closes #40206 from aokolnychyi/spark-42611.
Authored-by: aokolnychyi <ao...@apple.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit d7d8af0dbb47e152b280226a7afcf0771b5a5ae8)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../catalyst/analysis/TableOutputResolver.scala | 62 ++++++++++++++--------
.../apache/spark/sql/CharVarcharTestSuite.scala | 52 ++++++++++++++++++
2 files changed, 93 insertions(+), 21 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
index 61d24964d60..e1ee0defa23 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
@@ -36,20 +36,25 @@ object TableOutputResolver {
byName: Boolean,
conf: SQLConf): LogicalPlan = {
- if (expected.size < query.output.size) {
- throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError(tableName, expected, query)
+ val actualExpectedCols = expected.map { attr =>
+ attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType))
+ }
+
+ if (actualExpectedCols.size < query.output.size) {
+ throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError(
+ tableName, actualExpectedCols, query)
}
val errors = new mutable.ArrayBuffer[String]()
val resolved: Seq[NamedExpression] = if (byName) {
- reorderColumnsByName(query.output, expected, conf, errors += _)
+ reorderColumnsByName(query.output, actualExpectedCols, conf, errors += _)
} else {
- if (expected.size > query.output.size) {
+ if (actualExpectedCols.size > query.output.size) {
throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(
- tableName, expected, query)
+ tableName, actualExpectedCols, query)
}
- query.output.zip(expected).flatMap {
+ query.output.zip(actualExpectedCols).flatMap {
case (queryExpr, tableAttr) =>
checkField(tableAttr, queryExpr, byName, conf, err => errors += err, Seq(tableAttr.name))
}
@@ -254,28 +259,23 @@ object TableOutputResolver {
addError: String => Unit,
colPath: Seq[String]): Option[NamedExpression] = {
+ val attrTypeHasCharVarchar = CharVarcharUtils.hasCharVarchar(tableAttr.dataType)
+ val attrTypeWithoutCharVarchar = if (attrTypeHasCharVarchar) {
+ CharVarcharUtils.replaceCharVarcharWithString(tableAttr.dataType)
+ } else {
+ tableAttr.dataType
+ }
val storeAssignmentPolicy = conf.storeAssignmentPolicy
lazy val outputField = if (tableAttr.dataType.sameType(queryExpr.dataType) &&
tableAttr.name == queryExpr.name &&
tableAttr.metadata == queryExpr.metadata) {
Some(queryExpr)
} else {
- val casted = storeAssignmentPolicy match {
- case StoreAssignmentPolicy.ANSI =>
- val cast = Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone),
- ansiEnabled = true)
- cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
- checkCastOverflowInTableInsert(cast, colPath.quoted)
- case StoreAssignmentPolicy.LEGACY =>
- Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone),
- ansiEnabled = false)
- case _ =>
- Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone))
- }
- val exprWithStrLenCheck = if (conf.charVarcharAsString) {
+ val casted = cast(queryExpr, attrTypeWithoutCharVarchar, conf, colPath.quoted)
+ val exprWithStrLenCheck = if (conf.charVarcharAsString || !attrTypeHasCharVarchar) {
casted
} else {
- CharVarcharUtils.stringLengthCheck(casted, tableAttr)
+ CharVarcharUtils.stringLengthCheck(casted, tableAttr.dataType)
}
// Renaming is needed for handling the following cases like
// 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2
@@ -290,7 +290,7 @@ object TableOutputResolver {
case StoreAssignmentPolicy.STRICT | StoreAssignmentPolicy.ANSI =>
// run the type check first to ensure type errors are present
val canWrite = DataType.canWrite(
- queryExpr.dataType, tableAttr.dataType, byName, conf.resolver, colPath.quoted,
+ queryExpr.dataType, attrTypeWithoutCharVarchar, byName, conf.resolver, colPath.quoted,
storeAssignmentPolicy, addError)
if (queryExpr.nullable && !tableAttr.nullable) {
addError(s"Cannot write nullable values to non-null column '${colPath.quoted}'")
@@ -304,4 +304,24 @@ object TableOutputResolver {
}
}
}
+
+ private def cast(
+ expr: Expression,
+ expectedType: DataType,
+ conf: SQLConf,
+ colName: String): Expression = {
+
+ conf.storeAssignmentPolicy match {
+ case StoreAssignmentPolicy.ANSI =>
+ val cast = Cast(expr, expectedType, Option(conf.sessionLocalTimeZone), ansiEnabled = true)
+ cast.setTagValue(Cast.BY_TABLE_INSERTION, ())
+ checkCastOverflowInTableInsert(cast, colName)
+
+ case StoreAssignmentPolicy.LEGACY =>
+ Cast(expr, expectedType, Option(conf.sessionLocalTimeZone), ansiEnabled = false)
+
+ case _ =>
+ Cast(expr, expectedType, Option(conf.sessionLocalTimeZone))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index c0ceebaa9a6..a6c310cd925 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -926,4 +926,56 @@ class DSV2CharVarcharTestSuite extends CharVarcharTestSuite
}
}
}
+
+ test("SPARK-42611: check char/varchar length in reordered nested structs") {
+ Seq("CHAR(5)", "VARCHAR(5)").foreach { typ =>
+ withTable("t") {
+ sql(s"CREATE TABLE t(s STRUCT<n_c: $typ, n_i: INT>) USING $format")
+
+ val inputDF = sql("SELECT named_struct('n_i', 1, 'n_c', '123456') AS s")
+
+ val e = intercept[RuntimeException](inputDF.writeTo("t").append())
+ assert(e.getMessage.contains("Exceeds char/varchar type length limitation: 5"))
+ }
+ }
+ }
+
+ test("SPARK-42611: check char/varchar length in reordered structs within arrays") {
+ Seq("CHAR(5)", "VARCHAR(5)").foreach { typ =>
+ withTable("t") {
+ sql(s"CREATE TABLE t(a ARRAY<STRUCT<n_c: $typ, n_i: INT>>) USING $format")
+
+ val inputDF = sql("SELECT array(named_struct('n_i', 1, 'n_c', '123456')) AS a")
+
+ val e = intercept[SparkException](inputDF.writeTo("t").append())
+ assert(e.getCause.getMessage.contains("Exceeds char/varchar type length limitation: 5"))
+ }
+ }
+ }
+
+ test("SPARK-42611: check char/varchar length in reordered structs within map keys") {
+ Seq("CHAR(5)", "VARCHAR(5)").foreach { typ =>
+ withTable("t") {
+ sql(s"CREATE TABLE t(m MAP<STRUCT<n_c: $typ, n_i: INT>, INT>) USING $format")
+
+ val inputDF = sql("SELECT map(named_struct('n_i', 1, 'n_c', '123456'), 1) AS m")
+
+ val e = intercept[SparkException](inputDF.writeTo("t").append())
+ assert(e.getCause.getMessage.contains("Exceeds char/varchar type length limitation: 5"))
+ }
+ }
+ }
+
+ test("SPARK-42611: check char/varchar length in reordered structs within map values") {
+ Seq("CHAR(5)", "VARCHAR(5)").foreach { typ =>
+ withTable("t") {
+ sql(s"CREATE TABLE t(m MAP<INT, STRUCT<n_c: $typ, n_i: INT>>) USING $format")
+
+ val inputDF = sql("SELECT map(1, named_struct('n_i', 1, 'n_c', '123456')) AS m")
+
+ val e = intercept[SparkException](inputDF.writeTo("t").append())
+ assert(e.getCause.getMessage.contains("Exceeds char/varchar type length limitation: 5"))
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org