You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/10/18 06:56:07 UTC

[spark] branch master updated: [SPARK-44649][SQL] Runtime Filter supports passing equivalent creation side expressions

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 103de914a5f [SPARK-44649][SQL] Runtime Filter supports passing equivalent creation side expressions
103de914a5f is described below

commit 103de914a5f96fccbe722663ee69c8ee7d9c8135
Author: Jiaan Geng <be...@163.com>
AuthorDate: Wed Oct 18 14:55:51 2023 +0800

    [SPARK-44649][SQL] Runtime Filter supports passing equivalent creation side expressions
    
    ### What changes were proposed in this pull request?
    Currently, Spark runtime filter supports multi level shuffle join side as filter creation side. Please see: https://github.com/apache/spark/pull/39170. Although this feature adds the adaptive scene and improves the performance, there are still need to support other case.
    
    **Optimization of Expression Transitivity on the Creation Side of Spark Runtime Filter**
    
    **Principle**
    Association expressions are transitive in some Joins, such as:
    `Tab1.col1A = Tab2.col2B` and `Tab2.col2B = Tab3.col3C`
    Actually, it can be inferred that `Tab1.col1A = Tab3.col3C`.
    
    **Optimization points**
    Currently, the runtime filter's creation side expression only uses directly associated keys. If the transitivity of association conditions is utilized, runtime filters can be injected into many scenarios, such as:
    
    ```
    SELECT *
    FROM (
      SELECT *
      FROM tab1
        JOIN tab2 ON tab1.c1 = tab2.c2
      WHERE tab2.a2 = 5
    ) AS a
      JOIN tab3 ON tab3.c3 = a.c1
    ```
    
    The `tab3.c3` here is only associated with `tab1.c1` and not with `tab2.c2`. Although there is selective filtering on tab2 (`tab2.a2 = 5`), Spark is currently unable to inject a Runtime Filter.
    As long as transitivity is considered, we can know that `tab3.c3` and `tab2.c2` are related, so we can still inject Runtime Filter and improve performance.
    
    For the current implementation, Spark only inject runtime filter into tab1 with bloom filter based on `bf2.a2 = 5`.
    Because there is no the join between tab3 and tab2, so Spark can't inject runtime filter into tab3 with the same bloom filter.
    But the above SQL have the join condition `tab3.c3 = a.c1(tab1.c1)` between tab3 and tab2, and also have the join condition `tab1.c1 = tab2.c2`. We can rely on the transitivity of the join condition to get the virtual join condition `tab3.c3 = tab2.c2`, then we can inject the bloom filter based on `bf2.a2 = 5` into tab3.
    
    ### Why are the changes needed?
    Enhance the Spark runtime filter and improve performance.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    Just update the inner implementation.
    
    ### How was this patch tested?
    New tests.
    Micro benchmark for q75 in TPC-DS.
    **2TB TPC-DS**
    | TPC-DS Query   | Before(Seconds)  | After(Seconds)  | Speedup(Percent)  |
    |  ----  | ----  | ----  | ----  |
    | q75 | 129.664 | 81.562 | 58.98% |
    
    Closes #42317 from beliefer/SPARK-44649.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../catalyst/optimizer/InjectRuntimeFilter.scala   | 64 +++++++++++++++-------
 .../spark/sql/InjectRuntimeFilterSuite.scala       | 38 +++++++++++--
 2 files changed, 75 insertions(+), 27 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
index 8737082e571..30526bd8106 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InjectRuntimeFilter.scala
@@ -125,14 +125,14 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
    */
   private def extractSelectiveFilterOverScan(
       plan: LogicalPlan,
-      filterCreationSideKey: Expression): Option[LogicalPlan] = {
-    @tailrec
+      filterCreationSideKey: Expression): Option[(Expression, LogicalPlan)] = {
     def extract(
         p: LogicalPlan,
         predicateReference: AttributeSet,
         hasHitFilter: Boolean,
         hasHitSelectiveFilter: Boolean,
-        currentPlan: LogicalPlan): Option[LogicalPlan] = p match {
+        currentPlan: LogicalPlan,
+        targetKey: Expression): Option[(Expression, LogicalPlan)] = p match {
       case Project(projectList, child) if hasHitFilter =>
         // We need to make sure all expressions referenced by filter predicates are simple
         // expressions.
@@ -143,41 +143,62 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
             referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _),
             hasHitFilter,
             hasHitSelectiveFilter,
-            currentPlan)
+            currentPlan,
+            targetKey)
         } else {
           None
         }
       case Project(_, child) =>
         assert(predicateReference.isEmpty && !hasHitSelectiveFilter)
-        extract(child, predicateReference, hasHitFilter, hasHitSelectiveFilter, currentPlan)
+        extract(child, predicateReference, hasHitFilter, hasHitSelectiveFilter, currentPlan,
+          targetKey)
       case Filter(condition, child) if isSimpleExpression(condition) =>
         extract(
           child,
           predicateReference ++ condition.references,
           hasHitFilter = true,
           hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition),
-          currentPlan)
-      case ExtractEquiJoinKeys(_, _, _, _, _, left, right, _) =>
+          currentPlan,
+          targetKey)
+      case ExtractEquiJoinKeys(_, lkeys, rkeys, _, _, left, right, _) =>
         // Runtime filters use one side of the [[Join]] to build a set of join key values and prune
         // the other side of the [[Join]]. It's also OK to use a superset of the join key values
         // (ignore null values) to do the pruning.
-        if (left.output.exists(_.semanticEquals(filterCreationSideKey))) {
-          extract(left, AttributeSet.empty,
-            hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = left)
-        } else if (right.output.exists(_.semanticEquals(filterCreationSideKey))) {
-          extract(right, AttributeSet.empty,
-            hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = right)
+        // We assume other rules have already pushed predicates through join if possible.
+        // So the predicate references won't pass on anymore.
+        if (left.output.exists(_.semanticEquals(targetKey))) {
+          extract(left, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
+            currentPlan = left, targetKey = targetKey).orElse {
+            // We can also extract from the right side if the join keys are transitive.
+            lkeys.zip(rkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
+              .flatMap { newTargetKey =>
+                extract(right, AttributeSet.empty,
+                  hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = right,
+                  targetKey = newTargetKey)
+              }
+          }
+        } else if (right.output.exists(_.semanticEquals(targetKey))) {
+          extract(right, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
+            currentPlan = right, targetKey = targetKey).orElse {
+            // We can also extract from the left side if the join keys are transitive.
+            rkeys.zip(lkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
+              .flatMap { newTargetKey =>
+                extract(left, AttributeSet.empty,
+                  hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = left,
+                  targetKey = newTargetKey)
+              }
+          }
         } else {
           None
         }
       case _: LeafNode if hasHitSelectiveFilter =>
-        Some(currentPlan)
+        Some((targetKey, currentPlan))
       case _ => None
     }
 
     if (!plan.isStreaming) {
-      extract(plan, AttributeSet.empty,
-        hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = plan)
+      extract(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
+        currentPlan = plan, targetKey = filterCreationSideKey)
     } else {
       None
     }
@@ -239,7 +260,7 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
       filterApplicationSide: LogicalPlan,
       filterCreationSide: LogicalPlan,
       filterApplicationSideKey: Expression,
-      filterCreationSideKey: Expression): Option[LogicalPlan] = {
+      filterCreationSideKey: Expression): Option[(Expression, LogicalPlan)] = {
     if (findExpressionAndTrackLineageDown(
       filterApplicationSideKey, filterApplicationSide).isDefined &&
       satisfyByteSizeRequirement(filterApplicationSide)) {
@@ -331,8 +352,8 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
             val hasShuffle = isProbablyShuffleJoin(left, right, hint)
             if (canPruneLeft(joinType) && (hasShuffle || probablyHasShuffle(left))) {
               extractBeneficialFilterCreatePlan(left, right, l, r).foreach {
-                filterCreationSidePlan =>
-                  newLeft = injectFilter(l, newLeft, r, filterCreationSidePlan)
+                case (filterCreationSideKey, filterCreationSidePlan) =>
+                  newLeft = injectFilter(l, newLeft, filterCreationSideKey, filterCreationSidePlan)
               }
             }
             // Did we actually inject on the left? If not, try on the right
@@ -341,8 +362,9 @@ object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with J
             if (newLeft.fastEquals(oldLeft) && canPruneRight(joinType) &&
               (hasShuffle || probablyHasShuffle(right))) {
               extractBeneficialFilterCreatePlan(right, left, r, l).foreach {
-                filterCreationSidePlan =>
-                  newRight = injectFilter(r, newRight, l, filterCreationSidePlan)
+                case (filterCreationSideKey, filterCreationSidePlan) =>
+                  newRight = injectFilter(
+                    r, newRight, filterCreationSideKey, filterCreationSidePlan)
               }
             }
             if (!newLeft.fastEquals(oldLeft) || !newRight.fastEquals(oldRight)) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
index fedfd9ff587..c46e0bfcecb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/InjectRuntimeFilterSuite.scala
@@ -360,7 +360,7 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
     withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
       assertRewroteSemiJoin("select * from bf1 join bf2 join bf3 join bf4 on " +
-        "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5")
+        "bf1.c1 = bf2.c2 and bf2.c2 = bf3.c3 and bf3.c3 = bf4.c4 where bf1.a1 = 5", 3)
     }
   }
 
@@ -390,34 +390,60 @@ class InjectRuntimeFilterSuite extends QueryTest with SQLTestUtils with SharedSp
   test("Runtime bloom filter join: two joins") {
     withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "2000") {
+      // bf2 as creation side and inject runtime filter for bf1 and bf3.
       assertRewroteWithBloomFilter("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " +
         "and bf3.c3 = bf2.c2 where bf2.a2 = 5", 2)
-      assertRewroteWithBloomFilter("select * from (select * from bf1 left semi join bf2 on " +
-        "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
-      assertRewroteWithBloomFilter("select * from (select * from bf1 left anti join bf2 on " +
-        "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
       assertRewroteWithBloomFilter("select * from bf1 left outer join bf2 join bf3 on " +
         "bf1.c1 = bf2.c2 and bf3.c3 = bf2.c2 where bf2.a2 = 5", 2)
       assertRewroteWithBloomFilter("select * from bf1 right outer join bf2 join bf3 on " +
         "bf1.c1 = bf2.c2 and bf3.c3 = bf2.c2 where bf2.a2 = 5", 2)
+      // bf1 and bf2 hasn't shuffle. bf1 as creation side and inject runtime filter for bf3.
+      assertRewroteWithBloomFilter("select * from (select * from bf1 left semi join bf2 on " +
+        "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
+      assertRewroteWithBloomFilter("select * from (select * from bf1 left anti join bf2 on " +
+        "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
+      // bf1 as creation side and inject runtime filter for bf2 and bf3.
       assertRewroteWithBloomFilter("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " +
         "and bf3.c3 = bf1.c1 where bf1.a1 = 5", 2)
       assertRewroteWithBloomFilter("select * from bf1 left outer join bf2 join bf3 on " +
         "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf1.a1 = 5", 2)
       assertRewroteWithBloomFilter("select * from bf1 right outer join bf2 join bf3 on " +
         "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf1.a1 = 5", 2)
+      // bf2 as creation side and inject runtime filter for bf1 and bf3(join keys are transitive).
+      assertRewroteWithBloomFilter("select * from (select * from bf1 join bf2 on " +
+        "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2)
+      assertRewroteWithBloomFilter("select * from (select * from bf1 left join bf2 on " +
+        "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2)
+      assertRewroteWithBloomFilter("select * from (select * from bf1 right join bf2 on " +
+        "bf1.c1 = bf2.c2 where bf2.a2 = 5) as a join bf3 on bf3.c3 = a.c1", 2)
+      // Can't leverage the transitivity of join keys due to runtime filters already exists.
+      // bf2 as creation side and inject runtime filter for bf1.
+      assertRewroteWithBloomFilter("select * from bf1 join bf2 join bf3 on bf1.c1 = bf2.c2 " +
+        "and bf3.c3 = bf1.c1 where bf2.a2 = 5")
+      assertRewroteWithBloomFilter("select * from bf1 left outer join bf2 join bf3 on " +
+        "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5")
+      assertRewroteWithBloomFilter("select * from bf1 right outer join bf2 join bf3 on " +
+        "bf1.c1 = bf2.c2 and bf3.c3 = bf1.c1 where bf2.a2 = 5")
     }
 
     withSQLConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD.key -> "3000",
       SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1200") {
+      // bf1 as creation side and inject runtime filter for bf2 and bf3.
       assertRewroteWithBloomFilter("select * from (select * from bf1 left semi join bf2 on " +
         "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1", 2)
+      // left anti join unsupported. bf1 as creation side and inject runtime filter for bf3.
       assertRewroteWithBloomFilter("select * from (select * from bf1 left anti join bf2 on " +
         "bf1.c1 = bf2.c2 where bf1.a1 = 5) as a join bf3 on bf3.c3 = a.c1")
+      // bf2 as creation side and inject runtime filter for bf1 and bf3(by passing key).
       assertRewroteWithBloomFilter("select * from (select * from bf1 left semi join bf2 on " +
+        "(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1", 2)
+      // left anti join unsupported.
+      // bf2 as creation side and inject runtime filter for bf3(by passing key).
+      assertRewroteWithBloomFilter("select * from (select * from bf1 left anti join bf2 on " +
         "(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1")
+      // left anti join unsupported and hasn't selective filter.
       assertRewroteWithBloomFilter("select * from (select * from bf1 left anti join bf2 on " +
-        "(bf1.c1 = bf2.c2 and bf2.a2 = 5)) as a join bf3 on bf3.c3 = a.c1", 0)
+        "(bf1.c1 = bf2.c2 and bf1.a1 = 5)) as a join bf3 on bf3.c3 = a.c1", 0)
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org