You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/07/27 05:58:18 UTC

[spark] branch branch-3.2 updated: [SPARK-36247][SQL] Check string length for char/varchar and apply type coercion in UPDATE/MERGE command

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 14328e0  [SPARK-36247][SQL] Check string length for char/varchar and apply type coercion in UPDATE/MERGE command
14328e0 is described below

commit 14328e043d0233800869d5435291b3c0d4a65aa1
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Tue Jul 27 13:57:05 2021 +0800

    [SPARK-36247][SQL] Check string length for char/varchar and apply type coercion in UPDATE/MERGE command
    
    ### What changes were proposed in this pull request?
    
    We added the char/varchar support in 3.1, but the string length check is only applied to INSERT, not UPDATE/MERGE. This PR fixes it. This PR also adds the missing type coercion for UPDATE/MERGE.
    
    ### Why are the changes needed?
    
    complete the char/varchar support and make UPDATE/MERGE easier to use by doing type coercion.
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    new UT. No built-in source support UPDATE/MERGE so end-to-end test is not applicable here.
    
    Closes #33468 from cloud-fan/char.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 068f8d434ad9a6651006151de521d0799db8af52)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 43 +++++++++---
 .../spark/sql/catalyst/analysis/unresolved.scala   |  1 +
 .../catalyst/expressions/namedExpressions.scala    |  7 ++
 .../sql/catalyst/plans/logical/v2Commands.scala    | 17 ++++-
 .../spark/sql/catalyst/util/CharVarcharUtils.scala |  2 +-
 .../execution/command/PlanResolutionSuite.scala    | 76 ++++++++++++++++++++--
 6 files changed, 130 insertions(+), 16 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ed7ad7f..ee7b342 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3300,6 +3300,41 @@ class Analyzer(override val catalogManager: CatalogManager)
         } else {
           v2Write
         }
+
+      case u: UpdateTable if !u.skipSchemaResolution && u.resolved =>
+        resolveAssignments(u)
+
+      case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved =>
+        resolveAssignments(m)
+    }
+
+    private def resolveAssignments(p: LogicalPlan): LogicalPlan = {
+      p.transformExpressions {
+        case assignment: Assignment =>
+          val nullHandled = if (!assignment.key.nullable && assignment.value.nullable) {
+            AssertNotNull(assignment.value)
+          } else {
+            assignment.value
+          }
+          val casted = if (assignment.key.dataType != nullHandled.dataType) {
+            AnsiCast(nullHandled, assignment.key.dataType)
+          } else {
+            nullHandled
+          }
+          val rawKeyType = assignment.key.transform {
+            case a: AttributeReference =>
+              CharVarcharUtils.getRawType(a.metadata).map(a.withDataType).getOrElse(a)
+          }.dataType
+          val finalValue = if (CharVarcharUtils.hasCharVarchar(rawKeyType)) {
+            CharVarcharUtils.stringLengthCheck(casted, rawKeyType)
+          } else {
+            casted
+          }
+          val cleanedKey = assignment.key.transform {
+            case a: AttributeReference => CharVarcharUtils.cleanAttrMetadata(a)
+          }
+          Assignment(cleanedKey, finalValue)
+      }
     }
   }
 
@@ -4218,14 +4253,6 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
     }
   }
 
-  private def padOuterRefAttrCmp(outerAttr: Attribute, attr: Attribute): Seq[Expression] = {
-    val Seq(r, newAttr) = CharVarcharUtils.addPaddingInStringComparison(Seq(outerAttr, attr))
-    val newOuterRef = r.transform {
-      case ar: Attribute if ar.semanticEquals(outerAttr) => OuterReference(ar)
-    }
-    Seq(newOuterRef, newAttr)
-  }
-
   private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = {
     if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 29d5410..9f05367 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -168,6 +168,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
   override def withName(newName: String): UnresolvedAttribute = UnresolvedAttribute.quoted(newName)
   override def withMetadata(newMetadata: Metadata): Attribute = this
   override def withExprId(newExprId: ExprId): UnresolvedAttribute = this
+  override def withDataType(newType: DataType): Attribute = this
   final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_ATTRIBUTE)
 
   override def toString: String = s"'$name"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 2b8265f..ae2c66c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -123,6 +123,7 @@ abstract class Attribute extends LeafExpression with NamedExpression with NullIn
   def withName(newName: String): Attribute
   def withMetadata(newMetadata: Metadata): Attribute
   def withExprId(newExprId: ExprId): Attribute
+  def withDataType(newType: DataType): Attribute
 
   override def toAttribute: Attribute = this
   def newInstance(): Attribute
@@ -339,6 +340,10 @@ case class AttributeReference(
     AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
   }
 
+  override def withDataType(newType: DataType): Attribute = {
+    AttributeReference(name, newType, nullable, metadata)(exprId, qualifier)
+  }
+
   override protected final def otherCopyArgs: Seq[AnyRef] = {
     exprId :: qualifier :: Nil
   }
@@ -395,6 +400,8 @@ case class PrettyAttribute(
   override def exprId: ExprId = throw new UnsupportedOperationException
   override def withExprId(newExprId: ExprId): Attribute =
     throw new UnsupportedOperationException
+  override def withDataType(newType: DataType): Attribute =
+    throw new UnsupportedOperationException
   override def nullable: Boolean = true
 }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index 3d88d62..fa897a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -425,6 +425,12 @@ case class UpdateTable(
   override def child: LogicalPlan = table
   override protected def withNewChildInternal(newChild: LogicalPlan): UpdateTable =
     copy(table = newChild)
+
+  def skipSchemaResolution: Boolean = table match {
+    case r: NamedRelation => r.skipSchemaResolution
+    case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
+    case _ => false
+  }
 }
 
 /**
@@ -437,6 +443,13 @@ case class MergeIntoTable(
     matchedActions: Seq[MergeAction],
     notMatchedActions: Seq[MergeAction]) extends BinaryCommand with SupportsSubquery {
   def duplicateResolved: Boolean = targetTable.outputSet.intersect(sourceTable.outputSet).isEmpty
+
+  def skipSchemaResolution: Boolean = targetTable match {
+    case r: NamedRelation => r.skipSchemaResolution
+    case SubqueryAlias(_, r: NamedRelation) => r.skipSchemaResolution
+    case _ => false
+  }
+
   override def left: LogicalPlan = targetTable
   override def right: LogicalPlan = sourceTable
   override protected def withNewChildrenInternal(
@@ -466,7 +479,7 @@ case class UpdateAction(
       newChildren: IndexedSeq[Expression]): UpdateAction =
     copy(
       condition = if (condition.isDefined) Some(newChildren.head) else None,
-      assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
+      assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
 }
 
 case class UpdateStarAction(condition: Option[Expression]) extends MergeAction {
@@ -485,7 +498,7 @@ case class InsertAction(
       newChildren: IndexedSeq[Expression]): InsertAction =
     copy(
       condition = if (condition.isDefined) Some(newChildren.head) else None,
-      assignments = newChildren.tail.asInstanceOf[Seq[Assignment]])
+      assignments = newChildren.takeRight(assignments.length).asInstanceOf[Seq[Assignment]])
 }
 
 case class InsertStarAction(condition: Option[Expression]) extends MergeAction {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
index a566775..3094b5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
@@ -149,7 +149,7 @@ object CharVarcharUtils extends Logging {
     }.getOrElse(expr)
   }
 
-  private def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
+  def stringLengthCheck(expr: Expression, dt: DataType): Expression = {
     dt match {
       case CharType(length) =>
         StaticInvoke(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
index e714fb3..25a8c4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala
@@ -28,7 +28,8 @@ import org.apache.spark.sql.{AnalysisException, SaveMode}
 import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier}
 import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable}
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog}
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
+import org.apache.spark.sql.catalyst.expressions.{AnsiCast, AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral}
+import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
 import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
 import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTableAsSelect, CreateTableStatement, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -75,6 +76,13 @@ class PlanResolutionSuite extends AnalysisTest {
     t
   }
 
+  private val charVarcharTable: Table = {
+    val t = mock(classOf[Table])
+    when(t.schema()).thenReturn(new StructType().add("c1", "char(5)").add("c2", "varchar(5)"))
+    when(t.partitioning()).thenReturn(Array.empty[Transform])
+    t
+  }
+
   private val v1Table: V1Table = {
     val t = mock(classOf[CatalogTable])
     when(t.schema).thenReturn(new StructType()
@@ -109,6 +117,7 @@ class PlanResolutionSuite extends AnalysisTest {
         case "tab" => table
         case "tab1" => table1
         case "tab2" => table2
+        case "charvarchar" => charVarcharTable
         case name => throw new NoSuchTableException(name)
       }
     })
@@ -1058,12 +1067,33 @@ class PlanResolutionSuite extends AnalysisTest {
       }
     }
 
-    val sql = "UPDATE non_existing SET id=1"
-    val parsed = parseAndResolve(sql)
-    parsed match {
+    val sql1 = "UPDATE non_existing SET id=1"
+    val parsed1 = parseAndResolve(sql1)
+    parsed1 match {
       case u: UpdateTable =>
         assert(u.table.isInstanceOf[UnresolvedRelation])
-      case _ => fail("Expect UpdateTable, but got:\n" + parsed.treeString)
+      case _ => fail("Expect UpdateTable, but got:\n" + parsed1.treeString)
+    }
+
+    val sql2 = "UPDATE testcat.charvarchar SET c1='a', c2=1"
+    val parsed2 = parseAndResolve(sql2)
+    parsed2 match {
+      case u: UpdateTable =>
+        assert(u.assignments.length == 2)
+        u.assignments(0).value match {
+          case s: StaticInvoke =>
+            assert(s.arguments.length == 2)
+            assert(s.functionName == "charTypeWriteSideCheck")
+          case other => fail("Expect StaticInvoke, but got: " + other)
+        }
+        u.assignments(1).value match {
+          case s: StaticInvoke =>
+            assert(s.arguments.length == 2)
+            assert(s.arguments.head.isInstanceOf[AnsiCast])
+            assert(s.functionName == "varcharTypeWriteSideCheck")
+          case other => fail("Expect StaticInvoke, but got: " + other)
+        }
+      case _ => fail("Expect UpdateTable, but got:\n" + parsed2.treeString)
     }
   }
 
@@ -1568,6 +1598,42 @@ class PlanResolutionSuite extends AnalysisTest {
     val e3 = intercept[AnalysisException](parseAndResolve(sql3))
     assert(e3.message.contains(
       "cannot resolve s in MERGE command given columns [testcat.tab2.i, testcat.tab2.x]"))
+
+    val sql4 =
+      """
+        |MERGE INTO testcat.charvarchar
+        |USING testcat.tab2
+        |ON 1 = 1
+        |WHEN MATCHED THEN UPDATE SET c1='a', c2=1
+        |WHEN NOT MATCHED THEN INSERT (c1, c2) VALUES ('b', 2)
+        |""".stripMargin
+    val parsed4 = parseAndResolve(sql4)
+    parsed4 match {
+      case m: MergeIntoTable =>
+        assert(m.matchedActions.length == 1)
+        m.matchedActions.head match {
+          case UpdateAction(_, Seq(
+          Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
+            assert(s1.arguments.length == 2)
+            assert(s1.functionName == "charTypeWriteSideCheck")
+            assert(s2.arguments.length == 2)
+            assert(s2.arguments.head.isInstanceOf[AnsiCast])
+            assert(s2.functionName == "varcharTypeWriteSideCheck")
+          case other => fail("Expect UpdateAction, but got: " + other)
+        }
+        assert(m.notMatchedActions.length == 1)
+        m.notMatchedActions.head match {
+          case InsertAction(_, Seq(
+          Assignment(_, s1: StaticInvoke), Assignment(_, s2: StaticInvoke))) =>
+            assert(s1.arguments.length == 2)
+            assert(s1.functionName == "charTypeWriteSideCheck")
+            assert(s2.arguments.length == 2)
+            assert(s2.arguments.head.isInstanceOf[AnsiCast])
+            assert(s2.functionName == "varcharTypeWriteSideCheck")
+          case other => fail("Expect UpdateAction, but got: " + other)
+        }
+      case other => fail("Expect MergeIntoTable, but got:\n" + other.treeString)
+    }
   }
 
   test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org