You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yu...@apache.org on 2022/09/22 03:15:05 UTC

[spark] branch master updated: [SPARK-40487][SQL] Make defaultJoin in BroadcastNestedLoopJoinExec running in parallel

This is an automated email from the ASF dual-hosted git repository.

yumwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new db51ec6b9c4 [SPARK-40487][SQL] Make defaultJoin in BroadcastNestedLoopJoinExec running in parallel
db51ec6b9c4 is described below

commit db51ec6b9c4669a94e55d52add01c9e568d2bbf3
Author: Xingchao, Zhang <xi...@ebay.com>
AuthorDate: Thu Sep 22 11:14:11 2022 +0800

    [SPARK-40487][SQL] Make defaultJoin in BroadcastNestedLoopJoinExec running in parallel
    
    ### What changes were proposed in this pull request?
    
    Currently, the defaultJoin method in BroadcastNestedLoopJoinExec collects notMatchedBroadcastRows firstly, then collects matchedStreamRows. The two steps could run in parallel instead of serial.
    
    ### Why are the changes needed?
    Make defaultJoin in BroadcastNestedLoopJoinExec running in parallel.
    
    ### Does this PR introduce _any_ user-facing change?
    NO.
    
    ### How was this patch tested?
    UT.
    
    Closes #37930 from xingchaozh/SPARK-40487.
    
    Authored-by: Xingchao, Zhang <xi...@ebay.com>
    Signed-off-by: Yuming Wang <yu...@ebay.com>
---
 .../joins/BroadcastNestedLoopJoinExec.scala        | 43 ++++++++++++++--------
 .../scala/org/apache/spark/sql/JoinSuite.scala     | 15 ++++++++
 2 files changed, 42 insertions(+), 16 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index 23b5b614369..84c0cd127f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -286,21 +286,25 @@ case class BroadcastNestedLoopJoinExec(
    */
   private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
     val streamRdd = streamed.execute()
-    val matchedBroadcastRows = getMatchedBroadcastRowsBitSet(streamRdd, relation)
-    val notMatchedBroadcastRows: Seq[InternalRow] = {
-      val nulls = new GenericInternalRow(streamed.output.size)
-      val buf: CompactBuffer[InternalRow] = new CompactBuffer()
-      val joinedRow = new JoinedRow
-      joinedRow.withLeft(nulls)
-      var i = 0
-      val buildRows = relation.value
-      while (i < buildRows.length) {
-        if (!matchedBroadcastRows.get(i)) {
-          buf += joinedRow.withRight(buildRows(i)).copy()
+    def notMatchedBroadcastRows: RDD[InternalRow] = {
+      getMatchedBroadcastRowsBitSetRDD(streamRdd, relation)
+        .repartition(1)
+        .mapPartitions(iter => Seq(iter.fold(new BitSet(relation.value.length))(_ | _)).toIterator)
+        .flatMap { matchedBroadcastRows =>
+          val nulls = new GenericInternalRow(streamed.output.size)
+          val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+          val joinedRow = new JoinedRow
+          joinedRow.withLeft(nulls)
+          var i = 0
+          val buildRows = relation.value
+          while (i < buildRows.length) {
+            if (!matchedBroadcastRows.get(i)) {
+              buf += joinedRow.withRight(buildRows(i)).copy()
+            }
+            i += 1
+          }
+          buf.iterator
         }
-        i += 1
-      }
-      buf
     }
 
     val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
@@ -330,7 +334,7 @@ case class BroadcastNestedLoopJoinExec(
 
     sparkContext.union(
       matchedStreamRows,
-      sparkContext.makeRDD(notMatchedBroadcastRows)
+      notMatchedBroadcastRows
     )
   }
 
@@ -342,6 +346,13 @@ case class BroadcastNestedLoopJoinExec(
   private def getMatchedBroadcastRowsBitSet(
       streamRdd: RDD[InternalRow],
       relation: Broadcast[Array[InternalRow]]): BitSet = {
+    getMatchedBroadcastRowsBitSetRDD(streamRdd, relation)
+      .fold(new BitSet(relation.value.length))(_ | _)
+  }
+
+  private def getMatchedBroadcastRowsBitSetRDD(
+      streamRdd: RDD[InternalRow],
+      relation: Broadcast[Array[InternalRow]]): RDD[BitSet] = {
     val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter =>
       val buildRows = relation.value
       val matched = new BitSet(buildRows.length)
@@ -359,7 +370,7 @@ case class BroadcastNestedLoopJoinExec(
       Seq(matched).iterator
     }
 
-    matchedBuildRows.fold(new BitSet(relation.value.length))(_ | _)
+    matchedBuildRows
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index f41944d2ed5..6dd34d41cf6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -1440,4 +1440,19 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
       }
     }
   }
+
+  test("SPARK-40487: Make defaultJoin in BroadcastNestedLoopJoinExec running in parallel") {
+    withTable("t1", "t2") {
+      spark.range(5, 15).toDF("k").write.saveAsTable("t1")
+      spark.range(4, 8).toDF("k").write.saveAsTable("t2")
+
+      val queryBuildLeft = "SELECT /*+ BROADCAST(t1) */ *  FROM t1 LEFT JOIN t2 ON t1.k < t2.k"
+      val result1 = sql(queryBuildLeft)
+
+      val queryBuildRight = "SELECT /*+ BROADCAST(t2) */ *  FROM t1 LEFT JOIN t2 ON t1.k < t2.k"
+      val result2 = sql(queryBuildRight)
+
+      checkAnswer(result1, result2)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org