You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/07/27 05:58:18 UTC
[spark] branch branch-3.2 updated: [SPARK-36247][SQL] Check string
length for char/varchar and apply type coercion in UPDATE/MERGE command
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 14328e0 [SPARK-36247][SQL] Check string length for char/varchar and apply type coercion in UPDATE/MERGE command
14328e0 is described below
commit 14328e043d0233800869d5435291b3c0d4a65aa1
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Tue Jul 27 13:57:05 2021 +0800
[SPARK-36247][SQL] Check string length for char/varchar and apply type coercion in UPDATE/MERGE command
### What changes were proposed in this pull request?
We added the char/varchar support in 3.1, but the string length check is only applied to INSERT, not UPDATE/MERGE. This PR fixes it. This PR also adds the missing type coercion for UPDATE/MERGE.
### Why are the changes needed?
complete the char/varchar support and make UPDATE/MERGE easier to use by doing type coercion.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
new UT. No built-in source support UPDATE/MERGE so end-to-end test is not applicable here.
Closes #33468 from cloud-fan/char.
Authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit 068f8d434ad9a6651006151de521d0799db8af52)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 43 +++++++++---
.../spark/sql/catalyst/analysis/unresolved.scala | 1 +
.../catalyst/expressions/namedExpressions.scala | 7 ++
.../sql/catalyst/plans/logical/v2Commands.scala | 17 ++++-
.../spark/sql/catalyst/util/CharVarcharUtils.scala | 2 +-
.../execution/command/PlanResolutionSuite.scala | 76 ++++++++++++++++++++--
6 files changed, 130 insertions(+), 16 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ed7ad7f..ee7b342 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3300,6 +3300,41 @@ class Analyzer(override val catalogManager: CatalogManager)
} else {
v2Write
}
+
+ case u: UpdateTable if !u.skipSchemaResolution && u.resolved =>
+ resolveAssignments(u)
+
+ case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
+ resolveAssignments(m)
+ }
+
+ private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
+ p.transformExpressions {
+ case assignment: Assignment =>
+ val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
+ AssertNotNull(assignment.value)
+ } else {
+ assignment.value
+ }
+ val casted = if (assignment.key.dataType != nullHandled.dataType) {
+ AnsiCast(nullHandled, assignment.key.dataType)
+ } else {
+ nullHandled
+ }
+ val rawKeyType = assignment.key.transform {
+ case a: AttributeReference =>
+ CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
+ }.dataType
+ val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
+ CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
+ } else {
+ casted
+ }
+ val cleanedKey = assignment.key.transform {
+ case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
+ }
+ Assignment(cleanedKey, finalValue)
+ }
}
}
@@ -4218,14 +4253,6 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
}
}
- private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = {
- val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr))
- val newOuterRef = r.transform {
- case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar)
- }
- Seq(newOuterRef, newAttr)
- }
-
private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = {
if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 29d5410..9f05367 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -168,6 +168,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
override def withMetadata(newMetadata: Metadata): Attribute = this
override def withExprId(newExprId: ExprId): UnresolvedAttribute = this
+ override def withDataType(newType: DataType): Attribute = this
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE)
override def toString: String = s"'$name"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 2b8265f..ae2c66c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -123,6 +123,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
def withName(newName: String): Attribute
def withMetadata(newMetadata: Metadata): Attribute
def withExprId(newExprId: ExprId): Attribute
+ def withDataType(newType: DataType): Attribute
override def toAttribute: Attribute = this
def newInstance(): Attribute
@@ -339,6 +340,10 @@ case class AttributeReference(
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
}
+ override def withDataType(newType: DataType): Attribute = {
+ AttributeReference(name, newType, nullable, metadata)(exprId, qualifier)
+ }
+
override protected final def otherCopyArgs: Seq[AnyRef] = {
exprId :: qualifier :: Nil
}
@@ -395,6 +400,8 @@ case class PrettyAttribute(
override def exprId: ExprId = throw new UnsupportedOperationException
override def withExprId(newExprId: ExprId): Attribute =
throw new UnsupportedOperationException
+ override def withDataType(newType: DataType): Attribute =
+ throw new UnsupportedOperationException
override def nullable: Boolean = true
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 3d88d62..fa897a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -425,6 +425,12 @@ case class UpdateTable(
override def child: LogicalPlan = table
override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable =
copy(table = newChild)
+
+ def skipSchemaResolution: Boolean = table match {
+ case r: NamedRelation => r.skipSchemaResolution
+ case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
+ case _ => false
+ }
}
/**
@@ -437,6 +443,13 @@ case class MergeIntoTable(
matchedActions: Seq[MergeAction],
notMatchedActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery {
def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty
+
+ def skipSchemaResolution: Boolean = targetTable match {
+ case r: NamedRelation => r.skipSchemaResolution
+ case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
+ case _ => false
+ }
+
override def left: LogicalPlan = targetTable
override def right: LogicalPlan = sourceTable
override protected def withNewChildrenInternal(
@@ -466,7 +479,7 @@ case class UpdateAction(
newChildren: IndexedSeq[Expression]): UpdateAction =
copy(
condition = if (condition.isDefined) Some(newChildren.head) else None,
- assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
+ assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
}
case class UpdateStarAction(condition: Option[Expression]) extends MergeAction {
@@ -485,7 +498,7 @@ case class InsertAction(
newChildren: IndexedSeq[Expression]): InsertAction =
copy(
condition = if (condition.isDefined) Some(newChildren.head) else None,
- assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
+ assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
}
case class InsertStarAction(condition: Option[Expression]) extends MergeAction {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
index a566775..3094b5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
@@ -149,7 +149,7 @@ object CharVarcharUtils extends Logging {
}.getOrElse(expr)
}
- private def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
+ def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
dt match {
case CharType(length) =>
StaticInvoke(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
index e714fb3..25a8c4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -28,7 +28,8 @@ import org.apache.spark.sql.{AnalysisException, SaveMode}
import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
+import org.apache.spark.sql.catalyst.expressions.{AnsiCast, AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable}
import org.apache.spark.sql.catalyst.rules.Rule
@@ -75,6 +76,13 @@ class PlanResolutionSuite extends AnalysisTest {
t
}
+ private val charVarcharTable: Table = {
+ val t = mock(classOf[Table])
+ when(t.schema()).thenReturn(new StructType().add("c1", "char(5)").add("c2", "varchar(5)"))
+ when(t.partitioning()).thenReturn(Array.empty[Transform])
+ t
+ }
+
private val v1Table: V1Table = {
val t = mock(classOf[CatalogTable])
when(t.schema).thenReturn(new StructType()
@@ -109,6 +117,7 @@ class PlanResolutionSuite extends AnalysisTest {
case "tab" => table
case "tab1" => table1
case "tab2" => table2
+ case "charvarchar" => charVarcharTable
case name => throw new NoSuchTableException(name)
}
})
@@ -1058,12 +1067,33 @@ class PlanResolutionSuite extends AnalysisTest {
}
}
- val sql = "UPDATE non_existing SET id=1"
- val parsed = parseAndResolve(sql)
- parsed match {
+ val sql1 = "UPDATE non_existing SET id=1"
+ val parsed1 = parseAndResolve(sql1)
+ parsed1 match {
case u: UpdateTable =>
assert(u.table.isInstanceOf[UnresolvedRelation])
- case _ => fail("Expect UpdateTable, but got:\n" + parsed.treeString)
+ case _ => fail("Expect UpdateTable, but got:\n" + parsed1.treeString)
+ }
+
+ val sql2 = "UPDATE testcat.charvarchar SET c1='a', c2=1"
+ val parsed2 = parseAndResolve(sql2)
+ parsed2 match {
+ case u: UpdateTable =>
+ assert(u.assignments.length == 2)
+ u.assignments(0).value match {
+ case s: StaticInvoke =>
+ assert(s.arguments.length == 2)
+ assert(s.functionName == "charTypeWriteSideCheck")
+ case other => fail("Expect StaticInvoke, but got: " + other)
+ }
+ u.assignments(1).value match {
+ case s: StaticInvoke =>
+ assert(s.arguments.length == 2)
+ assert(s.arguments.head.isInstanceOf[AnsiCast])
+ assert(s.functionName == "varcharTypeWriteSideCheck")
+ case other => fail("Expect StaticInvoke, but got: " + other)
+ }
+ case _ => fail("Expect UpdateTable, but got:\n" + parsed2.treeString)
}
}
@@ -1568,6 +1598,42 @@ class PlanResolutionSuite extends AnalysisTest {
val e3 = intercept[AnalysisException](parseAndResolve(sql3))
assert(e3.message.contains(
"cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]"))
+
+ val sql4 =
+ """
+ |MERGE INTO testcat.charvarchar
+ |USING testcat.tab2
+ |ON 1 = 1
+ |WHEN MATCHED THEN UPDATE SET c1='a', c2=1
+ |WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES ('b', 2)
+ |""".stripMargin
+ val parsed4 = parseAndResolve(sql4)
+ parsed4 match {
+ case m: MergeIntoTable =>
+ assert(m.matchedActions.length == 1)
+ m.matchedActions.head match {
+ case UpdateAction(_, Seq(
+ Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
+ assert(s1.arguments.length == 2)
+ assert(s1.functionName == "charTypeWriteSideCheck")
+ assert(s2.arguments.length == 2)
+ assert(s2.arguments.head.isInstanceOf[AnsiCast])
+ assert(s2.functionName == "varcharTypeWriteSideCheck")
+ case other => fail("Expect UpdateAction, but got: " + other)
+ }
+ assert(m.notMatchedActions.length == 1)
+ m.notMatchedActions.head match {
+ case InsertAction(_, Seq(
+ Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
+ assert(s1.arguments.length == 2)
+ assert(s1.functionName == "charTypeWriteSideCheck")
+ assert(s2.arguments.length == 2)
+ assert(s2.arguments.head.isInstanceOf[AnsiCast])
+ assert(s2.functionName == "varcharTypeWriteSideCheck")
+ case other => fail("Expect UpdateAction, but got: " + other)
+ }
+ case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)
+ }
}
test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org