You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@hudi.apache.org by xu...@apache.org on 2021/11/05 14:50:30 UTC

[hudi] branch master updated: [HUDI-2471] Add support ignoring case in merge into (#3700)

This is an automated email from the ASF dual-hosted git repository.

xushiyan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/hudi.git


The following commit(s) were added to refs/heads/master by this push:
     new 844346c  [HUDI-2471] Add support ignoring case in merge into (#3700)
844346c is described below

commit 844346c3ab100a857b547137aca003c45e523eb9
Author: 董可伦 <do...@inspur.com>
AuthorDate: Fri Nov 5 22:50:16 2021 +0800

    [HUDI-2471] Add support ignoring case in merge into (#3700)
---
 .../spark/sql/hudi/analysis/HoodieAnalysis.scala   |  29 +++---
 .../hudi/command/MergeIntoHoodieTableCommand.scala |  31 ++++--
 .../spark/sql/hudi/TestMergeIntoTable2.scala       | 111 +++++++++++++++++++++
 3 files changed, 147 insertions(+), 24 deletions(-)

diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala
index 09e0314..87cbb8a 100644
--- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala
+++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala
@@ -125,6 +125,7 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
     case mergeInto @ MergeIntoTable(target, source, mergeCondition, matchedActions, notMatchedActions)
       if isHoodieTable(target, sparkSession) && target.resolved =>
 
+      val resolver = sparkSession.sessionState.conf.resolver
       val resolvedSource = analyzer.execute(source)
       def isInsertOrUpdateStar(assignments: Seq[Assignment]): Boolean = {
         if (assignments.isEmpty) {
@@ -161,23 +162,21 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
         val resolvedCondition = condition.map(resolveExpressionFrom(resolvedSource)(_))
         val resolvedAssignments = if (isInsertOrUpdateStar(assignments)) {
           // assignments is empty means insert * or update set *
-          val resolvedSourceOutputWithoutMetaFields = resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
-          val targetOutputWithoutMetaFields = target.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
-          val resolvedSourceColumnNamesWithoutMetaFields = resolvedSourceOutputWithoutMetaFields.map(_.name)
-          val targetColumnNamesWithoutMetaFields = targetOutputWithoutMetaFields.map(_.name)
+          val resolvedSourceOutput = resolvedSource.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
+          val targetOutput = target.output.filter(attr => !HoodieSqlUtils.isMetaField(attr.name))
+          val resolvedSourceColumnNames = resolvedSourceOutput.map(_.name)
 
-          if(targetColumnNamesWithoutMetaFields.toSet.subsetOf(resolvedSourceColumnNamesWithoutMetaFields.toSet)){
+          if(targetOutput.filter(attr => resolvedSourceColumnNames.exists(resolver(_, attr.name))).equals(targetOutput)){
             //If sourceTable's columns contains all targetTable's columns,
             //We fill assign all the source fields to the target fields by column name matching.
-            val sourceColNameAttrMap = resolvedSourceOutputWithoutMetaFields.map(attr => (attr.name, attr)).toMap
-            targetOutputWithoutMetaFields.map(targetAttr => {
-              val sourceAttr = sourceColNameAttrMap(targetAttr.name)
+            targetOutput.map(targetAttr => {
+              val sourceAttr = resolvedSourceOutput.find(f => resolver(f.name, targetAttr.name)).get
               Assignment(targetAttr, sourceAttr)
             })
           } else {
             // We fill assign all the source fields to the target fields by order.
-            targetOutputWithoutMetaFields
-              .zip(resolvedSourceOutputWithoutMetaFields)
+            targetOutput
+              .zip(resolvedSourceOutput)
               .map { case (targetAttr, sourceAttr) => Assignment(targetAttr, sourceAttr) }
           }
         } else {
@@ -214,8 +213,9 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
           }.toMap
 
           // Validate if there are incorrect target attributes.
+          val targetColumnNames = removeMetaFields(target.output).map(_.name)
           val unKnowTargets = target2Values.keys
-            .filterNot(removeMetaFields(target.output).map(_.name).contains(_))
+            .filterNot(name => targetColumnNames.exists(resolver(_, name)))
           if (unKnowTargets.nonEmpty) {
             throw new AnalysisException(s"Cannot find target attributes: ${unKnowTargets.mkString(",")}.")
           }
@@ -224,19 +224,20 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi
           // e.g. If the update action missing 'id' attribute, we fill a "id = target.id" to the update action.
           val newAssignments = removeMetaFields(target.output)
             .map(attr => {
+              val valueOption = target2Values.find(f => resolver(f._1, attr.name))
               // TODO support partial update for MOR.
-              if (!target2Values.contains(attr.name) && targetTableType == MOR_TABLE_TYPE_OPT_VAL) {
+              if (valueOption.isEmpty && targetTableType == MOR_TABLE_TYPE_OPT_VAL) {
                 throw new AnalysisException(s"Missing specify the value for target field: '${attr.name}' in merge into update action" +
                   s" for MOR table. Currently we cannot support partial update for MOR," +
                   s" please complete all the target fields just like '...update set id = s0.id, name = s0.name ....'")
               }
               if (preCombineField.isDefined && preCombineField.get.equalsIgnoreCase(attr.name)
-                  && !target2Values.contains(attr.name)) {
+                  && valueOption.isEmpty) {
                 throw new AnalysisException(s"Missing specify value for the preCombineField:" +
                   s" ${preCombineField.get} in merge-into update action. You should add" +
                   s" '... update set ${preCombineField.get} = xx....' to the when-matched clause.")
               }
-              Assignment(attr, target2Values.getOrElse(attr.name, attr))
+              Assignment(attr, if (valueOption.isEmpty) attr else valueOption.get._2)
             })
           UpdateAction(resolvedCondition, newAssignments)
         case DeleteAction(condition) =>
diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
index dd1be20..5ec15ce 100644
--- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
+++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/command/MergeIntoHoodieTableCommand.scala
@@ -27,6 +27,7 @@ import org.apache.hudi.hive.ddl.HiveSyncMode
 import org.apache.hudi.{AvroConversionUtils, DataSourceWriteOptions, HoodieSparkSqlWriter, HoodieWriterUtils, SparkAdapterSupport}
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.Resolver
 import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Cast, EqualTo, Expression, Literal}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.command.RunnableCommand
@@ -90,6 +91,7 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
    * TODO Currently Non-equivalent conditions are not supported.
    */
   private lazy val targetKey2SourceExpression: Map[String, Expression] = {
+    val resolver = sparkSession.sessionState.conf.resolver
     val conditions = splitByAnd(mergeInto.mergeCondition)
     val allEqs = conditions.forall(p => p.isInstanceOf[EqualTo])
     if (!allEqs) {
@@ -101,11 +103,11 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
     val target2Source = conditions.map(_.asInstanceOf[EqualTo])
       .map {
         case EqualTo(left: AttributeReference, right)
-          if targetAttrs.indexOf(left) >= 0 => // left is the target field
-          left.name -> right
+          if targetAttrs.exists(f => attributeEqual(f, left, resolver)) => // left is the target field
+            targetAttrs.find(f => resolver(f.name, left.name)).get.name -> right
         case EqualTo(left, right: AttributeReference)
-          if targetAttrs.indexOf(right) >= 0 => // right is the target field
-          right.name -> left
+          if targetAttrs.exists(f => attributeEqual(f, right, resolver)) => // right is the target field
+            targetAttrs.find(f => resolver(f.name, right.name)).get.name -> left
         case eq =>
           throw new AnalysisException(s"Invalidate Merge-On condition: ${eq.sql}." +
             "The validate condition should be 'targetColumn = sourceColumnExpression', e.g." +
@@ -196,16 +198,25 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
   }
 
   private def isEqualToTarget(targetColumnName: String, sourceExpression: Expression): Boolean = {
-    val sourceColNameMap = sourceDFOutput.map(attr => (attr.name.toLowerCase, attr.name)).toMap
+    val sourceColumnName = sourceDFOutput.map(_.name)
+    val resolver = sparkSession.sessionState.conf.resolver
 
     sourceExpression match {
-      case attr: AttributeReference if sourceColNameMap(attr.name.toLowerCase).equals(targetColumnName) => true
-      case Cast(attr: AttributeReference, _, _) if sourceColNameMap(attr.name.toLowerCase).equals(targetColumnName) => true
+      case attr: AttributeReference if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
+      case Cast(attr: AttributeReference, _, _) if sourceColumnName.find(resolver(_, attr.name)).get.equals(targetColumnName) => true
       case _=> false
     }
   }
 
   /**
+   * Compare a [[Attribute]] to another, return true if they have the same column name(by resolver) and exprId
+   */
+  private def attributeEqual(
+      attr: Attribute, other: Attribute, resolver: Resolver): Boolean = {
+    resolver(attr.name, other.name) && attr.exprId == other.exprId
+  }
+
+  /**
    * Execute the update and delete action. All the matched and not-matched actions will
    * execute in one upsert write operation. We pushed down the matched condition and assignment
    * expressions to the ExpressionPayload#combineAndGetUpdateValue and the not matched
@@ -361,9 +372,9 @@ case class MergeIntoHoodieTableCommand(mergeInto: MergeIntoTable) extends Runnab
     mergeInto.targetTable.output
       .filterNot(attr => isMetaField(attr.name))
       .map(attr => {
-        val assignment = attr2Assignment.getOrElse(attr,
-          throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}"))
-        castIfNeeded(assignment, attr.dataType, sparkSession.sqlContext.conf)
+        val assignment = attr2Assignment.find(f => attributeEqual(f._1, attr, sparkSession.sessionState.conf.resolver))
+          .getOrElse(throw new IllegalArgumentException(s"Cannot find related assignment for field: ${attr.name}"))
+        castIfNeeded(assignment._2, attr.dataType, sparkSession.sqlContext.conf)
       })
   }
 
diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala
index 153eacf..bf73251 100644
--- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala
+++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestMergeIntoTable2.scala
@@ -432,4 +432,115 @@ class TestMergeIntoTable2 extends TestHoodieSqlBase {
     }
   }
 
+  test("Test ignoring case") {
+    withTempDir { tmp =>
+      val tableName = generateTableName
+      // Create table
+      spark.sql(
+        s"""
+           |create table $tableName (
+           |  ID int,
+           |  name string,
+           |  price double,
+           |  TS long,
+           |  DT string
+           |) using hudi
+           | location '${tmp.getCanonicalPath}/$tableName'
+           | options (
+           |  primaryKey ='ID',
+           |  preCombineField = 'TS'
+           | )
+       """.stripMargin)
+
+      // First merge with a extra input field 'flag' (insert a new record)
+      spark.sql(
+        s"""
+           | merge into $tableName
+           | using (
+           |  select 1 as id, 'a1' as name, 10 as PRICE, 1000 as ts, '2021-05-05' as dt, '1' as flag
+           | ) s0
+           | on s0.id = $tableName.id
+           | when matched and flag = '1' then update set
+           | id = s0.id, name = s0.name, PRICE = s0.price, ts = s0.ts, dt = s0.dt
+           | when not matched and flag = '1' then insert *
+       """.stripMargin)
+      checkAnswer(s"select id, name, price, ts, dt from $tableName")(
+        Seq(1, "a1", 10.0, 1000, "2021-05-05")
+      )
+
+      // Second merge (update the record)
+      spark.sql(
+        s"""
+           | merge into $tableName
+           | using (
+           |  select 1 as id, 'a1' as name, 20 as PRICE, '2021-05-05' as dt, 1001 as ts
+           | ) s0
+           | on s0.id = $tableName.id
+           | when matched then update set
+           | id = s0.id, name = s0.name, PRICE = s0.price, ts = s0.ts, dt = s0.dt
+           | when not matched then insert *
+       """.stripMargin)
+      checkAnswer(s"select id, name, price, ts, dt from $tableName")(
+        Seq(1, "a1", 20.0, 1001, "2021-05-05")
+      )
+
+      // Test ignoring case when column name matches
+      spark.sql(
+        s"""
+           | merge into $tableName as t0
+           | using (
+           |  select 1 as id, 'a1' as name, 1111 as ts, '2021-05-05' as dt, 111 as PRICE union all
+           |  select 2 as id, 'a2' as name, 1112 as ts, '2021-05-05' as dt, 112 as PRICE
+           | ) as s0
+           | on t0.id = s0.id
+           | when matched then update set *
+           | when not matched then insert *
+           |""".stripMargin)
+      checkAnswer(s"select id, name, price, ts, dt from $tableName")(
+        Seq(1, "a1", 111.0, 1111, "2021-05-05"),
+        Seq(2, "a2", 112.0, 1112, "2021-05-05")
+      )
+    }
+  }
+
+  test("Test ignoring case for MOR table") {
+    withTempDir { tmp =>
+      val tableName = generateTableName
+      // Create a mor partitioned table.
+      spark.sql(
+        s"""
+           | create table $tableName (
+           |  ID int,
+           |  NAME string,
+           |  price double,
+           |  TS long,
+           |  dt string
+           | ) using hudi
+           | options (
+           |  type = 'mor',
+           |  primaryKey = 'ID',
+           |  preCombineField = 'TS'
+           | )
+           | partitioned by(dt)
+           | location '${tmp.getCanonicalPath}/$tableName'
+         """.stripMargin)
+
+      // Test ignoring case when column name matches
+      spark.sql(
+        s"""
+           | merge into $tableName as t0
+           | using (
+           |  select 1 as id, 'a1' as NAME, 1111 as ts, '2021-05-05' as DT, 111 as price
+           | ) as s0
+           | on t0.id = s0.id
+           | when matched then update set *
+           | when not matched then insert *
+         """.stripMargin
+      )
+      checkAnswer(s"select id, name, price, ts, dt from $tableName")(
+        Seq(1, "a1", 111.0, 1111, "2021-05-05")
+      )
+    }
+  }
+
 }