You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ja...@apache.org on 2022/02/10 22:35:52 UTC

[iceberg] 01/03: Spark 3.2: Fix predicate pushdown in row-level operations (#4023)

This is an automated email from the ASF dual-hosted git repository.

jackye pushed a commit to branch 0.13.x
in repository https://gitbox.apache.org/repos/asf/iceberg.git

commit 5d599e160787c39b2ded153514ef5dc5a74d890e
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Tue Feb 1 12:46:48 2022 -0800

    Spark 3.2: Fix predicate pushdown in row-level operations (#4023)
---
 .../v2/RowLevelCommandScanRelationPushDown.scala   | 32 ++++++++++----
 .../apache/iceberg/spark/extensions/TestMerge.java | 49 ++++++++++++++++++++++
 2 files changed, 72 insertions(+), 9 deletions(-)

diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
index 52b27d5..4e89b9a 100644
--- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
+++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RowLevelCommandScanRelationPushDown.scala
@@ -20,13 +20,17 @@
 package org.apache.spark.sql.execution.datasources.v2
 
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.PredicateHelper
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
 import org.apache.spark.sql.catalyst.planning.RewrittenRowLevelCommand
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.util.CharVarcharUtils
+import org.apache.spark.sql.connector.read.ScanBuilder
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.sources.Filter
 import org.apache.spark.sql.types.StructType
 
 object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
@@ -39,16 +43,12 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic
       val table = relation.table.asRowLevelOperationTable
       val scanBuilder = table.newScanBuilder(relation.options)
 
-      val filters = command.condition.toSeq
-      val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, relation.output)
-      val (_, normalizedFiltersWithoutSubquery) =
-        normalizedFilters.partition(SubqueryExpression.hasSubquery)
-
-      val (pushedFilters, remainingFilters) = PushDownUtils.pushFilters(
-        scanBuilder, normalizedFiltersWithoutSubquery)
+      val (pushedFilters, remainingFilters) = command.condition match {
+        case Some(cond) => pushFilters(cond, scanBuilder, relation.output)
+        case None => (Nil, Nil)
+      }
 
-      val (scan, output) = PushDownUtils.pruneColumns(
-        scanBuilder, relation, relation.output, Seq.empty)
+      val (scan, output) = PushDownUtils.pruneColumns(scanBuilder, relation, relation.output, Nil)
 
       logInfo(
         s"""
@@ -68,6 +68,20 @@ object RowLevelCommandScanRelationPushDown extends Rule[LogicalPlan] with Predic
       command.withNewRewritePlan(newRewritePlan)
   }
 
+  private def pushFilters(
+      cond: Expression,
+      scanBuilder: ScanBuilder,
+      tableAttrs: Seq[AttributeReference]): (Seq[Filter], Seq[Expression]) = {
+
+    val tableAttrSet = AttributeSet(tableAttrs)
+    val filters = splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet))
+    val normalizedFilters = DataSourceStrategy.normalizeExprs(filters, tableAttrs)
+    val (_, normalizedFiltersWithoutSubquery) =
+      normalizedFilters.partition(SubqueryExpression.hasSubquery)
+
+    PushDownUtils.pushFilters(scanBuilder, normalizedFiltersWithoutSubquery)
+  }
+
   private def toOutputAttrs(
       schema: StructType,
       relation: DataSourceV2Relation): Seq[AttributeReference] = {
diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
index 23ba7ad..6537e31 100644
--- a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
+++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -32,6 +32,9 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import org.apache.iceberg.AssertHelpers;
 import org.apache.iceberg.DistributionMode;
+import org.apache.iceberg.Snapshot;
+import org.apache.iceberg.SnapshotSummary;
+import org.apache.iceberg.Table;
 import org.apache.iceberg.exceptions.ValidationException;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
@@ -76,6 +79,52 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
   }
 
   @Test
+  public void testMergeWithStaticPredicatePushDown() {
+    createAndInitTable("id BIGINT, dep STRING");
+
+    sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"software\" }\n" +
+        "{ \"id\": 11, \"dep\": \"software\" }\n" +
+        "{ \"id\": 1, \"dep\": \"hr\" }");
+
+    Table table = validationCatalog.loadTable(tableIdent);
+
+    Snapshot snapshot = table.currentSnapshot();
+    String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP);
+    Assert.assertEquals("Must have 2 files before MERGE", "2", dataFilesCount);
+
+    createOrReplaceView("source",
+        "{ \"id\": 1, \"dep\": \"finance\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hardware\" }");
+
+    // disable dynamic pruning and rely only on static predicate pushdown
+    withSQLConf(ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), () -> {
+      sql("MERGE INTO %s t USING source " +
+          "ON t.id == source.id AND t.dep IN ('software') AND source.id < 10 " +
+          "WHEN MATCHED AND source.id = 1 THEN " +
+          "  UPDATE SET dep = source.dep " +
+          "WHEN NOT MATCHED THEN " +
+          "  INSERT (dep, id) VALUES (source.dep, source.id)", tableName);
+    });
+
+    table.refresh();
+
+    Snapshot mergeSnapshot = table.currentSnapshot();
+    String deletedDataFilesCount = mergeSnapshot.summary().get(SnapshotSummary.DELETED_FILES_PROP);
+    Assert.assertEquals("Must overwrite only 1 file", "1", deletedDataFilesCount);
+
+    ImmutableList<Object[]> expectedRows = ImmutableList.of(
+        row(1L, "finance"),  // updated
+        row(1L, "hr"),       // kept
+        row(2L, "hardware"), // new
+        row(11L, "software") // kept
+    );
+    assertEquals("Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+  }
+
+  @Test
   public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() {
     createAndInitTable("id INT, dep STRING");