You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ao...@apache.org on 2023/05/19 20:34:43 UTC
[iceberg] branch master updated: Spark 3.4: Split update into delete and insert for position deltas (#7646)
This is an automated email from the ASF dual-hosted git repository.
aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 2f61a08274 Spark 3.4: Split update into delete and insert for position deltas (#7646)
2f61a08274 is described below
commit 2f61a082745739f61fd07c56344afa88363d1d51
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Fri May 19 13:34:35 2023 -0700
Spark 3.4: Split update into delete and insert for position deltas (#7646)
---
.../iceberg/spark/UpdateProjectionBenchmark.java | 213 ++++++++++++++++
.../catalyst/analysis/RewriteMergeIntoTable.scala | 148 ++++-------
.../analysis/RewriteRowLevelIcebergCommand.scala | 96 ++++++++
.../sql/catalyst/analysis/RewriteUpdateTable.scala | 43 +++-
.../sql/catalyst/plans/logical/MergeRows.scala | 2 +-
.../logical/{MergeRows.scala => UpdateRows.scala} | 19 +-
.../v2/ExtendedDataSourceV2Strategy.scala | 4 +
.../execution/datasources/v2/MergeRowsExec.scala | 274 +++++++++++++++------
.../execution/datasources/v2/UpdateRowsExec.scala | 86 +++++++
.../apache/iceberg/spark/extensions/TestMerge.java | 34 +++
.../spark/SparkDistributionAndOrderingUtil.java | 2 +-
.../TestSparkDistributionAndOrderingUtil.java | 70 +++++-
12 files changed, 782 insertions(+), 209 deletions(-)
diff --git a/spark/v3.4/spark-extensions/src/jmh/java/org/apache/iceberg/spark/UpdateProjectionBenchmark.java b/spark/v3.4/spark-extensions/src/jmh/java/org/apache/iceberg/spark/UpdateProjectionBenchmark.java
new file mode 100644
index 0000000000..d917eae5eb
--- /dev/null
+++ b/spark/v3.4/spark-extensions/src/jmh/java/org/apache/iceberg/spark/UpdateProjectionBenchmark.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.spark;
+
+import static org.apache.spark.sql.functions.current_date;
+import static org.apache.spark.sql.functions.date_add;
+import static org.apache.spark.sql.functions.expr;
+
+import com.google.errorprone.annotations.FormatMethod;
+import com.google.errorprone.annotations.FormatString;
+import java.util.UUID;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.iceberg.DistributionMode;
+import org.apache.iceberg.RowLevelOperationMode;
+import org.apache.iceberg.Table;
+import org.apache.iceberg.TableProperties;
+import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.catalyst.parser.ParseException;
+import org.apache.spark.sql.internal.SQLConf;
+import org.apache.spark.sql.types.StructType;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.annotations.Threads;
+import org.openjdk.jmh.annotations.Warmup;
+
+@Fork(1)
+@State(Scope.Benchmark)
+@Warmup(iterations = 3)
+@Measurement(iterations = 5)
+@BenchmarkMode(Mode.SingleShotTime)
+public class UpdateProjectionBenchmark {
+
+ private static final String TABLE_NAME = "test_table";
+ private static final int NUM_FILES = 5;
+ private static final int NUM_ROWS_PER_FILE = 1_000_000;
+
+ private final Configuration hadoopConf = new Configuration();
+ private SparkSession spark;
+ private long originalSnapshotId;
+
+ @Setup
+ public void setupBenchmark() throws NoSuchTableException, ParseException {
+ setupSpark();
+ initTable();
+ appendData();
+
+ Table table = Spark3Util.loadIcebergTable(spark, TABLE_NAME);
+ this.originalSnapshotId = table.currentSnapshot().snapshotId();
+ }
+
+ @TearDown
+ public void tearDownBenchmark() {
+ tearDownSpark();
+ dropTable();
+ }
+
+ @Benchmark
+ @Threads(1)
+ public void copyOnWriteUpdate10Percent() {
+ runBenchmark(RowLevelOperationMode.COPY_ON_WRITE, 0.1);
+ }
+
+ @Benchmark
+ @Threads(1)
+ public void copyOnWriteUpdate30Percent() {
+ runBenchmark(RowLevelOperationMode.COPY_ON_WRITE, 0.3);
+ }
+
+ @Benchmark
+ @Threads(1)
+ public void copyOnWriteUpdate75Percent() {
+ runBenchmark(RowLevelOperationMode.COPY_ON_WRITE, 0.75);
+ }
+
+ @Benchmark
+ @Threads(1)
+ public void mergeOnRead10Percent() {
+ runBenchmark(RowLevelOperationMode.MERGE_ON_READ, 0.1);
+ }
+
+ @Benchmark
+ @Threads(1)
+ public void mergeOnReadUpdate30Percent() {
+ runBenchmark(RowLevelOperationMode.MERGE_ON_READ, 0.3);
+ }
+
+ @Benchmark
+ @Threads(1)
+ public void mergeOnReadUpdate75Percent() {
+ runBenchmark(RowLevelOperationMode.MERGE_ON_READ, 0.75);
+ }
+
+ private void runBenchmark(RowLevelOperationMode mode, double updatePercentage) {
+ sql(
+ "ALTER TABLE %s SET TBLPROPERTIES ('%s' '%s')",
+ TABLE_NAME, TableProperties.UPDATE_MODE, mode.modeName());
+
+ int mod = (int) (NUM_ROWS_PER_FILE / (NUM_ROWS_PER_FILE * updatePercentage));
+
+ sql(
+ "UPDATE %s "
+ + "SET intCol = intCol + 10, dateCol = date_add(dateCol, 1) "
+ + "WHERE mod(id, %d) = 0",
+ TABLE_NAME, mod);
+
+ sql(
+ "CALL system.rollback_to_snapshot(table => '%s', snapshot_id => %dL)",
+ TABLE_NAME, originalSnapshotId);
+ }
+
+ private void setupSpark() {
+ this.spark =
+ SparkSession.builder()
+ .config("spark.ui.enabled", false)
+ .config("spark.sql.extensions", IcebergSparkSessionExtensions.class.getName())
+ .config("spark.sql.catalog.spark_catalog", SparkSessionCatalog.class.getName())
+ .config("spark.sql.catalog.spark_catalog.type", "hadoop")
+ .config("spark.sql.catalog.spark_catalog.warehouse", newWarehouseDir())
+ .config(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false")
+ .config(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false")
+ .config(SQLConf.SHUFFLE_PARTITIONS().key(), "2")
+ .master("local")
+ .getOrCreate();
+ }
+
+ private void tearDownSpark() {
+ spark.stop();
+ }
+
+ private void initTable() {
+ sql(
+ "CREATE TABLE %s ( "
+ + " id LONG, intCol INT, floatCol FLOAT, doubleCol DOUBLE, "
+ + " decimalCol DECIMAL(20, 5), dateCol DATE, timestampCol TIMESTAMP, "
+ + " stringCol STRING)"
+ + "USING iceberg "
+ + "TBLPROPERTIES ("
+ + " '%s' '%s',"
+ + " '%s' '%d',"
+ + " '%s' '%d')",
+ TABLE_NAME,
+ TableProperties.UPDATE_DISTRIBUTION_MODE,
+ DistributionMode.NONE.modeName(),
+ TableProperties.SPLIT_OPEN_FILE_COST,
+ Integer.MAX_VALUE,
+ TableProperties.FORMAT_VERSION,
+ 2);
+
+ sql("ALTER TABLE %s WRITE ORDERED BY id", TABLE_NAME);
+ }
+
+ private void dropTable() {
+ sql("DROP TABLE IF EXISTS %s PURGE", TABLE_NAME);
+ }
+
+ private void appendData() throws NoSuchTableException {
+ for (int fileNum = 1; fileNum <= NUM_FILES; fileNum++) {
+ Dataset<Row> inputDF =
+ spark
+ .range(NUM_ROWS_PER_FILE)
+ .withColumn("intCol", expr("CAST(id AS INT)"))
+ .withColumn("floatCol", expr("CAST(id AS FLOAT)"))
+ .withColumn("doubleCol", expr("CAST(id AS DOUBLE)"))
+ .withColumn("decimalCol", expr("CAST(id AS DECIMAL(20, 5))"))
+ .withColumn("dateCol", date_add(current_date(), fileNum))
+ .withColumn("timestampCol", expr("TO_TIMESTAMP(dateCol)"))
+ .withColumn("stringCol", expr("CAST(dateCol AS STRING)"));
+ appendAsFile(inputDF);
+ }
+ }
+
+ private void appendAsFile(Dataset<Row> df) throws NoSuchTableException {
+ // ensure the schema is precise (including nullability)
+ StructType sparkSchema = spark.table(TABLE_NAME).schema();
+ spark.createDataFrame(df.rdd(), sparkSchema).coalesce(1).writeTo(TABLE_NAME).append();
+ }
+
+ private String newWarehouseDir() {
+ return hadoopConf.get("hadoop.tmp.dir") + UUID.randomUUID();
+ }
+
+ @FormatMethod
+ private void sql(@FormatString String query, Object... args) {
+ spark.sql(String.format(query, args));
+ }
+}
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index ecb1146c38..2a14c3144e 100644
--- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -20,12 +20,10 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.ProjectingInternalRow
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.And
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
-import org.apache.spark.sql.catalyst.expressions.AttributeSet
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.IsNotNull
import org.apache.spark.sql.catalyst.expressions.Literal
@@ -66,8 +64,6 @@ import org.apache.spark.sql.connector.write.RowLevelOperationTable
import org.apache.spark.sql.connector.write.SupportsDelta
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.sql.types.StructField
-import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
@@ -127,7 +123,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with Predicat
val joinPlan = Join(source, r, LeftAnti, Some(cond), JoinHint.NONE)
val notMatchedConditions = notMatchedActions.map(actionCondition)
- val notMatchedOutputs = notMatchedActions.map(actionOutput(_, Nil))
+ val notMatchedOutputs = notMatchedActions.map(notMatchedActionOutput(_, Nil))
// merge rows as there are multiple not matched actions
val mergeRows = MergeRows(
@@ -210,13 +206,11 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with Predicat
val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None)
val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(cond), joinHint)
- // add an extra matched action to output the original row if none of the actual actions matched
- // this is needed to keep target rows that should be copied over
- val matchedConditions = matchedActions.map(actionCondition) :+ TrueLiteral
- val matchedOutputs = matchedActions.map(actionOutput(_, metadataAttrs)) :+ readAttrs
+ val matchedConditions = matchedActions.map(actionCondition)
+ val matchedOutputs = matchedActions.map(matchedActionOutput(_, metadataAttrs))
val notMatchedConditions = notMatchedActions.map(actionCondition)
- val notMatchedOutputs = notMatchedActions.map(actionOutput(_, metadataAttrs))
+ val notMatchedOutputs = notMatchedActions.map(notMatchedActionOutput(_, metadataAttrs))
val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE_REF, joinPlan)
val rowFromTargetAttr = resolveAttrRef(ROW_FROM_TARGET_REF, joinPlan)
@@ -283,14 +277,17 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with Predicat
val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None)
val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(joinCond), joinHint)
- val deleteRowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs)
val metadataReadAttrs = readAttrs.filterNot(relation.outputSet.contains)
val matchedConditions = matchedActions.map(actionCondition)
- val matchedOutputs = matchedActions.map(deltaActionOutput(_, deleteRowValues, metadataReadAttrs))
+ val matchedOutputs = matchedActions.map { action =>
+ matchedDeltaActionOutput(action, rowAttrs, rowIdAttrs, metadataReadAttrs)
+ }
val notMatchedConditions = notMatchedActions.map(actionCondition)
- val notMatchedOutputs = notMatchedActions.map(deltaActionOutput(_, deleteRowValues, metadataReadAttrs))
+ val notMatchedOutputs = notMatchedActions.map { action =>
+ notMatchedDeltaActionOutput(action, metadataReadAttrs)
+ }
val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)()
val rowFromSourceAttr = resolveAttrRef(ROW_FROM_SOURCE_REF, joinPlan)
@@ -315,7 +312,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with Predicat
// build a plan to write the row delta to the table
val writeRelation = relation.copy(table = operationTable)
- val projections = buildMergeDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
+ val projections = buildDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
WriteIcebergDelta(writeRelation, mergeRows, relation, projections)
}
@@ -323,63 +320,77 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with Predicat
action.condition.getOrElse(TrueLiteral)
}
- private def actionOutput(
+ private def matchedActionOutput(
clause: MergeAction,
- metadataAttrs: Seq[Attribute]): Seq[Expression] = {
+ metadataAttrs: Seq[Attribute]): Seq[Seq[Expression]] = {
clause match {
case u: UpdateAction =>
- u.assignments.map(_.value) ++ metadataAttrs
+ Seq(u.assignments.map(_.value) ++ metadataAttrs)
case _: DeleteAction =>
Nil
+ case other =>
+ throw new AnalysisException(s"Unexpected WHEN MATCHED action: $other")
+ }
+ }
+
+ private def notMatchedActionOutput(
+ clause: MergeAction,
+ metadataAttrs: Seq[Attribute]): Seq[Expression] = {
+
+ clause match {
case i: InsertAction =>
i.assignments.map(_.value) ++ metadataAttrs.map(attr => Literal(null, attr.dataType))
case other =>
- throw new AnalysisException(s"Unexpected action: $other")
+ throw new AnalysisException(s"Unexpected WHEN NOT MATCHED action: $other")
}
}
- private def deltaActionOutput(
+ private def matchedDeltaActionOutput(
action: MergeAction,
- deleteRowValues: Seq[Expression],
- metadataAttrs: Seq[Attribute]): Seq[Expression] = {
+ rowAttrs: Seq[Attribute],
+ rowIdAttrs: Seq[Attribute],
+ metadataAttrs: Seq[Attribute]): Seq[Seq[Expression]] = {
action match {
case u: UpdateAction =>
- Seq(Literal(UPDATE_OPERATION)) ++ u.assignments.map(_.value) ++ metadataAttrs
+ val delete = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs)
+ val insert = deltaInsertOutput(u.assignments.map(_.value), metadataAttrs)
+ Seq(delete, insert)
case _: DeleteAction =>
- Seq(Literal(DELETE_OPERATION)) ++ deleteRowValues ++ metadataAttrs
+ val delete = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs)
+ Seq(delete)
+
+ case other =>
+ throw new AnalysisException(s"Unexpected WHEN MATCHED action: $other")
+ }
+ }
+
+ private def notMatchedDeltaActionOutput(
+ action: MergeAction,
+ metadataAttrs: Seq[Attribute]): Seq[Expression] = {
+ action match {
case i: InsertAction =>
- val metadataAttrValues = metadataAttrs.map(attr => Literal(null, attr.dataType))
- Seq(Literal(INSERT_OPERATION)) ++ i.assignments.map(_.value) ++ metadataAttrValues
+ deltaInsertOutput(i.assignments.map(_.value), metadataAttrs)
case other =>
- throw new AnalysisException(s"Unexpected action: $other")
+ throw new AnalysisException(s"Unexpected WHEN NOT MATCHED action: $other")
}
}
private def buildMergeRowsOutput(
- matchedOutputs: Seq[Seq[Expression]],
+ matchedOutputs: Seq[Seq[Seq[Expression]]],
notMatchedOutputs: Seq[Seq[Expression]],
attrs: Seq[Attribute]): Seq[Attribute] = {
- // collect all outputs from matched and not matched actions (ignoring DELETEs)
- val outputs = matchedOutputs.filter(_.nonEmpty) ++ notMatchedOutputs.filter(_.nonEmpty)
-
- // build a correct nullability map for output attributes
- // an attribute is nullable if at least one matched or not matched action may produce null
- val nullabilityMap = attrs.indices.map { index =>
- index -> outputs.exists(output => output(index).nullable)
- }.toMap
-
- attrs.zipWithIndex.map { case (attr, index) =>
- AttributeReference(attr.name, attr.dataType, nullabilityMap(index), attr.metadata)()
- }
+ // collect all outputs from matched and not matched actions (ignoring actions that discard rows)
+ val outputs = matchedOutputs.flatten.filter(_.nonEmpty) ++ notMatchedOutputs.filter(_.nonEmpty)
+ buildMergingOutput(outputs, attrs)
}
private def isCardinalityCheckNeeded(actions: Seq[MergeAction]): Boolean = actions match {
@@ -387,71 +398,18 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with Predicat
case _ => true
}
- private def buildDeltaDeleteRowValues(
- rowAttrs: Seq[Attribute],
- rowIdAttrs: Seq[Attribute]): Seq[Expression] = {
-
- // nullify all row attrs that are not part of the row ID
- val rowIdAttSet = AttributeSet(rowIdAttrs)
- rowAttrs.map {
- case attr if rowIdAttSet.contains(attr) => attr
- case attr => Literal(null, attr.dataType)
- }
- }
-
private def resolveAttrRef(ref: NamedReference, plan: LogicalPlan): AttributeReference = {
V2ExpressionUtils.resolveRef[AttributeReference](ref, plan)
}
- private def buildMergeDeltaProjections(
+ private def buildDeltaProjections(
mergeRows: MergeRows,
rowAttrs: Seq[Attribute],
rowIdAttrs: Seq[Attribute],
metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
- val outputAttrs = mergeRows.output
-
- val outputs = mergeRows.matchedOutputs ++ mergeRows.notMatchedOutputs
- val insertAndUpdateOutputs = outputs.filterNot(_.head == Literal(DELETE_OPERATION))
- val updateAndDeleteOutputs = outputs.filterNot(_.head == Literal(INSERT_OPERATION))
-
- val rowProjection = if (rowAttrs.nonEmpty) {
- Some(newLazyProjection(insertAndUpdateOutputs, outputAttrs, rowAttrs))
- } else {
- None
- }
-
- val rowIdProjection = newLazyProjection(updateAndDeleteOutputs, outputAttrs, rowIdAttrs)
-
- val metadataProjection = if (metadataAttrs.nonEmpty) {
- Some(newLazyProjection(updateAndDeleteOutputs, outputAttrs, metadataAttrs))
- } else {
- None
- }
-
- WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection)
- }
-
- // the projection is done by name, ignoring expr IDs
- private def newLazyProjection(
- outputs: Seq[Seq[Expression]],
- outputAttrs: Seq[Attribute],
- projectedAttrs: Seq[Attribute]): ProjectingInternalRow = {
-
- val projectedOrdinals = projectedAttrs.map(attr => outputAttrs.indexWhere(_.name == attr.name))
-
- val structFields = projectedAttrs.zip(projectedOrdinals).map { case (attr, ordinal) =>
- // output attr is nullable if at least one action may produce null for that attr
- // but row ID and metadata attrs are projected only in update/delete actions and
- // row attrs are projected only in insert/update actions
- // that's why the projection schema must rely only on relevant action outputs
- // instead of blindly inheriting the output attr nullability
- val nullable = outputs.exists(output => output(ordinal).nullable)
- StructField(attr.name, attr.dataType, nullable, attr.metadata)
- }
- val schema = StructType(structFields)
-
- ProjectingInternalRow(schema, projectedOrdinals)
+ val outputs = mergeRows.matchedOutputs.flatten ++ mergeRows.notMatchedOutputs
+ buildDeltaProjections(mergeRows, outputs, rowAttrs, rowIdAttrs, metadataAttrs)
}
// splits the MERGE condition into a predicate that references columns only from the target table,
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala
index abadab4e53..0b1871038f 100644
--- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala
@@ -23,12 +23,17 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.ProjectingInternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.V2ExpressionUtils
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
import org.apache.spark.sql.catalyst.util.WriteDeltaProjections
import org.apache.spark.sql.connector.write.RowLevelOperation
import org.apache.spark.sql.connector.write.SupportsDelta
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
trait RewriteRowLevelIcebergCommand extends RewriteRowLevelCommand {
@@ -67,6 +72,97 @@ trait RewriteRowLevelIcebergCommand extends RewriteRowLevelCommand {
ProjectingInternalRow(schema, projectedOrdinals)
}
+ protected def buildDeltaProjections(
+ plan: LogicalPlan,
+ outputs: Seq[Seq[Expression]],
+ rowAttrs: Seq[Attribute],
+ rowIdAttrs: Seq[Attribute],
+ metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
+
+ val insertAndUpdateOutputs = outputs.filterNot(_.head == Literal(DELETE_OPERATION))
+ val updateAndDeleteOutputs = outputs.filterNot(_.head == Literal(INSERT_OPERATION))
+
+ val rowProjection = if (rowAttrs.nonEmpty) {
+ Some(newLazyProjection(insertAndUpdateOutputs, plan.output, rowAttrs))
+ } else {
+ None
+ }
+
+ val rowIdProjection = newLazyProjection(updateAndDeleteOutputs, plan.output, rowIdAttrs)
+
+ val metadataProjection = if (metadataAttrs.nonEmpty) {
+ Some(newLazyProjection(updateAndDeleteOutputs, plan.output, metadataAttrs))
+ } else {
+ None
+ }
+
+ WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection)
+ }
+
+ // the projection is done by name, ignoring expr IDs
+ private def newLazyProjection(
+ outputs: Seq[Seq[Expression]],
+ outputAttrs: Seq[Attribute],
+ projectedAttrs: Seq[Attribute]): ProjectingInternalRow = {
+
+ val projectedOrdinals = projectedAttrs.map(attr => outputAttrs.indexWhere(_.name == attr.name))
+
+ val structFields = projectedAttrs.zip(projectedOrdinals).map { case (attr, ordinal) =>
+ // output attr is nullable if at least one output projection may produce null for that attr
+ // but row ID and metadata attrs are projected only for update/delete records and
+ // row attrs are projected only in insert/update records
+ // that's why the projection schema must rely only on relevant outputs
+ // instead of blindly inheriting the output attr nullability
+ val nullable = outputs.exists(output => output(ordinal).nullable)
+ StructField(attr.name, attr.dataType, nullable, attr.metadata)
+ }
+ val schema = StructType(structFields)
+
+ ProjectingInternalRow(schema, projectedOrdinals)
+ }
+
+ protected def deltaDeleteOutput(
+ rowAttrs: Seq[Attribute],
+ rowIdAttrs: Seq[Attribute],
+ metadataAttrs: Seq[Attribute]): Seq[Expression] = {
+ val deleteRowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs)
+ Seq(Literal(DELETE_OPERATION)) ++ deleteRowValues ++ metadataAttrs
+ }
+
+ protected def deltaInsertOutput(
+ rowValues: Seq[Expression],
+ metadataAttrs: Seq[Attribute]): Seq[Expression] = {
+ val metadataValues = metadataAttrs.map(attr => Literal(null, attr.dataType))
+ Seq(Literal(INSERT_OPERATION)) ++ rowValues ++ metadataValues
+ }
+
+ private def buildDeltaDeleteRowValues(
+ rowAttrs: Seq[Attribute],
+ rowIdAttrs: Seq[Attribute]): Seq[Expression] = {
+
+ // nullify all row attrs that are not part of the row ID
+ val rowIdAttSet = AttributeSet(rowIdAttrs)
+ rowAttrs.map {
+ case attr if rowIdAttSet.contains(attr) => attr
+ case attr => Literal(null, attr.dataType)
+ }
+ }
+
+ protected def buildMergingOutput(
+ outputs: Seq[Seq[Expression]],
+ attrs: Seq[Attribute]): Seq[Attribute] = {
+
+ // build a correct nullability map for output attributes
+ // an attribute is nullable if at least one output may produce null
+ val nullabilityMap = attrs.indices.map { index =>
+ index -> outputs.exists(output => output(index).nullable)
+ }.toMap
+
+ attrs.zipWithIndex.map { case (attr, index) =>
+ AttributeReference(attr.name, attr.dataType, nullabilityMap(index))()
+ }
+ }
+
protected def resolveRowIdAttrs(
relation: DataSourceV2Relation,
operation: RowLevelOperation): Seq[AttributeReference] = {
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
index 006040081b..bbe5532185 100644
--- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteUpdateTable.scala
@@ -21,6 +21,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.EqualNullSafe
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.If
@@ -34,13 +36,16 @@ import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData
import org.apache.spark.sql.catalyst.plans.logical.Union
import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable
+import org.apache.spark.sql.catalyst.plans.logical.UpdateRows
import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
+import org.apache.spark.sql.catalyst.util.WriteDeltaProjections
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPDATE
import org.apache.spark.sql.connector.write.RowLevelOperationTable
import org.apache.spark.sql.connector.write.SupportsDelta
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
@@ -142,17 +147,45 @@ object RewriteUpdateTable extends RewriteRowLevelIcebergCommand {
// construct a scan relation and include all required metadata columns
val readRelation = buildRelationWithAttrs(relation, operationTable, rowIdAttrs ++ metadataAttrs)
+ val readAttrs = readRelation.output
+ val metadataReadAttrs = readAttrs.filterNot(relation.outputSet.contains)
// build a plan for updated records that match the cond
val matchedRowsPlan = Filter(cond, readRelation)
- val updatedRowsPlan = buildUpdateProjection(matchedRowsPlan, assignments)
- val operationType = Alias(Literal(UPDATE_OPERATION), OPERATION_COLUMN)()
- val project = Project(operationType +: updatedRowsPlan.output, updatedRowsPlan)
+ val updatedRowsPlan = updateRows(
+ matchedRowsPlan, assignments, readAttrs,
+ rowAttrs, rowIdAttrs, metadataReadAttrs)
// build a plan to write the row delta to the table
val writeRelation = relation.copy(table = operationTable)
- val projections = buildWriteDeltaProjections(project, rowAttrs, rowIdAttrs, metadataAttrs)
- WriteIcebergDelta(writeRelation, project, relation, projections)
+ val projections = buildDeltaProjections(updatedRowsPlan, rowAttrs, rowIdAttrs, metadataAttrs)
+ WriteIcebergDelta(writeRelation, updatedRowsPlan, relation, projections)
+ }
+
+ private def updateRows(
+ matchedRowsPlan: LogicalPlan,
+ assignments: Seq[Assignment],
+ readAttrs: Seq[Attribute],
+ rowAttrs: Seq[Attribute],
+ rowIdAttrs: Seq[Attribute],
+ metadataAttrs: Seq[Attribute]): UpdateRows = {
+
+ val delete = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs)
+ val insert = deltaInsertOutput(assignments.map(_.value), metadataAttrs)
+ val outputs = Seq(delete, insert)
+ val operationTypeAttr = AttributeReference(OPERATION_COLUMN, IntegerType, nullable = false)()
+ val updateRowsOutput = buildMergingOutput(outputs, operationTypeAttr +: readAttrs)
+ UpdateRows(delete, insert, updateRowsOutput, matchedRowsPlan)
+ }
+
+ private def buildDeltaProjections(
+ updateRows: UpdateRows,
+ rowAttrs: Seq[Attribute],
+ rowIdAttrs: Seq[Attribute],
+ metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
+
+ val outputs = Seq(updateRows.deleteOutput, updateRows.insertOutput)
+ buildDeltaProjections(updateRows, outputs, rowAttrs, rowIdAttrs, metadataAttrs)
}
// this method assumes the assignments have been already aligned before
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala
index 9c17ef02e6..57f8bf3583 100644
--- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala
@@ -28,7 +28,7 @@ case class MergeRows(
isSourceRowPresent: Expression,
isTargetRowPresent: Expression,
matchedConditions: Seq[Expression],
- matchedOutputs: Seq[Seq[Expression]],
+ matchedOutputs: Seq[Seq[Seq[Expression]]],
notMatchedConditions: Seq[Expression],
notMatchedOutputs: Seq[Seq[Expression]],
targetOutput: Seq[Expression],
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdateRows.scala
similarity index 72%
copy from spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala
copy to spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdateRows.scala
index 9c17ef02e6..0446f30404 100644
--- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MergeRows.scala
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/UpdateRows.scala
@@ -24,29 +24,18 @@ import org.apache.spark.sql.catalyst.expressions.AttributeSet
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.util.truncatedString
-case class MergeRows(
- isSourceRowPresent: Expression,
- isTargetRowPresent: Expression,
- matchedConditions: Seq[Expression],
- matchedOutputs: Seq[Seq[Expression]],
- notMatchedConditions: Seq[Expression],
- notMatchedOutputs: Seq[Seq[Expression]],
- targetOutput: Seq[Expression],
- performCardinalityCheck: Boolean,
- emitNotMatchedTargetRows: Boolean,
+case class UpdateRows(
+ deleteOutput: Seq[Expression],
+ insertOutput: Seq[Expression],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
- require(targetOutput.nonEmpty || !emitNotMatchedTargetRows)
-
override lazy val producedAttributes: AttributeSet = {
AttributeSet(output.filterNot(attr => inputSet.contains(attr)))
}
- override lazy val references: AttributeSet = child.outputSet
-
override def simpleString(maxFields: Int): String = {
- s"MergeRows${truncatedString(output, "[", ", ", "]", maxFields)}"
+ s"UpdateRows${truncatedString(output, "[", ", ", "]", maxFields)}"
}
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = {
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
index e5ec638c0c..aa81e38f29 100644
--- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
@@ -43,6 +43,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData
import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField
import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields
import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering
+import org.apache.spark.sql.catalyst.plans.logical.UpdateRows
import org.apache.spark.sql.catalyst.plans.logical.WriteIcebergDelta
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.TableCatalog
@@ -104,6 +105,9 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy wi
notMatchedOutputs, targetOutput, performCardinalityCheck, emitNotMatchedTargetRows,
output, planLater(child)) :: Nil
+ case UpdateRows(deleteOutput, insertOutput, output, child) =>
+ UpdateRowsExec(deleteOutput, insertOutput, output, planLater(child)) :: Nil
+
case NoStatsUnaryNode(child) =>
planLater(child) :: Nil
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala
index 0f9d0fa7ac..474a417d33 100644
--- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/MergeRowsExec.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeSet
import org.apache.spark.sql.catalyst.expressions.BasePredicate
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.Projection
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
import org.apache.spark.sql.catalyst.util.truncatedString
@@ -37,7 +38,7 @@ case class MergeRowsExec(
isSourceRowPresent: Expression,
isTargetRowPresent: Expression,
matchedConditions: Seq[Expression],
- matchedOutputs: Seq[Seq[Expression]],
+ matchedOutputs: Seq[Seq[Seq[Expression]]],
notMatchedConditions: Seq[Expression],
notMatchedOutputs: Seq[Seq[Expression]],
targetOutput: Seq[Expression],
@@ -66,115 +67,224 @@ case class MergeRowsExec(
child.execute().mapPartitions(processPartition)
}
- private def createProjection(exprs: Seq[Expression], attrs: Seq[Attribute]): UnsafeProjection = {
- UnsafeProjection.create(exprs, attrs)
+ private def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
+ UnsafeProjection.create(exprs, child.output)
}
- private def createPredicate(expr: Expression, attrs: Seq[Attribute]): BasePredicate = {
- GeneratePredicate.generate(expr, attrs)
+ private def createPredicate(expr: Expression): BasePredicate = {
+ GeneratePredicate.generate(expr, child.output)
}
- private def applyProjection(
- actions: Seq[(BasePredicate, Option[UnsafeProjection])],
- inputRow: InternalRow): InternalRow = {
+ // This method is responsible for processing a input row to emit the resultant row with an
+ // additional column that indicates whether the row is going to be included in the final
+ // output of merge or not.
+ // 1. Found a target row for which there is no corresponding source row (join condition not met)
+ // - Only project the target columns if we need to output unchanged rows (group-based commands)
+ // 2. Found a source row for which there is no corresponding target row (join condition not met)
+ // - Apply the not matched actions (i.e INSERT actions) if non match conditions are met.
+ // 3. Found a source row for which there is a corresponding target row (join condition met)
+ // - Apply the matched actions (i.e DELETE or UPDATE actions) if match conditions are met.
+ private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
+ val isSourceRowPresentPred = createPredicate(isSourceRowPresent)
+ val isTargetRowPresentPred = createPredicate(isTargetRowPresent)
- // find the first action where the predicate evaluates to true
- // if there are overlapping conditions in actions, use the first matching action
- // in the example below, when id = 5, both actions match but the first one is applied
- // WHEN MATCHED AND id > 1 AND id < 10 UPDATE *
- // WHEN MATCHED AND id = 5 OR id = 21 DELETE
+ val matchedActions = matchedConditions.zip(matchedOutputs).map { case (cond, outputs) =>
+ outputs match {
+ case Seq(output1, output2) =>
+ Split(createPredicate(cond), createProjection(output1), createProjection(output2))
+ case Seq(output) =>
+ Project(createPredicate(cond), createProjection(output))
+ case Nil =>
+ Project(createPredicate(cond), EmptyProjection)
+ }
+ }
- val pair = actions.find {
- case (predicate, _) => predicate.eval(inputRow)
+ val notMatchedActions = notMatchedConditions.zip(notMatchedOutputs).map { case (cond, output) =>
+ Project(createPredicate(cond), createProjection(output))
}
- // apply the projection to produce an output row, or return null to suppress this row
- pair match {
- case Some((_, Some(projection))) =>
- projection.apply(inputRow)
- case _ =>
- null
+ val projectTargetCols = createProjection(targetOutput)
+
+ val cardinalityCheck = if (performCardinalityCheck) {
+ val rowIdOrdinal = child.output.indexWhere(attr => conf.resolver(attr.name, ROW_ID))
+ assert(rowIdOrdinal != -1, "Cannot find row ID attr")
+ BitmapCardinalityCheck(rowIdOrdinal)
+ } else {
+ EmptyCardinalityCheck
}
+
+ val mergeIterator = if (matchedActions.exists(_.isInstanceOf[Split])) {
+ new SplittingMergeRowIterator(
+ rowIterator, cardinalityCheck, isTargetRowPresentPred,
+ matchedActions, notMatchedActions)
+ } else {
+ new MergeRowIterator(
+ rowIterator, cardinalityCheck, isTargetRowPresentPred, isSourceRowPresentPred,
+ projectTargetCols, matchedActions.asInstanceOf[Seq[Project]], notMatchedActions)
+ }
+
+ // null indicates a record must be discarded
+ mergeIterator.filter(_ != null)
}
- private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
- val inputAttrs = child.output
+ trait Action {
+ def cond: BasePredicate
+ }
+
+ case class Project(cond: BasePredicate, proj: Projection) extends Action {
+ def apply(row: InternalRow): InternalRow = proj.apply(row)
+ }
- val isSourceRowPresentPred = createPredicate(isSourceRowPresent, inputAttrs)
- val isTargetRowPresentPred = createPredicate(isTargetRowPresent, inputAttrs)
+ case class Split(cond: BasePredicate, proj: Projection, otherProj: Projection) extends Action {
+ def projectRow(row: InternalRow): InternalRow = proj.apply(row)
+ def projectExtraRow(row: InternalRow): InternalRow = otherProj.apply(row)
+ }
- val matchedPreds = matchedConditions.map(createPredicate(_, inputAttrs))
- val matchedProjs = matchedOutputs.map {
- case output if output.nonEmpty => Some(createProjection(output, inputAttrs))
- case _ => None
- }
- val matchedPairs = matchedPreds zip matchedProjs
+ object EmptyProjection extends Projection {
+ override def apply(row: InternalRow): InternalRow = null
+ }
- val notMatchedPreds = notMatchedConditions.map(createPredicate(_, inputAttrs))
- val notMatchedProjs = notMatchedOutputs.map {
- case output if output.nonEmpty => Some(createProjection(output, inputAttrs))
- case _ => None
- }
- val nonMatchedPairs = notMatchedPreds zip notMatchedProjs
-
- val projectTargetCols = createProjection(targetOutput, inputAttrs)
-
- // This method is responsible for processing a input row to emit the resultant row with an
- // additional column that indicates whether the row is going to be included in the final
- // output of merge or not.
- // 1. Found a target row for which there is no corresponding source row (join condition not met)
- // - Only project the target columns if we need to output unchanged rows
- // 2. Found a source row for which there is no corresponding target row (join condition not met)
- // - Apply the not matched actions (i.e INSERT actions) if non match conditions are met.
- // 3. Found a source row for which there is a corresponding target row (join condition met)
- // - Apply the matched actions (i.e DELETE or UPDATE actions) if match conditions are met.
- def processRow(inputRow: InternalRow): InternalRow = {
- if (emitNotMatchedTargetRows && !isSourceRowPresentPred.eval(inputRow)) {
- projectTargetCols.apply(inputRow)
- } else if (!isTargetRowPresentPred.eval(inputRow)) {
- applyProjection(nonMatchedPairs, inputRow)
+ class MergeRowIterator(
+ private val rowIterator: Iterator[InternalRow],
+ private val cardinalityCheck: CardinalityCheck,
+ private val isTargetRowPresentPred: BasePredicate,
+ private val isSourceRowPresentPred: BasePredicate,
+ private val targetTableProj: Projection,
+ private val matchedActions: Seq[Project],
+ private val notMatchedActions: Seq[Project])
+ extends Iterator[InternalRow] {
+
+ override def hasNext: Boolean = rowIterator.hasNext
+
+ override def next(): InternalRow = {
+ val row = rowIterator.next()
+
+ val isSourceRowPresent = isSourceRowPresentPred.eval(row)
+ val isTargetRowPresent = isTargetRowPresentPred.eval(row)
+
+ if (isTargetRowPresent && isSourceRowPresent) {
+ cardinalityCheck.execute(row)
+ applyMatchedActions(row)
+ } else if (isSourceRowPresent) {
+ applyNotMatchedActions(row)
+ } else if (emitNotMatchedTargetRows && isTargetRowPresent) {
+ targetTableProj.apply(row)
} else {
- applyProjection(matchedPairs, inputRow)
+ null
}
}
- val matchedRowIds = new Roaring64Bitmap()
+ private def applyMatchedActions(row: InternalRow): InternalRow = {
+ for (action <- matchedActions) {
+ if (action.cond.eval(row)) {
+ return action.apply(row)
+ }
+ }
- def processRowWithCardinalityCheck(rowIdOrdinal: Int)(inputRow: InternalRow): InternalRow = {
- val isSourceRowPresent = isSourceRowPresentPred.eval(inputRow)
- val isTargetRowPresent = isTargetRowPresentPred.eval(inputRow)
+ if (emitNotMatchedTargetRows) targetTableProj.apply(row) else null
+ }
- if (isSourceRowPresent && isTargetRowPresent) {
- val currentRowId = inputRow.getLong(rowIdOrdinal)
- if (matchedRowIds.contains(currentRowId)) {
- throw new SparkException(
- "The ON search condition of the MERGE statement matched a single row from " +
- "the target table with multiple rows of the source table. This could result " +
- "in the target row being operated on more than once with an update or delete " +
- "operation and is not allowed.")
+ private def applyNotMatchedActions(row: InternalRow): InternalRow = {
+ for (action <- notMatchedActions) {
+ if (action.cond.eval(row)) {
+ return action.apply(row)
}
- matchedRowIds.add(currentRowId)
}
- if (emitNotMatchedTargetRows && !isSourceRowPresent) {
- projectTargetCols.apply(inputRow)
- } else if (!isTargetRowPresent) {
- applyProjection(nonMatchedPairs, inputRow)
+ null
+ }
+ }
+
+ /**
+ * An iterator that splits updates into deletes and inserts.
+ *
+ * Each input row that represents an update becomes two output rows: a delete and an insert.
+ */
+ class SplittingMergeRowIterator(
+ private val rowIterator: Iterator[InternalRow],
+ private val cardinalityCheck: CardinalityCheck,
+ private val isTargetRowPresentPred: BasePredicate,
+ private val matchedActions: Seq[Action],
+ private val notMatchedActions: Seq[Project])
+ extends Iterator[InternalRow] {
+
+ var cachedExtraRow: InternalRow = _
+
+ override def hasNext: Boolean = cachedExtraRow != null || rowIterator.hasNext
+
+ override def next(): InternalRow = {
+ if (cachedExtraRow != null) {
+ val extraRow = cachedExtraRow
+ cachedExtraRow = null
+ return extraRow
+ }
+
+ val row = rowIterator.next()
+
+ // it should be OK to just check if the target row exists
+ // as this iterator is only used for delta-based row-level plans
+ // that are rewritten using an inner or right outer join
+ if (isTargetRowPresentPred.eval(row)) {
+ cardinalityCheck.execute(row)
+ applyMatchedActions(row)
} else {
- applyProjection(matchedPairs, inputRow)
+ applyNotMatchedActions(row)
}
}
- val processFunc: InternalRow => InternalRow = if (performCardinalityCheck) {
- val rowIdOrdinal = child.output.indexWhere(attr => conf.resolver(attr.name, ROW_ID))
- assert(rowIdOrdinal != -1, "Cannot find row ID attr")
- processRowWithCardinalityCheck(rowIdOrdinal)
- } else {
- processRow
+ private def applyMatchedActions(row: InternalRow): InternalRow = {
+ for (action <- matchedActions) {
+ if (action.cond.eval(row)) {
+ action match {
+ case split: Split =>
+ cachedExtraRow = split.projectExtraRow(row)
+ return split.projectRow(row)
+ case project: Project =>
+ return project.apply(row)
+ }
+ }
+ }
+
+ null
}
- rowIterator
- .map(processFunc)
- .filter(row => row != null)
+ private def applyNotMatchedActions(row: InternalRow): InternalRow = {
+ for (action <- notMatchedActions) {
+ if (action.cond.eval(row)) {
+ return action.apply(row)
+ }
+ }
+
+ null
+ }
+ }
+
+ sealed trait CardinalityCheck {
+
+ def execute(inputRow: InternalRow): Unit
+
+ protected def fail(): Unit = {
+ throw new SparkException(
+ "The ON search condition of the MERGE statement matched a single row from " +
+ "the target table with multiple rows of the source table. This could result " +
+ "in the target row being operated on more than once with an update or delete " +
+ "operation and is not allowed.")
+ }
+ }
+
+ object EmptyCardinalityCheck extends CardinalityCheck {
+ def execute(inputRow: InternalRow): Unit = {}
+ }
+
+ case class BitmapCardinalityCheck(rowIdOrdinal: Int) extends CardinalityCheck {
+ private val matchedRowIds = new Roaring64Bitmap()
+
+ override def execute(inputRow: InternalRow): Unit = {
+ val currentRowId = inputRow.getLong(rowIdOrdinal)
+ if (matchedRowIds.contains(currentRowId)) {
+ fail()
+ }
+ matchedRowIds.add(currentRowId)
+ }
}
}
diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/UpdateRowsExec.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/UpdateRowsExec.scala
new file mode 100644
index 0000000000..ef4ad37339
--- /dev/null
+++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/UpdateRowsExec.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.UnaryExecNode
+
+case class UpdateRowsExec(
+ deleteOutput: Seq[Expression],
+ insertOutput: Seq[Expression],
+ output: Seq[Attribute],
+ child: SparkPlan) extends UnaryExecNode {
+
+ @transient override lazy val producedAttributes: AttributeSet = {
+ AttributeSet(output.filterNot(attr => inputSet.contains(attr)))
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ s"UpdateRowsExec${truncatedString(output, "[", ", ", "]", maxFields)}"
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitions(processPartition)
+ }
+
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = {
+ copy(child = newChild)
+ }
+
+ private def processPartition(rowIterator: Iterator[InternalRow]): Iterator[InternalRow] = {
+ val deleteProj = createProjection(deleteOutput)
+ val insertProj = createProjection(insertOutput)
+ new UpdateAsDeleteAndInsertRowIterator(rowIterator, deleteProj, insertProj)
+ }
+
+ private def createProjection(exprs: Seq[Expression]): UnsafeProjection = {
+ UnsafeProjection.create(exprs, child.output)
+ }
+
+ class UpdateAsDeleteAndInsertRowIterator(
+ private val inputRows: Iterator[InternalRow],
+ private val deleteProj: UnsafeProjection,
+ private val insertProj: UnsafeProjection)
+ extends Iterator[InternalRow] {
+
+ var cachedInsertRow: InternalRow = _
+
+ override def hasNext: Boolean = cachedInsertRow != null || inputRows.hasNext
+
+ override def next(): InternalRow = {
+ if (cachedInsertRow != null) {
+ val insertRow = cachedInsertRow
+ cachedInsertRow = null
+ return insertRow
+ }
+
+ val row = inputRows.next()
+ cachedInsertRow = insertProj.apply(row)
+ deleteProj.apply(row)
+ }
+ }
+}
diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
index c008017d50..7222777d25 100644
--- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
+++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -645,6 +645,40 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
sql("SELECT * FROM %s ORDER BY id", selectTarget()));
}
+ @Test
+ public void testMergeWithOneMatchingBranchButMultipleSourceRowsForTargetRow() {
+ createAndInitTable(
+ "id INT, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView(
+ "source",
+ "id INT, dep STRING",
+ "{ \"id\": 1, \"state\": \"on\" }\n"
+ + "{ \"id\": 1, \"state\": \"off\" }\n"
+ + "{ \"id\": 10, \"state\": \"on\" }");
+
+ String errorMsg = "a single row from the target table with multiple rows of the source table";
+ Assertions.assertThatThrownBy(
+ () ->
+ sql(
+ "MERGE INTO %s AS t USING source AS s "
+ + "ON t.id == s.id "
+ + "WHEN MATCHED AND t.id = 6 THEN "
+ + " DELETE "
+ + "WHEN NOT MATCHED THEN "
+ + " INSERT (id, dep) VALUES (s.id, 'unknown')",
+ commitTarget()))
+ .cause()
+ .isInstanceOf(SparkException.class)
+ .hasMessageContaining(errorMsg);
+
+ assertEquals(
+ "Target should be unchanged",
+ ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")),
+ sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget()));
+ }
+
@Test
public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSource() {
createAndInitTable(
diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java
index f2c8f6e26c..3180419b9b 100644
--- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java
+++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SparkDistributionAndOrderingUtil.java
@@ -218,7 +218,7 @@ public class SparkDistributionAndOrderingUtil {
}
public static SortOrder[] buildPositionDeltaOrdering(Table table, Command command) {
- if (command == DELETE || command == UPDATE) {
+ if (command == DELETE) {
return POSITION_DELETE_ORDERING;
} else {
// all metadata columns like _spec_id, _file, _pos will be null for new data records
diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java
index c6b1eaeceb..4fff992e87 100644
--- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java
+++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/TestSparkDistributionAndOrderingUtil.java
@@ -1596,11 +1596,22 @@ public class TestSparkDistributionAndOrderingUtil extends SparkTestBaseWithCatal
Table table = validationCatalog.loadTable(tableIdent);
+ SortOrder[] expectedOrdering =
+ new SortOrder[] {
+ Expressions.sort(
+ Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING),
+ Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING),
+ Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING)
+ };
+
checkPositionDeltaDistributionAndOrdering(
- table,
- UPDATE,
- SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION,
- SPEC_ID_PARTITION_FILE_POSITION_ORDERING);
+ table, UPDATE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, expectedOrdering);
}
@Test
@@ -1615,8 +1626,22 @@ public class TestSparkDistributionAndOrderingUtil extends SparkTestBaseWithCatal
table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_NONE).commit();
+ SortOrder[] expectedOrdering =
+ new SortOrder[] {
+ Expressions.sort(
+ Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING),
+ Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING),
+ Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING)
+ };
+
checkPositionDeltaDistributionAndOrdering(
- table, UPDATE, UNSPECIFIED_DISTRIBUTION, SPEC_ID_PARTITION_FILE_POSITION_ORDERING);
+ table, UPDATE, UNSPECIFIED_DISTRIBUTION, expectedOrdering);
}
@Test
@@ -1631,11 +1656,22 @@ public class TestSparkDistributionAndOrderingUtil extends SparkTestBaseWithCatal
table.updateProperties().set(UPDATE_DISTRIBUTION_MODE, WRITE_DISTRIBUTION_MODE_HASH).commit();
+ SortOrder[] expectedOrdering =
+ new SortOrder[] {
+ Expressions.sort(
+ Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING),
+ Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING),
+ Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING)
+ };
+
checkPositionDeltaDistributionAndOrdering(
- table,
- UPDATE,
- SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION,
- SPEC_ID_PARTITION_FILE_POSITION_ORDERING);
+ table, UPDATE, SPEC_ID_PARTITION_CLUSTERED_DISTRIBUTION, expectedOrdering);
}
@Test
@@ -1652,8 +1688,22 @@ public class TestSparkDistributionAndOrderingUtil extends SparkTestBaseWithCatal
Distribution expectedDistribution = Distributions.ordered(SPEC_ID_PARTITION_FILE_ORDERING);
+ SortOrder[] expectedOrdering =
+ new SortOrder[] {
+ Expressions.sort(
+ Expressions.column(MetadataColumns.SPEC_ID.name()), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.PARTITION_COLUMN_NAME), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.FILE_PATH.name()), SortDirection.ASCENDING),
+ Expressions.sort(
+ Expressions.column(MetadataColumns.ROW_POSITION.name()), SortDirection.ASCENDING),
+ Expressions.sort(Expressions.column("date"), SortDirection.ASCENDING),
+ Expressions.sort(Expressions.bucket(8, "data"), SortDirection.ASCENDING)
+ };
+
checkPositionDeltaDistributionAndOrdering(
- table, UPDATE, expectedDistribution, SPEC_ID_PARTITION_FILE_POSITION_ORDERING);
+ table, UPDATE, expectedDistribution, expectedOrdering);
}
// ==================================================================================