You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by bl...@apache.org on 2021/01/16 03:48:04 UTC

[iceberg] branch master updated: Spark: Refactor RewriteRowLevelOperationHelper to support MERGE operations (#2097)

This is an automated email from the ASF dual-hosted git repository.

blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new a576929  Spark: Refactor RewriteRowLevelOperationHelper to support MERGE operations (#2097)
a576929 is described below

commit a5769291e82307f6dfda823b55c220e87c6c6a9b
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Sat Jan 16 05:47:46 2021 +0200

    Spark: Refactor RewriteRowLevelOperationHelper to support MERGE operations (#2097)
---
 .../sql/catalyst/optimizer/RewriteDelete.scala     |  3 +-
 .../utils/RewriteRowLevelOperationHelper.scala     | 41 ++++++++++++++--------
 2 files changed, 29 insertions(+), 15 deletions(-)

diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala
index ee2f5a5..e86f21f 100644
--- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDelete.scala
@@ -61,7 +61,8 @@ case class RewriteDelete(conf: SQLConf) extends Rule[LogicalPlan] with RewriteRo
       val writeInfo = newWriteInfo(r.schema)
       val mergeBuilder = r.table.asMergeable.newMergeBuilder("delete", writeInfo)
 
-      val scanPlan = buildScanPlan(r.table, r.output, mergeBuilder, cond)
+      val matchingRowsPlanBuilder = scanRelation => Filter(cond, scanRelation)
+      val scanPlan = buildScanPlan(r.table, r.output, mergeBuilder, cond, matchingRowsPlanBuilder)
 
       val remainingRowFilter = Not(EqualNullSafe(cond, Literal(true, BooleanType)))
       val remainingRowsPlan = Filter(remainingRowFilter, scanPlan)
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/utils/RewriteRowLevelOperationHelper.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala
similarity index 73%
rename from spark3-extensions/src/main/scala/org/apache/spark/sql/utils/RewriteRowLevelOperationHelper.scala
rename to spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala
index 58a1dcd..f7ad083 100644
--- a/spark3-extensions/src/main/scala/org/apache/spark/sql/utils/RewriteRowLevelOperationHelper.scala
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/utils/RewriteRowLevelOperationHelper.scala
@@ -24,16 +24,17 @@ import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis.Resolver
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.AttributeSet
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.PredicateHelper
 import org.apache.spark.sql.catalyst.plans.logical.Aggregate
 import org.apache.spark.sql.catalyst.plans.logical.DynamicFileFilter
-import org.apache.spark.sql.catalyst.plans.logical.Filter
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.connector.catalog.Table
 import org.apache.spark.sql.connector.iceberg.read.SupportsFileFilter
 import org.apache.spark.sql.connector.iceberg.write.MergeBuilder
+import org.apache.spark.sql.connector.read.ScanBuilder
 import org.apache.spark.sql.connector.write.LogicalWriteInfo
 import org.apache.spark.sql.connector.write.LogicalWriteInfoImpl
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
@@ -51,28 +52,41 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging {
 
   protected def buildScanPlan(
       table: Table,
-      output: Seq[AttributeReference],
+      tableAttrs: Seq[AttributeReference],
       mergeBuilder: MergeBuilder,
-      cond: Expression): LogicalPlan = {
+      cond: Expression,
+      matchingRowsPlanBuilder: DataSourceV2ScanRelation => LogicalPlan): LogicalPlan = {
 
     val scanBuilder = mergeBuilder.asScanBuilder
 
-    val predicates = splitConjunctivePredicates(cond)
-    val normalizedPredicates = DataSourceStrategy.normalizeExprs(predicates, output)
-    PushDownUtils.pushFilters(scanBuilder, normalizedPredicates)
+    pushFilters(scanBuilder, cond, tableAttrs)
 
     val scan = scanBuilder.build()
-    val scanRelation = DataSourceV2ScanRelation(table, scan, toOutputAttrs(scan.readSchema(), output))
+    val outputAttrs = toOutputAttrs(scan.readSchema(), tableAttrs)
+    val scanRelation = DataSourceV2ScanRelation(table, scan, outputAttrs)
 
     scan match {
       case filterable: SupportsFileFilter =>
-        val matchingFilePlan = buildFileFilterPlan(cond, scanRelation)
+        val matchingFilePlan = buildFileFilterPlan(matchingRowsPlanBuilder(scanRelation))
         DynamicFileFilter(scanRelation, matchingFilePlan, filterable)
       case _ =>
         scanRelation
     }
   }
 
+  private def pushFilters(
+      scanBuilder: ScanBuilder,
+      cond: Expression,
+      tableAttrs: Seq[AttributeReference]): Unit = {
+
+    val tableAttrSet = AttributeSet(tableAttrs)
+    val predicates = splitConjunctivePredicates(cond).filter(_.references.subsetOf(tableAttrSet))
+    if (predicates.nonEmpty) {
+      val normalizedPredicates = DataSourceStrategy.normalizeExprs(predicates, tableAttrs)
+      PushDownUtils.pushFilters(scanBuilder, normalizedPredicates)
+    }
+  }
+
   protected def toDataSourceFilters(predicates: Seq[Expression]): Array[sources.Filter] = {
     predicates.flatMap { p =>
       val translatedFilter = DataSourceStrategy.translateFilter(p, supportNestedPredicatePushdown = true)
@@ -88,10 +102,9 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging {
     LogicalWriteInfoImpl(queryId = uuid.toString, schema, CaseInsensitiveStringMap.empty)
   }
 
-  private def buildFileFilterPlan(cond: Expression, scanRelation: DataSourceV2ScanRelation): LogicalPlan = {
-    val matchingFilter = Filter(cond, scanRelation)
-    val fileAttr = findOutputAttr(matchingFilter, FILE_NAME_COL)
-    val agg = Aggregate(Seq(fileAttr), Seq(fileAttr), matchingFilter)
+  private def buildFileFilterPlan(matchingRowsPlan: LogicalPlan): LogicalPlan = {
+    val fileAttr = findOutputAttr(matchingRowsPlan, FILE_NAME_COL)
+    val agg = Aggregate(Seq(fileAttr), Seq(fileAttr), matchingRowsPlan)
     Project(Seq(findOutputAttr(agg, FILE_NAME_COL)), agg)
   }
 
@@ -101,8 +114,8 @@ trait RewriteRowLevelOperationHelper extends PredicateHelper with Logging {
     }
   }
 
-  protected def toOutputAttrs(schema: StructType, output: Seq[AttributeReference]): Seq[AttributeReference] = {
-    val nameToAttr = output.map(_.name).zip(output).toMap
+  protected def toOutputAttrs(schema: StructType, attrs: Seq[AttributeReference]): Seq[AttributeReference] = {
+    val nameToAttr = attrs.map(_.name).zip(attrs).toMap
     schema.toAttributes.map {
       a => nameToAttr.get(a.name) match {
         case Some(ref) =>