You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ao...@apache.org on 2023/01/06 21:52:35 UTC
[iceberg] branch master updated: Spark 3.3: Use regular planning for applicable row-level operations (#6534)
This is an automated email from the ASF dual-hosted git repository.
aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 8026442e48 Spark 3.3: Use regular planning for applicable row-level operations (#6534)
8026442e48 is described below
commit 8026442e486c5fde70c855fe88839bc2411a5ba1
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Fri Jan 6 13:52:29 2023 -0800
Spark 3.3: Use regular planning for applicable row-level operations (#6534)
---
.../catalyst/analysis/RewriteMergeIntoTable.scala | 29 +++++++--
.../v2/RowLevelCommandScanRelationPushDown.scala | 12 +++-
.../SparkRowLevelOperationsTestBase.java | 6 +-
.../apache/iceberg/spark/extensions/TestMerge.java | 75 ++++++++++++++++++++++
.../spark/source/SparkCopyOnWriteOperation.java | 2 -
.../apache/iceberg/spark/source/SparkWrite.java | 25 ++++++--
.../iceberg/spark/source/SparkWriteBuilder.java | 2 -
7 files changed, 134 insertions(+), 17 deletions(-)
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index 2e720bdd44..ca37f99955 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -22,6 +22,7 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.ProjectingInternalRow
import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.expressions.And
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.AttributeSet
@@ -32,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.expressions.MonotonicallyIncreasingID
+import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.plans.FullOuter
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.LeftAnti
@@ -74,7 +76,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* This rule assumes the commands have been fully resolved and all assignments have been aligned.
* That's why it must be run after AlignRowLevelCommandAssignments.
*/
-object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
+object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand with PredicateHelper {
private final val ROW_FROM_SOURCE = "__row_from_source"
private final val ROW_FROM_TARGET = "__row_from_target"
@@ -185,12 +187,14 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs)
val readAttrs = readRelation.output
+ val (targetCond, joinCond) = splitMergeCond(cond, readRelation)
+
// project an extra column to check if a target row exists after the join
// project a synthetic row ID to perform the cardinality check
val rowFromTarget = Alias(TrueLiteral, ROW_FROM_TARGET)()
val rowId = Alias(MonotonicallyIncreasingID(), ROW_ID)()
val targetTableProjExprs = readAttrs ++ Seq(rowFromTarget, rowId)
- val targetTableProj = Project(targetTableProjExprs, readRelation)
+ val targetTableProj = Project(targetTableProjExprs, Filter(targetCond, readRelation))
// project an extra column to check if a source row exists after the join
val rowFromSource = Alias(TrueLiteral, ROW_FROM_SOURCE)()
@@ -202,7 +206,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
// disable broadcasts for the target table to perform the cardinality check
val joinType = if (notMatchedActions.isEmpty) LeftOuter else FullOuter
val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None)
- val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(cond), joinHint)
+ val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(joinCond), joinHint)
// add an extra matched action to output the original row if none of the actual actions matched
// this is needed to keep target rows that should be copied over
@@ -253,9 +257,11 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
val readRelation = buildRelationWithAttrs(relation, operationTable, rowIdAttrs ++ metadataAttrs)
val readAttrs = readRelation.output
+ val (targetCond, joinCond) = splitMergeCond(cond, readRelation)
+
// project an extra column to check if a target row exists after the join
val targetTableProjExprs = readAttrs :+ Alias(TrueLiteral, ROW_FROM_TARGET)()
- val targetTableProj = Project(targetTableProjExprs, readRelation)
+ val targetTableProj = Project(targetTableProjExprs, Filter(targetCond, readRelation))
// project an extra column to check if a source row exists after the join
val sourceTableProjExprs = source.output :+ Alias(TrueLiteral, ROW_FROM_SOURCE)()
@@ -266,7 +272,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
// also disable broadcasts for the target table to perform the cardinality check
val joinType = if (notMatchedActions.isEmpty) Inner else RightOuter
val joinHint = JoinHint(leftHint = Some(HintInfo(Some(NO_BROADCAST_HASH))), rightHint = None)
- val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(cond), joinHint)
+ val joinPlan = Join(NoStatsUnaryNode(targetTableProj), sourceTableProj, joinType, Some(joinCond), joinHint)
val deleteRowValues = buildDeltaDeleteRowValues(rowAttrs, rowIdAttrs)
val metadataReadAttrs = readAttrs.filterNot(relation.outputSet.contains)
@@ -439,4 +445,17 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
ProjectingInternalRow(schema, projectedOrdinals)
}
+
+ // splits the MERGE condition into a predicate that references columns only from the target table,
+ // which can be pushed down, and a predicate used as a join condition to find matches
+ private def splitMergeCond(
+ cond: Expression,
+ targetTable: LogicalPlan): (Expression, Expression) = {
+
+ val (targetPredicates, joinPredicates) = splitConjunctivePredicates(cond)
+ .partition(_.references.subsetOf(targetTable.outputSet))
+ val targetCond = targetPredicates.reduceOption(And).getOrElse(TrueLiteral)
+ val joinCond = joinPredicates.reduceOption(And).getOrElse(TrueLiteral)
+ (targetCond, joinCond)
+ }
}
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
index 2be73cb6ee..9ee3035c26 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
@@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable
+import org.apache.spark.sql.catalyst.plans.logical.WriteDelta
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.expressions.filter.Predicate
@@ -38,8 +40,16 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic
import ExtendedDataSourceV2Implicits._
override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+ // use native Spark planning for delta-based plans and copy-on-write MERGE operations
+ // unlike other commands, these plans have filters that can be pushed down directly
+ case RewrittenRowLevelCommand(command, _: DataSourceV2Relation, rewritePlan)
+ if rewritePlan.isInstanceOf[WriteDelta] || command.isInstanceOf[MergeIntoIcebergTable] =>
+
+ val newRewritePlan = V2ScanRelationPushDown.apply(rewritePlan)
+ command.withNewRewritePlan(newRewritePlan)
+
// push down the filter from the command condition instead of the filter in the rewrite plan,
- // which may be negated for copy-on-write operations
+ // which may be negated for copy-on-write DELETE and UPDATE operations
case RewrittenRowLevelCommand(command, relation: DataSourceV2Relation, rewritePlan) =>
val table = relation.table.asRowLevelOperationTable
val scanBuilder = table.newScanBuilder(relation.options)
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
index 633b2ee431..3039958d2c 100644
--- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
@@ -170,7 +170,11 @@ public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTes
}
protected void createAndInitTable(String schema, String jsonData) {
- sql("CREATE TABLE %s (%s) USING iceberg", tableName, schema);
+ createAndInitTable(schema, "", jsonData);
+ }
+
+ protected void createAndInitTable(String schema, String partitioning, String jsonData) {
+ sql("CREATE TABLE %s (%s) USING iceberg %s", tableName, schema, partitioning);
initTable();
if (jsonData != null) {
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
index 58fbb6241e..c598cb720c 100644
--- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -18,7 +18,10 @@
*/
package org.apache.iceberg.spark.extensions;
+import static org.apache.iceberg.RowLevelOperationMode.COPY_ON_WRITE;
import static org.apache.iceberg.TableProperties.MERGE_ISOLATION_LEVEL;
+import static org.apache.iceberg.TableProperties.MERGE_MODE;
+import static org.apache.iceberg.TableProperties.MERGE_MODE_DEFAULT;
import static org.apache.iceberg.TableProperties.PARQUET_ROW_GROUP_SIZE_BYTES;
import static org.apache.iceberg.TableProperties.SPLIT_SIZE;
import static org.apache.iceberg.TableProperties.WRITE_DISTRIBUTION_MODE;
@@ -39,6 +42,7 @@ import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iceberg.AssertHelpers;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.DistributionMode;
+import org.apache.iceberg.RowLevelOperationMode;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;
@@ -54,6 +58,7 @@ import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
+import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.internal.SQLConf;
import org.assertj.core.api.Assertions;
import org.junit.After;
@@ -85,6 +90,55 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
sql("DROP TABLE IF EXISTS source");
}
+ @Test
+ public void testMergeConditionSplitIntoTargetPredicateAndJoinCondition() {
+ createAndInitTable(
+ "id INT, salary INT, dep STRING, sub_dep STRING",
+ "PARTITIONED BY (dep, sub_dep)",
+ "{ \"id\": 1, \"salary\": 100, \"dep\": \"d1\", \"sub_dep\": \"sd1\" }\n"
+ + "{ \"id\": 6, \"salary\": 600, \"dep\": \"d6\", \"sub_dep\": \"sd6\" }");
+
+ createOrReplaceView(
+ "source",
+ "id INT, salary INT, dep STRING, sub_dep STRING",
+ "{ \"id\": 1, \"salary\": 101, \"dep\": \"d1\", \"sub_dep\": \"sd1\" }\n"
+ + "{ \"id\": 2, \"salary\": 200, \"dep\": \"d2\", \"sub_dep\": \"sd2\" }\n"
+ + "{ \"id\": 3, \"salary\": 300, \"dep\": \"d3\", \"sub_dep\": \"sd3\" }");
+
+ String query =
+ String.format(
+ "MERGE INTO %s AS t USING source AS s "
+ + "ON t.id == s.id AND ((t.dep = 'd1' AND t.sub_dep IN ('sd1', 'sd3')) OR (t.dep = 'd6' AND t.sub_dep IN ('sd2', 'sd6'))) "
+ + "WHEN MATCHED THEN "
+ + " UPDATE SET salary = s.salary "
+ + "WHEN NOT MATCHED THEN "
+ + " INSERT *",
+ tableName);
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ if (mode(table) == COPY_ON_WRITE) {
+ checkJoinAndFilterConditions(
+ query,
+ "Join [id], [id], FullOuter",
+ "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND sub_dep IN ('sd2', 'sd6')))");
+ } else {
+ checkJoinAndFilterConditions(
+ query,
+ "Join [id], [id], RightOuter",
+ "((dep = 'd1' AND sub_dep IN ('sd1', 'sd3')) OR (dep = 'd6' AND sub_dep IN ('sd2', 'sd6')))");
+ }
+
+ assertEquals(
+ "Should have expected rows",
+ ImmutableList.of(
+ row(1, 101, "d1", "sd1"), // updated
+ row(2, 200, "d2", "sd2"), // new
+ row(3, 300, "d3", "sd3"), // new
+ row(6, 600, "d6", "sd6")), // existing
+ sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
@Test
public void testMergeWithStaticPredicatePushDown() {
createAndInitTable("id BIGINT, dep STRING");
@@ -2274,4 +2328,25 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
List<Object[]> result = sql("SELECT * FROM %s ORDER BY id", tableName);
assertEquals("Should correctly add the non-matching rows", expectedRows, result);
}
+
+ private void checkJoinAndFilterConditions(String query, String join, String icebergFilters) {
+ // disable runtime filtering for easier validation
+ withSQLConf(
+ ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"),
+ () -> {
+ SparkPlan sparkPlan = executeAndKeepPlan(() -> sql(query));
+ String planAsString = sparkPlan.toString().replaceAll("#(\\d+L?)", "");
+
+ Assertions.assertThat(planAsString).as("Join should match").contains(join + "\n");
+
+ Assertions.assertThat(planAsString)
+ .as("Pushed filters must match")
+ .contains("[filters=" + icebergFilters + ",");
+ });
+ }
+
+ private RowLevelOperationMode mode(Table table) {
+ String modeName = table.properties().getOrDefault(MERGE_MODE, MERGE_MODE_DEFAULT);
+ return RowLevelOperationMode.fromName(modeName);
+ }
}
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
index 72c243fcbc..68c9944044 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java
@@ -24,7 +24,6 @@ import static org.apache.spark.sql.connector.write.RowLevelOperation.Command.UPD
import org.apache.iceberg.IsolationLevel;
import org.apache.iceberg.MetadataColumns;
import org.apache.iceberg.Table;
-import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.expressions.Expressions;
import org.apache.spark.sql.connector.expressions.NamedReference;
@@ -81,7 +80,6 @@ class SparkCopyOnWriteOperation implements RowLevelOperation {
@Override
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
if (lazyWriteBuilder == null) {
- Preconditions.checkState(configuredScan != null, "Write must be configured after scan");
SparkWriteBuilder writeBuilder = new SparkWriteBuilder(spark, table, info);
lazyWriteBuilder = writeBuilder.overwriteFiles(configuredScan, command, isolationLevel);
}
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
index 0b19fec9fd..f77d96da7f 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java
@@ -383,7 +383,11 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering {
}
private List<DataFile> overwrittenFiles() {
- return scan.tasks().stream().map(FileScanTask::file).collect(Collectors.toList());
+ if (scan == null) {
+ return ImmutableList.of();
+ } else {
+ return scan.tasks().stream().map(FileScanTask::file).collect(Collectors.toList());
+ }
}
private Expression conflictDetectionFilter() {
@@ -415,12 +419,21 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering {
overwriteFiles.addFile(file);
}
- if (isolationLevel == SERIALIZABLE) {
- commitWithSerializableIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles);
- } else if (isolationLevel == SNAPSHOT) {
- commitWithSnapshotIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles);
+ // the scan may be null if the optimizer replaces it with an empty relation (e.g. false cond)
+ // no validation is needed in this case as the command does not depend on the table state
+ if (scan != null) {
+ if (isolationLevel == SERIALIZABLE) {
+ commitWithSerializableIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles);
+ } else if (isolationLevel == SNAPSHOT) {
+ commitWithSnapshotIsolation(overwriteFiles, numOverwrittenFiles, numAddedFiles);
+ } else {
+ throw new IllegalArgumentException("Unsupported isolation level: " + isolationLevel);
+ }
+
} else {
- throw new IllegalArgumentException("Unsupported isolation level: " + isolationLevel);
+ commitOperation(
+ overwriteFiles,
+ String.format("overwrite with %d new data files (no validation)", numAddedFiles));
}
}
diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
index 6483f13048..55cf7961e9 100644
--- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
+++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java
@@ -86,8 +86,6 @@ class SparkWriteBuilder implements WriteBuilder, SupportsDynamicOverwrite, Suppo
}
public WriteBuilder overwriteFiles(Scan scan, Command command, IsolationLevel isolationLevel) {
- Preconditions.checkArgument(
- scan instanceof SparkCopyOnWriteScan, "%s is not SparkCopyOnWriteScan", scan);
Preconditions.checkState(!overwriteByFilter, "Cannot overwrite individual files and by filter");
Preconditions.checkState(
!overwriteDynamic, "Cannot overwrite individual files and dynamically");