You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ja...@apache.org on 2022/02/10 22:35:52 UTC
[iceberg] 01/03: Spark 3.2: Fix predicate pushdown in row-level operations (#4023)
This is an automated email from the ASF dual-hosted git repository.
jackye pushed a commit to branch 0.13.x
in repository https://gitbox.apache.org/repos/asf/iceberg.git
commit 5d599e160787c39b2ded153514ef5dc5a74d890e
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Tue Feb 1 12:46:48 2022 -0800
Spark 3.2: Fix predicate pushdown in row-level operations (#4023)
---
.../v2/RowLevelCommandScanRelationPushDown.scala | 32 ++++++++++----
.../apache/iceberg/spark/extensions/TestMerge.java | 49 ++++++++++++++++++++++
2 files changed, 72 insertions(+), 9 deletions(-)
diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
index 52b27d5..4e89b9a 100644
--- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
@@ -20,13 +20,17 @@
package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.PredicateHelper
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
@@ -39,16 +43,12 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic
val table = relation.table.asRowLevelOperationTable
val scanBuilder = table.newScanBuilder(relation.options)
- val filters = command.condition.toSeq
- val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output)
- val (_, normalizedFiltersWithoutSubquery) =
- normalizedFilters.partition(SubqueryExpression.hasSubquery)
-
- val (pushedFilters, remainingFilters) = PushDownUtils.pushFilters(
- scanBuilder, normalizedFiltersWithoutSubquery)
+ val (pushedFilters, remainingFilters) = command.condition match {
+ case Some(cond) => pushFilters(cond, scanBuilder, relation.output)
+ case None => (Nil, Nil)
+ }
- val (scan, output) = PushDownUtils.pruneColumns(
- scanBuilder, relation, relation.output, Seq.empty)
+ val (scan, output) = PushDownUtils.pruneColumns(scanBuilder, relation, relation.output, Nil)
logInfo(
s"""
@@ -68,6 +68,20 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic
command.withNewRewritePlan(newRewritePlan)
}
+ private def pushFilters(
+ cond: Expression,
+ scanBuilder: ScanBuilder,
+ tableAttrs: Seq[AttributeReference]): (Seq[Filter], Seq[Expression]) = {
+
+ val tableAttrSet = AttributeSet(tableAttrs)
+ val filters = splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet))
+ val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, tableAttrs)
+ val (_, normalizedFiltersWithoutSubquery) =
+ normalizedFilters.partition(SubqueryExpression.hasSubquery)
+
+ PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery)
+ }
+
private def toOutputAttrs(
schema: StructType,
relation: DataSourceV2Relation): Seq[AttributeReference] = {
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
index 23ba7ad..6537e31 100644
--- a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -32,6 +32,9 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.iceberg.AssertHelpers;
import org.apache.iceberg.DistributionMode;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.SnapshotSummary;
+import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
@@ -76,6 +79,52 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
}
@Test
+ public void testMergeWithStaticPredicatePushDown() {
+ createAndInitTable("id BIGINT, dep STRING");
+
+ sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+ append(tableName,
+ "{ \"id\": 1, \"dep\": \"software\" }\n" +
+ "{ \"id\": 11, \"dep\": \"software\" }\n" +
+ "{ \"id\": 1, \"dep\": \"hr\" }");
+
+ Table table = validationCatalog.loadTable(tableIdent);
+
+ Snapshot snapshot = table.currentSnapshot();
+ String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP);
+ Assert.assertEquals("Must have 2 files before MERGE", "2", dataFilesCount);
+
+ createOrReplaceView("source",
+ "{ \"id\": 1, \"dep\": \"finance\" }\n" +
+ "{ \"id\": 2, \"dep\": \"hardware\" }");
+
+ // disable dynamic pruning and rely only on static predicate pushdown
+ withSQLConf(ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), () -> {
+ sql("MERGE INTO %s t USING source " +
+ "ON t.id == source.id AND t.dep IN ('software') AND source.id < 10 " +
+ "WHEN MATCHED AND source.id = 1 THEN " +
+ " UPDATE SET dep = source.dep " +
+ "WHEN NOT MATCHED THEN " +
+ " INSERT (dep, id) VALUES (source.dep, source.id)", tableName);
+ });
+
+ table.refresh();
+
+ Snapshot mergeSnapshot = table.currentSnapshot();
+ String deletedDataFilesCount = mergeSnapshot.summary().get(SnapshotSummary.DELETED_FILES_PROP);
+ Assert.assertEquals("Must overwrite only 1 file", "1", deletedDataFilesCount);
+
+ ImmutableList<Object[]> expectedRows = ImmutableList.of(
+ row(1L, "finance"), // updated
+ row(1L, "hr"), // kept
+ row(2L, "hardware"), // new
+ row(11L, "software") // kept
+ );
+ assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+ }
+
+ @Test
public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() {
createAndInitTable("id INT, dep STRING");