You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ao...@apache.org on 2023/01/06 21:52:35 UTC

[iceberg] branch master updated: Spark 3.3: Use regular planning for applicable row-level operations (#6534)

This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 8026442e48 Spark 3.3: Use regular planning for applicable row-level operations (#6534)
8026442e48 is described below

commit 8026442e486c5fde70c855fe88839bc2411a5ba1
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Fri Jan 6 13:52:29 2023 -0800

    Spark 3.3: Use regular planning for applicable row-level operations (#6534)
---
 .../catalyst/analysis/RewriteMergeIntoTable.scala  | 29 +++++++--
 .../v2/RowLevelCommandScanRelationPushDown.scala   | 12 +++-
 .../SparkRowLevelOperationsTestBase.java           |  6 +-
 .../apache/iceberg/spark/extensions/TestMerge.java | 75 ++++++++++++++++++++++
 .../spark/source/SparkCopyOnWriteOperation.java    |  2 -
 .../apache/iceberg/spark/source/SparkWrite.java    | 25 ++++++--
 .../iceberg/spark/source/SparkWriteBuilder.java    |  2 -
 7 files changed, 134 insertions(+), 17 deletions(-)

diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index 2e720bdd44..ca37f99955 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -22,6 +22,7 @@ package org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.ProjectingInternalRow
 import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.expressions.And
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
 import org.apache.spark.sql.catalyst.expressions.AttributeSet
@@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
 import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
 import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
 import org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID
+import org.apache.spark.sql.catalyst.expressions.PredicateHelper
 import org.apache.spark.sql.catalyst.plans.FullOuter
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.LeftAnti
@@ -74,7 +76,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
  * This rule assumes the commands have been fully resolved and all assignments have been aligned.
  * That's why it must be run after AlignRowLevelCommandAssignments.
  */
-object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
+object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with PredicateHelper {
 
   private final val ROW_FROM_SOURCE = "__row_from_source"
   private final val ROW_FROM_TARGET = "__row_from_target"
@@ -185,12 +187,14 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
     val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs)
     val readAttrs = readRelation.output
 
+    val (targetCond, joinCond) = splitMergeCond(cond, readRelation)
+
     // project an extra column to check if a target row exists after the join
     // project a synthetic row ID to perform the cardinality check
     val rowFromTarget = Alias(TrueLiteral, ROW_FROM_TARGET)()
     val rowId = Alias(MonotonicallyIncreasingID(), ROW_ID)()
     val targetTableProjExprs = readAttrs ++ Seq(rowFromTarget, rowId)
-    val targetTableProj = Project(targetTableProjExprs, readRelation)
+    val targetTableProj = Project(targetTableProjExprs, Filter(targetCond, readRelation))
 
     // project an extra column to check if a source row exists after the join
     val rowFromSource = Alias(TrueLiteral, ROW_FROM_SOURCE)()
@@ -202,7 +206,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
     // disable broadcasts for the target table to perform the cardinality check
     val joinType = if (notMatchedActions.isEmpty) LeftOuter else FullOuter
     val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None)
-    val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(cond), joinHint)
+    val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(joinCond), joinHint)
 
     // add an extra matched action to output the original row if none of the actual actions matched
     // this is needed to keep target rows that should be copied over
@@ -253,9 +257,11 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
     val readRelation = buildRelationWithAttrs(relation, operationTable, rowIdAttrs ++ metadataAttrs)
     val readAttrs = readRelation.output
 
+    val (targetCond, joinCond) = splitMergeCond(cond, readRelation)
+
     // project an extra column to check if a target row exists after the join
     val targetTableProjExprs = readAttrs :+ Alias(TrueLiteral, ROW_FROM_TARGET)()
-    val targetTableProj = Project(targetTableProjExprs, readRelation)
+    val targetTableProj = Project(targetTableProjExprs, Filter(targetCond, readRelation))
 
     // project an extra column to check if a source row exists after the join
     val sourceTableProjExprs = source.output :+ Alias(TrueLiteral, ROW_FROM_SOURCE)()
@@ -266,7 +272,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
     // also disable broadcasts for the target table to perform the cardinality check
     val joinType = if (notMatchedActions.isEmpty) Inner else RightOuter
     val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None)
-    val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(cond), joinHint)
+    val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(joinCond), joinHint)
 
     val deleteRowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs)
     val metadataReadAttrs = readAttrs.filterNot(relation.outputSet.contains)
@@ -439,4 +445,17 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
 
     ProjectingInternalRow(schema, projectedOrdinals)
   }
+
+  // splits the MERGE condition into a predicate that references columns only from the target table,
+  // which can be pushed down, and a predicate used as a join condition to find matches
+  private def splitMergeCond(
+      cond: Expression,
+      targetTable: LogicalPlan): (Expression, Expression) = {
+
+    val (targetPredicates, joinPredicates) = splitConjunctivePredicates(cond)
+      .partition(_.references.subsetOf(targetTable.outputSet))
+    val targetCond = targetPredicates.reduceOption(And).getOrElse(TrueLiteral)
+    val joinCond = joinPredicates.reduceOption(And).getOrElse(TrueLiteral)
+    (targetCond, joinCond)
+  }
 }
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
index 2be73cb6ee..9ee3035c26 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.PredicateHelper
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
 import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable
+import org.apache.spark.sql.catalyst.plans.logical.WriteDelta
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
 import org.apache.spark.sql.connector.expressions.filter.Predicate
@@ -38,8 +40,16 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic
   import ExtendedDataSourceV2Implicits._
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+    // use native Spark planning for delta-based plans and copy-on-write MERGE operations
+    // unlike other commands, these plans have filters that can be pushed down directly
+    case RewrittenRowLevelCommand(command, _: DataSourceV2Relation, rewritePlan)
+        if rewritePlan.isInstanceOf[WriteDelta] || command.isInstanceOf[MergeIntoIcebergTable] =>
+
+      val newRewritePlan = V2ScanRelationPushDown.apply(rewritePlan)
+      command.withNewRewritePlan(newRewritePlan)
+
     // push down the filter from the command condition instead of the filter in the rewrite plan,
-    // which may be negated for copy-on-write operations
+    // which may be negated for copy-on-write DELETE and UPDATE operations
     case RewrittenRowLevelCommand(command, relation: DataSourceV2Relation, rewritePlan) =>
       val table = relation.table.asRowLevelOperationTable
       val scanBuilder = table.newScanBuilder(relation.options)
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
index 633b2ee431..3039958d2c 100644
--- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
@@ -170,7 +170,11 @@ public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTes
   }
 
   protected void createAndInitTable(String schema, String jsonData) {
-    sql("CREATE TABLE %s (%s) USING iceberg", tableName, schema);
+    createAndInitTable(schema, "", jsonData);
+  }
+
+  protected void createAndInitTable(String schema, String partitioning, String jsonData) {
+    sql("CREATE TABLE %s (%s) USING iceberg %s", tableName, schema, partitioning);
     initTable();
 
     if (jsonData != null) {
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
index 58fbb6241e..c598cb720c 100644
--- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -18,7 +18,10 @@
  */
 package org.apache.iceberg.spark.extensions;
 
+import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE;
 import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL;
+import static org.apache.iceberg.TableProperties.MERGE_MODE;
+import static org.apache.iceberg.TableProperties.MERGE_MODE_DEFAULT;
 import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES;
 import static org.apache.iceberg.TableProperties.SPLIT_SIZE;
 import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE;
@@ -39,6 +42,7 @@ import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.iceberg.AssertHelpers;
 import org.apache.iceberg.DataFile;
 import org.apache.iceberg.DistributionMode;
+import org.apache.iceberg.RowLevelOperationMode;
 import org.apache.iceberg.Snapshot;
 import org.apache.iceberg.SnapshotSummary;
 import org.apache.iceberg.Table;
@@ -54,6 +58,7 @@ import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.execution.SparkPlan;
 import org.apache.spark.sql.internal.SQLConf;
 import org.assertj.core.api.Assertions;
 import org.junit.After;
@@ -85,6 +90,55 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
     sql("DROP TABLE IF EXISTS source");
   }
 
+  @Test
+  public void testMergeConditionSplitIntoTargetPredicateAndJoinCondition() {
+    createAndInitTable(
+        "id INT, salary INT, dep STRING, sub_dep STRING",
+        "PARTITIONED BY (dep, sub_dep)",
+        "{ \"id\": 1, \"salary\": 100, \"dep\": \"d1\", \"sub_dep\": \"sd1\" }\n"
+            + "{ \"id\": 6, \"salary\": 600, \"dep\": \"d6\", \"sub_dep\": \"sd6\" }");
+
+    createOrReplaceView(
+        "source",
+        "id INT, salary INT, dep STRING, sub_dep STRING",
+        "{ \"id\": 1, \"salary\": 101, \"dep\": \"d1\", \"sub_dep\": \"sd1\" }\n"
+            + "{ \"id\": 2, \"salary\": 200, \"dep\": \"d2\", \"sub_dep\": \"sd2\" }\n"
+            + "{ \"id\": 3, \"salary\": 300, \"dep\": \"d3\", \"sub_dep\": \"sd3\"  }");
+
+    String query =
+        String.format(
+            "MERGE INTO %s AS t USING source AS s "
+                + "ON t.id == s.id AND ((t.dep = 'd1' AND t.sub_dep IN ('sd1', 'sd3')) OR (t.dep = 'd6' AND t.sub_dep IN ('sd2', 'sd6'))) "
+                + "WHEN MATCHED THEN "
+                + "  UPDATE SET salary = s.salary "
+                + "WHEN NOT MATCHED THEN "
+                + "  INSERT *",
+            tableName);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+
+    if (mode(table) == COPY_ON_WRITE) {
+      checkJoinAndFilterConditions(
+          query,
+          "Join [id], [id], FullOuter",
+          "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND sub_dep IN ('sd2', 'sd6')))");
+    } else {
+      checkJoinAndFilterConditions(
+          query,
+          "Join [id], [id], RightOuter",
+          "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND sub_dep IN ('sd2', 'sd6')))");
+    }
+
+    assertEquals(
+        "Should have expected rows",
+        ImmutableList.of(
+            row(1, 101, "d1", "sd1"), // updated
+            row(2, 200, "d2", "sd2"), // new
+            row(3, 300, "d3", "sd3"), // new
+            row(6, 600, "d6", "sd6")), // existing
+        sql("SELECT * FROM %s ORDER BY id", tableName));
+  }
+
   @Test
   public void testMergeWithStaticPredicatePushDown() {
     createAndInitTable("id BIGINT, dep STRING");
@@ -2274,4 +2328,25 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
     List<Object[]> result = sql("SELECT * FROM %s ORDER BY id", tableName);
     assertEquals("Should correctly add the non-matching rows", expectedRows, result);
   }
+
+  private void checkJoinAndFilterConditions(String query, String join, String icebergFilters) {
+    // disable runtime filtering for easier validation
+    withSQLConf(
+        ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"),
+        () -> {
+          SparkPlan sparkPlan = executeAndKeepPlan(() -> sql(query));
+          String planAsString = sparkPlan.toString().replaceAll("#(\\d+L?)", "");
+
+          Assertions.assertThat(planAsString).as("Join should match").contains(join + "\n");
+
+          Assertions.assertThat(planAsString)
+              .as("Pushed filters must match")
+              .contains("[filters=" + icebergFilters + ",");
+        });
+  }
+
+  private RowLevelOperationMode mode(Table table) {
+    String modeName = table.properties().getOrDefault(MERGE_MODE, MERGE_MODE_DEFAULT);
+    return RowLevelOperationMode.fromName(modeName);
+  }
 }
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
index 72c243fcbc..68c9944044 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
@@ -24,7 +24,6 @@ import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPD
 import org.apache.iceberg.IsolationLevel;
 import org.apache.iceberg.MetadataColumns;
 import org.apache.iceberg.Table;
-import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.connector.expressions.Expressions;
 import org.apache.spark.sql.connector.expressions.NamedReference;
@@ -81,7 +80,6 @@ class SparkCopyOnWriteOperation implements RowLevelOperation {
   @Override
   public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
     if (lazyWriteBuilder == null) {
-      Preconditions.checkState(configuredScan != null, "Write must be configured after scan");
       SparkWriteBuilder writeBuilder = new SparkWriteBuilder(spark, table, info);
       lazyWriteBuilder = writeBuilder.overwriteFiles(configuredScan, command, isolationLevel);
     }
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
index 0b19fec9fd..f77d96da7f 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
@@ -383,7 +383,11 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering {
     }
 
     private List<DataFile> overwrittenFiles() {
-      return scan.tasks().stream().map(FileScanTask::file).collect(Collectors.toList());
+      if (scan == null) {
+        return ImmutableList.of();
+      } else {
+        return scan.tasks().stream().map(FileScanTask::file).collect(Collectors.toList());
+      }
     }
 
     private Expression conflictDetectionFilter() {
@@ -415,12 +419,21 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering {
         overwriteFiles.addFile(file);
       }
 
-      if (isolationLevel == SERIALIZABLE) {
-        commitWithSerializableIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles);
-      } else if (isolationLevel == SNAPSHOT) {
-        commitWithSnapshotIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles);
+      // the scan may be null if the optimizer replaces it with an empty relation (e.g. false cond)
+      // no validation is needed in this case as the command does not depend on the table state
+      if (scan != null) {
+        if (isolationLevel == SERIALIZABLE) {
+          commitWithSerializableIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles);
+        } else if (isolationLevel == SNAPSHOT) {
+          commitWithSnapshotIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles);
+        } else {
+          throw new IllegalArgumentException("Unsupported isolation level: " + isolationLevel);
+        }
+
       } else {
-        throw new IllegalArgumentException("Unsupported isolation level: " + isolationLevel);
+        commitOperation(
+            overwriteFiles,
+            String.format("overwrite with %d new data files (no validation)", numAddedFiles));
       }
     }
 
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
index 6483f13048..55cf7961e9 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
@@ -86,8 +86,6 @@ class SparkWriteBuilder implements WriteBuilder, SupportsDynamicOverwrite, Suppo
   }
 
   public WriteBuilder overwriteFiles(Scan scan, Command command, IsolationLevel isolationLevel) {
-    Preconditions.checkArgument(
-        scan instanceof SparkCopyOnWriteScan, "%s is not SparkCopyOnWriteScan", scan);
     Preconditions.checkState(!overwriteByFilter, "Cannot overwrite individual files and by filter");
     Preconditions.checkState(
         !overwriteDynamic, "Cannot overwrite individual files and dynamically");